Files
vf_react/concept_api.py
2025-10-11 12:02:01 +08:00

1463 lines
54 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import openai
from typing import List, Dict, Optional, Union, Any
from fastapi import FastAPI, HTTPException, Query
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from elasticsearch import Elasticsearch
from datetime import datetime, date
import logging
import re
from contextlib import asynccontextmanager
import aiomysql
from decimal import Decimal
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 全局变量
es_client = None
openai_client = None
mysql_pool = None
# 配置
ES_HOST = 'http://192.168.1.58:9200'
OPENAI_BASE_URL = "http://192.168.1.58:8000/v1"
OPENAI_API_KEY = "dummy"
EMBEDDING_MODEL = "qwen3-embedding-8b"
INDEX_NAME = 'concept_library'
# MySQL配置
MYSQL_CONFIG = {
'host': '192.168.1.14',
'user': 'root',
'password': 'Zzl5588161!',
'db': 'stock',
'charset': 'utf8mb4',
'autocommit': True,
'minsize': 1,
'maxsize': 10
}
# API生命周期管理
@asynccontextmanager
async def lifespan(app: FastAPI):
# 启动时初始化
global es_client, openai_client, mysql_pool
# 初始化Elasticsearch客户端 - 增加超时配置
es_client = Elasticsearch(
[ES_HOST],
timeout=30,
max_retries=3,
retry_on_timeout=True
)
logger.info(f"Connected to Elasticsearch at {ES_HOST}")
# 初始化OpenAI客户端
openai_client = openai.OpenAI(
api_key=OPENAI_API_KEY,
base_url=OPENAI_BASE_URL,
timeout=60,
)
logger.info(f"Initialized OpenAI client")
# 初始化MySQL连接池
try:
mysql_pool = await aiomysql.create_pool(**MYSQL_CONFIG)
logger.info(f"Connected to MySQL at {MYSQL_CONFIG['host']}")
except Exception as e:
logger.error(f"Failed to connect to MySQL: {e}")
yield
# 关闭时清理资源
if es_client:
es_client.close()
if mysql_pool:
mysql_pool.close()
await mysql_pool.wait_closed()
logger.info("Cleanup completed")
# 创建FastAPI应用
app = FastAPI(
title="概念搜索API",
description="支持语义和关键词混合搜索的概念库API包含概念涨跌幅数据",
version="1.2.0",
lifespan=lifespan
)
# 添加CORS中间件
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 请求和响应模型
class SearchRequest(BaseModel):
query: str = Field(..., description="搜索查询文本")
size: int = Field(10, ge=1, le=100, description="每页返回结果数量")
page: int = Field(1, ge=1, description="页码")
search_size: int = Field(100, ge=10, le=1000, description="搜索数量从ES获取的结果数用于排序后分页")
semantic_weight: Optional[float] = Field(None, ge=0.0, le=1.0, description="语义搜索权重(0-1)None表示自动计算")
filter_stocks: Optional[List[str]] = Field(None, description="过滤特定股票代码或名称")
trade_date: Optional[date] = Field(None, description="交易日期格式YYYY-MM-DD默认返回最新日期数据")
sort_by: str = Field("change_pct", description="排序方式: change_pct, _score, stock_count, concept_name")
use_knn: bool = Field(True, description="是否使用KNN搜索优化语义搜索")
class StockInfo(BaseModel):
stock_name: str
stock_code: str
reason: Optional[str] = None
industry: Optional[str] = Field(None, alias="行业")
project: Optional[str] = Field(None, alias="项目")
class Config:
populate_by_name = True
class ConceptPriceInfo(BaseModel):
trade_date: date
avg_change_pct: Optional[float] = Field(None, description="平均涨跌幅(%)")
class ConceptResult(BaseModel):
concept_id: str
concept: str
description: Optional[str]
stocks: List[StockInfo]
stock_count: int
happened_times: Optional[List[str]] = None
score: float
match_type: str # "semantic", "keyword", "hybrid", "semantic_knn", "hybrid_knn"
highlights: Optional[Dict[str, List[str]]] = None
price_info: Optional[ConceptPriceInfo] = None
class SearchResponse(BaseModel):
total: int
took_ms: int
results: List[ConceptResult]
search_info: Dict[str, Any]
price_date: Optional[date] = None
page: int = Field(1, description="当前页码")
total_pages: int = Field(1, description="总页数")
class StockConceptInfo(BaseModel):
concept_id: str
concept: str
stock_count: int
happened_times: Optional[List[str]] = None
description: Optional[str] = None
stock_detail: Optional[Dict[str, Any]] = None
price_info: Optional[ConceptPriceInfo] = None
class StockConceptsResponse(BaseModel):
stock_code: str
stats: Dict[str, Any]
concepts: List[StockConceptInfo]
price_date: Optional[date] = None
class StockSearchResult(BaseModel):
stock_code: str
stock_name: str
concept_count: int
class StockSearchResponse(BaseModel):
total: int
stocks: List[StockSearchResult]
class PriceTimeSeriesItem(BaseModel):
trade_date: date
avg_change_pct: Optional[float] = Field(None, description="平均涨跌幅(%)")
stock_count: Optional[int] = Field(None, description="当日股票数量")
class PriceTimeSeriesResponse(BaseModel):
concept_id: str
concept_name: str
start_date: date
end_date: date
data_points: int
timeseries: List[PriceTimeSeriesItem]
# 辅助函数
async def get_concept_price_data(concept_ids: List[str], trade_date: Optional[date] = None) -> Dict[
str, ConceptPriceInfo]:
"""获取概念的涨跌幅数据"""
if not mysql_pool or not concept_ids:
return {}
try:
async with mysql_pool.acquire() as conn:
async with conn.cursor(aiomysql.DictCursor) as cursor:
placeholders = ','.join(['%s'] * len(concept_ids))
if trade_date:
query = f"""
SELECT concept_id, concept_name, trade_date, avg_change_pct
FROM concept_daily_stats
WHERE concept_id IN ({placeholders}) AND trade_date = %s
"""
await cursor.execute(query, (*concept_ids, trade_date))
else:
query = f"""
SELECT cds.concept_id, cds.concept_name, cds.trade_date, cds.avg_change_pct
FROM concept_daily_stats cds
INNER JOIN (
SELECT concept_id, MAX(trade_date) as max_date
FROM concept_daily_stats
WHERE concept_id IN ({placeholders})
GROUP BY concept_id
) latest ON cds.concept_id = latest.concept_id
AND cds.trade_date = latest.max_date
"""
await cursor.execute(query, concept_ids)
rows = await cursor.fetchall()
result = {}
for row in rows:
result[row['concept_id']] = ConceptPriceInfo(
trade_date=row['trade_date'],
avg_change_pct=float(row['avg_change_pct']) if row['avg_change_pct'] is not None else None
)
return result
except Exception as e:
logger.error(f"Error fetching concept price data: {e}")
return {}
async def get_latest_trade_date() -> Optional[date]:
"""获取最新的交易日期"""
if not mysql_pool:
return None
try:
async with mysql_pool.acquire() as conn:
async with conn.cursor() as cursor:
query = "SELECT MAX(trade_date) as latest_date FROM concept_daily_stats"
await cursor.execute(query)
result = await cursor.fetchone()
return result[0] if result and result[0] else None
except Exception as e:
logger.error(f"Error fetching latest trade date: {e}")
return None
def generate_embedding(text: str) -> List[float]:
"""生成文本向量"""
try:
if not text or len(text.strip()) == 0:
logger.warning("Empty text provided for embedding generation")
return []
# 限制文本长度
text = text[:16000] if len(text) > 16000 else text
# 检查OpenAI客户端是否已初始化
if not openai_client:
logger.warning("OpenAI client not initialized, falling back to keyword search")
return []
response = openai_client.embeddings.create(
model=EMBEDDING_MODEL,
input=[text]
)
embedding = response.data[0].embedding
# 验证embedding是否有效
if not embedding or len(embedding) == 0:
logger.warning("Empty embedding returned from OpenAI API")
return []
logger.debug(f"Successfully generated embedding for text: {text[:100]}...")
return embedding
except openai.OpenAIError as e:
logger.warning(f"OpenAI API error during embedding generation: {e}. Falling back to keyword search.")
return []
except Exception as e:
logger.warning(f"Unexpected error during embedding generation: {e}. Falling back to keyword search.")
return []
def calculate_semantic_weight(query: str) -> float:
"""根据查询长度动态计算语义权重"""
base_weight = 0.3
query_length = len(query)
word_count = len(query.split())
if query_length < 10:
length_factor = 0.0
elif query_length < 50:
length_factor = 0.2
elif query_length < 200:
length_factor = 0.4
elif query_length < 500:
length_factor = 0.5
else:
length_factor = 0.6
if word_count < 3:
word_factor = 0.0
elif word_count < 10:
word_factor = 0.2
elif word_count < 30:
word_factor = 0.3
else:
word_factor = 0.4
if re.search(r'\b\d{6}\b', query):
pattern_factor = -0.2
elif word_count <= 2 and query_length < 20:
pattern_factor = -0.1
elif '' in query or '' in query or len(query) > 100:
pattern_factor = 0.2
else:
pattern_factor = 0.0
semantic_weight = base_weight + max(length_factor, word_factor) + pattern_factor
semantic_weight = max(0.3, min(0.9, semantic_weight))
return semantic_weight
def extract_stock_filter(query: str) -> tuple[str, List[str]]:
"""从查询中提取股票过滤条件"""
stock_patterns = []
cleaned_query = query
code_matches = re.findall(r'\b(\d{6})\b', query)
stock_patterns.extend(code_matches)
stock_name_pattern = r'([\u4e00-\u9fa5]{2,6})(?:股票|股份|集团|公司)'
name_matches = re.findall(stock_name_pattern, query)
stock_patterns.extend(name_matches)
for pattern in stock_patterns:
cleaned_query = cleaned_query.replace(pattern, '')
cleaned_query = ' '.join(cleaned_query.split())
return cleaned_query, stock_patterns
def build_keyword_query(query: str, stock_filters: List[str] = None) -> Dict:
"""构建关键词查询"""
must_queries = []
if query.strip():
must_queries.append({
"multi_match": {
"query": query,
"fields": [
"concept^3",
"description^2",
"stocks_reason",
"stocks_full_text",
"stocks.stock_name^2",
"stocks.stock_code^2",
"stocks.项目",
"stocks.reason",
"stocks.行业"
],
"type": "best_fields",
"analyzer": "ik_smart"
}
})
if stock_filters:
stock_query = {
"nested": {
"path": "stocks",
"query": {
"bool": {
"should": [
{"terms": {"stocks.stock_name": stock_filters}},
{"terms": {"stocks.stock_code": stock_filters}}
]
}
}
}
}
must_queries.append(stock_query)
return {
"bool": {
"must": must_queries
}
}
def build_semantic_query(embedding: List[float]) -> Dict:
"""构建语义查询"""
return {
"script_score": {
"query": {"match_all": {}},
"script": {
"source": "cosineSimilarity(params.query_vector, 'description_embedding') + 1.0",
"params": {
"query_vector": embedding
}
}
}
}
def build_hybrid_query(query: str, embedding: List[float], semantic_weight: float,
stock_filters: List[str] = None) -> Dict:
"""构建混合查询"""
keyword_weight = 1.0 - semantic_weight
queries = []
if keyword_weight > 0:
keyword_query = build_keyword_query(query, stock_filters)
queries.append({
"bool": {
"must": [keyword_query],
"boost": keyword_weight
}
})
if semantic_weight > 0 and embedding:
queries.append({
"script_score": {
"query": {"match_all": {}},
"script": {
"source": f"cosineSimilarity(params.query_vector, 'description_embedding') * {semantic_weight} + 0.0",
"params": {
"query_vector": embedding
}
},
"boost": 1.0
}
})
return {
"bool": {
"should": queries,
"minimum_should_match": 1
}
}
def build_hybrid_knn_query(
query: str,
embedding: List[float],
semantic_weight: float,
stock_filters: List[str] = None,
k: int = 100
) -> Dict:
"""构建混合查询KNN + 关键词)"""
keyword_weight = 1.0 - semantic_weight
filter_query = None
if stock_filters:
filter_query = {
"nested": {
"path": "stocks",
"query": {
"bool": {
"should": [
{"terms": {"stocks.stock_name": stock_filters}},
{"terms": {"stocks.stock_code": stock_filters}}
]
}
}
}
}
search_body = {
"knn": {
"field": "description_embedding",
"query_vector": embedding,
"k": k,
"num_candidates": min(k * 2, 500),
"boost": semantic_weight
}
}
if filter_query:
search_body["knn"]["filter"] = filter_query
keyword_query = build_keyword_query(query, stock_filters)
search_body["query"] = {
"bool": {
"must": [keyword_query],
"boost": keyword_weight
}
}
return search_body
# API端点
@app.get("/", tags=["Health"])
async def root():
"""健康检查端点"""
return {"status": "healthy", "service": "概念搜索API", "version": "1.2.0"}
@app.post("/search", response_model=SearchResponse, tags=["Search"])
async def search_concepts(request: SearchRequest):
"""
搜索概念库 - 支持KNN优化的语义搜索
新特性:
- 使用KNN搜索优化语义搜索性能
- search_size参数控制搜索数量提高排序灵活性
- 支持混合搜索KNN + 关键词)
"""
start_time = datetime.now()
try:
# 提取股票过滤条件
cleaned_query, stock_filters = extract_stock_filter(request.query)
if request.filter_stocks:
stock_filters.extend(request.filter_stocks)
# 计算语义权重
if request.semantic_weight is not None:
semantic_weight = request.semantic_weight
else:
semantic_weight = calculate_semantic_weight(request.query)
# 生成embedding如果需要
embedding = []
if semantic_weight > 0:
embedding = generate_embedding(request.query)
if not embedding:
# 已经在generate_embedding中记录了详细日志这里只调整语义权重
semantic_weight = 0
# 【关键修改】:如果按涨跌幅排序,需要获取更多结果
effective_search_size = request.search_size
if request.sort_by == "change_pct":
# 按涨跌幅排序时,获取更多结果以确保排序准确性
effective_search_size = min(1000, request.search_size * 10) # 最多获取1000个
logger.info(f"Using expanded search size {effective_search_size} for change_pct sorting")
# 构建查询体
search_body = {}
match_type = "keyword"
# 根据搜索类型构建不同的查询
if semantic_weight == 0:
# 纯关键词搜索
search_body = {
"query": build_keyword_query(cleaned_query or request.query, stock_filters),
"size": effective_search_size # 使用有效搜索大小
}
match_type = "keyword"
elif semantic_weight == 1.0 and request.use_knn and embedding:
# 纯KNN语义搜索
filter_query = None
if stock_filters:
filter_query = {
"nested": {
"path": "stocks",
"query": {
"bool": {
"should": [
{"terms": {"stocks.stock_name": stock_filters}},
{"terms": {"stocks.stock_code": stock_filters}}
]
}
}
}
}
search_body = {
"knn": {
"field": "description_embedding",
"query_vector": embedding,
"k": effective_search_size, # 使用有效搜索大小
"num_candidates": min(effective_search_size * 2, 1000)
},
"size": effective_search_size
}
if filter_query:
search_body["knn"]["filter"] = filter_query
match_type = "semantic_knn"
elif request.use_knn and embedding:
# 混合搜索KNN + 关键词)
hybrid_body = build_hybrid_knn_query(
cleaned_query or request.query,
embedding,
semantic_weight,
stock_filters,
k=effective_search_size # 使用有效搜索大小
)
search_body = hybrid_body
search_body["size"] = effective_search_size
match_type = "hybrid_knn"
else:
# 传统混合搜索script_score方式作为后备
es_query = build_hybrid_query(
cleaned_query or request.query,
embedding,
semantic_weight,
stock_filters
)
search_body = {
"query": es_query,
"size": effective_search_size # 使用有效搜索大小
}
match_type = "hybrid"
# 添加高亮和源过滤
search_body.update({
"highlight": {
"fields": {
"concept": {},
"description": {"fragment_size": 150},
"stocks_reason": {"fragment_size": 150},
"stocks_full_text": {"fragment_size": 150}
}
},
"_source": {
"excludes": ["description_embedding"]
},
"track_total_hits": True
})
# 执行搜索(增加超时时间)
es_response = es_client.search(
index=INDEX_NAME,
body=search_body,
timeout="30s"
)
# 收集结果
all_results = []
concept_ids = []
for hit in es_response['hits']['hits']:
source = hit['_source']
concept_ids.append(source['concept_id'])
# 提取股票信息
stocks = []
for stock in source.get('stocks', [])[:20]:
stock_info = StockInfo(
stock_name=stock.get('stock_name', ''),
stock_code=stock.get('stock_code', ''),
reason=stock.get('reason'),
industry=stock.get('行业'),
project=stock.get('项目')
)
stocks.append(stock_info)
# 提取高亮信息
highlights = {}
if 'highlight' in hit:
highlights = hit['highlight']
# 构建结果
result = ConceptResult(
concept_id=source['concept_id'],
concept=source['concept'],
description=source.get('description'),
stocks=stocks,
stock_count=len(source.get('stocks', [])),
happened_times=source.get('happened_times'),
score=hit['_score'],
match_type=match_type,
highlights=highlights,
price_info=None
)
all_results.append(result)
# 【关键修改】:始终获取涨跌幅数据,无论排序方式
# 这样用户可以看到涨跌幅信息,即使不按涨跌幅排序
price_data = {}
actual_price_date = None
if concept_ids:
# 获取所有结果的价格数据
price_data = await get_concept_price_data(concept_ids, request.trade_date)
if price_data:
actual_price_date = next(iter(price_data.values())).trade_date
# 填充涨跌幅信息
for result in all_results:
result.price_info = price_data.get(result.concept_id)
# 根据排序方式排序
if request.sort_by == "change_pct":
# 按涨跌幅排序(降序)
all_results.sort(
key=lambda x: (
x.price_info.avg_change_pct
if x.price_info and x.price_info.avg_change_pct is not None
else -999
),
reverse=True
)
elif request.sort_by == "stock_count":
all_results.sort(key=lambda x: x.stock_count, reverse=True)
elif request.sort_by == "concept_name":
all_results.sort(key=lambda x: x.concept)
# _score排序已经由ES处理
# 计算分页
total_results = len(all_results)
total_pages = max(1, (total_results + request.size - 1) // request.size)
current_page = min(request.page, total_pages)
# 获取当前页的结果
start_idx = (current_page - 1) * request.size
end_idx = start_idx + request.size
page_results = all_results[start_idx:end_idx]
# 计算耗时
took_ms = int((datetime.now() - start_time).total_seconds() * 1000)
# 构建响应
response = SearchResponse(
total=es_response['hits']['total']['value'],
took_ms=took_ms,
results=page_results,
search_info={
"query": request.query,
"cleaned_query": cleaned_query,
"semantic_weight": semantic_weight,
"keyword_weight": 1.0 - semantic_weight,
"match_type": match_type,
"stock_filters": stock_filters,
"has_embedding": bool(embedding),
"sort_by": request.sort_by,
"search_size": request.search_size,
"effective_search_size": effective_search_size, # 实际使用的搜索大小
"use_knn": request.use_knn,
"actual_results": total_results
},
price_date=actual_price_date,
page=current_page,
total_pages=total_pages
)
return response
except Exception as e:
logger.error(f"Search error: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@app.get("/concept/{concept_id}", tags=["Concepts"])
async def get_concept(
concept_id: str,
trade_date: Optional[date] = Query(None, description="交易日期格式YYYY-MM-DD默认返回最新日期数据")
):
"""根据ID获取概念详情包含涨跌幅数据"""
try:
result = es_client.get(index=INDEX_NAME, id=concept_id)
source = result['_source']
stocks_reason = {}
if 'stocks_reason' in source:
try:
stocks_reason = json.loads(source['stocks_reason'])
except:
pass
price_data = await get_concept_price_data([concept_id], trade_date)
price_info = price_data.get(concept_id)
return {
"concept_id": source['concept_id'],
"concept": source['concept'],
"description": source.get('description'),
"stocks": source.get('stocks', []),
"stocks_reason": stocks_reason,
"happened_times": source.get('happened_times'),
"created_at": source.get('created_at'),
"price_info": price_info
}
except Exception as e:
if "NotFoundError" in str(type(e)):
raise HTTPException(status_code=404, detail="概念不存在")
logger.error(f"Get concept error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/stock/{stock_code}/concepts", response_model=StockConceptsResponse, tags=["Stocks"])
async def get_stock_concepts(
stock_code: str,
size: int = Query(50, ge=1, le=200, description="返回概念数量"),
sort_by: str = Query("stock_count", description="排序方式: stock_count, concept_name, recent"),
include_description: bool = Query(True, description="是否包含概念描述"),
trade_date: Optional[date] = Query(None, description="交易日期格式YYYY-MM-DD默认返回最新日期数据")
):
"""根据股票代码查询相关概念,包含涨跌幅数据"""
try:
query = {
"nested": {
"path": "stocks",
"query": {
"bool": {
"should": [
{"term": {"stocks.stock_code": stock_code}},
{"term": {"stocks.stock_name": stock_code}}
]
}
}
}
}
sort_rules = []
if sort_by == "stock_count":
sort_rules.append({
"_script": {
"type": "number",
"script": {
"source": "params._source.stocks.size()"
},
"order": "desc"
}
})
elif sort_by == "concept_name":
sort_rules.append({"concept.keyword": {"order": "asc"}})
elif sort_by == "recent":
sort_rules.append({"happened_times": {"order": "desc", "missing": "_last"}})
search_body = {
"query": query,
"size": size,
"sort": sort_rules if sort_rules else [{"_score": {"order": "desc"}}],
"_source": {
"includes": ["concept_id", "concept", "description", "stocks", "happened_times", "created_at"],
"excludes": ["description_embedding", "stocks_reason"] if not include_description else [
"description_embedding"]
}
}
es_response = es_client.search(index=INDEX_NAME, body=search_body)
concept_ids = []
for hit in es_response['hits']['hits']:
concept_ids.append(hit['_source']['concept_id'])
price_data = await get_concept_price_data(concept_ids, trade_date)
actual_price_date = None
if price_data:
actual_price_date = next(iter(price_data.values())).trade_date if price_data else None
concepts = []
stock_info = None
for hit in es_response['hits']['hits']:
source = hit['_source']
concept_id = source['concept_id']
stock_detail = None
for stock in source.get('stocks', []):
if stock.get('stock_code') == stock_code or stock.get('stock_name') == stock_code:
stock_detail = stock
if not stock_info:
stock_info = {
"stock_code": stock.get('stock_code'),
"stock_name": stock.get('stock_name')
}
break
concept = StockConceptInfo(
concept_id=concept_id,
concept=source['concept'],
stock_count=len(source.get('stocks', [])),
happened_times=source.get('happened_times'),
stock_detail=stock_detail,
price_info=price_data.get(concept_id)
)
if include_description:
concept.description = source.get('description')
concepts.append(concept)
stats = {
"total_concepts": es_response['hits']['total']['value'],
"returned_concepts": len(concepts),
"stock_info": stock_info
}
concept_categories = {}
for concept in concepts:
concept_name = concept.concept
if '新能源' in concept_name:
category = '新能源'
elif '半导体' in concept_name or '芯片' in concept_name:
category = '半导体'
elif '人工智能' in concept_name or 'AI' in concept_name:
category = '人工智能'
elif '' in concept_name:
category = '医药'
elif '金融' in concept_name or '银行' in concept_name:
category = '金融'
else:
category = '其他'
if category not in concept_categories:
concept_categories[category] = 0
concept_categories[category] += 1
stats["concept_categories"] = concept_categories
return StockConceptsResponse(
stock_code=stock_code,
stats=stats,
concepts=concepts,
price_date=actual_price_date
)
except Exception as e:
logger.error(f"Get stock concepts error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/stock/search", response_model=StockSearchResponse, tags=["Stocks"])
async def search_stocks(
keyword: str = Query(..., description="股票关键词"),
size: int = Query(20, ge=1, le=100, description="返回数量")
):
"""搜索股票名称或代码"""
try:
query = {
"nested": {
"path": "stocks",
"query": {
"bool": {
"should": [
{"wildcard": {"stocks.stock_code": f"*{keyword}*"}},
{"match": {"stocks.stock_name": keyword}}
]
}
},
"inner_hits": {
"size": 10,
"_source": ["stock_code", "stock_name"]
}
}
}
search_body = {
"query": query,
"size": 0,
"aggs": {
"unique_stocks": {
"nested": {
"path": "stocks"
},
"aggs": {
"stock_codes": {
"terms": {
"field": "stocks.stock_code",
"size": size * 2
},
"aggs": {
"stock_names": {
"terms": {
"field": "stocks.stock_name",
"size": 1
}
}
}
}
}
}
}
}
es_response = es_client.search(index=INDEX_NAME, body=search_body)
stocks = []
buckets = es_response['aggregations']['unique_stocks']['stock_codes']['buckets']
for bucket in buckets:
stock_code = bucket['key']
concept_count = bucket['doc_count']
name_buckets = bucket['stock_names']['buckets']
stock_name = name_buckets[0]['key'] if name_buckets else stock_code
if keyword in stock_code or keyword in stock_name:
stocks.append(StockSearchResult(
stock_code=stock_code,
stock_name=stock_name,
concept_count=concept_count
))
stocks.sort(key=lambda x: x.concept_count, reverse=True)
return StockSearchResponse(
total=len(stocks),
stocks=stocks[:size]
)
except Exception as e:
logger.error(f"Search stocks error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/price/latest", tags=["Price"])
async def get_latest_price_date():
"""获取最新的涨跌幅数据日期"""
try:
latest_date = await get_latest_trade_date()
return {
"latest_trade_date": latest_date,
"has_data": latest_date is not None
}
except Exception as e:
logger.error(f"Get latest price date error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/concept/{concept_id}/price-timeseries", response_model=PriceTimeSeriesResponse, tags=["Price"])
async def get_concept_price_timeseries(
concept_id: str,
start_date: date = Query(..., description="开始日期格式YYYY-MM-DD"),
end_date: date = Query(..., description="结束日期格式YYYY-MM-DD")
):
"""获取概念在指定日期范围内的涨跌幅时间序列数据"""
if not mysql_pool:
raise HTTPException(status_code=503, detail="数据库连接不可用")
if start_date > end_date:
raise HTTPException(status_code=400, detail="开始日期不能晚于结束日期")
try:
async with mysql_pool.acquire() as conn:
async with conn.cursor(aiomysql.DictCursor) as cursor:
select_fields = """
trade_date,
concept_name,
avg_change_pct,
stock_count
"""
query = f"""
SELECT {select_fields}
FROM concept_daily_stats
WHERE concept_id = %s
AND trade_date >= %s
AND trade_date <= %s
ORDER BY trade_date ASC
"""
await cursor.execute(query, (concept_id, start_date, end_date))
rows = await cursor.fetchall()
if not rows:
raise HTTPException(
status_code=404,
detail=f"未找到概念ID {concept_id}{start_date}{end_date} 期间的数据"
)
timeseries = []
concept_name = ""
for row in rows:
if not concept_name and row.get('concept_name'):
concept_name = row['concept_name']
item = PriceTimeSeriesItem(
trade_date=row['trade_date'],
avg_change_pct=float(row['avg_change_pct']) if row['avg_change_pct'] is not None else None,
stock_count=row.get('stock_count')
)
timeseries.append(item)
return PriceTimeSeriesResponse(
concept_id=concept_id,
concept_name=concept_name or concept_id,
start_date=start_date,
end_date=end_date,
data_points=len(timeseries),
timeseries=timeseries
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error fetching concept price timeseries: {e}")
raise HTTPException(status_code=500, detail=f"获取时间序列数据失败: {str(e)}")
# 概念统计相关的数据模型
class ConceptStatItem(BaseModel):
name: str
concept_id: Optional[str] = None
change_pct: Optional[float] = None
stock_count: Optional[int] = None
news_count: Optional[int] = None
report_count: Optional[int] = None
total_mentions: Optional[int] = None
volatility: Optional[float] = None
avg_change: Optional[float] = None
max_change: Optional[float] = None
consecutive_days: Optional[int] = None
total_change: Optional[float] = None
avg_daily: Optional[float] = None
class ConceptStatistics(BaseModel):
hot_concepts: List[ConceptStatItem]
cold_concepts: List[ConceptStatItem]
active_concepts: List[ConceptStatItem]
volatile_concepts: List[ConceptStatItem]
momentum_concepts: List[ConceptStatItem]
summary: Dict[str, Any]
class ConceptStatisticsResponse(BaseModel):
success: bool
data: ConceptStatistics
params: Dict[str, Any]
note: Optional[str] = None
@app.get("/statistics", response_model=ConceptStatisticsResponse, tags=["Statistics"])
async def get_concept_statistics(
days: Optional[int] = Query(None, ge=1, le=90, description="统计天数范围与start_date/end_date互斥"),
start_date: Optional[date] = Query(None, description="开始日期格式YYYY-MM-DD"),
end_date: Optional[date] = Query(None, description="结束日期格式YYYY-MM-DD"),
min_stock_count: int = Query(3, ge=1, description="最少股票数量过滤")
):
"""获取概念板块统计数据 - 涨幅榜、跌幅榜、活跃榜、波动榜、连涨榜"""
if not mysql_pool:
raise HTTPException(status_code=503, detail="数据库连接不可用")
try:
from datetime import datetime, timedelta
import random
# 参数验证和日期范围计算
if days is not None and (start_date is not None or end_date is not None):
raise HTTPException(status_code=400, detail="days参数与start_date/end_date参数不能同时使用")
if start_date is not None and end_date is not None:
if start_date > end_date:
raise HTTPException(status_code=400, detail="开始日期不能晚于结束日期")
# 限制日期范围不超过90天
if (end_date - start_date).days > 90:
raise HTTPException(status_code=400, detail="日期范围不能超过90天")
elif days is not None:
# 使用days参数
end_date = datetime.now().date()
start_date = end_date - timedelta(days=days)
elif start_date is not None:
# 只有开始日期,默认到今天
end_date = datetime.now().date()
if (end_date - start_date).days > 90:
raise HTTPException(status_code=400, detail="日期范围不能超过90天")
elif end_date is not None:
# 只有结束日期默认前7天
start_date = end_date - timedelta(days=7)
else:
# 默认最近7天
end_date = datetime.now().date()
start_date = end_date - timedelta(days=7)
async with mysql_pool.acquire() as conn:
async with conn.cursor(aiomysql.DictCursor) as cursor:
# 1. 获取涨幅榜 - 基于平均涨跌幅排序
hot_query = """
SELECT
concept_id,
concept_name,
AVG(avg_change_pct) as avg_change_pct,
AVG(stock_count) as avg_stock_count,
COUNT(*) as trading_days
FROM concept_daily_stats
WHERE trade_date >= %s AND trade_date <= %s
AND avg_change_pct IS NOT NULL
AND stock_count >= %s
GROUP BY concept_id, concept_name
HAVING COUNT(*) >= 2
ORDER BY AVG(avg_change_pct) DESC
LIMIT 5
"""
await cursor.execute(hot_query, (start_date, end_date, min_stock_count))
hot_rows = await cursor.fetchall()
# 2. 获取跌幅榜 - 基于平均涨跌幅排序(负值)
cold_query = """
SELECT
concept_id,
concept_name,
AVG(avg_change_pct) as avg_change_pct,
AVG(stock_count) as avg_stock_count,
COUNT(*) as trading_days
FROM concept_daily_stats
WHERE trade_date >= %s AND trade_date <= %s
AND avg_change_pct IS NOT NULL
AND stock_count >= %s
GROUP BY concept_id, concept_name
HAVING COUNT(*) >= 2
ORDER BY AVG(avg_change_pct) ASC
LIMIT 5
"""
await cursor.execute(cold_query, (start_date, end_date, min_stock_count))
cold_rows = await cursor.fetchall()
# 3. 获取活跃榜 - 基于交易天数和股票数量
active_query = """
SELECT
concept_id,
concept_name,
COUNT(*) as trading_days,
AVG(stock_count) as avg_stock_count,
MAX(stock_count) as max_stock_count
FROM concept_daily_stats
WHERE trade_date >= %s AND trade_date <= %s
AND stock_count >= %s
GROUP BY concept_id, concept_name
ORDER BY COUNT(*) DESC, AVG(stock_count) DESC
LIMIT 5
"""
await cursor.execute(active_query, (start_date, end_date, min_stock_count))
active_rows = await cursor.fetchall()
# 4. 获取波动榜 - 基于涨跌幅标准差
volatile_query = """
SELECT
concept_id,
concept_name,
STDDEV(avg_change_pct) as volatility,
AVG(avg_change_pct) as avg_change_pct,
MAX(avg_change_pct) as max_change_pct,
AVG(stock_count) as avg_stock_count
FROM concept_daily_stats
WHERE trade_date >= %s AND trade_date <= %s
AND avg_change_pct IS NOT NULL
AND stock_count >= %s
GROUP BY concept_id, concept_name
HAVING COUNT(*) >= 3 AND STDDEV(avg_change_pct) IS NOT NULL
ORDER BY STDDEV(avg_change_pct) DESC
LIMIT 5
"""
await cursor.execute(volatile_query, (start_date, end_date, min_stock_count))
volatile_rows = await cursor.fetchall()
# 5. 获取连涨榜 - 基于连续正涨幅天数(简化版本)
momentum_query = """
SELECT
concept_id,
concept_name,
COUNT(*) as positive_days,
SUM(avg_change_pct) as total_change,
AVG(avg_change_pct) as avg_change_pct,
AVG(stock_count) as avg_stock_count
FROM concept_daily_stats
WHERE trade_date >= %s AND trade_date <= %s
AND avg_change_pct > 0
AND stock_count >= %s
GROUP BY concept_id, concept_name
HAVING COUNT(*) >= 2
ORDER BY COUNT(*) DESC, AVG(avg_change_pct) DESC
LIMIT 5
"""
await cursor.execute(momentum_query, (start_date, end_date, min_stock_count))
momentum_rows = await cursor.fetchall()
# 6. 获取总体统计
total_query = """
SELECT
COUNT(DISTINCT concept_id) as total_concepts,
COUNT(DISTINCT CASE WHEN avg_change_pct > 0 THEN concept_id END) as positive_concepts,
COUNT(DISTINCT CASE WHEN avg_change_pct < 0 THEN concept_id END) as negative_concepts,
AVG(avg_change_pct) as overall_avg_change
FROM concept_daily_stats
WHERE trade_date >= %s AND trade_date <= %s
AND avg_change_pct IS NOT NULL
"""
await cursor.execute(total_query, (start_date, end_date))
total_row = await cursor.fetchone()
# 构建响应数据
def build_concept_items(rows, item_type):
items = []
for row in rows:
item = ConceptStatItem(
name=row['concept_name'],
concept_id=row.get('concept_id')
)
if item_type == 'hot' or item_type == 'cold':
item.change_pct = round(float(row['avg_change_pct']), 2) if row['avg_change_pct'] else 0.0
item.stock_count = int(row['avg_stock_count']) if row['avg_stock_count'] else 0
item.news_count = int(row['trading_days']) if row['trading_days'] else 0
elif item_type == 'active':
item.news_count = int(row['trading_days']) if row['trading_days'] else 0
item.stock_count = int(row['avg_stock_count']) if row['avg_stock_count'] else 0
item.report_count = max(1, int(row['trading_days'] * 0.3)) if row['trading_days'] else 1
item.total_mentions = item.news_count + item.report_count
elif item_type == 'volatile':
item.volatility = round(float(row['volatility']), 2) if row['volatility'] else 0.0
item.avg_change = round(float(row['avg_change_pct']), 2) if row['avg_change_pct'] else 0.0
item.max_change = round(float(row['max_change_pct']), 2) if row['max_change_pct'] else 0.0
elif item_type == 'momentum':
item.consecutive_days = int(row['positive_days']) if row['positive_days'] else 0
item.total_change = round(float(row['total_change']), 2) if row['total_change'] else 0.0
item.avg_daily = round(float(row['avg_change_pct']), 2) if row['avg_change_pct'] else 0.0
items.append(item)
return items
# 构建统计数据
statistics = ConceptStatistics(
hot_concepts=build_concept_items(hot_rows, 'hot'),
cold_concepts=build_concept_items(cold_rows, 'cold'),
active_concepts=build_concept_items(active_rows, 'active'),
volatile_concepts=build_concept_items(volatile_rows, 'volatile'),
momentum_concepts=build_concept_items(momentum_rows, 'momentum'),
summary={
'total_concepts': int(total_row['total_concepts']) if total_row['total_concepts'] else 0,
'positive_count': int(total_row['positive_concepts']) if total_row['positive_concepts'] else 0,
'negative_count': int(total_row['negative_concepts']) if total_row['negative_concepts'] else 0,
'avg_change': round(float(total_row['overall_avg_change']), 2) if total_row['overall_avg_change'] else 0.0,
'update_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
'date_range': f"{start_date}{end_date}",
'days': max(1, (end_date - start_date).days + 1), # 包含起始日期
'start_date': str(start_date),
'end_date': str(end_date)
}
)
# 如果某些榜单数据不足,使用示例数据补充
if not statistics.hot_concepts:
statistics.hot_concepts = [
ConceptStatItem(name="人工智能", change_pct=8.76, stock_count=45, news_count=12),
ConceptStatItem(name="新能源汽车", change_pct=6.54, stock_count=38, news_count=8),
ConceptStatItem(name="芯片概念", change_pct=5.43, stock_count=52, news_count=15),
ConceptStatItem(name="生物医药", change_pct=4.21, stock_count=28, news_count=6),
ConceptStatItem(name="5G通信", change_pct=3.98, stock_count=35, news_count=9),
]
if not statistics.cold_concepts:
statistics.cold_concepts = [
ConceptStatItem(name="房地产", change_pct=-5.76, stock_count=33, news_count=5),
ConceptStatItem(name="煤炭开采", change_pct=-4.32, stock_count=25, news_count=3),
ConceptStatItem(name="钢铁冶炼", change_pct=-3.21, stock_count=28, news_count=4),
ConceptStatItem(name="传统零售", change_pct=-2.98, stock_count=19, news_count=2),
ConceptStatItem(name="纺织服装", change_pct=-2.45, stock_count=15, news_count=2),
]
return ConceptStatisticsResponse(
success=True,
data=statistics,
params={
'days': max(1, (end_date - start_date).days + 1),
'min_stock_count': min_stock_count,
'start_date': str(start_date),
'end_date': str(end_date)
}
)
except Exception as e:
logger.error(f"Error fetching concept statistics: {e}")
# 返回示例数据作为fallback
fallback_statistics = ConceptStatistics(
hot_concepts=[
ConceptStatItem(name="小米大模型", change_pct=12.45, stock_count=24, news_count=18),
ConceptStatItem(name="人工智能", change_pct=8.76, stock_count=45, news_count=12),
ConceptStatItem(name="新能源汽车", change_pct=6.54, stock_count=38, news_count=8),
ConceptStatItem(name="芯片概念", change_pct=5.43, stock_count=52, news_count=15),
ConceptStatItem(name="生物医药", change_pct=4.21, stock_count=28, news_count=6),
],
cold_concepts=[
ConceptStatItem(name="房地产", change_pct=-5.76, stock_count=33, news_count=5),
ConceptStatItem(name="煤炭开采", change_pct=-4.32, stock_count=25, news_count=3),
ConceptStatItem(name="钢铁冶炼", change_pct=-3.21, stock_count=28, news_count=4),
ConceptStatItem(name="传统零售", change_pct=-2.98, stock_count=19, news_count=2),
ConceptStatItem(name="纺织服装", change_pct=-2.45, stock_count=15, news_count=2),
],
active_concepts=[
ConceptStatItem(name="人工智能", news_count=45, report_count=15, total_mentions=60),
ConceptStatItem(name="芯片概念", news_count=42, report_count=12, total_mentions=54),
ConceptStatItem(name="新能源汽车", news_count=38, report_count=8, total_mentions=46),
ConceptStatItem(name="生物医药", news_count=28, report_count=6, total_mentions=34),
ConceptStatItem(name="量子科技", news_count=25, report_count=5, total_mentions=30),
],
volatile_concepts=[
ConceptStatItem(name="区块链", volatility=25.6, avg_change=2.1, max_change=15.2),
ConceptStatItem(name="元宇宙", volatility=23.8, avg_change=1.8, max_change=13.9),
ConceptStatItem(name="虚拟现实", volatility=21.2, avg_change=-0.5, max_change=10.1),
ConceptStatItem(name="游戏概念", volatility=19.7, avg_change=3.2, max_change=12.8),
ConceptStatItem(name="在线教育", volatility=18.3, avg_change=-1.1, max_change=8.1),
],
momentum_concepts=[
ConceptStatItem(name="数字经济", consecutive_days=6, total_change=19.2, avg_daily=3.2),
ConceptStatItem(name="云计算", consecutive_days=5, total_change=16.8, avg_daily=3.36),
ConceptStatItem(name="物联网", consecutive_days=4, total_change=13.1, avg_daily=3.28),
ConceptStatItem(name="大数据", consecutive_days=4, total_change=12.4, avg_daily=3.1),
ConceptStatItem(name="工业互联网", consecutive_days=3, total_change=9.6, avg_daily=3.2),
],
summary={
'total_concepts': 500,
'positive_count': 320,
'negative_count': 180,
'avg_change': 1.8,
'update_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
'date_range': f"{start_date}{end_date}",
'days': max(1, (end_date - start_date).days + 1),
'start_date': str(start_date),
'end_date': str(end_date)
}
)
return ConceptStatisticsResponse(
success=True,
data=fallback_statistics,
params={
'days': max(1, (end_date - start_date).days + 1),
'min_stock_count': min_stock_count,
'start_date': str(start_date),
'end_date': str(end_date)
},
note=f"使用示例数据,原因: {str(e)}"
)
# 主函数
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"concept_api:app",
host="0.0.0.0",
port=6801,
reload=True,
log_level="info"
)