471 lines
15 KiB
Python
471 lines
15 KiB
Python
"""
|
||
MCP Agent System - 基于 DeepResearch 逻辑的智能代理系统
|
||
三阶段流程:计划制定 → 工具执行 → 结果总结
|
||
"""
|
||
|
||
from pydantic import BaseModel
|
||
from typing import List, Dict, Any, Optional, Literal
|
||
from datetime import datetime
|
||
import json
|
||
import logging
|
||
from openai import OpenAI
|
||
import asyncio
|
||
import os
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# ==================== 数据模型 ====================
|
||
|
||
class ToolCall(BaseModel):
|
||
"""工具调用"""
|
||
tool: str
|
||
arguments: Dict[str, Any]
|
||
reason: str # 为什么要调用这个工具
|
||
|
||
class ExecutionPlan(BaseModel):
|
||
"""执行计划"""
|
||
goal: str # 用户的目标
|
||
steps: List[ToolCall] # 执行步骤
|
||
reasoning: str # 规划reasoning
|
||
|
||
class StepResult(BaseModel):
|
||
"""单步执行结果"""
|
||
step_index: int
|
||
tool: str
|
||
arguments: Dict[str, Any]
|
||
status: Literal["success", "failed", "skipped"]
|
||
result: Optional[Any] = None
|
||
error: Optional[str] = None
|
||
execution_time: float = 0
|
||
|
||
class AgentResponse(BaseModel):
|
||
"""Agent响应"""
|
||
success: bool
|
||
message: str # 自然语言总结
|
||
plan: Optional[ExecutionPlan] = None # 执行计划
|
||
step_results: List[StepResult] = [] # 每步的结果
|
||
final_summary: Optional[str] = None # 最终总结
|
||
metadata: Optional[Dict[str, Any]] = None
|
||
|
||
class ChatRequest(BaseModel):
|
||
"""聊天请求"""
|
||
message: str
|
||
conversation_history: List[Dict[str, str]] = []
|
||
stream: bool = False # 是否流式输出
|
||
|
||
# ==================== Agent 系统 ====================
|
||
|
||
class MCPAgent:
|
||
"""MCP 智能代理 - 三阶段执行"""
|
||
|
||
def __init__(self, provider: str = "qwen"):
|
||
self.provider = provider
|
||
|
||
# LLM 配置
|
||
config = {
|
||
"qwen": {
|
||
"api_key": os.getenv("DASHSCOPE_API_KEY", ""),
|
||
"base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||
"model": "qwen-plus",
|
||
},
|
||
"deepseek": {
|
||
"api_key": os.getenv("DEEPSEEK_API_KEY", ""),
|
||
"base_url": "https://api.deepseek.com/v1",
|
||
"model": "deepseek-chat",
|
||
},
|
||
"openai": {
|
||
"api_key": os.getenv("OPENAI_API_KEY", ""),
|
||
"base_url": "https://api.openai.com/v1",
|
||
"model": "gpt-4o-mini",
|
||
},
|
||
}.get(provider)
|
||
|
||
if not config or not config["api_key"]:
|
||
raise ValueError(f"Provider '{provider}' not configured. Please set API key.")
|
||
|
||
self.client = OpenAI(
|
||
api_key=config["api_key"],
|
||
base_url=config["base_url"],
|
||
)
|
||
self.model = config["model"]
|
||
|
||
# ==================== 阶段 1: 计划制定 ====================
|
||
|
||
def get_planning_prompt(self, tools: List[dict]) -> str:
|
||
"""获取计划制定的系统提示词"""
|
||
tools_desc = "\n\n".join([
|
||
f"**{tool['name']}**\n"
|
||
f"描述:{tool['description']}\n"
|
||
f"参数:{json.dumps(tool['parameters'], ensure_ascii=False, indent=2)}"
|
||
for tool in tools
|
||
])
|
||
|
||
return f"""你是一个专业的金融研究助手。你需要根据用户的问题,制定一个详细的执行计划。
|
||
|
||
## 可用工具
|
||
|
||
{tools_desc}
|
||
|
||
## 重要知识
|
||
- 贵州茅台股票代码: 600519
|
||
- 涨停: 股价单日涨幅约10%
|
||
- 概念板块: 相同题材的股票分类
|
||
|
||
## 特殊工具说明
|
||
- **summarize_with_llm**: 这是一个特殊工具,用于让你总结和分析收集到的数据
|
||
- 当需要对多个数据源进行综合分析时使用
|
||
- 当需要生成研究报告时使用
|
||
- 参数: {{"data": "要分析的数据", "task": "分析任务描述"}}
|
||
|
||
## 任务
|
||
分析用户问题,制定执行计划。返回 JSON 格式:
|
||
|
||
```json
|
||
{{
|
||
"goal": "用户的目标(一句话概括)",
|
||
"reasoning": "你的分析思路(为什么这样规划)",
|
||
"steps": [
|
||
{{
|
||
"tool": "工具名称",
|
||
"arguments": {{"参数名": "参数值"}},
|
||
"reason": "为什么要执行这一步"
|
||
}}
|
||
]
|
||
}}
|
||
```
|
||
|
||
## 规划原则
|
||
1. **从简到繁**: 先获取基础信息,再深入分析
|
||
2. **数据先行**: 先收集数据,再总结分析
|
||
3. **合理组合**: 可以调用多个工具,但不要超过5个
|
||
4. **包含总结**: 最后一步通常是 summarize_with_llm
|
||
|
||
## 示例
|
||
|
||
用户:"帮我全面分析一下贵州茅台这只股票"
|
||
|
||
你的计划:
|
||
```json
|
||
{{
|
||
"goal": "全面分析贵州茅台股票",
|
||
"reasoning": "需要获取基本信息、财务指标、交易数据,然后综合分析",
|
||
"steps": [
|
||
{{
|
||
"tool": "get_stock_basic_info",
|
||
"arguments": {{"seccode": "600519"}},
|
||
"reason": "获取股票基本信息(公司名称、行业、市值等)"
|
||
}},
|
||
{{
|
||
"tool": "get_stock_financial_index",
|
||
"arguments": {{"seccode": "600519", "limit": 5}},
|
||
"reason": "获取最近5期财务指标(营收、利润、ROE等)"
|
||
}},
|
||
{{
|
||
"tool": "get_stock_trade_data",
|
||
"arguments": {{"seccode": "600519", "limit": 30}},
|
||
"reason": "获取最近30天交易数据(价格走势、成交量)"
|
||
}},
|
||
{{
|
||
"tool": "search_china_news",
|
||
"arguments": {{"query": "贵州茅台", "top_k": 5}},
|
||
"reason": "获取最新新闻,了解市场动态"
|
||
}},
|
||
{{
|
||
"tool": "summarize_with_llm",
|
||
"arguments": {{
|
||
"data": "前面收集的所有数据",
|
||
"task": "综合分析贵州茅台的投资价值,包括基本面、财务状况、股价走势、市场情绪"
|
||
}},
|
||
"reason": "综合所有数据,生成投资分析报告"
|
||
}}
|
||
]
|
||
}}
|
||
```
|
||
|
||
只返回JSON,不要额外解释。"""
|
||
|
||
async def create_plan(self, user_query: str, tools: List[dict]) -> ExecutionPlan:
|
||
"""阶段1: 创建执行计划"""
|
||
logger.info(f"[Planning] Creating plan for: {user_query}")
|
||
|
||
messages = [
|
||
{"role": "system", "content": self.get_planning_prompt(tools)},
|
||
{"role": "user", "content": user_query},
|
||
]
|
||
|
||
response = self.client.chat.completions.create(
|
||
model=self.model,
|
||
messages=messages,
|
||
temperature=0.3,
|
||
max_tokens=1500,
|
||
)
|
||
|
||
plan_json = response.choices[0].message.content.strip()
|
||
logger.info(f"[Planning] Raw response: {plan_json}")
|
||
|
||
# 清理可能的代码块标记
|
||
if "```json" in plan_json:
|
||
plan_json = plan_json.split("```json")[1].split("```")[0].strip()
|
||
elif "```" in plan_json:
|
||
plan_json = plan_json.split("```")[1].split("```")[0].strip()
|
||
|
||
plan_data = json.loads(plan_json)
|
||
|
||
plan = ExecutionPlan(
|
||
goal=plan_data["goal"],
|
||
reasoning=plan_data.get("reasoning", ""),
|
||
steps=[
|
||
ToolCall(**step) for step in plan_data["steps"]
|
||
],
|
||
)
|
||
|
||
logger.info(f"[Planning] Plan created: {len(plan.steps)} steps")
|
||
return plan
|
||
|
||
# ==================== 阶段 2: 工具执行 ====================
|
||
|
||
async def execute_tool(
|
||
self,
|
||
tool_name: str,
|
||
arguments: Dict[str, Any],
|
||
tool_handlers: Dict[str, Any],
|
||
) -> Dict[str, Any]:
|
||
"""执行单个工具"""
|
||
|
||
# 特殊处理:summarize_with_llm
|
||
if tool_name == "summarize_with_llm":
|
||
return await self.summarize_with_llm(
|
||
data=arguments.get("data", ""),
|
||
task=arguments.get("task", "总结数据"),
|
||
)
|
||
|
||
# 调用 MCP 工具
|
||
handler = tool_handlers.get(tool_name)
|
||
if not handler:
|
||
raise ValueError(f"Tool '{tool_name}' not found")
|
||
|
||
result = await handler(arguments)
|
||
return result
|
||
|
||
async def execute_plan(
|
||
self,
|
||
plan: ExecutionPlan,
|
||
tool_handlers: Dict[str, Any],
|
||
) -> List[StepResult]:
|
||
"""阶段2: 执行计划中的所有步骤"""
|
||
logger.info(f"[Execution] Starting execution: {len(plan.steps)} steps")
|
||
|
||
results = []
|
||
collected_data = {} # 收集的数据,供后续步骤使用
|
||
|
||
for i, step in enumerate(plan.steps):
|
||
logger.info(f"[Execution] Step {i+1}/{len(plan.steps)}: {step.tool}")
|
||
|
||
start_time = datetime.now()
|
||
|
||
try:
|
||
# 替换 arguments 中的占位符
|
||
arguments = step.arguments.copy()
|
||
if step.tool == "summarize_with_llm" and arguments.get("data") == "前面收集的所有数据":
|
||
# 将收集的数据传递给总结工具
|
||
arguments["data"] = json.dumps(collected_data, ensure_ascii=False, indent=2)
|
||
|
||
# 执行工具
|
||
result = await self.execute_tool(step.tool, arguments, tool_handlers)
|
||
|
||
execution_time = (datetime.now() - start_time).total_seconds()
|
||
|
||
# 保存结果
|
||
step_result = StepResult(
|
||
step_index=i,
|
||
tool=step.tool,
|
||
arguments=arguments,
|
||
status="success",
|
||
result=result,
|
||
execution_time=execution_time,
|
||
)
|
||
results.append(step_result)
|
||
|
||
# 收集数据
|
||
collected_data[f"step_{i+1}_{step.tool}"] = result
|
||
|
||
logger.info(f"[Execution] Step {i+1} completed in {execution_time:.2f}s")
|
||
|
||
except Exception as e:
|
||
logger.error(f"[Execution] Step {i+1} failed: {str(e)}")
|
||
|
||
execution_time = (datetime.now() - start_time).total_seconds()
|
||
|
||
step_result = StepResult(
|
||
step_index=i,
|
||
tool=step.tool,
|
||
arguments=step.arguments,
|
||
status="failed",
|
||
error=str(e),
|
||
execution_time=execution_time,
|
||
)
|
||
results.append(step_result)
|
||
|
||
# 根据错误类型决定是否继续
|
||
if "not found" in str(e).lower():
|
||
logger.warning(f"[Execution] Stopping due to critical error")
|
||
break
|
||
else:
|
||
logger.warning(f"[Execution] Continuing despite error")
|
||
continue
|
||
|
||
logger.info(f"[Execution] Execution completed: {len(results)} steps")
|
||
return results
|
||
|
||
async def summarize_with_llm(self, data: str, task: str) -> str:
|
||
"""特殊工具:使用 LLM 总结数据"""
|
||
logger.info(f"[LLM Summary] Task: {task}")
|
||
|
||
messages = [
|
||
{
|
||
"role": "system",
|
||
"content": "你是一个专业的金融分析师。根据提供的数据,完成指定的分析任务。"
|
||
},
|
||
{
|
||
"role": "user",
|
||
"content": f"## 任务\n{task}\n\n## 数据\n{data}\n\n请根据数据完成分析任务,用专业且易懂的语言呈现。"
|
||
},
|
||
]
|
||
|
||
response = self.client.chat.completions.create(
|
||
model=self.model,
|
||
messages=messages,
|
||
temperature=0.7,
|
||
max_tokens=2000,
|
||
)
|
||
|
||
summary = response.choices[0].message.content
|
||
return summary
|
||
|
||
# ==================== 阶段 3: 结果总结 ====================
|
||
|
||
async def generate_final_summary(
|
||
self,
|
||
user_query: str,
|
||
plan: ExecutionPlan,
|
||
step_results: List[StepResult],
|
||
) -> str:
|
||
"""阶段3: 生成最终总结"""
|
||
logger.info("[Summary] Generating final summary")
|
||
|
||
# 收集所有成功的结果
|
||
successful_results = [r for r in step_results if r.status == "success"]
|
||
|
||
if not successful_results:
|
||
return "很抱歉,所有步骤都执行失败,无法生成分析报告。"
|
||
|
||
# 构建总结提示
|
||
results_text = "\n\n".join([
|
||
f"**步骤 {r.step_index + 1}: {r.tool}**\n"
|
||
f"结果: {json.dumps(r.result, ensure_ascii=False, indent=2)[:1000]}..."
|
||
for r in successful_results
|
||
])
|
||
|
||
messages = [
|
||
{
|
||
"role": "system",
|
||
"content": "你是一个专业的金融研究助手。根据执行结果,生成一份简洁清晰的报告。"
|
||
},
|
||
{
|
||
"role": "user",
|
||
"content": f"""
|
||
用户问题:{user_query}
|
||
|
||
执行计划:{plan.goal}
|
||
|
||
执行结果:
|
||
{results_text}
|
||
|
||
请根据以上信息,生成一份专业的分析报告(300字以内)。
|
||
"""
|
||
},
|
||
]
|
||
|
||
response = self.client.chat.completions.create(
|
||
model=self.model,
|
||
messages=messages,
|
||
temperature=0.7,
|
||
max_tokens=1000,
|
||
)
|
||
|
||
summary = response.choices[0].message.content
|
||
logger.info("[Summary] Final summary generated")
|
||
return summary
|
||
|
||
# ==================== 主流程 ====================
|
||
|
||
async def process_query(
|
||
self,
|
||
user_query: str,
|
||
tools: List[dict],
|
||
tool_handlers: Dict[str, Any],
|
||
) -> AgentResponse:
|
||
"""主流程:处理用户查询"""
|
||
logger.info(f"[Agent] Processing query: {user_query}")
|
||
|
||
try:
|
||
# 阶段 1: 创建计划
|
||
plan = await self.create_plan(user_query, tools)
|
||
|
||
# 阶段 2: 执行计划
|
||
step_results = await self.execute_plan(plan, tool_handlers)
|
||
|
||
# 阶段 3: 生成总结
|
||
final_summary = await self.generate_final_summary(
|
||
user_query, plan, step_results
|
||
)
|
||
|
||
return AgentResponse(
|
||
success=True,
|
||
message=final_summary,
|
||
plan=plan,
|
||
step_results=step_results,
|
||
final_summary=final_summary,
|
||
metadata={
|
||
"total_steps": len(plan.steps),
|
||
"successful_steps": len([r for r in step_results if r.status == "success"]),
|
||
"failed_steps": len([r for r in step_results if r.status == "failed"]),
|
||
"total_execution_time": sum(r.execution_time for r in step_results),
|
||
},
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"[Agent] Error: {str(e)}", exc_info=True)
|
||
return AgentResponse(
|
||
success=False,
|
||
message=f"处理失败: {str(e)}",
|
||
)
|
||
|
||
# ==================== FastAPI 端点 ====================
|
||
|
||
"""
|
||
在 mcp_server.py 中添加:
|
||
|
||
from mcp_agent_system import MCPAgent, ChatRequest, AgentResponse
|
||
|
||
# 创建 Agent 实例
|
||
agent = MCPAgent(provider="qwen")
|
||
|
||
@app.post("/agent/chat", response_model=AgentResponse)
|
||
async def agent_chat(request: ChatRequest):
|
||
\"\"\"智能代理对话端点\"\"\"
|
||
logger.info(f"Agent chat: {request.message}")
|
||
|
||
# 获取工具列表和处理器
|
||
tools = [tool.dict() for tool in TOOLS]
|
||
|
||
# 处理查询
|
||
response = await agent.process_query(
|
||
user_query=request.message,
|
||
tools=tools,
|
||
tool_handlers=TOOL_HANDLERS,
|
||
)
|
||
|
||
return response
|
||
"""
|