Files
vf_react/mcp_elasticsearch.py

321 lines
10 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Elasticsearch 连接和工具模块
用于聊天记录存储和向量搜索
"""
from elasticsearch import Elasticsearch, helpers
from datetime import datetime
from typing import List, Dict, Any, Optional
import logging
import json
import openai
logger = logging.getLogger(__name__)
# ==================== 配置 ====================
# ES 配置
ES_CONFIG = {
"host": "http://222.128.1.157:19200",
"index_chat_history": "agent_chat_history", # 聊天记录索引
}
# Embedding 配置
EMBEDDING_CONFIG = {
"api_key": "dummy",
"base_url": "http://222.128.1.157:18008/v1",
"model": "qwen3-embedding-8b",
"dims": 4096, # 向量维度
}
# ==================== ES 客户端 ====================
class ESClient:
"""Elasticsearch 客户端封装"""
def __init__(self):
self.es = Elasticsearch([ES_CONFIG["host"]], request_timeout=60)
self.chat_index = ES_CONFIG["index_chat_history"]
# 初始化 OpenAI 客户端用于 embedding
self.embedding_client = openai.OpenAI(
api_key=EMBEDDING_CONFIG["api_key"],
base_url=EMBEDDING_CONFIG["base_url"],
)
self.embedding_model = EMBEDDING_CONFIG["model"]
# 初始化索引
self.create_chat_history_index()
def create_chat_history_index(self):
"""创建聊天记录索引"""
if self.es.indices.exists(index=self.chat_index):
logger.info(f"索引 {self.chat_index} 已存在")
return
mappings = {
"properties": {
"session_id": {"type": "keyword"}, # 会话ID
"user_id": {"type": "keyword"}, # 用户ID
"user_nickname": {"type": "text"}, # 用户昵称
"user_avatar": {"type": "keyword"}, # 用户头像URL
"message_type": {"type": "keyword"}, # user / assistant
"message": {"type": "text"}, # 消息内容
"message_embedding": { # 消息向量
"type": "dense_vector",
"dims": EMBEDDING_CONFIG["dims"],
"index": True,
"similarity": "cosine"
},
"plan": {"type": "text"}, # 执行计划(仅 assistant
"steps": {"type": "text"}, # 执行步骤(仅 assistant
"timestamp": {"type": "date"}, # 时间戳
"created_at": {"type": "date"}, # 创建时间
}
}
self.es.indices.create(index=self.chat_index, body={"mappings": mappings})
logger.info(f"创建索引: {self.chat_index}")
def generate_embedding(self, text: str) -> List[float]:
"""生成文本向量"""
try:
if not text or len(text.strip()) == 0:
return []
# 截断过长文本
text = text[:16000] if len(text) > 16000 else text
response = self.embedding_client.embeddings.create(
model=self.embedding_model,
input=[text]
)
return response.data[0].embedding
except Exception as e:
logger.error(f"Embedding 生成失败: {e}")
return []
def save_chat_message(
self,
session_id: str,
user_id: str,
user_nickname: str,
user_avatar: str,
message_type: str, # "user" or "assistant"
message: str,
plan: Optional[str] = None,
steps: Optional[str] = None,
) -> str:
"""
保存聊天消息
Args:
session_id: 会话ID
user_id: 用户ID
user_nickname: 用户昵称
user_avatar: 用户头像URL
message_type: 消息类型 (user/assistant)
message: 消息内容
plan: 执行计划(可选)
steps: 执行步骤(可选)
Returns:
文档ID
"""
try:
# 生成向量
embedding = self.generate_embedding(message)
doc = {
"session_id": session_id,
"user_id": user_id,
"user_nickname": user_nickname,
"user_avatar": user_avatar,
"message_type": message_type,
"message": message,
"message_embedding": embedding if embedding else None,
"plan": plan,
"steps": steps,
"timestamp": datetime.now(),
"created_at": datetime.now(),
}
result = self.es.index(index=self.chat_index, body=doc)
logger.info(f"保存聊天记录: {result['_id']}")
return result["_id"]
except Exception as e:
logger.error(f"保存聊天记录失败: {e}")
raise
def get_chat_sessions(self, user_id: str, limit: int = 50) -> List[Dict[str, Any]]:
"""
获取用户的聊天会话列表
Args:
user_id: 用户ID
limit: 返回数量
Returns:
会话列表每个会话包含session_id, last_message, last_timestamp
"""
try:
# 聚合查询:按 session_id 分组,获取每个会话的最后一条消息
query = {
"query": {
"term": {"user_id": user_id}
},
"aggs": {
"sessions": {
"terms": {
"field": "session_id",
"size": limit,
"order": {"last_message": "desc"}
},
"aggs": {
"last_message": {
"max": {"field": "timestamp"}
},
"last_message_content": {
"top_hits": {
"size": 1,
"sort": [{"timestamp": {"order": "desc"}}],
"_source": ["message", "timestamp", "message_type"]
}
}
}
}
},
"size": 0
}
result = self.es.search(index=self.chat_index, body=query)
sessions = []
for bucket in result["aggregations"]["sessions"]["buckets"]:
session_data = bucket["last_message_content"]["hits"]["hits"][0]["_source"]
sessions.append({
"session_id": bucket["key"],
"last_message": session_data["message"],
"last_timestamp": session_data["timestamp"],
"message_count": bucket["doc_count"],
})
return sessions
except Exception as e:
logger.error(f"获取会话列表失败: {e}")
return []
def get_chat_history(
self,
session_id: str,
limit: int = 100
) -> List[Dict[str, Any]]:
"""
获取指定会话的聊天历史
Args:
session_id: 会话ID
limit: 返回数量
Returns:
聊天记录列表
"""
try:
query = {
"query": {
"term": {"session_id": session_id}
},
"sort": [{"timestamp": {"order": "asc"}}],
"size": limit
}
result = self.es.search(index=self.chat_index, body=query)
messages = []
for hit in result["hits"]["hits"]:
doc = hit["_source"]
messages.append({
"message_type": doc["message_type"],
"message": doc["message"],
"plan": doc.get("plan"),
"steps": doc.get("steps"),
"timestamp": doc["timestamp"],
})
return messages
except Exception as e:
logger.error(f"获取聊天历史失败: {e}")
return []
def search_chat_history(
self,
user_id: str,
query_text: str,
top_k: int = 10
) -> List[Dict[str, Any]]:
"""
向量搜索聊天历史
Args:
user_id: 用户ID
query_text: 查询文本
top_k: 返回数量
Returns:
相关聊天记录列表
"""
try:
# 生成查询向量
query_embedding = self.generate_embedding(query_text)
if not query_embedding:
return []
# 向量搜索
query = {
"query": {
"bool": {
"must": [
{"term": {"user_id": user_id}},
{
"script_score": {
"query": {"match_all": {}},
"script": {
"source": "cosineSimilarity(params.query_vector, 'message_embedding') + 1.0",
"params": {"query_vector": query_embedding}
}
}
}
]
}
},
"size": top_k
}
result = self.es.search(index=self.chat_index, body=query)
messages = []
for hit in result["hits"]["hits"]:
doc = hit["_source"]
messages.append({
"session_id": doc["session_id"],
"message_type": doc["message_type"],
"message": doc["message"],
"timestamp": doc["timestamp"],
"score": hit["_score"],
})
return messages
except Exception as e:
logger.error(f"向量搜索失败: {e}")
return []
# ==================== 全局实例 ====================
# 创建全局 ES 客户端
es_client = ESClient()