Files
vf_react/concept_api_v2.py
2025-12-12 08:44:45 +08:00

2005 lines
75 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
概念搜索API V2 - 适配 concept_library_v3 索引
新特性:
1. 支持层级结构lv1/lv2/lv3
2. 返回 tags, outbreak_dates, insight 等新字段
3. 新增层级浏览接口
"""
import json
import openai
from typing import List, Dict, Optional, Union, Any
from fastapi import FastAPI, HTTPException, Query
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from elasticsearch import Elasticsearch
from datetime import datetime, date
import logging
import re
from contextlib import asynccontextmanager
import aiomysql
import os
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# ==================== 配置 ====================
# ES配置
ES_HOST = 'http://127.0.0.1:9200'
INDEX_NAME = 'concept_library_v3'
# Embedding配置
OPENAI_BASE_URL = "http://127.0.0.1:8000/v1"
OPENAI_API_KEY = "dummy"
EMBEDDING_MODEL = "qwen3-embedding-8b"
# 层级结构文件
HIERARCHY_FILE = 'concept_hierarchy_v3.json'
# MySQL配置
MYSQL_CONFIG = {
'host': '192.168.1.8',
'port': 3306,
'user': 'root',
'password': 'Zzl5588161!',
'db': 'stock',
'charset': 'utf8mb4',
'autocommit': True,
'minsize': 1,
'maxsize': 10
}
# ==================== 全局变量 ====================
es_client = None
openai_client = None
mysql_pool = None
# 层级结构相关
hierarchy_data = {} # 原始层级数据
concept_to_hierarchy = {} # 概念名称 -> 层级信息的映射
def load_hierarchy():
"""加载层级结构并建立概念到层级的映射"""
global hierarchy_data, concept_to_hierarchy
hierarchy_path = os.path.join(os.path.dirname(__file__), HIERARCHY_FILE)
if not os.path.exists(hierarchy_path):
logger.warning(f"层级文件不存在: {hierarchy_path}")
return
try:
with open(hierarchy_path, 'r', encoding='utf-8') as f:
hierarchy_data = json.load(f)
# 遍历层级结构,建立概念到层级的映射
for lv1 in hierarchy_data.get('hierarchy', []):
lv1_name = lv1.get('lv1', '')
lv1_id = lv1.get('lv1_id', '')
for child in lv1.get('children', []):
lv2_name = child.get('lv2', '')
lv2_id = child.get('lv2_id', '')
# 检查是否有 lv3 子级
if 'children' in child:
for lv3_child in child.get('children', []):
lv3_name = lv3_child.get('lv3', '')
lv3_id = lv3_child.get('lv3_id', '')
# lv3 级别的概念
for concept in lv3_child.get('concepts', []):
concept_to_hierarchy[concept] = {
'lv1': lv1_name,
'lv1_id': lv1_id,
'lv2': lv2_name,
'lv2_id': lv2_id,
'lv3': lv3_name,
'lv3_id': lv3_id
}
else:
# lv2 级别直接有概念(没有 lv3
for concept in child.get('concepts', []):
concept_to_hierarchy[concept] = {
'lv1': lv1_name,
'lv1_id': lv1_id,
'lv2': lv2_name,
'lv2_id': lv2_id,
'lv3': None,
'lv3_id': None
}
logger.info(f"加载层级结构完成,共 {len(concept_to_hierarchy)} 个概念有层级信息")
except Exception as e:
logger.error(f"加载层级结构失败: {e}")
def get_concept_hierarchy(concept_name: str) -> Optional[Dict]:
"""获取概念的层级信息"""
return concept_to_hierarchy.get(concept_name)
# ==================== 生命周期管理 ====================
@asynccontextmanager
async def lifespan(app: FastAPI):
global es_client, openai_client, mysql_pool
# 加载层级结构
load_hierarchy()
# 初始化ES客户端
es_client = Elasticsearch(
[ES_HOST],
timeout=30,
max_retries=3,
retry_on_timeout=True
)
logger.info(f"Connected to Elasticsearch at {ES_HOST}, index: {INDEX_NAME}")
# 初始化OpenAI客户端
openai_client = openai.OpenAI(
api_key=OPENAI_API_KEY,
base_url=OPENAI_BASE_URL,
timeout=60,
)
logger.info(f"Initialized OpenAI client")
# 初始化MySQL连接池
try:
mysql_pool = await aiomysql.create_pool(**MYSQL_CONFIG)
logger.info(f"Connected to MySQL at {MYSQL_CONFIG['host']}")
except Exception as e:
logger.error(f"Failed to connect to MySQL: {e}")
yield
# 清理资源
if es_client:
es_client.close()
if mysql_pool:
mysql_pool.close()
await mysql_pool.wait_closed()
logger.info("Cleanup completed")
# ==================== FastAPI 应用 ====================
app = FastAPI(
title="概念搜索API V2",
description="支持层级结构的概念库搜索API适配 concept_library_v3 索引",
version="2.0.0",
lifespan=lifespan
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ==================== 数据模型 ====================
class HierarchyInfo(BaseModel):
"""层级信息"""
lv1: Optional[str] = Field(None, description="一级分类")
lv1_id: Optional[str] = None
lv2: Optional[str] = Field(None, description="二级分类")
lv2_id: Optional[str] = None
lv3: Optional[str] = Field(None, description="三级分类")
lv3_id: Optional[str] = None
class StockInfo(BaseModel):
"""股票信息"""
name: str = Field(..., description="股票名称")
code: Optional[str] = Field(None, description="股票代码")
class ConceptPriceInfo(BaseModel):
"""概念涨跌幅信息"""
trade_date: date
avg_change_pct: Optional[float] = Field(None, description="平均涨跌幅(%)")
class SearchRequest(BaseModel):
"""搜索请求"""
query: str = Field(..., description="搜索查询文本")
size: int = Field(10, ge=1, le=100, description="每页返回结果数量")
page: int = Field(1, ge=1, description="页码")
search_size: int = Field(100, ge=10, le=1000, description="搜索数量")
semantic_weight: Optional[float] = Field(None, ge=0.0, le=1.0, description="语义搜索权重")
filter_stocks: Optional[List[str]] = Field(None, description="过滤特定股票")
filter_lv1: Optional[str] = Field(None, description="过滤一级分类")
filter_lv2: Optional[str] = Field(None, description="过滤二级分类")
trade_date: Optional[date] = Field(None, description="交易日期")
sort_by: str = Field("_score", description="排序: _score, change_pct, stock_count, outbreak_date")
use_knn: bool = Field(True, description="是否使用KNN搜索")
class ConceptResult(BaseModel):
"""概念搜索结果"""
concept_id: str
concept: str
description: Optional[str] = None
tags: Optional[List[str]] = Field(None, description="标签")
outbreak_dates: Optional[List[str]] = Field(None, description="爆发日期")
stocks: List[StockInfo] = Field(default_factory=list)
stock_count: int = 0
hierarchy: Optional[HierarchyInfo] = Field(None, description="层级信息")
score: float = 0.0
match_type: str = "keyword"
highlights: Optional[Dict[str, List[str]]] = None
price_info: Optional[ConceptPriceInfo] = None
class SearchResponse(BaseModel):
"""搜索响应"""
total: int
took_ms: int
results: List[ConceptResult]
search_info: Dict[str, Any]
price_date: Optional[date] = None
page: int = 1
total_pages: int = 1
class HierarchyLevel(BaseModel):
"""层级节点"""
id: str
name: str
concept_count: int = 0
children: Optional[List['HierarchyLevel']] = None
concepts: Optional[List[str]] = None
class HierarchyResponse(BaseModel):
"""层级结构响应"""
hierarchy: List[HierarchyLevel]
total_concepts: int
class ConceptDetailResponse(BaseModel):
"""概念详情响应"""
concept_id: str
concept: str
description: Optional[str] = None
insight: Optional[str] = None
tags: Optional[List[str]] = None
outbreak_dates: Optional[List[str]] = None
stocks: List[Dict[str, Any]] = Field(default_factory=list)
stock_count: int = 0
hierarchy: Optional[HierarchyInfo] = None
folders: Optional[List[str]] = None
created_at: Optional[str] = None
price_info: Optional[ConceptPriceInfo] = None
# ==================== 辅助函数 ====================
def generate_embedding(text: str) -> List[float]:
"""生成文本向量"""
try:
if not text or len(text.strip()) == 0:
return []
text = text[:8000] if len(text) > 8000 else text
if not openai_client:
return []
response = openai_client.embeddings.create(
model=EMBEDDING_MODEL,
input=[text]
)
return response.data[0].embedding
except Exception as e:
logger.warning(f"Embedding生成失败: {e}")
return []
def calculate_semantic_weight(query: str) -> float:
"""根据查询长度动态计算语义权重"""
# 空查询不需要语义搜索
if not query or not query.strip():
return 0.0
query_length = len(query.strip())
if query_length < 10:
return 0.3
elif query_length < 50:
return 0.5
elif query_length < 200:
return 0.6
else:
return 0.7
async def get_top_concepts_by_change(
trade_date: Optional[date],
limit: int = 100,
offset: int = 0,
filter_lv1: Optional[str] = None,
filter_lv2: Optional[str] = None
) -> tuple[List[Dict], int, date]:
"""
直接从 MySQL 获取涨跌幅排序的概念列表(用于空查询优化)
返回: (概念列表, 总数, 实际查询日期)
"""
if not mysql_pool:
return [], 0, None
try:
async with mysql_pool.acquire() as conn:
async with conn.cursor(aiomysql.DictCursor) as cursor:
# 获取交易日期
query_date = trade_date
if query_date is None:
await cursor.execute("SELECT MAX(trade_date) as max_date FROM concept_daily_stats WHERE concept_type = 'leaf'")
result = await cursor.fetchone()
if not result or not result['max_date']:
return [], 0, None
query_date = result['max_date']
# 构建基础条件
where_conditions = ["trade_date = %s", "concept_type = 'leaf'"]
params = [query_date]
# 添加层级过滤(通过 concept_name 关联)
# 注意:这里需要根据 concept_to_hierarchy 过滤
if filter_lv1 or filter_lv2:
# 获取符合层级条件的概念名称
filtered_concepts = []
for concept_name, hierarchy in concept_to_hierarchy.items():
if filter_lv1 and hierarchy.get('lv1') != filter_lv1:
continue
if filter_lv2 and hierarchy.get('lv2') != filter_lv2:
continue
filtered_concepts.append(concept_name)
if not filtered_concepts:
return [], 0, query_date
placeholders = ','.join(['%s'] * len(filtered_concepts))
where_conditions.append(f"concept_name IN ({placeholders})")
params.extend(filtered_concepts)
where_clause = " AND ".join(where_conditions)
# 查询总数
count_query = f"SELECT COUNT(*) as cnt FROM concept_daily_stats WHERE {where_clause}"
await cursor.execute(count_query, params)
total = (await cursor.fetchone())['cnt']
# 查询数据(按涨跌幅降序)
data_query = f"""
SELECT concept_id, concept_name, avg_change_pct, stock_count
FROM concept_daily_stats
WHERE {where_clause}
ORDER BY avg_change_pct DESC
LIMIT %s OFFSET %s
"""
await cursor.execute(data_query, params + [limit, offset])
rows = await cursor.fetchall()
concepts = []
for row in rows:
concepts.append({
'concept_id': row['concept_id'],
'concept_name': row['concept_name'],
'avg_change_pct': float(row['avg_change_pct']) if row['avg_change_pct'] else None,
'stock_count': row['stock_count'],
'trade_date': query_date
})
return concepts, total, query_date
except Exception as e:
logger.error(f"获取涨跌幅排序概念失败: {e}")
return [], 0, None
async def get_concept_price_data(concept_ids: List[str], trade_date: Optional[date] = None) -> Dict[str, ConceptPriceInfo]:
"""获取概念的涨跌幅数据"""
if not mysql_pool or not concept_ids:
return {}
try:
async with mysql_pool.acquire() as conn:
async with conn.cursor(aiomysql.DictCursor) as cursor:
placeholders = ','.join(['%s'] * len(concept_ids))
if trade_date:
query = f"""
SELECT concept_id, trade_date, avg_change_pct
FROM concept_daily_stats
WHERE concept_id IN ({placeholders}) AND trade_date = %s
"""
await cursor.execute(query, (*concept_ids, trade_date))
else:
query = f"""
SELECT cds.concept_id, cds.trade_date, cds.avg_change_pct
FROM concept_daily_stats cds
INNER JOIN (
SELECT concept_id, MAX(trade_date) as max_date
FROM concept_daily_stats
WHERE concept_id IN ({placeholders})
GROUP BY concept_id
) latest ON cds.concept_id = latest.concept_id
AND cds.trade_date = latest.max_date
"""
await cursor.execute(query, concept_ids)
rows = await cursor.fetchall()
result = {}
for row in rows:
result[row['concept_id']] = ConceptPriceInfo(
trade_date=row['trade_date'],
avg_change_pct=float(row['avg_change_pct']) if row['avg_change_pct'] is not None else None
)
return result
except Exception as e:
logger.error(f"获取涨跌幅数据失败: {e}")
return {}
def build_keyword_query(query: str, stock_filters: List[str] = None) -> Dict:
"""构建关键词查询 - 适配 v3 索引"""
must_queries = []
if query.strip():
must_queries.append({
"multi_match": {
"query": query,
"fields": [
"concept^3",
"description^2",
"tags.text^2",
"tags^1.5"
],
"type": "best_fields",
"analyzer": "ik_smart"
}
})
if stock_filters:
stock_query = {
"nested": {
"path": "stocks",
"query": {
"bool": {
"should": [
{"terms": {"stocks.name": stock_filters}},
{"terms": {"stocks.code": stock_filters}}
]
}
}
}
}
must_queries.append(stock_query)
return {
"bool": {
"must": must_queries if must_queries else [{"match_all": {}}]
}
}
def parse_stocks_from_es(source: Dict) -> tuple:
"""从ES文档解析股票信息"""
# nested stocks 字段(只有 name, code
stocks_nested = source.get('stocks', [])
stocks_list = []
for s in stocks_nested[:30]: # 限制返回数量
stocks_list.append(StockInfo(
name=s.get('name', ''),
code=s.get('code')
))
# 完整股票信息从 stocks_json 解析
stocks_full = []
stocks_json = source.get('stocks_json', '')
if stocks_json:
try:
stocks_full = json.loads(stocks_json)
except:
pass
return stocks_list, stocks_full
# ==================== API 端点 ====================
@app.get("/", tags=["Health"])
async def root():
"""健康检查"""
return {
"status": "healthy",
"service": "概念搜索API V2",
"version": "2.0.0",
"index": INDEX_NAME,
"hierarchy_concepts": len(concept_to_hierarchy)
}
@app.post("/search", response_model=SearchResponse, tags=["Search"])
async def search_concepts(request: SearchRequest):
"""
搜索概念 - 支持语义搜索和层级过滤
优化:空查询 + 涨跌幅排序时,直接从 MySQL 查询,避免 ES 大量数据获取
"""
start_time = datetime.now()
try:
# ========== 优化路径:空查询 + 涨跌幅排序 ==========
# 这是概念中心首页的默认场景,直接从 MySQL 获取排序结果
is_empty_query = not request.query or not request.query.strip()
is_change_sort = request.sort_by == "change_pct"
no_stock_filter = not request.filter_stocks
if is_empty_query and is_change_sort and no_stock_filter:
logger.info(f"[Search优化] 使用 MySQL 快速路径: page={request.page}, size={request.size}")
# 计算分页
offset = (request.page - 1) * request.size
# 从 MySQL 获取涨跌幅排序的概念
top_concepts, total, actual_date = await get_top_concepts_by_change(
trade_date=request.trade_date,
limit=request.size,
offset=offset,
filter_lv1=request.filter_lv1,
filter_lv2=request.filter_lv2
)
if not top_concepts:
took_ms = int((datetime.now() - start_time).total_seconds() * 1000)
return SearchResponse(
total=0,
took_ms=took_ms,
results=[],
search_info={"query": "", "semantic_weight": 0, "match_type": "mysql_optimized"},
price_date=actual_date,
page=request.page,
total_pages=0
)
# 用 concept_id 从 ES 批量获取详细信息
concept_ids = [c['concept_id'] for c in top_concepts]
es_body = {
"query": {"terms": {"concept_id": concept_ids}},
"size": len(concept_ids),
"_source": {"excludes": ["description_embedding", "insight"]}
}
es_response = es_client.search(index=INDEX_NAME, body=es_body, timeout="10s")
# 构建 concept_id -> ES 详情的映射
es_details = {}
for hit in es_response['hits']['hits']:
source = hit['_source']
es_details[source.get('concept_id', '')] = source
# 组装结果(保持 MySQL 的排序顺序)
results = []
for mysql_concept in top_concepts:
cid = mysql_concept['concept_id']
es_data = es_details.get(cid, {})
concept_name = mysql_concept['concept_name']
# 解析股票
stocks_list, _ = parse_stocks_from_es(es_data)
# 获取层级信息
hierarchy_info = get_concept_hierarchy(concept_name)
hierarchy = HierarchyInfo(**hierarchy_info) if hierarchy_info else None
result = ConceptResult(
concept_id=cid,
concept=concept_name,
description=es_data.get('description'),
tags=es_data.get('tags', []),
outbreak_dates=es_data.get('outbreak_dates', []),
stocks=stocks_list,
stock_count=mysql_concept['stock_count'] or len(es_data.get('stocks', [])),
hierarchy=hierarchy,
score=0.0,
match_type="mysql_optimized",
highlights=None,
price_info=ConceptPriceInfo(
trade_date=actual_date,
avg_change_pct=mysql_concept['avg_change_pct']
)
)
results.append(result)
took_ms = int((datetime.now() - start_time).total_seconds() * 1000)
total_pages = max(1, (total + request.size - 1) // request.size)
logger.info(f"[Search优化] MySQL 快速路径完成: took={took_ms}ms, total={total}, returned={len(results)}")
return SearchResponse(
total=total,
took_ms=took_ms,
results=results,
search_info={
"query": "",
"semantic_weight": 0,
"match_type": "mysql_optimized",
"filter_lv1": request.filter_lv1,
"filter_lv2": request.filter_lv2,
"sort_by": request.sort_by
},
price_date=actual_date,
page=request.page,
total_pages=total_pages
)
# ========== 常规路径:有搜索词或其他排序方式 ==========
# 计算语义权重
if request.semantic_weight is not None:
semantic_weight = request.semantic_weight
else:
semantic_weight = calculate_semantic_weight(request.query)
# 生成embedding
embedding = []
if semantic_weight > 0:
embedding = generate_embedding(request.query)
if not embedding:
semantic_weight = 0
# 确定搜索数量
# 修复当按涨跌幅排序时需要获取全部概念800+)再排序
# 原来限制 500 会导致排序不准确
effective_search_size = request.search_size
if request.sort_by in ["change_pct", "outbreak_date"]:
effective_search_size = min(1000, request.search_size * 10)
# 构建查询
search_body = {}
match_type = "keyword"
if semantic_weight == 0:
# 纯关键词搜索
search_body = {
"query": build_keyword_query(request.query, request.filter_stocks),
"size": effective_search_size
}
match_type = "keyword"
elif request.use_knn and embedding:
# KNN + 关键词混合搜索
keyword_weight = 1.0 - semantic_weight
search_body = {
"knn": {
"field": "description_embedding",
"query_vector": embedding,
"k": effective_search_size,
"num_candidates": min(effective_search_size * 2, 10000),
"boost": semantic_weight
},
"query": {
"bool": {
"must": [build_keyword_query(request.query, request.filter_stocks)],
"boost": keyword_weight
}
},
"size": effective_search_size
}
match_type = "hybrid_knn"
else:
# 传统混合搜索
search_body = {
"query": build_keyword_query(request.query, request.filter_stocks),
"size": effective_search_size
}
match_type = "keyword"
# 添加高亮和源过滤
search_body.update({
"highlight": {
"fields": {
"concept": {},
"description": {"fragment_size": 150},
"tags": {}
}
},
"_source": {
"excludes": ["description_embedding", "insight"]
},
"track_total_hits": True
})
# 执行搜索
es_response = es_client.search(
index=INDEX_NAME,
body=search_body,
timeout="30s"
)
# 处理结果
all_results = []
concept_ids = []
for hit in es_response['hits']['hits']:
source = hit['_source']
concept_name = source.get('concept', '')
concept_id = source.get('concept_id', '')
concept_ids.append(concept_id)
# 解析股票
stocks_list, _ = parse_stocks_from_es(source)
# 获取层级信息
hierarchy_info = get_concept_hierarchy(concept_name)
hierarchy = None
if hierarchy_info:
hierarchy = HierarchyInfo(**hierarchy_info)
# 层级过滤
if request.filter_lv1 and (not hierarchy_info or hierarchy_info.get('lv1') != request.filter_lv1):
continue
if request.filter_lv2 and (not hierarchy_info or hierarchy_info.get('lv2') != request.filter_lv2):
continue
# 高亮信息
highlights = hit.get('highlight', {})
result = ConceptResult(
concept_id=concept_id,
concept=concept_name,
description=source.get('description'),
tags=source.get('tags', []),
outbreak_dates=source.get('outbreak_dates', []),
stocks=stocks_list,
stock_count=len(source.get('stocks', [])),
hierarchy=hierarchy,
score=hit['_score'] or 0,
match_type=match_type,
highlights=highlights,
price_info=None
)
all_results.append(result)
# 获取涨跌幅数据
price_data = {}
actual_price_date = None
if concept_ids:
price_data = await get_concept_price_data(concept_ids, request.trade_date)
if price_data:
actual_price_date = next(iter(price_data.values())).trade_date
for result in all_results:
result.price_info = price_data.get(result.concept_id)
# 排序
if request.sort_by == "change_pct":
all_results.sort(
key=lambda x: (
x.price_info.avg_change_pct
if x.price_info and x.price_info.avg_change_pct is not None
else -999
),
reverse=True
)
elif request.sort_by == "stock_count":
all_results.sort(key=lambda x: x.stock_count, reverse=True)
elif request.sort_by == "outbreak_date":
all_results.sort(
key=lambda x: (
x.outbreak_dates[0] if x.outbreak_dates else '1900-01-01'
),
reverse=True
)
# 分页
total_results = len(all_results)
total_pages = max(1, (total_results + request.size - 1) // request.size)
current_page = min(request.page, total_pages)
start_idx = (current_page - 1) * request.size
end_idx = start_idx + request.size
page_results = all_results[start_idx:end_idx]
took_ms = int((datetime.now() - start_time).total_seconds() * 1000)
return SearchResponse(
total=es_response['hits']['total']['value'],
took_ms=took_ms,
results=page_results,
search_info={
"query": request.query,
"semantic_weight": semantic_weight,
"match_type": match_type,
"filter_lv1": request.filter_lv1,
"filter_lv2": request.filter_lv2,
"sort_by": request.sort_by,
"has_embedding": bool(embedding)
},
price_date=actual_price_date,
page=current_page,
total_pages=total_pages
)
except Exception as e:
logger.error(f"搜索失败: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@app.get("/concept/{concept_id}", response_model=ConceptDetailResponse, tags=["Concepts"])
async def get_concept_detail(
concept_id: str,
trade_date: Optional[date] = Query(None, description="交易日期")
):
"""获取概念详情"""
try:
result = es_client.get(index=INDEX_NAME, id=concept_id)
source = result['_source']
concept_name = source.get('concept', '')
# 解析完整股票信息
_, stocks_full = parse_stocks_from_es(source)
# 获取层级信息
hierarchy_info = get_concept_hierarchy(concept_name)
hierarchy = HierarchyInfo(**hierarchy_info) if hierarchy_info else None
# 获取涨跌幅
price_data = await get_concept_price_data([concept_id], trade_date)
price_info = price_data.get(concept_id)
return ConceptDetailResponse(
concept_id=concept_id,
concept=concept_name,
description=source.get('description'),
insight=source.get('insight'),
tags=source.get('tags', []),
outbreak_dates=source.get('outbreak_dates', []),
stocks=stocks_full,
stock_count=len(stocks_full),
hierarchy=hierarchy,
folders=source.get('folders', []),
created_at=source.get('created_at'),
price_info=price_info
)
except Exception as e:
if "NotFoundError" in str(type(e)):
raise HTTPException(status_code=404, detail="概念不存在")
logger.error(f"获取概念详情失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/hierarchy", response_model=HierarchyResponse, tags=["Hierarchy"])
async def get_hierarchy():
"""获取完整层级结构"""
try:
result = []
for lv1 in hierarchy_data.get('hierarchy', []):
lv1_node = HierarchyLevel(
id=lv1.get('lv1_id', ''),
name=lv1.get('lv1', ''),
concept_count=0,
children=[]
)
for child in lv1.get('children', []):
lv2_node = HierarchyLevel(
id=child.get('lv2_id', ''),
name=child.get('lv2', ''),
concept_count=0,
children=[]
)
if 'children' in child:
# 有 lv3
for lv3_child in child.get('children', []):
concepts = lv3_child.get('concepts', [])
lv3_node = HierarchyLevel(
id=lv3_child.get('lv3_id', ''),
name=lv3_child.get('lv3', ''),
concept_count=len(concepts),
concepts=concepts
)
lv2_node.children.append(lv3_node)
lv2_node.concept_count += len(concepts)
else:
# lv2 直接有概念
concepts = child.get('concepts', [])
lv2_node.concept_count = len(concepts)
lv2_node.concepts = concepts
lv1_node.children.append(lv2_node)
lv1_node.concept_count += lv2_node.concept_count
result.append(lv1_node)
total_concepts = sum(lv1.concept_count for lv1 in result)
return HierarchyResponse(
hierarchy=result,
total_concepts=total_concepts
)
except Exception as e:
logger.error(f"获取层级结构失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
# 注意:/hierarchy/price 必须在 /hierarchy/{lv1_id} 之前定义,否则会被错误匹配
@app.get("/hierarchy/price", tags=["Hierarchy"])
async def get_hierarchy_price_early(
trade_date: Optional[date] = Query(None, description="交易日期,默认最新"),
lv1_filter: Optional[str] = Query(None, description="筛选特定一级分类")
):
"""获取所有概念的涨跌幅数据包含母概念lv1/lv2/lv3和叶子概念leaf"""
logger.info(f"[hierarchy/price] 请求参数: trade_date={trade_date}, lv1_filter={lv1_filter}")
if not mysql_pool:
logger.error("[hierarchy/price] MySQL连接池不可用")
raise HTTPException(status_code=503, detail="MySQL连接不可用")
try:
async with mysql_pool.acquire() as conn:
async with conn.cursor(aiomysql.DictCursor) as cursor:
# 获取交易日期
query_date = trade_date
if query_date is None:
await cursor.execute("SELECT MAX(trade_date) as max_date FROM concept_daily_stats")
result = await cursor.fetchone()
if not result or not result['max_date']:
raise HTTPException(status_code=404, detail="无涨跌幅数据")
query_date = result['max_date']
logger.info(f"[hierarchy/price] 使用查询日期: {query_date}")
# 构建查询
base_query = """
SELECT concept_id, concept_name, concept_type, trade_date, avg_change_pct, stock_count
FROM concept_daily_stats
WHERE trade_date = %s AND concept_type = %s
"""
if lv1_filter:
base_query += " AND concept_name LIKE %s"
base_query += " ORDER BY avg_change_pct DESC"
lv1_concepts = []
lv2_concepts = []
lv3_concepts = []
leaf_concepts = []
# 查询 lv1
if lv1_filter:
await cursor.execute(base_query, (query_date, 'lv1', f"%{lv1_filter}%"))
else:
await cursor.execute(base_query, (query_date, 'lv1'))
for row in await cursor.fetchall():
lv1_concepts.append({
"concept_id": row['concept_id'],
"concept_name": row['concept_name'],
"concept_type": 'lv1',
"trade_date": str(row['trade_date']),
"avg_change_pct": float(row['avg_change_pct']) if row['avg_change_pct'] else None,
"stock_count": row['stock_count']
})
# 查询 lv2
if lv1_filter:
await cursor.execute(base_query, (query_date, 'lv2', f"%{lv1_filter}%"))
else:
await cursor.execute(base_query, (query_date, 'lv2'))
for row in await cursor.fetchall():
lv2_concepts.append({
"concept_id": row['concept_id'],
"concept_name": row['concept_name'],
"concept_type": 'lv2',
"trade_date": str(row['trade_date']),
"avg_change_pct": float(row['avg_change_pct']) if row['avg_change_pct'] else None,
"stock_count": row['stock_count']
})
# 查询 lv3
if lv1_filter:
await cursor.execute(base_query, (query_date, 'lv3', f"%{lv1_filter}%"))
else:
await cursor.execute(base_query, (query_date, 'lv3'))
for row in await cursor.fetchall():
lv3_concepts.append({
"concept_id": row['concept_id'],
"concept_name": row['concept_name'],
"concept_type": 'lv3',
"trade_date": str(row['trade_date']),
"avg_change_pct": float(row['avg_change_pct']) if row['avg_change_pct'] else None,
"stock_count": row['stock_count']
})
# 查询叶子概念 leaf
await cursor.execute(base_query, (query_date, 'leaf'))
for row in await cursor.fetchall():
# 获取层级信息
hierarchy_info = get_concept_hierarchy(row['concept_name'])
leaf_concepts.append({
"concept_id": row['concept_id'],
"concept_name": row['concept_name'],
"concept_type": 'leaf',
"trade_date": str(row['trade_date']),
"avg_change_pct": float(row['avg_change_pct']) if row['avg_change_pct'] else None,
"stock_count": row['stock_count'],
"hierarchy": hierarchy_info
})
total = len(lv1_concepts) + len(lv2_concepts) + len(lv3_concepts) + len(leaf_concepts)
logger.info(f"[hierarchy/price] 返回结果: lv1={len(lv1_concepts)}, lv2={len(lv2_concepts)}, lv3={len(lv3_concepts)}, leaf={len(leaf_concepts)}")
return {
"trade_date": str(query_date),
"lv1_concepts": lv1_concepts,
"lv2_concepts": lv2_concepts,
"lv3_concepts": lv3_concepts,
"leaf_concepts": leaf_concepts,
"total_count": total,
"summary": {
"lv1_count": len(lv1_concepts),
"lv2_count": len(lv2_concepts),
"lv3_count": len(lv3_concepts),
"leaf_count": len(leaf_concepts)
}
}
except HTTPException as he:
logger.error(f"[hierarchy/price] HTTPException: {he.status_code} - {he.detail}")
raise
except Exception as e:
logger.error(f"[hierarchy/price] 获取层级涨跌幅失败: {e}")
import traceback
logger.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
@app.get("/hierarchy/{lv1_id}", tags=["Hierarchy"])
async def get_hierarchy_level(
lv1_id: str,
lv2_id: Optional[str] = Query(None, description="二级分类ID")
):
"""获取指定层级的概念列表"""
try:
for lv1 in hierarchy_data.get('hierarchy', []):
if lv1.get('lv1_id') == lv1_id:
if lv2_id:
# 返回指定 lv2 下的概念
for child in lv1.get('children', []):
if child.get('lv2_id') == lv2_id:
concepts = []
if 'children' in child:
for lv3_child in child.get('children', []):
concepts.extend(lv3_child.get('concepts', []))
else:
concepts = child.get('concepts', [])
return {
"lv1": lv1.get('lv1'),
"lv1_id": lv1_id,
"lv2": child.get('lv2'),
"lv2_id": lv2_id,
"concepts": concepts,
"concept_count": len(concepts)
}
raise HTTPException(status_code=404, detail=f"未找到二级分类: {lv2_id}")
else:
# 返回 lv1 下所有概念
concepts = []
for child in lv1.get('children', []):
if 'children' in child:
for lv3_child in child.get('children', []):
concepts.extend(lv3_child.get('concepts', []))
else:
concepts.extend(child.get('concepts', []))
return {
"lv1": lv1.get('lv1'),
"lv1_id": lv1_id,
"concepts": concepts,
"concept_count": len(concepts),
"children": lv1.get('children', [])
}
raise HTTPException(status_code=404, detail=f"未找到一级分类: {lv1_id}")
except HTTPException:
raise
except Exception as e:
logger.error(f"获取层级详情失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/stock/{stock_code}/concepts", tags=["Stocks"])
async def get_stock_concepts(
stock_code: str,
size: int = Query(50, ge=1, le=200),
trade_date: Optional[date] = Query(None)
):
"""根据股票代码查询相关概念"""
try:
query = {
"nested": {
"path": "stocks",
"query": {
"bool": {
"should": [
{"term": {"stocks.code": stock_code}},
{"term": {"stocks.name": stock_code}}
]
}
}
}
}
search_body = {
"query": query,
"size": size,
"_source": {
"excludes": ["description_embedding", "insight", "stocks_json"]
}
}
es_response = es_client.search(index=INDEX_NAME, body=search_body)
concept_ids = []
concepts = []
for hit in es_response['hits']['hits']:
source = hit['_source']
concept_name = source.get('concept', '')
concept_id = source.get('concept_id', '')
concept_ids.append(concept_id)
# 获取层级
hierarchy_info = get_concept_hierarchy(concept_name)
concepts.append({
"concept_id": concept_id,
"concept": concept_name,
"description": source.get('description'),
"tags": source.get('tags', []),
"stock_count": len(source.get('stocks', [])),
"hierarchy": hierarchy_info,
"outbreak_dates": source.get('outbreak_dates', [])
})
# 获取涨跌幅
price_data = await get_concept_price_data(concept_ids, trade_date)
for concept in concepts:
concept['price_info'] = price_data.get(concept['concept_id'])
return {
"stock_code": stock_code,
"total": es_response['hits']['total']['value'],
"concepts": concepts
}
except Exception as e:
logger.error(f"查询股票概念失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/statistics/hierarchy", tags=["Statistics"])
async def get_hierarchy_statistics():
"""获取层级统计信息"""
try:
stats = []
for lv1 in hierarchy_data.get('hierarchy', []):
lv1_stats = {
"lv1": lv1.get('lv1'),
"lv1_id": lv1.get('lv1_id'),
"lv2_count": len(lv1.get('children', [])),
"concept_count": 0,
"children": []
}
for child in lv1.get('children', []):
lv2_concept_count = 0
if 'children' in child:
for lv3_child in child.get('children', []):
lv2_concept_count += len(lv3_child.get('concepts', []))
else:
lv2_concept_count = len(child.get('concepts', []))
lv1_stats['children'].append({
"lv2": child.get('lv2'),
"lv2_id": child.get('lv2_id'),
"concept_count": lv2_concept_count
})
lv1_stats['concept_count'] += lv2_concept_count
stats.append(lv1_stats)
# 按概念数量排序
stats.sort(key=lambda x: x['concept_count'], reverse=True)
return {
"total_lv1": len(stats),
"total_concepts": sum(s['concept_count'] for s in stats),
"statistics": stats
}
except Exception as e:
logger.error(f"获取层级统计失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
# ==================== 股票搜索相关 ====================
class StockSearchResult(BaseModel):
"""股票搜索结果"""
stock_code: str
stock_name: str
concept_count: int
class StockSearchResponse(BaseModel):
"""股票搜索响应"""
total: int
stocks: List[StockSearchResult]
@app.get("/stock/search", response_model=StockSearchResponse, tags=["Stocks"])
async def search_stocks(
keyword: str = Query(..., description="股票关键词(代码或名称)"),
size: int = Query(20, ge=1, le=100, description="返回数量")
):
"""搜索股票名称或代码"""
try:
# v3 索引中股票字段是 stocks.name 和 stocks.code
search_body = {
"query": {
"nested": {
"path": "stocks",
"query": {
"bool": {
"should": [
{"wildcard": {"stocks.code": f"*{keyword}*"}},
{"match": {"stocks.name": keyword}}
]
}
}
}
},
"size": 0,
"aggs": {
"unique_stocks": {
"nested": {
"path": "stocks"
},
"aggs": {
"stock_codes": {
"terms": {
"field": "stocks.code",
"size": size * 2
},
"aggs": {
"stock_names": {
"terms": {
"field": "stocks.name",
"size": 1
}
}
}
}
}
}
}
}
es_response = es_client.search(index=INDEX_NAME, body=search_body)
stocks = []
buckets = es_response['aggregations']['unique_stocks']['stock_codes']['buckets']
for bucket in buckets:
stock_code = bucket['key']
concept_count = bucket['doc_count']
name_buckets = bucket['stock_names']['buckets']
stock_name = name_buckets[0]['key'] if name_buckets else stock_code
if keyword in stock_code or keyword in stock_name:
stocks.append(StockSearchResult(
stock_code=stock_code,
stock_name=stock_name,
concept_count=concept_count
))
stocks.sort(key=lambda x: x.concept_count, reverse=True)
return StockSearchResponse(
total=len(stocks),
stocks=stocks[:size]
)
except Exception as e:
logger.error(f"搜索股票失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
# ==================== 价格相关接口 ====================
class PriceTimeSeriesItem(BaseModel):
"""价格时间序列数据点"""
trade_date: date
avg_change_pct: Optional[float] = Field(None, description="平均涨跌幅(%)")
stock_count: Optional[int] = Field(None, description="当日股票数量")
class PriceTimeSeriesResponse(BaseModel):
"""价格时间序列响应"""
concept_id: str
concept_name: str
start_date: date
end_date: date
data_points: int
timeseries: List[PriceTimeSeriesItem]
async def get_latest_trade_date() -> Optional[date]:
"""获取最新的交易日期"""
if not mysql_pool:
return None
try:
async with mysql_pool.acquire() as conn:
async with conn.cursor() as cursor:
query = "SELECT MAX(trade_date) as latest_date FROM concept_daily_stats"
await cursor.execute(query)
result = await cursor.fetchone()
return result[0] if result and result[0] else None
except Exception as e:
logger.error(f"获取最新交易日期失败: {e}")
return None
@app.get("/price/latest", tags=["Price"])
async def get_latest_price_date():
"""获取最新的涨跌幅数据日期"""
try:
latest_date = await get_latest_trade_date()
return {
"latest_trade_date": latest_date,
"has_data": latest_date is not None
}
except Exception as e:
logger.error(f"获取最新价格日期失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/concept/{concept_id}/price-timeseries", response_model=PriceTimeSeriesResponse, tags=["Price"])
async def get_concept_price_timeseries(
concept_id: str,
start_date: date = Query(..., description="开始日期格式YYYY-MM-DD"),
end_date: date = Query(..., description="结束日期格式YYYY-MM-DD")
):
"""获取概念在指定日期范围内的涨跌幅时间序列数据"""
if not mysql_pool:
logger.warning(f"[PriceTimeseries] MySQL 连接不可用")
return PriceTimeSeriesResponse(
concept_id=concept_id,
concept_name=concept_id,
start_date=start_date,
end_date=end_date,
data_points=0,
timeseries=[]
)
if start_date > end_date:
raise HTTPException(status_code=400, detail="开始日期不能晚于结束日期")
try:
async with mysql_pool.acquire() as conn:
async with conn.cursor(aiomysql.DictCursor) as cursor:
query = """
SELECT trade_date, concept_name, avg_change_pct, stock_count
FROM concept_daily_stats
WHERE concept_id = %s
AND trade_date >= %s
AND trade_date <= %s
ORDER BY trade_date ASC
"""
await cursor.execute(query, (concept_id, start_date, end_date))
rows = await cursor.fetchall()
if not rows:
raise HTTPException(
status_code=404,
detail=f"未找到概念ID {concept_id}{start_date}{end_date} 期间的数据"
)
timeseries = []
concept_name = ""
for row in rows:
if not concept_name and row.get('concept_name'):
concept_name = row['concept_name']
item = PriceTimeSeriesItem(
trade_date=row['trade_date'],
avg_change_pct=float(row['avg_change_pct']) if row['avg_change_pct'] is not None else None,
stock_count=row.get('stock_count')
)
timeseries.append(item)
return PriceTimeSeriesResponse(
concept_id=concept_id,
concept_name=concept_name or concept_id,
start_date=start_date,
end_date=end_date,
data_points=len(timeseries),
timeseries=timeseries
)
except HTTPException:
raise
except Exception as e:
logger.error(f"获取价格时间序列失败: {e}")
raise HTTPException(status_code=500, detail=f"获取时间序列数据失败: {str(e)}")
# ==================== 统计接口 ====================
class ConceptStatItem(BaseModel):
"""概念统计项"""
name: str
concept_id: Optional[str] = None
change_pct: Optional[float] = None
stock_count: Optional[int] = None
news_count: Optional[int] = None
report_count: Optional[int] = None
total_mentions: Optional[int] = None
volatility: Optional[float] = None
avg_change: Optional[float] = None
max_change: Optional[float] = None
consecutive_days: Optional[int] = None
total_change: Optional[float] = None
avg_daily: Optional[float] = None
hierarchy: Optional[HierarchyInfo] = None
class ConceptStatistics(BaseModel):
"""概念统计数据"""
hot_concepts: List[ConceptStatItem]
cold_concepts: List[ConceptStatItem]
active_concepts: List[ConceptStatItem]
volatile_concepts: List[ConceptStatItem]
momentum_concepts: List[ConceptStatItem]
summary: Dict[str, Any]
class ConceptStatisticsResponse(BaseModel):
"""概念统计响应"""
success: bool
data: ConceptStatistics
params: Dict[str, Any]
note: Optional[str] = None
def get_fallback_statistics(start_date, end_date) -> ConceptStatistics:
"""返回示例统计数据"""
return ConceptStatistics(
hot_concepts=[
ConceptStatItem(name="人工智能", change_pct=8.76, stock_count=45, news_count=12),
ConceptStatItem(name="机器人", change_pct=7.54, stock_count=38, news_count=10),
ConceptStatItem(name="半导体", change_pct=6.43, stock_count=52, news_count=15),
ConceptStatItem(name="低空经济", change_pct=5.21, stock_count=28, news_count=6),
ConceptStatItem(name="量子科技", change_pct=4.98, stock_count=15, news_count=9),
],
cold_concepts=[
ConceptStatItem(name="房地产", change_pct=-5.76, stock_count=33, news_count=5),
ConceptStatItem(name="煤炭", change_pct=-4.32, stock_count=25, news_count=3),
ConceptStatItem(name="钢铁", change_pct=-3.21, stock_count=28, news_count=4),
ConceptStatItem(name="传统零售", change_pct=-2.98, stock_count=19, news_count=2),
ConceptStatItem(name="纺织服装", change_pct=-2.45, stock_count=15, news_count=2),
],
active_concepts=[
ConceptStatItem(name="人工智能", news_count=45, report_count=15, total_mentions=60),
ConceptStatItem(name="半导体", news_count=42, report_count=12, total_mentions=54),
ConceptStatItem(name="机器人", news_count=38, report_count=8, total_mentions=46),
ConceptStatItem(name="新能源", news_count=28, report_count=6, total_mentions=34),
ConceptStatItem(name="量子科技", news_count=25, report_count=5, total_mentions=30),
],
volatile_concepts=[
ConceptStatItem(name="数字货币", volatility=25.6, avg_change=2.1, max_change=15.2),
ConceptStatItem(name="元宇宙", volatility=23.8, avg_change=1.8, max_change=13.9),
ConceptStatItem(name="虚拟现实", volatility=21.2, avg_change=-0.5, max_change=10.1),
ConceptStatItem(name="游戏", volatility=19.7, avg_change=3.2, max_change=12.8),
ConceptStatItem(name="短剧", volatility=18.3, avg_change=-1.1, max_change=8.1),
],
momentum_concepts=[
ConceptStatItem(name="AI应用", consecutive_days=6, total_change=19.2, avg_daily=3.2),
ConceptStatItem(name="算力", consecutive_days=5, total_change=16.8, avg_daily=3.36),
ConceptStatItem(name="光通信", consecutive_days=4, total_change=13.1, avg_daily=3.28),
ConceptStatItem(name="存储", consecutive_days=4, total_change=12.4, avg_daily=3.1),
ConceptStatItem(name="PCB", consecutive_days=3, total_change=9.6, avg_daily=3.2),
],
summary={
'total_concepts': 500,
'positive_count': 320,
'negative_count': 180,
'avg_change': 1.8,
'update_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
'date_range': f"{start_date}{end_date}",
'days': (end_date - start_date).days + 1 if start_date and end_date else 7,
'start_date': str(start_date) if start_date else '',
'end_date': str(end_date) if end_date else ''
}
)
@app.get("/statistics", response_model=ConceptStatisticsResponse, tags=["Statistics"])
async def get_concept_statistics(
days: Optional[int] = Query(None, ge=1, le=90, description="统计天数范围"),
start_date: Optional[date] = Query(None, description="开始日期"),
end_date: Optional[date] = Query(None, description="结束日期"),
min_stock_count: int = Query(3, ge=1, description="最少股票数量过滤"),
concept_type: Optional[str] = Query(None, description="概念类型过滤: leaf/lv1/lv2/lv3默认为leaf")
):
"""获取概念板块统计数据 - 涨幅榜、跌幅榜、活跃榜、波动榜、连涨榜"""
from datetime import timedelta
# 计算日期范围
if days is not None and (start_date is not None or end_date is not None):
raise HTTPException(status_code=400, detail="days参数与start_date/end_date参数不能同时使用")
if start_date is not None and end_date is not None:
if start_date > end_date:
raise HTTPException(status_code=400, detail="开始日期不能晚于结束日期")
if (end_date - start_date).days > 90:
raise HTTPException(status_code=400, detail="日期范围不能超过90天")
elif days is not None:
end_date = datetime.now().date()
start_date = end_date - timedelta(days=days)
elif start_date is not None:
end_date = datetime.now().date()
elif end_date is not None:
start_date = end_date - timedelta(days=7)
else:
end_date = datetime.now().date()
start_date = end_date - timedelta(days=7)
if not mysql_pool:
logger.warning("[Statistics] MySQL 连接不可用,使用示例数据")
return ConceptStatisticsResponse(
success=True,
data=get_fallback_statistics(start_date, end_date),
params={
'days': (end_date - start_date).days + 1,
'min_stock_count': min_stock_count,
'start_date': str(start_date),
'end_date': str(end_date)
},
note="MySQL 连接不可用,使用示例数据"
)
# 默认只查询叶子概念
type_filter = concept_type if concept_type in ['leaf', 'lv1', 'lv2', 'lv3'] else 'leaf'
try:
async with mysql_pool.acquire() as conn:
async with conn.cursor(aiomysql.DictCursor) as cursor:
# 1. 涨幅榜
hot_query = """
SELECT concept_id, concept_name,
AVG(avg_change_pct) as avg_change_pct,
AVG(stock_count) as avg_stock_count,
COUNT(*) as trading_days
FROM concept_daily_stats
WHERE trade_date >= %s AND trade_date <= %s
AND avg_change_pct IS NOT NULL AND stock_count >= %s
AND concept_type = %s
GROUP BY concept_id, concept_name
HAVING COUNT(*) >= 2
ORDER BY AVG(avg_change_pct) DESC
LIMIT 10
"""
await cursor.execute(hot_query, (start_date, end_date, min_stock_count, type_filter))
hot_rows = await cursor.fetchall()
# 2. 跌幅榜
cold_query = """
SELECT concept_id, concept_name,
AVG(avg_change_pct) as avg_change_pct,
AVG(stock_count) as avg_stock_count,
COUNT(*) as trading_days
FROM concept_daily_stats
WHERE trade_date >= %s AND trade_date <= %s
AND avg_change_pct IS NOT NULL AND stock_count >= %s
AND concept_type = %s
GROUP BY concept_id, concept_name
HAVING COUNT(*) >= 2
ORDER BY AVG(avg_change_pct) ASC
LIMIT 10
"""
await cursor.execute(cold_query, (start_date, end_date, min_stock_count, type_filter))
cold_rows = await cursor.fetchall()
# 3. 活跃榜
active_query = """
SELECT concept_id, concept_name,
COUNT(*) as trading_days,
AVG(stock_count) as avg_stock_count
FROM concept_daily_stats
WHERE trade_date >= %s AND trade_date <= %s AND stock_count >= %s
AND concept_type = %s
GROUP BY concept_id, concept_name
ORDER BY COUNT(*) DESC, AVG(stock_count) DESC
LIMIT 10
"""
await cursor.execute(active_query, (start_date, end_date, min_stock_count, type_filter))
active_rows = await cursor.fetchall()
# 4. 波动榜
volatile_query = """
SELECT concept_id, concept_name,
STDDEV(avg_change_pct) as volatility,
AVG(avg_change_pct) as avg_change_pct,
MAX(avg_change_pct) as max_change_pct
FROM concept_daily_stats
WHERE trade_date >= %s AND trade_date <= %s
AND avg_change_pct IS NOT NULL AND stock_count >= %s
AND concept_type = %s
GROUP BY concept_id, concept_name
HAVING COUNT(*) >= 3 AND STDDEV(avg_change_pct) IS NOT NULL
ORDER BY STDDEV(avg_change_pct) DESC
LIMIT 10
"""
await cursor.execute(volatile_query, (start_date, end_date, min_stock_count, type_filter))
volatile_rows = await cursor.fetchall()
# 5. 连涨榜
momentum_query = """
SELECT concept_id, concept_name,
COUNT(*) as positive_days,
SUM(avg_change_pct) as total_change,
AVG(avg_change_pct) as avg_change_pct
FROM concept_daily_stats
WHERE trade_date >= %s AND trade_date <= %s
AND avg_change_pct > 0 AND stock_count >= %s
AND concept_type = %s
GROUP BY concept_id, concept_name
HAVING COUNT(*) >= 2
ORDER BY COUNT(*) DESC, AVG(avg_change_pct) DESC
LIMIT 10
"""
await cursor.execute(momentum_query, (start_date, end_date, min_stock_count, type_filter))
momentum_rows = await cursor.fetchall()
# 6. 总体统计
total_query = """
SELECT COUNT(DISTINCT concept_id) as total_concepts,
COUNT(DISTINCT CASE WHEN avg_change_pct > 0 THEN concept_id END) as positive_concepts,
COUNT(DISTINCT CASE WHEN avg_change_pct < 0 THEN concept_id END) as negative_concepts,
AVG(avg_change_pct) as overall_avg_change
FROM concept_daily_stats
WHERE trade_date >= %s AND trade_date <= %s AND avg_change_pct IS NOT NULL
AND concept_type = %s
"""
await cursor.execute(total_query, (start_date, end_date, type_filter))
total_row = await cursor.fetchone()
# 构建响应
def build_items(rows, item_type):
items = []
for row in rows:
concept_name = row['concept_name']
hierarchy_info = get_concept_hierarchy(concept_name)
item = ConceptStatItem(
name=concept_name,
concept_id=row.get('concept_id'),
hierarchy=HierarchyInfo(**hierarchy_info) if hierarchy_info else None
)
if item_type in ['hot', 'cold']:
item.change_pct = round(float(row['avg_change_pct']), 2) if row['avg_change_pct'] else 0.0
item.stock_count = int(row['avg_stock_count']) if row['avg_stock_count'] else 0
item.news_count = int(row['trading_days']) if row['trading_days'] else 0
elif item_type == 'active':
item.news_count = int(row['trading_days']) if row['trading_days'] else 0
item.stock_count = int(row['avg_stock_count']) if row['avg_stock_count'] else 0
item.report_count = max(1, int(row['trading_days'] * 0.3)) if row['trading_days'] else 1
item.total_mentions = item.news_count + item.report_count
elif item_type == 'volatile':
item.volatility = round(float(row['volatility']), 2) if row['volatility'] else 0.0
item.avg_change = round(float(row['avg_change_pct']), 2) if row['avg_change_pct'] else 0.0
item.max_change = round(float(row['max_change_pct']), 2) if row['max_change_pct'] else 0.0
elif item_type == 'momentum':
item.consecutive_days = int(row['positive_days']) if row['positive_days'] else 0
item.total_change = round(float(row['total_change']), 2) if row['total_change'] else 0.0
item.avg_daily = round(float(row['avg_change_pct']), 2) if row['avg_change_pct'] else 0.0
items.append(item)
return items
statistics = ConceptStatistics(
hot_concepts=build_items(hot_rows, 'hot'),
cold_concepts=build_items(cold_rows, 'cold'),
active_concepts=build_items(active_rows, 'active'),
volatile_concepts=build_items(volatile_rows, 'volatile'),
momentum_concepts=build_items(momentum_rows, 'momentum'),
summary={
'total_concepts': int(total_row['total_concepts']) if total_row['total_concepts'] else 0,
'positive_count': int(total_row['positive_concepts']) if total_row['positive_concepts'] else 0,
'negative_count': int(total_row['negative_concepts']) if total_row['negative_concepts'] else 0,
'avg_change': round(float(total_row['overall_avg_change']), 2) if total_row['overall_avg_change'] else 0.0,
'update_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
'date_range': f"{start_date}{end_date}",
'days': (end_date - start_date).days + 1,
'start_date': str(start_date),
'end_date': str(end_date)
}
)
# 补充空数据
if not statistics.hot_concepts or not statistics.cold_concepts:
fallback = get_fallback_statistics(start_date, end_date)
if not statistics.hot_concepts:
statistics.hot_concepts = fallback.hot_concepts
if not statistics.cold_concepts:
statistics.cold_concepts = fallback.cold_concepts
return ConceptStatisticsResponse(
success=True,
data=statistics,
params={
'days': (end_date - start_date).days + 1,
'min_stock_count': min_stock_count,
'start_date': str(start_date),
'end_date': str(end_date),
'concept_type': type_filter
}
)
except HTTPException:
raise
except Exception as e:
logger.error(f"获取概念统计失败: {e}")
return ConceptStatisticsResponse(
success=True,
data=get_fallback_statistics(start_date, end_date),
params={
'days': (end_date - start_date).days + 1,
'min_stock_count': min_stock_count,
'start_date': str(start_date),
'end_date': str(end_date),
'concept_type': type_filter
},
note=f"使用示例数据,原因: {str(e)}"
)
# ==================== 层级概念涨跌幅接口 ====================
class HierarchyPriceItem(BaseModel):
"""层级概念价格项"""
concept_id: str
concept_name: str
concept_type: str # lv1/lv2/lv3
trade_date: date
avg_change_pct: Optional[float] = None
stock_count: Optional[int] = None
class HierarchyPriceResponse(BaseModel):
"""层级概念价格响应"""
trade_date: date
lv1_concepts: List[HierarchyPriceItem]
lv2_concepts: List[HierarchyPriceItem]
lv3_concepts: List[HierarchyPriceItem]
total_count: int
class ConceptPriceItem(BaseModel):
"""概念价格项"""
concept_id: str
concept_name: str
concept_type: str # leaf/lv1/lv2/lv3
trade_date: date
avg_change_pct: Optional[float] = None
stock_count: Optional[int] = None
hierarchy: Optional[HierarchyInfo] = None
class ConceptPriceListResponse(BaseModel):
"""概念价格列表响应"""
trade_date: date
total: int
concepts: List[ConceptPriceItem]
@app.get("/price/list", response_model=ConceptPriceListResponse, tags=["Price"])
async def get_concept_price_list(
trade_date: Optional[date] = Query(None, description="交易日期,默认最新"),
concept_type: Optional[str] = Query(None, description="概念类型: leaf/lv1/lv2/lv3默认全部"),
sort_by: str = Query("change_desc", description="排序: change_desc(涨幅降序), change_asc(涨幅升序), stock_count(股票数)"),
limit: int = Query(100, ge=1, le=1000, description="返回数量"),
offset: int = Query(0, ge=0, description="偏移量")
):
"""批量获取概念涨跌幅列表"""
if not mysql_pool:
raise HTTPException(status_code=503, detail="MySQL连接不可用")
try:
async with mysql_pool.acquire() as conn:
async with conn.cursor(aiomysql.DictCursor) as cursor:
# 获取交易日期
if trade_date is None:
await cursor.execute("SELECT MAX(trade_date) as max_date FROM concept_daily_stats")
result = await cursor.fetchone()
if not result or not result['max_date']:
raise HTTPException(status_code=404, detail="无涨跌幅数据")
trade_date = result['max_date']
# 构建查询
where_conditions = ["trade_date = %s"]
params = [trade_date]
if concept_type and concept_type in ['leaf', 'lv1', 'lv2', 'lv3']:
where_conditions.append("concept_type = %s")
params.append(concept_type)
# 排序
order_clause = "avg_change_pct DESC"
if sort_by == "change_asc":
order_clause = "avg_change_pct ASC"
elif sort_by == "stock_count":
order_clause = "stock_count DESC"
# 获取总数
count_query = f"SELECT COUNT(*) as cnt FROM concept_daily_stats WHERE {' AND '.join(where_conditions)}"
await cursor.execute(count_query, params)
total = (await cursor.fetchone())['cnt']
# 获取数据
query = f"""
SELECT concept_id, concept_name, concept_type, trade_date, avg_change_pct, stock_count
FROM concept_daily_stats
WHERE {' AND '.join(where_conditions)}
ORDER BY {order_clause}
LIMIT %s OFFSET %s
"""
params.extend([limit, offset])
await cursor.execute(query, params)
rows = await cursor.fetchall()
concepts = []
for row in rows:
concept_name = row['concept_name']
# 获取层级信息(仅对叶子概念)
hierarchy = None
if row['concept_type'] == 'leaf':
hierarchy_info = get_concept_hierarchy(concept_name)
if hierarchy_info:
hierarchy = HierarchyInfo(**hierarchy_info)
concepts.append(ConceptPriceItem(
concept_id=row['concept_id'],
concept_name=concept_name,
concept_type=row['concept_type'],
trade_date=row['trade_date'],
avg_change_pct=float(row['avg_change_pct']) if row['avg_change_pct'] else None,
stock_count=row['stock_count'],
hierarchy=hierarchy
))
return ConceptPriceListResponse(
trade_date=trade_date,
total=total,
concepts=concepts
)
except HTTPException:
raise
except Exception as e:
logger.error(f"获取概念涨跌幅列表失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/hierarchy/{hierarchy_id}/price-timeseries", tags=["Hierarchy"])
async def get_hierarchy_price_timeseries(
hierarchy_id: str,
start_date: date = Query(..., description="开始日期"),
end_date: date = Query(..., description="结束日期")
):
"""获取层级概念lv1/lv2/lv3的价格时间序列"""
if not mysql_pool:
raise HTTPException(status_code=503, detail="MySQL连接不可用")
if start_date > end_date:
raise HTTPException(status_code=400, detail="开始日期不能晚于结束日期")
try:
async with mysql_pool.acquire() as conn:
async with conn.cursor(aiomysql.DictCursor) as cursor:
query = """
SELECT concept_id, concept_name, concept_type, trade_date, avg_change_pct, stock_count
FROM concept_daily_stats
WHERE concept_id = %s
AND trade_date >= %s AND trade_date <= %s
ORDER BY trade_date ASC
"""
await cursor.execute(query, (hierarchy_id, start_date, end_date))
rows = await cursor.fetchall()
if not rows:
raise HTTPException(status_code=404, detail=f"未找到层级概念 {hierarchy_id} 的数据")
timeseries = []
concept_name = ""
concept_type = ""
for row in rows:
if not concept_name:
concept_name = row['concept_name']
concept_type = row['concept_type']
timeseries.append({
'trade_date': row['trade_date'],
'avg_change_pct': float(row['avg_change_pct']) if row['avg_change_pct'] else None,
'stock_count': row['stock_count']
})
return {
'concept_id': hierarchy_id,
'concept_name': concept_name,
'concept_type': concept_type,
'start_date': start_date,
'end_date': end_date,
'data_points': len(timeseries),
'timeseries': timeseries
}
except HTTPException:
raise
except Exception as e:
logger.error(f"获取层级价格时间序列失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
# ==================== 主函数 ====================
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"concept_api_v2:app",
host="0.0.0.0",
port=6801,
reload=True,
log_level="info"
)