agent功能开发增加MCP后端
This commit is contained in:
237
mcp_database.py
237
mcp_database.py
@@ -544,3 +544,240 @@ async def get_stock_comparison(
|
|||||||
"comparison_type": metric,
|
"comparison_type": metric,
|
||||||
"stocks": [convert_row(row) for row in results]
|
"stocks": [convert_row(row) for row in results]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def get_user_favorite_stocks(user_id: str, limit: int = 100) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
获取用户自选股列表
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: 用户ID
|
||||||
|
limit: 返回条数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
自选股列表(包含最新行情数据)
|
||||||
|
"""
|
||||||
|
pool = await get_pool()
|
||||||
|
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
async with conn.cursor(aiomysql.DictCursor) as cursor:
|
||||||
|
# 查询用户自选股(假设有 user_favorites 表)
|
||||||
|
# 如果没有此表,可以根据实际情况调整
|
||||||
|
query = """
|
||||||
|
SELECT
|
||||||
|
f.user_id,
|
||||||
|
f.stock_code,
|
||||||
|
b.SECNAME as stock_name,
|
||||||
|
b.F030V as industry,
|
||||||
|
t.F007N as current_price,
|
||||||
|
t.F010N as change_pct,
|
||||||
|
t.F012N as turnover_rate,
|
||||||
|
t.F026N as pe_ratio,
|
||||||
|
t.TRADEDATE as latest_trade_date,
|
||||||
|
f.created_at as favorite_time
|
||||||
|
FROM user_favorites f
|
||||||
|
INNER JOIN ea_baseinfo b ON f.stock_code = b.SECCODE
|
||||||
|
LEFT JOIN (
|
||||||
|
SELECT SECCODE, MAX(TRADEDATE) as max_date
|
||||||
|
FROM ea_trade
|
||||||
|
GROUP BY SECCODE
|
||||||
|
) latest ON b.SECCODE = latest.SECCODE
|
||||||
|
LEFT JOIN ea_trade t ON b.SECCODE = t.SECCODE
|
||||||
|
AND t.TRADEDATE = latest.max_date
|
||||||
|
WHERE f.user_id = %s AND f.is_deleted = 0
|
||||||
|
ORDER BY f.created_at DESC
|
||||||
|
LIMIT %s
|
||||||
|
"""
|
||||||
|
|
||||||
|
await cursor.execute(query, [user_id, limit])
|
||||||
|
results = await cursor.fetchall()
|
||||||
|
|
||||||
|
return [convert_row(row) for row in results]
|
||||||
|
|
||||||
|
|
||||||
|
async def get_user_favorite_events(user_id: str, limit: int = 100) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
获取用户自选事件列表
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: 用户ID
|
||||||
|
limit: 返回条数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
自选事件列表
|
||||||
|
"""
|
||||||
|
pool = await get_pool()
|
||||||
|
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
async with conn.cursor(aiomysql.DictCursor) as cursor:
|
||||||
|
# 查询用户自选事件(假设有 user_event_favorites 表)
|
||||||
|
query = """
|
||||||
|
SELECT
|
||||||
|
f.user_id,
|
||||||
|
f.event_id,
|
||||||
|
e.title,
|
||||||
|
e.description,
|
||||||
|
e.event_date,
|
||||||
|
e.importance,
|
||||||
|
e.related_stocks,
|
||||||
|
e.category,
|
||||||
|
f.created_at as favorite_time
|
||||||
|
FROM user_event_favorites f
|
||||||
|
INNER JOIN events e ON f.event_id = e.id
|
||||||
|
WHERE f.user_id = %s AND f.is_deleted = 0
|
||||||
|
ORDER BY e.event_date DESC
|
||||||
|
LIMIT %s
|
||||||
|
"""
|
||||||
|
|
||||||
|
await cursor.execute(query, [user_id, limit])
|
||||||
|
results = await cursor.fetchall()
|
||||||
|
|
||||||
|
return [convert_row(row) for row in results]
|
||||||
|
|
||||||
|
|
||||||
|
async def add_favorite_stock(user_id: str, stock_code: str) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
添加自选股
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: 用户ID
|
||||||
|
stock_code: 股票代码
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
操作结果
|
||||||
|
"""
|
||||||
|
pool = await get_pool()
|
||||||
|
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
async with conn.cursor(aiomysql.DictCursor) as cursor:
|
||||||
|
# 检查是否已存在
|
||||||
|
check_query = """
|
||||||
|
SELECT id, is_deleted
|
||||||
|
FROM user_favorites
|
||||||
|
WHERE user_id = %s AND stock_code = %s
|
||||||
|
"""
|
||||||
|
await cursor.execute(check_query, [user_id, stock_code])
|
||||||
|
existing = await cursor.fetchone()
|
||||||
|
|
||||||
|
if existing:
|
||||||
|
if existing['is_deleted'] == 1:
|
||||||
|
# 恢复已删除的记录
|
||||||
|
update_query = """
|
||||||
|
UPDATE user_favorites
|
||||||
|
SET is_deleted = 0, updated_at = NOW()
|
||||||
|
WHERE id = %s
|
||||||
|
"""
|
||||||
|
await cursor.execute(update_query, [existing['id']])
|
||||||
|
return {"success": True, "message": "已恢复自选股"}
|
||||||
|
else:
|
||||||
|
return {"success": False, "message": "该股票已在自选中"}
|
||||||
|
|
||||||
|
# 插入新记录
|
||||||
|
insert_query = """
|
||||||
|
INSERT INTO user_favorites (user_id, stock_code, created_at, updated_at, is_deleted)
|
||||||
|
VALUES (%s, %s, NOW(), NOW(), 0)
|
||||||
|
"""
|
||||||
|
await cursor.execute(insert_query, [user_id, stock_code])
|
||||||
|
return {"success": True, "message": "添加自选股成功"}
|
||||||
|
|
||||||
|
|
||||||
|
async def remove_favorite_stock(user_id: str, stock_code: str) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
删除自选股
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: 用户ID
|
||||||
|
stock_code: 股票代码
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
操作结果
|
||||||
|
"""
|
||||||
|
pool = await get_pool()
|
||||||
|
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
async with conn.cursor(aiomysql.DictCursor) as cursor:
|
||||||
|
query = """
|
||||||
|
UPDATE user_favorites
|
||||||
|
SET is_deleted = 1, updated_at = NOW()
|
||||||
|
WHERE user_id = %s AND stock_code = %s AND is_deleted = 0
|
||||||
|
"""
|
||||||
|
result = await cursor.execute(query, [user_id, stock_code])
|
||||||
|
|
||||||
|
if result > 0:
|
||||||
|
return {"success": True, "message": "删除自选股成功"}
|
||||||
|
else:
|
||||||
|
return {"success": False, "message": "未找到该自选股"}
|
||||||
|
|
||||||
|
|
||||||
|
async def add_favorite_event(user_id: str, event_id: int) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
添加自选事件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: 用户ID
|
||||||
|
event_id: 事件ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
操作结果
|
||||||
|
"""
|
||||||
|
pool = await get_pool()
|
||||||
|
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
async with conn.cursor(aiomysql.DictCursor) as cursor:
|
||||||
|
# 检查是否已存在
|
||||||
|
check_query = """
|
||||||
|
SELECT id, is_deleted
|
||||||
|
FROM user_event_favorites
|
||||||
|
WHERE user_id = %s AND event_id = %s
|
||||||
|
"""
|
||||||
|
await cursor.execute(check_query, [user_id, event_id])
|
||||||
|
existing = await cursor.fetchone()
|
||||||
|
|
||||||
|
if existing:
|
||||||
|
if existing['is_deleted'] == 1:
|
||||||
|
# 恢复已删除的记录
|
||||||
|
update_query = """
|
||||||
|
UPDATE user_event_favorites
|
||||||
|
SET is_deleted = 0, updated_at = NOW()
|
||||||
|
WHERE id = %s
|
||||||
|
"""
|
||||||
|
await cursor.execute(update_query, [existing['id']])
|
||||||
|
return {"success": True, "message": "已恢复自选事件"}
|
||||||
|
else:
|
||||||
|
return {"success": False, "message": "该事件已在自选中"}
|
||||||
|
|
||||||
|
# 插入新记录
|
||||||
|
insert_query = """
|
||||||
|
INSERT INTO user_event_favorites (user_id, event_id, created_at, updated_at, is_deleted)
|
||||||
|
VALUES (%s, %s, NOW(), NOW(), 0)
|
||||||
|
"""
|
||||||
|
await cursor.execute(insert_query, [user_id, event_id])
|
||||||
|
return {"success": True, "message": "添加自选事件成功"}
|
||||||
|
|
||||||
|
|
||||||
|
async def remove_favorite_event(user_id: str, event_id: int) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
删除自选事件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: 用户ID
|
||||||
|
event_id: 事件ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
操作结果
|
||||||
|
"""
|
||||||
|
pool = await get_pool()
|
||||||
|
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
async with conn.cursor(aiomysql.DictCursor) as cursor:
|
||||||
|
query = """
|
||||||
|
UPDATE user_event_favorites
|
||||||
|
SET is_deleted = 1, updated_at = NOW()
|
||||||
|
WHERE user_id = %s AND event_id = %s AND is_deleted = 0
|
||||||
|
"""
|
||||||
|
result = await cursor.execute(query, [user_id, event_id])
|
||||||
|
|
||||||
|
if result > 0:
|
||||||
|
return {"success": True, "message": "删除自选事件成功"}
|
||||||
|
else:
|
||||||
|
return {"success": False, "message": "未找到该自选事件"}
|
||||||
|
|||||||
320
mcp_elasticsearch.py
Normal file
320
mcp_elasticsearch.py
Normal file
@@ -0,0 +1,320 @@
|
|||||||
|
"""
|
||||||
|
Elasticsearch 连接和工具模块
|
||||||
|
用于聊天记录存储和向量搜索
|
||||||
|
"""
|
||||||
|
|
||||||
|
from elasticsearch import Elasticsearch, helpers
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import List, Dict, Any, Optional
|
||||||
|
import logging
|
||||||
|
import json
|
||||||
|
import openai
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# ==================== 配置 ====================
|
||||||
|
|
||||||
|
# ES 配置
|
||||||
|
ES_CONFIG = {
|
||||||
|
"host": "http://222.128.1.157:19200",
|
||||||
|
"index_chat_history": "agent_chat_history", # 聊天记录索引
|
||||||
|
}
|
||||||
|
|
||||||
|
# Embedding 配置
|
||||||
|
EMBEDDING_CONFIG = {
|
||||||
|
"api_key": "dummy",
|
||||||
|
"base_url": "http://222.128.1.157:18008/v1",
|
||||||
|
"model": "qwen3-embedding-8b",
|
||||||
|
"dims": 4096, # 向量维度
|
||||||
|
}
|
||||||
|
|
||||||
|
# ==================== ES 客户端 ====================
|
||||||
|
|
||||||
|
class ESClient:
|
||||||
|
"""Elasticsearch 客户端封装"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.es = Elasticsearch([ES_CONFIG["host"]], request_timeout=60)
|
||||||
|
self.chat_index = ES_CONFIG["index_chat_history"]
|
||||||
|
|
||||||
|
# 初始化 OpenAI 客户端用于 embedding
|
||||||
|
self.embedding_client = openai.OpenAI(
|
||||||
|
api_key=EMBEDDING_CONFIG["api_key"],
|
||||||
|
base_url=EMBEDDING_CONFIG["base_url"],
|
||||||
|
)
|
||||||
|
self.embedding_model = EMBEDDING_CONFIG["model"]
|
||||||
|
|
||||||
|
# 初始化索引
|
||||||
|
self.create_chat_history_index()
|
||||||
|
|
||||||
|
def create_chat_history_index(self):
|
||||||
|
"""创建聊天记录索引"""
|
||||||
|
if self.es.indices.exists(index=self.chat_index):
|
||||||
|
logger.info(f"索引 {self.chat_index} 已存在")
|
||||||
|
return
|
||||||
|
|
||||||
|
mappings = {
|
||||||
|
"properties": {
|
||||||
|
"session_id": {"type": "keyword"}, # 会话ID
|
||||||
|
"user_id": {"type": "keyword"}, # 用户ID
|
||||||
|
"user_nickname": {"type": "text"}, # 用户昵称
|
||||||
|
"user_avatar": {"type": "keyword"}, # 用户头像URL
|
||||||
|
"message_type": {"type": "keyword"}, # user / assistant
|
||||||
|
"message": {"type": "text"}, # 消息内容
|
||||||
|
"message_embedding": { # 消息向量
|
||||||
|
"type": "dense_vector",
|
||||||
|
"dims": EMBEDDING_CONFIG["dims"],
|
||||||
|
"index": True,
|
||||||
|
"similarity": "cosine"
|
||||||
|
},
|
||||||
|
"plan": {"type": "text"}, # 执行计划(仅 assistant)
|
||||||
|
"steps": {"type": "text"}, # 执行步骤(仅 assistant)
|
||||||
|
"timestamp": {"type": "date"}, # 时间戳
|
||||||
|
"created_at": {"type": "date"}, # 创建时间
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
self.es.indices.create(index=self.chat_index, body={"mappings": mappings})
|
||||||
|
logger.info(f"创建索引: {self.chat_index}")
|
||||||
|
|
||||||
|
def generate_embedding(self, text: str) -> List[float]:
|
||||||
|
"""生成文本向量"""
|
||||||
|
try:
|
||||||
|
if not text or len(text.strip()) == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 截断过长文本
|
||||||
|
text = text[:16000] if len(text) > 16000 else text
|
||||||
|
|
||||||
|
response = self.embedding_client.embeddings.create(
|
||||||
|
model=self.embedding_model,
|
||||||
|
input=[text]
|
||||||
|
)
|
||||||
|
return response.data[0].embedding
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Embedding 生成失败: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def save_chat_message(
|
||||||
|
self,
|
||||||
|
session_id: str,
|
||||||
|
user_id: str,
|
||||||
|
user_nickname: str,
|
||||||
|
user_avatar: str,
|
||||||
|
message_type: str, # "user" or "assistant"
|
||||||
|
message: str,
|
||||||
|
plan: Optional[str] = None,
|
||||||
|
steps: Optional[str] = None,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
保存聊天消息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: 会话ID
|
||||||
|
user_id: 用户ID
|
||||||
|
user_nickname: 用户昵称
|
||||||
|
user_avatar: 用户头像URL
|
||||||
|
message_type: 消息类型 (user/assistant)
|
||||||
|
message: 消息内容
|
||||||
|
plan: 执行计划(可选)
|
||||||
|
steps: 执行步骤(可选)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
文档ID
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 生成向量
|
||||||
|
embedding = self.generate_embedding(message)
|
||||||
|
|
||||||
|
doc = {
|
||||||
|
"session_id": session_id,
|
||||||
|
"user_id": user_id,
|
||||||
|
"user_nickname": user_nickname,
|
||||||
|
"user_avatar": user_avatar,
|
||||||
|
"message_type": message_type,
|
||||||
|
"message": message,
|
||||||
|
"message_embedding": embedding if embedding else None,
|
||||||
|
"plan": plan,
|
||||||
|
"steps": steps,
|
||||||
|
"timestamp": datetime.now(),
|
||||||
|
"created_at": datetime.now(),
|
||||||
|
}
|
||||||
|
|
||||||
|
result = self.es.index(index=self.chat_index, body=doc)
|
||||||
|
logger.info(f"保存聊天记录: {result['_id']}")
|
||||||
|
return result["_id"]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"保存聊天记录失败: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def get_chat_sessions(self, user_id: str, limit: int = 50) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
获取用户的聊天会话列表
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: 用户ID
|
||||||
|
limit: 返回数量
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
会话列表,每个会话包含:session_id, last_message, last_timestamp
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 聚合查询:按 session_id 分组,获取每个会话的最后一条消息
|
||||||
|
query = {
|
||||||
|
"query": {
|
||||||
|
"term": {"user_id": user_id}
|
||||||
|
},
|
||||||
|
"aggs": {
|
||||||
|
"sessions": {
|
||||||
|
"terms": {
|
||||||
|
"field": "session_id",
|
||||||
|
"size": limit,
|
||||||
|
"order": {"last_message": "desc"}
|
||||||
|
},
|
||||||
|
"aggs": {
|
||||||
|
"last_message": {
|
||||||
|
"max": {"field": "timestamp"}
|
||||||
|
},
|
||||||
|
"last_message_content": {
|
||||||
|
"top_hits": {
|
||||||
|
"size": 1,
|
||||||
|
"sort": [{"timestamp": {"order": "desc"}}],
|
||||||
|
"_source": ["message", "timestamp", "message_type"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"size": 0
|
||||||
|
}
|
||||||
|
|
||||||
|
result = self.es.search(index=self.chat_index, body=query)
|
||||||
|
|
||||||
|
sessions = []
|
||||||
|
for bucket in result["aggregations"]["sessions"]["buckets"]:
|
||||||
|
session_data = bucket["last_message_content"]["hits"]["hits"][0]["_source"]
|
||||||
|
sessions.append({
|
||||||
|
"session_id": bucket["key"],
|
||||||
|
"last_message": session_data["message"],
|
||||||
|
"last_timestamp": session_data["timestamp"],
|
||||||
|
"message_count": bucket["doc_count"],
|
||||||
|
})
|
||||||
|
|
||||||
|
return sessions
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取会话列表失败: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def get_chat_history(
|
||||||
|
self,
|
||||||
|
session_id: str,
|
||||||
|
limit: int = 100
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
获取指定会话的聊天历史
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: 会话ID
|
||||||
|
limit: 返回数量
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
聊天记录列表
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
query = {
|
||||||
|
"query": {
|
||||||
|
"term": {"session_id": session_id}
|
||||||
|
},
|
||||||
|
"sort": [{"timestamp": {"order": "asc"}}],
|
||||||
|
"size": limit
|
||||||
|
}
|
||||||
|
|
||||||
|
result = self.es.search(index=self.chat_index, body=query)
|
||||||
|
|
||||||
|
messages = []
|
||||||
|
for hit in result["hits"]["hits"]:
|
||||||
|
doc = hit["_source"]
|
||||||
|
messages.append({
|
||||||
|
"message_type": doc["message_type"],
|
||||||
|
"message": doc["message"],
|
||||||
|
"plan": doc.get("plan"),
|
||||||
|
"steps": doc.get("steps"),
|
||||||
|
"timestamp": doc["timestamp"],
|
||||||
|
})
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取聊天历史失败: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def search_chat_history(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
query_text: str,
|
||||||
|
top_k: int = 10
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
向量搜索聊天历史
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: 用户ID
|
||||||
|
query_text: 查询文本
|
||||||
|
top_k: 返回数量
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
相关聊天记录列表
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 生成查询向量
|
||||||
|
query_embedding = self.generate_embedding(query_text)
|
||||||
|
if not query_embedding:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 向量搜索
|
||||||
|
query = {
|
||||||
|
"query": {
|
||||||
|
"bool": {
|
||||||
|
"must": [
|
||||||
|
{"term": {"user_id": user_id}},
|
||||||
|
{
|
||||||
|
"script_score": {
|
||||||
|
"query": {"match_all": {}},
|
||||||
|
"script": {
|
||||||
|
"source": "cosineSimilarity(params.query_vector, 'message_embedding') + 1.0",
|
||||||
|
"params": {"query_vector": query_embedding}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"size": top_k
|
||||||
|
}
|
||||||
|
|
||||||
|
result = self.es.search(index=self.chat_index, body=query)
|
||||||
|
|
||||||
|
messages = []
|
||||||
|
for hit in result["hits"]["hits"]:
|
||||||
|
doc = hit["_source"]
|
||||||
|
messages.append({
|
||||||
|
"session_id": doc["session_id"],
|
||||||
|
"message_type": doc["message_type"],
|
||||||
|
"message": doc["message"],
|
||||||
|
"timestamp": doc["timestamp"],
|
||||||
|
"score": hit["_score"],
|
||||||
|
})
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"向量搜索失败: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 全局实例 ====================
|
||||||
|
|
||||||
|
# 创建全局 ES 客户端
|
||||||
|
es_client = ESClient()
|
||||||
145
mcp_server.py
145
mcp_server.py
@@ -17,6 +17,8 @@ import mcp_database as db
|
|||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
import json
|
import json
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import uuid
|
||||||
|
from mcp_elasticsearch import es_client
|
||||||
|
|
||||||
# 配置日志
|
# 配置日志
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
@@ -135,6 +137,10 @@ class AgentChatRequest(BaseModel):
|
|||||||
"""聊天请求"""
|
"""聊天请求"""
|
||||||
message: str
|
message: str
|
||||||
conversation_history: List[ConversationMessage] = []
|
conversation_history: List[ConversationMessage] = []
|
||||||
|
user_id: Optional[str] = None # 用户ID
|
||||||
|
user_nickname: Optional[str] = None # 用户昵称
|
||||||
|
user_avatar: Optional[str] = None # 用户头像URL
|
||||||
|
session_id: Optional[str] = None # 会话ID(如果为空则创建新会话)
|
||||||
|
|
||||||
# ==================== MCP工具定义 ====================
|
# ==================== MCP工具定义 ====================
|
||||||
|
|
||||||
@@ -1023,7 +1029,14 @@ class MCPAgentIntegrated:
|
|||||||
for tool in tools
|
for tool in tools
|
||||||
])
|
])
|
||||||
|
|
||||||
return f"""你是一个专业的金融研究助手。根据用户问题,制定详细的执行计划。
|
return f"""你是"价小前",北京价值前沿科技公司的AI投研聊天助手。
|
||||||
|
|
||||||
|
## 你的人格特征
|
||||||
|
- **名字**: 价小前
|
||||||
|
- **身份**: 北京价值前沿科技公司的专业AI投研助手
|
||||||
|
- **专业领域**: 股票投资研究、市场分析、新闻解读、财务分析
|
||||||
|
- **性格**: 专业、严谨、友好,擅长用简洁的语言解释复杂的金融概念
|
||||||
|
- **服务宗旨**: 帮助投资者做出更明智的投资决策,提供数据驱动的研究支持
|
||||||
|
|
||||||
## 可用工具
|
## 可用工具
|
||||||
|
|
||||||
@@ -1040,7 +1053,7 @@ class MCPAgentIntegrated:
|
|||||||
- 概念板块: 相同题材股票分类
|
- 概念板块: 相同题材股票分类
|
||||||
|
|
||||||
## 任务
|
## 任务
|
||||||
分析用户问题,制定执行计划。返回 JSON:
|
分析用户问题,制定详细的执行计划。返回 JSON:
|
||||||
|
|
||||||
```json
|
```json
|
||||||
{{
|
{{
|
||||||
@@ -1533,7 +1546,32 @@ async def chat(request: ChatRequest):
|
|||||||
@app.post("/agent/chat", response_model=AgentResponse)
|
@app.post("/agent/chat", response_model=AgentResponse)
|
||||||
async def agent_chat(request: AgentChatRequest):
|
async def agent_chat(request: AgentChatRequest):
|
||||||
"""智能代理对话端点(非流式)"""
|
"""智能代理对话端点(非流式)"""
|
||||||
logger.info(f"Agent chat: {request.message}")
|
logger.info(f"Agent chat: {request.message} (user: {request.user_id})")
|
||||||
|
|
||||||
|
# ==================== 权限检查 ====================
|
||||||
|
# 仅允许 max 用户使用
|
||||||
|
if request.user_id != "max":
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=403,
|
||||||
|
detail="很抱歉,「价小前投研」功能目前仅对特定用户开放。如需使用,请联系管理员。"
|
||||||
|
)
|
||||||
|
|
||||||
|
# ==================== 会话管理 ====================
|
||||||
|
# 如果没有提供 session_id,创建新会话
|
||||||
|
session_id = request.session_id or str(uuid.uuid4())
|
||||||
|
|
||||||
|
# 保存用户消息到 ES
|
||||||
|
try:
|
||||||
|
es_client.save_chat_message(
|
||||||
|
session_id=session_id,
|
||||||
|
user_id=request.user_id or "anonymous",
|
||||||
|
user_nickname=request.user_nickname or "匿名用户",
|
||||||
|
user_avatar=request.user_avatar or "",
|
||||||
|
message_type="user",
|
||||||
|
message=request.message,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"保存用户消息失败: {e}")
|
||||||
|
|
||||||
# 获取工具列表
|
# 获取工具列表
|
||||||
tools = [tool.dict() for tool in TOOLS]
|
tools = [tool.dict() for tool in TOOLS]
|
||||||
@@ -1565,7 +1603,31 @@ async def agent_chat(request: AgentChatRequest):
|
|||||||
tool_handlers=TOOL_HANDLERS,
|
tool_handlers=TOOL_HANDLERS,
|
||||||
)
|
)
|
||||||
|
|
||||||
return response
|
# 保存 Agent 回复到 ES
|
||||||
|
try:
|
||||||
|
# 将执行步骤转换为JSON字符串
|
||||||
|
steps_json = json.dumps(
|
||||||
|
[{"tool": step.tool, "result": step.result} for step in response.steps],
|
||||||
|
ensure_ascii=False
|
||||||
|
)
|
||||||
|
|
||||||
|
es_client.save_chat_message(
|
||||||
|
session_id=session_id,
|
||||||
|
user_id=request.user_id or "anonymous",
|
||||||
|
user_nickname=request.user_nickname or "匿名用户",
|
||||||
|
user_avatar=request.user_avatar or "",
|
||||||
|
message_type="assistant",
|
||||||
|
message=response.final_answer,
|
||||||
|
plan=response.plan,
|
||||||
|
steps=steps_json,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"保存 Agent 回复失败: {e}")
|
||||||
|
|
||||||
|
# 在响应中返回 session_id
|
||||||
|
response_dict = response.dict()
|
||||||
|
response_dict["session_id"] = session_id
|
||||||
|
return response_dict
|
||||||
|
|
||||||
@app.post("/agent/chat/stream")
|
@app.post("/agent/chat/stream")
|
||||||
async def agent_chat_stream(request: AgentChatRequest):
|
async def agent_chat_stream(request: AgentChatRequest):
|
||||||
@@ -1610,6 +1672,81 @@ async def agent_chat_stream(request: AgentChatRequest):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# ==================== 聊天记录管理 API ====================
|
||||||
|
|
||||||
|
@app.get("/agent/sessions")
|
||||||
|
async def get_chat_sessions(user_id: str, limit: int = 50):
|
||||||
|
"""
|
||||||
|
获取用户的聊天会话列表
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: 用户ID
|
||||||
|
limit: 返回数量(默认50)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
会话列表
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
sessions = es_client.get_chat_sessions(user_id, limit)
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"data": sessions,
|
||||||
|
"count": len(sessions)
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取会话列表失败: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/agent/history/{session_id}")
|
||||||
|
async def get_chat_history(session_id: str, limit: int = 100):
|
||||||
|
"""
|
||||||
|
获取指定会话的聊天历史
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: 会话ID
|
||||||
|
limit: 返回数量(默认100)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
聊天记录列表
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
messages = es_client.get_chat_history(session_id, limit)
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"data": messages,
|
||||||
|
"count": len(messages)
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取聊天历史失败: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/agent/search")
|
||||||
|
async def search_chat_history(user_id: str, query: str, top_k: int = 10):
|
||||||
|
"""
|
||||||
|
向量搜索聊天历史
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: 用户ID
|
||||||
|
query: 查询文本
|
||||||
|
top_k: 返回数量(默认10)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
相关聊天记录列表
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
results = es_client.search_chat_history(user_id, query, top_k)
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"data": results,
|
||||||
|
"count": len(results)
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"向量搜索失败: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
# ==================== 健康检查 ====================
|
# ==================== 健康检查 ====================
|
||||||
|
|
||||||
@app.get("/health")
|
@app.get("/health")
|
||||||
|
|||||||
@@ -157,8 +157,8 @@ export const routeConfig = [
|
|||||||
protection: PROTECTION_MODES.MODAL,
|
protection: PROTECTION_MODES.MODAL,
|
||||||
layout: 'main',
|
layout: 'main',
|
||||||
meta: {
|
meta: {
|
||||||
title: 'AI投资助手',
|
title: '价小前投研',
|
||||||
description: '基于MCP的智能投资顾问'
|
description: '北京价值前沿科技公司的AI投研聊天助手'
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
];
|
];
|
||||||
|
|||||||
Reference in New Issue
Block a user