From 2756e6e379ab54f1ecb21d6a195a2628b4863191 Mon Sep 17 00:00:00 2001 From: zzlgreat Date: Sat, 8 Nov 2025 11:32:01 +0800 Subject: [PATCH] =?UTF-8?q?agent=E5=8A=9F=E8=83=BD=E5=BC=80=E5=8F=91?= =?UTF-8?q?=E5=A2=9E=E5=8A=A0MCP=E5=90=8E=E7=AB=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- mcp_server.py | 153 ++++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 137 insertions(+), 16 deletions(-) diff --git a/mcp_server.py b/mcp_server.py index 066aeb37..ca497619 100644 --- a/mcp_server.py +++ b/mcp_server.py @@ -48,6 +48,7 @@ class ServiceEndpoints: ROADSHOW_API = "http://222.128.1.157:19800" # 路演API CONCEPT_API = "http://222.128.1.157:16801" # 概念API(本地) STOCK_ANALYSIS_API = "http://222.128.1.157:8811" # 涨停分析+研报API + MAIN_APP_API = "http://49.232.185.254:5001" # 主应用API(自选股、自选事件等) # HTTP客户端配置 HTTP_CLIENT = httpx.AsyncClient(timeout=60.0) @@ -648,6 +649,34 @@ TOOLS: List[ToolDefinition] = [ "required": ["seccodes"] } ), + ToolDefinition( + name="get_user_watchlist", + description="获取用户的自选股列表及实时行情数据。返回用户关注的股票及其当前价格、涨跌幅等信息。", + parameters={ + "type": "object", + "properties": { + "user_id": { + "type": "string", + "description": "用户ID(可选,如果不提供则使用当前会话用户)" + } + }, + "required": [] + } + ), + ToolDefinition( + name="get_user_following_events", + description="获取用户关注的事件列表。返回用户关注的热点事件及其基本信息(标题、类型、热度、关注人数等)。", + parameters={ + "type": "object", + "properties": { + "user_id": { + "type": "string", + "description": "用户ID(可选,如果不提供则使用当前会话用户)" + } + }, + "required": [] + } + ), ] # ==================== MCP协议端点 ==================== @@ -979,6 +1008,88 @@ async def handle_get_stock_comparison(args: Dict[str, Any]) -> Any: "data": result } +async def handle_get_user_watchlist(args: Dict[str, Any]) -> Any: + """获取用户自选股列表及实时行情""" + try: + # 从 agent 实例获取 cookies(如果可用) + cookies = getattr(agent, 'cookies', {}) + + # 调用主应用的自选股API + response = await HTTP_CLIENT.get( + f"{ServiceEndpoints.MAIN_APP_API}/api/account/watchlist/realtime", + headers={ + "Content-Type": "application/json" + }, + cookies=cookies # 传递用户的 session cookie + ) + + if response.status_code == 200: + data = response.json() + logger.info(f"[Watchlist] 成功获取 {len(data.get('data', []))} 只自选股") + return data + elif response.status_code == 401: + logger.warning("[Watchlist] 未登录或会话已过期") + return { + "success": False, + "error": "未登录或会话已过期", + "data": [] + } + else: + logger.error(f"[Watchlist] 获取失败: {response.status_code}") + return { + "success": False, + "error": f"获取自选股失败: {response.status_code}", + "data": [] + } + except Exception as e: + logger.error(f"[Watchlist] 获取用户自选股失败: {e}", exc_info=True) + return { + "success": False, + "error": str(e), + "data": [] + } + +async def handle_get_user_following_events(args: Dict[str, Any]) -> Any: + """获取用户关注的事件列表""" + try: + # 从 agent 实例获取 cookies(如果可用) + cookies = getattr(agent, 'cookies', {}) + + # 调用主应用的关注事件API + response = await HTTP_CLIENT.get( + f"{ServiceEndpoints.MAIN_APP_API}/api/account/events/following", + headers={ + "Content-Type": "application/json" + }, + cookies=cookies # 传递用户的 session cookie + ) + + if response.status_code == 200: + data = response.json() + logger.info(f"[FollowingEvents] 成功获取 {len(data.get('data', []))} 个关注事件") + return data + elif response.status_code == 401: + logger.warning("[FollowingEvents] 未登录或会话已过期") + return { + "success": False, + "error": "未登录或会话已过期", + "data": [] + } + else: + logger.error(f"[FollowingEvents] 获取失败: {response.status_code}") + return { + "success": False, + "error": f"获取关注事件失败: {response.status_code}", + "data": [] + } + except Exception as e: + logger.error(f"[FollowingEvents] 获取用户关注事件失败: {e}", exc_info=True) + return { + "success": False, + "error": str(e), + "data": [] + } + # 工具处理函数映射 TOOL_HANDLERS = { "search_news": handle_search_news, @@ -999,6 +1110,8 @@ TOOL_HANDLERS = { "get_stock_cashflow": handle_get_stock_cashflow, "search_stocks_by_criteria": handle_search_stocks_by_criteria, "get_stock_comparison": handle_get_stock_comparison, + "get_user_watchlist": handle_get_user_watchlist, + "get_user_following_events": handle_get_user_following_events, } # ==================== Agent系统实现 ==================== @@ -1465,10 +1578,14 @@ class MCPAgentIntegrated: user_id: str = None, user_nickname: str = None, user_avatar: str = None, + cookies: dict = None, ) -> AsyncGenerator[str, None]: """主流程(流式输出)- 逐步返回执行结果""" logger.info(f"[Agent Stream] 处理查询: {user_query}") + # 将 cookies 存储为实例属性,供工具调用时使用 + self.cookies = cookies or {} + try: # 发送开始事件 yield self._format_sse("status", {"stage": "start", "message": "开始处理查询"}) @@ -1992,9 +2109,12 @@ async def agent_chat(request: AgentChatRequest): return response_dict @app.post("/agent/chat/stream") -async def agent_chat_stream(request: AgentChatRequest): +async def agent_chat_stream(chat_request: AgentChatRequest, request: Request): """智能代理对话端点(流式 SSE)""" - logger.info(f"Agent chat stream: {request.message}") + logger.info(f"Agent chat stream: {chat_request.message}") + + # 获取请求的 cookies(用于转发到需要认证的 API) + cookies = request.cookies # ==================== 权限检查 ==================== # 订阅等级判断函数(与 app.py 保持一致) @@ -2004,7 +2124,7 @@ async def agent_chat_stream(request: AgentChatRequest): return mapping.get((sub_type or 'free').lower(), 0) # 获取用户订阅类型(默认为 free) - user_subscription = (request.subscription_type or 'free').lower() + user_subscription = (chat_request.subscription_type or 'free').lower() required_level = 'max' # 权限检查:仅允许 max 用户访问(与传导链分析权限保持一致) @@ -2012,8 +2132,8 @@ async def agent_chat_stream(request: AgentChatRequest): if not has_access: logger.warning( - f"[Stream] 权限检查失败 - user_id: {request.user_id}, " - f"nickname: {request.user_nickname}, " + f"[Stream] 权限检查失败 - user_id: {chat_request.user_id}, " + f"nickname: {chat_request.user_nickname}, " f"subscription_type: {user_subscription}, " f"required: {required_level}" ) @@ -2023,23 +2143,23 @@ async def agent_chat_stream(request: AgentChatRequest): ) logger.info( - f"[Stream] 权限检查通过 - user_id: {request.user_id}, " - f"nickname: {request.user_nickname}, " + f"[Stream] 权限检查通过 - user_id: {chat_request.user_id}, " + f"nickname: {chat_request.user_nickname}, " f"subscription_type: {user_subscription}" ) # 如果没有提供 session_id,创建新会话 - session_id = request.session_id or str(uuid.uuid4()) + session_id = chat_request.session_id or str(uuid.uuid4()) # 保存用户消息到 ES try: es_client.save_chat_message( session_id=session_id, - user_id=request.user_id or "anonymous", - user_nickname=request.user_nickname or "匿名用户", - user_avatar=request.user_avatar or "", + user_id=chat_request.user_id or "anonymous", + user_nickname=chat_request.user_nickname or "匿名用户", + user_avatar=chat_request.user_avatar or "", message_type="user", - message=request.message, + message=chat_request.message, ) logger.info(f"[ES] 用户消息已保存到会话 {session_id}") except Exception as e: @@ -2071,13 +2191,14 @@ async def agent_chat_stream(request: AgentChatRequest): # 返回流式响应 return StreamingResponse( agent.process_query_stream( - user_query=request.message, + user_query=chat_request.message, tools=tools, tool_handlers=TOOL_HANDLERS, session_id=session_id, - user_id=request.user_id, - user_nickname=request.user_nickname, - user_avatar=request.user_avatar, + user_id=chat_request.user_id, + user_nickname=chat_request.user_nickname, + user_avatar=chat_request.user_avatar, + cookies=cookies, # 传递 cookies 用于认证 API 调用 ), media_type="text/event-stream", headers={