2089 lines
67 KiB
Python
2089 lines
67 KiB
Python
"""
|
||
Flask API for ES-based Stock Analysis System with Research Reports Integration
|
||
支持前后端分离的涨停股票分析系统API + 研究报告搜索API
|
||
"""
|
||
|
||
from flask import Flask, request, jsonify
|
||
from flask_cors import CORS
|
||
from elasticsearch import Elasticsearch, RequestError
|
||
from datetime import datetime, timedelta
|
||
from collections import defaultdict
|
||
import json
|
||
from typing import Dict, List, Any, Optional
|
||
import logging
|
||
import openai
|
||
import time
|
||
from enum import Enum
|
||
|
||
app = Flask(__name__)
|
||
CORS(app) # 启用跨域支持
|
||
|
||
# ES连接配置
|
||
ES_HOST = "http://192.168.1.231:9200"
|
||
es = Elasticsearch(
|
||
[ES_HOST],
|
||
timeout=60,
|
||
retry_on_timeout=True,
|
||
max_retries=3,
|
||
verify_certs=False
|
||
)
|
||
|
||
# Embedding配置
|
||
EMBEDDING_BASE_URLS = [
|
||
"http://192.168.1.231:8000/v1"
|
||
]
|
||
EMBEDDING_MODEL = "qwen3-embedding-8b"
|
||
EMBEDDING_DIMENSION = 4096
|
||
current_embedding_index = 0
|
||
|
||
# 配置日志
|
||
logging.basicConfig(level=logging.INFO)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
# === 枚举类型 ===
|
||
class SearchMode(str, Enum):
|
||
HYBRID = "hybrid" # 混合搜索(文本+向量)
|
||
TEXT = "text" # 仅文本搜索
|
||
VECTOR = "vector" # 仅向量搜索
|
||
|
||
|
||
# === Embedding相关函数 ===
|
||
def get_embedding_client():
|
||
"""获取embedding客户端,轮询使用不同的服务"""
|
||
global current_embedding_index
|
||
url = EMBEDDING_BASE_URLS[current_embedding_index % len(EMBEDDING_BASE_URLS)]
|
||
current_embedding_index += 1
|
||
|
||
return openai.OpenAI(
|
||
api_key="dummy",
|
||
base_url=url,
|
||
timeout=60,
|
||
), url
|
||
|
||
|
||
def get_embedding(text: str) -> Optional[List[float]]:
|
||
"""生成文本向量"""
|
||
if not text or len(text.strip()) == 0:
|
||
return None
|
||
|
||
try:
|
||
client, service_url = get_embedding_client()
|
||
|
||
# 限制文本长度
|
||
text = text[:8000] if len(text) > 8000 else text
|
||
|
||
response = client.embeddings.create(
|
||
model=EMBEDDING_MODEL,
|
||
input=[text]
|
||
)
|
||
|
||
logger.info(f"使用embedding服务: {service_url}")
|
||
return response.data[0].embedding
|
||
except Exception as e:
|
||
logger.error(f"生成向量失败: {e}")
|
||
return None
|
||
|
||
|
||
class StockAnalysisAPI:
|
||
"""股票分析API类"""
|
||
|
||
@staticmethod
|
||
def format_date(date_str: str) -> str:
|
||
"""格式化日期字符串 YYYYMMDD -> YYYY-MM-DD"""
|
||
if len(date_str) == 8:
|
||
return f"{date_str[:4]}-{date_str[4:6]}-{date_str[6:]}"
|
||
return date_str
|
||
|
||
@staticmethod
|
||
def parse_date(date_str: str) -> str:
|
||
"""解析日期字符串 YYYY-MM-DD -> YYYYMMDD"""
|
||
if '-' in date_str:
|
||
return date_str.replace('-', '')
|
||
return date_str
|
||
|
||
|
||
# ========== 原有的股票分析API路由 ==========
|
||
|
||
@app.route('/api/v1/dates/available', methods=['GET'])
|
||
def get_available_dates():
|
||
"""
|
||
获取所有可用的日期列表
|
||
用于日历组件显示
|
||
"""
|
||
try:
|
||
# 从ES获取所有日期的统计数据
|
||
query = {
|
||
"size": 0,
|
||
"aggs": {
|
||
"dates": {
|
||
"terms": {
|
||
"field": "date",
|
||
"size": 1000,
|
||
"order": {"_key": "desc"}
|
||
},
|
||
"aggs": {
|
||
"stock_count": {
|
||
"cardinality": {
|
||
"field": "scode"
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
result = es.search(index="zt_stocks", body=query)
|
||
|
||
events = []
|
||
for bucket in result['aggregations']['dates']['buckets']:
|
||
date = bucket['key']
|
||
count = bucket['doc_count']
|
||
formatted_date = StockAnalysisAPI.format_date(date)
|
||
|
||
events.append({
|
||
'title': f'{count}只',
|
||
'start': formatted_date,
|
||
'end': formatted_date,
|
||
'className': 'bg-gradient-primary',
|
||
'allDay': True,
|
||
'date': date, # 原始日期格式
|
||
'count': count
|
||
})
|
||
|
||
return jsonify({
|
||
'success': True,
|
||
'events': events,
|
||
'total': len(events)
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error getting available dates: {str(e)}")
|
||
return jsonify({
|
||
'success': False,
|
||
'error': str(e)
|
||
}), 500
|
||
|
||
|
||
@app.route('/api/v1/analysis/daily/<date>', methods=['GET'])
|
||
def get_daily_analysis(date):
|
||
"""
|
||
获取指定日期的分析数据
|
||
包括板块分析、词云数据、图表数据等
|
||
"""
|
||
try:
|
||
logger.info(f"========== Getting daily analysis for date: {date} ==========")
|
||
|
||
# 首先尝试从缓存获取
|
||
cache_query = {
|
||
"query": {
|
||
"bool": {
|
||
"must": [
|
||
{"term": {"cache_key": date}},
|
||
{"term": {"cache_type": "daily_analysis"}}
|
||
]
|
||
}
|
||
}
|
||
}
|
||
|
||
try:
|
||
cache_result = es.search(index="zt_cache", body=cache_query, size=1)
|
||
logger.info(f"Cache query result: {cache_result['hits']['total']['value']} hits")
|
||
|
||
if cache_result['hits']['total']['value'] > 0:
|
||
cached_data = cache_result['hits']['hits'][0]['_source']
|
||
expires_at = datetime.fromisoformat(cached_data['expires_at'].replace('Z', '+00:00'))
|
||
|
||
if expires_at > datetime.now():
|
||
logger.info(f"Returning cached data for {date}")
|
||
return jsonify({
|
||
'success': True,
|
||
'data': cached_data['data'],
|
||
'from_cache': True
|
||
})
|
||
else:
|
||
logger.info(f"Cache expired for {date}")
|
||
except Exception as e:
|
||
logger.warning(f"Cache query failed: {e}, continuing without cache")
|
||
|
||
# 获取统计数据
|
||
stats_query = {
|
||
"query": {
|
||
"term": {"date": date}
|
||
},
|
||
"_source": ["sector_stats", "word_freq", "chart_data", "summary"]
|
||
}
|
||
|
||
logger.info(f"Querying zt_daily_stats with: {json.dumps(stats_query)}")
|
||
stats_result = es.search(index="zt_daily_stats", body=stats_query, size=1)
|
||
logger.info(f"Stats result total: {stats_result['hits']['total']['value']}")
|
||
|
||
if stats_result['hits']['total']['value'] > 0:
|
||
stats = stats_result['hits']['hits'][0]['_source']
|
||
logger.info(f"Stats fields found: {list(stats.keys())}")
|
||
logger.info(f"Sector_stats count: {len(stats.get('sector_stats', []))}")
|
||
logger.info(f"Word_freq count: {len(stats.get('word_freq', []))}")
|
||
else:
|
||
stats = {}
|
||
logger.warning(f"No stats found in zt_daily_stats for {date}")
|
||
|
||
# 获取当日所有股票
|
||
stock_query = {
|
||
"query": {
|
||
"term": {"date": date}
|
||
},
|
||
"size": 10000,
|
||
"sort": [{"zt_time": "asc"}],
|
||
"_source": {
|
||
"exclude": ["content_embedding"]
|
||
}
|
||
}
|
||
|
||
logger.info(f"Querying zt_stocks...")
|
||
stock_result = es.search(index="zt_stocks", body=stock_query)
|
||
logger.info(f"Found {stock_result['hits']['total']['value']} stocks")
|
||
|
||
stocks = []
|
||
for hit in stock_result['hits']['hits']:
|
||
stock = hit['_source']
|
||
if 'content_embedding' in stock:
|
||
del stock['content_embedding']
|
||
stocks.append(stock)
|
||
|
||
logger.info(f"Processed {len(stocks)} stocks")
|
||
|
||
# 处理板块数据
|
||
if stats and 'sector_stats' in stats and len(stats['sector_stats']) > 0:
|
||
logger.info(f"Using sector_stats from zt_daily_stats")
|
||
sector_data = process_sector_data_from_stats(stats['sector_stats'], stocks)
|
||
logger.info(f"Processed sector_data with {len(sector_data)} sectors")
|
||
else:
|
||
logger.info(f"No sector_stats found, generating from stocks")
|
||
sector_data = process_sector_data(stocks)
|
||
logger.info(f"Generated sector_data with {len(sector_data)} sectors")
|
||
|
||
# 计算板块关联关系(用于TOP10横向柱状图)
|
||
sector_relations_top10 = calculate_sector_relations_top10(stocks)
|
||
logger.info(f"Calculated sector_relations_top10: {len(sector_relations_top10.get('labels', []))} items")
|
||
|
||
# 准备响应数据
|
||
response_data = {
|
||
'date': date,
|
||
'formatted_date': StockAnalysisAPI.format_date(date),
|
||
'total_stocks': len(stocks),
|
||
'sector_data': sector_data,
|
||
'chart_data': stats.get('chart_data', generate_chart_data(sector_data)),
|
||
'word_freq_data': stats.get('word_freq', []),
|
||
'sector_relations_top10': sector_relations_top10,
|
||
'summary': stats.get('summary', generate_summary(stocks, sector_data))
|
||
}
|
||
|
||
logger.info(f"Response data prepared successfully")
|
||
logger.info(
|
||
f"Response summary: total_stocks={response_data['total_stocks']}, sectors={len(response_data['sector_data'])}")
|
||
|
||
# 更新缓存
|
||
try:
|
||
update_cache(date, 'daily_analysis', response_data)
|
||
logger.info(f"Cache updated for {date}")
|
||
except Exception as e:
|
||
logger.warning(f"Failed to update cache: {e}")
|
||
|
||
return jsonify({
|
||
'success': True,
|
||
'data': response_data,
|
||
'from_cache': False
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error getting daily analysis for {date}: {str(e)}", exc_info=True)
|
||
return jsonify({
|
||
'success': False,
|
||
'error': str(e)
|
||
}), 500
|
||
|
||
|
||
@app.route('/api/v1/stocks/search', methods=['POST'])
|
||
def search_stocks():
|
||
"""
|
||
搜索股票
|
||
支持按股票代码、名称、板块等条件搜索
|
||
"""
|
||
try:
|
||
data = request.json
|
||
keyword = data.get('keyword', '')
|
||
date = data.get('date')
|
||
sectors = data.get('sectors', [])
|
||
page = data.get('page', 1)
|
||
page_size = data.get('page_size', 20)
|
||
|
||
# 构建查询条件
|
||
must_conditions = []
|
||
|
||
if date:
|
||
must_conditions.append({"term": {"date": date}})
|
||
|
||
if sectors:
|
||
must_conditions.append({
|
||
"terms": {"core_sectors": sectors}
|
||
})
|
||
|
||
should_conditions = []
|
||
if keyword:
|
||
should_conditions.extend([
|
||
{"match": {"sname": keyword}},
|
||
{"term": {"scode": keyword}},
|
||
{"match": {"brief": keyword}},
|
||
{"match": {"summary": keyword}}
|
||
])
|
||
|
||
query = {
|
||
"query": {
|
||
"bool": {
|
||
"must": must_conditions,
|
||
"should": should_conditions if should_conditions else None,
|
||
"minimum_should_match": 1 if should_conditions else 0
|
||
}
|
||
},
|
||
"from": (page - 1) * page_size,
|
||
"size": page_size,
|
||
"sort": [{"zt_time": "desc"}],
|
||
"_source": {
|
||
"excludes": ["content_embedding"] # 排除向量字段
|
||
},
|
||
"highlight": {
|
||
"fields": {
|
||
"sname": {},
|
||
"brief": {},
|
||
"summary": {}
|
||
}
|
||
}
|
||
}
|
||
|
||
# 移除None值
|
||
query["query"]["bool"] = {k: v for k, v in query["query"]["bool"].items() if v is not None}
|
||
|
||
result = es.search(index="zt_stocks", body=query)
|
||
|
||
stocks = []
|
||
for hit in result['hits']['hits']:
|
||
stock = hit['_source']
|
||
if 'highlight' in hit:
|
||
stock['highlight'] = hit['highlight']
|
||
stocks.append(stock)
|
||
|
||
return jsonify({
|
||
'success': True,
|
||
'data': {
|
||
'stocks': stocks,
|
||
'total': result['hits']['total']['value'],
|
||
'page': page,
|
||
'page_size': page_size,
|
||
'total_pages': (result['hits']['total']['value'] + page_size - 1) // page_size
|
||
}
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error searching stocks: {str(e)}")
|
||
return jsonify({
|
||
'success': False,
|
||
'error': str(e)
|
||
}), 500
|
||
|
||
|
||
@app.route('/api/v1/stocks/search/hybrid', methods=['POST'])
|
||
def hybrid_search_stocks():
|
||
"""
|
||
混合搜索涨停原因
|
||
结合关键词搜索和向量相似度搜索
|
||
|
||
优化逻辑:
|
||
1. 先严格匹配股票名称/代码,如果匹配到则只返回该股票的记录
|
||
2. 如果没有匹配到股票名称/代码,再走智能检索逻辑
|
||
3. 结果按涨停时间由新到旧排序
|
||
"""
|
||
try:
|
||
data = request.json
|
||
query_text = data.get('query', '').strip()
|
||
mode = data.get('mode', 'hybrid') # hybrid, text, vector
|
||
date = data.get('date')
|
||
date_range = data.get('date_range', {}) # {'start': 'YYYYMMDD', 'end': 'YYYYMMDD'}
|
||
sectors = data.get('sectors', [])
|
||
page = data.get('page', 1)
|
||
page_size = data.get('page_size', 20)
|
||
min_score = data.get('min_score')
|
||
|
||
if not query_text:
|
||
return jsonify({
|
||
'success': False,
|
||
'error': 'Query text is required'
|
||
}), 400
|
||
|
||
# 构建日期过滤条件
|
||
date_filter = []
|
||
if date:
|
||
date_filter.append({"term": {"date": date}})
|
||
elif date_range:
|
||
range_query = {}
|
||
if date_range.get('start'):
|
||
range_query['gte'] = date_range['start']
|
||
if date_range.get('end'):
|
||
range_query['lte'] = date_range['end']
|
||
if range_query:
|
||
date_filter.append({"range": {"date": range_query}})
|
||
|
||
if sectors:
|
||
date_filter.append({"terms": {"core_sectors": sectors}})
|
||
|
||
# ========== 第一步:尝试精确匹配股票名称或代码 ==========
|
||
exact_match_query = {
|
||
"query": {
|
||
"bool": {
|
||
"must": date_filter.copy() if date_filter else [],
|
||
"should": [
|
||
{"term": {"sname.keyword": {"value": query_text, "boost": 10.0}}},
|
||
{"term": {"scode": {"value": query_text, "boost": 10.0}}},
|
||
# 也支持部分匹配股票名称(如输入"比亚迪"匹配"比亚迪")
|
||
{"match_phrase": {"sname": {"query": query_text, "boost": 8.0}}}
|
||
],
|
||
"minimum_should_match": 1
|
||
}
|
||
},
|
||
"size": 1000, # 获取足够多的结果用于后续过滤
|
||
"_source": {
|
||
"excludes": ["content_embedding"]
|
||
}
|
||
}
|
||
|
||
exact_result = es.search(index="zt_stocks", body=exact_match_query)
|
||
exact_hits = exact_result['hits']['hits']
|
||
|
||
# 检查是否有精确匹配的股票
|
||
matched_stock_codes = set()
|
||
is_stock_search = False
|
||
|
||
if exact_hits:
|
||
# 检查是否真的匹配到了股票名称或代码
|
||
for hit in exact_hits:
|
||
stock = hit['_source']
|
||
sname = stock.get('sname', '')
|
||
scode = stock.get('scode', '')
|
||
|
||
# 判断是否精确匹配股票名称或代码
|
||
if query_text == sname or query_text == scode or query_text in sname:
|
||
matched_stock_codes.add(scode)
|
||
is_stock_search = True
|
||
|
||
# ========== 第二步:根据匹配情况决定搜索策略 ==========
|
||
if is_stock_search and matched_stock_codes:
|
||
# 用户搜索的是特定股票,只返回该股票的所有涨停记录
|
||
logger.info(f"Stock name/code search detected: {query_text} -> {matched_stock_codes}")
|
||
|
||
final_query = {
|
||
"query": {
|
||
"bool": {
|
||
"must": date_filter.copy() if date_filter else [],
|
||
"filter": [
|
||
{"terms": {"scode": list(matched_stock_codes)}}
|
||
]
|
||
}
|
||
},
|
||
"from": (page - 1) * page_size,
|
||
"size": page_size,
|
||
"sort": [
|
||
{"date": {"order": "desc"}}, # 先按日期降序
|
||
{"zt_time": {"order": "desc"}} # 再按涨停时间降序
|
||
],
|
||
"_source": {
|
||
"excludes": ["content_embedding"]
|
||
}
|
||
}
|
||
|
||
search_mode_actual = "exact_stock"
|
||
else:
|
||
# 用户搜索的不是股票名称,走智能检索逻辑
|
||
logger.info(f"Semantic search: {query_text}")
|
||
|
||
must_conditions = date_filter.copy() if date_filter else []
|
||
should_conditions = []
|
||
|
||
# 文本搜索部分
|
||
if mode in ['text', 'hybrid']:
|
||
text_queries = [
|
||
{"match": {"brief": {"query": query_text, "boost": 2.0}}},
|
||
{"match": {"summary": {"query": query_text, "boost": 1.5}}},
|
||
{"match": {"sname": {"query": query_text, "boost": 1.0}}}
|
||
]
|
||
should_conditions.extend(text_queries)
|
||
|
||
# 向量搜索部分
|
||
if mode in ['vector', 'hybrid']:
|
||
query_vector = get_embedding(query_text)
|
||
|
||
if query_vector:
|
||
vector_score_query = {
|
||
"script_score": {
|
||
"query": {"match_all": {}},
|
||
"script": {
|
||
"source": """
|
||
if (!doc['content_embedding'].empty) {
|
||
double similarity = cosineSimilarity(params.query_vector, 'content_embedding') + 1.0;
|
||
return similarity * params.boost;
|
||
}
|
||
return 0;
|
||
""",
|
||
"params": {
|
||
"query_vector": query_vector,
|
||
"boost": 2.0 if mode == 'vector' else 1.0
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
if mode == 'vector':
|
||
final_query = {
|
||
"query": {
|
||
"bool": {
|
||
"must": [vector_score_query] + must_conditions
|
||
}
|
||
}
|
||
}
|
||
else:
|
||
should_conditions.append(vector_score_query)
|
||
|
||
# 构建最终查询
|
||
if mode != 'vector':
|
||
final_query = {
|
||
"query": {
|
||
"bool": {
|
||
"must": must_conditions,
|
||
"should": should_conditions,
|
||
"minimum_should_match": 1 if should_conditions else 0
|
||
}
|
||
}
|
||
}
|
||
|
||
# 添加分页、排序和高亮
|
||
final_query.update({
|
||
"from": (page - 1) * page_size,
|
||
"size": page_size,
|
||
"sort": [
|
||
{"date": {"order": "desc"}}, # 先按日期降序
|
||
{"zt_time": {"order": "desc"}} # 再按涨停时间降序
|
||
],
|
||
"_source": {
|
||
"excludes": ["content_embedding"]
|
||
},
|
||
"highlight": {
|
||
"fields": {
|
||
"brief": {"fragment_size": 150},
|
||
"summary": {"fragment_size": 200},
|
||
"sname": {}
|
||
}
|
||
}
|
||
})
|
||
|
||
# 添加最小分数过滤
|
||
if min_score:
|
||
final_query["min_score"] = min_score
|
||
|
||
search_mode_actual = mode
|
||
|
||
# ========== 第三步:执行搜索 ==========
|
||
start_time = time.time()
|
||
result = es.search(index="zt_stocks", body=final_query)
|
||
took_ms = int((time.time() - start_time) * 1000)
|
||
|
||
# 处理结果
|
||
stocks = []
|
||
for hit in result['hits']['hits']:
|
||
stock = hit['_source']
|
||
stock['_score'] = hit['_score']
|
||
|
||
# 添加高亮信息
|
||
if 'highlight' in hit:
|
||
stock['highlight'] = hit['highlight']
|
||
|
||
# 格式化时间
|
||
if 'zt_time' in stock:
|
||
try:
|
||
zt_time = datetime.fromisoformat(stock['zt_time'].replace('Z', '+00:00'))
|
||
stock['formatted_time'] = zt_time.strftime('%H:%M:%S')
|
||
except:
|
||
pass
|
||
|
||
stocks.append(stock)
|
||
|
||
return jsonify({
|
||
'success': True,
|
||
'data': {
|
||
'stocks': stocks,
|
||
'total': result['hits']['total']['value'],
|
||
'page': page,
|
||
'page_size': page_size,
|
||
'total_pages': (result['hits']['total']['value'] + page_size - 1) // page_size,
|
||
'took_ms': took_ms,
|
||
'search_mode': search_mode_actual,
|
||
'is_stock_search': is_stock_search,
|
||
'matched_stocks': list(matched_stock_codes) if is_stock_search else []
|
||
}
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error in hybrid search: {str(e)}")
|
||
return jsonify({
|
||
'success': False,
|
||
'error': str(e)
|
||
}), 500
|
||
|
||
|
||
@app.route('/api/limit-analyse/high-position-stocks', methods=['GET'])
|
||
def get_high_position_stocks():
|
||
"""
|
||
获取高位股数据
|
||
高位股是指连续涨停天数较多的股票,通常风险较高
|
||
|
||
:param date: YYYYMMDD格式的日期
|
||
:return: 高位股列表及统计信息
|
||
"""
|
||
try:
|
||
date = request.args.get('date')
|
||
if not date:
|
||
return jsonify({
|
||
'success': False,
|
||
'error': 'Date parameter is required'
|
||
}), 400
|
||
|
||
# 查询当日所有涨停股票
|
||
# 由于continuous_days是keyword类型且格式为"X天Y板",无法直接用range查询
|
||
# 需要获取所有数据后在Python中过滤
|
||
query = {
|
||
"query": {
|
||
"bool": {
|
||
"must": [
|
||
{"term": {"date": date}},
|
||
{"exists": {"field": "continuous_days"}} # 确保有continuous_days字段
|
||
]
|
||
}
|
||
},
|
||
"size": 10000,
|
||
"sort": [
|
||
{"zt_time": {"order": "asc"}} # 按涨停时间升序
|
||
],
|
||
"_source": {
|
||
"exclude": ["content_embedding"]
|
||
}
|
||
}
|
||
|
||
result = es.search(index="zt_stocks", body=query)
|
||
|
||
stocks = []
|
||
total_continuous_days = 0
|
||
max_continuous_days = 0
|
||
|
||
for hit in result['hits']['hits']:
|
||
stock = hit['_source']
|
||
|
||
# 解析continuous_days字段,格式如"4天3板"
|
||
continuous_days_str = stock.get('continuous_days', '')
|
||
continuous_days = 0
|
||
|
||
if continuous_days_str:
|
||
# 尝试提取数字,格式通常是"X天Y板"
|
||
import re
|
||
match = re.search(r'(\d+)天', continuous_days_str)
|
||
if match:
|
||
continuous_days = int(match.group(1))
|
||
else:
|
||
# 如果格式不同,尝试提取任意数字
|
||
numbers = re.findall(r'\d+', continuous_days_str)
|
||
if numbers:
|
||
continuous_days = int(numbers[0])
|
||
|
||
# 只保留连续涨停天数>=2的股票(高位股)
|
||
if continuous_days < 2:
|
||
continue
|
||
|
||
total_continuous_days += continuous_days
|
||
max_continuous_days = max(max_continuous_days, continuous_days)
|
||
|
||
# 构建股票数据(匹配前端需要的字段)
|
||
stock_data = {
|
||
'stock_code': stock.get('scode', ''),
|
||
'stock_name': stock.get('sname', ''),
|
||
'price': 0, # 数据中没有价格字段,返回0或可以从其他接口获取
|
||
'increase_rate': 10.0, # 涨停一般是10%
|
||
'continuous_limit_up': continuous_days,
|
||
'continuous_days_str': continuous_days_str, # 原始字符串
|
||
'industry': stock.get('core_sectors', [''])[0] if stock.get('core_sectors') else '其他',
|
||
'turnover_rate': 0, # 数据中没有换手率字段
|
||
'brief': stock.get('brief', '').replace('<br>', ' ').replace('\n', ' '),
|
||
'summary': stock.get('summary', ''),
|
||
'zt_time': stock.get('zt_time', ''),
|
||
'formatted_time': stock.get('formatted_time', ''),
|
||
'core_sectors': stock.get('core_sectors', [])
|
||
}
|
||
|
||
# 格式化涨停时间
|
||
if not stock_data['formatted_time'] and stock_data['zt_time']:
|
||
try:
|
||
from datetime import datetime
|
||
zt_time = datetime.fromisoformat(stock_data['zt_time'].replace('Z', '+00:00'))
|
||
stock_data['formatted_time'] = zt_time.strftime('%H:%M:%S')
|
||
except:
|
||
pass
|
||
|
||
stocks.append(stock_data)
|
||
|
||
# 按连续涨停天数降序排序
|
||
stocks.sort(key=lambda x: x['continuous_limit_up'], reverse=True)
|
||
|
||
# 计算统计信息
|
||
total_count = len(stocks)
|
||
avg_continuous_days = round(total_continuous_days / total_count, 1) if total_count > 0 else 0
|
||
|
||
# 按连续涨停天数分组统计
|
||
distribution = {}
|
||
for stock in stocks:
|
||
days = stock['continuous_limit_up']
|
||
if days not in distribution:
|
||
distribution[days] = 0
|
||
distribution[days] += 1
|
||
|
||
response_data = {
|
||
'stocks': stocks[:50], # 限制返回最多50只股票
|
||
'statistics': {
|
||
'total_count': total_count,
|
||
'avg_continuous_days': avg_continuous_days,
|
||
'max_continuous_days': max_continuous_days,
|
||
'distribution': distribution # 各连板天数的分布
|
||
},
|
||
'date': date,
|
||
'formatted_date': StockAnalysisAPI.format_date(date)
|
||
}
|
||
|
||
logger.info(
|
||
f"Found {total_count} high position stocks for date {date} (max continuous days: {max_continuous_days})")
|
||
|
||
return jsonify({
|
||
'success': True,
|
||
'data': response_data
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error getting high position stocks for {date}: {str(e)}", exc_info=True)
|
||
return jsonify({
|
||
'success': False,
|
||
'error': str(e)
|
||
}), 500
|
||
@app.route('/api/v1/sectors/relations/<date>', methods=['GET'])
|
||
def get_sector_relations(date):
|
||
"""
|
||
获取板块关联关系数据
|
||
用于绘制板块关系图
|
||
"""
|
||
try:
|
||
# 获取当日所有股票
|
||
query = {
|
||
"query": {
|
||
"term": {"date": date}
|
||
},
|
||
"size": 10000,
|
||
"_source": {
|
||
"include": ["scode", "sname", "core_sectors"] # 只获取需要的字段
|
||
}
|
||
}
|
||
|
||
result = es.search(index="zt_stocks", body=query)
|
||
stocks = [hit['_source'] for hit in result['hits']['hits']]
|
||
|
||
# 计算板块关联关系
|
||
relations = calculate_sector_relations(stocks)
|
||
|
||
return jsonify({
|
||
'success': True,
|
||
'data': relations
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error getting sector relations for {date}: {str(e)}")
|
||
return jsonify({
|
||
'success': False,
|
||
'error': str(e)
|
||
}), 500
|
||
|
||
|
||
@app.route('/api/v1/stocks/detail/<scode>', methods=['GET'])
|
||
def get_stock_detail(scode):
|
||
"""
|
||
获取股票详细信息
|
||
"""
|
||
try:
|
||
date = request.args.get('date')
|
||
|
||
query = {
|
||
"query": {
|
||
"bool": {
|
||
"must": [
|
||
{"term": {"scode": scode}}
|
||
]
|
||
}
|
||
},
|
||
"_source": {
|
||
"exclude": ["content_embedding"] # ES 8.x 使用 exclude
|
||
}
|
||
}
|
||
|
||
if date:
|
||
query["query"]["bool"]["must"].append({"term": {"date": date}})
|
||
|
||
result = es.search(index="zt_stocks", body=query, size=1)
|
||
|
||
if result['hits']['total']['value'] == 0:
|
||
return jsonify({
|
||
'success': False,
|
||
'error': 'Stock not found'
|
||
}), 404
|
||
|
||
stock = result['hits']['hits'][0]['_source']
|
||
# 确保删除content_embedding
|
||
if 'content_embedding' in stock:
|
||
del stock['content_embedding']
|
||
|
||
# 获取历史涨停记录
|
||
history_query = {
|
||
"query": {
|
||
"term": {"scode": scode}
|
||
},
|
||
"sort": [{"date": "desc"}],
|
||
"size": 30,
|
||
"_source": {
|
||
"exclude": ["content_embedding"]
|
||
}
|
||
}
|
||
|
||
history_result = es.search(index="zt_stocks", body=history_query)
|
||
history = []
|
||
for hit in history_result['hits']['hits']:
|
||
h = hit['_source']
|
||
if 'content_embedding' in h:
|
||
del h['content_embedding']
|
||
history.append(h)
|
||
|
||
return jsonify({
|
||
'success': True,
|
||
'data': {
|
||
'stock': stock,
|
||
'history': history
|
||
}
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error getting stock detail for {scode}: {str(e)}")
|
||
return jsonify({
|
||
'success': False,
|
||
'error': str(e)
|
||
}), 500
|
||
|
||
|
||
@app.route('/api/v1/analysis/trends', methods=['GET'])
|
||
def get_market_trends():
|
||
"""
|
||
获取市场趋势分析
|
||
分析最近N天的板块轮动情况
|
||
"""
|
||
try:
|
||
days = int(request.args.get('days', 7))
|
||
end_date = datetime.now()
|
||
start_date = end_date - timedelta(days=days)
|
||
|
||
# 获取日期范围内的统计数据
|
||
query = {
|
||
"query": {
|
||
"range": {
|
||
"date": {
|
||
"gte": start_date.strftime('%Y%m%d'),
|
||
"lte": end_date.strftime('%Y%m%d')
|
||
}
|
||
}
|
||
},
|
||
"size": 0,
|
||
"aggs": {
|
||
"by_date": {
|
||
"terms": {
|
||
"field": "date",
|
||
"size": days,
|
||
"order": {"_key": "asc"}
|
||
},
|
||
"aggs": {
|
||
"sectors": {
|
||
"terms": {
|
||
"field": "core_sectors",
|
||
"size": 20
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
result = es.search(index="zt_stocks", body=query)
|
||
|
||
trends = []
|
||
for bucket in result['aggregations']['by_date']['buckets']:
|
||
date = bucket['key']
|
||
sectors = [
|
||
{
|
||
'name': sector_bucket['key'],
|
||
'count': sector_bucket['doc_count']
|
||
}
|
||
for sector_bucket in bucket['sectors']['buckets']
|
||
]
|
||
|
||
trends.append({
|
||
'date': date,
|
||
'formatted_date': StockAnalysisAPI.format_date(date),
|
||
'sectors': sectors,
|
||
'total': bucket['doc_count']
|
||
})
|
||
|
||
return jsonify({
|
||
'success': True,
|
||
'data': {
|
||
'trends': trends,
|
||
'days': days
|
||
}
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error getting market trends: {str(e)}")
|
||
return jsonify({
|
||
'success': False,
|
||
'error': str(e)
|
||
}), 500
|
||
|
||
|
||
@app.route('/api/v1/analysis/wordcloud/<date>', methods=['GET'])
|
||
def get_wordcloud_data(date):
|
||
"""
|
||
获取词云图数据
|
||
专门为词云图组件提供数据
|
||
|
||
:param date: YYYYMMDD格式的日期
|
||
"""
|
||
try:
|
||
# 从zt_daily_stats获取词云数据
|
||
query = {
|
||
"query": {
|
||
"term": {"date": date}
|
||
},
|
||
"_source": ["word_freq"]
|
||
}
|
||
|
||
result = es.search(index="zt_daily_stats", body=query, size=1)
|
||
|
||
if result['hits']['total']['value'] > 0:
|
||
word_freq = result['hits']['hits'][0]['_source'].get('word_freq', [])
|
||
|
||
# 确保数据格式正确(echarts-wordcloud需要的格式)
|
||
formatted_data = []
|
||
for item in word_freq:
|
||
formatted_data.append({
|
||
'name': item.get('name', ''),
|
||
'value': max(100, item.get('value', 100)) # 确保最小值为100
|
||
})
|
||
|
||
return jsonify({
|
||
'success': True,
|
||
'data': formatted_data[:200] # 限制最多200个词
|
||
})
|
||
else:
|
||
# 如果没有预计算的词云数据,从股票数据生成
|
||
return generate_wordcloud_from_stocks(date)
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error getting wordcloud data for {date}: {str(e)}")
|
||
return jsonify({
|
||
'success': False,
|
||
'error': str(e)
|
||
}), 500
|
||
|
||
|
||
@app.route('/api/v1/sectors/top-relations/<date>', methods=['GET'])
|
||
def get_sector_top_relations(date):
|
||
"""
|
||
获取板块关联TOP10数据
|
||
专门为横向柱状图提供数据
|
||
|
||
:param date: YYYYMMDD格式的日期
|
||
"""
|
||
try:
|
||
# 获取当日所有股票
|
||
query = {
|
||
"query": {
|
||
"term": {"date": date}
|
||
},
|
||
"_source": ["scode", "core_sectors"],
|
||
"size": 10000
|
||
}
|
||
|
||
result = es.search(index="zt_stocks", body=query)
|
||
stocks = [hit['_source'] for hit in result['hits']['hits']]
|
||
|
||
# 计算板块关联TOP10
|
||
relations_data = calculate_sector_relations_top10(stocks)
|
||
|
||
return jsonify({
|
||
'success': True,
|
||
'data': relations_data
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error getting sector top relations for {date}: {str(e)}")
|
||
return jsonify({
|
||
'success': False,
|
||
'error': str(e)
|
||
}), 500
|
||
|
||
|
||
@app.route('/api/v1/stocks/batch-detail', methods=['POST'])
|
||
def get_stocks_batch_detail():
|
||
"""
|
||
批量获取股票详情
|
||
用于板块展开时显示所有股票
|
||
"""
|
||
try:
|
||
data = request.json
|
||
stock_codes = data.get('codes', [])
|
||
date = data.get('date')
|
||
|
||
if not stock_codes:
|
||
return jsonify({
|
||
'success': False,
|
||
'error': 'No stock codes provided'
|
||
}), 400
|
||
|
||
# 构建查询
|
||
must_conditions = [
|
||
{"terms": {"scode": stock_codes}}
|
||
]
|
||
|
||
if date:
|
||
must_conditions.append({"term": {"date": date}})
|
||
|
||
query = {
|
||
"query": {
|
||
"bool": {
|
||
"must": must_conditions
|
||
}
|
||
},
|
||
"size": len(stock_codes),
|
||
"sort": [{"zt_time": "asc"}],
|
||
"_source": {
|
||
"exclude": ["content_embedding"] # ES 8.x 使用 exclude
|
||
}
|
||
}
|
||
|
||
result = es.search(index="zt_stocks", body=query)
|
||
|
||
stocks = []
|
||
for hit in result['hits']['hits']:
|
||
stock = hit['_source']
|
||
# 确保删除content_embedding
|
||
if 'content_embedding' in stock:
|
||
del stock['content_embedding']
|
||
|
||
# 格式化时间
|
||
if 'zt_time' in stock:
|
||
try:
|
||
zt_time = datetime.fromisoformat(stock['zt_time'].replace('Z', '+00:00'))
|
||
stock['formatted_time'] = zt_time.strftime('%H:%M:%S')
|
||
except:
|
||
stock['formatted_time'] = ''
|
||
|
||
# 处理brief和summary的格式
|
||
stock['brief'] = stock.get('brief', '').replace('\n', '<br>').replace('#', '').replace('*', '')
|
||
stock['summary'] = stock.get('summary', '').replace('\n', '<br>').replace('#', '').replace('*', '')
|
||
|
||
stocks.append(stock)
|
||
|
||
return jsonify({
|
||
'success': True,
|
||
'data': stocks
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error getting batch stock details: {str(e)}")
|
||
return jsonify({
|
||
'success': False,
|
||
'error': str(e)
|
||
}), 500
|
||
|
||
|
||
@app.route('/api/v1/init/data', methods=['GET'])
|
||
def initialize_data():
|
||
"""
|
||
初始化页面数据
|
||
获取最新日期的所有必要数据
|
||
减少前端多次请求
|
||
"""
|
||
try:
|
||
# 获取最新日期
|
||
date_query = {
|
||
"size": 0,
|
||
"aggs": {
|
||
"latest_date": {
|
||
"terms": {
|
||
"field": "date",
|
||
"size": 1,
|
||
"order": {"_key": "desc"}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
date_result = es.search(index="zt_stocks", body=date_query)
|
||
|
||
if not date_result['aggregations']['latest_date']['buckets']:
|
||
return jsonify({
|
||
'success': False,
|
||
'error': 'No data available'
|
||
}), 404
|
||
|
||
latest_date = date_result['aggregations']['latest_date']['buckets'][0]['key']
|
||
|
||
# 获取该日期的完整分析数据
|
||
analysis_response = get_daily_analysis(latest_date)
|
||
analysis_data = json.loads(analysis_response.data)
|
||
|
||
if not analysis_data['success']:
|
||
return analysis_response
|
||
|
||
# 获取可用日期列表
|
||
dates_response = get_available_dates()
|
||
dates_data = json.loads(dates_response.data)
|
||
|
||
# 组合返回数据
|
||
return jsonify({
|
||
'success': True,
|
||
'data': {
|
||
'latest_date': latest_date,
|
||
'formatted_date': StockAnalysisAPI.format_date(latest_date),
|
||
'analysis': analysis_data['data'],
|
||
'available_dates': dates_data['events']
|
||
}
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error initializing data: {str(e)}")
|
||
return jsonify({
|
||
'success': False,
|
||
'error': str(e)
|
||
}), 500
|
||
|
||
|
||
@app.route('/api/v1/statistics/overview', methods=['GET'])
|
||
def get_statistics_overview():
|
||
"""
|
||
获取统计概览
|
||
包括总体统计、热门板块、连板统计等
|
||
"""
|
||
try:
|
||
# 多维度聚合查询
|
||
query = {
|
||
"size": 0,
|
||
"aggs": {
|
||
"total_stocks": {
|
||
"cardinality": {
|
||
"field": "scode"
|
||
}
|
||
},
|
||
"total_days": {
|
||
"cardinality": {
|
||
"field": "date"
|
||
}
|
||
},
|
||
"top_sectors": {
|
||
"terms": {
|
||
"field": "core_sectors",
|
||
"size": 10
|
||
}
|
||
},
|
||
"continuous_days_stats": {
|
||
"terms": {
|
||
"field": "continuous_days",
|
||
"size": 10
|
||
}
|
||
},
|
||
"recent_trends": {
|
||
"date_histogram": {
|
||
"field": "zt_time",
|
||
"calendar_interval": "day",
|
||
"min_doc_count": 0,
|
||
"extended_bounds": {
|
||
"min": "now-30d",
|
||
"max": "now"
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
result = es.search(index="zt_stocks", body=query)
|
||
aggs = result['aggregations']
|
||
|
||
overview = {
|
||
'total_stocks': aggs['total_stocks']['value'],
|
||
'total_days': aggs['total_days']['value'],
|
||
'top_sectors': [
|
||
{
|
||
'name': bucket['key'],
|
||
'count': bucket['doc_count']
|
||
}
|
||
for bucket in aggs['top_sectors']['buckets']
|
||
],
|
||
'continuous_stats': [
|
||
{
|
||
'days': bucket['key'],
|
||
'count': bucket['doc_count']
|
||
}
|
||
for bucket in aggs['continuous_days_stats']['buckets']
|
||
],
|
||
'recent_trends': [
|
||
{
|
||
'date': bucket['key_as_string'],
|
||
'count': bucket['doc_count']
|
||
}
|
||
for bucket in aggs['recent_trends']['buckets']
|
||
]
|
||
}
|
||
|
||
return jsonify({
|
||
'success': True,
|
||
'data': overview
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error getting statistics overview: {str(e)}")
|
||
return jsonify({
|
||
'success': False,
|
||
'error': str(e)
|
||
}), 500
|
||
|
||
|
||
# ========== 研究报告搜索API路由 ==========
|
||
|
||
@app.route('/search', methods=['GET'])
|
||
def search_reports():
|
||
"""
|
||
搜索研究报告
|
||
|
||
搜索功能:
|
||
- 文本搜索:在content字段中搜索
|
||
- 向量搜索:基于content_embedding的语义搜索
|
||
- 混合搜索:结合文本和向量搜索
|
||
- 精确搜索:使用exact_match参数进行精确匹配
|
||
- 可选过滤:证券代码、日期范围
|
||
"""
|
||
try:
|
||
# 获取参数
|
||
query = request.args.get('query', '')
|
||
mode = request.args.get('mode', 'hybrid') # hybrid, text, vector
|
||
exact_match = request.args.get('exact_match', '0') # 0: 模糊匹配, 1: 精确匹配
|
||
security_code = request.args.get('security_code')
|
||
start_date = request.args.get('start_date')
|
||
end_date = request.args.get('end_date')
|
||
size = int(request.args.get('size', 10))
|
||
from_ = int(request.args.get('from', 0))
|
||
min_score = request.args.get('min_score', type=float)
|
||
|
||
if not query:
|
||
return jsonify({
|
||
'success': False,
|
||
'error': 'Query parameter is required'
|
||
}), 400
|
||
|
||
start_time = time.time()
|
||
|
||
# 构建查询条件
|
||
must_conditions = []
|
||
should_conditions = []
|
||
|
||
# 证券代码过滤
|
||
if security_code:
|
||
must_conditions.append({"term": {"security_code": security_code}})
|
||
|
||
# 日期范围过滤
|
||
if start_date or end_date:
|
||
date_range = {}
|
||
if start_date:
|
||
date_range["gte"] = start_date
|
||
if end_date:
|
||
date_range["lte"] = end_date
|
||
must_conditions.append({"range": {"declare_date": date_range}})
|
||
|
||
# 文本搜索 - 根据exact_match参数选择匹配方式
|
||
if mode in ['text', 'hybrid']:
|
||
if exact_match == '1':
|
||
# 精确匹配模式
|
||
# 使用match_phrase进行短语精确匹配
|
||
text_queries = [
|
||
{
|
||
"match_phrase": {
|
||
"content": {
|
||
"query": query,
|
||
"boost": 2.0
|
||
}
|
||
}
|
||
},
|
||
{
|
||
"match_phrase": {
|
||
"report_title": {
|
||
"query": query,
|
||
"boost": 3.0
|
||
}
|
||
}
|
||
}
|
||
]
|
||
|
||
# 对于证券名称和作者,也支持精确匹配
|
||
if len(query) <= 10: # 短查询可能是名称或作者
|
||
text_queries.extend([
|
||
{
|
||
"term": {
|
||
"security_name.keyword": {
|
||
"value": query,
|
||
"boost": 2.5
|
||
}
|
||
}
|
||
},
|
||
{
|
||
"term": {
|
||
"author.keyword": {
|
||
"value": query,
|
||
"boost": 2.0
|
||
}
|
||
}
|
||
},
|
||
{
|
||
"term": {
|
||
"publisher.keyword": {
|
||
"value": query,
|
||
"boost": 1.5
|
||
}
|
||
}
|
||
}
|
||
])
|
||
|
||
should_conditions.extend(text_queries)
|
||
else:
|
||
# 模糊匹配模式(原有逻辑)
|
||
text_queries = [
|
||
{
|
||
"match": {
|
||
"content": {
|
||
"query": query,
|
||
"operator": "or",
|
||
"boost": 1.0
|
||
}
|
||
}
|
||
},
|
||
{
|
||
"match": {
|
||
"report_title": {
|
||
"query": query,
|
||
"boost": 2.0
|
||
}
|
||
}
|
||
},
|
||
{
|
||
"match": {
|
||
"security_name": {
|
||
"query": query,
|
||
"boost": 1.5
|
||
}
|
||
}
|
||
},
|
||
{
|
||
"match": {
|
||
"author": {
|
||
"query": query,
|
||
"boost": 1.2
|
||
}
|
||
}
|
||
},
|
||
{
|
||
"match": {
|
||
"publisher": {
|
||
"query": query,
|
||
"boost": 1.0
|
||
}
|
||
}
|
||
}
|
||
]
|
||
should_conditions.extend(text_queries)
|
||
|
||
# 向量搜索
|
||
if mode in ['vector', 'hybrid']:
|
||
query_vector = get_embedding(query)
|
||
|
||
if query_vector:
|
||
vector_query = {
|
||
"script_score": {
|
||
"query": {"match_all": {}} if mode == 'vector' else {
|
||
"bool": {"must": must_conditions}},
|
||
"script": {
|
||
"source": """
|
||
double similarity = cosineSimilarity(params.query_vector, 'content_embedding') + 1.0;
|
||
return similarity;
|
||
""",
|
||
"params": {
|
||
"query_vector": query_vector
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
if mode == 'vector':
|
||
# 纯向量搜索
|
||
if must_conditions:
|
||
vector_query["query"] = {"bool": {"must": must_conditions}}
|
||
|
||
final_query = {
|
||
"query": vector_query,
|
||
"from": from_,
|
||
"size": size
|
||
}
|
||
else:
|
||
# 混合搜索时,向量搜索作为should条件
|
||
should_conditions.append(vector_query)
|
||
|
||
# 构建最终查询(非纯向量搜索)
|
||
if mode != 'vector':
|
||
final_query = {
|
||
"query": {
|
||
"bool": {
|
||
"must": must_conditions,
|
||
"should": should_conditions,
|
||
"minimum_should_match": 1 if should_conditions else 0
|
||
}
|
||
},
|
||
"from": from_,
|
||
"size": size
|
||
}
|
||
|
||
# 返回所有字段(除了embedding)
|
||
final_query["_source"] = {
|
||
"excludes": ["content_embedding"]
|
||
}
|
||
|
||
# 添加高亮显示
|
||
if mode in ['text', 'hybrid']:
|
||
final_query["highlight"] = {
|
||
"fields": {
|
||
"content": {
|
||
"fragment_size": 200,
|
||
"number_of_fragments": 3,
|
||
"pre_tags": ["<mark>"],
|
||
"post_tags": ["</mark>"]
|
||
},
|
||
"report_title": {
|
||
"fragment_size": 100,
|
||
"number_of_fragments": 1,
|
||
"pre_tags": ["<mark>"],
|
||
"post_tags": ["</mark>"]
|
||
}
|
||
}
|
||
}
|
||
|
||
# 添加最小得分过滤
|
||
if min_score is not None:
|
||
final_query["min_score"] = min_score
|
||
|
||
logger.info(f"执行查询: mode={mode}, exact_match={exact_match}, query={query[:50]}")
|
||
logger.debug(f"查询详情: {json.dumps(final_query, ensure_ascii=False)[:500]}")
|
||
|
||
# 执行搜索
|
||
response = es.search(index="research_reports", body=final_query)
|
||
|
||
# 处理结果
|
||
results = []
|
||
for hit in response['hits']['hits']:
|
||
source = hit['_source']
|
||
|
||
result = {
|
||
'object_id': str(source.get('object_id', '')),
|
||
'report_title': source.get('report_title', ''),
|
||
'security_code': source.get('security_code'),
|
||
'security_name': source.get('security_name'),
|
||
'author': source.get('author'),
|
||
'rating': source.get('rating'),
|
||
'publisher': source.get('publisher'),
|
||
'declare_date': source.get('declare_date'),
|
||
'content': source.get('content', ''),
|
||
'content_url': source.get('content_url'),
|
||
'processed_time': source.get('processed_time'),
|
||
'score': hit['_score']
|
||
}
|
||
|
||
# 添加高亮内容
|
||
if 'highlight' in hit:
|
||
result['highlight'] = hit['highlight']
|
||
|
||
results.append(result)
|
||
|
||
# 计算执行时间
|
||
took_ms = int((time.time() - start_time) * 1000)
|
||
|
||
return jsonify({
|
||
'success': True,
|
||
'data': {
|
||
'total': response['hits']['total']['value'],
|
||
'results': results,
|
||
'took_ms': took_ms,
|
||
'search_mode': mode,
|
||
'exact_match': exact_match == '1'
|
||
}
|
||
})
|
||
|
||
except RequestError as e:
|
||
logger.error(f"Elasticsearch查询失败: {e.info}")
|
||
error_detail = e.info.get('error', {}).get('root_cause', [{}])[0].get('reason', str(e))
|
||
return jsonify({
|
||
'success': False,
|
||
'error': f"搜索查询错误: {error_detail}"
|
||
}), 400
|
||
except Exception as e:
|
||
logger.error(f"搜索时发生意外错误: {str(e)}", exc_info=True)
|
||
return jsonify({
|
||
'success': False,
|
||
'error': str(e)
|
||
}), 500
|
||
|
||
|
||
@app.route('/reports/<object_id>', methods=['GET'])
|
||
def get_report_by_id(object_id):
|
||
"""根据ID获取特定研究报告的详细信息"""
|
||
try:
|
||
# 排除embedding字段
|
||
response = es.get(
|
||
index="research_reports",
|
||
id=object_id,
|
||
_source_excludes=["content_embedding"]
|
||
)
|
||
|
||
if not response['found']:
|
||
return jsonify({
|
||
'success': False,
|
||
'error': f'Report with ID {object_id} not found'
|
||
}), 404
|
||
|
||
source = response['_source']
|
||
# 确保object_id是字符串
|
||
source['object_id'] = str(source.get('object_id', ''))
|
||
|
||
return jsonify({
|
||
'success': True,
|
||
'data': source
|
||
})
|
||
|
||
except RequestError as e:
|
||
if e.status_code == 404:
|
||
return jsonify({
|
||
'success': False,
|
||
'error': f'Report with ID {object_id} not found'
|
||
}), 404
|
||
logger.error(f"Elasticsearch错误: {e}")
|
||
return jsonify({
|
||
'success': False,
|
||
'error': str(e)
|
||
}), 500
|
||
except Exception as e:
|
||
logger.error(f"意外错误: {str(e)}", exc_info=True)
|
||
return jsonify({
|
||
'success': False,
|
||
'error': str(e)
|
||
}), 500
|
||
|
||
|
||
@app.route('/reports/<object_id>/similar', methods=['GET'])
|
||
def get_similar_reports(object_id):
|
||
"""获取与指定研报相似的其他研报(基于向量相似度)"""
|
||
try:
|
||
size = int(request.args.get('size', 5))
|
||
|
||
# 首先获取原研报的向量
|
||
original = es.get(
|
||
index="research_reports",
|
||
id=object_id,
|
||
_source=["content_embedding"]
|
||
)
|
||
|
||
if not original['found']:
|
||
return jsonify({
|
||
'success': False,
|
||
'error': f'Report with ID {object_id} not found'
|
||
}), 404
|
||
|
||
embedding = original['_source'].get('content_embedding')
|
||
if not embedding:
|
||
return jsonify({
|
||
'success': False,
|
||
'error': 'Report does not have embedding data'
|
||
}), 400
|
||
|
||
# 向量相似度搜索
|
||
query = {
|
||
"query": {
|
||
"script_score": {
|
||
"query": {
|
||
"bool": {
|
||
"must_not": [
|
||
{"term": {"object_id": object_id}} # 排除自己
|
||
]
|
||
}
|
||
},
|
||
"script": {
|
||
"source": """
|
||
double similarity = cosineSimilarity(params.query_vector, 'content_embedding') + 1.0;
|
||
return similarity;
|
||
""",
|
||
"params": {
|
||
"query_vector": embedding
|
||
}
|
||
}
|
||
}
|
||
},
|
||
"size": size,
|
||
"_source": {
|
||
"excludes": ["content_embedding", "content"] # 不返回内容和向量
|
||
}
|
||
}
|
||
|
||
response = es.search(index="research_reports", body=query)
|
||
|
||
similar_reports = []
|
||
for hit in response['hits']['hits']:
|
||
source = hit['_source']
|
||
source['object_id'] = str(source.get('object_id', ''))
|
||
source['similarity_score'] = hit['_score']
|
||
similar_reports.append(source)
|
||
|
||
return jsonify({
|
||
'success': True,
|
||
'data': {
|
||
'object_id': object_id,
|
||
'similar_reports': similar_reports,
|
||
'total': len(similar_reports)
|
||
}
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取相似研报失败: {str(e)}", exc_info=True)
|
||
return jsonify({
|
||
'success': False,
|
||
'error': str(e)
|
||
}), 500
|
||
|
||
|
||
# ========== 辅助函数 ==========
|
||
|
||
def process_sector_data_from_stats(sector_stats: List[Dict], stocks: List[Dict]) -> Dict:
|
||
"""从zt_daily_stats的sector_stats处理板块数据"""
|
||
logger.info(f"process_sector_data_from_stats called with {len(sector_stats)} sectors and {len(stocks)} stocks")
|
||
|
||
sector_data = {}
|
||
|
||
# 创建股票代码到股票信息的映射
|
||
stock_map = {stock['scode']: stock for stock in stocks}
|
||
logger.info(f"Created stock_map with {len(stock_map)} entries")
|
||
|
||
# 处理每个板块
|
||
for idx, sector_info in enumerate(sector_stats):
|
||
sector_name = sector_info['sector_name']
|
||
stock_codes = sector_info.get('stock_codes', [])
|
||
|
||
logger.debug(f"Processing sector {idx}: {sector_name} with {len(stock_codes)} stocks")
|
||
|
||
# 获取该板块的所有股票详情
|
||
sector_stocks = []
|
||
missing_stocks = []
|
||
|
||
for scode in stock_codes:
|
||
if scode in stock_map:
|
||
stock = stock_map[scode].copy() # 使用copy避免修改原始数据
|
||
|
||
# 格式化时间
|
||
if 'zt_time' in stock:
|
||
try:
|
||
zt_time = datetime.fromisoformat(stock['zt_time'].replace('Z', '+00:00'))
|
||
stock['formatted_time'] = zt_time.strftime('%H:%M:%S')
|
||
except:
|
||
stock['formatted_time'] = ''
|
||
else:
|
||
stock['formatted_time'] = ''
|
||
|
||
# 处理brief和summary
|
||
stock['brief'] = stock.get('brief', '')
|
||
if not stock['brief'] and stock.get('summary'):
|
||
stock['brief'] = stock['summary'][:100] + '...'
|
||
|
||
# 清理格式
|
||
try:
|
||
stock['brief'] = stock['brief'].replace('\n', '<br>').replace('#', '').replace('*', '')
|
||
stock['summary'] = stock.get('summary', '').replace('\n', '<br>').replace('#', '').replace('*', '')
|
||
except:
|
||
pass
|
||
|
||
sector_stocks.append(stock)
|
||
else:
|
||
missing_stocks.append(scode)
|
||
|
||
if missing_stocks:
|
||
logger.warning(f"Sector {sector_name}: Missing stocks in stock_map: {missing_stocks}")
|
||
|
||
# 按涨停时间排序
|
||
sector_stocks.sort(key=lambda x: x.get('zt_time', ''))
|
||
|
||
sector_data[sector_name] = {
|
||
'count': sector_info['count'],
|
||
'stocks': sector_stocks
|
||
}
|
||
|
||
logger.debug(f"Sector {sector_name}: found {len(sector_stocks)}/{len(stock_codes)} stocks")
|
||
|
||
# 按照特定顺序排序
|
||
sorted_items = []
|
||
announcement_item = None
|
||
other_item = None
|
||
normal_items = []
|
||
|
||
for sector_name, data in sector_data.items():
|
||
if sector_name == '公告':
|
||
announcement_item = (sector_name, data)
|
||
elif sector_name == '其他':
|
||
other_item = (sector_name, data)
|
||
else:
|
||
normal_items.append((sector_name, data))
|
||
|
||
# 按数量排序普通板块
|
||
normal_items.sort(key=lambda x: x[1]['count'], reverse=True)
|
||
|
||
# 组合最终顺序
|
||
if announcement_item:
|
||
sorted_items.append(announcement_item)
|
||
sorted_items.extend(normal_items)
|
||
if other_item:
|
||
sorted_items.append(other_item)
|
||
|
||
result = dict(sorted_items)
|
||
logger.info(f"process_sector_data_from_stats completed with {len(result)} sectors")
|
||
|
||
return result
|
||
|
||
|
||
def process_sector_data(stocks: List[Dict]) -> Dict:
|
||
"""处理板块数据"""
|
||
sector_data = defaultdict(lambda: {"count": 0, "stocks": []})
|
||
|
||
# 统计板块出现次数
|
||
sector_counts = defaultdict(int)
|
||
for stock in stocks:
|
||
for sector in stock.get('core_sectors', []):
|
||
sector_counts[sector] += 1
|
||
|
||
# 筛选小板块(出现次数少于2的)
|
||
small_sectors = {sector for sector, count in sector_counts.items() if count < 2}
|
||
|
||
# 分类股票
|
||
for stock in stocks:
|
||
# 处理公告类
|
||
if stock.get('effective_announcements'):
|
||
sector_data['公告']['count'] += 1
|
||
sector_data['公告']['stocks'].append(stock)
|
||
continue
|
||
|
||
# 获取有效板块
|
||
valid_sectors = [sector for sector in stock.get('core_sectors', [])
|
||
if sector not in small_sectors]
|
||
|
||
if valid_sectors:
|
||
for sector in valid_sectors:
|
||
sector_data[sector]['count'] += 1
|
||
sector_data[sector]['stocks'].append(stock)
|
||
else:
|
||
sector_data['其他']['count'] += 1
|
||
sector_data['其他']['stocks'].append(stock)
|
||
|
||
# 排序
|
||
sorted_data = dict(sorted(
|
||
sector_data.items(),
|
||
key=lambda x: (x[0] != '公告', -x[1]['count'])
|
||
))
|
||
|
||
return sorted_data
|
||
|
||
|
||
def generate_chart_data(sector_data: Dict) -> Dict:
|
||
"""生成图表数据"""
|
||
labels = []
|
||
counts = []
|
||
|
||
for sector, data in sector_data.items():
|
||
if sector not in ['其他', '公告']:
|
||
labels.append(sector)
|
||
counts.append(data['count'])
|
||
|
||
return {
|
||
'labels': labels,
|
||
'counts': counts
|
||
}
|
||
|
||
|
||
def generate_summary(stocks: List[Dict], sector_data: Dict) -> Dict:
|
||
"""生成汇总信息"""
|
||
announcement_stocks = sum(1 for stock in stocks if stock.get('effective_announcements'))
|
||
|
||
# 计算涨停时间分布
|
||
time_distribution = {
|
||
'morning': 0, # 9:30-11:30
|
||
'midday': 0, # 11:30-13:00
|
||
'afternoon': 0 # 13:00-15:00
|
||
}
|
||
|
||
for stock in stocks:
|
||
if 'zt_time' in stock:
|
||
try:
|
||
zt_time = datetime.fromisoformat(stock['zt_time'].replace('Z', '+00:00'))
|
||
hour = zt_time.hour
|
||
|
||
if 9 <= hour < 11 or (hour == 11 and zt_time.minute <= 30):
|
||
time_distribution['morning'] += 1
|
||
elif hour == 11 or hour == 12:
|
||
time_distribution['midday'] += 1
|
||
else:
|
||
time_distribution['afternoon'] += 1
|
||
except:
|
||
pass
|
||
|
||
# 找出最大板块
|
||
top_sector = None
|
||
top_count = 0
|
||
for sector, data in sector_data.items():
|
||
if sector not in ['其他', '公告'] and data['count'] > top_count:
|
||
top_sector = sector
|
||
top_count = data['count']
|
||
|
||
return {
|
||
'total_stocks': len(stocks),
|
||
'announcement_stocks': announcement_stocks,
|
||
'top_sector': top_sector,
|
||
'top_sector_count': top_count,
|
||
'zt_time_distribution': time_distribution
|
||
}
|
||
|
||
|
||
def calculate_sector_relations(stocks: List[Dict]) -> Dict:
|
||
"""计算板块关联关系"""
|
||
# 建立股票-板块映射
|
||
stock_sector_map = defaultdict(set)
|
||
sector_stocks = defaultdict(set)
|
||
|
||
for stock in stocks:
|
||
scode = stock['scode']
|
||
for sector in stock.get('core_sectors', []):
|
||
stock_sector_map[scode].add(sector)
|
||
sector_stocks[sector].add(scode)
|
||
|
||
# 计算板块关联
|
||
nodes = []
|
||
links = []
|
||
relation_map = defaultdict(int)
|
||
|
||
# 创建节点
|
||
for sector, stock_codes in sector_stocks.items():
|
||
if sector not in ['其他', '公告']:
|
||
nodes.append({
|
||
'name': sector,
|
||
'value': len(stock_codes),
|
||
'symbolSize': min(50, max(10, len(stock_codes) * 3))
|
||
})
|
||
|
||
# 计算板块间共同股票数
|
||
sector_list = list(sector_stocks.keys())
|
||
for i in range(len(sector_list)):
|
||
for j in range(i + 1, len(sector_list)):
|
||
sector1, sector2 = sector_list[i], sector_list[j]
|
||
if sector1 not in ['其他', '公告'] and sector2 not in ['其他', '公告']:
|
||
common_stocks = sector_stocks[sector1] & sector_stocks[sector2]
|
||
if len(common_stocks) > 0:
|
||
links.append({
|
||
'source': sector1,
|
||
'target': sector2,
|
||
'value': len(common_stocks)
|
||
})
|
||
|
||
return {
|
||
'nodes': nodes,
|
||
'links': links
|
||
}
|
||
|
||
|
||
def calculate_sector_relations_top10(stocks: List[Dict]) -> Dict:
|
||
"""计算板块关联关系TOP10(用于横向柱状图)"""
|
||
relations = defaultdict(int)
|
||
stock_sector_map = defaultdict(set)
|
||
|
||
# 建立股票-板块映射
|
||
for stock in stocks:
|
||
scode = stock['scode']
|
||
sectors = stock.get('core_sectors', [])
|
||
for sector in sectors:
|
||
stock_sector_map[scode].add(sector)
|
||
|
||
# 基于股票维度计算板块关联
|
||
for scode, sectors in stock_sector_map.items():
|
||
sector_list = list(sectors)
|
||
# 计算板块两两组合
|
||
for i in range(len(sector_list)):
|
||
for j in range(i + 1, len(sector_list)):
|
||
pair = tuple(sorted([sector_list[i], sector_list[j]]))
|
||
relations[pair] += 1
|
||
|
||
# 转换为数组并排序,取TOP10
|
||
sorted_relations = sorted(relations.items(), key=lambda x: x[1], reverse=True)[:10]
|
||
|
||
return {
|
||
'labels': [f"{pair[0]} - {pair[1]}" for pair, _ in sorted_relations],
|
||
'counts': [count for _, count in sorted_relations]
|
||
}
|
||
|
||
|
||
def generate_wordcloud_from_stocks(date: str):
|
||
"""从股票数据生成词云(备用方案)"""
|
||
try:
|
||
# 获取当日所有股票的文本内容
|
||
query = {
|
||
"query": {
|
||
"term": {"date": date}
|
||
},
|
||
"_source": ["brief", "summary", "core_sectors"],
|
||
"size": 10000
|
||
}
|
||
|
||
result = es.search(index="zt_stocks", body=query)
|
||
|
||
# 统计词频
|
||
word_count = defaultdict(int)
|
||
|
||
for hit in result['hits']['hits']:
|
||
stock = hit['_source']
|
||
|
||
# 统计板块词频
|
||
for sector in stock.get('core_sectors', []):
|
||
if sector not in ['其他', '公告']:
|
||
word_count[sector] += 500 # 板块权重更高
|
||
|
||
# 这里可以添加对brief和summary的分词处理
|
||
# 但需要额外的中文分词库支持
|
||
|
||
# 转换为词云格式
|
||
word_freq = [
|
||
{'name': word, 'value': count}
|
||
for word, count in word_count.items()
|
||
]
|
||
|
||
# 按频率排序并限制数量
|
||
word_freq.sort(key=lambda x: x['value'], reverse=True)
|
||
|
||
return jsonify({
|
||
'success': True,
|
||
'data': word_freq[:100]
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error generating wordcloud from stocks: {str(e)}")
|
||
return jsonify({
|
||
'success': False,
|
||
'error': str(e)
|
||
}), 500
|
||
|
||
|
||
def update_cache(cache_key: str, cache_type: str, data: Any, ttl_hours: int = 24):
|
||
"""更新缓存"""
|
||
try:
|
||
doc = {
|
||
'cache_key': cache_key,
|
||
'cache_type': cache_type,
|
||
'data': data,
|
||
'created_at': datetime.now().isoformat(),
|
||
'updated_at': datetime.now().isoformat(),
|
||
'expires_at': (datetime.now() + timedelta(hours=ttl_hours)).isoformat()
|
||
}
|
||
|
||
# 使用upsert确保更新或创建
|
||
es.index(
|
||
index='zt_cache',
|
||
id=f"{cache_key}_{cache_type}",
|
||
body=doc
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error updating cache: {str(e)}")
|
||
|
||
|
||
# 健康检查端点
|
||
@app.route('/api/v1/health', methods=['GET'])
|
||
def health_check():
|
||
"""健康检查"""
|
||
try:
|
||
# 检查ES连接
|
||
es_health = es.cluster.health()
|
||
|
||
# 检查embedding服务
|
||
embedding_status = []
|
||
for url in EMBEDDING_BASE_URLS:
|
||
try:
|
||
client = openai.OpenAI(api_key="dummy", base_url=url, timeout=5)
|
||
# 简单测试
|
||
client.embeddings.create(model=EMBEDDING_MODEL, input=["test"])
|
||
embedding_status.append({"url": url, "status": "healthy"})
|
||
except:
|
||
embedding_status.append({"url": url, "status": "unhealthy"})
|
||
|
||
return jsonify({
|
||
'success': True,
|
||
'status': 'healthy',
|
||
'elasticsearch': {
|
||
'status': es_health['status'],
|
||
'cluster_name': es_health['cluster_name']
|
||
},
|
||
'embedding_services': embedding_status
|
||
})
|
||
except Exception as e:
|
||
return jsonify({
|
||
'success': False,
|
||
'status': 'unhealthy',
|
||
'error': str(e)
|
||
}), 503
|
||
|
||
|
||
# 错误处理
|
||
@app.errorhandler(404)
|
||
def not_found(error):
|
||
return jsonify({
|
||
'success': False,
|
||
'error': 'Endpoint not found'
|
||
}), 404
|
||
|
||
|
||
@app.errorhandler(500)
|
||
def internal_error(error):
|
||
return jsonify({
|
||
'success': False,
|
||
'error': 'Internal server error'
|
||
}), 500
|
||
|
||
|
||
if __name__ == '__main__':
|
||
app.run(debug=True, host='0.0.0.0', port=8800) |