Files
vf_react/concept_api_v4.py

1850 lines
71 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
概念搜索API V4 - 适配 concept_library_v4 索引
新特性:
1. 支持股票原因(reason)搜索
2. 优化的股票信息结构
3. 保留层级结构(lv1/lv2/lv3)支持
4. 新增reason搜索接口
"""
import json
import openai
from typing import List, Dict, Optional, Union, Any
from fastapi import FastAPI, HTTPException, Query
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import PlainTextResponse
from pydantic import BaseModel, Field
from elasticsearch import Elasticsearch
from datetime import datetime, date
import logging
import re
from contextlib import asynccontextmanager
import aiomysql
import os
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# ==================== 配置 ====================
# ES配置
ES_HOST = 'http://127.0.0.1:9200'
INDEX_NAME = 'concept_library_v3'
# Embedding配置
OPENAI_BASE_URL = "http://127.0.0.1:8000/v1"
OPENAI_API_KEY = "dummy"
EMBEDDING_MODEL = "qwen3-embedding-8b"
# 层级结构文件
HIERARCHY_FILE = 'concept_hierarchy_v3.json'
# MySQL配置
MYSQL_CONFIG = {
'host': '110.42.32.207',
'port': 3306,
'user': 'root',
'password': 'Zzl33818!',
'db': 'stock',
'charset': 'utf8mb4',
'autocommit': True,
'minsize': 1,
'maxsize': 10
}
# ==================== 全局变量 ====================
es_client = None
openai_client = None
mysql_pool = None
# 层级结构相关
hierarchy_data = {}
concept_to_hierarchy = {}
def load_hierarchy():
"""加载层级结构并建立概念到层级的映射"""
global hierarchy_data, concept_to_hierarchy
hierarchy_path = os.path.join(os.path.dirname(__file__), HIERARCHY_FILE)
if not os.path.exists(hierarchy_path):
logger.warning(f"层级文件不存在: {hierarchy_path}")
return
try:
with open(hierarchy_path, 'r', encoding='utf-8') as f:
hierarchy_data = json.load(f)
for lv1 in hierarchy_data.get('hierarchy', []):
lv1_name = lv1.get('lv1', '')
lv1_id = lv1.get('lv1_id', '')
for child in lv1.get('children', []):
lv2_name = child.get('lv2', '')
lv2_id = child.get('lv2_id', '')
if 'children' in child:
for lv3_child in child.get('children', []):
lv3_name = lv3_child.get('lv3', '')
lv3_id = lv3_child.get('lv3_id', '')
for concept in lv3_child.get('concepts', []):
concept_to_hierarchy[concept] = {
'lv1': lv1_name, 'lv1_id': lv1_id,
'lv2': lv2_name, 'lv2_id': lv2_id,
'lv3': lv3_name, 'lv3_id': lv3_id
}
else:
for concept in child.get('concepts', []):
concept_to_hierarchy[concept] = {
'lv1': lv1_name, 'lv1_id': lv1_id,
'lv2': lv2_name, 'lv2_id': lv2_id,
'lv3': None, 'lv3_id': None
}
logger.info(f"加载层级结构完成,共 {len(concept_to_hierarchy)} 个概念有层级信息")
except Exception as e:
logger.error(f"加载层级结构失败: {e}")
def get_concept_hierarchy(concept_name: str) -> Optional[Dict]:
"""获取概念的层级信息"""
return concept_to_hierarchy.get(concept_name)
# ==================== 生命周期管理 ====================
@asynccontextmanager
async def lifespan(app: FastAPI):
global es_client, openai_client, mysql_pool
load_hierarchy()
es_client = Elasticsearch(
[ES_HOST],
timeout=30,
max_retries=3,
retry_on_timeout=True
)
logger.info(f"Connected to Elasticsearch at {ES_HOST}, index: {INDEX_NAME}")
openai_client = openai.OpenAI(
api_key=OPENAI_API_KEY,
base_url=OPENAI_BASE_URL,
timeout=60,
)
logger.info(f"Initialized OpenAI client")
try:
mysql_pool = await aiomysql.create_pool(**MYSQL_CONFIG)
logger.info(f"Connected to MySQL at {MYSQL_CONFIG['host']}")
except Exception as e:
logger.error(f"Failed to connect to MySQL: {e}")
yield
if es_client:
es_client.close()
if mysql_pool:
mysql_pool.close()
await mysql_pool.wait_closed()
logger.info("Cleanup completed")
# ==================== FastAPI 应用 ====================
app = FastAPI(
title="概念搜索API V4",
description="支持股票原因搜索的概念库API适配 concept_library_v4 索引",
version="4.0.0",
lifespan=lifespan
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ==================== 微信小程序验证 ====================
@app.get("/DfASFmNQoo.txt", response_class=PlainTextResponse)
async def wechat_verification():
return "ebd78eb22819b1393a34c6ae1e8fcce6"
# ==================== 数据模型 ====================
class HierarchyInfo(BaseModel):
lv1: Optional[str] = None
lv1_id: Optional[str] = None
lv2: Optional[str] = None
lv2_id: Optional[str] = None
lv3: Optional[str] = None
lv3_id: Optional[str] = None
class StockInfo(BaseModel):
"""股票信息(包含原因)"""
name: str
code: Optional[str] = None
reason: Optional[str] = None
class ConceptPriceInfo(BaseModel):
trade_date: date
avg_change_pct: Optional[float] = None
class SearchRequest(BaseModel):
query: str
size: int = Field(10, ge=1, le=100)
page: int = Field(1, ge=1)
search_size: int = Field(100, ge=10, le=1000)
semantic_weight: Optional[float] = Field(None, ge=0.0, le=1.0)
filter_stocks: Optional[List[str]] = None
filter_lv1: Optional[str] = None
filter_lv2: Optional[str] = None
trade_date: Optional[date] = None
sort_by: str = Field("_score")
use_knn: bool = Field(True)
search_reason: bool = Field(False, description="是否同时搜索股票原因")
include_stock_reasons: bool = Field(False, description="是否在返回的股票列表中包含原因")
class ReasonSearchRequest(BaseModel):
"""股票原因搜索请求"""
query: str = Field(..., description="搜索关键词")
size: int = Field(20, ge=1, le=100)
page: int = Field(1, ge=1)
include_stock_details: bool = Field(True, description="是否返回匹配的股票详情")
class ConceptResult(BaseModel):
concept_id: str
concept: str
description: Optional[str] = None
tags: Optional[List[str]] = None
outbreak_dates: Optional[List[str]] = None
stocks: List[StockInfo] = Field(default_factory=list, description="股票列表是否包含原因取决于include_stock_reasons参数")
stock_count: int = 0
hierarchy: Optional[HierarchyInfo] = None
score: float = 0.0
match_type: str = "keyword"
highlights: Optional[Dict[str, List[str]]] = None
price_info: Optional[ConceptPriceInfo] = None
matched_stocks: Optional[List[StockInfo]] = Field(None, description="匹配的股票含原因仅在search_reason时返回")
class SearchResponse(BaseModel):
total: int
took_ms: int
results: List[ConceptResult]
search_info: Dict[str, Any]
price_date: Optional[date] = None
page: int = 1
total_pages: int = 1
class ReasonSearchResult(BaseModel):
"""原因搜索结果"""
concept_id: str
concept: str
description: Optional[str] = None
score: float
matched_stocks: List[StockInfo] = Field(default_factory=list, description="匹配的股票及原因")
reason_highlights: Optional[List[str]] = None
class ReasonSearchResponse(BaseModel):
"""原因搜索响应"""
total: int
took_ms: int
query: str
results: List[ReasonSearchResult]
page: int = 1
total_pages: int = 1
class ConceptDetailResponse(BaseModel):
concept_id: str
concept: str
description: Optional[str] = None
insight: Optional[str] = None
tags: Optional[List[str]] = None
outbreak_dates: Optional[List[str]] = None
stocks: List[StockInfo] = Field(default_factory=list, description="股票列表(含原因)")
stock_count: int = 0
hierarchy: Optional[HierarchyInfo] = None
folders: Optional[List[str]] = None
created_at: Optional[str] = None
price_info: Optional[ConceptPriceInfo] = None
class HierarchyLevel(BaseModel):
id: str
name: str
concept_count: int = 0
children: Optional[List['HierarchyLevel']] = None
concepts: Optional[List[str]] = None
class HierarchyResponse(BaseModel):
hierarchy: List[HierarchyLevel]
total_concepts: int
class PriceTimeSeriesItem(BaseModel):
trade_date: date
avg_change_pct: Optional[float] = None
stock_count: Optional[int] = None
class PriceTimeSeriesResponse(BaseModel):
concept_id: str
concept_name: str
start_date: date
end_date: date
data_points: int
timeseries: List[PriceTimeSeriesItem]
class ConceptStatItem(BaseModel):
name: str
concept_id: Optional[str] = None
change_pct: Optional[float] = None
stock_count: Optional[int] = None
news_count: Optional[int] = None
report_count: Optional[int] = None
total_mentions: Optional[int] = None
volatility: Optional[float] = None
avg_change: Optional[float] = None
max_change: Optional[float] = None
consecutive_days: Optional[int] = None
total_change: Optional[float] = None
avg_daily: Optional[float] = None
hierarchy: Optional[HierarchyInfo] = None
class ConceptStatistics(BaseModel):
hot_concepts: List[ConceptStatItem]
cold_concepts: List[ConceptStatItem]
active_concepts: List[ConceptStatItem]
volatile_concepts: List[ConceptStatItem]
momentum_concepts: List[ConceptStatItem]
summary: Dict[str, Any]
class ConceptStatisticsResponse(BaseModel):
success: bool
data: ConceptStatistics
params: Dict[str, Any]
note: Optional[str] = None
class ConceptPriceItem(BaseModel):
concept_id: str
concept_name: str
concept_type: str
trade_date: date
avg_change_pct: Optional[float] = None
stock_count: Optional[int] = None
hierarchy: Optional[HierarchyInfo] = None
class ConceptPriceListResponse(BaseModel):
trade_date: date
total: int
concepts: List[ConceptPriceItem]
class StockSearchResult(BaseModel):
stock_code: str
stock_name: str
concept_count: int
class StockSearchResponse(BaseModel):
total: int
stocks: List[StockSearchResult]
# ==================== 辅助函数 ====================
def generate_embedding(text: str) -> List[float]:
try:
if not text or len(text.strip()) == 0:
return []
text = text[:8000] if len(text) > 8000 else text
if not openai_client:
return []
response = openai_client.embeddings.create(model=EMBEDDING_MODEL, input=[text])
return response.data[0].embedding
except Exception as e:
logger.warning(f"Embedding生成失败: {e}")
return []
def calculate_semantic_weight(query: str) -> float:
if not query or not query.strip():
return 0.0
query_length = len(query.strip())
if query_length < 10:
return 0.3
elif query_length < 50:
return 0.5
elif query_length < 200:
return 0.6
else:
return 0.7
async def get_concept_price_data(concept_ids: List[str], trade_date: Optional[date] = None) -> Dict[str, ConceptPriceInfo]:
if not mysql_pool or not concept_ids:
return {}
try:
async with mysql_pool.acquire() as conn:
async with conn.cursor(aiomysql.DictCursor) as cursor:
placeholders = ','.join(['%s'] * len(concept_ids))
if trade_date:
query = f"""
SELECT concept_id, trade_date, avg_change_pct
FROM concept_daily_stats
WHERE concept_id IN ({placeholders}) AND trade_date = %s
"""
await cursor.execute(query, (*concept_ids, trade_date))
else:
query = f"""
SELECT cds.concept_id, cds.trade_date, cds.avg_change_pct
FROM concept_daily_stats cds
INNER JOIN (
SELECT concept_id, MAX(trade_date) as max_date
FROM concept_daily_stats
WHERE concept_id IN ({placeholders})
GROUP BY concept_id
) latest ON cds.concept_id = latest.concept_id AND cds.trade_date = latest.max_date
"""
await cursor.execute(query, concept_ids)
rows = await cursor.fetchall()
result = {}
for row in rows:
result[row['concept_id']] = ConceptPriceInfo(
trade_date=row['trade_date'],
avg_change_pct=float(row['avg_change_pct']) if row['avg_change_pct'] is not None else None
)
return result
except Exception as e:
logger.error(f"获取涨跌幅数据失败: {e}")
return {}
async def get_top_concepts_by_change(
trade_date: Optional[date],
limit: int = 100,
offset: int = 0,
filter_lv1: Optional[str] = None,
filter_lv2: Optional[str] = None
) -> tuple:
"""
直接从 MySQL 获取涨跌幅排序的概念列表(用于空查询优化)
返回: (概念列表, 总数, 实际查询日期)
"""
if not mysql_pool:
return [], 0, None
try:
async with mysql_pool.acquire() as conn:
async with conn.cursor(aiomysql.DictCursor) as cursor:
query_date = trade_date
if query_date is None:
await cursor.execute("SELECT MAX(trade_date) as max_date FROM concept_daily_stats WHERE concept_type = 'leaf'")
result = await cursor.fetchone()
if not result or not result['max_date']:
return [], 0, None
query_date = result['max_date']
where_conditions = ["trade_date = %s", "concept_type = 'leaf'"]
params = [query_date]
if filter_lv1 or filter_lv2:
filtered_concepts = []
for concept_name, hierarchy in concept_to_hierarchy.items():
if filter_lv1 and hierarchy.get('lv1') != filter_lv1:
continue
if filter_lv2 and hierarchy.get('lv2') != filter_lv2:
continue
filtered_concepts.append(concept_name)
if not filtered_concepts:
return [], 0, query_date
placeholders = ','.join(['%s'] * len(filtered_concepts))
where_conditions.append(f"concept_name IN ({placeholders})")
params.extend(filtered_concepts)
where_clause = " AND ".join(where_conditions)
count_query = f"SELECT COUNT(*) as cnt FROM concept_daily_stats WHERE {where_clause}"
await cursor.execute(count_query, params)
total = (await cursor.fetchone())['cnt']
data_query = f"""
SELECT concept_id, concept_name, avg_change_pct, stock_count
FROM concept_daily_stats
WHERE {where_clause}
ORDER BY avg_change_pct DESC
LIMIT %s OFFSET %s
"""
await cursor.execute(data_query, params + [limit, offset])
rows = await cursor.fetchall()
concepts = []
for row in rows:
concepts.append({
'concept_id': row['concept_id'],
'concept_name': row['concept_name'],
'avg_change_pct': float(row['avg_change_pct']) if row['avg_change_pct'] else None,
'stock_count': row['stock_count'],
'trade_date': query_date
})
return concepts, total, query_date
except Exception as e:
logger.error(f"获取涨跌幅排序概念失败: {e}")
return [], 0, None
async def get_latest_trade_date() -> Optional[date]:
"""获取最新的交易日期"""
if not mysql_pool:
return None
try:
async with mysql_pool.acquire() as conn:
async with conn.cursor() as cursor:
await cursor.execute("SELECT MAX(trade_date) as latest_date FROM concept_daily_stats")
result = await cursor.fetchone()
return result[0] if result and result[0] else None
except Exception as e:
logger.error(f"获取最新交易日期失败: {e}")
return None
def get_fallback_statistics(start_date, end_date) -> ConceptStatistics:
"""返回示例统计数据"""
return ConceptStatistics(
hot_concepts=[
ConceptStatItem(name="人工智能", change_pct=8.76, stock_count=45, news_count=12),
ConceptStatItem(name="机器人", change_pct=7.54, stock_count=38, news_count=10),
ConceptStatItem(name="半导体", change_pct=6.43, stock_count=52, news_count=15),
],
cold_concepts=[
ConceptStatItem(name="房地产", change_pct=-5.76, stock_count=33, news_count=5),
ConceptStatItem(name="煤炭", change_pct=-4.32, stock_count=25, news_count=3),
],
active_concepts=[
ConceptStatItem(name="人工智能", news_count=45, report_count=15, total_mentions=60),
],
volatile_concepts=[
ConceptStatItem(name="数字货币", volatility=25.6, avg_change=2.1, max_change=15.2),
],
momentum_concepts=[
ConceptStatItem(name="AI应用", consecutive_days=6, total_change=19.2, avg_daily=3.2),
],
summary={
'total_concepts': 500,
'positive_count': 320,
'negative_count': 180,
'avg_change': 1.8,
'update_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
'date_range': f"{start_date}{end_date}",
'days': (end_date - start_date).days + 1 if start_date and end_date else 7,
'start_date': str(start_date) if start_date else '',
'end_date': str(end_date) if end_date else ''
}
)
def build_keyword_query(query: str, stock_filters: List[str] = None, search_reason: bool = False) -> Dict:
"""构建关键词查询 - V4版本"""
must_queries = []
should_queries = []
if query.strip():
# 概念名称和描述搜索
should_queries.append({
"multi_match": {
"query": query,
"fields": ["concept^3", "description^2", "tags.text^2", "tags^1.5"],
"type": "best_fields",
"analyzer": "ik_smart"
}
})
# 股票名称搜索
should_queries.append({
"match": {
"stock_names.text": {
"query": query,
"boost": 1.5
}
}
})
# 可选:搜索股票原因
if search_reason:
should_queries.append({
"match": {
"stock_reasons_text": {
"query": query,
"boost": 1.0
}
}
})
if stock_filters:
# 股票代码或名称精确匹配
stock_query = {
"bool": {
"should": [
{"terms": {"stock_codes": stock_filters}},
{"terms": {"stock_names": stock_filters}}
]
}
}
must_queries.append(stock_query)
if should_queries:
must_queries.append({
"bool": {
"should": should_queries,
"minimum_should_match": 1
}
})
return {
"bool": {
"must": must_queries if must_queries else [{"match_all": {}}]
}
}
def parse_stocks_from_es(source: Dict, limit: int = 30, include_reasons: bool = False) -> List[StockInfo]:
"""从ES文档解析股票信息
Args:
source: ES文档源数据
limit: 返回股票数量限制
include_reasons: 是否包含股票原因
Returns:
股票信息列表
"""
stocks_data = source.get('stocks', [])
stocks = []
for s in stocks_data[:limit]:
stocks.append(StockInfo(
name=s.get('name', ''),
code=s.get('code'),
reason=s.get('reason') if include_reasons else None
))
return stocks
def parse_stocks_full(source: Dict) -> List[StockInfo]:
"""从ES文档解析完整股票信息含原因无数量限制"""
stocks_data = source.get('stocks', [])
return [
StockInfo(
name=s.get('name', ''),
code=s.get('code'),
reason=s.get('reason')
)
for s in stocks_data
]
# ==================== API 端点 ====================
@app.get("/", tags=["Health"])
async def root():
return {
"status": "healthy",
"service": "概念搜索API V4",
"version": "4.0.0",
"index": INDEX_NAME,
"hierarchy_concepts": len(concept_to_hierarchy),
"features": ["reason_search", "stock_search", "semantic_search"]
}
@app.post("/search", response_model=SearchResponse, tags=["Search"])
async def search_concepts(request: SearchRequest):
"""搜索概念 - 支持语义搜索和原因搜索"""
start_time = datetime.now()
try:
# ========== 优化路径:空查询 + 涨跌幅排序 ==========
is_empty_query = not request.query or not request.query.strip()
is_change_sort = request.sort_by == "change_pct"
no_stock_filter = not request.filter_stocks
if is_empty_query and is_change_sort and no_stock_filter:
logger.info(f"[Search优化] 使用 MySQL 快速路径: page={request.page}, size={request.size}")
offset = (request.page - 1) * request.size
top_concepts, total, actual_date = await get_top_concepts_by_change(
trade_date=request.trade_date,
limit=request.size,
offset=offset,
filter_lv1=request.filter_lv1,
filter_lv2=request.filter_lv2
)
if not top_concepts:
took_ms = int((datetime.now() - start_time).total_seconds() * 1000)
return SearchResponse(
total=0, took_ms=took_ms, results=[],
search_info={"query": "", "semantic_weight": 0, "match_type": "mysql_optimized"},
price_date=actual_date, page=request.page, total_pages=0
)
concept_ids = [c['concept_id'] for c in top_concepts]
es_body = {
"query": {"terms": {"concept_id": concept_ids}},
"size": len(concept_ids),
"_source": {"excludes": ["description_embedding", "insight", "stock_reasons_text"]}
}
es_response = es_client.search(index=INDEX_NAME, body=es_body, timeout="10s")
es_details = {}
for hit in es_response['hits']['hits']:
source = hit['_source']
es_details[source.get('concept_id', '')] = source
results = []
for mysql_concept in top_concepts:
cid = mysql_concept['concept_id']
es_data = es_details.get(cid, {})
concept_name = mysql_concept['concept_name']
stocks = parse_stocks_from_es(es_data, include_reasons=request.include_stock_reasons)
hierarchy_info = get_concept_hierarchy(concept_name)
hierarchy = HierarchyInfo(**hierarchy_info) if hierarchy_info else None
result = ConceptResult(
concept_id=cid,
concept=concept_name,
description=es_data.get('description'),
tags=es_data.get('tags', []),
outbreak_dates=es_data.get('outbreak_dates', []),
stocks=stocks,
stock_count=mysql_concept['stock_count'] or len(es_data.get('stocks', [])),
hierarchy=hierarchy,
score=0.0,
match_type="mysql_optimized",
highlights=None,
price_info=ConceptPriceInfo(
trade_date=actual_date,
avg_change_pct=mysql_concept['avg_change_pct']
)
)
results.append(result)
took_ms = int((datetime.now() - start_time).total_seconds() * 1000)
total_pages = max(1, (total + request.size - 1) // request.size)
return SearchResponse(
total=total, took_ms=took_ms, results=results,
search_info={
"query": "", "semantic_weight": 0, "match_type": "mysql_optimized",
"filter_lv1": request.filter_lv1, "filter_lv2": request.filter_lv2,
"sort_by": request.sort_by
},
price_date=actual_date, page=request.page, total_pages=total_pages
)
# ========== 常规路径 ==========
# 计算语义权重
if request.semantic_weight is not None:
semantic_weight = request.semantic_weight
else:
semantic_weight = calculate_semantic_weight(request.query)
# 生成embedding
embedding = []
if semantic_weight > 0:
embedding = generate_embedding(request.query)
if not embedding:
semantic_weight = 0
effective_search_size = request.search_size
if request.sort_by in ["change_pct", "outbreak_date"]:
effective_search_size = min(1000, request.search_size * 10)
# 构建查询
search_body = {}
match_type = "keyword"
if semantic_weight == 0:
search_body = {
"query": build_keyword_query(request.query, request.filter_stocks, request.search_reason),
"size": effective_search_size
}
match_type = "keyword"
elif request.use_knn and embedding:
keyword_weight = 1.0 - semantic_weight
search_body = {
"knn": {
"field": "description_embedding",
"query_vector": embedding,
"k": effective_search_size,
"num_candidates": min(effective_search_size * 2, 10000),
"boost": semantic_weight
},
"query": {
"bool": {
"must": [build_keyword_query(request.query, request.filter_stocks, request.search_reason)],
"boost": keyword_weight
}
},
"size": effective_search_size
}
match_type = "hybrid_knn"
else:
search_body = {
"query": build_keyword_query(request.query, request.filter_stocks, request.search_reason),
"size": effective_search_size
}
match_type = "keyword"
search_body.update({
"highlight": {
"fields": {
"concept": {},
"description": {"fragment_size": 150},
"tags": {},
"stock_reasons_text": {"fragment_size": 150} if request.search_reason else {}
}
},
"_source": {"excludes": ["description_embedding", "insight", "stock_reasons_text"]},
"track_total_hits": True
})
es_response = es_client.search(index=INDEX_NAME, body=search_body, timeout="30s")
all_results = []
concept_ids = []
for hit in es_response['hits']['hits']:
source = hit['_source']
concept_name = source.get('concept', '')
concept_id = source.get('concept_id', '')
concept_ids.append(concept_id)
stocks = parse_stocks_from_es(source, include_reasons=request.include_stock_reasons)
hierarchy_info = get_concept_hierarchy(concept_name)
hierarchy = HierarchyInfo(**hierarchy_info) if hierarchy_info else None
# 层级过滤
if request.filter_lv1 and (not hierarchy_info or hierarchy_info.get('lv1') != request.filter_lv1):
continue
if request.filter_lv2 and (not hierarchy_info or hierarchy_info.get('lv2') != request.filter_lv2):
continue
highlights = hit.get('highlight', {})
result = ConceptResult(
concept_id=concept_id,
concept=concept_name,
description=source.get('description'),
tags=source.get('tags', []),
outbreak_dates=source.get('outbreak_dates', []),
stocks=stocks,
stock_count=source.get('stock_count', len(source.get('stocks', []))),
hierarchy=hierarchy,
score=hit['_score'] or 0,
match_type=match_type,
highlights=highlights,
price_info=None
)
all_results.append(result)
# 获取涨跌幅数据
price_data = {}
actual_price_date = None
if concept_ids:
price_data = await get_concept_price_data(concept_ids, request.trade_date)
if price_data:
actual_price_date = next(iter(price_data.values())).trade_date
for result in all_results:
result.price_info = price_data.get(result.concept_id)
# 排序
if request.sort_by == "change_pct":
all_results.sort(
key=lambda x: x.price_info.avg_change_pct if x.price_info and x.price_info.avg_change_pct is not None else -999,
reverse=True
)
elif request.sort_by == "stock_count":
all_results.sort(key=lambda x: x.stock_count, reverse=True)
elif request.sort_by == "outbreak_date":
all_results.sort(
key=lambda x: x.outbreak_dates[0] if x.outbreak_dates else '1900-01-01',
reverse=True
)
# 分页
total_results = len(all_results)
total_pages = max(1, (total_results + request.size - 1) // request.size)
current_page = min(request.page, total_pages)
start_idx = (current_page - 1) * request.size
end_idx = start_idx + request.size
page_results = all_results[start_idx:end_idx]
took_ms = int((datetime.now() - start_time).total_seconds() * 1000)
return SearchResponse(
total=es_response['hits']['total']['value'],
took_ms=took_ms,
results=page_results,
search_info={
"query": request.query,
"semantic_weight": semantic_weight,
"match_type": match_type,
"search_reason": request.search_reason,
"filter_lv1": request.filter_lv1,
"filter_lv2": request.filter_lv2,
"sort_by": request.sort_by,
"has_embedding": bool(embedding)
},
price_date=actual_price_date,
page=current_page,
total_pages=total_pages
)
except Exception as e:
logger.error(f"搜索失败: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@app.post("/search/reason", response_model=ReasonSearchResponse, tags=["Search"])
async def search_by_reason(request: ReasonSearchRequest):
"""
按股票原因搜索概念
搜索包含特定关键词的股票关联原因
例如:搜索"AI芯片"可以找到所有reason中提到AI芯片的概念和股票
"""
start_time = datetime.now()
try:
search_body = {
"query": {
"bool": {
"should": [
# 搜索聚合的reason文本性能更好
{
"match": {
"stock_reasons_text": {
"query": request.query,
"boost": 2.0
}
}
},
# nested查询可定位具体股票
{
"nested": {
"path": "stocks",
"query": {
"match": {
"stocks.reason": request.query
}
},
"inner_hits": {
"size": 5,
"highlight": {
"fields": {
"stocks.reason": {"fragment_size": 200}
}
}
} if request.include_stock_details else {}
}
}
]
}
},
"highlight": {
"fields": {
"stock_reasons_text": {"fragment_size": 200, "number_of_fragments": 3}
}
},
"_source": ["concept_id", "concept", "description", "stocks"],
"size": request.size * request.page # 获取足够数据进行分页
}
es_response = es_client.search(index=INDEX_NAME, body=search_body, timeout="30s")
results = []
for hit in es_response['hits']['hits']:
source = hit['_source']
# 提取匹配的股票
matched_stocks = []
if request.include_stock_details and 'inner_hits' in hit and 'stocks' in hit['inner_hits']:
for inner in hit['inner_hits']['stocks']['hits']['hits']:
stock_source = inner['_source']
reason = stock_source.get('reason', '')
# 如果有高亮,使用高亮内容
if 'highlight' in inner and 'stocks.reason' in inner['highlight']:
reason = inner['highlight']['stocks.reason'][0]
matched_stocks.append(StockInfo(
name=stock_source.get('name', ''),
code=stock_source.get('code'),
reason=reason
))
# 提取reason高亮
reason_highlights = None
if 'highlight' in hit and 'stock_reasons_text' in hit['highlight']:
reason_highlights = hit['highlight']['stock_reasons_text']
results.append(ReasonSearchResult(
concept_id=source.get('concept_id', ''),
concept=source.get('concept', ''),
description=source.get('description'),
score=hit['_score'] or 0,
matched_stocks=matched_stocks,
reason_highlights=reason_highlights
))
# 分页
total = es_response['hits']['total']['value']
total_pages = max(1, (total + request.size - 1) // request.size)
start_idx = (request.page - 1) * request.size
page_results = results[start_idx:start_idx + request.size]
took_ms = int((datetime.now() - start_time).total_seconds() * 1000)
return ReasonSearchResponse(
total=total,
took_ms=took_ms,
query=request.query,
results=page_results,
page=request.page,
total_pages=total_pages
)
except Exception as e:
logger.error(f"原因搜索失败: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@app.get("/concept/{concept_id}", response_model=ConceptDetailResponse, tags=["Concepts"])
async def get_concept_detail(
concept_id: str,
trade_date: Optional[date] = Query(None)
):
"""获取概念详情(包含完整股票信息和原因)"""
try:
result = es_client.get(index=INDEX_NAME, id=concept_id)
source = result['_source']
concept_name = source.get('concept', '')
# 解析完整股票信息含reason
stocks_full = parse_stocks_full(source)
hierarchy_info = get_concept_hierarchy(concept_name)
hierarchy = HierarchyInfo(**hierarchy_info) if hierarchy_info else None
price_data = await get_concept_price_data([concept_id], trade_date)
price_info = price_data.get(concept_id)
return ConceptDetailResponse(
concept_id=concept_id,
concept=concept_name,
description=source.get('description'),
insight=source.get('insight'),
tags=source.get('tags', []),
outbreak_dates=source.get('outbreak_dates', []),
stocks=stocks_full,
stock_count=len(stocks_full),
hierarchy=hierarchy,
folders=source.get('folders', []),
created_at=source.get('created_at'),
price_info=price_info
)
except Exception as e:
if "NotFoundError" in str(type(e)):
raise HTTPException(status_code=404, detail="概念不存在")
logger.error(f"获取概念详情失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/stock/{stock_code}/concepts", tags=["Stocks"])
async def get_stock_concepts(
stock_code: str,
size: int = Query(50, ge=1, le=200),
trade_date: Optional[date] = Query(None),
include_reason: bool = Query(True, description="是否返回关联原因")
):
"""根据股票代码或名称查询相关概念"""
try:
# 使用keyword字段快速匹配
query = {
"bool": {
"should": [
{"term": {"stock_codes": stock_code}},
{"term": {"stock_names": stock_code}},
{"match": {"stock_names.text": stock_code}}
],
"minimum_should_match": 1
}
}
search_body = {
"query": query,
"size": size,
"_source": ["concept_id", "concept", "description", "tags", "stocks", "stock_count", "outbreak_dates"]
}
# 如果需要返回reason添加nested inner_hits
if include_reason:
search_body["query"] = {
"bool": {
"should": [
{"term": {"stock_codes": stock_code}},
{"term": {"stock_names": stock_code}},
{
"nested": {
"path": "stocks",
"query": {
"bool": {
"should": [
{"term": {"stocks.code": stock_code}},
{"match": {"stocks.name.text": stock_code}}
]
}
},
"inner_hits": {
"size": 1,
"_source": ["name", "code", "reason"]
}
}
}
],
"minimum_should_match": 1
}
}
es_response = es_client.search(index=INDEX_NAME, body=search_body)
concept_ids = []
concepts = []
for hit in es_response['hits']['hits']:
source = hit['_source']
concept_name = source.get('concept', '')
concept_id = source.get('concept_id', '')
concept_ids.append(concept_id)
hierarchy_info = get_concept_hierarchy(concept_name)
concept_data = {
"concept_id": concept_id,
"concept": concept_name,
"description": source.get('description'),
"tags": source.get('tags', []),
"stock_count": source.get('stock_count', len(source.get('stocks', []))),
"hierarchy": hierarchy_info,
"outbreak_dates": source.get('outbreak_dates', [])
}
# 提取该股票的关联原因
if include_reason and 'inner_hits' in hit and 'stocks' in hit['inner_hits']:
inner = hit['inner_hits']['stocks']['hits']['hits']
if inner:
stock_info = inner[0]['_source']
concept_data['stock_reason'] = stock_info.get('reason')
concept_data['matched_stock'] = {
'name': stock_info.get('name'),
'code': stock_info.get('code')
}
concepts.append(concept_data)
# 获取涨跌幅
price_data = await get_concept_price_data(concept_ids, trade_date)
for concept in concepts:
price_info = price_data.get(concept['concept_id'])
concept['price_info'] = {
'trade_date': str(price_info.trade_date),
'avg_change_pct': price_info.avg_change_pct
} if price_info else None
return {
"stock_code": stock_code,
"total": es_response['hits']['total']['value'],
"concepts": concepts
}
except Exception as e:
logger.error(f"查询股票概念失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/stock/search", tags=["Stocks"])
async def search_stocks(
keyword: str = Query(..., description="股票关键词(代码或名称)"),
size: int = Query(20, ge=1, le=100)
):
"""搜索股票名称或代码"""
try:
# V4索引使用独立的stock_names和stock_codes字段
search_body = {
"query": {
"bool": {
"should": [
{"wildcard": {"stock_codes": f"*{keyword}*"}},
{"match": {"stock_names.text": keyword}},
{"wildcard": {"stock_names": f"*{keyword}*"}}
]
}
},
"size": 0,
"aggs": {
"by_stock_code": {
"terms": {
"field": "stock_codes",
"size": size * 3
}
},
"by_stock_name": {
"terms": {
"field": "stock_names",
"size": size * 3
}
}
}
}
es_response = es_client.search(index=INDEX_NAME, body=search_body)
# 合并结果
stocks_map = {}
# 从stock_codes聚合
for bucket in es_response['aggregations']['by_stock_code']['buckets']:
code = bucket['key']
if keyword in code:
if code not in stocks_map:
stocks_map[code] = {'code': code, 'name': None, 'count': bucket['doc_count']}
else:
stocks_map[code]['count'] = max(stocks_map[code]['count'], bucket['doc_count'])
# 从stock_names聚合
for bucket in es_response['aggregations']['by_stock_name']['buckets']:
name = bucket['key']
if keyword in name:
# 需要找到对应的code
stocks_map[name] = {'code': None, 'name': name, 'count': bucket['doc_count']}
stocks = []
for key, data in stocks_map.items():
stocks.append({
'stock_code': data.get('code') or '',
'stock_name': data.get('name') or key,
'concept_count': data['count']
})
stocks.sort(key=lambda x: x['concept_count'], reverse=True)
return {
"total": len(stocks),
"stocks": stocks[:size]
}
except Exception as e:
logger.error(f"搜索股票失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
# ==================== 层级相关接口 ====================
@app.get("/hierarchy", response_model=HierarchyResponse, tags=["Hierarchy"])
async def get_hierarchy():
"""获取完整层级结构"""
try:
result = []
for lv1 in hierarchy_data.get('hierarchy', []):
lv1_node = HierarchyLevel(
id=lv1.get('lv1_id', ''),
name=lv1.get('lv1', ''),
concept_count=0,
children=[]
)
for child in lv1.get('children', []):
lv2_node = HierarchyLevel(
id=child.get('lv2_id', ''),
name=child.get('lv2', ''),
concept_count=0,
children=[]
)
if 'children' in child:
for lv3_child in child.get('children', []):
concepts = lv3_child.get('concepts', [])
lv3_node = HierarchyLevel(
id=lv3_child.get('lv3_id', ''),
name=lv3_child.get('lv3', ''),
concept_count=len(concepts),
concepts=concepts
)
lv2_node.children.append(lv3_node)
lv2_node.concept_count += len(concepts)
else:
concepts = child.get('concepts', [])
lv2_node.concept_count = len(concepts)
lv2_node.concepts = concepts
lv1_node.children.append(lv2_node)
lv1_node.concept_count += lv2_node.concept_count
result.append(lv1_node)
total_concepts = sum(lv1.concept_count for lv1 in result)
return HierarchyResponse(hierarchy=result, total_concepts=total_concepts)
except Exception as e:
logger.error(f"获取层级结构失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
# 注意:/hierarchy/price 必须在 /hierarchy/{lv1_id} 之前定义
@app.get("/hierarchy/price", tags=["Hierarchy"])
async def get_hierarchy_price(
trade_date: Optional[date] = Query(None, description="交易日期,默认最新"),
lv1_filter: Optional[str] = Query(None, description="筛选特定一级分类")
):
"""获取所有概念的涨跌幅数据包含母概念lv1/lv2/lv3和叶子概念leaf"""
if not mysql_pool:
raise HTTPException(status_code=503, detail="MySQL连接不可用")
try:
async with mysql_pool.acquire() as conn:
async with conn.cursor(aiomysql.DictCursor) as cursor:
query_date = trade_date
if query_date is None:
await cursor.execute("SELECT MAX(trade_date) as max_date FROM concept_daily_stats")
result = await cursor.fetchone()
if not result or not result['max_date']:
raise HTTPException(status_code=404, detail="无涨跌幅数据")
query_date = result['max_date']
base_query = """
SELECT concept_id, concept_name, concept_type, trade_date, avg_change_pct, stock_count
FROM concept_daily_stats WHERE trade_date = %s AND concept_type = %s
"""
if lv1_filter:
base_query += " AND concept_name LIKE %s"
base_query += " ORDER BY avg_change_pct DESC"
lv1_concepts, lv2_concepts, lv3_concepts, leaf_concepts = [], [], [], []
for ctype, target_list in [('lv1', lv1_concepts), ('lv2', lv2_concepts),
('lv3', lv3_concepts), ('leaf', leaf_concepts)]:
params = (query_date, ctype, f"%{lv1_filter}%") if lv1_filter else (query_date, ctype)
await cursor.execute(base_query, params)
for row in await cursor.fetchall():
item = {
"concept_id": row['concept_id'],
"concept_name": row['concept_name'],
"concept_type": ctype,
"trade_date": str(row['trade_date']),
"avg_change_pct": float(row['avg_change_pct']) if row['avg_change_pct'] else None,
"stock_count": row['stock_count']
}
if ctype == 'leaf':
item["hierarchy"] = get_concept_hierarchy(row['concept_name'])
target_list.append(item)
return {
"trade_date": str(query_date),
"lv1_concepts": lv1_concepts,
"lv2_concepts": lv2_concepts,
"lv3_concepts": lv3_concepts,
"leaf_concepts": leaf_concepts,
"total_count": len(lv1_concepts) + len(lv2_concepts) + len(lv3_concepts) + len(leaf_concepts),
"summary": {
"lv1_count": len(lv1_concepts), "lv2_count": len(lv2_concepts),
"lv3_count": len(lv3_concepts), "leaf_count": len(leaf_concepts)
}
}
except HTTPException:
raise
except Exception as e:
logger.error(f"获取层级涨跌幅失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/hierarchy/{lv1_id}", tags=["Hierarchy"])
async def get_hierarchy_level(
lv1_id: str,
lv2_id: Optional[str] = Query(None, description="二级分类ID")
):
"""获取指定层级的概念列表"""
try:
for lv1 in hierarchy_data.get('hierarchy', []):
if lv1.get('lv1_id') == lv1_id:
if lv2_id:
for child in lv1.get('children', []):
if child.get('lv2_id') == lv2_id:
concepts = []
if 'children' in child:
for lv3_child in child.get('children', []):
concepts.extend(lv3_child.get('concepts', []))
else:
concepts = child.get('concepts', [])
return {
"lv1": lv1.get('lv1'), "lv1_id": lv1_id,
"lv2": child.get('lv2'), "lv2_id": lv2_id,
"concepts": concepts, "concept_count": len(concepts)
}
raise HTTPException(status_code=404, detail=f"未找到二级分类: {lv2_id}")
else:
concepts = []
for child in lv1.get('children', []):
if 'children' in child:
for lv3_child in child.get('children', []):
concepts.extend(lv3_child.get('concepts', []))
else:
concepts.extend(child.get('concepts', []))
return {
"lv1": lv1.get('lv1'), "lv1_id": lv1_id,
"concepts": concepts, "concept_count": len(concepts),
"children": lv1.get('children', [])
}
raise HTTPException(status_code=404, detail=f"未找到一级分类: {lv1_id}")
except HTTPException:
raise
except Exception as e:
logger.error(f"获取层级详情失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/statistics/hierarchy", tags=["Statistics"])
async def get_hierarchy_statistics():
"""获取层级统计信息"""
try:
stats = []
for lv1 in hierarchy_data.get('hierarchy', []):
lv1_stats = {
"lv1": lv1.get('lv1'), "lv1_id": lv1.get('lv1_id'),
"lv2_count": len(lv1.get('children', [])), "concept_count": 0, "children": []
}
for child in lv1.get('children', []):
lv2_count = 0
if 'children' in child:
for lv3_child in child.get('children', []):
lv2_count += len(lv3_child.get('concepts', []))
else:
lv2_count = len(child.get('concepts', []))
lv1_stats['children'].append({
"lv2": child.get('lv2'), "lv2_id": child.get('lv2_id'), "concept_count": lv2_count
})
lv1_stats['concept_count'] += lv2_count
stats.append(lv1_stats)
stats.sort(key=lambda x: x['concept_count'], reverse=True)
return {
"total_lv1": len(stats),
"total_concepts": sum(s['concept_count'] for s in stats),
"statistics": stats
}
except Exception as e:
logger.error(f"获取层级统计失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
# ==================== 价格相关接口 ====================
@app.get("/price/latest", tags=["Price"])
async def get_latest_price_date():
"""获取最新的涨跌幅数据日期"""
try:
latest_date = await get_latest_trade_date()
return {"latest_trade_date": latest_date, "has_data": latest_date is not None}
except Exception as e:
logger.error(f"获取最新价格日期失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/concept/{concept_id}/price-timeseries", response_model=PriceTimeSeriesResponse, tags=["Price"])
async def get_concept_price_timeseries(
concept_id: str,
start_date: date = Query(..., description="开始日期"),
end_date: date = Query(..., description="结束日期")
):
"""获取概念在指定日期范围内的涨跌幅时间序列数据"""
if not mysql_pool:
return PriceTimeSeriesResponse(
concept_id=concept_id, concept_name=concept_id,
start_date=start_date, end_date=end_date, data_points=0, timeseries=[]
)
if start_date > end_date:
raise HTTPException(status_code=400, detail="开始日期不能晚于结束日期")
try:
async with mysql_pool.acquire() as conn:
async with conn.cursor(aiomysql.DictCursor) as cursor:
query = """
SELECT trade_date, concept_name, avg_change_pct, stock_count
FROM concept_daily_stats
WHERE concept_id = %s AND trade_date >= %s AND trade_date <= %s
ORDER BY trade_date ASC
"""
await cursor.execute(query, (concept_id, start_date, end_date))
rows = await cursor.fetchall()
if not rows:
raise HTTPException(status_code=404, detail=f"未找到概念ID {concept_id} 的数据")
timeseries = []
concept_name = ""
for row in rows:
if not concept_name:
concept_name = row['concept_name']
timeseries.append(PriceTimeSeriesItem(
trade_date=row['trade_date'],
avg_change_pct=float(row['avg_change_pct']) if row['avg_change_pct'] else None,
stock_count=row.get('stock_count')
))
return PriceTimeSeriesResponse(
concept_id=concept_id, concept_name=concept_name or concept_id,
start_date=start_date, end_date=end_date,
data_points=len(timeseries), timeseries=timeseries
)
except HTTPException:
raise
except Exception as e:
logger.error(f"获取价格时间序列失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/price/list", response_model=ConceptPriceListResponse, tags=["Price"])
async def get_concept_price_list(
trade_date: Optional[date] = Query(None, description="交易日期,默认最新"),
concept_type: Optional[str] = Query(None, description="概念类型: leaf/lv1/lv2/lv3"),
sort_by: str = Query("change_desc", description="排序方式"),
limit: int = Query(100, ge=1, le=1000),
offset: int = Query(0, ge=0)
):
"""批量获取概念涨跌幅列表"""
if not mysql_pool:
raise HTTPException(status_code=503, detail="MySQL连接不可用")
try:
async with mysql_pool.acquire() as conn:
async with conn.cursor(aiomysql.DictCursor) as cursor:
if trade_date is None:
await cursor.execute("SELECT MAX(trade_date) as max_date FROM concept_daily_stats")
result = await cursor.fetchone()
if not result or not result['max_date']:
raise HTTPException(status_code=404, detail="无涨跌幅数据")
trade_date = result['max_date']
where_conditions = ["trade_date = %s"]
params = [trade_date]
if concept_type and concept_type in ['leaf', 'lv1', 'lv2', 'lv3']:
where_conditions.append("concept_type = %s")
params.append(concept_type)
order_clause = "avg_change_pct DESC"
if sort_by == "change_asc":
order_clause = "avg_change_pct ASC"
elif sort_by == "stock_count":
order_clause = "stock_count DESC"
count_query = f"SELECT COUNT(*) as cnt FROM concept_daily_stats WHERE {' AND '.join(where_conditions)}"
await cursor.execute(count_query, params)
total = (await cursor.fetchone())['cnt']
query = f"""
SELECT concept_id, concept_name, concept_type, trade_date, avg_change_pct, stock_count
FROM concept_daily_stats WHERE {' AND '.join(where_conditions)}
ORDER BY {order_clause} LIMIT %s OFFSET %s
"""
await cursor.execute(query, params + [limit, offset])
rows = await cursor.fetchall()
concepts = []
for row in rows:
hierarchy = None
if row['concept_type'] == 'leaf':
hi = get_concept_hierarchy(row['concept_name'])
if hi:
hierarchy = HierarchyInfo(**hi)
concepts.append(ConceptPriceItem(
concept_id=row['concept_id'], concept_name=row['concept_name'],
concept_type=row['concept_type'], trade_date=row['trade_date'],
avg_change_pct=float(row['avg_change_pct']) if row['avg_change_pct'] else None,
stock_count=row['stock_count'], hierarchy=hierarchy
))
return ConceptPriceListResponse(trade_date=trade_date, total=total, concepts=concepts)
except HTTPException:
raise
except Exception as e:
logger.error(f"获取概念涨跌幅列表失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/hierarchy/{hierarchy_id}/price-timeseries", tags=["Hierarchy"])
async def get_hierarchy_price_timeseries(
hierarchy_id: str,
start_date: date = Query(..., description="开始日期"),
end_date: date = Query(..., description="结束日期")
):
"""获取层级概念lv1/lv2/lv3的价格时间序列"""
if not mysql_pool:
raise HTTPException(status_code=503, detail="MySQL连接不可用")
if start_date > end_date:
raise HTTPException(status_code=400, detail="开始日期不能晚于结束日期")
try:
async with mysql_pool.acquire() as conn:
async with conn.cursor(aiomysql.DictCursor) as cursor:
query = """
SELECT concept_id, concept_name, concept_type, trade_date, avg_change_pct, stock_count
FROM concept_daily_stats
WHERE concept_id = %s AND trade_date >= %s AND trade_date <= %s
ORDER BY trade_date ASC
"""
await cursor.execute(query, (hierarchy_id, start_date, end_date))
rows = await cursor.fetchall()
if not rows:
raise HTTPException(status_code=404, detail=f"未找到层级概念 {hierarchy_id} 的数据")
timeseries = []
concept_name, concept_type = "", ""
for row in rows:
if not concept_name:
concept_name = row['concept_name']
concept_type = row['concept_type']
timeseries.append({
'trade_date': row['trade_date'],
'avg_change_pct': float(row['avg_change_pct']) if row['avg_change_pct'] else None,
'stock_count': row['stock_count']
})
return {
'concept_id': hierarchy_id, 'concept_name': concept_name,
'concept_type': concept_type, 'start_date': start_date, 'end_date': end_date,
'data_points': len(timeseries), 'timeseries': timeseries
}
except HTTPException:
raise
except Exception as e:
logger.error(f"获取层级价格时间序列失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
# ==================== 统计接口 ====================
@app.get("/statistics", response_model=ConceptStatisticsResponse, tags=["Statistics"])
async def get_concept_statistics(
days: Optional[int] = Query(None, ge=1, le=90),
start_date: Optional[date] = Query(None),
end_date: Optional[date] = Query(None),
min_stock_count: int = Query(3, ge=1),
concept_type: Optional[str] = Query(None)
):
"""获取概念板块统计数据"""
from datetime import timedelta
if days is not None and (start_date is not None or end_date is not None):
raise HTTPException(status_code=400, detail="days参数与start_date/end_date参数不能同时使用")
if start_date is not None and end_date is not None:
if start_date > end_date:
raise HTTPException(status_code=400, detail="开始日期不能晚于结束日期")
elif days is not None:
end_date = datetime.now().date()
start_date = end_date - timedelta(days=days)
elif start_date is not None:
end_date = datetime.now().date()
elif end_date is not None:
start_date = end_date - timedelta(days=7)
else:
end_date = datetime.now().date()
start_date = end_date - timedelta(days=7)
if not mysql_pool:
return ConceptStatisticsResponse(
success=True, data=get_fallback_statistics(start_date, end_date),
params={'start_date': str(start_date), 'end_date': str(end_date)},
note="MySQL连接不可用"
)
type_filter = concept_type if concept_type in ['leaf', 'lv1', 'lv2', 'lv3'] else 'leaf'
try:
async with mysql_pool.acquire() as conn:
async with conn.cursor(aiomysql.DictCursor) as cursor:
# 涨幅榜
await cursor.execute("""
SELECT concept_id, concept_name, AVG(avg_change_pct) as avg_change_pct,
AVG(stock_count) as avg_stock_count, COUNT(*) as trading_days
FROM concept_daily_stats
WHERE trade_date >= %s AND trade_date <= %s AND avg_change_pct IS NOT NULL
AND stock_count >= %s AND concept_type = %s
GROUP BY concept_id, concept_name HAVING COUNT(*) >= 2
ORDER BY AVG(avg_change_pct) DESC LIMIT 10
""", (start_date, end_date, min_stock_count, type_filter))
hot_rows = await cursor.fetchall()
# 跌幅榜
await cursor.execute("""
SELECT concept_id, concept_name, AVG(avg_change_pct) as avg_change_pct,
AVG(stock_count) as avg_stock_count, COUNT(*) as trading_days
FROM concept_daily_stats
WHERE trade_date >= %s AND trade_date <= %s AND avg_change_pct IS NOT NULL
AND stock_count >= %s AND concept_type = %s
GROUP BY concept_id, concept_name HAVING COUNT(*) >= 2
ORDER BY AVG(avg_change_pct) ASC LIMIT 10
""", (start_date, end_date, min_stock_count, type_filter))
cold_rows = await cursor.fetchall()
# 活跃榜
await cursor.execute("""
SELECT concept_id, concept_name, COUNT(*) as trading_days, AVG(stock_count) as avg_stock_count
FROM concept_daily_stats
WHERE trade_date >= %s AND trade_date <= %s AND stock_count >= %s AND concept_type = %s
GROUP BY concept_id, concept_name ORDER BY COUNT(*) DESC LIMIT 10
""", (start_date, end_date, min_stock_count, type_filter))
active_rows = await cursor.fetchall()
# 波动榜
await cursor.execute("""
SELECT concept_id, concept_name, STDDEV(avg_change_pct) as volatility,
AVG(avg_change_pct) as avg_change_pct, MAX(avg_change_pct) as max_change_pct
FROM concept_daily_stats
WHERE trade_date >= %s AND trade_date <= %s AND avg_change_pct IS NOT NULL
AND stock_count >= %s AND concept_type = %s
GROUP BY concept_id, concept_name
HAVING COUNT(*) >= 3 AND STDDEV(avg_change_pct) IS NOT NULL
ORDER BY STDDEV(avg_change_pct) DESC LIMIT 10
""", (start_date, end_date, min_stock_count, type_filter))
volatile_rows = await cursor.fetchall()
# 连涨榜
await cursor.execute("""
SELECT concept_id, concept_name, COUNT(*) as positive_days,
SUM(avg_change_pct) as total_change, AVG(avg_change_pct) as avg_change_pct
FROM concept_daily_stats
WHERE trade_date >= %s AND trade_date <= %s AND avg_change_pct > 0
AND stock_count >= %s AND concept_type = %s
GROUP BY concept_id, concept_name HAVING COUNT(*) >= 2
ORDER BY COUNT(*) DESC LIMIT 10
""", (start_date, end_date, min_stock_count, type_filter))
momentum_rows = await cursor.fetchall()
# 总体统计
await cursor.execute("""
SELECT COUNT(DISTINCT concept_id) as total_concepts,
COUNT(DISTINCT CASE WHEN avg_change_pct > 0 THEN concept_id END) as positive_concepts,
COUNT(DISTINCT CASE WHEN avg_change_pct < 0 THEN concept_id END) as negative_concepts,
AVG(avg_change_pct) as overall_avg_change
FROM concept_daily_stats
WHERE trade_date >= %s AND trade_date <= %s AND avg_change_pct IS NOT NULL AND concept_type = %s
""", (start_date, end_date, type_filter))
total_row = await cursor.fetchone()
def build_items(rows, item_type):
items = []
for row in rows:
hi = get_concept_hierarchy(row['concept_name'])
item = ConceptStatItem(
name=row['concept_name'], concept_id=row.get('concept_id'),
hierarchy=HierarchyInfo(**hi) if hi else None
)
if item_type in ['hot', 'cold']:
item.change_pct = round(float(row['avg_change_pct']), 2) if row['avg_change_pct'] else 0.0
item.stock_count = int(row['avg_stock_count']) if row['avg_stock_count'] else 0
item.news_count = int(row['trading_days']) if row['trading_days'] else 0
elif item_type == 'active':
item.news_count = int(row['trading_days']) if row['trading_days'] else 0
item.stock_count = int(row['avg_stock_count']) if row['avg_stock_count'] else 0
item.report_count = max(1, int(row['trading_days'] * 0.3)) if row['trading_days'] else 1
item.total_mentions = item.news_count + item.report_count
elif item_type == 'volatile':
item.volatility = round(float(row['volatility']), 2) if row['volatility'] else 0.0
item.avg_change = round(float(row['avg_change_pct']), 2) if row['avg_change_pct'] else 0.0
item.max_change = round(float(row['max_change_pct']), 2) if row['max_change_pct'] else 0.0
elif item_type == 'momentum':
item.consecutive_days = int(row['positive_days']) if row['positive_days'] else 0
item.total_change = round(float(row['total_change']), 2) if row['total_change'] else 0.0
item.avg_daily = round(float(row['avg_change_pct']), 2) if row['avg_change_pct'] else 0.0
items.append(item)
return items
statistics = ConceptStatistics(
hot_concepts=build_items(hot_rows, 'hot'),
cold_concepts=build_items(cold_rows, 'cold'),
active_concepts=build_items(active_rows, 'active'),
volatile_concepts=build_items(volatile_rows, 'volatile'),
momentum_concepts=build_items(momentum_rows, 'momentum'),
summary={
'total_concepts': int(total_row['total_concepts']) if total_row['total_concepts'] else 0,
'positive_count': int(total_row['positive_concepts']) if total_row['positive_concepts'] else 0,
'negative_count': int(total_row['negative_concepts']) if total_row['negative_concepts'] else 0,
'avg_change': round(float(total_row['overall_avg_change']), 2) if total_row['overall_avg_change'] else 0.0,
'update_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
'date_range': f"{start_date}{end_date}",
'days': (end_date - start_date).days + 1,
'start_date': str(start_date), 'end_date': str(end_date)
}
)
return ConceptStatisticsResponse(
success=True, data=statistics,
params={
'days': (end_date - start_date).days + 1, 'min_stock_count': min_stock_count,
'start_date': str(start_date), 'end_date': str(end_date), 'concept_type': type_filter
}
)
except HTTPException:
raise
except Exception as e:
logger.error(f"获取概念统计失败: {e}")
return ConceptStatisticsResponse(
success=True, data=get_fallback_statistics(start_date, end_date),
params={'start_date': str(start_date), 'end_date': str(end_date)},
note=f"使用示例数据: {str(e)}"
)
# ==================== 主函数 ====================
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"concept_api_v4:app",
host="0.0.0.0",
port=6801,
reload=True,
log_level="info"
)