agent功能开发增加MCP后端

This commit is contained in:
2025-11-07 21:46:50 +08:00
parent d1d8d1a25d
commit 3a058fd805
5 changed files with 2098 additions and 6 deletions

1398
CLAUDE.md Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -544,3 +544,240 @@ async def get_stock_comparison(
"comparison_type": metric,
"stocks": [convert_row(row) for row in results]
}
async def get_user_favorite_stocks(user_id: str, limit: int = 100) -> List[Dict[str, Any]]:
"""
获取用户自选股列表
Args:
user_id: 用户ID
limit: 返回条数
Returns:
自选股列表(包含最新行情数据)
"""
pool = await get_pool()
async with pool.acquire() as conn:
async with conn.cursor(aiomysql.DictCursor) as cursor:
# 查询用户自选股(假设有 user_favorites 表)
# 如果没有此表,可以根据实际情况调整
query = """
SELECT
f.user_id,
f.stock_code,
b.SECNAME as stock_name,
b.F030V as industry,
t.F007N as current_price,
t.F010N as change_pct,
t.F012N as turnover_rate,
t.F026N as pe_ratio,
t.TRADEDATE as latest_trade_date,
f.created_at as favorite_time
FROM user_favorites f
INNER JOIN ea_baseinfo b ON f.stock_code = b.SECCODE
LEFT JOIN (
SELECT SECCODE, MAX(TRADEDATE) as max_date
FROM ea_trade
GROUP BY SECCODE
) latest ON b.SECCODE = latest.SECCODE
LEFT JOIN ea_trade t ON b.SECCODE = t.SECCODE
AND t.TRADEDATE = latest.max_date
WHERE f.user_id = %s AND f.is_deleted = 0
ORDER BY f.created_at DESC
LIMIT %s
"""
await cursor.execute(query, [user_id, limit])
results = await cursor.fetchall()
return [convert_row(row) for row in results]
async def get_user_favorite_events(user_id: str, limit: int = 100) -> List[Dict[str, Any]]:
"""
获取用户自选事件列表
Args:
user_id: 用户ID
limit: 返回条数
Returns:
自选事件列表
"""
pool = await get_pool()
async with pool.acquire() as conn:
async with conn.cursor(aiomysql.DictCursor) as cursor:
# 查询用户自选事件(假设有 user_event_favorites 表)
query = """
SELECT
f.user_id,
f.event_id,
e.title,
e.description,
e.event_date,
e.importance,
e.related_stocks,
e.category,
f.created_at as favorite_time
FROM user_event_favorites f
INNER JOIN events e ON f.event_id = e.id
WHERE f.user_id = %s AND f.is_deleted = 0
ORDER BY e.event_date DESC
LIMIT %s
"""
await cursor.execute(query, [user_id, limit])
results = await cursor.fetchall()
return [convert_row(row) for row in results]
async def add_favorite_stock(user_id: str, stock_code: str) -> Dict[str, Any]:
"""
添加自选股
Args:
user_id: 用户ID
stock_code: 股票代码
Returns:
操作结果
"""
pool = await get_pool()
async with pool.acquire() as conn:
async with conn.cursor(aiomysql.DictCursor) as cursor:
# 检查是否已存在
check_query = """
SELECT id, is_deleted
FROM user_favorites
WHERE user_id = %s AND stock_code = %s
"""
await cursor.execute(check_query, [user_id, stock_code])
existing = await cursor.fetchone()
if existing:
if existing['is_deleted'] == 1:
# 恢复已删除的记录
update_query = """
UPDATE user_favorites
SET is_deleted = 0, updated_at = NOW()
WHERE id = %s
"""
await cursor.execute(update_query, [existing['id']])
return {"success": True, "message": "已恢复自选股"}
else:
return {"success": False, "message": "该股票已在自选中"}
# 插入新记录
insert_query = """
INSERT INTO user_favorites (user_id, stock_code, created_at, updated_at, is_deleted)
VALUES (%s, %s, NOW(), NOW(), 0)
"""
await cursor.execute(insert_query, [user_id, stock_code])
return {"success": True, "message": "添加自选股成功"}
async def remove_favorite_stock(user_id: str, stock_code: str) -> Dict[str, Any]:
"""
删除自选股
Args:
user_id: 用户ID
stock_code: 股票代码
Returns:
操作结果
"""
pool = await get_pool()
async with pool.acquire() as conn:
async with conn.cursor(aiomysql.DictCursor) as cursor:
query = """
UPDATE user_favorites
SET is_deleted = 1, updated_at = NOW()
WHERE user_id = %s AND stock_code = %s AND is_deleted = 0
"""
result = await cursor.execute(query, [user_id, stock_code])
if result > 0:
return {"success": True, "message": "删除自选股成功"}
else:
return {"success": False, "message": "未找到该自选股"}
async def add_favorite_event(user_id: str, event_id: int) -> Dict[str, Any]:
"""
添加自选事件
Args:
user_id: 用户ID
event_id: 事件ID
Returns:
操作结果
"""
pool = await get_pool()
async with pool.acquire() as conn:
async with conn.cursor(aiomysql.DictCursor) as cursor:
# 检查是否已存在
check_query = """
SELECT id, is_deleted
FROM user_event_favorites
WHERE user_id = %s AND event_id = %s
"""
await cursor.execute(check_query, [user_id, event_id])
existing = await cursor.fetchone()
if existing:
if existing['is_deleted'] == 1:
# 恢复已删除的记录
update_query = """
UPDATE user_event_favorites
SET is_deleted = 0, updated_at = NOW()
WHERE id = %s
"""
await cursor.execute(update_query, [existing['id']])
return {"success": True, "message": "已恢复自选事件"}
else:
return {"success": False, "message": "该事件已在自选中"}
# 插入新记录
insert_query = """
INSERT INTO user_event_favorites (user_id, event_id, created_at, updated_at, is_deleted)
VALUES (%s, %s, NOW(), NOW(), 0)
"""
await cursor.execute(insert_query, [user_id, event_id])
return {"success": True, "message": "添加自选事件成功"}
async def remove_favorite_event(user_id: str, event_id: int) -> Dict[str, Any]:
"""
删除自选事件
Args:
user_id: 用户ID
event_id: 事件ID
Returns:
操作结果
"""
pool = await get_pool()
async with pool.acquire() as conn:
async with conn.cursor(aiomysql.DictCursor) as cursor:
query = """
UPDATE user_event_favorites
SET is_deleted = 1, updated_at = NOW()
WHERE user_id = %s AND event_id = %s AND is_deleted = 0
"""
result = await cursor.execute(query, [user_id, event_id])
if result > 0:
return {"success": True, "message": "删除自选事件成功"}
else:
return {"success": False, "message": "未找到该自选事件"}

320
mcp_elasticsearch.py Normal file
View File

@@ -0,0 +1,320 @@
"""
Elasticsearch 连接和工具模块
用于聊天记录存储和向量搜索
"""
from elasticsearch import Elasticsearch, helpers
from datetime import datetime
from typing import List, Dict, Any, Optional
import logging
import json
import openai
logger = logging.getLogger(__name__)
# ==================== 配置 ====================
# ES 配置
ES_CONFIG = {
"host": "http://222.128.1.157:19200",
"index_chat_history": "agent_chat_history", # 聊天记录索引
}
# Embedding 配置
EMBEDDING_CONFIG = {
"api_key": "dummy",
"base_url": "http://222.128.1.157:18008/v1",
"model": "qwen3-embedding-8b",
"dims": 4096, # 向量维度
}
# ==================== ES 客户端 ====================
class ESClient:
"""Elasticsearch 客户端封装"""
def __init__(self):
self.es = Elasticsearch([ES_CONFIG["host"]], request_timeout=60)
self.chat_index = ES_CONFIG["index_chat_history"]
# 初始化 OpenAI 客户端用于 embedding
self.embedding_client = openai.OpenAI(
api_key=EMBEDDING_CONFIG["api_key"],
base_url=EMBEDDING_CONFIG["base_url"],
)
self.embedding_model = EMBEDDING_CONFIG["model"]
# 初始化索引
self.create_chat_history_index()
def create_chat_history_index(self):
"""创建聊天记录索引"""
if self.es.indices.exists(index=self.chat_index):
logger.info(f"索引 {self.chat_index} 已存在")
return
mappings = {
"properties": {
"session_id": {"type": "keyword"}, # 会话ID
"user_id": {"type": "keyword"}, # 用户ID
"user_nickname": {"type": "text"}, # 用户昵称
"user_avatar": {"type": "keyword"}, # 用户头像URL
"message_type": {"type": "keyword"}, # user / assistant
"message": {"type": "text"}, # 消息内容
"message_embedding": { # 消息向量
"type": "dense_vector",
"dims": EMBEDDING_CONFIG["dims"],
"index": True,
"similarity": "cosine"
},
"plan": {"type": "text"}, # 执行计划(仅 assistant
"steps": {"type": "text"}, # 执行步骤(仅 assistant
"timestamp": {"type": "date"}, # 时间戳
"created_at": {"type": "date"}, # 创建时间
}
}
self.es.indices.create(index=self.chat_index, body={"mappings": mappings})
logger.info(f"创建索引: {self.chat_index}")
def generate_embedding(self, text: str) -> List[float]:
"""生成文本向量"""
try:
if not text or len(text.strip()) == 0:
return []
# 截断过长文本
text = text[:16000] if len(text) > 16000 else text
response = self.embedding_client.embeddings.create(
model=self.embedding_model,
input=[text]
)
return response.data[0].embedding
except Exception as e:
logger.error(f"Embedding 生成失败: {e}")
return []
def save_chat_message(
self,
session_id: str,
user_id: str,
user_nickname: str,
user_avatar: str,
message_type: str, # "user" or "assistant"
message: str,
plan: Optional[str] = None,
steps: Optional[str] = None,
) -> str:
"""
保存聊天消息
Args:
session_id: 会话ID
user_id: 用户ID
user_nickname: 用户昵称
user_avatar: 用户头像URL
message_type: 消息类型 (user/assistant)
message: 消息内容
plan: 执行计划(可选)
steps: 执行步骤(可选)
Returns:
文档ID
"""
try:
# 生成向量
embedding = self.generate_embedding(message)
doc = {
"session_id": session_id,
"user_id": user_id,
"user_nickname": user_nickname,
"user_avatar": user_avatar,
"message_type": message_type,
"message": message,
"message_embedding": embedding if embedding else None,
"plan": plan,
"steps": steps,
"timestamp": datetime.now(),
"created_at": datetime.now(),
}
result = self.es.index(index=self.chat_index, body=doc)
logger.info(f"保存聊天记录: {result['_id']}")
return result["_id"]
except Exception as e:
logger.error(f"保存聊天记录失败: {e}")
raise
def get_chat_sessions(self, user_id: str, limit: int = 50) -> List[Dict[str, Any]]:
"""
获取用户的聊天会话列表
Args:
user_id: 用户ID
limit: 返回数量
Returns:
会话列表每个会话包含session_id, last_message, last_timestamp
"""
try:
# 聚合查询:按 session_id 分组,获取每个会话的最后一条消息
query = {
"query": {
"term": {"user_id": user_id}
},
"aggs": {
"sessions": {
"terms": {
"field": "session_id",
"size": limit,
"order": {"last_message": "desc"}
},
"aggs": {
"last_message": {
"max": {"field": "timestamp"}
},
"last_message_content": {
"top_hits": {
"size": 1,
"sort": [{"timestamp": {"order": "desc"}}],
"_source": ["message", "timestamp", "message_type"]
}
}
}
}
},
"size": 0
}
result = self.es.search(index=self.chat_index, body=query)
sessions = []
for bucket in result["aggregations"]["sessions"]["buckets"]:
session_data = bucket["last_message_content"]["hits"]["hits"][0]["_source"]
sessions.append({
"session_id": bucket["key"],
"last_message": session_data["message"],
"last_timestamp": session_data["timestamp"],
"message_count": bucket["doc_count"],
})
return sessions
except Exception as e:
logger.error(f"获取会话列表失败: {e}")
return []
def get_chat_history(
self,
session_id: str,
limit: int = 100
) -> List[Dict[str, Any]]:
"""
获取指定会话的聊天历史
Args:
session_id: 会话ID
limit: 返回数量
Returns:
聊天记录列表
"""
try:
query = {
"query": {
"term": {"session_id": session_id}
},
"sort": [{"timestamp": {"order": "asc"}}],
"size": limit
}
result = self.es.search(index=self.chat_index, body=query)
messages = []
for hit in result["hits"]["hits"]:
doc = hit["_source"]
messages.append({
"message_type": doc["message_type"],
"message": doc["message"],
"plan": doc.get("plan"),
"steps": doc.get("steps"),
"timestamp": doc["timestamp"],
})
return messages
except Exception as e:
logger.error(f"获取聊天历史失败: {e}")
return []
def search_chat_history(
self,
user_id: str,
query_text: str,
top_k: int = 10
) -> List[Dict[str, Any]]:
"""
向量搜索聊天历史
Args:
user_id: 用户ID
query_text: 查询文本
top_k: 返回数量
Returns:
相关聊天记录列表
"""
try:
# 生成查询向量
query_embedding = self.generate_embedding(query_text)
if not query_embedding:
return []
# 向量搜索
query = {
"query": {
"bool": {
"must": [
{"term": {"user_id": user_id}},
{
"script_score": {
"query": {"match_all": {}},
"script": {
"source": "cosineSimilarity(params.query_vector, 'message_embedding') + 1.0",
"params": {"query_vector": query_embedding}
}
}
}
]
}
},
"size": top_k
}
result = self.es.search(index=self.chat_index, body=query)
messages = []
for hit in result["hits"]["hits"]:
doc = hit["_source"]
messages.append({
"session_id": doc["session_id"],
"message_type": doc["message_type"],
"message": doc["message"],
"timestamp": doc["timestamp"],
"score": hit["_score"],
})
return messages
except Exception as e:
logger.error(f"向量搜索失败: {e}")
return []
# ==================== 全局实例 ====================
# 创建全局 ES 客户端
es_client = ESClient()

View File

@@ -17,6 +17,8 @@ import mcp_database as db
from openai import OpenAI
import json
import asyncio
import uuid
from mcp_elasticsearch import es_client
# 配置日志
logging.basicConfig(level=logging.INFO)
@@ -135,6 +137,10 @@ class AgentChatRequest(BaseModel):
"""聊天请求"""
message: str
conversation_history: List[ConversationMessage] = []
user_id: Optional[str] = None # 用户ID
user_nickname: Optional[str] = None # 用户昵称
user_avatar: Optional[str] = None # 用户头像URL
session_id: Optional[str] = None # 会话ID如果为空则创建新会话
# ==================== MCP工具定义 ====================
@@ -1023,7 +1029,14 @@ class MCPAgentIntegrated:
for tool in tools
])
return f"""你是一个专业的金融研究助手。根据用户问题,制定详细的执行计划
return f"""你是"价小前"北京价值前沿科技公司的AI投研聊天助手
## 你的人格特征
- **名字**: 价小前
- **身份**: 北京价值前沿科技公司的专业AI投研助手
- **专业领域**: 股票投资研究、市场分析、新闻解读、财务分析
- **性格**: 专业、严谨、友好,擅长用简洁的语言解释复杂的金融概念
- **服务宗旨**: 帮助投资者做出更明智的投资决策,提供数据驱动的研究支持
## 可用工具
@@ -1040,7 +1053,7 @@ class MCPAgentIntegrated:
- 概念板块: 相同题材股票分类
## 任务
分析用户问题,制定执行计划。返回 JSON
分析用户问题,制定详细的执行计划。返回 JSON
```json
{{
@@ -1533,7 +1546,32 @@ async def chat(request: ChatRequest):
@app.post("/agent/chat", response_model=AgentResponse)
async def agent_chat(request: AgentChatRequest):
"""智能代理对话端点(非流式)"""
logger.info(f"Agent chat: {request.message}")
logger.info(f"Agent chat: {request.message} (user: {request.user_id})")
# ==================== 权限检查 ====================
# 仅允许 max 用户使用
if request.user_id != "max":
raise HTTPException(
status_code=403,
detail="很抱歉,「价小前投研」功能目前仅对特定用户开放。如需使用,请联系管理员。"
)
# ==================== 会话管理 ====================
# 如果没有提供 session_id创建新会话
session_id = request.session_id or str(uuid.uuid4())
# 保存用户消息到 ES
try:
es_client.save_chat_message(
session_id=session_id,
user_id=request.user_id or "anonymous",
user_nickname=request.user_nickname or "匿名用户",
user_avatar=request.user_avatar or "",
message_type="user",
message=request.message,
)
except Exception as e:
logger.error(f"保存用户消息失败: {e}")
# 获取工具列表
tools = [tool.dict() for tool in TOOLS]
@@ -1565,7 +1603,31 @@ async def agent_chat(request: AgentChatRequest):
tool_handlers=TOOL_HANDLERS,
)
return response
# 保存 Agent 回复到 ES
try:
# 将执行步骤转换为JSON字符串
steps_json = json.dumps(
[{"tool": step.tool, "result": step.result} for step in response.steps],
ensure_ascii=False
)
es_client.save_chat_message(
session_id=session_id,
user_id=request.user_id or "anonymous",
user_nickname=request.user_nickname or "匿名用户",
user_avatar=request.user_avatar or "",
message_type="assistant",
message=response.final_answer,
plan=response.plan,
steps=steps_json,
)
except Exception as e:
logger.error(f"保存 Agent 回复失败: {e}")
# 在响应中返回 session_id
response_dict = response.dict()
response_dict["session_id"] = session_id
return response_dict
@app.post("/agent/chat/stream")
async def agent_chat_stream(request: AgentChatRequest):
@@ -1610,6 +1672,81 @@ async def agent_chat_stream(request: AgentChatRequest):
},
)
# ==================== 聊天记录管理 API ====================
@app.get("/agent/sessions")
async def get_chat_sessions(user_id: str, limit: int = 50):
"""
获取用户的聊天会话列表
Args:
user_id: 用户ID
limit: 返回数量默认50
Returns:
会话列表
"""
try:
sessions = es_client.get_chat_sessions(user_id, limit)
return {
"success": True,
"data": sessions,
"count": len(sessions)
}
except Exception as e:
logger.error(f"获取会话列表失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/agent/history/{session_id}")
async def get_chat_history(session_id: str, limit: int = 100):
"""
获取指定会话的聊天历史
Args:
session_id: 会话ID
limit: 返回数量默认100
Returns:
聊天记录列表
"""
try:
messages = es_client.get_chat_history(session_id, limit)
return {
"success": True,
"data": messages,
"count": len(messages)
}
except Exception as e:
logger.error(f"获取聊天历史失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/agent/search")
async def search_chat_history(user_id: str, query: str, top_k: int = 10):
"""
向量搜索聊天历史
Args:
user_id: 用户ID
query: 查询文本
top_k: 返回数量默认10
Returns:
相关聊天记录列表
"""
try:
results = es_client.search_chat_history(user_id, query, top_k)
return {
"success": True,
"data": results,
"count": len(results)
}
except Exception as e:
logger.error(f"向量搜索失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
# ==================== 健康检查 ====================
@app.get("/health")

View File

@@ -157,8 +157,8 @@ export const routeConfig = [
protection: PROTECTION_MODES.MODAL,
layout: 'main',
meta: {
title: 'AI投资助手',
description: '基于MCP的智能投资顾问'
title: '价小前投研',
description: '北京价值前沿科技公司的AI投研聊天助手'
}
},
];