diff --git a/kimi_integration.py b/kimi_integration.py new file mode 100644 index 00000000..3f74c99a --- /dev/null +++ b/kimi_integration.py @@ -0,0 +1,361 @@ +""" +Kimi API 集成示例 +演示如何将MCP工具与Kimi大模型结合使用 +""" + +from openai import OpenAI +import json +from typing import List, Dict, Any +from mcp_client_example import MCPClient + +# Kimi API配置 +KIMI_API_KEY = "sk-TzB4VYJfCoXGcGrGMiewukVRzjuDsbVCkaZXi2LvkS8s60E5" +KIMI_BASE_URL = "https://api.moonshot.cn/v1" +KIMI_MODEL = "kimi-k2-turbo-preview" + +# 初始化Kimi客户端 +kimi_client = OpenAI( + api_key=KIMI_API_KEY, + base_url=KIMI_BASE_URL, +) + +# 初始化MCP客户端 +mcp_client = MCPClient() + + +def convert_mcp_tools_to_kimi_format() -> tuple[List[Dict], Dict]: + """ + 将MCP工具转换为Kimi API的tools格式 + + Returns: + tools: Kimi格式的工具列表 + tool_map: 工具名称到执行函数的映射 + """ + # 获取所有MCP工具 + mcp_tools_response = mcp_client.list_tools() + mcp_tools = mcp_tools_response["tools"] + + # 转换为Kimi格式 + kimi_tools = [] + tool_map = {} + + for tool in mcp_tools: + # Kimi工具格式 + kimi_tool = { + "type": "function", + "function": { + "name": tool["name"], + "description": tool["description"], + "parameters": tool["parameters"] + } + } + kimi_tools.append(kimi_tool) + + # 创建工具执行函数 + tool_name = tool["name"] + tool_map[tool_name] = lambda args, name=tool_name: execute_mcp_tool(name, args) + + return kimi_tools, tool_map + + +def execute_mcp_tool(tool_name: str, arguments: Dict[str, Any]) -> Dict[str, Any]: + """ + 执行MCP工具 + + Args: + tool_name: 工具名称 + arguments: 工具参数 + + Returns: + 工具执行结果 + """ + print(f"[工具调用] {tool_name}") + print(f"[参数] {json.dumps(arguments, ensure_ascii=False, indent=2)}") + + result = mcp_client.call_tool(tool_name, arguments) + + print(f"[结果] 成功: {result.get('success', False)}") + + return result + + +def chat_with_kimi(user_message: str, verbose: bool = True) -> str: + """ + 与Kimi进行对话,支持工具调用 + + Args: + user_message: 用户消息 + verbose: 是否打印详细信息 + + Returns: + Kimi的回复 + """ + # 获取Kimi格式的工具 + tools, tool_map = convert_mcp_tools_to_kimi_format() + + if verbose: + print(f"\n{'='*60}") + print(f"加载了 {len(tools)} 个工具") + print(f"{'='*60}\n") + + # 初始化对话 + messages = [ + { + "role": "system", + "content": """你是一个专业的金融数据分析助手,由 Moonshot AI 提供支持。 +你可以使用各种工具来帮助用户查询和分析金融数据,包括: +- 新闻搜索(全球新闻、中国新闻、医疗新闻) +- 公司研究(路演信息、研究报告) +- 概念板块分析 +- 股票分析(涨停分析、财务数据、交易数据) +- 财务报表(资产负债表、现金流量表) + +请根据用户的问题,选择合适的工具来获取信息,并提供专业的分析和建议。""" + }, + { + "role": "user", + "content": user_message + } + ] + + if verbose: + print(f"[用户]: {user_message}\n") + + # 对话循环,处理工具调用 + finish_reason = None + iteration = 0 + max_iterations = 10 # 防止无限循环 + + while finish_reason is None or finish_reason == "tool_calls": + iteration += 1 + if iteration > max_iterations: + print("[警告] 达到最大迭代次数") + break + + if verbose and iteration > 1: + print(f"\n[轮次 {iteration}]") + + # 调用Kimi API + completion = kimi_client.chat.completions.create( + model=KIMI_MODEL, + messages=messages, + temperature=0.6, # Kimi推荐的temperature值 + tools=tools, + ) + + choice = completion.choices[0] + finish_reason = choice.finish_reason + + if verbose: + print(f"[Kimi] finish_reason: {finish_reason}") + + # 处理工具调用 + if finish_reason == "tool_calls": + # 将Kimi的消息添加到上下文 + messages.append(choice.message) + + # 执行每个工具调用 + for tool_call in choice.message.tool_calls: + tool_name = tool_call.function.name + tool_arguments = json.loads(tool_call.function.arguments) + + # 执行工具 + tool_result = tool_map[tool_name](tool_arguments) + + # 将工具结果添加到消息中 + messages.append({ + "role": "tool", + "tool_call_id": tool_call.id, + "name": tool_name, + "content": json.dumps(tool_result, ensure_ascii=False), + }) + + if verbose: + print() # 空行分隔 + + # 返回最终回复 + final_response = choice.message.content + + if verbose: + print(f"\n[Kimi]: {final_response}\n") + print(f"{'='*60}") + + return final_response + + +def demo_simple_query(): + """演示1: 简单查询""" + print("\n" + "="*60) + print("演示1: 简单新闻查询") + print("="*60) + + response = chat_with_kimi("帮我查找关于人工智能的最新新闻") + return response + + +def demo_stock_analysis(): + """演示2: 股票分析""" + print("\n" + "="*60) + print("演示2: 股票财务分析") + print("="*60) + + response = chat_with_kimi("帮我分析贵州茅台(600519)的财务状况") + return response + + +def demo_concept_research(): + """演示3: 概念研究""" + print("\n" + "="*60) + print("演示3: 概念板块研究") + print("="*60) + + response = chat_with_kimi("查找新能源汽车相关的概念板块,并告诉我涨幅最高的是哪些") + return response + + +def demo_industry_comparison(): + """演示4: 行业对比""" + print("\n" + "="*60) + print("演示4: 行业内股票对比") + print("="*60) + + response = chat_with_kimi("帮我找出半导体行业的龙头股票,并对比它们的财务指标") + return response + + +def demo_comprehensive_analysis(): + """演示5: 综合分析""" + print("\n" + "="*60) + print("演示5: 综合分析") + print("="*60) + + response = chat_with_kimi(""" + 我想投资白酒行业,请帮我: + 1. 搜索白酒行业的主要上市公司 + 2. 对比贵州茅台和五粮液的财务数据 + 3. 查看最近的行业新闻 + 4. 给出投资建议 + """) + return response + + +def interactive_chat(): + """交互式对话""" + print("\n" + "="*60) + print("Kimi 金融助手 - 交互模式") + print("="*60) + print("提示:输入 'quit' 或 'exit' 退出") + print("="*60 + "\n") + + while True: + try: + user_input = input("你: ").strip() + + if not user_input: + continue + + if user_input.lower() in ['quit', 'exit', '退出']: + print("\n再见!") + break + + response = chat_with_kimi(user_input) + + except KeyboardInterrupt: + print("\n\n再见!") + break + except Exception as e: + print(f"\n[错误] {str(e)}\n") + + +def test_kimi_connection(): + """测试Kimi API连接""" + print("\n" + "="*60) + print("测试 Kimi API 连接") + print("="*60 + "\n") + + try: + # 简单的测试请求 + response = kimi_client.chat.completions.create( + model=KIMI_MODEL, + messages=[ + {"role": "user", "content": "你好,请介绍一下你自己"} + ], + temperature=0.6 + ) + + print("[✓] 连接成功!") + print(f"[✓] 模型: {KIMI_MODEL}") + print(f"[✓] 回复: {response.choices[0].message.content}\n") + + return True + except Exception as e: + print(f"[✗] 连接失败: {str(e)}\n") + return False + + +def show_available_tools(): + """显示所有可用工具""" + print("\n" + "="*60) + print("可用工具列表") + print("="*60 + "\n") + + tools, _ = convert_mcp_tools_to_kimi_format() + + for i, tool in enumerate(tools, 1): + func = tool["function"] + print(f"{i}. {func['name']}") + print(f" 描述: {func['description'][:80]}...") + print() + + print(f"总计: {len(tools)} 个工具\n") + + +if __name__ == "__main__": + import sys + + # 首先测试连接 + if not test_kimi_connection(): + print("请检查API Key和网络连接") + sys.exit(1) + + # 显示可用工具 + show_available_tools() + + # 运行演示 + print("\n选择运行模式:") + print("1. 简单查询演示") + print("2. 股票分析演示") + print("3. 概念研究演示") + print("4. 行业对比演示") + print("5. 综合分析演示") + print("6. 交互式对话") + print("7. 运行所有演示") + + try: + choice = input("\n请选择 (1-7): ").strip() + + if choice == "1": + demo_simple_query() + elif choice == "2": + demo_stock_analysis() + elif choice == "3": + demo_concept_research() + elif choice == "4": + demo_industry_comparison() + elif choice == "5": + demo_comprehensive_analysis() + elif choice == "6": + interactive_chat() + elif choice == "7": + demo_simple_query() + demo_stock_analysis() + demo_concept_research() + demo_industry_comparison() + demo_comprehensive_analysis() + else: + print("无效选择") + + except KeyboardInterrupt: + print("\n\n程序已退出") + finally: + mcp_client.close() diff --git a/mcp_client_example.py b/mcp_client_example.py new file mode 100644 index 00000000..19727435 --- /dev/null +++ b/mcp_client_example.py @@ -0,0 +1,248 @@ +""" +MCP客户端使用示例 +演示如何调用MCP服务器的各种工具 +""" + +import httpx +import json +from typing import Dict, Any + + +class MCPClient: + """MCP客户端""" + + def __init__(self, base_url: str = "http://localhost:8900"): + self.base_url = base_url + self.client = httpx.Client(timeout=60.0) + + def list_tools(self): + """列出所有可用工具""" + response = self.client.get(f"{self.base_url}/tools") + response.raise_for_status() + return response.json() + + def get_tool(self, tool_name: str): + """获取特定工具的定义""" + response = self.client.get(f"{self.base_url}/tools/{tool_name}") + response.raise_for_status() + return response.json() + + def call_tool(self, tool_name: str, arguments: Dict[str, Any]): + """调用工具""" + payload = { + "tool": tool_name, + "arguments": arguments + } + response = self.client.post(f"{self.base_url}/tools/call", json=payload) + response.raise_for_status() + return response.json() + + def close(self): + """关闭客户端""" + self.client.close() + + +def print_result(title: str, result: Dict[str, Any]): + """打印结果""" + print(f"\n{'=' * 60}") + print(f"{title}") + print(f"{'=' * 60}") + print(json.dumps(result, ensure_ascii=False, indent=2)) + + +def main(): + """主函数 - 演示各种工具的使用""" + + client = MCPClient() + + try: + # 1. 列出所有工具 + print("\n示例1: 列出所有可用工具") + tools = client.list_tools() + print(f"可用工具数量: {len(tools['tools'])}") + for tool in tools['tools']: + print(f" - {tool['name']}: {tool['description'][:50]}...") + + # 2. 搜索中国新闻 + print("\n示例2: 搜索中国新闻(关键词:人工智能)") + result = client.call_tool( + "search_china_news", + { + "query": "人工智能", + "top_k": 5 + } + ) + if result['success']: + print_result("中国新闻搜索结果", result['data']) + + # 3. 搜索概念板块(按涨跌幅排序) + print("\n示例3: 搜索概念板块(关键词:新能源,按涨跌幅排序)") + result = client.call_tool( + "search_concepts", + { + "query": "新能源", + "size": 5, + "sort_by": "change_pct" + } + ) + if result['success']: + print_result("概念搜索结果", result['data']) + + # 4. 获取股票的相关概念 + print("\n示例4: 获取股票相关概念(股票代码:600519)") + result = client.call_tool( + "get_stock_concepts", + { + "stock_code": "600519", + "size": 10 + } + ) + if result['success']: + print_result("股票概念结果", result['data']) + + # 5. 搜索涨停股票 + print("\n示例5: 搜索涨停股票(关键词:锂电池)") + result = client.call_tool( + "search_limit_up_stocks", + { + "query": "锂电池", + "mode": "hybrid", + "page_size": 5 + } + ) + if result['success']: + print_result("涨停股票搜索结果", result['data']) + + # 6. 搜索研究报告 + print("\n示例6: 搜索研究报告(关键词:投资策略)") + result = client.call_tool( + "search_research_reports", + { + "query": "投资策略", + "mode": "hybrid", + "size": 3 + } + ) + if result['success']: + print_result("研究报告搜索结果", result['data']) + + # 7. 获取概念统计数据 + print("\n示例7: 获取概念统计(最近7天)") + result = client.call_tool( + "get_concept_statistics", + { + "days": 7, + "min_stock_count": 3 + } + ) + if result['success']: + print_result("概念统计结果", result['data']) + + # 8. 搜索路演信息 + print("\n示例8: 搜索路演信息(关键词:业绩)") + result = client.call_tool( + "search_roadshows", + { + "query": "业绩", + "size": 3 + } + ) + if result['success']: + print_result("路演搜索结果", result['data']) + + # 9. 获取股票基本信息 + print("\n示例9: 获取股票基本信息(股票:600519)") + result = client.call_tool( + "get_stock_basic_info", + { + "seccode": "600519" + } + ) + if result['success']: + print_result("股票基本信息", result['data']) + + # 10. 获取股票财务指标 + print("\n示例10: 获取股票财务指标(股票:600519,最近5期)") + result = client.call_tool( + "get_stock_financial_index", + { + "seccode": "600519", + "limit": 5 + } + ) + if result['success']: + print_result("财务指标", result['data']) + + # 11. 获取股票交易数据 + print("\n示例11: 获取股票交易数据(股票:600519,最近10天)") + result = client.call_tool( + "get_stock_trade_data", + { + "seccode": "600519", + "limit": 10 + } + ) + if result['success']: + print_result("交易数据", result['data']) + + # 12. 按行业搜索股票 + print("\n示例12: 按行业搜索股票(行业:半导体)") + result = client.call_tool( + "search_stocks_by_criteria", + { + "industry": "半导体", + "limit": 10 + } + ) + if result['success']: + print_result("行业股票", result['data']) + + # 13. 股票对比分析 + print("\n示例13: 股票对比分析(600519 vs 000858)") + result = client.call_tool( + "get_stock_comparison", + { + "seccodes": ["600519", "000858"], + "metric": "financial" + } + ) + if result['success']: + print_result("股票对比", result['data']) + + except Exception as e: + print(f"\n错误: {str(e)}") + + finally: + client.close() + + +def test_single_tool(): + """测试单个工具(用于快速测试)""" + client = MCPClient() + + try: + # 修改这里来测试不同的工具 + result = client.call_tool( + "search_china_news", + { + "query": "芯片", + "exact_match": True, + "top_k": 3 + } + ) + + print_result("测试结果", result) + + except Exception as e: + print(f"错误: {str(e)}") + + finally: + client.close() + + +if __name__ == "__main__": + # 运行完整示例 + main() + + # 或者测试单个工具 + # test_single_tool() diff --git a/mcp_config.py b/mcp_config.py new file mode 100644 index 00000000..fd82ea95 --- /dev/null +++ b/mcp_config.py @@ -0,0 +1,108 @@ +""" +MCP服务器配置文件 +集中管理所有配置项 +""" + +from typing import Dict +from pydantic import BaseSettings + +class Settings(BaseSettings): + """应用配置""" + + # 服务器配置 + SERVER_HOST: str = "0.0.0.0" + SERVER_PORT: int = 8900 + DEBUG: bool = True + + # 后端API服务端点 + NEWS_API_URL: str = "http://222.128.1.157:21891" + ROADSHOW_API_URL: str = "http://222.128.1.157:19800" + CONCEPT_API_URL: str = "http://222.128.1.157:16801" + STOCK_ANALYSIS_API_URL: str = "http://222.128.1.157:8811" + + # HTTP客户端配置 + HTTP_TIMEOUT: float = 60.0 + HTTP_MAX_RETRIES: int = 3 + + # 日志配置 + LOG_LEVEL: str = "INFO" + LOG_FORMAT: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + + # CORS配置 + CORS_ORIGINS: list = ["*"] + CORS_CREDENTIALS: bool = True + CORS_METHODS: list = ["*"] + CORS_HEADERS: list = ["*"] + + # LLM配置(如果需要集成) + LLM_PROVIDER: str = "openai" # openai, anthropic, etc. + LLM_API_KEY: str = "" + LLM_MODEL: str = "gpt-4" + LLM_BASE_URL: str = "" + + # 速率限制 + RATE_LIMIT_ENABLED: bool = False + RATE_LIMIT_PER_MINUTE: int = 60 + + # 缓存配置 + CACHE_ENABLED: bool = True + CACHE_TTL: int = 300 # 秒 + + class Config: + env_file = ".env" + case_sensitive = True + + +# 全局设置实例 +settings = Settings() + + +# 工具类别映射(用于组织和展示) +TOOL_CATEGORIES: Dict[str, list] = { + "新闻搜索": [ + "search_news", + "search_china_news", + "search_medical_news" + ], + "公司研究": [ + "search_roadshows", + "search_research_reports" + ], + "概念板块": [ + "search_concepts", + "get_concept_details", + "get_stock_concepts", + "get_concept_statistics" + ], + "股票分析": [ + "search_limit_up_stocks", + "get_daily_stock_analysis" + ] +} + + +# 工具优先级(用于LLM选择工具时的提示) +TOOL_PRIORITIES: Dict[str, int] = { + "search_china_news": 10, # 最常用 + "search_concepts": 9, + "search_limit_up_stocks": 8, + "search_research_reports": 8, + "get_stock_concepts": 7, + "search_news": 6, + "get_daily_stock_analysis": 5, + "get_concept_statistics": 5, + "search_medical_news": 4, + "search_roadshows": 4, + "get_concept_details": 3, +} + + +# 默认参数配置 +DEFAULT_PARAMS = { + "top_k": 20, + "page_size": 20, + "size": 10, + "sort_by": "change_pct", + "mode": "hybrid", + "exact_match": False, +} diff --git a/mcp_database.py b/mcp_database.py new file mode 100644 index 00000000..ec2df704 --- /dev/null +++ b/mcp_database.py @@ -0,0 +1,546 @@ +""" +MySQL数据库查询模块 +提供股票财务数据查询功能 +""" + +import aiomysql +import logging +from typing import Dict, List, Any, Optional +from datetime import datetime, date +from decimal import Decimal +import json + +logger = logging.getLogger(__name__) + +# MySQL连接配置 +MYSQL_CONFIG = { + 'host': '222.128.1.157', + 'port': 33060, + 'user': 'root', + 'password': 'Zzl5588161!', + 'db': 'stock', + 'charset': 'utf8mb4', + 'autocommit': True +} + +# 全局连接池 +_pool = None + + +class DateTimeEncoder(json.JSONEncoder): + """JSON编码器,处理datetime和Decimal类型""" + def default(self, obj): + if isinstance(obj, (datetime, date)): + return obj.isoformat() + if isinstance(obj, Decimal): + return float(obj) + return super().default(obj) + + +async def get_pool(): + """获取MySQL连接池""" + global _pool + if _pool is None: + _pool = await aiomysql.create_pool( + host=MYSQL_CONFIG['host'], + port=MYSQL_CONFIG['port'], + user=MYSQL_CONFIG['user'], + password=MYSQL_CONFIG['password'], + db=MYSQL_CONFIG['db'], + charset=MYSQL_CONFIG['charset'], + autocommit=MYSQL_CONFIG['autocommit'], + minsize=1, + maxsize=10 + ) + logger.info("MySQL connection pool created") + return _pool + + +async def close_pool(): + """关闭MySQL连接池""" + global _pool + if _pool: + _pool.close() + await _pool.wait_closed() + _pool = None + logger.info("MySQL connection pool closed") + + +def convert_row(row: Dict) -> Dict: + """转换数据库行,处理特殊类型""" + if not row: + return {} + + result = {} + for key, value in row.items(): + if isinstance(value, Decimal): + result[key] = float(value) + elif isinstance(value, (datetime, date)): + result[key] = value.isoformat() + else: + result[key] = value + return result + + +async def get_stock_basic_info(seccode: str) -> Optional[Dict[str, Any]]: + """ + 获取股票基本信息 + + Args: + seccode: 股票代码 + + Returns: + 股票基本信息字典 + """ + pool = await get_pool() + + async with pool.acquire() as conn: + async with conn.cursor(aiomysql.DictCursor) as cursor: + query = """ + SELECT + SECCODE, SECNAME, ORGNAME, + F001V as english_name, + F003V as legal_representative, + F004V as registered_address, + F005V as office_address, + F010D as establishment_date, + F011V as website, + F012V as email, + F013V as phone, + F015V as main_business, + F016V as business_scope, + F017V as company_profile, + F030V as industry_level1, + F032V as industry_level2, + F034V as sw_industry_level1, + F036V as sw_industry_level2, + F026V as province, + F028V as city, + F041V as chairman, + F042V as general_manager, + UPDATE_DATE as update_date + FROM ea_baseinfo + WHERE SECCODE = %s + LIMIT 1 + """ + + await cursor.execute(query, (seccode,)) + result = await cursor.fetchone() + + if result: + return convert_row(result) + return None + + +async def get_stock_financial_index( + seccode: str, + start_date: Optional[str] = None, + end_date: Optional[str] = None, + limit: int = 10 +) -> List[Dict[str, Any]]: + """ + 获取股票财务指标 + + Args: + seccode: 股票代码 + start_date: 开始日期 YYYY-MM-DD + end_date: 结束日期 YYYY-MM-DD + limit: 返回条数 + + Returns: + 财务指标列表 + """ + pool = await get_pool() + + async with pool.acquire() as conn: + async with conn.cursor(aiomysql.DictCursor) as cursor: + # 构建查询 + query = """ + SELECT + SECCODE, SECNAME, ENDDATE, STARTDATE, + F069D as report_year, + F003N as eps, -- 每股收益 + F004N as basic_eps, + F008N as bps, -- 每股净资产 + F014N as roe, -- 净资产收益率 + F016N as roa, -- 总资产报酬率 + F017N as net_profit_margin, -- 净利润率 + F022N as receivable_turnover, -- 应收账款周转率 + F023N as inventory_turnover, -- 存货周转率 + F025N as total_asset_turnover, -- 总资产周转率 + F041N as debt_ratio, -- 资产负债率 + F042N as current_ratio, -- 流动比率 + F043N as quick_ratio, -- 速动比率 + F052N as revenue_growth, -- 营业收入增长率 + F053N as profit_growth, -- 净利润增长率 + F089N as revenue, -- 营业收入 + F090N as operating_cost, -- 营业成本 + F101N as net_profit, -- 净利润 + F102N as net_profit_parent -- 归母净利润 + FROM ea_financialindex + WHERE SECCODE = %s + """ + + params = [seccode] + + if start_date: + query += " AND ENDDATE >= %s" + params.append(start_date) + + if end_date: + query += " AND ENDDATE <= %s" + params.append(end_date) + + query += " ORDER BY ENDDATE DESC LIMIT %s" + params.append(limit) + + await cursor.execute(query, params) + results = await cursor.fetchall() + + return [convert_row(row) for row in results] + + +async def get_stock_trade_data( + seccode: str, + start_date: Optional[str] = None, + end_date: Optional[str] = None, + limit: int = 30 +) -> List[Dict[str, Any]]: + """ + 获取股票交易数据 + + Args: + seccode: 股票代码 + start_date: 开始日期 YYYY-MM-DD + end_date: 结束日期 YYYY-MM-DD + limit: 返回条数 + + Returns: + 交易数据列表 + """ + pool = await get_pool() + + async with pool.acquire() as conn: + async with conn.cursor(aiomysql.DictCursor) as cursor: + query = """ + SELECT + SECCODE, SECNAME, TRADEDATE, + F002N as prev_close, -- 昨日收盘价 + F003N as open_price, -- 开盘价 + F005N as high_price, -- 最高价 + F006N as low_price, -- 最低价 + F007N as close_price, -- 收盘价 + F004N as volume, -- 成交量 + F011N as turnover, -- 成交金额 + F009N as change_amount, -- 涨跌额 + F010N as change_pct, -- 涨跌幅 + F012N as turnover_rate, -- 换手率 + F013N as amplitude, -- 振幅 + F026N as pe_ratio, -- 市盈率 + F020N as total_shares, -- 总股本 + F021N as circulating_shares -- 流通股本 + FROM ea_trade + WHERE SECCODE = %s + """ + + params = [seccode] + + if start_date: + query += " AND TRADEDATE >= %s" + params.append(start_date) + + if end_date: + query += " AND TRADEDATE <= %s" + params.append(end_date) + + query += " ORDER BY TRADEDATE DESC LIMIT %s" + params.append(limit) + + await cursor.execute(query, params) + results = await cursor.fetchall() + + return [convert_row(row) for row in results] + + +async def get_stock_balance_sheet( + seccode: str, + start_date: Optional[str] = None, + end_date: Optional[str] = None, + limit: int = 8 +) -> List[Dict[str, Any]]: + """ + 获取资产负债表数据 + + Args: + seccode: 股票代码 + start_date: 开始日期 + end_date: 结束日期 + limit: 返回条数 + + Returns: + 资产负债表数据列表 + """ + pool = await get_pool() + + async with pool.acquire() as conn: + async with conn.cursor(aiomysql.DictCursor) as cursor: + query = """ + SELECT + SECCODE, SECNAME, ENDDATE, + F001D as report_year, + F006N as cash, -- 货币资金 + F009N as receivables, -- 应收账款 + F015N as inventory, -- 存货 + F019N as current_assets, -- 流动资产合计 + F023N as long_term_investment, -- 长期股权投资 + F025N as fixed_assets, -- 固定资产 + F037N as noncurrent_assets, -- 非流动资产合计 + F038N as total_assets, -- 资产总计 + F039N as short_term_loan, -- 短期借款 + F042N as payables, -- 应付账款 + F052N as current_liabilities, -- 流动负债合计 + F053N as long_term_loan, -- 长期借款 + F060N as noncurrent_liabilities, -- 非流动负债合计 + F061N as total_liabilities, -- 负债合计 + F062N as share_capital, -- 股本 + F063N as capital_reserve, -- 资本公积 + F065N as retained_earnings, -- 未分配利润 + F070N as total_equity -- 所有者权益合计 + FROM ea_asset + WHERE SECCODE = %s + """ + + params = [seccode] + + if start_date: + query += " AND ENDDATE >= %s" + params.append(start_date) + + if end_date: + query += " AND ENDDATE <= %s" + params.append(end_date) + + query += " ORDER BY ENDDATE DESC LIMIT %s" + params.append(limit) + + await cursor.execute(query, params) + results = await cursor.fetchall() + + return [convert_row(row) for row in results] + + +async def get_stock_cashflow( + seccode: str, + start_date: Optional[str] = None, + end_date: Optional[str] = None, + limit: int = 8 +) -> List[Dict[str, Any]]: + """ + 获取现金流量表数据 + + Args: + seccode: 股票代码 + start_date: 开始日期 + end_date: 结束日期 + limit: 返回条数 + + Returns: + 现金流量表数据列表 + """ + pool = await get_pool() + + async with pool.acquire() as conn: + async with conn.cursor(aiomysql.DictCursor) as cursor: + query = """ + SELECT + SECCODE, SECNAME, ENDDATE, STARTDATE, + F001D as report_year, + F009N as operating_cash_inflow, -- 经营活动现金流入 + F014N as operating_cash_outflow, -- 经营活动现金流出 + F015N as net_operating_cashflow, -- 经营活动现金流量净额 + F021N as investing_cash_inflow, -- 投资活动现金流入 + F026N as investing_cash_outflow, -- 投资活动现金流出 + F027N as net_investing_cashflow, -- 投资活动现金流量净额 + F031N as financing_cash_inflow, -- 筹资活动现金流入 + F035N as financing_cash_outflow, -- 筹资活动现金流出 + F036N as net_financing_cashflow, -- 筹资活动现金流量净额 + F039N as net_cash_increase, -- 现金及现金等价物净增加额 + F044N as net_profit, -- 净利润 + F046N as depreciation, -- 固定资产折旧 + F060N as net_operating_cashflow_adjusted -- 经营活动现金流量净额(补充) + FROM ea_cashflow + WHERE SECCODE = %s + """ + + params = [seccode] + + if start_date: + query += " AND ENDDATE >= %s" + params.append(start_date) + + if end_date: + query += " AND ENDDATE <= %s" + params.append(end_date) + + query += " ORDER BY ENDDATE DESC LIMIT %s" + params.append(limit) + + await cursor.execute(query, params) + results = await cursor.fetchall() + + return [convert_row(row) for row in results] + + +async def search_stocks_by_criteria( + industry: Optional[str] = None, + province: Optional[str] = None, + min_market_cap: Optional[float] = None, + max_market_cap: Optional[float] = None, + limit: int = 50 +) -> List[Dict[str, Any]]: + """ + 按条件搜索股票 + + Args: + industry: 行业名称 + province: 省份 + min_market_cap: 最小市值(亿元) + max_market_cap: 最大市值(亿元) + limit: 返回条数 + + Returns: + 股票列表 + """ + pool = await get_pool() + + async with pool.acquire() as conn: + async with conn.cursor(aiomysql.DictCursor) as cursor: + query = """ + SELECT DISTINCT + b.SECCODE, + b.SECNAME, + b.F030V as industry_level1, + b.F032V as industry_level2, + b.F034V as sw_industry_level1, + b.F026V as province, + b.F028V as city, + b.F015V as main_business, + t.F007N as latest_price, + t.F010N as change_pct, + t.F026N as pe_ratio, + t.TRADEDATE as latest_trade_date + FROM ea_baseinfo b + LEFT JOIN ( + SELECT SECCODE, MAX(TRADEDATE) as max_date + FROM ea_trade + GROUP BY SECCODE + ) latest ON b.SECCODE = latest.SECCODE + LEFT JOIN ea_trade t ON b.SECCODE = t.SECCODE + AND t.TRADEDATE = latest.max_date + WHERE 1=1 + """ + + params = [] + + if industry: + query += " AND (b.F030V LIKE %s OR b.F032V LIKE %s OR b.F034V LIKE %s)" + pattern = f"%{industry}%" + params.extend([pattern, pattern, pattern]) + + if province: + query += " AND b.F026V = %s" + params.append(province) + + if min_market_cap or max_market_cap: + # 市值 = 最新价 * 总股本 / 100000000(转换为亿元) + if min_market_cap: + query += " AND (t.F007N * t.F020N / 100000000) >= %s" + params.append(min_market_cap) + + if max_market_cap: + query += " AND (t.F007N * t.F020N / 100000000) <= %s" + params.append(max_market_cap) + + query += " ORDER BY t.TRADEDATE DESC LIMIT %s" + params.append(limit) + + await cursor.execute(query, params) + results = await cursor.fetchall() + + return [convert_row(row) for row in results] + + +async def get_stock_comparison( + seccodes: List[str], + metric: str = "financial" +) -> Dict[str, Any]: + """ + 股票对比分析 + + Args: + seccodes: 股票代码列表 + metric: 对比指标类型 (financial/trade) + + Returns: + 对比数据 + """ + pool = await get_pool() + + if not seccodes or len(seccodes) < 2: + return {"error": "至少需要2个股票代码进行对比"} + + async with pool.acquire() as conn: + async with conn.cursor(aiomysql.DictCursor) as cursor: + placeholders = ','.join(['%s'] * len(seccodes)) + + if metric == "financial": + # 对比最新财务指标 + query = f""" + SELECT + f.SECCODE, f.SECNAME, f.ENDDATE, + f.F003N as eps, + f.F008N as bps, + f.F014N as roe, + f.F017N as net_profit_margin, + f.F041N as debt_ratio, + f.F052N as revenue_growth, + f.F053N as profit_growth, + f.F089N as revenue, + f.F101N as net_profit + FROM ea_financialindex f + INNER JOIN ( + SELECT SECCODE, MAX(ENDDATE) as max_date + FROM ea_financialindex + WHERE SECCODE IN ({placeholders}) + GROUP BY SECCODE + ) latest ON f.SECCODE = latest.SECCODE + AND f.ENDDATE = latest.max_date + """ + else: # trade + # 对比最新交易数据 + query = f""" + SELECT + t.SECCODE, t.SECNAME, t.TRADEDATE, + t.F007N as close_price, + t.F010N as change_pct, + t.F012N as turnover_rate, + t.F026N as pe_ratio, + t.F020N as total_shares, + t.F021N as circulating_shares + FROM ea_trade t + INNER JOIN ( + SELECT SECCODE, MAX(TRADEDATE) as max_date + FROM ea_trade + WHERE SECCODE IN ({placeholders}) + GROUP BY SECCODE + ) latest ON t.SECCODE = latest.SECCODE + AND t.TRADEDATE = latest.max_date + """ + + await cursor.execute(query, seccodes) + results = await cursor.fetchall() + + return { + "comparison_type": metric, + "stocks": [convert_row(row) for row in results] + } diff --git a/mcp_server.py b/mcp_server.py new file mode 100644 index 00000000..54e67c73 --- /dev/null +++ b/mcp_server.py @@ -0,0 +1,1066 @@ +""" +MCP Server for Financial Data Search +基于FastAPI的MCP服务端,整合多个金融数据搜索API +支持LLM调用和Web聊天功能 +""" + +from fastapi import FastAPI, HTTPException, Request +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse +from pydantic import BaseModel, Field +from typing import List, Dict, Any, Optional, Literal +from datetime import datetime, date +import logging +import httpx +from enum import Enum +import mcp_database as db + +# 配置日志 +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# 创建FastAPI应用 +app = FastAPI( + title="Financial Data MCP Server", + description="Model Context Protocol server for financial data search and analysis", + version="1.0.0" +) + +# 添加CORS中间件 +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# ==================== 配置 ==================== + +class ServiceEndpoints: + """API服务端点配置""" + NEWS_API = "http://222.128.1.157:21891" # 新闻API + ROADSHOW_API = "http://222.128.1.157:19800" # 路演API + CONCEPT_API = "http://localhost:6801" # 概念API(本地) + STOCK_ANALYSIS_API = "http://222.128.1.157:8811" # 涨停分析+研报API + +# HTTP客户端配置 +HTTP_CLIENT = httpx.AsyncClient(timeout=60.0) + +# ==================== MCP协议数据模型 ==================== + +class ToolParameter(BaseModel): + """工具参数定义""" + type: str + description: str + enum: Optional[List[str]] = None + default: Optional[Any] = None + +class ToolDefinition(BaseModel): + """工具定义""" + name: str + description: str + parameters: Dict[str, Dict[str, Any]] + required: List[str] = [] + +class ToolCallRequest(BaseModel): + """工具调用请求""" + tool: str + arguments: Dict[str, Any] = {} + +class ToolCallResponse(BaseModel): + """工具调用响应""" + success: bool + data: Optional[Any] = None + error: Optional[str] = None + metadata: Optional[Dict[str, Any]] = None + +# ==================== MCP工具定义 ==================== + +TOOLS: List[ToolDefinition] = [ + ToolDefinition( + name="search_news", + description="搜索全球新闻,支持关键词搜索和日期过滤。适用于查找国际新闻、行业动态等。", + parameters={ + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "搜索关键词,例如:'人工智能'、'新能源汽车'" + }, + "source": { + "type": "string", + "description": "新闻来源筛选,可选" + }, + "start_date": { + "type": "string", + "description": "开始日期,格式:YYYY-MM-DD" + }, + "end_date": { + "type": "string", + "description": "结束日期,格式:YYYY-MM-DD" + }, + "top_k": { + "type": "integer", + "description": "返回结果数量,默认20", + "default": 20 + } + }, + "required": ["query"] + } + ), + ToolDefinition( + name="search_china_news", + description="搜索中国新闻,使用KNN语义搜索。支持精确匹配模式,适合查找股票、公司相关新闻。", + parameters={ + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "搜索关键词" + }, + "exact_match": { + "type": "boolean", + "description": "是否精确匹配(用于股票代码、公司名称等),默认false", + "default": False + }, + "source": { + "type": "string", + "description": "新闻来源筛选" + }, + "start_date": { + "type": "string", + "description": "开始日期,格式:YYYY-MM-DD" + }, + "end_date": { + "type": "string", + "description": "结束日期,格式:YYYY-MM-DD" + }, + "top_k": { + "type": "integer", + "description": "返回结果数量,默认20", + "default": 20 + } + }, + "required": ["query"] + } + ), + ToolDefinition( + name="search_medical_news", + description="搜索医疗健康类新闻,包括医药、医疗设备、生物技术等领域。", + parameters={ + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "搜索关键词" + }, + "source": { + "type": "string", + "description": "新闻来源" + }, + "start_date": { + "type": "string", + "description": "开始日期,格式:YYYY-MM-DD" + }, + "end_date": { + "type": "string", + "description": "结束日期,格式:YYYY-MM-DD" + }, + "top_k": { + "type": "integer", + "description": "返回结果数量", + "default": 10 + } + }, + "required": ["query"] + } + ), + ToolDefinition( + name="search_roadshows", + description="搜索上市公司路演、投资者交流活动记录。可按公司代码、日期范围搜索。", + parameters={ + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "搜索关键词,可以是公司名称、主题等" + }, + "company_code": { + "type": "string", + "description": "公司股票代码,例如:'600519.SH'" + }, + "start_date": { + "type": "string", + "description": "开始日期,格式:YYYY-MM-DD 或 YYYY-MM-DD HH:MM:SS" + }, + "end_date": { + "type": "string", + "description": "结束日期,格式:YYYY-MM-DD 或 YYYY-MM-DD HH:MM:SS" + }, + "size": { + "type": "integer", + "description": "返回结果数量", + "default": 10 + } + }, + "required": ["query"] + } + ), + ToolDefinition( + name="search_concepts", + description="搜索股票概念板块,支持按涨跌幅、股票数量排序。返回概念详情及相关股票列表。", + parameters={ + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "搜索关键词,例如:'新能源'、'人工智能'" + }, + "size": { + "type": "integer", + "description": "每页结果数量", + "default": 10 + }, + "page": { + "type": "integer", + "description": "页码", + "default": 1 + }, + "sort_by": { + "type": "string", + "description": "排序方式:change_pct(涨跌幅), _score(相关度), stock_count(股票数), concept_name(名称)", + "enum": ["change_pct", "_score", "stock_count", "concept_name"], + "default": "change_pct" + }, + "trade_date": { + "type": "string", + "description": "交易日期,格式:YYYY-MM-DD,默认最新" + } + }, + "required": ["query"] + } + ), + ToolDefinition( + name="get_concept_details", + description="根据概念ID获取详细信息,包括描述、相关股票、涨跌幅数据等。", + parameters={ + "type": "object", + "properties": { + "concept_id": { + "type": "string", + "description": "概念ID" + }, + "trade_date": { + "type": "string", + "description": "交易日期,格式:YYYY-MM-DD" + } + }, + "required": ["concept_id"] + } + ), + ToolDefinition( + name="get_stock_concepts", + description="查询指定股票的所有相关概念板块,包括涨跌幅信息。", + parameters={ + "type": "object", + "properties": { + "stock_code": { + "type": "string", + "description": "股票代码或名称" + }, + "size": { + "type": "integer", + "description": "返回概念数量", + "default": 50 + }, + "sort_by": { + "type": "string", + "description": "排序方式", + "enum": ["stock_count", "concept_name", "recent"], + "default": "stock_count" + }, + "trade_date": { + "type": "string", + "description": "交易日期,格式:YYYY-MM-DD" + } + }, + "required": ["stock_code"] + } + ), + ToolDefinition( + name="get_concept_statistics", + description="获取概念板块统计数据,包括涨幅榜、跌幅榜、活跃榜、波动榜、连涨榜。", + parameters={ + "type": "object", + "properties": { + "days": { + "type": "integer", + "description": "统计天数(与start_date/end_date互斥)" + }, + "start_date": { + "type": "string", + "description": "开始日期,格式:YYYY-MM-DD" + }, + "end_date": { + "type": "string", + "description": "结束日期,格式:YYYY-MM-DD" + }, + "min_stock_count": { + "type": "integer", + "description": "最少股票数量过滤", + "default": 3 + } + }, + "required": [] + } + ), + ToolDefinition( + name="search_limit_up_stocks", + description="搜索涨停股票,支持按日期、关键词、板块等条件搜索。包括混合语义搜索。", + parameters={ + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "搜索关键词(涨停原因、公司名称等)" + }, + "date": { + "type": "string", + "description": "日期,格式:YYYYMMDD" + }, + "mode": { + "type": "string", + "description": "搜索模式", + "enum": ["hybrid", "text", "vector"], + "default": "hybrid" + }, + "sectors": { + "type": "array", + "items": {"type": "string"}, + "description": "板块筛选" + }, + "page_size": { + "type": "integer", + "description": "每页结果数", + "default": 20 + } + }, + "required": ["query"] + } + ), + ToolDefinition( + name="get_daily_stock_analysis", + description="获取指定日期的涨停股票分析,包括板块分析、词云、趋势图表等。", + parameters={ + "type": "object", + "properties": { + "date": { + "type": "string", + "description": "日期,格式:YYYYMMDD" + } + }, + "required": ["date"] + } + ), + ToolDefinition( + name="search_research_reports", + description="搜索研究报告,支持文本和语义混合搜索。可按作者、证券、日期等筛选。", + parameters={ + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "搜索关键词" + }, + "mode": { + "type": "string", + "description": "搜索模式", + "enum": ["hybrid", "text", "vector"], + "default": "hybrid" + }, + "exact_match": { + "type": "string", + "description": "是否精确匹配:0=模糊,1=精确", + "enum": ["0", "1"], + "default": "0" + }, + "security_code": { + "type": "string", + "description": "证券代码筛选" + }, + "start_date": { + "type": "string", + "description": "开始日期,格式:YYYY-MM-DD" + }, + "end_date": { + "type": "string", + "description": "结束日期,格式:YYYY-MM-DD" + }, + "size": { + "type": "integer", + "description": "返回结果数量", + "default": 10 + } + }, + "required": ["query"] + } + ), + ToolDefinition( + name="get_stock_basic_info", + description="获取股票基本信息,包括公司名称、行业、地址、主营业务、高管等基础数据。", + parameters={ + "type": "object", + "properties": { + "seccode": { + "type": "string", + "description": "股票代码,例如:600519" + } + }, + "required": ["seccode"] + } + ), + ToolDefinition( + name="get_stock_financial_index", + description="获取股票财务指标,包括每股收益、净资产收益率、营收增长率等关键财务数据。", + parameters={ + "type": "object", + "properties": { + "seccode": { + "type": "string", + "description": "股票代码" + }, + "start_date": { + "type": "string", + "description": "开始日期,格式:YYYY-MM-DD" + }, + "end_date": { + "type": "string", + "description": "结束日期,格式:YYYY-MM-DD" + }, + "limit": { + "type": "integer", + "description": "返回条数,默认10", + "default": 10 + } + }, + "required": ["seccode"] + } + ), + ToolDefinition( + name="get_stock_trade_data", + description="获取股票交易数据,包括价格、成交量、涨跌幅、换手率等日线行情数据。", + parameters={ + "type": "object", + "properties": { + "seccode": { + "type": "string", + "description": "股票代码" + }, + "start_date": { + "type": "string", + "description": "开始日期,格式:YYYY-MM-DD" + }, + "end_date": { + "type": "string", + "description": "结束日期,格式:YYYY-MM-DD" + }, + "limit": { + "type": "integer", + "description": "返回条数,默认30", + "default": 30 + } + }, + "required": ["seccode"] + } + ), + ToolDefinition( + name="get_stock_balance_sheet", + description="获取股票资产负债表,包括资产、负债、所有者权益等财务状况数据。", + parameters={ + "type": "object", + "properties": { + "seccode": { + "type": "string", + "description": "股票代码" + }, + "start_date": { + "type": "string", + "description": "开始日期,格式:YYYY-MM-DD" + }, + "end_date": { + "type": "string", + "description": "结束日期,格式:YYYY-MM-DD" + }, + "limit": { + "type": "integer", + "description": "返回条数,默认8", + "default": 8 + } + }, + "required": ["seccode"] + } + ), + ToolDefinition( + name="get_stock_cashflow", + description="获取股票现金流量表,包括经营、投资、筹资活动现金流数据。", + parameters={ + "type": "object", + "properties": { + "seccode": { + "type": "string", + "description": "股票代码" + }, + "start_date": { + "type": "string", + "description": "开始日期,格式:YYYY-MM-DD" + }, + "end_date": { + "type": "string", + "description": "结束日期,格式:YYYY-MM-DD" + }, + "limit": { + "type": "integer", + "description": "返回条数,默认8", + "default": 8 + } + }, + "required": ["seccode"] + } + ), + ToolDefinition( + name="search_stocks_by_criteria", + description="按条件搜索股票,支持按行业、地区、市值等条件筛选股票列表。", + parameters={ + "type": "object", + "properties": { + "industry": { + "type": "string", + "description": "行业名称,支持模糊匹配" + }, + "province": { + "type": "string", + "description": "省份名称" + }, + "min_market_cap": { + "type": "number", + "description": "最小市值(亿元)" + }, + "max_market_cap": { + "type": "number", + "description": "最大市值(亿元)" + }, + "limit": { + "type": "integer", + "description": "返回条数,默认50", + "default": 50 + } + }, + "required": [] + } + ), + ToolDefinition( + name="get_stock_comparison", + description="股票对比分析,支持多只股票的财务指标或交易数据对比。", + parameters={ + "type": "object", + "properties": { + "seccodes": { + "type": "array", + "items": {"type": "string"}, + "description": "股票代码列表,至少2个" + }, + "metric": { + "type": "string", + "description": "对比指标类型", + "enum": ["financial", "trade"], + "default": "financial" + } + }, + "required": ["seccodes"] + } + ), +] + +# ==================== MCP协议端点 ==================== + +@app.get("/") +async def root(): + """服务根端点""" + return { + "name": "Financial Data MCP Server", + "version": "1.0.0", + "protocol": "MCP", + "description": "Model Context Protocol server for financial data search and analysis" + } + +@app.get("/tools") +async def list_tools(): + """列出所有可用工具""" + return { + "tools": [tool.dict() for tool in TOOLS] + } + +@app.get("/tools/{tool_name}") +async def get_tool(tool_name: str): + """获取特定工具的定义""" + tool = next((t for t in TOOLS if t.name == tool_name), None) + if not tool: + raise HTTPException(status_code=404, detail=f"Tool '{tool_name}' not found") + return tool.dict() + +@app.post("/tools/call") +async def call_tool(request: ToolCallRequest): + """调用工具""" + logger.info(f"Tool call: {request.tool} with args: {request.arguments}") + + try: + # 路由到对应的工具处理函数 + handler = TOOL_HANDLERS.get(request.tool) + if not handler: + raise HTTPException(status_code=404, detail=f"Tool '{request.tool}' not found") + + result = await handler(request.arguments) + + return ToolCallResponse( + success=True, + data=result, + metadata={ + "tool": request.tool, + "timestamp": datetime.now().isoformat() + } + ) + + except Exception as e: + logger.error(f"Tool call error: {str(e)}", exc_info=True) + return ToolCallResponse( + success=False, + error=str(e), + metadata={ + "tool": request.tool, + "timestamp": datetime.now().isoformat() + } + ) + +# ==================== 工具处理函数 ==================== + +async def handle_search_news(args: Dict[str, Any]) -> Any: + """处理新闻搜索""" + params = { + "query": args.get("query"), + "source": args.get("source"), + "start_date": args.get("start_date"), + "end_date": args.get("end_date"), + "top_k": args.get("top_k", 20) + } + # 移除None值 + params = {k: v for k, v in params.items() if v is not None} + + response = await HTTP_CLIENT.get(f"{ServiceEndpoints.NEWS_API}/search_news", params=params) + response.raise_for_status() + return response.json() + +async def handle_search_china_news(args: Dict[str, Any]) -> Any: + """处理中国新闻搜索""" + params = { + "query": args.get("query"), + "exact_match": args.get("exact_match", False), + "source": args.get("source"), + "start_date": args.get("start_date"), + "end_date": args.get("end_date"), + "top_k": args.get("top_k", 20) + } + params = {k: v for k, v in params.items() if v is not None} + + response = await HTTP_CLIENT.get(f"{ServiceEndpoints.NEWS_API}/search_china_news", params=params) + response.raise_for_status() + return response.json() + +async def handle_search_medical_news(args: Dict[str, Any]) -> Any: + """处理医疗新闻搜索""" + params = { + "query": args["query"], + "source": args.get("source"), + "start_date": args.get("start_date"), + "end_date": args.get("end_date"), + "top_k": args.get("top_k", 10) + } + params = {k: v for k, v in params.items() if v is not None} + + response = await HTTP_CLIENT.get(f"{ServiceEndpoints.NEWS_API}/search_medical_news", params=params) + response.raise_for_status() + return response.json() + +async def handle_search_roadshows(args: Dict[str, Any]) -> Any: + """处理路演搜索""" + params = { + "query": args["query"], + "company_code": args.get("company_code"), + "start_date": args.get("start_date"), + "end_date": args.get("end_date"), + "size": args.get("size", 10) + } + params = {k: v for k, v in params.items() if v is not None} + + response = await HTTP_CLIENT.get(f"{ServiceEndpoints.ROADSHOW_API}/search", params=params) + response.raise_for_status() + return response.json() + +async def handle_search_concepts(args: Dict[str, Any]) -> Any: + """处理概念搜索""" + payload = { + "query": args["query"], + "size": args.get("size", 10), + "page": args.get("page", 1), + "search_size": 100, + "sort_by": args.get("sort_by", "change_pct"), + "use_knn": True + } + if args.get("trade_date"): + payload["trade_date"] = args["trade_date"] + + response = await HTTP_CLIENT.post(f"{ServiceEndpoints.CONCEPT_API}/search", json=payload) + response.raise_for_status() + return response.json() + +async def handle_get_concept_details(args: Dict[str, Any]) -> Any: + """处理概念详情获取""" + concept_id = args["concept_id"] + params = {} + if args.get("trade_date"): + params["trade_date"] = args["trade_date"] + + response = await HTTP_CLIENT.get( + f"{ServiceEndpoints.CONCEPT_API}/concept/{concept_id}", + params=params + ) + response.raise_for_status() + return response.json() + +async def handle_get_stock_concepts(args: Dict[str, Any]) -> Any: + """处理股票概念获取""" + stock_code = args["stock_code"] + params = { + "size": args.get("size", 50), + "sort_by": args.get("sort_by", "stock_count"), + "include_description": True + } + if args.get("trade_date"): + params["trade_date"] = args["trade_date"] + + response = await HTTP_CLIENT.get( + f"{ServiceEndpoints.CONCEPT_API}/stock/{stock_code}/concepts", + params=params + ) + response.raise_for_status() + return response.json() + +async def handle_get_concept_statistics(args: Dict[str, Any]) -> Any: + """处理概念统计获取""" + params = {} + if args.get("days"): + params["days"] = args["days"] + if args.get("start_date"): + params["start_date"] = args["start_date"] + if args.get("end_date"): + params["end_date"] = args["end_date"] + if args.get("min_stock_count"): + params["min_stock_count"] = args["min_stock_count"] + + response = await HTTP_CLIENT.get(f"{ServiceEndpoints.CONCEPT_API}/statistics", params=params) + response.raise_for_status() + return response.json() + +async def handle_search_limit_up_stocks(args: Dict[str, Any]) -> Any: + """处理涨停股票搜索""" + payload = { + "query": args["query"], + "mode": args.get("mode", "hybrid"), + "page_size": args.get("page_size", 20) + } + if args.get("date"): + payload["date"] = args["date"] + if args.get("sectors"): + payload["sectors"] = args["sectors"] + + response = await HTTP_CLIENT.post( + f"{ServiceEndpoints.STOCK_ANALYSIS_API}/api/v1/stocks/search/hybrid", + json=payload + ) + response.raise_for_status() + return response.json() + +async def handle_get_daily_stock_analysis(args: Dict[str, Any]) -> Any: + """处理每日股票分析获取""" + date = args["date"] + response = await HTTP_CLIENT.get( + f"{ServiceEndpoints.STOCK_ANALYSIS_API}/api/v1/analysis/daily/{date}" + ) + response.raise_for_status() + return response.json() + +async def handle_search_research_reports(args: Dict[str, Any]) -> Any: + """处理研报搜索""" + params = { + "query": args["query"], + "mode": args.get("mode", "hybrid"), + "exact_match": args.get("exact_match", "0"), + "size": args.get("size", 10) + } + if args.get("security_code"): + params["security_code"] = args["security_code"] + if args.get("start_date"): + params["start_date"] = args["start_date"] + if args.get("end_date"): + params["end_date"] = args["end_date"] + + response = await HTTP_CLIENT.get(f"{ServiceEndpoints.STOCK_ANALYSIS_API}/search", params=params) + response.raise_for_status() + return response.json() + +async def handle_get_stock_basic_info(args: Dict[str, Any]) -> Any: + """处理股票基本信息查询""" + seccode = args["seccode"] + result = await db.get_stock_basic_info(seccode) + if result: + return {"success": True, "data": result} + else: + return {"success": False, "error": f"未找到股票代码 {seccode} 的信息"} + +async def handle_get_stock_financial_index(args: Dict[str, Any]) -> Any: + """处理股票财务指标查询""" + seccode = args["seccode"] + start_date = args.get("start_date") + end_date = args.get("end_date") + limit = args.get("limit", 10) + + result = await db.get_stock_financial_index(seccode, start_date, end_date, limit) + return { + "success": True, + "data": result, + "count": len(result) + } + +async def handle_get_stock_trade_data(args: Dict[str, Any]) -> Any: + """处理股票交易数据查询""" + seccode = args["seccode"] + start_date = args.get("start_date") + end_date = args.get("end_date") + limit = args.get("limit", 30) + + result = await db.get_stock_trade_data(seccode, start_date, end_date, limit) + return { + "success": True, + "data": result, + "count": len(result) + } + +async def handle_get_stock_balance_sheet(args: Dict[str, Any]) -> Any: + """处理资产负债表查询""" + seccode = args["seccode"] + start_date = args.get("start_date") + end_date = args.get("end_date") + limit = args.get("limit", 8) + + result = await db.get_stock_balance_sheet(seccode, start_date, end_date, limit) + return { + "success": True, + "data": result, + "count": len(result) + } + +async def handle_get_stock_cashflow(args: Dict[str, Any]) -> Any: + """处理现金流量表查询""" + seccode = args["seccode"] + start_date = args.get("start_date") + end_date = args.get("end_date") + limit = args.get("limit", 8) + + result = await db.get_stock_cashflow(seccode, start_date, end_date, limit) + return { + "success": True, + "data": result, + "count": len(result) + } + +async def handle_search_stocks_by_criteria(args: Dict[str, Any]) -> Any: + """处理按条件搜索股票""" + industry = args.get("industry") + province = args.get("province") + min_market_cap = args.get("min_market_cap") + max_market_cap = args.get("max_market_cap") + limit = args.get("limit", 50) + + result = await db.search_stocks_by_criteria( + industry, province, min_market_cap, max_market_cap, limit + ) + return { + "success": True, + "data": result, + "count": len(result) + } + +async def handle_get_stock_comparison(args: Dict[str, Any]) -> Any: + """处理股票对比分析""" + seccodes = args["seccodes"] + metric = args.get("metric", "financial") + + result = await db.get_stock_comparison(seccodes, metric) + return { + "success": True, + "data": result + } + +# 工具处理函数映射 +TOOL_HANDLERS = { + "search_news": handle_search_news, + "search_china_news": handle_search_china_news, + "search_medical_news": handle_search_medical_news, + "search_roadshows": handle_search_roadshows, + "search_concepts": handle_search_concepts, + "get_concept_details": handle_get_concept_details, + "get_stock_concepts": handle_get_stock_concepts, + "get_concept_statistics": handle_get_concept_statistics, + "search_limit_up_stocks": handle_search_limit_up_stocks, + "get_daily_stock_analysis": handle_get_daily_stock_analysis, + "search_research_reports": handle_search_research_reports, + "get_stock_basic_info": handle_get_stock_basic_info, + "get_stock_financial_index": handle_get_stock_financial_index, + "get_stock_trade_data": handle_get_stock_trade_data, + "get_stock_balance_sheet": handle_get_stock_balance_sheet, + "get_stock_cashflow": handle_get_stock_cashflow, + "search_stocks_by_criteria": handle_search_stocks_by_criteria, + "get_stock_comparison": handle_get_stock_comparison, +} + +# ==================== Web聊天接口 ==================== + +class ChatMessage(BaseModel): + """聊天消息""" + role: Literal["user", "assistant", "system"] + content: str + +class ChatRequest(BaseModel): + """聊天请求""" + messages: List[ChatMessage] + stream: bool = False + +@app.post("/chat") +async def chat(request: ChatRequest): + """ + Web聊天接口 + + 这是一个简化的接口,实际应该集成LLM API(如OpenAI、Claude等) + 这里只是演示如何使用工具 + """ + # TODO: 集成实际的LLM API + # 1. 将消息发送给LLM + # 2. LLM返回需要调用的工具 + # 3. 调用工具并获取结果 + # 4. 将工具结果返回给LLM + # 5. LLM生成最终回复 + + return { + "message": "Chat endpoint placeholder - integrate with your LLM provider", + "available_tools": len(TOOLS), + "hint": "Use POST /tools/call to invoke tools" + } + +# ==================== 健康检查 ==================== + +@app.get("/health") +async def health_check(): + """健康检查""" + # 检查各个后端服务的健康状态 + services_status = {} + + try: + response = await HTTP_CLIENT.get(f"{ServiceEndpoints.NEWS_API}/search_news?query=test&top_k=1", timeout=5.0) + services_status["news_api"] = "healthy" if response.status_code == 200 else "unhealthy" + except: + services_status["news_api"] = "unhealthy" + + try: + response = await HTTP_CLIENT.get(f"{ServiceEndpoints.CONCEPT_API}/", timeout=5.0) + services_status["concept_api"] = "healthy" if response.status_code == 200 else "unhealthy" + except: + services_status["concept_api"] = "unhealthy" + + try: + response = await HTTP_CLIENT.get(f"{ServiceEndpoints.STOCK_ANALYSIS_API}/api/v1/health", timeout=5.0) + services_status["stock_analysis_api"] = "healthy" if response.status_code == 200 else "unhealthy" + except: + services_status["stock_analysis_api"] = "unhealthy" + + return { + "status": "healthy", + "timestamp": datetime.now().isoformat(), + "services": services_status + } + +# ==================== 错误处理 ==================== + +@app.exception_handler(HTTPException) +async def http_exception_handler(request: Request, exc: HTTPException): + """HTTP异常处理""" + return JSONResponse( + status_code=exc.status_code, + content={ + "success": False, + "error": exc.detail, + "timestamp": datetime.now().isoformat() + } + ) + +@app.exception_handler(Exception) +async def general_exception_handler(request: Request, exc: Exception): + """通用异常处理""" + logger.error(f"Unexpected error: {str(exc)}", exc_info=True) + return JSONResponse( + status_code=500, + content={ + "success": False, + "error": "Internal server error", + "detail": str(exc), + "timestamp": datetime.now().isoformat() + } + ) + +# ==================== 应用启动/关闭 ==================== + +@app.on_event("startup") +async def startup_event(): + """应用启动""" + logger.info("MCP Server starting up...") + logger.info(f"Registered {len(TOOLS)} tools") + # 初始化数据库连接池 + try: + await db.get_pool() + logger.info("MySQL connection pool initialized") + except Exception as e: + logger.error(f"Failed to initialize MySQL pool: {e}") + +@app.on_event("shutdown") +async def shutdown_event(): + """应用关闭""" + logger.info("MCP Server shutting down...") + await HTTP_CLIENT.aclose() + # 关闭数据库连接池 + try: + await db.close_pool() + logger.info("MySQL connection pool closed") + except Exception as e: + logger.error(f"Failed to close MySQL pool: {e}") + +# ==================== 主程序 ==================== + +if __name__ == "__main__": + import uvicorn + + uvicorn.run( + "mcp_server:app", + host="0.0.0.0", + port=8900, + reload=True, + log_level="info" + )