""" Elasticsearch 连接和工具模块 用于聊天记录存储和向量搜索 """ from elasticsearch import Elasticsearch, helpers from datetime import datetime from typing import List, Dict, Any, Optional import logging import json import openai logger = logging.getLogger(__name__) # ==================== 配置 ==================== # ES 配置 ES_CONFIG = { "host": "http://222.128.1.157:19200", "index_chat_history": "agent_chat_history", # 聊天记录索引 } # Embedding 配置 EMBEDDING_CONFIG = { "api_key": "dummy", "base_url": "http://222.128.1.157:18008/v1", "model": "qwen3-embedding-8b", "dims": 4096, # 向量维度 } # ==================== ES 客户端 ==================== class ESClient: """Elasticsearch 客户端封装""" def __init__(self): self.es = Elasticsearch([ES_CONFIG["host"]], request_timeout=60) self.chat_index = ES_CONFIG["index_chat_history"] # 初始化 OpenAI 客户端用于 embedding self.embedding_client = openai.OpenAI( api_key=EMBEDDING_CONFIG["api_key"], base_url=EMBEDDING_CONFIG["base_url"], ) self.embedding_model = EMBEDDING_CONFIG["model"] # 初始化索引 self.create_chat_history_index() def create_chat_history_index(self): """创建聊天记录索引""" if self.es.indices.exists(index=self.chat_index): logger.info(f"索引 {self.chat_index} 已存在") return mappings = { "properties": { "session_id": {"type": "keyword"}, # 会话ID "user_id": {"type": "keyword"}, # 用户ID "user_nickname": {"type": "text"}, # 用户昵称 "user_avatar": {"type": "keyword"}, # 用户头像URL "message_type": {"type": "keyword"}, # user / assistant "message": {"type": "text"}, # 消息内容 "message_embedding": { # 消息向量 "type": "dense_vector", "dims": EMBEDDING_CONFIG["dims"], "index": True, "similarity": "cosine" }, "plan": {"type": "text"}, # 执行计划(仅 assistant) "steps": {"type": "text"}, # 执行步骤(仅 assistant) "timestamp": {"type": "date"}, # 时间戳 "created_at": {"type": "date"}, # 创建时间 } } self.es.indices.create(index=self.chat_index, body={"mappings": mappings}) logger.info(f"创建索引: {self.chat_index}") def generate_embedding(self, text: str) -> List[float]: """生成文本向量""" try: if not text or len(text.strip()) == 0: return [] # 截断过长文本 text = text[:16000] if len(text) > 16000 else text response = self.embedding_client.embeddings.create( model=self.embedding_model, input=[text] ) return response.data[0].embedding except Exception as e: logger.error(f"Embedding 生成失败: {e}") return [] def save_chat_message( self, session_id: str, user_id: str, user_nickname: str, user_avatar: str, message_type: str, # "user" or "assistant" message: str, plan: Optional[str] = None, steps: Optional[str] = None, ) -> str: """ 保存聊天消息 Args: session_id: 会话ID user_id: 用户ID user_nickname: 用户昵称 user_avatar: 用户头像URL message_type: 消息类型 (user/assistant) message: 消息内容 plan: 执行计划(可选) steps: 执行步骤(可选) Returns: 文档ID """ try: # 生成向量 embedding = self.generate_embedding(message) doc = { "session_id": session_id, "user_id": user_id, "user_nickname": user_nickname, "user_avatar": user_avatar, "message_type": message_type, "message": message, "message_embedding": embedding if embedding else None, "plan": plan, "steps": steps, "timestamp": datetime.now(), "created_at": datetime.now(), } result = self.es.index(index=self.chat_index, body=doc) logger.info(f"保存聊天记录: {result['_id']}") return result["_id"] except Exception as e: logger.error(f"保存聊天记录失败: {e}") raise def get_chat_sessions(self, user_id: str, limit: int = 50) -> List[Dict[str, Any]]: """ 获取用户的聊天会话列表 Args: user_id: 用户ID limit: 返回数量 Returns: 会话列表,每个会话包含:session_id, last_message, last_timestamp """ try: # 聚合查询:按 session_id 分组,获取每个会话的最后一条消息 query = { "query": { "term": {"user_id": user_id} }, "aggs": { "sessions": { "terms": { "field": "session_id", "size": limit, "order": {"last_message": "desc"} }, "aggs": { "last_message": { "max": {"field": "timestamp"} }, "last_message_content": { "top_hits": { "size": 1, "sort": [{"timestamp": {"order": "desc"}}], "_source": ["message", "timestamp", "message_type"] } } } } }, "size": 0 } result = self.es.search(index=self.chat_index, body=query) sessions = [] for bucket in result["aggregations"]["sessions"]["buckets"]: session_data = bucket["last_message_content"]["hits"]["hits"][0]["_source"] sessions.append({ "session_id": bucket["key"], "last_message": session_data["message"], "last_timestamp": session_data["timestamp"], "message_count": bucket["doc_count"], }) return sessions except Exception as e: logger.error(f"获取会话列表失败: {e}") return [] def get_chat_history( self, session_id: str, limit: int = 100 ) -> List[Dict[str, Any]]: """ 获取指定会话的聊天历史 Args: session_id: 会话ID limit: 返回数量 Returns: 聊天记录列表 """ try: query = { "query": { "term": {"session_id": session_id} }, "sort": [{"timestamp": {"order": "asc"}}], "size": limit } result = self.es.search(index=self.chat_index, body=query) messages = [] for hit in result["hits"]["hits"]: doc = hit["_source"] messages.append({ "message_type": doc["message_type"], "message": doc["message"], "plan": doc.get("plan"), "steps": doc.get("steps"), "timestamp": doc["timestamp"], }) return messages except Exception as e: logger.error(f"获取聊天历史失败: {e}") return [] def search_chat_history( self, user_id: str, query_text: str, top_k: int = 10 ) -> List[Dict[str, Any]]: """ 向量搜索聊天历史 Args: user_id: 用户ID query_text: 查询文本 top_k: 返回数量 Returns: 相关聊天记录列表 """ try: # 生成查询向量 query_embedding = self.generate_embedding(query_text) if not query_embedding: return [] # 向量搜索 query = { "query": { "bool": { "must": [ {"term": {"user_id": user_id}}, { "script_score": { "query": {"match_all": {}}, "script": { "source": "cosineSimilarity(params.query_vector, 'message_embedding') + 1.0", "params": {"query_vector": query_embedding} } } } ] } }, "size": top_k } result = self.es.search(index=self.chat_index, body=query) messages = [] for hit in result["hits"]["hits"]: doc = hit["_source"] messages.append({ "session_id": doc["session_id"], "message_type": doc["message_type"], "message": doc["message"], "timestamp": doc["timestamp"], "score": hit["_score"], }) return messages except Exception as e: logger.error(f"向量搜索失败: {e}") return [] # ==================== 全局实例 ==================== # 创建全局 ES 客户端 es_client = ESClient()