agent功能开发增加MCP后端
This commit is contained in:
320
mcp_elasticsearch.py
Normal file
320
mcp_elasticsearch.py
Normal file
@@ -0,0 +1,320 @@
|
||||
"""
|
||||
Elasticsearch 连接和工具模块
|
||||
用于聊天记录存储和向量搜索
|
||||
"""
|
||||
|
||||
from elasticsearch import Elasticsearch, helpers
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any, Optional
|
||||
import logging
|
||||
import json
|
||||
import openai
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ==================== 配置 ====================
|
||||
|
||||
# ES 配置
|
||||
ES_CONFIG = {
|
||||
"host": "http://222.128.1.157:19200",
|
||||
"index_chat_history": "agent_chat_history", # 聊天记录索引
|
||||
}
|
||||
|
||||
# Embedding 配置
|
||||
EMBEDDING_CONFIG = {
|
||||
"api_key": "dummy",
|
||||
"base_url": "http://222.128.1.157:18008/v1",
|
||||
"model": "qwen3-embedding-8b",
|
||||
"dims": 4096, # 向量维度
|
||||
}
|
||||
|
||||
# ==================== ES 客户端 ====================
|
||||
|
||||
class ESClient:
|
||||
"""Elasticsearch 客户端封装"""
|
||||
|
||||
def __init__(self):
|
||||
self.es = Elasticsearch([ES_CONFIG["host"]], request_timeout=60)
|
||||
self.chat_index = ES_CONFIG["index_chat_history"]
|
||||
|
||||
# 初始化 OpenAI 客户端用于 embedding
|
||||
self.embedding_client = openai.OpenAI(
|
||||
api_key=EMBEDDING_CONFIG["api_key"],
|
||||
base_url=EMBEDDING_CONFIG["base_url"],
|
||||
)
|
||||
self.embedding_model = EMBEDDING_CONFIG["model"]
|
||||
|
||||
# 初始化索引
|
||||
self.create_chat_history_index()
|
||||
|
||||
def create_chat_history_index(self):
|
||||
"""创建聊天记录索引"""
|
||||
if self.es.indices.exists(index=self.chat_index):
|
||||
logger.info(f"索引 {self.chat_index} 已存在")
|
||||
return
|
||||
|
||||
mappings = {
|
||||
"properties": {
|
||||
"session_id": {"type": "keyword"}, # 会话ID
|
||||
"user_id": {"type": "keyword"}, # 用户ID
|
||||
"user_nickname": {"type": "text"}, # 用户昵称
|
||||
"user_avatar": {"type": "keyword"}, # 用户头像URL
|
||||
"message_type": {"type": "keyword"}, # user / assistant
|
||||
"message": {"type": "text"}, # 消息内容
|
||||
"message_embedding": { # 消息向量
|
||||
"type": "dense_vector",
|
||||
"dims": EMBEDDING_CONFIG["dims"],
|
||||
"index": True,
|
||||
"similarity": "cosine"
|
||||
},
|
||||
"plan": {"type": "text"}, # 执行计划(仅 assistant)
|
||||
"steps": {"type": "text"}, # 执行步骤(仅 assistant)
|
||||
"timestamp": {"type": "date"}, # 时间戳
|
||||
"created_at": {"type": "date"}, # 创建时间
|
||||
}
|
||||
}
|
||||
|
||||
self.es.indices.create(index=self.chat_index, body={"mappings": mappings})
|
||||
logger.info(f"创建索引: {self.chat_index}")
|
||||
|
||||
def generate_embedding(self, text: str) -> List[float]:
|
||||
"""生成文本向量"""
|
||||
try:
|
||||
if not text or len(text.strip()) == 0:
|
||||
return []
|
||||
|
||||
# 截断过长文本
|
||||
text = text[:16000] if len(text) > 16000 else text
|
||||
|
||||
response = self.embedding_client.embeddings.create(
|
||||
model=self.embedding_model,
|
||||
input=[text]
|
||||
)
|
||||
return response.data[0].embedding
|
||||
except Exception as e:
|
||||
logger.error(f"Embedding 生成失败: {e}")
|
||||
return []
|
||||
|
||||
def save_chat_message(
|
||||
self,
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
user_nickname: str,
|
||||
user_avatar: str,
|
||||
message_type: str, # "user" or "assistant"
|
||||
message: str,
|
||||
plan: Optional[str] = None,
|
||||
steps: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
保存聊天消息
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
user_id: 用户ID
|
||||
user_nickname: 用户昵称
|
||||
user_avatar: 用户头像URL
|
||||
message_type: 消息类型 (user/assistant)
|
||||
message: 消息内容
|
||||
plan: 执行计划(可选)
|
||||
steps: 执行步骤(可选)
|
||||
|
||||
Returns:
|
||||
文档ID
|
||||
"""
|
||||
try:
|
||||
# 生成向量
|
||||
embedding = self.generate_embedding(message)
|
||||
|
||||
doc = {
|
||||
"session_id": session_id,
|
||||
"user_id": user_id,
|
||||
"user_nickname": user_nickname,
|
||||
"user_avatar": user_avatar,
|
||||
"message_type": message_type,
|
||||
"message": message,
|
||||
"message_embedding": embedding if embedding else None,
|
||||
"plan": plan,
|
||||
"steps": steps,
|
||||
"timestamp": datetime.now(),
|
||||
"created_at": datetime.now(),
|
||||
}
|
||||
|
||||
result = self.es.index(index=self.chat_index, body=doc)
|
||||
logger.info(f"保存聊天记录: {result['_id']}")
|
||||
return result["_id"]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"保存聊天记录失败: {e}")
|
||||
raise
|
||||
|
||||
def get_chat_sessions(self, user_id: str, limit: int = 50) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取用户的聊天会话列表
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
limit: 返回数量
|
||||
|
||||
Returns:
|
||||
会话列表,每个会话包含:session_id, last_message, last_timestamp
|
||||
"""
|
||||
try:
|
||||
# 聚合查询:按 session_id 分组,获取每个会话的最后一条消息
|
||||
query = {
|
||||
"query": {
|
||||
"term": {"user_id": user_id}
|
||||
},
|
||||
"aggs": {
|
||||
"sessions": {
|
||||
"terms": {
|
||||
"field": "session_id",
|
||||
"size": limit,
|
||||
"order": {"last_message": "desc"}
|
||||
},
|
||||
"aggs": {
|
||||
"last_message": {
|
||||
"max": {"field": "timestamp"}
|
||||
},
|
||||
"last_message_content": {
|
||||
"top_hits": {
|
||||
"size": 1,
|
||||
"sort": [{"timestamp": {"order": "desc"}}],
|
||||
"_source": ["message", "timestamp", "message_type"]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"size": 0
|
||||
}
|
||||
|
||||
result = self.es.search(index=self.chat_index, body=query)
|
||||
|
||||
sessions = []
|
||||
for bucket in result["aggregations"]["sessions"]["buckets"]:
|
||||
session_data = bucket["last_message_content"]["hits"]["hits"][0]["_source"]
|
||||
sessions.append({
|
||||
"session_id": bucket["key"],
|
||||
"last_message": session_data["message"],
|
||||
"last_timestamp": session_data["timestamp"],
|
||||
"message_count": bucket["doc_count"],
|
||||
})
|
||||
|
||||
return sessions
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取会话列表失败: {e}")
|
||||
return []
|
||||
|
||||
def get_chat_history(
|
||||
self,
|
||||
session_id: str,
|
||||
limit: int = 100
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取指定会话的聊天历史
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
limit: 返回数量
|
||||
|
||||
Returns:
|
||||
聊天记录列表
|
||||
"""
|
||||
try:
|
||||
query = {
|
||||
"query": {
|
||||
"term": {"session_id": session_id}
|
||||
},
|
||||
"sort": [{"timestamp": {"order": "asc"}}],
|
||||
"size": limit
|
||||
}
|
||||
|
||||
result = self.es.search(index=self.chat_index, body=query)
|
||||
|
||||
messages = []
|
||||
for hit in result["hits"]["hits"]:
|
||||
doc = hit["_source"]
|
||||
messages.append({
|
||||
"message_type": doc["message_type"],
|
||||
"message": doc["message"],
|
||||
"plan": doc.get("plan"),
|
||||
"steps": doc.get("steps"),
|
||||
"timestamp": doc["timestamp"],
|
||||
})
|
||||
|
||||
return messages
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取聊天历史失败: {e}")
|
||||
return []
|
||||
|
||||
def search_chat_history(
|
||||
self,
|
||||
user_id: str,
|
||||
query_text: str,
|
||||
top_k: int = 10
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
向量搜索聊天历史
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
query_text: 查询文本
|
||||
top_k: 返回数量
|
||||
|
||||
Returns:
|
||||
相关聊天记录列表
|
||||
"""
|
||||
try:
|
||||
# 生成查询向量
|
||||
query_embedding = self.generate_embedding(query_text)
|
||||
if not query_embedding:
|
||||
return []
|
||||
|
||||
# 向量搜索
|
||||
query = {
|
||||
"query": {
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"user_id": user_id}},
|
||||
{
|
||||
"script_score": {
|
||||
"query": {"match_all": {}},
|
||||
"script": {
|
||||
"source": "cosineSimilarity(params.query_vector, 'message_embedding') + 1.0",
|
||||
"params": {"query_vector": query_embedding}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"size": top_k
|
||||
}
|
||||
|
||||
result = self.es.search(index=self.chat_index, body=query)
|
||||
|
||||
messages = []
|
||||
for hit in result["hits"]["hits"]:
|
||||
doc = hit["_source"]
|
||||
messages.append({
|
||||
"session_id": doc["session_id"],
|
||||
"message_type": doc["message_type"],
|
||||
"message": doc["message"],
|
||||
"timestamp": doc["timestamp"],
|
||||
"score": hit["_score"],
|
||||
})
|
||||
|
||||
return messages
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"向量搜索失败: {e}")
|
||||
return []
|
||||
|
||||
|
||||
# ==================== 全局实例 ====================
|
||||
|
||||
# 创建全局 ES 客户端
|
||||
es_client = ESClient()
|
||||
Reference in New Issue
Block a user