296 lines
10 KiB
Python
296 lines
10 KiB
Python
"""
|
||
MCP Chat Endpoint - 添加到 mcp_server.py
|
||
集成LLM实现智能对话,自动调用MCP工具并总结结果
|
||
"""
|
||
|
||
from pydantic import BaseModel
|
||
from typing import List, Dict, Any, Optional
|
||
import os
|
||
import json
|
||
from openai import OpenAI
|
||
import logging
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# ==================== LLM配置 ====================
|
||
|
||
# 支持多种LLM提供商
|
||
LLM_PROVIDERS = {
|
||
"openai": {
|
||
"api_key": os.getenv("OPENAI_API_KEY", ""),
|
||
"base_url": "https://api.openai.com/v1",
|
||
"model": "gpt-4o-mini", # 便宜且快速
|
||
},
|
||
"qwen": {
|
||
"api_key": os.getenv("DASHSCOPE_API_KEY", ""),
|
||
"base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||
"model": "qwen-plus",
|
||
},
|
||
"deepseek": {
|
||
"api_key": os.getenv("DEEPSEEK_API_KEY", ""),
|
||
"base_url": "https://api.deepseek.com/v1",
|
||
"model": "deepseek-chat",
|
||
},
|
||
}
|
||
|
||
# 默认使用的LLM提供商
|
||
DEFAULT_PROVIDER = "qwen" # 推荐使用通义千问,价格便宜
|
||
|
||
# ==================== 数据模型 ====================
|
||
|
||
class Message(BaseModel):
|
||
"""消息"""
|
||
role: str # system, user, assistant
|
||
content: str
|
||
|
||
class ChatRequest(BaseModel):
|
||
"""聊天请求"""
|
||
message: str
|
||
conversation_history: List[Dict[str, str]] = []
|
||
provider: Optional[str] = DEFAULT_PROVIDER
|
||
|
||
class ChatResponse(BaseModel):
|
||
"""聊天响应"""
|
||
success: bool
|
||
message: str
|
||
tool_used: Optional[str] = None
|
||
raw_data: Optional[Any] = None
|
||
error: Optional[str] = None
|
||
|
||
# ==================== LLM助手类 ====================
|
||
|
||
class MCPChatAssistant:
|
||
"""MCP聊天助手 - 集成LLM和工具调用"""
|
||
|
||
def __init__(self, provider: str = DEFAULT_PROVIDER):
|
||
self.provider = provider
|
||
config = LLM_PROVIDERS.get(provider)
|
||
|
||
if not config or not config["api_key"]:
|
||
logger.warning(f"LLM provider '{provider}' not configured, using fallback mode")
|
||
self.client = None
|
||
else:
|
||
self.client = OpenAI(
|
||
api_key=config["api_key"],
|
||
base_url=config["base_url"],
|
||
)
|
||
self.model = config["model"]
|
||
|
||
def get_system_prompt(self, tools: List[dict]) -> str:
|
||
"""构建系统提示词"""
|
||
tools_desc = "\n\n".join([
|
||
f"**{tool['name']}**\n描述:{tool['description']}\n参数:{json.dumps(tool['parameters'], ensure_ascii=False, indent=2)}"
|
||
for tool in tools
|
||
])
|
||
|
||
return f"""你是一个专业的金融投资助手。你可以使用以下工具来帮助用户查询信息:
|
||
|
||
{tools_desc}
|
||
|
||
## 工作流程
|
||
1. **理解用户意图**:分析用户问题,确定需要什么信息
|
||
2. **选择工具**:从上面的工具中选择最合适的一个或多个
|
||
3. **提取参数**:从用户输入中提取工具需要的参数
|
||
4. **返回工具调用指令**(JSON格式):
|
||
{{"tool": "工具名", "arguments": {{...}}}}
|
||
|
||
## 重要规则
|
||
- 贵州茅台的股票代码是 **600519**
|
||
- 如果用户提到股票名称,尝试推断股票代码
|
||
- 如果不确定需要什么信息,使用 search_china_news 搜索相关新闻
|
||
- 涨停是指股票当日涨幅达到10%左右
|
||
- 只返回工具调用指令,不要额外解释
|
||
|
||
## 示例
|
||
用户:"查询贵州茅台的股票信息"
|
||
你:{{"tool": "get_stock_basic_info", "arguments": {{"seccode": "600519"}}}}
|
||
|
||
用户:"今日涨停的股票有哪些"
|
||
你:{{"tool": "search_limit_up_stocks", "arguments": {{"query": "", "mode": "hybrid", "page_size": 10}}}}
|
||
|
||
用户:"新能源概念板块表现如何"
|
||
你:{{"tool": "search_concepts", "arguments": {{"query": "新能源", "size": 10, "sort_by": "change_pct"}}}}
|
||
"""
|
||
|
||
async def chat(self, user_message: str, conversation_history: List[Dict[str, str]], tools: List[dict]) -> ChatResponse:
|
||
"""智能对话"""
|
||
try:
|
||
if not self.client:
|
||
# 降级到简单匹配
|
||
return await self.fallback_chat(user_message)
|
||
|
||
# 1. 构建消息历史
|
||
messages = [
|
||
{"role": "system", "content": self.get_system_prompt(tools)},
|
||
]
|
||
|
||
# 添加历史对话(最多保留最近10轮)
|
||
for msg in conversation_history[-20:]:
|
||
messages.append({
|
||
"role": "user" if msg.get("isUser") else "assistant",
|
||
"content": msg.get("content", ""),
|
||
})
|
||
|
||
messages.append({"role": "user", "content": user_message})
|
||
|
||
# 2. 调用LLM获取工具调用指令
|
||
logger.info(f"Calling LLM with {len(messages)} messages")
|
||
response = self.client.chat.completions.create(
|
||
model=self.model,
|
||
messages=messages,
|
||
temperature=0.3, # 低温度,更确定性
|
||
max_tokens=500,
|
||
)
|
||
|
||
tool_call_instruction = response.choices[0].message.content.strip()
|
||
logger.info(f"LLM response: {tool_call_instruction}")
|
||
|
||
# 3. 解析工具调用指令
|
||
try:
|
||
tool_call = json.loads(tool_call_instruction)
|
||
tool_name = tool_call.get("tool")
|
||
tool_args = tool_call.get("arguments", {})
|
||
|
||
if not tool_name:
|
||
raise ValueError("No tool specified")
|
||
|
||
# 4. 调用工具(这里需要导入 mcp_server 的工具处理器)
|
||
from mcp_server import TOOL_HANDLERS
|
||
|
||
handler = TOOL_HANDLERS.get(tool_name)
|
||
if not handler:
|
||
raise ValueError(f"Tool '{tool_name}' not found")
|
||
|
||
tool_result = await handler(tool_args)
|
||
|
||
# 5. 让LLM总结结果
|
||
summary_messages = messages + [
|
||
{"role": "assistant", "content": tool_call_instruction},
|
||
{"role": "system", "content": f"工具 {tool_name} 返回的数据:\n{json.dumps(tool_result, ensure_ascii=False, indent=2)}\n\n请用自然语言总结这些数据,给用户一个简洁清晰的回复(不超过200字)。"}
|
||
]
|
||
|
||
summary_response = self.client.chat.completions.create(
|
||
model=self.model,
|
||
messages=summary_messages,
|
||
temperature=0.7,
|
||
max_tokens=300,
|
||
)
|
||
|
||
summary = summary_response.choices[0].message.content
|
||
|
||
return ChatResponse(
|
||
success=True,
|
||
message=summary,
|
||
tool_used=tool_name,
|
||
raw_data=tool_result,
|
||
)
|
||
|
||
except json.JSONDecodeError:
|
||
# LLM没有返回JSON格式,直接返回其回复
|
||
return ChatResponse(
|
||
success=True,
|
||
message=tool_call_instruction,
|
||
)
|
||
except Exception as tool_error:
|
||
logger.error(f"Tool execution error: {str(tool_error)}")
|
||
return ChatResponse(
|
||
success=False,
|
||
message="工具调用失败",
|
||
error=str(tool_error),
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"Chat error: {str(e)}", exc_info=True)
|
||
return ChatResponse(
|
||
success=False,
|
||
message="对话处理失败",
|
||
error=str(e),
|
||
)
|
||
|
||
async def fallback_chat(self, user_message: str) -> ChatResponse:
|
||
"""降级方案:简单关键词匹配"""
|
||
from mcp_server import TOOL_HANDLERS
|
||
|
||
try:
|
||
# 茅台特殊处理
|
||
if "茅台" in user_message or "贵州茅台" in user_message:
|
||
handler = TOOL_HANDLERS.get("get_stock_basic_info")
|
||
result = await handler({"seccode": "600519"})
|
||
return ChatResponse(
|
||
success=True,
|
||
message="已为您查询贵州茅台(600519)的股票信息:",
|
||
tool_used="get_stock_basic_info",
|
||
raw_data=result,
|
||
)
|
||
|
||
# 涨停分析
|
||
elif "涨停" in user_message:
|
||
handler = TOOL_HANDLERS.get("search_limit_up_stocks")
|
||
query = user_message.replace("涨停", "").strip()
|
||
result = await handler({"query": query, "mode": "hybrid", "page_size": 10})
|
||
return ChatResponse(
|
||
success=True,
|
||
message="已为您查询涨停股票信息:",
|
||
tool_used="search_limit_up_stocks",
|
||
raw_data=result,
|
||
)
|
||
|
||
# 概念板块
|
||
elif "概念" in user_message or "板块" in user_message:
|
||
handler = TOOL_HANDLERS.get("search_concepts")
|
||
query = user_message.replace("概念", "").replace("板块", "").strip()
|
||
result = await handler({"query": query, "size": 10, "sort_by": "change_pct"})
|
||
return ChatResponse(
|
||
success=True,
|
||
message=f"已为您查询'{query}'相关概念板块:",
|
||
tool_used="search_concepts",
|
||
raw_data=result,
|
||
)
|
||
|
||
# 默认:搜索新闻
|
||
else:
|
||
handler = TOOL_HANDLERS.get("search_china_news")
|
||
result = await handler({"query": user_message, "top_k": 5})
|
||
return ChatResponse(
|
||
success=True,
|
||
message="已为您搜索相关新闻:",
|
||
tool_used="search_china_news",
|
||
raw_data=result,
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"Fallback chat error: {str(e)}")
|
||
return ChatResponse(
|
||
success=False,
|
||
message="查询失败",
|
||
error=str(e),
|
||
)
|
||
|
||
# ==================== FastAPI端点 ====================
|
||
|
||
# 在 mcp_server.py 中添加以下代码:
|
||
|
||
"""
|
||
from mcp_chat_endpoint import MCPChatAssistant, ChatRequest, ChatResponse
|
||
|
||
# 创建聊天助手实例
|
||
chat_assistant = MCPChatAssistant(provider="qwen") # 或 "openai", "deepseek"
|
||
|
||
@app.post("/chat", response_model=ChatResponse)
|
||
async def chat_endpoint(request: ChatRequest):
|
||
\"\"\"智能对话端点 - 使用LLM理解意图并调用工具\"\"\"
|
||
logger.info(f"Chat request: {request.message}")
|
||
|
||
# 获取可用工具列表
|
||
tools = [tool.dict() for tool in TOOLS]
|
||
|
||
# 调用聊天助手
|
||
response = await chat_assistant.chat(
|
||
user_message=request.message,
|
||
conversation_history=request.conversation_history,
|
||
tools=tools,
|
||
)
|
||
|
||
return response
|
||
"""
|