agent功能开发增加MCP后端

This commit is contained in:
2025-11-07 21:46:50 +08:00
parent d1d8d1a25d
commit 3a058fd805
5 changed files with 2098 additions and 6 deletions

320
mcp_elasticsearch.py Normal file
View File

@@ -0,0 +1,320 @@
"""
Elasticsearch 连接和工具模块
用于聊天记录存储和向量搜索
"""
from elasticsearch import Elasticsearch, helpers
from datetime import datetime
from typing import List, Dict, Any, Optional
import logging
import json
import openai
logger = logging.getLogger(__name__)
# ==================== 配置 ====================
# ES 配置
ES_CONFIG = {
"host": "http://222.128.1.157:19200",
"index_chat_history": "agent_chat_history", # 聊天记录索引
}
# Embedding 配置
EMBEDDING_CONFIG = {
"api_key": "dummy",
"base_url": "http://222.128.1.157:18008/v1",
"model": "qwen3-embedding-8b",
"dims": 4096, # 向量维度
}
# ==================== ES 客户端 ====================
class ESClient:
"""Elasticsearch 客户端封装"""
def __init__(self):
self.es = Elasticsearch([ES_CONFIG["host"]], request_timeout=60)
self.chat_index = ES_CONFIG["index_chat_history"]
# 初始化 OpenAI 客户端用于 embedding
self.embedding_client = openai.OpenAI(
api_key=EMBEDDING_CONFIG["api_key"],
base_url=EMBEDDING_CONFIG["base_url"],
)
self.embedding_model = EMBEDDING_CONFIG["model"]
# 初始化索引
self.create_chat_history_index()
def create_chat_history_index(self):
"""创建聊天记录索引"""
if self.es.indices.exists(index=self.chat_index):
logger.info(f"索引 {self.chat_index} 已存在")
return
mappings = {
"properties": {
"session_id": {"type": "keyword"}, # 会话ID
"user_id": {"type": "keyword"}, # 用户ID
"user_nickname": {"type": "text"}, # 用户昵称
"user_avatar": {"type": "keyword"}, # 用户头像URL
"message_type": {"type": "keyword"}, # user / assistant
"message": {"type": "text"}, # 消息内容
"message_embedding": { # 消息向量
"type": "dense_vector",
"dims": EMBEDDING_CONFIG["dims"],
"index": True,
"similarity": "cosine"
},
"plan": {"type": "text"}, # 执行计划(仅 assistant
"steps": {"type": "text"}, # 执行步骤(仅 assistant
"timestamp": {"type": "date"}, # 时间戳
"created_at": {"type": "date"}, # 创建时间
}
}
self.es.indices.create(index=self.chat_index, body={"mappings": mappings})
logger.info(f"创建索引: {self.chat_index}")
def generate_embedding(self, text: str) -> List[float]:
"""生成文本向量"""
try:
if not text or len(text.strip()) == 0:
return []
# 截断过长文本
text = text[:16000] if len(text) > 16000 else text
response = self.embedding_client.embeddings.create(
model=self.embedding_model,
input=[text]
)
return response.data[0].embedding
except Exception as e:
logger.error(f"Embedding 生成失败: {e}")
return []
def save_chat_message(
self,
session_id: str,
user_id: str,
user_nickname: str,
user_avatar: str,
message_type: str, # "user" or "assistant"
message: str,
plan: Optional[str] = None,
steps: Optional[str] = None,
) -> str:
"""
保存聊天消息
Args:
session_id: 会话ID
user_id: 用户ID
user_nickname: 用户昵称
user_avatar: 用户头像URL
message_type: 消息类型 (user/assistant)
message: 消息内容
plan: 执行计划(可选)
steps: 执行步骤(可选)
Returns:
文档ID
"""
try:
# 生成向量
embedding = self.generate_embedding(message)
doc = {
"session_id": session_id,
"user_id": user_id,
"user_nickname": user_nickname,
"user_avatar": user_avatar,
"message_type": message_type,
"message": message,
"message_embedding": embedding if embedding else None,
"plan": plan,
"steps": steps,
"timestamp": datetime.now(),
"created_at": datetime.now(),
}
result = self.es.index(index=self.chat_index, body=doc)
logger.info(f"保存聊天记录: {result['_id']}")
return result["_id"]
except Exception as e:
logger.error(f"保存聊天记录失败: {e}")
raise
def get_chat_sessions(self, user_id: str, limit: int = 50) -> List[Dict[str, Any]]:
"""
获取用户的聊天会话列表
Args:
user_id: 用户ID
limit: 返回数量
Returns:
会话列表每个会话包含session_id, last_message, last_timestamp
"""
try:
# 聚合查询:按 session_id 分组,获取每个会话的最后一条消息
query = {
"query": {
"term": {"user_id": user_id}
},
"aggs": {
"sessions": {
"terms": {
"field": "session_id",
"size": limit,
"order": {"last_message": "desc"}
},
"aggs": {
"last_message": {
"max": {"field": "timestamp"}
},
"last_message_content": {
"top_hits": {
"size": 1,
"sort": [{"timestamp": {"order": "desc"}}],
"_source": ["message", "timestamp", "message_type"]
}
}
}
}
},
"size": 0
}
result = self.es.search(index=self.chat_index, body=query)
sessions = []
for bucket in result["aggregations"]["sessions"]["buckets"]:
session_data = bucket["last_message_content"]["hits"]["hits"][0]["_source"]
sessions.append({
"session_id": bucket["key"],
"last_message": session_data["message"],
"last_timestamp": session_data["timestamp"],
"message_count": bucket["doc_count"],
})
return sessions
except Exception as e:
logger.error(f"获取会话列表失败: {e}")
return []
def get_chat_history(
self,
session_id: str,
limit: int = 100
) -> List[Dict[str, Any]]:
"""
获取指定会话的聊天历史
Args:
session_id: 会话ID
limit: 返回数量
Returns:
聊天记录列表
"""
try:
query = {
"query": {
"term": {"session_id": session_id}
},
"sort": [{"timestamp": {"order": "asc"}}],
"size": limit
}
result = self.es.search(index=self.chat_index, body=query)
messages = []
for hit in result["hits"]["hits"]:
doc = hit["_source"]
messages.append({
"message_type": doc["message_type"],
"message": doc["message"],
"plan": doc.get("plan"),
"steps": doc.get("steps"),
"timestamp": doc["timestamp"],
})
return messages
except Exception as e:
logger.error(f"获取聊天历史失败: {e}")
return []
def search_chat_history(
self,
user_id: str,
query_text: str,
top_k: int = 10
) -> List[Dict[str, Any]]:
"""
向量搜索聊天历史
Args:
user_id: 用户ID
query_text: 查询文本
top_k: 返回数量
Returns:
相关聊天记录列表
"""
try:
# 生成查询向量
query_embedding = self.generate_embedding(query_text)
if not query_embedding:
return []
# 向量搜索
query = {
"query": {
"bool": {
"must": [
{"term": {"user_id": user_id}},
{
"script_score": {
"query": {"match_all": {}},
"script": {
"source": "cosineSimilarity(params.query_vector, 'message_embedding') + 1.0",
"params": {"query_vector": query_embedding}
}
}
}
]
}
},
"size": top_k
}
result = self.es.search(index=self.chat_index, body=query)
messages = []
for hit in result["hits"]["hits"]:
doc = hit["_source"]
messages.append({
"session_id": doc["session_id"],
"message_type": doc["message_type"],
"message": doc["message"],
"timestamp": doc["timestamp"],
"score": hit["_score"],
})
return messages
except Exception as e:
logger.error(f"向量搜索失败: {e}")
return []
# ==================== 全局实例 ====================
# 创建全局 ES 客户端
es_client = ESClient()