""" MCP Chat Endpoint - 添加到 mcp_server.py 集成LLM实现智能对话,自动调用MCP工具并总结结果 """ from pydantic import BaseModel from typing import List, Dict, Any, Optional import os import json from openai import OpenAI import logging logger = logging.getLogger(__name__) # ==================== LLM配置 ==================== # 支持多种LLM提供商 LLM_PROVIDERS = { "openai": { "api_key": os.getenv("OPENAI_API_KEY", ""), "base_url": "https://api.openai.com/v1", "model": "gpt-4o-mini", # 便宜且快速 }, "qwen": { "api_key": os.getenv("DASHSCOPE_API_KEY", ""), "base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1", "model": "qwen-plus", }, "deepseek": { "api_key": os.getenv("DEEPSEEK_API_KEY", ""), "base_url": "https://api.deepseek.com/v1", "model": "deepseek-chat", }, } # 默认使用的LLM提供商 DEFAULT_PROVIDER = "qwen" # 推荐使用通义千问,价格便宜 # ==================== 数据模型 ==================== class Message(BaseModel): """消息""" role: str # system, user, assistant content: str class ChatRequest(BaseModel): """聊天请求""" message: str conversation_history: List[Dict[str, str]] = [] provider: Optional[str] = DEFAULT_PROVIDER class ChatResponse(BaseModel): """聊天响应""" success: bool message: str tool_used: Optional[str] = None raw_data: Optional[Any] = None error: Optional[str] = None # ==================== LLM助手类 ==================== class MCPChatAssistant: """MCP聊天助手 - 集成LLM和工具调用""" def __init__(self, provider: str = DEFAULT_PROVIDER): self.provider = provider config = LLM_PROVIDERS.get(provider) if not config or not config["api_key"]: logger.warning(f"LLM provider '{provider}' not configured, using fallback mode") self.client = None else: self.client = OpenAI( api_key=config["api_key"], base_url=config["base_url"], ) self.model = config["model"] def get_system_prompt(self, tools: List[dict]) -> str: """构建系统提示词""" tools_desc = "\n\n".join([ f"**{tool['name']}**\n描述:{tool['description']}\n参数:{json.dumps(tool['parameters'], ensure_ascii=False, indent=2)}" for tool in tools ]) return f"""你是一个专业的金融投资助手。你可以使用以下工具来帮助用户查询信息: {tools_desc} ## 工作流程 1. **理解用户意图**:分析用户问题,确定需要什么信息 2. **选择工具**:从上面的工具中选择最合适的一个或多个 3. **提取参数**:从用户输入中提取工具需要的参数 4. **返回工具调用指令**(JSON格式): {{"tool": "工具名", "arguments": {{...}}}} ## 重要规则 - 贵州茅台的股票代码是 **600519** - 如果用户提到股票名称,尝试推断股票代码 - 如果不确定需要什么信息,使用 search_china_news 搜索相关新闻 - 涨停是指股票当日涨幅达到10%左右 - 只返回工具调用指令,不要额外解释 ## 示例 用户:"查询贵州茅台的股票信息" 你:{{"tool": "get_stock_basic_info", "arguments": {{"seccode": "600519"}}}} 用户:"今日涨停的股票有哪些" 你:{{"tool": "search_limit_up_stocks", "arguments": {{"query": "", "mode": "hybrid", "page_size": 10}}}} 用户:"新能源概念板块表现如何" 你:{{"tool": "search_concepts", "arguments": {{"query": "新能源", "size": 10, "sort_by": "change_pct"}}}} """ async def chat(self, user_message: str, conversation_history: List[Dict[str, str]], tools: List[dict]) -> ChatResponse: """智能对话""" try: if not self.client: # 降级到简单匹配 return await self.fallback_chat(user_message) # 1. 构建消息历史 messages = [ {"role": "system", "content": self.get_system_prompt(tools)}, ] # 添加历史对话(最多保留最近10轮) for msg in conversation_history[-20:]: messages.append({ "role": "user" if msg.get("isUser") else "assistant", "content": msg.get("content", ""), }) messages.append({"role": "user", "content": user_message}) # 2. 调用LLM获取工具调用指令 logger.info(f"Calling LLM with {len(messages)} messages") response = self.client.chat.completions.create( model=self.model, messages=messages, temperature=0.3, # 低温度,更确定性 max_tokens=500, ) tool_call_instruction = response.choices[0].message.content.strip() logger.info(f"LLM response: {tool_call_instruction}") # 3. 解析工具调用指令 try: tool_call = json.loads(tool_call_instruction) tool_name = tool_call.get("tool") tool_args = tool_call.get("arguments", {}) if not tool_name: raise ValueError("No tool specified") # 4. 调用工具(这里需要导入 mcp_server 的工具处理器) from mcp_server import TOOL_HANDLERS handler = TOOL_HANDLERS.get(tool_name) if not handler: raise ValueError(f"Tool '{tool_name}' not found") tool_result = await handler(tool_args) # 5. 让LLM总结结果 summary_messages = messages + [ {"role": "assistant", "content": tool_call_instruction}, {"role": "system", "content": f"工具 {tool_name} 返回的数据:\n{json.dumps(tool_result, ensure_ascii=False, indent=2)}\n\n请用自然语言总结这些数据,给用户一个简洁清晰的回复(不超过200字)。"} ] summary_response = self.client.chat.completions.create( model=self.model, messages=summary_messages, temperature=0.7, max_tokens=300, ) summary = summary_response.choices[0].message.content return ChatResponse( success=True, message=summary, tool_used=tool_name, raw_data=tool_result, ) except json.JSONDecodeError: # LLM没有返回JSON格式,直接返回其回复 return ChatResponse( success=True, message=tool_call_instruction, ) except Exception as tool_error: logger.error(f"Tool execution error: {str(tool_error)}") return ChatResponse( success=False, message="工具调用失败", error=str(tool_error), ) except Exception as e: logger.error(f"Chat error: {str(e)}", exc_info=True) return ChatResponse( success=False, message="对话处理失败", error=str(e), ) async def fallback_chat(self, user_message: str) -> ChatResponse: """降级方案:简单关键词匹配""" from mcp_server import TOOL_HANDLERS try: # 茅台特殊处理 if "茅台" in user_message or "贵州茅台" in user_message: handler = TOOL_HANDLERS.get("get_stock_basic_info") result = await handler({"seccode": "600519"}) return ChatResponse( success=True, message="已为您查询贵州茅台(600519)的股票信息:", tool_used="get_stock_basic_info", raw_data=result, ) # 涨停分析 elif "涨停" in user_message: handler = TOOL_HANDLERS.get("search_limit_up_stocks") query = user_message.replace("涨停", "").strip() result = await handler({"query": query, "mode": "hybrid", "page_size": 10}) return ChatResponse( success=True, message="已为您查询涨停股票信息:", tool_used="search_limit_up_stocks", raw_data=result, ) # 概念板块 elif "概念" in user_message or "板块" in user_message: handler = TOOL_HANDLERS.get("search_concepts") query = user_message.replace("概念", "").replace("板块", "").strip() result = await handler({"query": query, "size": 10, "sort_by": "change_pct"}) return ChatResponse( success=True, message=f"已为您查询'{query}'相关概念板块:", tool_used="search_concepts", raw_data=result, ) # 默认:搜索新闻 else: handler = TOOL_HANDLERS.get("search_china_news") result = await handler({"query": user_message, "top_k": 5}) return ChatResponse( success=True, message="已为您搜索相关新闻:", tool_used="search_china_news", raw_data=result, ) except Exception as e: logger.error(f"Fallback chat error: {str(e)}") return ChatResponse( success=False, message="查询失败", error=str(e), ) # ==================== FastAPI端点 ==================== # 在 mcp_server.py 中添加以下代码: """ from mcp_chat_endpoint import MCPChatAssistant, ChatRequest, ChatResponse # 创建聊天助手实例 chat_assistant = MCPChatAssistant(provider="qwen") # 或 "openai", "deepseek" @app.post("/chat", response_model=ChatResponse) async def chat_endpoint(request: ChatRequest): \"\"\"智能对话端点 - 使用LLM理解意图并调用工具\"\"\" logger.info(f"Chat request: {request.message}") # 获取可用工具列表 tools = [tool.dict() for tool in TOOLS] # 调用聊天助手 response = await chat_assistant.chat( user_message=request.message, conversation_history=request.conversation_history, tools=tools, ) return response """