Compare commits
5 Commits
a3a82794ca
...
cb662c8a37
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cb662c8a37 | ||
| f5023d9ce6 | |||
| c589516633 | |||
| c88f13db89 | |||
| 5804aa27c4 |
Binary file not shown.
Binary file not shown.
@@ -1030,3 +1030,51 @@ async def get_stock_intraday_statistics(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[ClickHouse] 日内统计失败: {e}", exc_info=True)
|
logger.error(f"[ClickHouse] 日内统计失败: {e}", exc_info=True)
|
||||||
return {"success": False, "error": str(e)}
|
return {"success": False, "error": str(e)}
|
||||||
|
|
||||||
|
|
||||||
|
async def get_stock_code_by_name(stock_name: str) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
根据股票名称查询股票代码
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stock_name: 股票名称(支持模糊匹配)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
匹配的股票列表,包含代码和名称
|
||||||
|
"""
|
||||||
|
pool = await get_pool()
|
||||||
|
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
async with conn.cursor(aiomysql.DictCursor) as cursor:
|
||||||
|
# 使用 LIKE 进行模糊匹配
|
||||||
|
query = """
|
||||||
|
SELECT DISTINCT
|
||||||
|
SECCODE as code,
|
||||||
|
SECNAME as name,
|
||||||
|
F030V as industry
|
||||||
|
FROM ea_baseinfo
|
||||||
|
WHERE SECNAME LIKE %s
|
||||||
|
OR SECNAME = %s
|
||||||
|
ORDER BY
|
||||||
|
CASE WHEN SECNAME = %s THEN 0 ELSE 1 END,
|
||||||
|
SECCODE
|
||||||
|
LIMIT 10
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 精确匹配和模糊匹配
|
||||||
|
like_pattern = f"%{stock_name}%"
|
||||||
|
|
||||||
|
await cursor.execute(query, (like_pattern, stock_name, stock_name))
|
||||||
|
results = await cursor.fetchall()
|
||||||
|
|
||||||
|
if not results:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": f"未找到名称包含 '{stock_name}' 的股票"
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"data": results,
|
||||||
|
"count": len(results)
|
||||||
|
}
|
||||||
|
|||||||
@@ -495,6 +495,20 @@ TOOLS: List[ToolDefinition] = [
|
|||||||
"required": ["query"]
|
"required": ["query"]
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
|
ToolDefinition(
|
||||||
|
name="get_stock_code_by_name",
|
||||||
|
description="根据股票名称查询股票代码,支持模糊匹配。当只知道股票名称不知道代码时使用。",
|
||||||
|
parameters={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"stock_name": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "股票名称,例如:'贵州茅台'、'舒泰神'、'比亚迪'"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["stock_name"]
|
||||||
|
}
|
||||||
|
),
|
||||||
ToolDefinition(
|
ToolDefinition(
|
||||||
name="get_stock_basic_info",
|
name="get_stock_basic_info",
|
||||||
description="获取股票基本信息,包括公司名称、行业、地址、主营业务、高管等基础数据。",
|
description="获取股票基本信息,包括公司名称、行业、地址、主营业务、高管等基础数据。",
|
||||||
@@ -1494,7 +1508,11 @@ async def handle_get_concept_details(args: Dict[str, Any]) -> Any:
|
|||||||
|
|
||||||
async def handle_get_stock_concepts(args: Dict[str, Any]) -> Any:
|
async def handle_get_stock_concepts(args: Dict[str, Any]) -> Any:
|
||||||
"""处理股票概念获取"""
|
"""处理股票概念获取"""
|
||||||
stock_code = args["stock_code"]
|
# 兼容不同的参数名: stock_code, seccode, code
|
||||||
|
stock_code = args.get("stock_code") or args.get("seccode") or args.get("code")
|
||||||
|
if not stock_code:
|
||||||
|
raise ValueError("缺少股票代码参数 (stock_code/seccode/code)")
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
"size": args.get("size", 50),
|
"size": args.get("size", 50),
|
||||||
"sort_by": args.get("sort_by", "stock_count"),
|
"sort_by": args.get("sort_by", "stock_count"),
|
||||||
@@ -1503,6 +1521,7 @@ async def handle_get_stock_concepts(args: Dict[str, Any]) -> Any:
|
|||||||
if args.get("trade_date"):
|
if args.get("trade_date"):
|
||||||
params["trade_date"] = args["trade_date"]
|
params["trade_date"] = args["trade_date"]
|
||||||
|
|
||||||
|
logger.info(f"[get_stock_concepts] 查询股票 {stock_code} 的概念")
|
||||||
response = await HTTP_CLIENT.get(
|
response = await HTTP_CLIENT.get(
|
||||||
f"{ServiceEndpoints.CONCEPT_API}/stock/{stock_code}/concepts",
|
f"{ServiceEndpoints.CONCEPT_API}/stock/{stock_code}/concepts",
|
||||||
params=params
|
params=params
|
||||||
@@ -1573,9 +1592,24 @@ async def handle_search_research_reports(args: Dict[str, Any]) -> Any:
|
|||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
|
async def handle_get_stock_code_by_name(args: Dict[str, Any]) -> Any:
|
||||||
|
"""根据股票名称查询股票代码"""
|
||||||
|
# 兼容不同的参数名: stock_name, name
|
||||||
|
stock_name = args.get("stock_name") or args.get("name")
|
||||||
|
if not stock_name:
|
||||||
|
return {"success": False, "error": "缺少股票名称参数 (stock_name/name)"}
|
||||||
|
|
||||||
|
logger.info(f"[get_stock_code_by_name] 查询股票名称: {stock_name}")
|
||||||
|
result = await db.get_stock_code_by_name(stock_name)
|
||||||
|
return result
|
||||||
|
|
||||||
async def handle_get_stock_basic_info(args: Dict[str, Any]) -> Any:
|
async def handle_get_stock_basic_info(args: Dict[str, Any]) -> Any:
|
||||||
"""处理股票基本信息查询"""
|
"""处理股票基本信息查询"""
|
||||||
seccode = args["seccode"]
|
# 兼容不同的参数名: seccode, stock_code, code
|
||||||
|
seccode = args.get("seccode") or args.get("stock_code") or args.get("code")
|
||||||
|
if not seccode:
|
||||||
|
return {"success": False, "error": "缺少股票代码参数 (seccode/stock_code/code)"}
|
||||||
|
|
||||||
result = await db.get_stock_basic_info(seccode)
|
result = await db.get_stock_basic_info(seccode)
|
||||||
if result:
|
if result:
|
||||||
return {"success": True, "data": result}
|
return {"success": True, "data": result}
|
||||||
@@ -1803,6 +1837,7 @@ TOOL_HANDLERS = {
|
|||||||
"search_limit_up_stocks": handle_search_limit_up_stocks,
|
"search_limit_up_stocks": handle_search_limit_up_stocks,
|
||||||
"get_daily_stock_analysis": handle_get_daily_stock_analysis,
|
"get_daily_stock_analysis": handle_get_daily_stock_analysis,
|
||||||
"search_research_reports": handle_search_research_reports,
|
"search_research_reports": handle_search_research_reports,
|
||||||
|
"get_stock_code_by_name": handle_get_stock_code_by_name,
|
||||||
"get_stock_basic_info": handle_get_stock_basic_info,
|
"get_stock_basic_info": handle_get_stock_basic_info,
|
||||||
"get_stock_financial_index": handle_get_stock_financial_index,
|
"get_stock_financial_index": handle_get_stock_financial_index,
|
||||||
"get_stock_trade_data": handle_get_stock_trade_data,
|
"get_stock_trade_data": handle_get_stock_trade_data,
|
||||||
@@ -2549,8 +2584,15 @@ A股交易时间: 上午 9:30-11:30,下午 13:00-15:00
|
|||||||
# 如果没有原生工具调用,尝试从文本内容中解析
|
# 如果没有原生工具调用,尝试从文本内容中解析
|
||||||
if not native_tool_calls and assistant_message.content:
|
if not native_tool_calls and assistant_message.content:
|
||||||
content = assistant_message.content
|
content = assistant_message.content
|
||||||
# 检查是否包含工具调用标记
|
# 检查是否包含工具调用标记(包括 DSML 格式)
|
||||||
if '<tool_call>' in content or '```tool_call' in content or '"tool":' in content:
|
has_tool_markers = (
|
||||||
|
'<tool_call>' in content or
|
||||||
|
'```tool_call' in content or
|
||||||
|
'"tool":' in content or
|
||||||
|
'DSML' in content or # DeepSeek DSML 格式
|
||||||
|
'|DSML|' in content # 全角竖线版本
|
||||||
|
)
|
||||||
|
if has_tool_markers:
|
||||||
logger.info(f"[Agent Stream] 尝试从文本内容解析工具调用")
|
logger.info(f"[Agent Stream] 尝试从文本内容解析工具调用")
|
||||||
logger.info(f"[Agent Stream] 内容预览: {content[:500]}")
|
logger.info(f"[Agent Stream] 内容预览: {content[:500]}")
|
||||||
text_tool_calls = self._parse_text_tool_calls(content)
|
text_tool_calls = self._parse_text_tool_calls(content)
|
||||||
@@ -2947,6 +2989,7 @@ A股交易时间: 上午 9:30-11:30,下午 13:00-15:00
|
|||||||
支持的格式:
|
支持的格式:
|
||||||
1. <tool_call> <function=xxx> <parameter=yyy> value </parameter> </function> </tool_call>
|
1. <tool_call> <function=xxx> <parameter=yyy> value </parameter> </function> </tool_call>
|
||||||
2. ```tool_call\n{"name": "xxx", "arguments": {...}}\n```
|
2. ```tool_call\n{"name": "xxx", "arguments": {...}}\n```
|
||||||
|
3. DeepSeek DSML 格式: <|DSML|function_calls> <|DSML|invoke name="xxx"> <|DSML|parameter name="yyy" string="true">value</|DSML|parameter> </|DSML|invoke> </|DSML|function_calls>
|
||||||
|
|
||||||
返回: [{"name": "tool_name", "arguments": {...}}, ...]
|
返回: [{"name": "tool_name", "arguments": {...}}, ...]
|
||||||
"""
|
"""
|
||||||
@@ -3006,6 +3049,47 @@ A股交易时间: 上午 9:30-11:30,下午 13:00-15:00
|
|||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# 格式4: DeepSeek DSML 格式(使用全角竖线 |)
|
||||||
|
# <|DSML|function_calls> <|DSML|invoke name="search_research_reports"> <|DSML|parameter name="query" string="true">AI概念股</|DSML|parameter> </|DSML|invoke> </|DSML|function_calls>
|
||||||
|
# 注意:| 是全角字符
|
||||||
|
dsml_pattern = r'<[|\|]DSML[|\|]function_calls>(.*?)</[|\|]DSML[|\|]function_calls>'
|
||||||
|
dsml_matches = re.findall(dsml_pattern, content, re.DOTALL)
|
||||||
|
|
||||||
|
for dsml_content in dsml_matches:
|
||||||
|
# 解析 invoke 标签
|
||||||
|
invoke_pattern = r'<[|\|]DSML[|\|]invoke\s+name="(\w+)">(.*?)</[|\|]DSML[|\|]invoke>'
|
||||||
|
invoke_matches = re.findall(invoke_pattern, dsml_content, re.DOTALL)
|
||||||
|
|
||||||
|
for func_name, params_str in invoke_matches:
|
||||||
|
arguments = {}
|
||||||
|
# 解析参数: <|DSML|parameter name="xxx" string="true/false">value</|DSML|parameter>
|
||||||
|
param_pattern = r'<[|\|]DSML[|\|]parameter\s+name="(\w+)"\s+string="(true|false)">(.*?)</[|\|]DSML[|\|]parameter>'
|
||||||
|
param_matches = re.findall(param_pattern, params_str, re.DOTALL)
|
||||||
|
|
||||||
|
for param_name, is_string, param_value in param_matches:
|
||||||
|
param_value = param_value.strip()
|
||||||
|
if is_string == "false":
|
||||||
|
# 不是字符串,尝试解析为数字或 JSON
|
||||||
|
try:
|
||||||
|
arguments[param_name] = json.loads(param_value)
|
||||||
|
except:
|
||||||
|
# 尝试转为整数或浮点数
|
||||||
|
try:
|
||||||
|
arguments[param_name] = int(param_value)
|
||||||
|
except:
|
||||||
|
try:
|
||||||
|
arguments[param_name] = float(param_value)
|
||||||
|
except:
|
||||||
|
arguments[param_name] = param_value
|
||||||
|
else:
|
||||||
|
# 是字符串
|
||||||
|
arguments[param_name] = param_value
|
||||||
|
|
||||||
|
tool_calls.append({
|
||||||
|
"name": func_name,
|
||||||
|
"arguments": arguments
|
||||||
|
})
|
||||||
|
|
||||||
logger.info(f"[Text Tool Call] 解析到 {len(tool_calls)} 个工具调用: {tool_calls}")
|
logger.info(f"[Text Tool Call] 解析到 {len(tool_calls)} 个工具调用: {tool_calls}")
|
||||||
return tool_calls
|
return tool_calls
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user