""" Flask API for ES-based Stock Analysis System with Research Reports Integration 支持前后端分离的涨停股票分析系统API + 研究报告搜索API """ from flask import Flask, request, jsonify from flask_cors import CORS from elasticsearch import Elasticsearch, RequestError from datetime import datetime, timedelta from collections import defaultdict import json from typing import Dict, List, Any, Optional import logging import openai import time from enum import Enum app = Flask(__name__) CORS(app) # 启用跨域支持 # ES连接配置 ES_HOST = "http://192.168.1.231:9200" es = Elasticsearch( [ES_HOST], timeout=60, retry_on_timeout=True, max_retries=3, verify_certs=False ) # Embedding配置 EMBEDDING_BASE_URLS = [ "http://192.168.1.231:8000/v1" ] EMBEDDING_MODEL = "qwen3-embedding-8b" EMBEDDING_DIMENSION = 4096 current_embedding_index = 0 # 配置日志 logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # === 枚举类型 === class SearchMode(str, Enum): HYBRID = "hybrid" # 混合搜索(文本+向量) TEXT = "text" # 仅文本搜索 VECTOR = "vector" # 仅向量搜索 # === Embedding相关函数 === def get_embedding_client(): """获取embedding客户端,轮询使用不同的服务""" global current_embedding_index url = EMBEDDING_BASE_URLS[current_embedding_index % len(EMBEDDING_BASE_URLS)] current_embedding_index += 1 return openai.OpenAI( api_key="dummy", base_url=url, timeout=60, ), url def get_embedding(text: str) -> Optional[List[float]]: """生成文本向量""" if not text or len(text.strip()) == 0: return None try: client, service_url = get_embedding_client() # 限制文本长度 text = text[:8000] if len(text) > 8000 else text response = client.embeddings.create( model=EMBEDDING_MODEL, input=[text] ) logger.info(f"使用embedding服务: {service_url}") return response.data[0].embedding except Exception as e: logger.error(f"生成向量失败: {e}") return None class StockAnalysisAPI: """股票分析API类""" @staticmethod def format_date(date_str: str) -> str: """格式化日期字符串 YYYYMMDD -> YYYY-MM-DD""" if len(date_str) == 8: return f"{date_str[:4]}-{date_str[4:6]}-{date_str[6:]}" return date_str @staticmethod def parse_date(date_str: str) -> str: """解析日期字符串 YYYY-MM-DD -> YYYYMMDD""" if '-' in date_str: return date_str.replace('-', '') return date_str # ========== 原有的股票分析API路由 ========== @app.route('/api/v1/dates/available', methods=['GET']) def get_available_dates(): """ 获取所有可用的日期列表 用于日历组件显示 """ try: # 从ES获取所有日期的统计数据 query = { "size": 0, "aggs": { "dates": { "terms": { "field": "date", "size": 1000, "order": {"_key": "desc"} }, "aggs": { "stock_count": { "cardinality": { "field": "scode" } } } } } } result = es.search(index="zt_stocks", body=query) events = [] for bucket in result['aggregations']['dates']['buckets']: date = bucket['key'] count = bucket['doc_count'] formatted_date = StockAnalysisAPI.format_date(date) events.append({ 'title': f'{count}只', 'start': formatted_date, 'end': formatted_date, 'className': 'bg-gradient-primary', 'allDay': True, 'date': date, # 原始日期格式 'count': count }) return jsonify({ 'success': True, 'events': events, 'total': len(events) }) except Exception as e: logger.error(f"Error getting available dates: {str(e)}") return jsonify({ 'success': False, 'error': str(e) }), 500 @app.route('/api/v1/analysis/daily/', methods=['GET']) def get_daily_analysis(date): """ 获取指定日期的分析数据 包括板块分析、词云数据、图表数据等 """ try: logger.info(f"========== Getting daily analysis for date: {date} ==========") # 首先尝试从缓存获取 cache_query = { "query": { "bool": { "must": [ {"term": {"cache_key": date}}, {"term": {"cache_type": "daily_analysis"}} ] } } } try: cache_result = es.search(index="zt_cache", body=cache_query, size=1) logger.info(f"Cache query result: {cache_result['hits']['total']['value']} hits") if cache_result['hits']['total']['value'] > 0: cached_data = cache_result['hits']['hits'][0]['_source'] expires_at = datetime.fromisoformat(cached_data['expires_at'].replace('Z', '+00:00')) if expires_at > datetime.now(): logger.info(f"Returning cached data for {date}") return jsonify({ 'success': True, 'data': cached_data['data'], 'from_cache': True }) else: logger.info(f"Cache expired for {date}") except Exception as e: logger.warning(f"Cache query failed: {e}, continuing without cache") # 获取统计数据 stats_query = { "query": { "term": {"date": date} }, "_source": ["sector_stats", "word_freq", "chart_data", "summary"] } logger.info(f"Querying zt_daily_stats with: {json.dumps(stats_query)}") stats_result = es.search(index="zt_daily_stats", body=stats_query, size=1) logger.info(f"Stats result total: {stats_result['hits']['total']['value']}") if stats_result['hits']['total']['value'] > 0: stats = stats_result['hits']['hits'][0]['_source'] logger.info(f"Stats fields found: {list(stats.keys())}") logger.info(f"Sector_stats count: {len(stats.get('sector_stats', []))}") logger.info(f"Word_freq count: {len(stats.get('word_freq', []))}") else: stats = {} logger.warning(f"No stats found in zt_daily_stats for {date}") # 获取当日所有股票 stock_query = { "query": { "term": {"date": date} }, "size": 10000, "sort": [{"zt_time": "asc"}], "_source": { "exclude": ["content_embedding"] } } logger.info(f"Querying zt_stocks...") stock_result = es.search(index="zt_stocks", body=stock_query) logger.info(f"Found {stock_result['hits']['total']['value']} stocks") stocks = [] for hit in stock_result['hits']['hits']: stock = hit['_source'] if 'content_embedding' in stock: del stock['content_embedding'] stocks.append(stock) logger.info(f"Processed {len(stocks)} stocks") # 处理板块数据 if stats and 'sector_stats' in stats and len(stats['sector_stats']) > 0: logger.info(f"Using sector_stats from zt_daily_stats") sector_data = process_sector_data_from_stats(stats['sector_stats'], stocks) logger.info(f"Processed sector_data with {len(sector_data)} sectors") else: logger.info(f"No sector_stats found, generating from stocks") sector_data = process_sector_data(stocks) logger.info(f"Generated sector_data with {len(sector_data)} sectors") # 计算板块关联关系(用于TOP10横向柱状图) sector_relations_top10 = calculate_sector_relations_top10(stocks) logger.info(f"Calculated sector_relations_top10: {len(sector_relations_top10.get('labels', []))} items") # 准备响应数据 response_data = { 'date': date, 'formatted_date': StockAnalysisAPI.format_date(date), 'total_stocks': len(stocks), 'sector_data': sector_data, 'chart_data': stats.get('chart_data', generate_chart_data(sector_data)), 'word_freq_data': stats.get('word_freq', []), 'sector_relations_top10': sector_relations_top10, 'summary': stats.get('summary', generate_summary(stocks, sector_data)) } logger.info(f"Response data prepared successfully") logger.info( f"Response summary: total_stocks={response_data['total_stocks']}, sectors={len(response_data['sector_data'])}") # 更新缓存 try: update_cache(date, 'daily_analysis', response_data) logger.info(f"Cache updated for {date}") except Exception as e: logger.warning(f"Failed to update cache: {e}") return jsonify({ 'success': True, 'data': response_data, 'from_cache': False }) except Exception as e: logger.error(f"Error getting daily analysis for {date}: {str(e)}", exc_info=True) return jsonify({ 'success': False, 'error': str(e) }), 500 @app.route('/api/v1/stocks/search', methods=['POST']) def search_stocks(): """ 搜索股票 支持按股票代码、名称、板块等条件搜索 """ try: data = request.json keyword = data.get('keyword', '') date = data.get('date') sectors = data.get('sectors', []) page = data.get('page', 1) page_size = data.get('page_size', 20) # 构建查询条件 must_conditions = [] if date: must_conditions.append({"term": {"date": date}}) if sectors: must_conditions.append({ "terms": {"core_sectors": sectors} }) should_conditions = [] if keyword: should_conditions.extend([ {"match": {"sname": keyword}}, {"term": {"scode": keyword}}, {"match": {"brief": keyword}}, {"match": {"summary": keyword}} ]) query = { "query": { "bool": { "must": must_conditions, "should": should_conditions if should_conditions else None, "minimum_should_match": 1 if should_conditions else 0 } }, "from": (page - 1) * page_size, "size": page_size, "sort": [{"zt_time": "desc"}], "_source": { "excludes": ["content_embedding"] # 排除向量字段 }, "highlight": { "fields": { "sname": {}, "brief": {}, "summary": {} } } } # 移除None值 query["query"]["bool"] = {k: v for k, v in query["query"]["bool"].items() if v is not None} result = es.search(index="zt_stocks", body=query) stocks = [] for hit in result['hits']['hits']: stock = hit['_source'] if 'highlight' in hit: stock['highlight'] = hit['highlight'] stocks.append(stock) return jsonify({ 'success': True, 'data': { 'stocks': stocks, 'total': result['hits']['total']['value'], 'page': page, 'page_size': page_size, 'total_pages': (result['hits']['total']['value'] + page_size - 1) // page_size } }) except Exception as e: logger.error(f"Error searching stocks: {str(e)}") return jsonify({ 'success': False, 'error': str(e) }), 500 @app.route('/api/v1/stocks/search/hybrid', methods=['POST']) def hybrid_search_stocks(): """ 混合搜索涨停原因 结合关键词搜索和向量相似度搜索 优化逻辑: 1. 先严格匹配股票名称/代码,如果匹配到则只返回该股票的记录 2. 如果没有匹配到股票名称/代码,再走智能检索逻辑 3. 结果按涨停时间由新到旧排序 """ try: data = request.json query_text = data.get('query', '').strip() mode = data.get('mode', 'hybrid') # hybrid, text, vector date = data.get('date') date_range = data.get('date_range', {}) # {'start': 'YYYYMMDD', 'end': 'YYYYMMDD'} sectors = data.get('sectors', []) page = data.get('page', 1) page_size = data.get('page_size', 20) min_score = data.get('min_score') if not query_text: return jsonify({ 'success': False, 'error': 'Query text is required' }), 400 # 构建日期过滤条件 date_filter = [] if date: date_filter.append({"term": {"date": date}}) elif date_range: range_query = {} if date_range.get('start'): range_query['gte'] = date_range['start'] if date_range.get('end'): range_query['lte'] = date_range['end'] if range_query: date_filter.append({"range": {"date": range_query}}) if sectors: date_filter.append({"terms": {"core_sectors": sectors}}) # ========== 第一步:尝试精确匹配股票名称或代码 ========== exact_match_query = { "query": { "bool": { "must": date_filter.copy() if date_filter else [], "should": [ {"term": {"sname.keyword": {"value": query_text, "boost": 10.0}}}, {"term": {"scode": {"value": query_text, "boost": 10.0}}}, # 也支持部分匹配股票名称(如输入"比亚迪"匹配"比亚迪") {"match_phrase": {"sname": {"query": query_text, "boost": 8.0}}} ], "minimum_should_match": 1 } }, "size": 1000, # 获取足够多的结果用于后续过滤 "_source": { "excludes": ["content_embedding"] } } exact_result = es.search(index="zt_stocks", body=exact_match_query) exact_hits = exact_result['hits']['hits'] # 检查是否有精确匹配的股票 matched_stock_codes = set() is_stock_search = False if exact_hits: # 检查是否真的匹配到了股票名称或代码 for hit in exact_hits: stock = hit['_source'] sname = stock.get('sname', '') scode = stock.get('scode', '') # 判断是否精确匹配股票名称或代码 if query_text == sname or query_text == scode or query_text in sname: matched_stock_codes.add(scode) is_stock_search = True # ========== 第二步:根据匹配情况决定搜索策略 ========== if is_stock_search and matched_stock_codes: # 用户搜索的是特定股票,只返回该股票的所有涨停记录 logger.info(f"Stock name/code search detected: {query_text} -> {matched_stock_codes}") final_query = { "query": { "bool": { "must": date_filter.copy() if date_filter else [], "filter": [ {"terms": {"scode": list(matched_stock_codes)}} ] } }, "from": (page - 1) * page_size, "size": page_size, "sort": [ {"date": {"order": "desc"}}, # 先按日期降序 {"zt_time": {"order": "desc"}} # 再按涨停时间降序 ], "_source": { "excludes": ["content_embedding"] } } search_mode_actual = "exact_stock" else: # 用户搜索的不是股票名称,走智能检索逻辑 logger.info(f"Semantic search: {query_text}") must_conditions = date_filter.copy() if date_filter else [] should_conditions = [] # 文本搜索部分 if mode in ['text', 'hybrid']: text_queries = [ {"match": {"brief": {"query": query_text, "boost": 2.0}}}, {"match": {"summary": {"query": query_text, "boost": 1.5}}}, {"match": {"sname": {"query": query_text, "boost": 1.0}}} ] should_conditions.extend(text_queries) # 向量搜索部分 if mode in ['vector', 'hybrid']: query_vector = get_embedding(query_text) if query_vector: vector_score_query = { "script_score": { "query": {"match_all": {}}, "script": { "source": """ if (!doc['content_embedding'].empty) { double similarity = cosineSimilarity(params.query_vector, 'content_embedding') + 1.0; return similarity * params.boost; } return 0; """, "params": { "query_vector": query_vector, "boost": 2.0 if mode == 'vector' else 1.0 } } } } if mode == 'vector': final_query = { "query": { "bool": { "must": [vector_score_query] + must_conditions } } } else: should_conditions.append(vector_score_query) # 构建最终查询 if mode != 'vector': final_query = { "query": { "bool": { "must": must_conditions, "should": should_conditions, "minimum_should_match": 1 if should_conditions else 0 } } } # 添加分页、排序和高亮 final_query.update({ "from": (page - 1) * page_size, "size": page_size, "sort": [ {"date": {"order": "desc"}}, # 先按日期降序 {"zt_time": {"order": "desc"}} # 再按涨停时间降序 ], "_source": { "excludes": ["content_embedding"] }, "highlight": { "fields": { "brief": {"fragment_size": 150}, "summary": {"fragment_size": 200}, "sname": {} } } }) # 添加最小分数过滤 if min_score: final_query["min_score"] = min_score search_mode_actual = mode # ========== 第三步:执行搜索 ========== start_time = time.time() result = es.search(index="zt_stocks", body=final_query) took_ms = int((time.time() - start_time) * 1000) # 处理结果 stocks = [] for hit in result['hits']['hits']: stock = hit['_source'] stock['_score'] = hit['_score'] # 添加高亮信息 if 'highlight' in hit: stock['highlight'] = hit['highlight'] # 格式化时间 if 'zt_time' in stock: try: zt_time = datetime.fromisoformat(stock['zt_time'].replace('Z', '+00:00')) stock['formatted_time'] = zt_time.strftime('%H:%M:%S') except: pass stocks.append(stock) return jsonify({ 'success': True, 'data': { 'stocks': stocks, 'total': result['hits']['total']['value'], 'page': page, 'page_size': page_size, 'total_pages': (result['hits']['total']['value'] + page_size - 1) // page_size, 'took_ms': took_ms, 'search_mode': search_mode_actual, 'is_stock_search': is_stock_search, 'matched_stocks': list(matched_stock_codes) if is_stock_search else [] } }) except Exception as e: logger.error(f"Error in hybrid search: {str(e)}") return jsonify({ 'success': False, 'error': str(e) }), 500 @app.route('/api/limit-analyse/high-position-stocks', methods=['GET']) def get_high_position_stocks(): """ 获取高位股数据 高位股是指连续涨停天数较多的股票,通常风险较高 :param date: YYYYMMDD格式的日期 :return: 高位股列表及统计信息 """ try: date = request.args.get('date') if not date: return jsonify({ 'success': False, 'error': 'Date parameter is required' }), 400 # 查询当日所有涨停股票 # 由于continuous_days是keyword类型且格式为"X天Y板",无法直接用range查询 # 需要获取所有数据后在Python中过滤 query = { "query": { "bool": { "must": [ {"term": {"date": date}}, {"exists": {"field": "continuous_days"}} # 确保有continuous_days字段 ] } }, "size": 10000, "sort": [ {"zt_time": {"order": "asc"}} # 按涨停时间升序 ], "_source": { "exclude": ["content_embedding"] } } result = es.search(index="zt_stocks", body=query) stocks = [] total_continuous_days = 0 max_continuous_days = 0 for hit in result['hits']['hits']: stock = hit['_source'] # 解析continuous_days字段,格式如"4天3板" continuous_days_str = stock.get('continuous_days', '') continuous_days = 0 if continuous_days_str: # 尝试提取数字,格式通常是"X天Y板" import re match = re.search(r'(\d+)天', continuous_days_str) if match: continuous_days = int(match.group(1)) else: # 如果格式不同,尝试提取任意数字 numbers = re.findall(r'\d+', continuous_days_str) if numbers: continuous_days = int(numbers[0]) # 只保留连续涨停天数>=2的股票(高位股) if continuous_days < 2: continue total_continuous_days += continuous_days max_continuous_days = max(max_continuous_days, continuous_days) # 构建股票数据(匹配前端需要的字段) stock_data = { 'stock_code': stock.get('scode', ''), 'stock_name': stock.get('sname', ''), 'price': 0, # 数据中没有价格字段,返回0或可以从其他接口获取 'increase_rate': 10.0, # 涨停一般是10% 'continuous_limit_up': continuous_days, 'continuous_days_str': continuous_days_str, # 原始字符串 'industry': stock.get('core_sectors', [''])[0] if stock.get('core_sectors') else '其他', 'turnover_rate': 0, # 数据中没有换手率字段 'brief': stock.get('brief', '').replace('
', ' ').replace('\n', ' '), 'summary': stock.get('summary', ''), 'zt_time': stock.get('zt_time', ''), 'formatted_time': stock.get('formatted_time', ''), 'core_sectors': stock.get('core_sectors', []) } # 格式化涨停时间 if not stock_data['formatted_time'] and stock_data['zt_time']: try: from datetime import datetime zt_time = datetime.fromisoformat(stock_data['zt_time'].replace('Z', '+00:00')) stock_data['formatted_time'] = zt_time.strftime('%H:%M:%S') except: pass stocks.append(stock_data) # 按连续涨停天数降序排序 stocks.sort(key=lambda x: x['continuous_limit_up'], reverse=True) # 计算统计信息 total_count = len(stocks) avg_continuous_days = round(total_continuous_days / total_count, 1) if total_count > 0 else 0 # 按连续涨停天数分组统计 distribution = {} for stock in stocks: days = stock['continuous_limit_up'] if days not in distribution: distribution[days] = 0 distribution[days] += 1 response_data = { 'stocks': stocks[:50], # 限制返回最多50只股票 'statistics': { 'total_count': total_count, 'avg_continuous_days': avg_continuous_days, 'max_continuous_days': max_continuous_days, 'distribution': distribution # 各连板天数的分布 }, 'date': date, 'formatted_date': StockAnalysisAPI.format_date(date) } logger.info( f"Found {total_count} high position stocks for date {date} (max continuous days: {max_continuous_days})") return jsonify({ 'success': True, 'data': response_data }) except Exception as e: logger.error(f"Error getting high position stocks for {date}: {str(e)}", exc_info=True) return jsonify({ 'success': False, 'error': str(e) }), 500 @app.route('/api/v1/sectors/relations/', methods=['GET']) def get_sector_relations(date): """ 获取板块关联关系数据 用于绘制板块关系图 """ try: # 获取当日所有股票 query = { "query": { "term": {"date": date} }, "size": 10000, "_source": { "include": ["scode", "sname", "core_sectors"] # 只获取需要的字段 } } result = es.search(index="zt_stocks", body=query) stocks = [hit['_source'] for hit in result['hits']['hits']] # 计算板块关联关系 relations = calculate_sector_relations(stocks) return jsonify({ 'success': True, 'data': relations }) except Exception as e: logger.error(f"Error getting sector relations for {date}: {str(e)}") return jsonify({ 'success': False, 'error': str(e) }), 500 @app.route('/api/v1/stocks/detail/', methods=['GET']) def get_stock_detail(scode): """ 获取股票详细信息 """ try: date = request.args.get('date') query = { "query": { "bool": { "must": [ {"term": {"scode": scode}} ] } }, "_source": { "exclude": ["content_embedding"] # ES 8.x 使用 exclude } } if date: query["query"]["bool"]["must"].append({"term": {"date": date}}) result = es.search(index="zt_stocks", body=query, size=1) if result['hits']['total']['value'] == 0: return jsonify({ 'success': False, 'error': 'Stock not found' }), 404 stock = result['hits']['hits'][0]['_source'] # 确保删除content_embedding if 'content_embedding' in stock: del stock['content_embedding'] # 获取历史涨停记录 history_query = { "query": { "term": {"scode": scode} }, "sort": [{"date": "desc"}], "size": 30, "_source": { "exclude": ["content_embedding"] } } history_result = es.search(index="zt_stocks", body=history_query) history = [] for hit in history_result['hits']['hits']: h = hit['_source'] if 'content_embedding' in h: del h['content_embedding'] history.append(h) return jsonify({ 'success': True, 'data': { 'stock': stock, 'history': history } }) except Exception as e: logger.error(f"Error getting stock detail for {scode}: {str(e)}") return jsonify({ 'success': False, 'error': str(e) }), 500 @app.route('/api/v1/analysis/trends', methods=['GET']) def get_market_trends(): """ 获取市场趋势分析 分析最近N天的板块轮动情况 """ try: days = int(request.args.get('days', 7)) end_date = datetime.now() start_date = end_date - timedelta(days=days) # 获取日期范围内的统计数据 query = { "query": { "range": { "date": { "gte": start_date.strftime('%Y%m%d'), "lte": end_date.strftime('%Y%m%d') } } }, "size": 0, "aggs": { "by_date": { "terms": { "field": "date", "size": days, "order": {"_key": "asc"} }, "aggs": { "sectors": { "terms": { "field": "core_sectors", "size": 20 } } } } } } result = es.search(index="zt_stocks", body=query) trends = [] for bucket in result['aggregations']['by_date']['buckets']: date = bucket['key'] sectors = [ { 'name': sector_bucket['key'], 'count': sector_bucket['doc_count'] } for sector_bucket in bucket['sectors']['buckets'] ] trends.append({ 'date': date, 'formatted_date': StockAnalysisAPI.format_date(date), 'sectors': sectors, 'total': bucket['doc_count'] }) return jsonify({ 'success': True, 'data': { 'trends': trends, 'days': days } }) except Exception as e: logger.error(f"Error getting market trends: {str(e)}") return jsonify({ 'success': False, 'error': str(e) }), 500 @app.route('/api/v1/analysis/wordcloud/', methods=['GET']) def get_wordcloud_data(date): """ 获取词云图数据 专门为词云图组件提供数据 :param date: YYYYMMDD格式的日期 """ try: # 从zt_daily_stats获取词云数据 query = { "query": { "term": {"date": date} }, "_source": ["word_freq"] } result = es.search(index="zt_daily_stats", body=query, size=1) if result['hits']['total']['value'] > 0: word_freq = result['hits']['hits'][0]['_source'].get('word_freq', []) # 确保数据格式正确(echarts-wordcloud需要的格式) formatted_data = [] for item in word_freq: formatted_data.append({ 'name': item.get('name', ''), 'value': max(100, item.get('value', 100)) # 确保最小值为100 }) return jsonify({ 'success': True, 'data': formatted_data[:200] # 限制最多200个词 }) else: # 如果没有预计算的词云数据,从股票数据生成 return generate_wordcloud_from_stocks(date) except Exception as e: logger.error(f"Error getting wordcloud data for {date}: {str(e)}") return jsonify({ 'success': False, 'error': str(e) }), 500 @app.route('/api/v1/sectors/top-relations/', methods=['GET']) def get_sector_top_relations(date): """ 获取板块关联TOP10数据 专门为横向柱状图提供数据 :param date: YYYYMMDD格式的日期 """ try: # 获取当日所有股票 query = { "query": { "term": {"date": date} }, "_source": ["scode", "core_sectors"], "size": 10000 } result = es.search(index="zt_stocks", body=query) stocks = [hit['_source'] for hit in result['hits']['hits']] # 计算板块关联TOP10 relations_data = calculate_sector_relations_top10(stocks) return jsonify({ 'success': True, 'data': relations_data }) except Exception as e: logger.error(f"Error getting sector top relations for {date}: {str(e)}") return jsonify({ 'success': False, 'error': str(e) }), 500 @app.route('/api/v1/stocks/batch-detail', methods=['POST']) def get_stocks_batch_detail(): """ 批量获取股票详情 用于板块展开时显示所有股票 """ try: data = request.json stock_codes = data.get('codes', []) date = data.get('date') if not stock_codes: return jsonify({ 'success': False, 'error': 'No stock codes provided' }), 400 # 构建查询 must_conditions = [ {"terms": {"scode": stock_codes}} ] if date: must_conditions.append({"term": {"date": date}}) query = { "query": { "bool": { "must": must_conditions } }, "size": len(stock_codes), "sort": [{"zt_time": "asc"}], "_source": { "exclude": ["content_embedding"] # ES 8.x 使用 exclude } } result = es.search(index="zt_stocks", body=query) stocks = [] for hit in result['hits']['hits']: stock = hit['_source'] # 确保删除content_embedding if 'content_embedding' in stock: del stock['content_embedding'] # 格式化时间 if 'zt_time' in stock: try: zt_time = datetime.fromisoformat(stock['zt_time'].replace('Z', '+00:00')) stock['formatted_time'] = zt_time.strftime('%H:%M:%S') except: stock['formatted_time'] = '' # 处理brief和summary的格式 stock['brief'] = stock.get('brief', '').replace('\n', '
').replace('#', '').replace('*', '') stock['summary'] = stock.get('summary', '').replace('\n', '
').replace('#', '').replace('*', '') stocks.append(stock) return jsonify({ 'success': True, 'data': stocks }) except Exception as e: logger.error(f"Error getting batch stock details: {str(e)}") return jsonify({ 'success': False, 'error': str(e) }), 500 @app.route('/api/v1/init/data', methods=['GET']) def initialize_data(): """ 初始化页面数据 获取最新日期的所有必要数据 减少前端多次请求 """ try: # 获取最新日期 date_query = { "size": 0, "aggs": { "latest_date": { "terms": { "field": "date", "size": 1, "order": {"_key": "desc"} } } } } date_result = es.search(index="zt_stocks", body=date_query) if not date_result['aggregations']['latest_date']['buckets']: return jsonify({ 'success': False, 'error': 'No data available' }), 404 latest_date = date_result['aggregations']['latest_date']['buckets'][0]['key'] # 获取该日期的完整分析数据 analysis_response = get_daily_analysis(latest_date) analysis_data = json.loads(analysis_response.data) if not analysis_data['success']: return analysis_response # 获取可用日期列表 dates_response = get_available_dates() dates_data = json.loads(dates_response.data) # 组合返回数据 return jsonify({ 'success': True, 'data': { 'latest_date': latest_date, 'formatted_date': StockAnalysisAPI.format_date(latest_date), 'analysis': analysis_data['data'], 'available_dates': dates_data['events'] } }) except Exception as e: logger.error(f"Error initializing data: {str(e)}") return jsonify({ 'success': False, 'error': str(e) }), 500 @app.route('/api/v1/statistics/overview', methods=['GET']) def get_statistics_overview(): """ 获取统计概览 包括总体统计、热门板块、连板统计等 """ try: # 多维度聚合查询 query = { "size": 0, "aggs": { "total_stocks": { "cardinality": { "field": "scode" } }, "total_days": { "cardinality": { "field": "date" } }, "top_sectors": { "terms": { "field": "core_sectors", "size": 10 } }, "continuous_days_stats": { "terms": { "field": "continuous_days", "size": 10 } }, "recent_trends": { "date_histogram": { "field": "zt_time", "calendar_interval": "day", "min_doc_count": 0, "extended_bounds": { "min": "now-30d", "max": "now" } } } } } result = es.search(index="zt_stocks", body=query) aggs = result['aggregations'] overview = { 'total_stocks': aggs['total_stocks']['value'], 'total_days': aggs['total_days']['value'], 'top_sectors': [ { 'name': bucket['key'], 'count': bucket['doc_count'] } for bucket in aggs['top_sectors']['buckets'] ], 'continuous_stats': [ { 'days': bucket['key'], 'count': bucket['doc_count'] } for bucket in aggs['continuous_days_stats']['buckets'] ], 'recent_trends': [ { 'date': bucket['key_as_string'], 'count': bucket['doc_count'] } for bucket in aggs['recent_trends']['buckets'] ] } return jsonify({ 'success': True, 'data': overview }) except Exception as e: logger.error(f"Error getting statistics overview: {str(e)}") return jsonify({ 'success': False, 'error': str(e) }), 500 # ========== 研究报告搜索API路由 ========== @app.route('/search', methods=['GET']) def search_reports(): """ 搜索研究报告 搜索功能: - 文本搜索:在content字段中搜索 - 向量搜索:基于content_embedding的语义搜索 - 混合搜索:结合文本和向量搜索 - 精确搜索:使用exact_match参数进行精确匹配 - 可选过滤:证券代码、日期范围 """ try: # 获取参数 query = request.args.get('query', '') mode = request.args.get('mode', 'hybrid') # hybrid, text, vector exact_match = request.args.get('exact_match', '0') # 0: 模糊匹配, 1: 精确匹配 security_code = request.args.get('security_code') start_date = request.args.get('start_date') end_date = request.args.get('end_date') size = int(request.args.get('size', 10)) from_ = int(request.args.get('from', 0)) min_score = request.args.get('min_score', type=float) if not query: return jsonify({ 'success': False, 'error': 'Query parameter is required' }), 400 start_time = time.time() # 构建查询条件 must_conditions = [] should_conditions = [] # 证券代码过滤 if security_code: must_conditions.append({"term": {"security_code": security_code}}) # 日期范围过滤 if start_date or end_date: date_range = {} if start_date: date_range["gte"] = start_date if end_date: date_range["lte"] = end_date must_conditions.append({"range": {"declare_date": date_range}}) # 文本搜索 - 根据exact_match参数选择匹配方式 if mode in ['text', 'hybrid']: if exact_match == '1': # 精确匹配模式 # 使用match_phrase进行短语精确匹配 text_queries = [ { "match_phrase": { "content": { "query": query, "boost": 2.0 } } }, { "match_phrase": { "report_title": { "query": query, "boost": 3.0 } } } ] # 对于证券名称和作者,也支持精确匹配 if len(query) <= 10: # 短查询可能是名称或作者 text_queries.extend([ { "term": { "security_name.keyword": { "value": query, "boost": 2.5 } } }, { "term": { "author.keyword": { "value": query, "boost": 2.0 } } }, { "term": { "publisher.keyword": { "value": query, "boost": 1.5 } } } ]) should_conditions.extend(text_queries) else: # 模糊匹配模式(原有逻辑) text_queries = [ { "match": { "content": { "query": query, "operator": "or", "boost": 1.0 } } }, { "match": { "report_title": { "query": query, "boost": 2.0 } } }, { "match": { "security_name": { "query": query, "boost": 1.5 } } }, { "match": { "author": { "query": query, "boost": 1.2 } } }, { "match": { "publisher": { "query": query, "boost": 1.0 } } } ] should_conditions.extend(text_queries) # 向量搜索 if mode in ['vector', 'hybrid']: query_vector = get_embedding(query) if query_vector: vector_query = { "script_score": { "query": {"match_all": {}} if mode == 'vector' else { "bool": {"must": must_conditions}}, "script": { "source": """ double similarity = cosineSimilarity(params.query_vector, 'content_embedding') + 1.0; return similarity; """, "params": { "query_vector": query_vector } } } } if mode == 'vector': # 纯向量搜索 if must_conditions: vector_query["query"] = {"bool": {"must": must_conditions}} final_query = { "query": vector_query, "from": from_, "size": size } else: # 混合搜索时,向量搜索作为should条件 should_conditions.append(vector_query) # 构建最终查询(非纯向量搜索) if mode != 'vector': final_query = { "query": { "bool": { "must": must_conditions, "should": should_conditions, "minimum_should_match": 1 if should_conditions else 0 } }, "from": from_, "size": size } # 返回所有字段(除了embedding) final_query["_source"] = { "excludes": ["content_embedding"] } # 添加高亮显示 if mode in ['text', 'hybrid']: final_query["highlight"] = { "fields": { "content": { "fragment_size": 200, "number_of_fragments": 3, "pre_tags": [""], "post_tags": [""] }, "report_title": { "fragment_size": 100, "number_of_fragments": 1, "pre_tags": [""], "post_tags": [""] } } } # 添加最小得分过滤 if min_score is not None: final_query["min_score"] = min_score logger.info(f"执行查询: mode={mode}, exact_match={exact_match}, query={query[:50]}") logger.debug(f"查询详情: {json.dumps(final_query, ensure_ascii=False)[:500]}") # 执行搜索 response = es.search(index="research_reports", body=final_query) # 处理结果 results = [] for hit in response['hits']['hits']: source = hit['_source'] result = { 'object_id': str(source.get('object_id', '')), 'report_title': source.get('report_title', ''), 'security_code': source.get('security_code'), 'security_name': source.get('security_name'), 'author': source.get('author'), 'rating': source.get('rating'), 'publisher': source.get('publisher'), 'declare_date': source.get('declare_date'), 'content': source.get('content', ''), 'content_url': source.get('content_url'), 'processed_time': source.get('processed_time'), 'score': hit['_score'] } # 添加高亮内容 if 'highlight' in hit: result['highlight'] = hit['highlight'] results.append(result) # 计算执行时间 took_ms = int((time.time() - start_time) * 1000) return jsonify({ 'success': True, 'data': { 'total': response['hits']['total']['value'], 'results': results, 'took_ms': took_ms, 'search_mode': mode, 'exact_match': exact_match == '1' } }) except RequestError as e: logger.error(f"Elasticsearch查询失败: {e.info}") error_detail = e.info.get('error', {}).get('root_cause', [{}])[0].get('reason', str(e)) return jsonify({ 'success': False, 'error': f"搜索查询错误: {error_detail}" }), 400 except Exception as e: logger.error(f"搜索时发生意外错误: {str(e)}", exc_info=True) return jsonify({ 'success': False, 'error': str(e) }), 500 @app.route('/reports/', methods=['GET']) def get_report_by_id(object_id): """根据ID获取特定研究报告的详细信息""" try: # 排除embedding字段 response = es.get( index="research_reports", id=object_id, _source_excludes=["content_embedding"] ) if not response['found']: return jsonify({ 'success': False, 'error': f'Report with ID {object_id} not found' }), 404 source = response['_source'] # 确保object_id是字符串 source['object_id'] = str(source.get('object_id', '')) return jsonify({ 'success': True, 'data': source }) except RequestError as e: if e.status_code == 404: return jsonify({ 'success': False, 'error': f'Report with ID {object_id} not found' }), 404 logger.error(f"Elasticsearch错误: {e}") return jsonify({ 'success': False, 'error': str(e) }), 500 except Exception as e: logger.error(f"意外错误: {str(e)}", exc_info=True) return jsonify({ 'success': False, 'error': str(e) }), 500 @app.route('/reports//similar', methods=['GET']) def get_similar_reports(object_id): """获取与指定研报相似的其他研报(基于向量相似度)""" try: size = int(request.args.get('size', 5)) # 首先获取原研报的向量 original = es.get( index="research_reports", id=object_id, _source=["content_embedding"] ) if not original['found']: return jsonify({ 'success': False, 'error': f'Report with ID {object_id} not found' }), 404 embedding = original['_source'].get('content_embedding') if not embedding: return jsonify({ 'success': False, 'error': 'Report does not have embedding data' }), 400 # 向量相似度搜索 query = { "query": { "script_score": { "query": { "bool": { "must_not": [ {"term": {"object_id": object_id}} # 排除自己 ] } }, "script": { "source": """ double similarity = cosineSimilarity(params.query_vector, 'content_embedding') + 1.0; return similarity; """, "params": { "query_vector": embedding } } } }, "size": size, "_source": { "excludes": ["content_embedding", "content"] # 不返回内容和向量 } } response = es.search(index="research_reports", body=query) similar_reports = [] for hit in response['hits']['hits']: source = hit['_source'] source['object_id'] = str(source.get('object_id', '')) source['similarity_score'] = hit['_score'] similar_reports.append(source) return jsonify({ 'success': True, 'data': { 'object_id': object_id, 'similar_reports': similar_reports, 'total': len(similar_reports) } }) except Exception as e: logger.error(f"获取相似研报失败: {str(e)}", exc_info=True) return jsonify({ 'success': False, 'error': str(e) }), 500 # ========== 辅助函数 ========== def process_sector_data_from_stats(sector_stats: List[Dict], stocks: List[Dict]) -> Dict: """从zt_daily_stats的sector_stats处理板块数据""" logger.info(f"process_sector_data_from_stats called with {len(sector_stats)} sectors and {len(stocks)} stocks") sector_data = {} # 创建股票代码到股票信息的映射 stock_map = {stock['scode']: stock for stock in stocks} logger.info(f"Created stock_map with {len(stock_map)} entries") # 处理每个板块 for idx, sector_info in enumerate(sector_stats): sector_name = sector_info['sector_name'] stock_codes = sector_info.get('stock_codes', []) logger.debug(f"Processing sector {idx}: {sector_name} with {len(stock_codes)} stocks") # 获取该板块的所有股票详情 sector_stocks = [] missing_stocks = [] for scode in stock_codes: if scode in stock_map: stock = stock_map[scode].copy() # 使用copy避免修改原始数据 # 格式化时间 if 'zt_time' in stock: try: zt_time = datetime.fromisoformat(stock['zt_time'].replace('Z', '+00:00')) stock['formatted_time'] = zt_time.strftime('%H:%M:%S') except: stock['formatted_time'] = '' else: stock['formatted_time'] = '' # 处理brief和summary stock['brief'] = stock.get('brief', '') if not stock['brief'] and stock.get('summary'): stock['brief'] = stock['summary'][:100] + '...' # 清理格式 try: stock['brief'] = stock['brief'].replace('\n', '
').replace('#', '').replace('*', '') stock['summary'] = stock.get('summary', '').replace('\n', '
').replace('#', '').replace('*', '') except: pass sector_stocks.append(stock) else: missing_stocks.append(scode) if missing_stocks: logger.warning(f"Sector {sector_name}: Missing stocks in stock_map: {missing_stocks}") # 按涨停时间排序 sector_stocks.sort(key=lambda x: x.get('zt_time', '')) sector_data[sector_name] = { 'count': sector_info['count'], 'stocks': sector_stocks } logger.debug(f"Sector {sector_name}: found {len(sector_stocks)}/{len(stock_codes)} stocks") # 按照特定顺序排序 sorted_items = [] announcement_item = None other_item = None normal_items = [] for sector_name, data in sector_data.items(): if sector_name == '公告': announcement_item = (sector_name, data) elif sector_name == '其他': other_item = (sector_name, data) else: normal_items.append((sector_name, data)) # 按数量排序普通板块 normal_items.sort(key=lambda x: x[1]['count'], reverse=True) # 组合最终顺序 if announcement_item: sorted_items.append(announcement_item) sorted_items.extend(normal_items) if other_item: sorted_items.append(other_item) result = dict(sorted_items) logger.info(f"process_sector_data_from_stats completed with {len(result)} sectors") return result def process_sector_data(stocks: List[Dict]) -> Dict: """处理板块数据""" sector_data = defaultdict(lambda: {"count": 0, "stocks": []}) # 统计板块出现次数 sector_counts = defaultdict(int) for stock in stocks: for sector in stock.get('core_sectors', []): sector_counts[sector] += 1 # 筛选小板块(出现次数少于2的) small_sectors = {sector for sector, count in sector_counts.items() if count < 2} # 分类股票 for stock in stocks: # 处理公告类 if stock.get('effective_announcements'): sector_data['公告']['count'] += 1 sector_data['公告']['stocks'].append(stock) continue # 获取有效板块 valid_sectors = [sector for sector in stock.get('core_sectors', []) if sector not in small_sectors] if valid_sectors: for sector in valid_sectors: sector_data[sector]['count'] += 1 sector_data[sector]['stocks'].append(stock) else: sector_data['其他']['count'] += 1 sector_data['其他']['stocks'].append(stock) # 排序 sorted_data = dict(sorted( sector_data.items(), key=lambda x: (x[0] != '公告', -x[1]['count']) )) return sorted_data def generate_chart_data(sector_data: Dict) -> Dict: """生成图表数据""" labels = [] counts = [] for sector, data in sector_data.items(): if sector not in ['其他', '公告']: labels.append(sector) counts.append(data['count']) return { 'labels': labels, 'counts': counts } def generate_summary(stocks: List[Dict], sector_data: Dict) -> Dict: """生成汇总信息""" announcement_stocks = sum(1 for stock in stocks if stock.get('effective_announcements')) # 计算涨停时间分布 time_distribution = { 'morning': 0, # 9:30-11:30 'midday': 0, # 11:30-13:00 'afternoon': 0 # 13:00-15:00 } for stock in stocks: if 'zt_time' in stock: try: zt_time = datetime.fromisoformat(stock['zt_time'].replace('Z', '+00:00')) hour = zt_time.hour if 9 <= hour < 11 or (hour == 11 and zt_time.minute <= 30): time_distribution['morning'] += 1 elif hour == 11 or hour == 12: time_distribution['midday'] += 1 else: time_distribution['afternoon'] += 1 except: pass # 找出最大板块 top_sector = None top_count = 0 for sector, data in sector_data.items(): if sector not in ['其他', '公告'] and data['count'] > top_count: top_sector = sector top_count = data['count'] return { 'total_stocks': len(stocks), 'announcement_stocks': announcement_stocks, 'top_sector': top_sector, 'top_sector_count': top_count, 'zt_time_distribution': time_distribution } def calculate_sector_relations(stocks: List[Dict]) -> Dict: """计算板块关联关系""" # 建立股票-板块映射 stock_sector_map = defaultdict(set) sector_stocks = defaultdict(set) for stock in stocks: scode = stock['scode'] for sector in stock.get('core_sectors', []): stock_sector_map[scode].add(sector) sector_stocks[sector].add(scode) # 计算板块关联 nodes = [] links = [] relation_map = defaultdict(int) # 创建节点 for sector, stock_codes in sector_stocks.items(): if sector not in ['其他', '公告']: nodes.append({ 'name': sector, 'value': len(stock_codes), 'symbolSize': min(50, max(10, len(stock_codes) * 3)) }) # 计算板块间共同股票数 sector_list = list(sector_stocks.keys()) for i in range(len(sector_list)): for j in range(i + 1, len(sector_list)): sector1, sector2 = sector_list[i], sector_list[j] if sector1 not in ['其他', '公告'] and sector2 not in ['其他', '公告']: common_stocks = sector_stocks[sector1] & sector_stocks[sector2] if len(common_stocks) > 0: links.append({ 'source': sector1, 'target': sector2, 'value': len(common_stocks) }) return { 'nodes': nodes, 'links': links } def calculate_sector_relations_top10(stocks: List[Dict]) -> Dict: """计算板块关联关系TOP10(用于横向柱状图)""" relations = defaultdict(int) stock_sector_map = defaultdict(set) # 建立股票-板块映射 for stock in stocks: scode = stock['scode'] sectors = stock.get('core_sectors', []) for sector in sectors: stock_sector_map[scode].add(sector) # 基于股票维度计算板块关联 for scode, sectors in stock_sector_map.items(): sector_list = list(sectors) # 计算板块两两组合 for i in range(len(sector_list)): for j in range(i + 1, len(sector_list)): pair = tuple(sorted([sector_list[i], sector_list[j]])) relations[pair] += 1 # 转换为数组并排序,取TOP10 sorted_relations = sorted(relations.items(), key=lambda x: x[1], reverse=True)[:10] return { 'labels': [f"{pair[0]} - {pair[1]}" for pair, _ in sorted_relations], 'counts': [count for _, count in sorted_relations] } def generate_wordcloud_from_stocks(date: str): """从股票数据生成词云(备用方案)""" try: # 获取当日所有股票的文本内容 query = { "query": { "term": {"date": date} }, "_source": ["brief", "summary", "core_sectors"], "size": 10000 } result = es.search(index="zt_stocks", body=query) # 统计词频 word_count = defaultdict(int) for hit in result['hits']['hits']: stock = hit['_source'] # 统计板块词频 for sector in stock.get('core_sectors', []): if sector not in ['其他', '公告']: word_count[sector] += 500 # 板块权重更高 # 这里可以添加对brief和summary的分词处理 # 但需要额外的中文分词库支持 # 转换为词云格式 word_freq = [ {'name': word, 'value': count} for word, count in word_count.items() ] # 按频率排序并限制数量 word_freq.sort(key=lambda x: x['value'], reverse=True) return jsonify({ 'success': True, 'data': word_freq[:100] }) except Exception as e: logger.error(f"Error generating wordcloud from stocks: {str(e)}") return jsonify({ 'success': False, 'error': str(e) }), 500 def update_cache(cache_key: str, cache_type: str, data: Any, ttl_hours: int = 24): """更新缓存""" try: doc = { 'cache_key': cache_key, 'cache_type': cache_type, 'data': data, 'created_at': datetime.now().isoformat(), 'updated_at': datetime.now().isoformat(), 'expires_at': (datetime.now() + timedelta(hours=ttl_hours)).isoformat() } # 使用upsert确保更新或创建 es.index( index='zt_cache', id=f"{cache_key}_{cache_type}", body=doc ) except Exception as e: logger.error(f"Error updating cache: {str(e)}") # 健康检查端点 @app.route('/api/v1/health', methods=['GET']) def health_check(): """健康检查""" try: # 检查ES连接 es_health = es.cluster.health() # 检查embedding服务 embedding_status = [] for url in EMBEDDING_BASE_URLS: try: client = openai.OpenAI(api_key="dummy", base_url=url, timeout=5) # 简单测试 client.embeddings.create(model=EMBEDDING_MODEL, input=["test"]) embedding_status.append({"url": url, "status": "healthy"}) except: embedding_status.append({"url": url, "status": "unhealthy"}) return jsonify({ 'success': True, 'status': 'healthy', 'elasticsearch': { 'status': es_health['status'], 'cluster_name': es_health['cluster_name'] }, 'embedding_services': embedding_status }) except Exception as e: return jsonify({ 'success': False, 'status': 'unhealthy', 'error': str(e) }), 503 # 错误处理 @app.errorhandler(404) def not_found(error): return jsonify({ 'success': False, 'error': 'Endpoint not found' }), 404 @app.errorhandler(500) def internal_error(error): return jsonify({ 'success': False, 'error': 'Internal server error' }), 500 if __name__ == '__main__': app.run(debug=True, host='0.0.0.0', port=8800)