1562 lines
59 KiB
Python
1562 lines
59 KiB
Python
import json
|
||
import openai
|
||
from typing import List, Dict, Optional, Union, Any
|
||
from fastapi import FastAPI, HTTPException, Query
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from pydantic import BaseModel, Field
|
||
from elasticsearch import Elasticsearch
|
||
from datetime import datetime, date
|
||
import logging
|
||
import re
|
||
from contextlib import asynccontextmanager
|
||
import aiomysql
|
||
from decimal import Decimal
|
||
|
||
# 配置日志
|
||
logging.basicConfig(level=logging.INFO)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# 全局变量
|
||
es_client = None
|
||
openai_client = None
|
||
mysql_pool = None
|
||
|
||
# 配置
|
||
ES_HOST = 'http://127.0.0.1:9200'
|
||
OPENAI_BASE_URL = "http://127.0.0.1:8000/v1"
|
||
OPENAI_API_KEY = "dummy"
|
||
EMBEDDING_MODEL = "qwen3-embedding-8b"
|
||
INDEX_NAME = 'concept_library'
|
||
|
||
# MySQL配置
|
||
MYSQL_CONFIG = {
|
||
'host': '192.168.1.8',
|
||
'user': 'root',
|
||
'password': 'Zzl5588161!',
|
||
'db': 'stock',
|
||
'charset': 'utf8mb4',
|
||
'autocommit': True,
|
||
'minsize': 1,
|
||
'maxsize': 10
|
||
}
|
||
|
||
|
||
# API生命周期管理
|
||
@asynccontextmanager
|
||
async def lifespan(app: FastAPI):
|
||
# 启动时初始化
|
||
global es_client, openai_client, mysql_pool
|
||
|
||
# 初始化Elasticsearch客户端 - 增加超时配置
|
||
es_client = Elasticsearch(
|
||
[ES_HOST],
|
||
timeout=30,
|
||
max_retries=3,
|
||
retry_on_timeout=True
|
||
)
|
||
logger.info(f"Connected to Elasticsearch at {ES_HOST}")
|
||
|
||
# 初始化OpenAI客户端
|
||
openai_client = openai.OpenAI(
|
||
api_key=OPENAI_API_KEY,
|
||
base_url=OPENAI_BASE_URL,
|
||
timeout=60,
|
||
)
|
||
logger.info(f"Initialized OpenAI client")
|
||
|
||
# 初始化MySQL连接池
|
||
try:
|
||
mysql_pool = await aiomysql.create_pool(**MYSQL_CONFIG)
|
||
logger.info(f"Connected to MySQL at {MYSQL_CONFIG['host']}")
|
||
except Exception as e:
|
||
logger.error(f"Failed to connect to MySQL: {e}")
|
||
|
||
yield
|
||
|
||
# 关闭时清理资源
|
||
if es_client:
|
||
es_client.close()
|
||
if mysql_pool:
|
||
mysql_pool.close()
|
||
await mysql_pool.wait_closed()
|
||
logger.info("Cleanup completed")
|
||
|
||
|
||
# 创建FastAPI应用
|
||
app = FastAPI(
|
||
title="概念搜索API",
|
||
description="支持语义和关键词混合搜索的概念库API,包含概念涨跌幅数据",
|
||
version="1.2.0",
|
||
lifespan=lifespan
|
||
)
|
||
|
||
|
||
# 添加CORS中间件
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=["*"],
|
||
allow_credentials=True,
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
|
||
|
||
# 请求和响应模型
|
||
class SearchRequest(BaseModel):
|
||
query: str = Field(..., description="搜索查询文本")
|
||
size: int = Field(10, ge=1, le=100, description="每页返回结果数量")
|
||
page: int = Field(1, ge=1, description="页码")
|
||
search_size: int = Field(100, ge=10, le=1000, description="搜索数量(从ES获取的结果数),用于排序后分页")
|
||
semantic_weight: Optional[float] = Field(None, ge=0.0, le=1.0, description="语义搜索权重(0-1),None表示自动计算")
|
||
filter_stocks: Optional[List[str]] = Field(None, description="过滤特定股票代码或名称")
|
||
trade_date: Optional[date] = Field(None, description="交易日期,格式:YYYY-MM-DD,默认返回最新日期数据")
|
||
sort_by: str = Field("change_pct", description="排序方式: change_pct, _score, stock_count, concept_name, added_date")
|
||
use_knn: bool = Field(True, description="是否使用KNN搜索优化语义搜索")
|
||
|
||
|
||
class StockInfo(BaseModel):
|
||
stock_name: str
|
||
stock_code: str
|
||
reason: Optional[str] = None
|
||
industry: Optional[str] = Field(None, alias="行业")
|
||
project: Optional[str] = Field(None, alias="项目")
|
||
|
||
class Config:
|
||
populate_by_name = True
|
||
|
||
|
||
class ConceptPriceInfo(BaseModel):
|
||
trade_date: date
|
||
avg_change_pct: Optional[float] = Field(None, description="平均涨跌幅(%)")
|
||
|
||
|
||
class ConceptResult(BaseModel):
|
||
concept_id: str
|
||
concept: str
|
||
description: Optional[str]
|
||
stocks: List[StockInfo]
|
||
stock_count: int
|
||
happened_times: Optional[List[str]] = None
|
||
score: float
|
||
match_type: str # "semantic", "keyword", "hybrid", "semantic_knn", "hybrid_knn"
|
||
highlights: Optional[Dict[str, List[str]]] = None
|
||
price_info: Optional[ConceptPriceInfo] = None
|
||
|
||
|
||
class SearchResponse(BaseModel):
|
||
total: int
|
||
took_ms: int
|
||
results: List[ConceptResult]
|
||
search_info: Dict[str, Any]
|
||
price_date: Optional[date] = None
|
||
page: int = Field(1, description="当前页码")
|
||
total_pages: int = Field(1, description="总页数")
|
||
|
||
|
||
class StockConceptInfo(BaseModel):
|
||
concept_id: str
|
||
concept: str
|
||
stock_count: int
|
||
happened_times: Optional[List[str]] = None
|
||
description: Optional[str] = None
|
||
stock_detail: Optional[Dict[str, Any]] = None
|
||
price_info: Optional[ConceptPriceInfo] = None
|
||
|
||
|
||
class StockConceptsResponse(BaseModel):
|
||
stock_code: str
|
||
stats: Dict[str, Any]
|
||
concepts: List[StockConceptInfo]
|
||
price_date: Optional[date] = None
|
||
|
||
|
||
class StockSearchResult(BaseModel):
|
||
stock_code: str
|
||
stock_name: str
|
||
concept_count: int
|
||
|
||
|
||
class StockSearchResponse(BaseModel):
|
||
total: int
|
||
stocks: List[StockSearchResult]
|
||
|
||
|
||
class PriceTimeSeriesItem(BaseModel):
|
||
trade_date: date
|
||
avg_change_pct: Optional[float] = Field(None, description="平均涨跌幅(%)")
|
||
stock_count: Optional[int] = Field(None, description="当日股票数量")
|
||
|
||
|
||
class PriceTimeSeriesResponse(BaseModel):
|
||
concept_id: str
|
||
concept_name: str
|
||
start_date: date
|
||
end_date: date
|
||
data_points: int
|
||
timeseries: List[PriceTimeSeriesItem]
|
||
|
||
|
||
# 辅助函数
|
||
async def get_concept_price_data(concept_ids: List[str], trade_date: Optional[date] = None) -> Dict[
|
||
str, ConceptPriceInfo]:
|
||
"""获取概念的涨跌幅数据"""
|
||
if not mysql_pool or not concept_ids:
|
||
return {}
|
||
|
||
try:
|
||
async with mysql_pool.acquire() as conn:
|
||
async with conn.cursor(aiomysql.DictCursor) as cursor:
|
||
placeholders = ','.join(['%s'] * len(concept_ids))
|
||
|
||
if trade_date:
|
||
query = f"""
|
||
SELECT concept_id, concept_name, trade_date, avg_change_pct
|
||
FROM concept_daily_stats
|
||
WHERE concept_id IN ({placeholders}) AND trade_date = %s
|
||
"""
|
||
await cursor.execute(query, (*concept_ids, trade_date))
|
||
else:
|
||
query = f"""
|
||
SELECT cds.concept_id, cds.concept_name, cds.trade_date, cds.avg_change_pct
|
||
FROM concept_daily_stats cds
|
||
INNER JOIN (
|
||
SELECT concept_id, MAX(trade_date) as max_date
|
||
FROM concept_daily_stats
|
||
WHERE concept_id IN ({placeholders})
|
||
GROUP BY concept_id
|
||
) latest ON cds.concept_id = latest.concept_id
|
||
AND cds.trade_date = latest.max_date
|
||
"""
|
||
await cursor.execute(query, concept_ids)
|
||
|
||
rows = await cursor.fetchall()
|
||
|
||
result = {}
|
||
for row in rows:
|
||
result[row['concept_id']] = ConceptPriceInfo(
|
||
trade_date=row['trade_date'],
|
||
avg_change_pct=float(row['avg_change_pct']) if row['avg_change_pct'] is not None else None
|
||
)
|
||
|
||
return result
|
||
except Exception as e:
|
||
logger.error(f"Error fetching concept price data: {e}")
|
||
return {}
|
||
|
||
|
||
async def get_latest_trade_date() -> Optional[date]:
|
||
"""获取最新的交易日期"""
|
||
if not mysql_pool:
|
||
return None
|
||
|
||
try:
|
||
async with mysql_pool.acquire() as conn:
|
||
async with conn.cursor() as cursor:
|
||
query = "SELECT MAX(trade_date) as latest_date FROM concept_daily_stats"
|
||
await cursor.execute(query)
|
||
result = await cursor.fetchone()
|
||
return result[0] if result and result[0] else None
|
||
except Exception as e:
|
||
logger.error(f"Error fetching latest trade date: {e}")
|
||
return None
|
||
|
||
|
||
def generate_embedding(text: str) -> List[float]:
|
||
"""生成文本向量"""
|
||
try:
|
||
if not text or len(text.strip()) == 0:
|
||
logger.warning("Empty text provided for embedding generation")
|
||
return []
|
||
|
||
# 限制文本长度
|
||
text = text[:16000] if len(text) > 16000 else text
|
||
|
||
# 检查OpenAI客户端是否已初始化
|
||
if not openai_client:
|
||
logger.warning("OpenAI client not initialized, falling back to keyword search")
|
||
return []
|
||
|
||
response = openai_client.embeddings.create(
|
||
model=EMBEDDING_MODEL,
|
||
input=[text]
|
||
)
|
||
|
||
embedding = response.data[0].embedding
|
||
|
||
# 验证embedding是否有效
|
||
if not embedding or len(embedding) == 0:
|
||
logger.warning("Empty embedding returned from OpenAI API")
|
||
return []
|
||
|
||
logger.debug(f"Successfully generated embedding for text: {text[:100]}...")
|
||
return embedding
|
||
|
||
except openai.OpenAIError as e:
|
||
logger.warning(f"OpenAI API error during embedding generation: {e}. Falling back to keyword search.")
|
||
return []
|
||
except Exception as e:
|
||
logger.warning(f"Unexpected error during embedding generation: {e}. Falling back to keyword search.")
|
||
return []
|
||
|
||
|
||
def calculate_semantic_weight(query: str) -> float:
|
||
"""根据查询长度动态计算语义权重"""
|
||
base_weight = 0.3
|
||
query_length = len(query)
|
||
word_count = len(query.split())
|
||
|
||
if query_length < 10:
|
||
length_factor = 0.0
|
||
elif query_length < 50:
|
||
length_factor = 0.2
|
||
elif query_length < 200:
|
||
length_factor = 0.4
|
||
elif query_length < 500:
|
||
length_factor = 0.5
|
||
else:
|
||
length_factor = 0.6
|
||
|
||
if word_count < 3:
|
||
word_factor = 0.0
|
||
elif word_count < 10:
|
||
word_factor = 0.2
|
||
elif word_count < 30:
|
||
word_factor = 0.3
|
||
else:
|
||
word_factor = 0.4
|
||
|
||
if re.search(r'\b\d{6}\b', query):
|
||
pattern_factor = -0.2
|
||
elif word_count <= 2 and query_length < 20:
|
||
pattern_factor = -0.1
|
||
elif '。' in query or ',' in query or len(query) > 100:
|
||
pattern_factor = 0.2
|
||
else:
|
||
pattern_factor = 0.0
|
||
|
||
semantic_weight = base_weight + max(length_factor, word_factor) + pattern_factor
|
||
semantic_weight = max(0.3, min(0.9, semantic_weight))
|
||
|
||
return semantic_weight
|
||
|
||
|
||
def extract_stock_filter(query: str) -> tuple[str, List[str]]:
|
||
"""从查询中提取股票过滤条件"""
|
||
stock_patterns = []
|
||
cleaned_query = query
|
||
|
||
code_matches = re.findall(r'\b(\d{6})\b', query)
|
||
stock_patterns.extend(code_matches)
|
||
|
||
stock_name_pattern = r'([\u4e00-\u9fa5]{2,6})(?:股票|股份|集团|公司)'
|
||
name_matches = re.findall(stock_name_pattern, query)
|
||
stock_patterns.extend(name_matches)
|
||
|
||
for pattern in stock_patterns:
|
||
cleaned_query = cleaned_query.replace(pattern, '')
|
||
|
||
cleaned_query = ' '.join(cleaned_query.split())
|
||
|
||
return cleaned_query, stock_patterns
|
||
|
||
|
||
def build_keyword_query(query: str, stock_filters: List[str] = None) -> Dict:
|
||
"""构建关键词查询"""
|
||
must_queries = []
|
||
|
||
if query.strip():
|
||
must_queries.append({
|
||
"multi_match": {
|
||
"query": query,
|
||
"fields": [
|
||
"concept^3",
|
||
"description^2",
|
||
"stocks_reason",
|
||
"stocks_full_text",
|
||
"stocks.stock_name^2",
|
||
"stocks.stock_code^2",
|
||
"stocks.项目",
|
||
"stocks.reason",
|
||
"stocks.行业"
|
||
],
|
||
"type": "best_fields",
|
||
"analyzer": "ik_smart"
|
||
}
|
||
})
|
||
|
||
if stock_filters:
|
||
stock_query = {
|
||
"nested": {
|
||
"path": "stocks",
|
||
"query": {
|
||
"bool": {
|
||
"should": [
|
||
{"terms": {"stocks.stock_name": stock_filters}},
|
||
{"terms": {"stocks.stock_code": stock_filters}}
|
||
]
|
||
}
|
||
}
|
||
}
|
||
}
|
||
must_queries.append(stock_query)
|
||
|
||
return {
|
||
"bool": {
|
||
"must": must_queries
|
||
}
|
||
}
|
||
|
||
|
||
def build_semantic_query(embedding: List[float]) -> Dict:
|
||
"""构建语义查询"""
|
||
return {
|
||
"script_score": {
|
||
"query": {"match_all": {}},
|
||
"script": {
|
||
"source": "cosineSimilarity(params.query_vector, 'description_embedding') + 1.0",
|
||
"params": {
|
||
"query_vector": embedding
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
|
||
def build_hybrid_query(query: str, embedding: List[float], semantic_weight: float,
|
||
stock_filters: List[str] = None) -> Dict:
|
||
"""构建混合查询"""
|
||
keyword_weight = 1.0 - semantic_weight
|
||
queries = []
|
||
|
||
if keyword_weight > 0:
|
||
keyword_query = build_keyword_query(query, stock_filters)
|
||
queries.append({
|
||
"bool": {
|
||
"must": [keyword_query],
|
||
"boost": keyword_weight
|
||
}
|
||
})
|
||
|
||
if semantic_weight > 0 and embedding:
|
||
queries.append({
|
||
"script_score": {
|
||
"query": {"match_all": {}},
|
||
"script": {
|
||
"source": f"cosineSimilarity(params.query_vector, 'description_embedding') * {semantic_weight} + 0.0",
|
||
"params": {
|
||
"query_vector": embedding
|
||
}
|
||
},
|
||
"boost": 1.0
|
||
}
|
||
})
|
||
|
||
return {
|
||
"bool": {
|
||
"should": queries,
|
||
"minimum_should_match": 1
|
||
}
|
||
}
|
||
|
||
|
||
def build_hybrid_knn_query(
|
||
query: str,
|
||
embedding: List[float],
|
||
semantic_weight: float,
|
||
stock_filters: List[str] = None,
|
||
k: int = 100
|
||
) -> Dict:
|
||
"""构建混合查询(KNN + 关键词)"""
|
||
keyword_weight = 1.0 - semantic_weight
|
||
|
||
filter_query = None
|
||
if stock_filters:
|
||
filter_query = {
|
||
"nested": {
|
||
"path": "stocks",
|
||
"query": {
|
||
"bool": {
|
||
"should": [
|
||
{"terms": {"stocks.stock_name": stock_filters}},
|
||
{"terms": {"stocks.stock_code": stock_filters}}
|
||
]
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
search_body = {
|
||
"knn": {
|
||
"field": "description_embedding",
|
||
"query_vector": embedding,
|
||
"k": k,
|
||
"num_candidates": max(k + 50, min(k * 2, 10000)), # 确保 num_candidates > k,最大 10000
|
||
"boost": semantic_weight
|
||
}
|
||
}
|
||
|
||
if filter_query:
|
||
search_body["knn"]["filter"] = filter_query
|
||
|
||
keyword_query = build_keyword_query(query, stock_filters)
|
||
search_body["query"] = {
|
||
"bool": {
|
||
"must": [keyword_query],
|
||
"boost": keyword_weight
|
||
}
|
||
}
|
||
|
||
return search_body
|
||
|
||
|
||
# API端点
|
||
@app.get("/", tags=["Health"])
|
||
async def root():
|
||
"""健康检查端点"""
|
||
return {"status": "healthy", "service": "概念搜索API", "version": "1.2.0"}
|
||
|
||
|
||
@app.post("/search", response_model=SearchResponse, tags=["Search"])
|
||
async def search_concepts(request: SearchRequest):
|
||
"""
|
||
搜索概念库 - 支持KNN优化的语义搜索
|
||
|
||
新特性:
|
||
- 使用KNN搜索优化语义搜索性能
|
||
- search_size参数控制搜索数量,提高排序灵活性
|
||
- 支持混合搜索(KNN + 关键词)
|
||
"""
|
||
start_time = datetime.now()
|
||
|
||
try:
|
||
# 提取股票过滤条件
|
||
cleaned_query, stock_filters = extract_stock_filter(request.query)
|
||
if request.filter_stocks:
|
||
stock_filters.extend(request.filter_stocks)
|
||
|
||
# 计算语义权重
|
||
if request.semantic_weight is not None:
|
||
semantic_weight = request.semantic_weight
|
||
else:
|
||
semantic_weight = calculate_semantic_weight(request.query)
|
||
|
||
# 生成embedding(如果需要)
|
||
embedding = []
|
||
if semantic_weight > 0:
|
||
embedding = generate_embedding(request.query)
|
||
if not embedding:
|
||
# 已经在generate_embedding中记录了详细日志,这里只调整语义权重
|
||
semantic_weight = 0
|
||
|
||
# 【关键修改】:如果按涨跌幅或添加日期排序,需要获取更多结果
|
||
effective_search_size = request.search_size
|
||
if request.sort_by in ["change_pct", "added_date"]:
|
||
# 按涨跌幅或添加日期排序时,获取更多结果以确保排序准确性
|
||
effective_search_size = min(1000, request.search_size * 10) # 最多获取1000个
|
||
logger.info(f"Using expanded search size {effective_search_size} for {request.sort_by} sorting")
|
||
|
||
# 构建查询体
|
||
search_body = {}
|
||
match_type = "keyword"
|
||
|
||
# 根据搜索类型构建不同的查询
|
||
if semantic_weight == 0:
|
||
# 纯关键词搜索
|
||
search_body = {
|
||
"query": build_keyword_query(cleaned_query or request.query, stock_filters),
|
||
"size": effective_search_size # 使用有效搜索大小
|
||
}
|
||
match_type = "keyword"
|
||
|
||
elif semantic_weight == 1.0 and request.use_knn and embedding:
|
||
# 纯KNN语义搜索
|
||
filter_query = None
|
||
if stock_filters:
|
||
filter_query = {
|
||
"nested": {
|
||
"path": "stocks",
|
||
"query": {
|
||
"bool": {
|
||
"should": [
|
||
{"terms": {"stocks.stock_name": stock_filters}},
|
||
{"terms": {"stocks.stock_code": stock_filters}}
|
||
]
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
search_body = {
|
||
"knn": {
|
||
"field": "description_embedding",
|
||
"query_vector": embedding,
|
||
"k": effective_search_size, # 使用有效搜索大小
|
||
"num_candidates": max(effective_search_size + 50, min(effective_search_size * 2, 10000)) # 确保 num_candidates > k
|
||
},
|
||
"size": effective_search_size
|
||
}
|
||
|
||
if filter_query:
|
||
search_body["knn"]["filter"] = filter_query
|
||
|
||
match_type = "semantic_knn"
|
||
|
||
elif request.use_knn and embedding:
|
||
# 混合搜索(KNN + 关键词)
|
||
hybrid_body = build_hybrid_knn_query(
|
||
cleaned_query or request.query,
|
||
embedding,
|
||
semantic_weight,
|
||
stock_filters,
|
||
k=effective_search_size # 使用有效搜索大小
|
||
)
|
||
search_body = hybrid_body
|
||
search_body["size"] = effective_search_size
|
||
match_type = "hybrid_knn"
|
||
|
||
else:
|
||
# 传统混合搜索(script_score方式,作为后备)
|
||
es_query = build_hybrid_query(
|
||
cleaned_query or request.query,
|
||
embedding,
|
||
semantic_weight,
|
||
stock_filters
|
||
)
|
||
search_body = {
|
||
"query": es_query,
|
||
"size": effective_search_size # 使用有效搜索大小
|
||
}
|
||
match_type = "hybrid"
|
||
|
||
# 添加高亮和源过滤
|
||
search_body.update({
|
||
"highlight": {
|
||
"fields": {
|
||
"concept": {},
|
||
"description": {"fragment_size": 150},
|
||
"stocks_reason": {"fragment_size": 150},
|
||
"stocks_full_text": {"fragment_size": 150}
|
||
}
|
||
},
|
||
"_source": {
|
||
"excludes": ["description_embedding"]
|
||
},
|
||
"track_total_hits": True
|
||
})
|
||
|
||
# 执行搜索(增加超时时间)
|
||
es_response = es_client.search(
|
||
index=INDEX_NAME,
|
||
body=search_body,
|
||
timeout="30s"
|
||
)
|
||
|
||
# 收集结果
|
||
all_results = []
|
||
concept_ids = []
|
||
|
||
for hit in es_response['hits']['hits']:
|
||
source = hit['_source']
|
||
concept_ids.append(source['concept_id'])
|
||
|
||
# 提取股票信息
|
||
stocks = []
|
||
for stock in source.get('stocks', [])[:20]:
|
||
stock_info = StockInfo(
|
||
stock_name=stock.get('stock_name', ''),
|
||
stock_code=stock.get('stock_code', ''),
|
||
reason=stock.get('reason'),
|
||
industry=stock.get('行业'),
|
||
project=stock.get('项目')
|
||
)
|
||
stocks.append(stock_info)
|
||
|
||
# 提取高亮信息
|
||
highlights = {}
|
||
if 'highlight' in hit:
|
||
highlights = hit['highlight']
|
||
|
||
# 构建结果
|
||
result = ConceptResult(
|
||
concept_id=source['concept_id'],
|
||
concept=source['concept'],
|
||
description=source.get('description'),
|
||
stocks=stocks,
|
||
stock_count=len(source.get('stocks', [])),
|
||
happened_times=source.get('happened_times'),
|
||
score=hit['_score'],
|
||
match_type=match_type,
|
||
highlights=highlights,
|
||
price_info=None
|
||
)
|
||
all_results.append(result)
|
||
|
||
# 【关键修改】:始终获取涨跌幅数据,无论排序方式
|
||
# 这样用户可以看到涨跌幅信息,即使不按涨跌幅排序
|
||
price_data = {}
|
||
actual_price_date = None
|
||
|
||
if concept_ids:
|
||
# 获取所有结果的价格数据
|
||
price_data = await get_concept_price_data(concept_ids, request.trade_date)
|
||
if price_data:
|
||
actual_price_date = next(iter(price_data.values())).trade_date
|
||
|
||
# 填充涨跌幅信息
|
||
for result in all_results:
|
||
result.price_info = price_data.get(result.concept_id)
|
||
|
||
# 根据排序方式排序
|
||
if request.sort_by == "change_pct":
|
||
# 按涨跌幅排序(降序)
|
||
all_results.sort(
|
||
key=lambda x: (
|
||
x.price_info.avg_change_pct
|
||
if x.price_info and x.price_info.avg_change_pct is not None
|
||
else -999
|
||
),
|
||
reverse=True
|
||
)
|
||
elif request.sort_by == "stock_count":
|
||
all_results.sort(key=lambda x: x.stock_count, reverse=True)
|
||
elif request.sort_by == "concept_name":
|
||
all_results.sort(key=lambda x: x.concept)
|
||
elif request.sort_by == "added_date":
|
||
# 按添加日期排序(降序 - 最新的在前)
|
||
all_results.sort(
|
||
key=lambda x: (
|
||
x.happened_times[0] if x.happened_times and len(x.happened_times) > 0 else '1900-01-01'
|
||
),
|
||
reverse=True
|
||
)
|
||
# _score排序已经由ES处理
|
||
|
||
# 计算分页
|
||
total_results = len(all_results)
|
||
total_pages = max(1, (total_results + request.size - 1) // request.size)
|
||
current_page = min(request.page, total_pages)
|
||
|
||
# 获取当前页的结果
|
||
start_idx = (current_page - 1) * request.size
|
||
end_idx = start_idx + request.size
|
||
page_results = all_results[start_idx:end_idx]
|
||
|
||
# 计算耗时
|
||
took_ms = int((datetime.now() - start_time).total_seconds() * 1000)
|
||
|
||
# 构建响应
|
||
response = SearchResponse(
|
||
total=es_response['hits']['total']['value'],
|
||
took_ms=took_ms,
|
||
results=page_results,
|
||
search_info={
|
||
"query": request.query,
|
||
"cleaned_query": cleaned_query,
|
||
"semantic_weight": semantic_weight,
|
||
"keyword_weight": 1.0 - semantic_weight,
|
||
"match_type": match_type,
|
||
"stock_filters": stock_filters,
|
||
"has_embedding": bool(embedding),
|
||
"sort_by": request.sort_by,
|
||
"search_size": request.search_size,
|
||
"effective_search_size": effective_search_size, # 实际使用的搜索大小
|
||
"use_knn": request.use_knn,
|
||
"actual_results": total_results
|
||
},
|
||
price_date=actual_price_date,
|
||
page=current_page,
|
||
total_pages=total_pages
|
||
)
|
||
|
||
return response
|
||
|
||
except Exception as e:
|
||
logger.error(f"Search error: {e}", exc_info=True)
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
@app.get("/concept/{concept_id}", tags=["Concepts"])
|
||
async def get_concept(
|
||
concept_id: str,
|
||
trade_date: Optional[date] = Query(None, description="交易日期,格式:YYYY-MM-DD,默认返回最新日期数据")
|
||
):
|
||
"""根据ID获取概念详情,包含涨跌幅数据"""
|
||
try:
|
||
result = es_client.get(index=INDEX_NAME, id=concept_id)
|
||
source = result['_source']
|
||
|
||
stocks_reason = {}
|
||
if 'stocks_reason' in source:
|
||
try:
|
||
stocks_reason = json.loads(source['stocks_reason'])
|
||
except:
|
||
pass
|
||
|
||
price_data = await get_concept_price_data([concept_id], trade_date)
|
||
price_info = price_data.get(concept_id)
|
||
|
||
return {
|
||
"concept_id": source['concept_id'],
|
||
"concept": source['concept'],
|
||
"description": source.get('description'),
|
||
"stocks": source.get('stocks', []),
|
||
"stocks_reason": stocks_reason,
|
||
"happened_times": source.get('happened_times'),
|
||
"created_at": source.get('created_at'),
|
||
"price_info": price_info
|
||
}
|
||
except Exception as e:
|
||
if "NotFoundError" in str(type(e)):
|
||
raise HTTPException(status_code=404, detail="概念不存在")
|
||
logger.error(f"Get concept error: {e}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
@app.get("/stock/{stock_code}/concepts", response_model=StockConceptsResponse, tags=["Stocks"])
|
||
async def get_stock_concepts(
|
||
stock_code: str,
|
||
size: int = Query(50, ge=1, le=200, description="返回概念数量"),
|
||
sort_by: str = Query("stock_count", description="排序方式: stock_count, concept_name, recent"),
|
||
include_description: bool = Query(True, description="是否包含概念描述"),
|
||
trade_date: Optional[date] = Query(None, description="交易日期,格式:YYYY-MM-DD,默认返回最新日期数据")
|
||
):
|
||
"""根据股票代码查询相关概念,包含涨跌幅数据"""
|
||
try:
|
||
query = {
|
||
"nested": {
|
||
"path": "stocks",
|
||
"query": {
|
||
"bool": {
|
||
"should": [
|
||
{"term": {"stocks.stock_code": stock_code}},
|
||
{"term": {"stocks.stock_name": stock_code}}
|
||
]
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
sort_rules = []
|
||
if sort_by == "stock_count":
|
||
sort_rules.append({
|
||
"_script": {
|
||
"type": "number",
|
||
"script": {
|
||
"source": "params._source.stocks.size()"
|
||
},
|
||
"order": "desc"
|
||
}
|
||
})
|
||
elif sort_by == "concept_name":
|
||
sort_rules.append({"concept.keyword": {"order": "asc"}})
|
||
elif sort_by == "recent":
|
||
sort_rules.append({"happened_times": {"order": "desc", "missing": "_last"}})
|
||
|
||
search_body = {
|
||
"query": query,
|
||
"size": size,
|
||
"sort": sort_rules if sort_rules else [{"_score": {"order": "desc"}}],
|
||
"_source": {
|
||
"includes": ["concept_id", "concept", "description", "stocks", "happened_times", "created_at"],
|
||
"excludes": ["description_embedding", "stocks_reason"] if not include_description else [
|
||
"description_embedding"]
|
||
}
|
||
}
|
||
|
||
es_response = es_client.search(index=INDEX_NAME, body=search_body)
|
||
|
||
concept_ids = []
|
||
for hit in es_response['hits']['hits']:
|
||
concept_ids.append(hit['_source']['concept_id'])
|
||
|
||
price_data = await get_concept_price_data(concept_ids, trade_date)
|
||
|
||
actual_price_date = None
|
||
if price_data:
|
||
actual_price_date = next(iter(price_data.values())).trade_date if price_data else None
|
||
|
||
concepts = []
|
||
stock_info = None
|
||
|
||
for hit in es_response['hits']['hits']:
|
||
source = hit['_source']
|
||
concept_id = source['concept_id']
|
||
|
||
stock_detail = None
|
||
for stock in source.get('stocks', []):
|
||
if stock.get('stock_code') == stock_code or stock.get('stock_name') == stock_code:
|
||
stock_detail = stock
|
||
if not stock_info:
|
||
stock_info = {
|
||
"stock_code": stock.get('stock_code'),
|
||
"stock_name": stock.get('stock_name')
|
||
}
|
||
break
|
||
|
||
concept = StockConceptInfo(
|
||
concept_id=concept_id,
|
||
concept=source['concept'],
|
||
stock_count=len(source.get('stocks', [])),
|
||
happened_times=source.get('happened_times'),
|
||
stock_detail=stock_detail,
|
||
price_info=price_data.get(concept_id)
|
||
)
|
||
|
||
if include_description:
|
||
concept.description = source.get('description')
|
||
|
||
concepts.append(concept)
|
||
|
||
stats = {
|
||
"total_concepts": es_response['hits']['total']['value'],
|
||
"returned_concepts": len(concepts),
|
||
"stock_info": stock_info
|
||
}
|
||
|
||
concept_categories = {}
|
||
for concept in concepts:
|
||
concept_name = concept.concept
|
||
if '新能源' in concept_name:
|
||
category = '新能源'
|
||
elif '半导体' in concept_name or '芯片' in concept_name:
|
||
category = '半导体'
|
||
elif '人工智能' in concept_name or 'AI' in concept_name:
|
||
category = '人工智能'
|
||
elif '医' in concept_name:
|
||
category = '医药'
|
||
elif '金融' in concept_name or '银行' in concept_name:
|
||
category = '金融'
|
||
else:
|
||
category = '其他'
|
||
|
||
if category not in concept_categories:
|
||
concept_categories[category] = 0
|
||
concept_categories[category] += 1
|
||
|
||
stats["concept_categories"] = concept_categories
|
||
|
||
return StockConceptsResponse(
|
||
stock_code=stock_code,
|
||
stats=stats,
|
||
concepts=concepts,
|
||
price_date=actual_price_date
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"Get stock concepts error: {e}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
@app.get("/stock/search", response_model=StockSearchResponse, tags=["Stocks"])
|
||
async def search_stocks(
|
||
keyword: str = Query(..., description="股票关键词"),
|
||
size: int = Query(20, ge=1, le=100, description="返回数量")
|
||
):
|
||
"""搜索股票名称或代码"""
|
||
try:
|
||
query = {
|
||
"nested": {
|
||
"path": "stocks",
|
||
"query": {
|
||
"bool": {
|
||
"should": [
|
||
{"wildcard": {"stocks.stock_code": f"*{keyword}*"}},
|
||
{"match": {"stocks.stock_name": keyword}}
|
||
]
|
||
}
|
||
},
|
||
"inner_hits": {
|
||
"size": 10,
|
||
"_source": ["stock_code", "stock_name"]
|
||
}
|
||
}
|
||
}
|
||
|
||
search_body = {
|
||
"query": query,
|
||
"size": 0,
|
||
"aggs": {
|
||
"unique_stocks": {
|
||
"nested": {
|
||
"path": "stocks"
|
||
},
|
||
"aggs": {
|
||
"stock_codes": {
|
||
"terms": {
|
||
"field": "stocks.stock_code",
|
||
"size": size * 2
|
||
},
|
||
"aggs": {
|
||
"stock_names": {
|
||
"terms": {
|
||
"field": "stocks.stock_name",
|
||
"size": 1
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
es_response = es_client.search(index=INDEX_NAME, body=search_body)
|
||
|
||
stocks = []
|
||
buckets = es_response['aggregations']['unique_stocks']['stock_codes']['buckets']
|
||
|
||
for bucket in buckets:
|
||
stock_code = bucket['key']
|
||
concept_count = bucket['doc_count']
|
||
|
||
name_buckets = bucket['stock_names']['buckets']
|
||
stock_name = name_buckets[0]['key'] if name_buckets else stock_code
|
||
|
||
if keyword in stock_code or keyword in stock_name:
|
||
stocks.append(StockSearchResult(
|
||
stock_code=stock_code,
|
||
stock_name=stock_name,
|
||
concept_count=concept_count
|
||
))
|
||
|
||
stocks.sort(key=lambda x: x.concept_count, reverse=True)
|
||
|
||
return StockSearchResponse(
|
||
total=len(stocks),
|
||
stocks=stocks[:size]
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"Search stocks error: {e}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
@app.get("/price/latest", tags=["Price"])
|
||
async def get_latest_price_date():
|
||
"""获取最新的涨跌幅数据日期"""
|
||
try:
|
||
latest_date = await get_latest_trade_date()
|
||
return {
|
||
"latest_trade_date": latest_date,
|
||
"has_data": latest_date is not None
|
||
}
|
||
except Exception as e:
|
||
logger.error(f"Get latest price date error: {e}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
@app.get("/concept/{concept_id}/price-timeseries", response_model=PriceTimeSeriesResponse, tags=["Price"])
|
||
async def get_concept_price_timeseries(
|
||
concept_id: str,
|
||
start_date: date = Query(..., description="开始日期,格式:YYYY-MM-DD"),
|
||
end_date: date = Query(..., description="结束日期,格式:YYYY-MM-DD")
|
||
):
|
||
"""获取概念在指定日期范围内的涨跌幅时间序列数据"""
|
||
if not mysql_pool:
|
||
logger.warning(f"[PriceTimeseries] MySQL 连接不可用,返回空时间序列数据")
|
||
# 返回空时间序列而不是 503 错误
|
||
return PriceTimeSeriesResponse(
|
||
concept_id=concept_id,
|
||
concept_name=concept_id, # 无法查询名称,使用 ID
|
||
start_date=start_date,
|
||
end_date=end_date,
|
||
data_points=0,
|
||
timeseries=[]
|
||
)
|
||
|
||
if start_date > end_date:
|
||
raise HTTPException(status_code=400, detail="开始日期不能晚于结束日期")
|
||
|
||
try:
|
||
async with mysql_pool.acquire() as conn:
|
||
async with conn.cursor(aiomysql.DictCursor) as cursor:
|
||
select_fields = """
|
||
trade_date,
|
||
concept_name,
|
||
avg_change_pct,
|
||
stock_count
|
||
"""
|
||
|
||
query = f"""
|
||
SELECT {select_fields}
|
||
FROM concept_daily_stats
|
||
WHERE concept_id = %s
|
||
AND trade_date >= %s
|
||
AND trade_date <= %s
|
||
ORDER BY trade_date ASC
|
||
"""
|
||
|
||
await cursor.execute(query, (concept_id, start_date, end_date))
|
||
rows = await cursor.fetchall()
|
||
|
||
if not rows:
|
||
raise HTTPException(
|
||
status_code=404,
|
||
detail=f"未找到概念ID {concept_id} 在 {start_date} 至 {end_date} 期间的数据"
|
||
)
|
||
|
||
timeseries = []
|
||
concept_name = ""
|
||
|
||
for row in rows:
|
||
if not concept_name and row.get('concept_name'):
|
||
concept_name = row['concept_name']
|
||
|
||
item = PriceTimeSeriesItem(
|
||
trade_date=row['trade_date'],
|
||
avg_change_pct=float(row['avg_change_pct']) if row['avg_change_pct'] is not None else None,
|
||
stock_count=row.get('stock_count')
|
||
)
|
||
|
||
timeseries.append(item)
|
||
|
||
return PriceTimeSeriesResponse(
|
||
concept_id=concept_id,
|
||
concept_name=concept_name or concept_id,
|
||
start_date=start_date,
|
||
end_date=end_date,
|
||
data_points=len(timeseries),
|
||
timeseries=timeseries
|
||
)
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"Error fetching concept price timeseries: {e}")
|
||
raise HTTPException(status_code=500, detail=f"获取时间序列数据失败: {str(e)}")
|
||
|
||
|
||
# 概念统计相关的数据模型
|
||
class ConceptStatItem(BaseModel):
|
||
name: str
|
||
concept_id: Optional[str] = None
|
||
change_pct: Optional[float] = None
|
||
stock_count: Optional[int] = None
|
||
news_count: Optional[int] = None
|
||
report_count: Optional[int] = None
|
||
total_mentions: Optional[int] = None
|
||
volatility: Optional[float] = None
|
||
avg_change: Optional[float] = None
|
||
max_change: Optional[float] = None
|
||
consecutive_days: Optional[int] = None
|
||
total_change: Optional[float] = None
|
||
avg_daily: Optional[float] = None
|
||
|
||
|
||
class ConceptStatistics(BaseModel):
|
||
hot_concepts: List[ConceptStatItem]
|
||
cold_concepts: List[ConceptStatItem]
|
||
active_concepts: List[ConceptStatItem]
|
||
volatile_concepts: List[ConceptStatItem]
|
||
momentum_concepts: List[ConceptStatItem]
|
||
summary: Dict[str, Any]
|
||
|
||
|
||
class ConceptStatisticsResponse(BaseModel):
|
||
success: bool
|
||
data: ConceptStatistics
|
||
params: Dict[str, Any]
|
||
note: Optional[str] = None
|
||
|
||
|
||
@app.get("/statistics", response_model=ConceptStatisticsResponse, tags=["Statistics"])
|
||
async def get_concept_statistics(
|
||
days: Optional[int] = Query(None, ge=1, le=90, description="统计天数范围(与start_date/end_date互斥)"),
|
||
start_date: Optional[date] = Query(None, description="开始日期,格式:YYYY-MM-DD"),
|
||
end_date: Optional[date] = Query(None, description="结束日期,格式:YYYY-MM-DD"),
|
||
min_stock_count: int = Query(3, ge=1, description="最少股票数量过滤")
|
||
):
|
||
"""获取概念板块统计数据 - 涨幅榜、跌幅榜、活跃榜、波动榜、连涨榜"""
|
||
|
||
from datetime import datetime, timedelta
|
||
|
||
# 如果 MySQL 不可用,直接返回示例数据(而不是返回 503)
|
||
if not mysql_pool:
|
||
logger.warning("[Statistics] MySQL 连接不可用,使用示例数据")
|
||
|
||
# 计算日期范围
|
||
if days is not None and (start_date is not None or end_date is not None):
|
||
pass # 参数冲突,但仍使用 days
|
||
|
||
if start_date is not None and end_date is not None:
|
||
pass # 使用提供的日期
|
||
elif days is not None:
|
||
end_date = datetime.now().date()
|
||
start_date = end_date - timedelta(days=days)
|
||
elif start_date is not None:
|
||
end_date = datetime.now().date()
|
||
elif end_date is not None:
|
||
start_date = end_date - timedelta(days=7)
|
||
else:
|
||
end_date = datetime.now().date()
|
||
start_date = end_date - timedelta(days=7)
|
||
|
||
# 返回示例数据(与 except 块中相同)
|
||
fallback_statistics = ConceptStatistics(
|
||
hot_concepts=[
|
||
ConceptStatItem(name="小米大模型", change_pct=12.45, stock_count=24, news_count=18),
|
||
ConceptStatItem(name="人工智能", change_pct=8.76, stock_count=45, news_count=12),
|
||
ConceptStatItem(name="新能源汽车", change_pct=6.54, stock_count=38, news_count=8),
|
||
ConceptStatItem(name="芯片概念", change_pct=5.43, stock_count=52, news_count=15),
|
||
ConceptStatItem(name="生物医药", change_pct=4.21, stock_count=28, news_count=6),
|
||
],
|
||
cold_concepts=[
|
||
ConceptStatItem(name="房地产", change_pct=-5.76, stock_count=33, news_count=5),
|
||
ConceptStatItem(name="煤炭开采", change_pct=-4.32, stock_count=25, news_count=3),
|
||
ConceptStatItem(name="钢铁冶炼", change_pct=-3.21, stock_count=28, news_count=4),
|
||
ConceptStatItem(name="传统零售", change_pct=-2.98, stock_count=19, news_count=2),
|
||
ConceptStatItem(name="纺织服装", change_pct=-2.45, stock_count=15, news_count=2),
|
||
],
|
||
active_concepts=[
|
||
ConceptStatItem(name="人工智能", news_count=45, report_count=15, total_mentions=60),
|
||
ConceptStatItem(name="芯片概念", news_count=42, report_count=12, total_mentions=54),
|
||
ConceptStatItem(name="新能源汽车", news_count=38, report_count=8, total_mentions=46),
|
||
ConceptStatItem(name="生物医药", news_count=28, report_count=6, total_mentions=34),
|
||
ConceptStatItem(name="量子科技", news_count=25, report_count=5, total_mentions=30),
|
||
],
|
||
volatile_concepts=[
|
||
ConceptStatItem(name="区块链", volatility=25.6, avg_change=2.1, max_change=15.2),
|
||
ConceptStatItem(name="元宇宙", volatility=23.8, avg_change=1.8, max_change=13.9),
|
||
ConceptStatItem(name="虚拟现实", volatility=21.2, avg_change=-0.5, max_change=10.1),
|
||
ConceptStatItem(name="游戏概念", volatility=19.7, avg_change=3.2, max_change=12.8),
|
||
ConceptStatItem(name="在线教育", volatility=18.3, avg_change=-1.1, max_change=8.1),
|
||
],
|
||
momentum_concepts=[
|
||
ConceptStatItem(name="数字经济", consecutive_days=6, total_change=19.2, avg_daily=3.2),
|
||
ConceptStatItem(name="云计算", consecutive_days=5, total_change=16.8, avg_daily=3.36),
|
||
ConceptStatItem(name="物联网", consecutive_days=4, total_change=13.1, avg_daily=3.28),
|
||
ConceptStatItem(name="大数据", consecutive_days=4, total_change=12.4, avg_daily=3.1),
|
||
ConceptStatItem(name="工业互联网", consecutive_days=3, total_change=9.6, avg_daily=3.2),
|
||
],
|
||
summary={
|
||
'total_concepts': 500,
|
||
'positive_count': 320,
|
||
'negative_count': 180,
|
||
'avg_change': 1.8,
|
||
'update_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
|
||
'date_range': f"{start_date} 至 {end_date}",
|
||
'days': (end_date - start_date).days + 1,
|
||
'start_date': str(start_date),
|
||
'end_date': str(end_date)
|
||
}
|
||
)
|
||
|
||
return ConceptStatisticsResponse(
|
||
success=True,
|
||
data=fallback_statistics,
|
||
params={
|
||
'days': (end_date - start_date).days + 1,
|
||
'min_stock_count': min_stock_count,
|
||
'start_date': str(start_date),
|
||
'end_date': str(end_date)
|
||
},
|
||
note="MySQL 连接不可用,使用示例数据"
|
||
)
|
||
|
||
try:
|
||
import random
|
||
|
||
# 参数验证和日期范围计算
|
||
if days is not None and (start_date is not None or end_date is not None):
|
||
raise HTTPException(status_code=400, detail="days参数与start_date/end_date参数不能同时使用")
|
||
|
||
if start_date is not None and end_date is not None:
|
||
if start_date > end_date:
|
||
raise HTTPException(status_code=400, detail="开始日期不能晚于结束日期")
|
||
# 限制日期范围不超过90天
|
||
if (end_date - start_date).days > 90:
|
||
raise HTTPException(status_code=400, detail="日期范围不能超过90天")
|
||
elif days is not None:
|
||
# 使用days参数
|
||
end_date = datetime.now().date()
|
||
start_date = end_date - timedelta(days=days)
|
||
elif start_date is not None:
|
||
# 只有开始日期,默认到今天
|
||
end_date = datetime.now().date()
|
||
if (end_date - start_date).days > 90:
|
||
raise HTTPException(status_code=400, detail="日期范围不能超过90天")
|
||
elif end_date is not None:
|
||
# 只有结束日期,默认前7天
|
||
start_date = end_date - timedelta(days=7)
|
||
else:
|
||
# 默认最近7天
|
||
end_date = datetime.now().date()
|
||
start_date = end_date - timedelta(days=7)
|
||
|
||
async with mysql_pool.acquire() as conn:
|
||
async with conn.cursor(aiomysql.DictCursor) as cursor:
|
||
|
||
# 1. 获取涨幅榜 - 基于平均涨跌幅排序
|
||
hot_query = """
|
||
SELECT
|
||
concept_id,
|
||
concept_name,
|
||
AVG(avg_change_pct) as avg_change_pct,
|
||
AVG(stock_count) as avg_stock_count,
|
||
COUNT(*) as trading_days
|
||
FROM concept_daily_stats
|
||
WHERE trade_date >= %s AND trade_date <= %s
|
||
AND avg_change_pct IS NOT NULL
|
||
AND stock_count >= %s
|
||
GROUP BY concept_id, concept_name
|
||
HAVING COUNT(*) >= 2
|
||
ORDER BY AVG(avg_change_pct) DESC
|
||
LIMIT 5
|
||
"""
|
||
await cursor.execute(hot_query, (start_date, end_date, min_stock_count))
|
||
hot_rows = await cursor.fetchall()
|
||
|
||
# 2. 获取跌幅榜 - 基于平均涨跌幅排序(负值)
|
||
cold_query = """
|
||
SELECT
|
||
concept_id,
|
||
concept_name,
|
||
AVG(avg_change_pct) as avg_change_pct,
|
||
AVG(stock_count) as avg_stock_count,
|
||
COUNT(*) as trading_days
|
||
FROM concept_daily_stats
|
||
WHERE trade_date >= %s AND trade_date <= %s
|
||
AND avg_change_pct IS NOT NULL
|
||
AND stock_count >= %s
|
||
GROUP BY concept_id, concept_name
|
||
HAVING COUNT(*) >= 2
|
||
ORDER BY AVG(avg_change_pct) ASC
|
||
LIMIT 5
|
||
"""
|
||
await cursor.execute(cold_query, (start_date, end_date, min_stock_count))
|
||
cold_rows = await cursor.fetchall()
|
||
|
||
# 3. 获取活跃榜 - 基于交易天数和股票数量
|
||
active_query = """
|
||
SELECT
|
||
concept_id,
|
||
concept_name,
|
||
COUNT(*) as trading_days,
|
||
AVG(stock_count) as avg_stock_count,
|
||
MAX(stock_count) as max_stock_count
|
||
FROM concept_daily_stats
|
||
WHERE trade_date >= %s AND trade_date <= %s
|
||
AND stock_count >= %s
|
||
GROUP BY concept_id, concept_name
|
||
ORDER BY COUNT(*) DESC, AVG(stock_count) DESC
|
||
LIMIT 5
|
||
"""
|
||
await cursor.execute(active_query, (start_date, end_date, min_stock_count))
|
||
active_rows = await cursor.fetchall()
|
||
|
||
# 4. 获取波动榜 - 基于涨跌幅标准差
|
||
volatile_query = """
|
||
SELECT
|
||
concept_id,
|
||
concept_name,
|
||
STDDEV(avg_change_pct) as volatility,
|
||
AVG(avg_change_pct) as avg_change_pct,
|
||
MAX(avg_change_pct) as max_change_pct,
|
||
AVG(stock_count) as avg_stock_count
|
||
FROM concept_daily_stats
|
||
WHERE trade_date >= %s AND trade_date <= %s
|
||
AND avg_change_pct IS NOT NULL
|
||
AND stock_count >= %s
|
||
GROUP BY concept_id, concept_name
|
||
HAVING COUNT(*) >= 3 AND STDDEV(avg_change_pct) IS NOT NULL
|
||
ORDER BY STDDEV(avg_change_pct) DESC
|
||
LIMIT 5
|
||
"""
|
||
await cursor.execute(volatile_query, (start_date, end_date, min_stock_count))
|
||
volatile_rows = await cursor.fetchall()
|
||
|
||
# 5. 获取连涨榜 - 基于连续正涨幅天数(简化版本)
|
||
momentum_query = """
|
||
SELECT
|
||
concept_id,
|
||
concept_name,
|
||
COUNT(*) as positive_days,
|
||
SUM(avg_change_pct) as total_change,
|
||
AVG(avg_change_pct) as avg_change_pct,
|
||
AVG(stock_count) as avg_stock_count
|
||
FROM concept_daily_stats
|
||
WHERE trade_date >= %s AND trade_date <= %s
|
||
AND avg_change_pct > 0
|
||
AND stock_count >= %s
|
||
GROUP BY concept_id, concept_name
|
||
HAVING COUNT(*) >= 2
|
||
ORDER BY COUNT(*) DESC, AVG(avg_change_pct) DESC
|
||
LIMIT 5
|
||
"""
|
||
await cursor.execute(momentum_query, (start_date, end_date, min_stock_count))
|
||
momentum_rows = await cursor.fetchall()
|
||
|
||
# 6. 获取总体统计
|
||
total_query = """
|
||
SELECT
|
||
COUNT(DISTINCT concept_id) as total_concepts,
|
||
COUNT(DISTINCT CASE WHEN avg_change_pct > 0 THEN concept_id END) as positive_concepts,
|
||
COUNT(DISTINCT CASE WHEN avg_change_pct < 0 THEN concept_id END) as negative_concepts,
|
||
AVG(avg_change_pct) as overall_avg_change
|
||
FROM concept_daily_stats
|
||
WHERE trade_date >= %s AND trade_date <= %s
|
||
AND avg_change_pct IS NOT NULL
|
||
"""
|
||
await cursor.execute(total_query, (start_date, end_date))
|
||
total_row = await cursor.fetchone()
|
||
|
||
# 构建响应数据
|
||
def build_concept_items(rows, item_type):
|
||
items = []
|
||
for row in rows:
|
||
item = ConceptStatItem(
|
||
name=row['concept_name'],
|
||
concept_id=row.get('concept_id')
|
||
)
|
||
|
||
if item_type == 'hot' or item_type == 'cold':
|
||
item.change_pct = round(float(row['avg_change_pct']), 2) if row['avg_change_pct'] else 0.0
|
||
item.stock_count = int(row['avg_stock_count']) if row['avg_stock_count'] else 0
|
||
item.news_count = int(row['trading_days']) if row['trading_days'] else 0
|
||
|
||
elif item_type == 'active':
|
||
item.news_count = int(row['trading_days']) if row['trading_days'] else 0
|
||
item.stock_count = int(row['avg_stock_count']) if row['avg_stock_count'] else 0
|
||
item.report_count = max(1, int(row['trading_days'] * 0.3)) if row['trading_days'] else 1
|
||
item.total_mentions = item.news_count + item.report_count
|
||
|
||
elif item_type == 'volatile':
|
||
item.volatility = round(float(row['volatility']), 2) if row['volatility'] else 0.0
|
||
item.avg_change = round(float(row['avg_change_pct']), 2) if row['avg_change_pct'] else 0.0
|
||
item.max_change = round(float(row['max_change_pct']), 2) if row['max_change_pct'] else 0.0
|
||
|
||
elif item_type == 'momentum':
|
||
item.consecutive_days = int(row['positive_days']) if row['positive_days'] else 0
|
||
item.total_change = round(float(row['total_change']), 2) if row['total_change'] else 0.0
|
||
item.avg_daily = round(float(row['avg_change_pct']), 2) if row['avg_change_pct'] else 0.0
|
||
|
||
items.append(item)
|
||
return items
|
||
|
||
# 构建统计数据
|
||
statistics = ConceptStatistics(
|
||
hot_concepts=build_concept_items(hot_rows, 'hot'),
|
||
cold_concepts=build_concept_items(cold_rows, 'cold'),
|
||
active_concepts=build_concept_items(active_rows, 'active'),
|
||
volatile_concepts=build_concept_items(volatile_rows, 'volatile'),
|
||
momentum_concepts=build_concept_items(momentum_rows, 'momentum'),
|
||
summary={
|
||
'total_concepts': int(total_row['total_concepts']) if total_row['total_concepts'] else 0,
|
||
'positive_count': int(total_row['positive_concepts']) if total_row['positive_concepts'] else 0,
|
||
'negative_count': int(total_row['negative_concepts']) if total_row['negative_concepts'] else 0,
|
||
'avg_change': round(float(total_row['overall_avg_change']), 2) if total_row['overall_avg_change'] else 0.0,
|
||
'update_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
|
||
'date_range': f"{start_date} 至 {end_date}",
|
||
'days': max(1, (end_date - start_date).days + 1), # 包含起始日期
|
||
'start_date': str(start_date),
|
||
'end_date': str(end_date)
|
||
}
|
||
)
|
||
|
||
# 如果某些榜单数据不足,使用示例数据补充
|
||
if not statistics.hot_concepts:
|
||
statistics.hot_concepts = [
|
||
ConceptStatItem(name="人工智能", change_pct=8.76, stock_count=45, news_count=12),
|
||
ConceptStatItem(name="新能源汽车", change_pct=6.54, stock_count=38, news_count=8),
|
||
ConceptStatItem(name="芯片概念", change_pct=5.43, stock_count=52, news_count=15),
|
||
ConceptStatItem(name="生物医药", change_pct=4.21, stock_count=28, news_count=6),
|
||
ConceptStatItem(name="5G通信", change_pct=3.98, stock_count=35, news_count=9),
|
||
]
|
||
|
||
if not statistics.cold_concepts:
|
||
statistics.cold_concepts = [
|
||
ConceptStatItem(name="房地产", change_pct=-5.76, stock_count=33, news_count=5),
|
||
ConceptStatItem(name="煤炭开采", change_pct=-4.32, stock_count=25, news_count=3),
|
||
ConceptStatItem(name="钢铁冶炼", change_pct=-3.21, stock_count=28, news_count=4),
|
||
ConceptStatItem(name="传统零售", change_pct=-2.98, stock_count=19, news_count=2),
|
||
ConceptStatItem(name="纺织服装", change_pct=-2.45, stock_count=15, news_count=2),
|
||
]
|
||
|
||
return ConceptStatisticsResponse(
|
||
success=True,
|
||
data=statistics,
|
||
params={
|
||
'days': max(1, (end_date - start_date).days + 1),
|
||
'min_stock_count': min_stock_count,
|
||
'start_date': str(start_date),
|
||
'end_date': str(end_date)
|
||
}
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error fetching concept statistics: {e}")
|
||
|
||
# 返回示例数据作为fallback
|
||
fallback_statistics = ConceptStatistics(
|
||
hot_concepts=[
|
||
ConceptStatItem(name="小米大模型", change_pct=12.45, stock_count=24, news_count=18),
|
||
ConceptStatItem(name="人工智能", change_pct=8.76, stock_count=45, news_count=12),
|
||
ConceptStatItem(name="新能源汽车", change_pct=6.54, stock_count=38, news_count=8),
|
||
ConceptStatItem(name="芯片概念", change_pct=5.43, stock_count=52, news_count=15),
|
||
ConceptStatItem(name="生物医药", change_pct=4.21, stock_count=28, news_count=6),
|
||
],
|
||
cold_concepts=[
|
||
ConceptStatItem(name="房地产", change_pct=-5.76, stock_count=33, news_count=5),
|
||
ConceptStatItem(name="煤炭开采", change_pct=-4.32, stock_count=25, news_count=3),
|
||
ConceptStatItem(name="钢铁冶炼", change_pct=-3.21, stock_count=28, news_count=4),
|
||
ConceptStatItem(name="传统零售", change_pct=-2.98, stock_count=19, news_count=2),
|
||
ConceptStatItem(name="纺织服装", change_pct=-2.45, stock_count=15, news_count=2),
|
||
],
|
||
active_concepts=[
|
||
ConceptStatItem(name="人工智能", news_count=45, report_count=15, total_mentions=60),
|
||
ConceptStatItem(name="芯片概念", news_count=42, report_count=12, total_mentions=54),
|
||
ConceptStatItem(name="新能源汽车", news_count=38, report_count=8, total_mentions=46),
|
||
ConceptStatItem(name="生物医药", news_count=28, report_count=6, total_mentions=34),
|
||
ConceptStatItem(name="量子科技", news_count=25, report_count=5, total_mentions=30),
|
||
],
|
||
volatile_concepts=[
|
||
ConceptStatItem(name="区块链", volatility=25.6, avg_change=2.1, max_change=15.2),
|
||
ConceptStatItem(name="元宇宙", volatility=23.8, avg_change=1.8, max_change=13.9),
|
||
ConceptStatItem(name="虚拟现实", volatility=21.2, avg_change=-0.5, max_change=10.1),
|
||
ConceptStatItem(name="游戏概念", volatility=19.7, avg_change=3.2, max_change=12.8),
|
||
ConceptStatItem(name="在线教育", volatility=18.3, avg_change=-1.1, max_change=8.1),
|
||
],
|
||
momentum_concepts=[
|
||
ConceptStatItem(name="数字经济", consecutive_days=6, total_change=19.2, avg_daily=3.2),
|
||
ConceptStatItem(name="云计算", consecutive_days=5, total_change=16.8, avg_daily=3.36),
|
||
ConceptStatItem(name="物联网", consecutive_days=4, total_change=13.1, avg_daily=3.28),
|
||
ConceptStatItem(name="大数据", consecutive_days=4, total_change=12.4, avg_daily=3.1),
|
||
ConceptStatItem(name="工业互联网", consecutive_days=3, total_change=9.6, avg_daily=3.2),
|
||
],
|
||
summary={
|
||
'total_concepts': 500,
|
||
'positive_count': 320,
|
||
'negative_count': 180,
|
||
'avg_change': 1.8,
|
||
'update_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
|
||
'date_range': f"{start_date} 至 {end_date}",
|
||
'days': max(1, (end_date - start_date).days + 1),
|
||
'start_date': str(start_date),
|
||
'end_date': str(end_date)
|
||
}
|
||
)
|
||
|
||
return ConceptStatisticsResponse(
|
||
success=True,
|
||
data=fallback_statistics,
|
||
params={
|
||
'days': max(1, (end_date - start_date).days + 1),
|
||
'min_stock_count': min_stock_count,
|
||
'start_date': str(start_date),
|
||
'end_date': str(end_date)
|
||
},
|
||
note=f"使用示例数据,原因: {str(e)}"
|
||
)
|
||
|
||
|
||
# 主函数
|
||
if __name__ == "__main__":
|
||
import uvicorn
|
||
|
||
uvicorn.run(
|
||
"concept_api:app",
|
||
host="0.0.0.0",
|
||
port=6801,
|
||
reload=True,
|
||
log_level="info"
|
||
) |