Compare commits
6 Commits
193aad3458
...
feature_bu
| Author | SHA1 | Date | |
|---|---|---|---|
| f5023d9ce6 | |||
| c589516633 | |||
| c88f13db89 | |||
| 5804aa27c4 | |||
| 413e327a19 | |||
| f9163b1228 |
Binary file not shown.
Binary file not shown.
@@ -1030,3 +1030,51 @@ async def get_stock_intraday_statistics(
|
||||
except Exception as e:
|
||||
logger.error(f"[ClickHouse] 日内统计失败: {e}", exc_info=True)
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
|
||||
async def get_stock_code_by_name(stock_name: str) -> Dict[str, Any]:
|
||||
"""
|
||||
根据股票名称查询股票代码
|
||||
|
||||
Args:
|
||||
stock_name: 股票名称(支持模糊匹配)
|
||||
|
||||
Returns:
|
||||
匹配的股票列表,包含代码和名称
|
||||
"""
|
||||
pool = await get_pool()
|
||||
|
||||
async with pool.acquire() as conn:
|
||||
async with conn.cursor(aiomysql.DictCursor) as cursor:
|
||||
# 使用 LIKE 进行模糊匹配
|
||||
query = """
|
||||
SELECT DISTINCT
|
||||
SECCODE as code,
|
||||
SECNAME as name,
|
||||
F030V as industry
|
||||
FROM ea_baseinfo
|
||||
WHERE SECNAME LIKE %s
|
||||
OR SECNAME = %s
|
||||
ORDER BY
|
||||
CASE WHEN SECNAME = %s THEN 0 ELSE 1 END,
|
||||
SECCODE
|
||||
LIMIT 10
|
||||
"""
|
||||
|
||||
# 精确匹配和模糊匹配
|
||||
like_pattern = f"%{stock_name}%"
|
||||
|
||||
await cursor.execute(query, (like_pattern, stock_name, stock_name))
|
||||
results = await cursor.fetchall()
|
||||
|
||||
if not results:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"未找到名称包含 '{stock_name}' 的股票"
|
||||
}
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"data": results,
|
||||
"count": len(results)
|
||||
}
|
||||
|
||||
339
mcp_server.py
339
mcp_server.py
@@ -314,33 +314,17 @@ TOOLS: List[ToolDefinition] = [
|
||||
),
|
||||
ToolDefinition(
|
||||
name="search_concepts",
|
||||
description="搜索股票概念板块,支持按涨跌幅、股票数量排序。返回概念详情及相关股票列表。",
|
||||
description="搜索股票概念板块,返回概念详情及相关股票列表。",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "搜索关键词,例如:'新能源'、'人工智能'"
|
||||
},
|
||||
"size": {
|
||||
"type": "integer",
|
||||
"description": "每页结果数量",
|
||||
"default": 10
|
||||
},
|
||||
"page": {
|
||||
"type": "integer",
|
||||
"description": "页码",
|
||||
"default": 1
|
||||
},
|
||||
"sort_by": {
|
||||
"type": "string",
|
||||
"description": "排序方式:change_pct(涨跌幅), _score(相关度), stock_count(股票数), concept_name(名称)",
|
||||
"enum": ["change_pct", "_score", "stock_count", "concept_name"],
|
||||
"default": "change_pct"
|
||||
"description": "搜索关键词,例如:'新能源'、'人工智能'、'商业航天'"
|
||||
},
|
||||
"trade_date": {
|
||||
"type": "string",
|
||||
"description": "交易日期,格式:YYYY-MM-DD,默认最新"
|
||||
"description": "交易日期,格式:YYYY-MM-DD,不传则使用今天"
|
||||
}
|
||||
},
|
||||
"required": ["query"]
|
||||
@@ -511,6 +495,20 @@ TOOLS: List[ToolDefinition] = [
|
||||
"required": ["query"]
|
||||
}
|
||||
),
|
||||
ToolDefinition(
|
||||
name="get_stock_code_by_name",
|
||||
description="根据股票名称查询股票代码,支持模糊匹配。当只知道股票名称不知道代码时使用。",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"stock_name": {
|
||||
"type": "string",
|
||||
"description": "股票名称,例如:'贵州茅台'、'舒泰神'、'比亚迪'"
|
||||
}
|
||||
},
|
||||
"required": ["stock_name"]
|
||||
}
|
||||
),
|
||||
ToolDefinition(
|
||||
name="get_stock_basic_info",
|
||||
description="获取股票基本信息,包括公司名称、行业、地址、主营业务、高管等基础数据。",
|
||||
@@ -1469,18 +1467,27 @@ async def handle_search_roadshows(args: Dict[str, Any]) -> Any:
|
||||
return response.json()
|
||||
|
||||
async def handle_search_concepts(args: Dict[str, Any]) -> Any:
|
||||
"""处理概念搜索"""
|
||||
"""处理概念搜索
|
||||
|
||||
参数写死:size=12, page=1, sort_by="_score"
|
||||
trade_date 如果没传则使用今天的日期
|
||||
"""
|
||||
from datetime import date
|
||||
|
||||
# trade_date 默认今天
|
||||
trade_date = args.get("trade_date") or date.today().strftime("%Y-%m-%d")
|
||||
|
||||
payload = {
|
||||
"query": args["query"],
|
||||
"size": args.get("size", 10),
|
||||
"page": args.get("page", 1),
|
||||
"size": 12, # 写死
|
||||
"page": 1, # 写死
|
||||
"sort_by": "_score", # 写死,按相关度排序
|
||||
"trade_date": trade_date,
|
||||
"search_size": 100,
|
||||
"sort_by": args.get("sort_by", "change_pct"),
|
||||
"use_knn": True
|
||||
}
|
||||
if args.get("trade_date"):
|
||||
payload["trade_date"] = args["trade_date"]
|
||||
|
||||
logger.info(f"[search_concepts] 请求参数: {payload}")
|
||||
response = await HTTP_CLIENT.post(f"{ServiceEndpoints.CONCEPT_API}/search", json=payload)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
@@ -1501,7 +1508,11 @@ async def handle_get_concept_details(args: Dict[str, Any]) -> Any:
|
||||
|
||||
async def handle_get_stock_concepts(args: Dict[str, Any]) -> Any:
|
||||
"""处理股票概念获取"""
|
||||
stock_code = args["stock_code"]
|
||||
# 兼容不同的参数名: stock_code, seccode, code
|
||||
stock_code = args.get("stock_code") or args.get("seccode") or args.get("code")
|
||||
if not stock_code:
|
||||
raise ValueError("缺少股票代码参数 (stock_code/seccode/code)")
|
||||
|
||||
params = {
|
||||
"size": args.get("size", 50),
|
||||
"sort_by": args.get("sort_by", "stock_count"),
|
||||
@@ -1510,6 +1521,7 @@ async def handle_get_stock_concepts(args: Dict[str, Any]) -> Any:
|
||||
if args.get("trade_date"):
|
||||
params["trade_date"] = args["trade_date"]
|
||||
|
||||
logger.info(f"[get_stock_concepts] 查询股票 {stock_code} 的概念")
|
||||
response = await HTTP_CLIENT.get(
|
||||
f"{ServiceEndpoints.CONCEPT_API}/stock/{stock_code}/concepts",
|
||||
params=params
|
||||
@@ -1580,9 +1592,24 @@ async def handle_search_research_reports(args: Dict[str, Any]) -> Any:
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
async def handle_get_stock_code_by_name(args: Dict[str, Any]) -> Any:
|
||||
"""根据股票名称查询股票代码"""
|
||||
# 兼容不同的参数名: stock_name, name
|
||||
stock_name = args.get("stock_name") or args.get("name")
|
||||
if not stock_name:
|
||||
return {"success": False, "error": "缺少股票名称参数 (stock_name/name)"}
|
||||
|
||||
logger.info(f"[get_stock_code_by_name] 查询股票名称: {stock_name}")
|
||||
result = await db.get_stock_code_by_name(stock_name)
|
||||
return result
|
||||
|
||||
async def handle_get_stock_basic_info(args: Dict[str, Any]) -> Any:
|
||||
"""处理股票基本信息查询"""
|
||||
seccode = args["seccode"]
|
||||
# 兼容不同的参数名: seccode, stock_code, code
|
||||
seccode = args.get("seccode") or args.get("stock_code") or args.get("code")
|
||||
if not seccode:
|
||||
return {"success": False, "error": "缺少股票代码参数 (seccode/stock_code/code)"}
|
||||
|
||||
result = await db.get_stock_basic_info(seccode)
|
||||
if result:
|
||||
return {"success": True, "data": result}
|
||||
@@ -1810,6 +1837,7 @@ TOOL_HANDLERS = {
|
||||
"search_limit_up_stocks": handle_search_limit_up_stocks,
|
||||
"get_daily_stock_analysis": handle_get_daily_stock_analysis,
|
||||
"search_research_reports": handle_search_research_reports,
|
||||
"get_stock_code_by_name": handle_get_stock_code_by_name,
|
||||
"get_stock_basic_info": handle_get_stock_basic_info,
|
||||
"get_stock_financial_index": handle_get_stock_financial_index,
|
||||
"get_stock_trade_data": handle_get_stock_trade_data,
|
||||
@@ -2549,9 +2577,29 @@ A股交易时间: 上午 9:30-11:30,下午 13:00-15:00
|
||||
assistant_message = response.choices[0].message
|
||||
logger.info(f"[Agent Stream] LLM 响应: finish_reason={response.choices[0].finish_reason}")
|
||||
|
||||
# 检查是否有工具调用
|
||||
if assistant_message.tool_calls:
|
||||
logger.info(f"[Agent Stream] 检测到 {len(assistant_message.tool_calls)} 个工具调用")
|
||||
# 获取工具调用(优先使用原生 tool_calls,其次解析文本格式)
|
||||
native_tool_calls = assistant_message.tool_calls or []
|
||||
text_tool_calls = []
|
||||
|
||||
# 如果没有原生工具调用,尝试从文本内容中解析
|
||||
if not native_tool_calls and assistant_message.content:
|
||||
content = assistant_message.content
|
||||
# 检查是否包含工具调用标记(包括 DSML 格式)
|
||||
has_tool_markers = (
|
||||
'<tool_call>' in content or
|
||||
'```tool_call' in content or
|
||||
'"tool":' in content or
|
||||
'DSML' in content or # DeepSeek DSML 格式
|
||||
'|DSML|' in content # 全角竖线版本
|
||||
)
|
||||
if has_tool_markers:
|
||||
logger.info(f"[Agent Stream] 尝试从文本内容解析工具调用")
|
||||
logger.info(f"[Agent Stream] 内容预览: {content[:500]}")
|
||||
text_tool_calls = self._parse_text_tool_calls(content)
|
||||
|
||||
# 检查是否有工具调用(原生或文本格式)
|
||||
if native_tool_calls:
|
||||
logger.info(f"[Agent Stream] 检测到 {len(native_tool_calls)} 个原生工具调用")
|
||||
|
||||
# 将 assistant 消息添加到历史(包含 tool_calls)
|
||||
messages.append(assistant_message)
|
||||
@@ -2564,7 +2612,7 @@ A股交易时间: 上午 9:30-11:30,下午 13:00-15:00
|
||||
"reasoning": "使用工具获取相关数据进行分析",
|
||||
"steps": []
|
||||
}
|
||||
for tc in assistant_message.tool_calls:
|
||||
for tc in native_tool_calls:
|
||||
try:
|
||||
args = json.loads(tc.function.arguments) if tc.function.arguments else {}
|
||||
except:
|
||||
@@ -2576,10 +2624,10 @@ A股交易时间: 上午 9:30-11:30,下午 13:00-15:00
|
||||
})
|
||||
|
||||
yield self._format_sse("plan", plan_data)
|
||||
yield self._format_sse("status", {"stage": "executing", "message": f"开始执行 {len(assistant_message.tool_calls)} 个工具调用"})
|
||||
yield self._format_sse("status", {"stage": "executing", "message": f"开始执行 {len(native_tool_calls)} 个工具调用"})
|
||||
|
||||
# 执行每个工具调用
|
||||
for tool_call in assistant_message.tool_calls:
|
||||
for tool_call in native_tool_calls:
|
||||
tool_name = tool_call.function.name
|
||||
tool_call_id = tool_call.id
|
||||
|
||||
@@ -2682,6 +2730,120 @@ A股交易时间: 上午 9:30-11:30,下午 13:00-15:00
|
||||
logger.info(f"[Tool Call] ========== 工具调用结束 ==========")
|
||||
step_index += 1
|
||||
|
||||
elif text_tool_calls:
|
||||
# 处理文本格式的工具调用
|
||||
logger.info(f"[Agent Stream] 检测到 {len(text_tool_calls)} 个文本格式工具调用")
|
||||
|
||||
# 将 assistant 消息添加到历史
|
||||
messages.append({"role": "assistant", "content": assistant_message.content})
|
||||
|
||||
# 如果是第一次工具调用,发送计划事件
|
||||
if step_index == 0:
|
||||
plan_data = {
|
||||
"goal": f"分析用户问题:{user_query[:50]}...",
|
||||
"reasoning": "使用工具获取相关数据进行分析",
|
||||
"steps": [
|
||||
{"tool": tc["name"], "arguments": tc["arguments"], "reason": f"调用 {tc['name']}"}
|
||||
for tc in text_tool_calls
|
||||
]
|
||||
}
|
||||
yield self._format_sse("plan", plan_data)
|
||||
yield self._format_sse("status", {"stage": "executing", "message": f"开始执行 {len(text_tool_calls)} 个工具调用"})
|
||||
|
||||
# 执行每个工具调用
|
||||
for tc in text_tool_calls:
|
||||
tool_name = tc["name"]
|
||||
arguments = tc["arguments"]
|
||||
tool_call_id = f"text_call_{step_index}_{tool_name}"
|
||||
|
||||
logger.info(f"[Tool Call] ========== 文本工具调用开始 ==========")
|
||||
logger.info(f"[Tool Call] 工具名: {tool_name}")
|
||||
logger.info(f"[Tool Call] 参数内容: {json.dumps(arguments, ensure_ascii=False)}")
|
||||
|
||||
# 发送步骤开始事件
|
||||
yield self._format_sse("step_start", {
|
||||
"step_index": step_index,
|
||||
"tool": tool_name,
|
||||
"arguments": arguments,
|
||||
"reason": f"调用 {tool_name}",
|
||||
})
|
||||
|
||||
start_time = datetime.now()
|
||||
|
||||
try:
|
||||
# 特殊处理 summarize_news
|
||||
if tool_name == "summarize_news":
|
||||
data_arg = arguments.get("data", "")
|
||||
if data_arg in ["前面的新闻数据", "前面收集的所有数据", ""]:
|
||||
arguments["data"] = json.dumps(collected_data, ensure_ascii=False, indent=2)
|
||||
|
||||
# 执行工具
|
||||
result = await self.execute_tool(tool_name, arguments, tool_handlers)
|
||||
execution_time = (datetime.now() - start_time).total_seconds()
|
||||
|
||||
# 记录结果
|
||||
step_result = StepResult(
|
||||
step_index=step_index,
|
||||
tool=tool_name,
|
||||
arguments=arguments,
|
||||
status="success",
|
||||
result=result,
|
||||
execution_time=execution_time,
|
||||
)
|
||||
step_results.append(step_result)
|
||||
collected_data[f"step_{step_index+1}_{tool_name}"] = result
|
||||
plan_steps.append({"tool": tool_name, "arguments": arguments, "reason": f"调用 {tool_name}"})
|
||||
|
||||
# 发送步骤完成事件
|
||||
yield self._format_sse("step_complete", {
|
||||
"step_index": step_index,
|
||||
"tool": tool_name,
|
||||
"status": "success",
|
||||
"result": result,
|
||||
"execution_time": execution_time,
|
||||
})
|
||||
|
||||
# 将工具结果添加到消息历史(简化格式,因为模型可能不支持标准 tool 消息)
|
||||
result_str = json.dumps(result, ensure_ascii=False) if isinstance(result, (dict, list)) else str(result)
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": f"[工具调用结果] {tool_name}: {result_str[:3000]}"
|
||||
})
|
||||
|
||||
logger.info(f"[Tool Call] 执行成功,耗时 {execution_time:.2f}s")
|
||||
|
||||
except Exception as e:
|
||||
execution_time = (datetime.now() - start_time).total_seconds()
|
||||
error_msg = str(e)
|
||||
|
||||
step_result = StepResult(
|
||||
step_index=step_index,
|
||||
tool=tool_name,
|
||||
arguments=arguments,
|
||||
status="failed",
|
||||
error=error_msg,
|
||||
execution_time=execution_time,
|
||||
)
|
||||
step_results.append(step_result)
|
||||
|
||||
yield self._format_sse("step_complete", {
|
||||
"step_index": step_index,
|
||||
"tool": tool_name,
|
||||
"status": "failed",
|
||||
"error": error_msg,
|
||||
"execution_time": execution_time,
|
||||
})
|
||||
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": f"[工具调用失败] {tool_name}: {error_msg}"
|
||||
})
|
||||
|
||||
logger.error(f"[Tool Call] 执行失败: {error_msg}")
|
||||
|
||||
logger.info(f"[Tool Call] ========== 文本工具调用结束 ==========")
|
||||
step_index += 1
|
||||
|
||||
else:
|
||||
# 没有工具调用,模型生成了最终回复
|
||||
logger.info(f"[Agent Stream] 模型生成最终回复")
|
||||
@@ -2820,6 +2982,117 @@ A股交易时间: 上午 9:30-11:30,下午 13:00-15:00
|
||||
"""格式化 SSE 消息"""
|
||||
return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
|
||||
|
||||
def _parse_text_tool_calls(self, content: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
解析文本格式的工具调用
|
||||
|
||||
支持的格式:
|
||||
1. <tool_call> <function=xxx> <parameter=yyy> value </parameter> </function> </tool_call>
|
||||
2. ```tool_call\n{"name": "xxx", "arguments": {...}}\n```
|
||||
3. DeepSeek DSML 格式: <|DSML|function_calls> <|DSML|invoke name="xxx"> <|DSML|parameter name="yyy" string="true">value</|DSML|parameter> </|DSML|invoke> </|DSML|function_calls>
|
||||
|
||||
返回: [{"name": "tool_name", "arguments": {...}}, ...]
|
||||
"""
|
||||
import re
|
||||
|
||||
tool_calls = []
|
||||
|
||||
# 格式1: <tool_call> 标签格式
|
||||
# 例如: <tool_call> <function=get_stock_concepts> <parameter=seccode> 300274 </parameter> </function> </tool_call>
|
||||
pattern1 = r'<tool_call>\s*<function=(\w+)>(.*?)</function>\s*</tool_call>'
|
||||
matches1 = re.findall(pattern1, content, re.DOTALL)
|
||||
|
||||
for func_name, params_str in matches1:
|
||||
arguments = {}
|
||||
# 解析参数: <parameter=xxx> value </parameter>
|
||||
param_pattern = r'<parameter=(\w+)>\s*(.*?)\s*</parameter>'
|
||||
param_matches = re.findall(param_pattern, params_str, re.DOTALL)
|
||||
for param_name, param_value in param_matches:
|
||||
# 尝试解析 JSON 值,否则作为字符串
|
||||
param_value = param_value.strip()
|
||||
try:
|
||||
arguments[param_name] = json.loads(param_value)
|
||||
except:
|
||||
arguments[param_name] = param_value
|
||||
|
||||
tool_calls.append({
|
||||
"name": func_name,
|
||||
"arguments": arguments
|
||||
})
|
||||
|
||||
# 格式2: ```tool_call 代码块格式
|
||||
pattern2 = r'```tool_call\s*\n?(.*?)\n?```'
|
||||
matches2 = re.findall(pattern2, content, re.DOTALL)
|
||||
|
||||
for match in matches2:
|
||||
try:
|
||||
data = json.loads(match.strip())
|
||||
if isinstance(data, dict) and "name" in data:
|
||||
tool_calls.append({
|
||||
"name": data["name"],
|
||||
"arguments": data.get("arguments", {})
|
||||
})
|
||||
except:
|
||||
pass
|
||||
|
||||
# 格式3: 直接 JSON 格式 {"tool": "xxx", "arguments": {...}}
|
||||
pattern3 = r'\{\s*"tool"\s*:\s*"(\w+)"\s*,\s*"arguments"\s*:\s*(\{[^}]*\})\s*\}'
|
||||
matches3 = re.findall(pattern3, content)
|
||||
|
||||
for tool_name, args_str in matches3:
|
||||
try:
|
||||
arguments = json.loads(args_str)
|
||||
tool_calls.append({
|
||||
"name": tool_name,
|
||||
"arguments": arguments
|
||||
})
|
||||
except:
|
||||
pass
|
||||
|
||||
# 格式4: DeepSeek DSML 格式(使用全角竖线 |)
|
||||
# <|DSML|function_calls> <|DSML|invoke name="search_research_reports"> <|DSML|parameter name="query" string="true">AI概念股</|DSML|parameter> </|DSML|invoke> </|DSML|function_calls>
|
||||
# 注意:| 是全角字符
|
||||
dsml_pattern = r'<[|\|]DSML[|\|]function_calls>(.*?)</[|\|]DSML[|\|]function_calls>'
|
||||
dsml_matches = re.findall(dsml_pattern, content, re.DOTALL)
|
||||
|
||||
for dsml_content in dsml_matches:
|
||||
# 解析 invoke 标签
|
||||
invoke_pattern = r'<[|\|]DSML[|\|]invoke\s+name="(\w+)">(.*?)</[|\|]DSML[|\|]invoke>'
|
||||
invoke_matches = re.findall(invoke_pattern, dsml_content, re.DOTALL)
|
||||
|
||||
for func_name, params_str in invoke_matches:
|
||||
arguments = {}
|
||||
# 解析参数: <|DSML|parameter name="xxx" string="true/false">value</|DSML|parameter>
|
||||
param_pattern = r'<[|\|]DSML[|\|]parameter\s+name="(\w+)"\s+string="(true|false)">(.*?)</[|\|]DSML[|\|]parameter>'
|
||||
param_matches = re.findall(param_pattern, params_str, re.DOTALL)
|
||||
|
||||
for param_name, is_string, param_value in param_matches:
|
||||
param_value = param_value.strip()
|
||||
if is_string == "false":
|
||||
# 不是字符串,尝试解析为数字或 JSON
|
||||
try:
|
||||
arguments[param_name] = json.loads(param_value)
|
||||
except:
|
||||
# 尝试转为整数或浮点数
|
||||
try:
|
||||
arguments[param_name] = int(param_value)
|
||||
except:
|
||||
try:
|
||||
arguments[param_name] = float(param_value)
|
||||
except:
|
||||
arguments[param_name] = param_value
|
||||
else:
|
||||
# 是字符串
|
||||
arguments[param_name] = param_value
|
||||
|
||||
tool_calls.append({
|
||||
"name": func_name,
|
||||
"arguments": arguments
|
||||
})
|
||||
|
||||
logger.info(f"[Text Tool Call] 解析到 {len(tool_calls)} 个工具调用: {tool_calls}")
|
||||
return tool_calls
|
||||
|
||||
# 创建 Agent 实例(全局)
|
||||
agent = MCPAgentIntegrated()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user