agent功能开发增加MCP后端
This commit is contained in:
470
mcp_agent_system.py
Normal file
470
mcp_agent_system.py
Normal file
@@ -0,0 +1,470 @@
|
||||
"""
|
||||
MCP Agent System - 基于 DeepResearch 逻辑的智能代理系统
|
||||
三阶段流程:计划制定 → 工具执行 → 结果总结
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Dict, Any, Optional, Literal
|
||||
from datetime import datetime
|
||||
import json
|
||||
import logging
|
||||
from openai import OpenAI
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ==================== 数据模型 ====================
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
"""工具调用"""
|
||||
tool: str
|
||||
arguments: Dict[str, Any]
|
||||
reason: str # 为什么要调用这个工具
|
||||
|
||||
class ExecutionPlan(BaseModel):
|
||||
"""执行计划"""
|
||||
goal: str # 用户的目标
|
||||
steps: List[ToolCall] # 执行步骤
|
||||
reasoning: str # 规划reasoning
|
||||
|
||||
class StepResult(BaseModel):
|
||||
"""单步执行结果"""
|
||||
step_index: int
|
||||
tool: str
|
||||
arguments: Dict[str, Any]
|
||||
status: Literal["success", "failed", "skipped"]
|
||||
result: Optional[Any] = None
|
||||
error: Optional[str] = None
|
||||
execution_time: float = 0
|
||||
|
||||
class AgentResponse(BaseModel):
|
||||
"""Agent响应"""
|
||||
success: bool
|
||||
message: str # 自然语言总结
|
||||
plan: Optional[ExecutionPlan] = None # 执行计划
|
||||
step_results: List[StepResult] = [] # 每步的结果
|
||||
final_summary: Optional[str] = None # 最终总结
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
"""聊天请求"""
|
||||
message: str
|
||||
conversation_history: List[Dict[str, str]] = []
|
||||
stream: bool = False # 是否流式输出
|
||||
|
||||
# ==================== Agent 系统 ====================
|
||||
|
||||
class MCPAgent:
|
||||
"""MCP 智能代理 - 三阶段执行"""
|
||||
|
||||
def __init__(self, provider: str = "qwen"):
|
||||
self.provider = provider
|
||||
|
||||
# LLM 配置
|
||||
config = {
|
||||
"qwen": {
|
||||
"api_key": os.getenv("DASHSCOPE_API_KEY", ""),
|
||||
"base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
"model": "qwen-plus",
|
||||
},
|
||||
"deepseek": {
|
||||
"api_key": os.getenv("DEEPSEEK_API_KEY", ""),
|
||||
"base_url": "https://api.deepseek.com/v1",
|
||||
"model": "deepseek-chat",
|
||||
},
|
||||
"openai": {
|
||||
"api_key": os.getenv("OPENAI_API_KEY", ""),
|
||||
"base_url": "https://api.openai.com/v1",
|
||||
"model": "gpt-4o-mini",
|
||||
},
|
||||
}.get(provider)
|
||||
|
||||
if not config or not config["api_key"]:
|
||||
raise ValueError(f"Provider '{provider}' not configured. Please set API key.")
|
||||
|
||||
self.client = OpenAI(
|
||||
api_key=config["api_key"],
|
||||
base_url=config["base_url"],
|
||||
)
|
||||
self.model = config["model"]
|
||||
|
||||
# ==================== 阶段 1: 计划制定 ====================
|
||||
|
||||
def get_planning_prompt(self, tools: List[dict]) -> str:
|
||||
"""获取计划制定的系统提示词"""
|
||||
tools_desc = "\n\n".join([
|
||||
f"**{tool['name']}**\n"
|
||||
f"描述:{tool['description']}\n"
|
||||
f"参数:{json.dumps(tool['parameters'], ensure_ascii=False, indent=2)}"
|
||||
for tool in tools
|
||||
])
|
||||
|
||||
return f"""你是一个专业的金融研究助手。你需要根据用户的问题,制定一个详细的执行计划。
|
||||
|
||||
## 可用工具
|
||||
|
||||
{tools_desc}
|
||||
|
||||
## 重要知识
|
||||
- 贵州茅台股票代码: 600519
|
||||
- 涨停: 股价单日涨幅约10%
|
||||
- 概念板块: 相同题材的股票分类
|
||||
|
||||
## 特殊工具说明
|
||||
- **summarize_with_llm**: 这是一个特殊工具,用于让你总结和分析收集到的数据
|
||||
- 当需要对多个数据源进行综合分析时使用
|
||||
- 当需要生成研究报告时使用
|
||||
- 参数: {{"data": "要分析的数据", "task": "分析任务描述"}}
|
||||
|
||||
## 任务
|
||||
分析用户问题,制定执行计划。返回 JSON 格式:
|
||||
|
||||
```json
|
||||
{{
|
||||
"goal": "用户的目标(一句话概括)",
|
||||
"reasoning": "你的分析思路(为什么这样规划)",
|
||||
"steps": [
|
||||
{{
|
||||
"tool": "工具名称",
|
||||
"arguments": {{"参数名": "参数值"}},
|
||||
"reason": "为什么要执行这一步"
|
||||
}}
|
||||
]
|
||||
}}
|
||||
```
|
||||
|
||||
## 规划原则
|
||||
1. **从简到繁**: 先获取基础信息,再深入分析
|
||||
2. **数据先行**: 先收集数据,再总结分析
|
||||
3. **合理组合**: 可以调用多个工具,但不要超过5个
|
||||
4. **包含总结**: 最后一步通常是 summarize_with_llm
|
||||
|
||||
## 示例
|
||||
|
||||
用户:"帮我全面分析一下贵州茅台这只股票"
|
||||
|
||||
你的计划:
|
||||
```json
|
||||
{{
|
||||
"goal": "全面分析贵州茅台股票",
|
||||
"reasoning": "需要获取基本信息、财务指标、交易数据,然后综合分析",
|
||||
"steps": [
|
||||
{{
|
||||
"tool": "get_stock_basic_info",
|
||||
"arguments": {{"seccode": "600519"}},
|
||||
"reason": "获取股票基本信息(公司名称、行业、市值等)"
|
||||
}},
|
||||
{{
|
||||
"tool": "get_stock_financial_index",
|
||||
"arguments": {{"seccode": "600519", "limit": 5}},
|
||||
"reason": "获取最近5期财务指标(营收、利润、ROE等)"
|
||||
}},
|
||||
{{
|
||||
"tool": "get_stock_trade_data",
|
||||
"arguments": {{"seccode": "600519", "limit": 30}},
|
||||
"reason": "获取最近30天交易数据(价格走势、成交量)"
|
||||
}},
|
||||
{{
|
||||
"tool": "search_china_news",
|
||||
"arguments": {{"query": "贵州茅台", "top_k": 5}},
|
||||
"reason": "获取最新新闻,了解市场动态"
|
||||
}},
|
||||
{{
|
||||
"tool": "summarize_with_llm",
|
||||
"arguments": {{
|
||||
"data": "前面收集的所有数据",
|
||||
"task": "综合分析贵州茅台的投资价值,包括基本面、财务状况、股价走势、市场情绪"
|
||||
}},
|
||||
"reason": "综合所有数据,生成投资分析报告"
|
||||
}}
|
||||
]
|
||||
}}
|
||||
```
|
||||
|
||||
只返回JSON,不要额外解释。"""
|
||||
|
||||
async def create_plan(self, user_query: str, tools: List[dict]) -> ExecutionPlan:
|
||||
"""阶段1: 创建执行计划"""
|
||||
logger.info(f"[Planning] Creating plan for: {user_query}")
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": self.get_planning_prompt(tools)},
|
||||
{"role": "user", "content": user_query},
|
||||
]
|
||||
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
temperature=0.3,
|
||||
max_tokens=1500,
|
||||
)
|
||||
|
||||
plan_json = response.choices[0].message.content.strip()
|
||||
logger.info(f"[Planning] Raw response: {plan_json}")
|
||||
|
||||
# 清理可能的代码块标记
|
||||
if "```json" in plan_json:
|
||||
plan_json = plan_json.split("```json")[1].split("```")[0].strip()
|
||||
elif "```" in plan_json:
|
||||
plan_json = plan_json.split("```")[1].split("```")[0].strip()
|
||||
|
||||
plan_data = json.loads(plan_json)
|
||||
|
||||
plan = ExecutionPlan(
|
||||
goal=plan_data["goal"],
|
||||
reasoning=plan_data.get("reasoning", ""),
|
||||
steps=[
|
||||
ToolCall(**step) for step in plan_data["steps"]
|
||||
],
|
||||
)
|
||||
|
||||
logger.info(f"[Planning] Plan created: {len(plan.steps)} steps")
|
||||
return plan
|
||||
|
||||
# ==================== 阶段 2: 工具执行 ====================
|
||||
|
||||
async def execute_tool(
|
||||
self,
|
||||
tool_name: str,
|
||||
arguments: Dict[str, Any],
|
||||
tool_handlers: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
"""执行单个工具"""
|
||||
|
||||
# 特殊处理:summarize_with_llm
|
||||
if tool_name == "summarize_with_llm":
|
||||
return await self.summarize_with_llm(
|
||||
data=arguments.get("data", ""),
|
||||
task=arguments.get("task", "总结数据"),
|
||||
)
|
||||
|
||||
# 调用 MCP 工具
|
||||
handler = tool_handlers.get(tool_name)
|
||||
if not handler:
|
||||
raise ValueError(f"Tool '{tool_name}' not found")
|
||||
|
||||
result = await handler(arguments)
|
||||
return result
|
||||
|
||||
async def execute_plan(
|
||||
self,
|
||||
plan: ExecutionPlan,
|
||||
tool_handlers: Dict[str, Any],
|
||||
) -> List[StepResult]:
|
||||
"""阶段2: 执行计划中的所有步骤"""
|
||||
logger.info(f"[Execution] Starting execution: {len(plan.steps)} steps")
|
||||
|
||||
results = []
|
||||
collected_data = {} # 收集的数据,供后续步骤使用
|
||||
|
||||
for i, step in enumerate(plan.steps):
|
||||
logger.info(f"[Execution] Step {i+1}/{len(plan.steps)}: {step.tool}")
|
||||
|
||||
start_time = datetime.now()
|
||||
|
||||
try:
|
||||
# 替换 arguments 中的占位符
|
||||
arguments = step.arguments.copy()
|
||||
if step.tool == "summarize_with_llm" and arguments.get("data") == "前面收集的所有数据":
|
||||
# 将收集的数据传递给总结工具
|
||||
arguments["data"] = json.dumps(collected_data, ensure_ascii=False, indent=2)
|
||||
|
||||
# 执行工具
|
||||
result = await self.execute_tool(step.tool, arguments, tool_handlers)
|
||||
|
||||
execution_time = (datetime.now() - start_time).total_seconds()
|
||||
|
||||
# 保存结果
|
||||
step_result = StepResult(
|
||||
step_index=i,
|
||||
tool=step.tool,
|
||||
arguments=arguments,
|
||||
status="success",
|
||||
result=result,
|
||||
execution_time=execution_time,
|
||||
)
|
||||
results.append(step_result)
|
||||
|
||||
# 收集数据
|
||||
collected_data[f"step_{i+1}_{step.tool}"] = result
|
||||
|
||||
logger.info(f"[Execution] Step {i+1} completed in {execution_time:.2f}s")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Execution] Step {i+1} failed: {str(e)}")
|
||||
|
||||
execution_time = (datetime.now() - start_time).total_seconds()
|
||||
|
||||
step_result = StepResult(
|
||||
step_index=i,
|
||||
tool=step.tool,
|
||||
arguments=step.arguments,
|
||||
status="failed",
|
||||
error=str(e),
|
||||
execution_time=execution_time,
|
||||
)
|
||||
results.append(step_result)
|
||||
|
||||
# 根据错误类型决定是否继续
|
||||
if "not found" in str(e).lower():
|
||||
logger.warning(f"[Execution] Stopping due to critical error")
|
||||
break
|
||||
else:
|
||||
logger.warning(f"[Execution] Continuing despite error")
|
||||
continue
|
||||
|
||||
logger.info(f"[Execution] Execution completed: {len(results)} steps")
|
||||
return results
|
||||
|
||||
async def summarize_with_llm(self, data: str, task: str) -> str:
|
||||
"""特殊工具:使用 LLM 总结数据"""
|
||||
logger.info(f"[LLM Summary] Task: {task}")
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "你是一个专业的金融分析师。根据提供的数据,完成指定的分析任务。"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"## 任务\n{task}\n\n## 数据\n{data}\n\n请根据数据完成分析任务,用专业且易懂的语言呈现。"
|
||||
},
|
||||
]
|
||||
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
temperature=0.7,
|
||||
max_tokens=2000,
|
||||
)
|
||||
|
||||
summary = response.choices[0].message.content
|
||||
return summary
|
||||
|
||||
# ==================== 阶段 3: 结果总结 ====================
|
||||
|
||||
async def generate_final_summary(
|
||||
self,
|
||||
user_query: str,
|
||||
plan: ExecutionPlan,
|
||||
step_results: List[StepResult],
|
||||
) -> str:
|
||||
"""阶段3: 生成最终总结"""
|
||||
logger.info("[Summary] Generating final summary")
|
||||
|
||||
# 收集所有成功的结果
|
||||
successful_results = [r for r in step_results if r.status == "success"]
|
||||
|
||||
if not successful_results:
|
||||
return "很抱歉,所有步骤都执行失败,无法生成分析报告。"
|
||||
|
||||
# 构建总结提示
|
||||
results_text = "\n\n".join([
|
||||
f"**步骤 {r.step_index + 1}: {r.tool}**\n"
|
||||
f"结果: {json.dumps(r.result, ensure_ascii=False, indent=2)[:1000]}..."
|
||||
for r in successful_results
|
||||
])
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "你是一个专业的金融研究助手。根据执行结果,生成一份简洁清晰的报告。"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""
|
||||
用户问题:{user_query}
|
||||
|
||||
执行计划:{plan.goal}
|
||||
|
||||
执行结果:
|
||||
{results_text}
|
||||
|
||||
请根据以上信息,生成一份专业的分析报告(300字以内)。
|
||||
"""
|
||||
},
|
||||
]
|
||||
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
temperature=0.7,
|
||||
max_tokens=1000,
|
||||
)
|
||||
|
||||
summary = response.choices[0].message.content
|
||||
logger.info("[Summary] Final summary generated")
|
||||
return summary
|
||||
|
||||
# ==================== 主流程 ====================
|
||||
|
||||
async def process_query(
|
||||
self,
|
||||
user_query: str,
|
||||
tools: List[dict],
|
||||
tool_handlers: Dict[str, Any],
|
||||
) -> AgentResponse:
|
||||
"""主流程:处理用户查询"""
|
||||
logger.info(f"[Agent] Processing query: {user_query}")
|
||||
|
||||
try:
|
||||
# 阶段 1: 创建计划
|
||||
plan = await self.create_plan(user_query, tools)
|
||||
|
||||
# 阶段 2: 执行计划
|
||||
step_results = await self.execute_plan(plan, tool_handlers)
|
||||
|
||||
# 阶段 3: 生成总结
|
||||
final_summary = await self.generate_final_summary(
|
||||
user_query, plan, step_results
|
||||
)
|
||||
|
||||
return AgentResponse(
|
||||
success=True,
|
||||
message=final_summary,
|
||||
plan=plan,
|
||||
step_results=step_results,
|
||||
final_summary=final_summary,
|
||||
metadata={
|
||||
"total_steps": len(plan.steps),
|
||||
"successful_steps": len([r for r in step_results if r.status == "success"]),
|
||||
"failed_steps": len([r for r in step_results if r.status == "failed"]),
|
||||
"total_execution_time": sum(r.execution_time for r in step_results),
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Agent] Error: {str(e)}", exc_info=True)
|
||||
return AgentResponse(
|
||||
success=False,
|
||||
message=f"处理失败: {str(e)}",
|
||||
)
|
||||
|
||||
# ==================== FastAPI 端点 ====================
|
||||
|
||||
"""
|
||||
在 mcp_server.py 中添加:
|
||||
|
||||
from mcp_agent_system import MCPAgent, ChatRequest, AgentResponse
|
||||
|
||||
# 创建 Agent 实例
|
||||
agent = MCPAgent(provider="qwen")
|
||||
|
||||
@app.post("/agent/chat", response_model=AgentResponse)
|
||||
async def agent_chat(request: ChatRequest):
|
||||
\"\"\"智能代理对话端点\"\"\"
|
||||
logger.info(f"Agent chat: {request.message}")
|
||||
|
||||
# 获取工具列表和处理器
|
||||
tools = [tool.dict() for tool in TOOLS]
|
||||
|
||||
# 处理查询
|
||||
response = await agent.process_query(
|
||||
user_query=request.message,
|
||||
tools=tools,
|
||||
tool_handlers=TOOL_HANDLERS,
|
||||
)
|
||||
|
||||
return response
|
||||
"""
|
||||
Reference in New Issue
Block a user