321 lines
10 KiB
Python
321 lines
10 KiB
Python
"""
|
||
Elasticsearch 连接和工具模块
|
||
用于聊天记录存储和向量搜索
|
||
"""
|
||
|
||
from elasticsearch import Elasticsearch, helpers
|
||
from datetime import datetime
|
||
from typing import List, Dict, Any, Optional
|
||
import logging
|
||
import json
|
||
import openai
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# ==================== 配置 ====================
|
||
|
||
# ES 配置
|
||
ES_CONFIG = {
|
||
"host": "http://222.128.1.157:19200",
|
||
"index_chat_history": "agent_chat_history", # 聊天记录索引
|
||
}
|
||
|
||
# Embedding 配置
|
||
EMBEDDING_CONFIG = {
|
||
"api_key": "dummy",
|
||
"base_url": "http://222.128.1.157:18008/v1",
|
||
"model": "qwen3-embedding-8b",
|
||
"dims": 4096, # 向量维度
|
||
}
|
||
|
||
# ==================== ES 客户端 ====================
|
||
|
||
class ESClient:
|
||
"""Elasticsearch 客户端封装"""
|
||
|
||
def __init__(self):
|
||
self.es = Elasticsearch([ES_CONFIG["host"]], request_timeout=60)
|
||
self.chat_index = ES_CONFIG["index_chat_history"]
|
||
|
||
# 初始化 OpenAI 客户端用于 embedding
|
||
self.embedding_client = openai.OpenAI(
|
||
api_key=EMBEDDING_CONFIG["api_key"],
|
||
base_url=EMBEDDING_CONFIG["base_url"],
|
||
)
|
||
self.embedding_model = EMBEDDING_CONFIG["model"]
|
||
|
||
# 初始化索引
|
||
self.create_chat_history_index()
|
||
|
||
def create_chat_history_index(self):
|
||
"""创建聊天记录索引"""
|
||
if self.es.indices.exists(index=self.chat_index):
|
||
logger.info(f"索引 {self.chat_index} 已存在")
|
||
return
|
||
|
||
mappings = {
|
||
"properties": {
|
||
"session_id": {"type": "keyword"}, # 会话ID
|
||
"user_id": {"type": "keyword"}, # 用户ID
|
||
"user_nickname": {"type": "text"}, # 用户昵称
|
||
"user_avatar": {"type": "keyword"}, # 用户头像URL
|
||
"message_type": {"type": "keyword"}, # user / assistant
|
||
"message": {"type": "text"}, # 消息内容
|
||
"message_embedding": { # 消息向量
|
||
"type": "dense_vector",
|
||
"dims": EMBEDDING_CONFIG["dims"],
|
||
"index": True,
|
||
"similarity": "cosine"
|
||
},
|
||
"plan": {"type": "text"}, # 执行计划(仅 assistant)
|
||
"steps": {"type": "text"}, # 执行步骤(仅 assistant)
|
||
"timestamp": {"type": "date"}, # 时间戳
|
||
"created_at": {"type": "date"}, # 创建时间
|
||
}
|
||
}
|
||
|
||
self.es.indices.create(index=self.chat_index, body={"mappings": mappings})
|
||
logger.info(f"创建索引: {self.chat_index}")
|
||
|
||
def generate_embedding(self, text: str) -> List[float]:
|
||
"""生成文本向量"""
|
||
try:
|
||
if not text or len(text.strip()) == 0:
|
||
return []
|
||
|
||
# 截断过长文本
|
||
text = text[:16000] if len(text) > 16000 else text
|
||
|
||
response = self.embedding_client.embeddings.create(
|
||
model=self.embedding_model,
|
||
input=[text]
|
||
)
|
||
return response.data[0].embedding
|
||
except Exception as e:
|
||
logger.error(f"Embedding 生成失败: {e}")
|
||
return []
|
||
|
||
def save_chat_message(
|
||
self,
|
||
session_id: str,
|
||
user_id: str,
|
||
user_nickname: str,
|
||
user_avatar: str,
|
||
message_type: str, # "user" or "assistant"
|
||
message: str,
|
||
plan: Optional[str] = None,
|
||
steps: Optional[str] = None,
|
||
) -> str:
|
||
"""
|
||
保存聊天消息
|
||
|
||
Args:
|
||
session_id: 会话ID
|
||
user_id: 用户ID
|
||
user_nickname: 用户昵称
|
||
user_avatar: 用户头像URL
|
||
message_type: 消息类型 (user/assistant)
|
||
message: 消息内容
|
||
plan: 执行计划(可选)
|
||
steps: 执行步骤(可选)
|
||
|
||
Returns:
|
||
文档ID
|
||
"""
|
||
try:
|
||
# 生成向量
|
||
embedding = self.generate_embedding(message)
|
||
|
||
doc = {
|
||
"session_id": session_id,
|
||
"user_id": user_id,
|
||
"user_nickname": user_nickname,
|
||
"user_avatar": user_avatar,
|
||
"message_type": message_type,
|
||
"message": message,
|
||
"message_embedding": embedding if embedding else None,
|
||
"plan": plan,
|
||
"steps": steps,
|
||
"timestamp": datetime.now(),
|
||
"created_at": datetime.now(),
|
||
}
|
||
|
||
result = self.es.index(index=self.chat_index, body=doc)
|
||
logger.info(f"保存聊天记录: {result['_id']}")
|
||
return result["_id"]
|
||
|
||
except Exception as e:
|
||
logger.error(f"保存聊天记录失败: {e}")
|
||
raise
|
||
|
||
def get_chat_sessions(self, user_id: str, limit: int = 50) -> List[Dict[str, Any]]:
|
||
"""
|
||
获取用户的聊天会话列表
|
||
|
||
Args:
|
||
user_id: 用户ID
|
||
limit: 返回数量
|
||
|
||
Returns:
|
||
会话列表,每个会话包含:session_id, last_message, last_timestamp
|
||
"""
|
||
try:
|
||
# 聚合查询:按 session_id 分组,获取每个会话的最后一条消息
|
||
query = {
|
||
"query": {
|
||
"term": {"user_id": user_id}
|
||
},
|
||
"aggs": {
|
||
"sessions": {
|
||
"terms": {
|
||
"field": "session_id",
|
||
"size": limit,
|
||
"order": {"last_message": "desc"}
|
||
},
|
||
"aggs": {
|
||
"last_message": {
|
||
"max": {"field": "timestamp"}
|
||
},
|
||
"last_message_content": {
|
||
"top_hits": {
|
||
"size": 1,
|
||
"sort": [{"timestamp": {"order": "desc"}}],
|
||
"_source": ["message", "timestamp", "message_type"]
|
||
}
|
||
}
|
||
}
|
||
}
|
||
},
|
||
"size": 0
|
||
}
|
||
|
||
result = self.es.search(index=self.chat_index, body=query)
|
||
|
||
sessions = []
|
||
for bucket in result["aggregations"]["sessions"]["buckets"]:
|
||
session_data = bucket["last_message_content"]["hits"]["hits"][0]["_source"]
|
||
sessions.append({
|
||
"session_id": bucket["key"],
|
||
"last_message": session_data["message"],
|
||
"last_timestamp": session_data["timestamp"],
|
||
"message_count": bucket["doc_count"],
|
||
})
|
||
|
||
return sessions
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取会话列表失败: {e}")
|
||
return []
|
||
|
||
def get_chat_history(
|
||
self,
|
||
session_id: str,
|
||
limit: int = 100
|
||
) -> List[Dict[str, Any]]:
|
||
"""
|
||
获取指定会话的聊天历史
|
||
|
||
Args:
|
||
session_id: 会话ID
|
||
limit: 返回数量
|
||
|
||
Returns:
|
||
聊天记录列表
|
||
"""
|
||
try:
|
||
query = {
|
||
"query": {
|
||
"term": {"session_id": session_id}
|
||
},
|
||
"sort": [{"timestamp": {"order": "asc"}}],
|
||
"size": limit
|
||
}
|
||
|
||
result = self.es.search(index=self.chat_index, body=query)
|
||
|
||
messages = []
|
||
for hit in result["hits"]["hits"]:
|
||
doc = hit["_source"]
|
||
messages.append({
|
||
"message_type": doc["message_type"],
|
||
"message": doc["message"],
|
||
"plan": doc.get("plan"),
|
||
"steps": doc.get("steps"),
|
||
"timestamp": doc["timestamp"],
|
||
})
|
||
|
||
return messages
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取聊天历史失败: {e}")
|
||
return []
|
||
|
||
def search_chat_history(
|
||
self,
|
||
user_id: str,
|
||
query_text: str,
|
||
top_k: int = 10
|
||
) -> List[Dict[str, Any]]:
|
||
"""
|
||
向量搜索聊天历史
|
||
|
||
Args:
|
||
user_id: 用户ID
|
||
query_text: 查询文本
|
||
top_k: 返回数量
|
||
|
||
Returns:
|
||
相关聊天记录列表
|
||
"""
|
||
try:
|
||
# 生成查询向量
|
||
query_embedding = self.generate_embedding(query_text)
|
||
if not query_embedding:
|
||
return []
|
||
|
||
# 向量搜索
|
||
query = {
|
||
"query": {
|
||
"bool": {
|
||
"must": [
|
||
{"term": {"user_id": user_id}},
|
||
{
|
||
"script_score": {
|
||
"query": {"match_all": {}},
|
||
"script": {
|
||
"source": "cosineSimilarity(params.query_vector, 'message_embedding') + 1.0",
|
||
"params": {"query_vector": query_embedding}
|
||
}
|
||
}
|
||
}
|
||
]
|
||
}
|
||
},
|
||
"size": top_k
|
||
}
|
||
|
||
result = self.es.search(index=self.chat_index, body=query)
|
||
|
||
messages = []
|
||
for hit in result["hits"]["hits"]:
|
||
doc = hit["_source"]
|
||
messages.append({
|
||
"session_id": doc["session_id"],
|
||
"message_type": doc["message_type"],
|
||
"message": doc["message"],
|
||
"timestamp": doc["timestamp"],
|
||
"score": hit["_score"],
|
||
})
|
||
|
||
return messages
|
||
|
||
except Exception as e:
|
||
logger.error(f"向量搜索失败: {e}")
|
||
return []
|
||
|
||
|
||
# ==================== 全局实例 ====================
|
||
|
||
# 创建全局 ES 客户端
|
||
es_client = ESClient()
|