diff --git a/__pycache__/mcp_database.cpython-310.pyc b/__pycache__/mcp_database.cpython-310.pyc index e6d84f78..026c6759 100644 Binary files a/__pycache__/mcp_database.cpython-310.pyc and b/__pycache__/mcp_database.cpython-310.pyc differ diff --git a/__pycache__/mcp_server.cpython-310.pyc b/__pycache__/mcp_server.cpython-310.pyc index e06d1636..81c425da 100644 Binary files a/__pycache__/mcp_server.cpython-310.pyc and b/__pycache__/mcp_server.cpython-310.pyc differ diff --git a/mcp_database.py b/mcp_database.py index 4c868afb..0f467013 100644 --- a/mcp_database.py +++ b/mcp_database.py @@ -1030,3 +1030,51 @@ async def get_stock_intraday_statistics( except Exception as e: logger.error(f"[ClickHouse] 日内统计失败: {e}", exc_info=True) return {"success": False, "error": str(e)} + + +async def get_stock_code_by_name(stock_name: str) -> Dict[str, Any]: + """ + 根据股票名称查询股票代码 + + Args: + stock_name: 股票名称(支持模糊匹配) + + Returns: + 匹配的股票列表,包含代码和名称 + """ + pool = await get_pool() + + async with pool.acquire() as conn: + async with conn.cursor(aiomysql.DictCursor) as cursor: + # 使用 LIKE 进行模糊匹配 + query = """ + SELECT DISTINCT + SECCODE as code, + SECNAME as name, + F030V as industry + FROM ea_baseinfo + WHERE SECNAME LIKE %s + OR SECNAME = %s + ORDER BY + CASE WHEN SECNAME = %s THEN 0 ELSE 1 END, + SECCODE + LIMIT 10 + """ + + # 精确匹配和模糊匹配 + like_pattern = f"%{stock_name}%" + + await cursor.execute(query, (like_pattern, stock_name, stock_name)) + results = await cursor.fetchall() + + if not results: + return { + "success": False, + "error": f"未找到名称包含 '{stock_name}' 的股票" + } + + return { + "success": True, + "data": results, + "count": len(results) + } diff --git a/mcp_server.py b/mcp_server.py index 9889235d..ac55ace4 100644 --- a/mcp_server.py +++ b/mcp_server.py @@ -495,6 +495,20 @@ TOOLS: List[ToolDefinition] = [ "required": ["query"] } ), + ToolDefinition( + name="get_stock_code_by_name", + description="根据股票名称查询股票代码,支持模糊匹配。当只知道股票名称不知道代码时使用。", + parameters={ + "type": "object", + "properties": { + "stock_name": { + "type": "string", + "description": "股票名称,例如:'贵州茅台'、'舒泰神'、'比亚迪'" + } + }, + "required": ["stock_name"] + } + ), ToolDefinition( name="get_stock_basic_info", description="获取股票基本信息,包括公司名称、行业、地址、主营业务、高管等基础数据。", @@ -1494,7 +1508,11 @@ async def handle_get_concept_details(args: Dict[str, Any]) -> Any: async def handle_get_stock_concepts(args: Dict[str, Any]) -> Any: """处理股票概念获取""" - stock_code = args["stock_code"] + # 兼容不同的参数名: stock_code, seccode, code + stock_code = args.get("stock_code") or args.get("seccode") or args.get("code") + if not stock_code: + raise ValueError("缺少股票代码参数 (stock_code/seccode/code)") + params = { "size": args.get("size", 50), "sort_by": args.get("sort_by", "stock_count"), @@ -1503,6 +1521,7 @@ async def handle_get_stock_concepts(args: Dict[str, Any]) -> Any: if args.get("trade_date"): params["trade_date"] = args["trade_date"] + logger.info(f"[get_stock_concepts] 查询股票 {stock_code} 的概念") response = await HTTP_CLIENT.get( f"{ServiceEndpoints.CONCEPT_API}/stock/{stock_code}/concepts", params=params @@ -1573,9 +1592,24 @@ async def handle_search_research_reports(args: Dict[str, Any]) -> Any: response.raise_for_status() return response.json() +async def handle_get_stock_code_by_name(args: Dict[str, Any]) -> Any: + """根据股票名称查询股票代码""" + # 兼容不同的参数名: stock_name, name + stock_name = args.get("stock_name") or args.get("name") + if not stock_name: + return {"success": False, "error": "缺少股票名称参数 (stock_name/name)"} + + logger.info(f"[get_stock_code_by_name] 查询股票名称: {stock_name}") + result = await db.get_stock_code_by_name(stock_name) + return result + async def handle_get_stock_basic_info(args: Dict[str, Any]) -> Any: """处理股票基本信息查询""" - seccode = args["seccode"] + # 兼容不同的参数名: seccode, stock_code, code + seccode = args.get("seccode") or args.get("stock_code") or args.get("code") + if not seccode: + return {"success": False, "error": "缺少股票代码参数 (seccode/stock_code/code)"} + result = await db.get_stock_basic_info(seccode) if result: return {"success": True, "data": result} @@ -1803,6 +1837,7 @@ TOOL_HANDLERS = { "search_limit_up_stocks": handle_search_limit_up_stocks, "get_daily_stock_analysis": handle_get_daily_stock_analysis, "search_research_reports": handle_search_research_reports, + "get_stock_code_by_name": handle_get_stock_code_by_name, "get_stock_basic_info": handle_get_stock_basic_info, "get_stock_financial_index": handle_get_stock_financial_index, "get_stock_trade_data": handle_get_stock_trade_data, @@ -2549,8 +2584,15 @@ A股交易时间: 上午 9:30-11:30,下午 13:00-15:00 # 如果没有原生工具调用,尝试从文本内容中解析 if not native_tool_calls and assistant_message.content: content = assistant_message.content - # 检查是否包含工具调用标记 - if '' in content or '```tool_call' in content or '"tool":' in content: + # 检查是否包含工具调用标记(包括 DSML 格式) + has_tool_markers = ( + '' in content or + '```tool_call' in content or + '"tool":' in content or + 'DSML' in content or # DeepSeek DSML 格式 + '|DSML|' in content # 全角竖线版本 + ) + if has_tool_markers: logger.info(f"[Agent Stream] 尝试从文本内容解析工具调用") logger.info(f"[Agent Stream] 内容预览: {content[:500]}") text_tool_calls = self._parse_text_tool_calls(content) @@ -2947,6 +2989,7 @@ A股交易时间: 上午 9:30-11:30,下午 13:00-15:00 支持的格式: 1. value 2. ```tool_call\n{"name": "xxx", "arguments": {...}}\n``` + 3. DeepSeek DSML 格式: <|DSML|function_calls> <|DSML|invoke name="xxx"> <|DSML|parameter name="yyy" string="true">value 返回: [{"name": "tool_name", "arguments": {...}}, ...] """ @@ -3006,6 +3049,47 @@ A股交易时间: 上午 9:30-11:30,下午 13:00-15:00 except: pass + # 格式4: DeepSeek DSML 格式(使用全角竖线 |) + # <|DSML|function_calls> <|DSML|invoke name="search_research_reports"> <|DSML|parameter name="query" string="true">AI概念股 + # 注意:| 是全角字符 + dsml_pattern = r'<[|\|]DSML[|\|]function_calls>(.*?)' + dsml_matches = re.findall(dsml_pattern, content, re.DOTALL) + + for dsml_content in dsml_matches: + # 解析 invoke 标签 + invoke_pattern = r'<[|\|]DSML[|\|]invoke\s+name="(\w+)">(.*?)' + invoke_matches = re.findall(invoke_pattern, dsml_content, re.DOTALL) + + for func_name, params_str in invoke_matches: + arguments = {} + # 解析参数: <|DSML|parameter name="xxx" string="true/false">value + param_pattern = r'<[|\|]DSML[|\|]parameter\s+name="(\w+)"\s+string="(true|false)">(.*?)' + param_matches = re.findall(param_pattern, params_str, re.DOTALL) + + for param_name, is_string, param_value in param_matches: + param_value = param_value.strip() + if is_string == "false": + # 不是字符串,尝试解析为数字或 JSON + try: + arguments[param_name] = json.loads(param_value) + except: + # 尝试转为整数或浮点数 + try: + arguments[param_name] = int(param_value) + except: + try: + arguments[param_name] = float(param_value) + except: + arguments[param_name] = param_value + else: + # 是字符串 + arguments[param_name] = param_value + + tool_calls.append({ + "name": func_name, + "arguments": arguments + }) + logger.info(f"[Text Tool Call] 解析到 {len(tool_calls)} 个工具调用: {tool_calls}") return tool_calls