1810 lines
67 KiB
Python
1810 lines
67 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
概念搜索API V2 - 适配 concept_library_v3 索引
|
||
新特性:
|
||
1. 支持层级结构(lv1/lv2/lv3)
|
||
2. 返回 tags, outbreak_dates, insight 等新字段
|
||
3. 新增层级浏览接口
|
||
"""
|
||
import json
|
||
import openai
|
||
from typing import List, Dict, Optional, Union, Any
|
||
from fastapi import FastAPI, HTTPException, Query
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from pydantic import BaseModel, Field
|
||
from elasticsearch import Elasticsearch
|
||
from datetime import datetime, date
|
||
import logging
|
||
import re
|
||
from contextlib import asynccontextmanager
|
||
import aiomysql
|
||
import os
|
||
|
||
# 配置日志
|
||
logging.basicConfig(level=logging.INFO)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# ==================== 配置 ====================
|
||
|
||
# ES配置
|
||
ES_HOST = 'http://127.0.0.1:9200'
|
||
INDEX_NAME = 'concept_library_v3'
|
||
|
||
# Embedding配置
|
||
OPENAI_BASE_URL = "http://127.0.0.1:8000/v1"
|
||
OPENAI_API_KEY = "dummy"
|
||
EMBEDDING_MODEL = "qwen3-embedding-8b"
|
||
|
||
# 层级结构文件
|
||
HIERARCHY_FILE = 'concept_hierarchy_v3.json'
|
||
|
||
# MySQL配置
|
||
MYSQL_CONFIG = {
|
||
'host': '192.168.1.8',
|
||
'port': 3306,
|
||
'user': 'root',
|
||
'password': 'Zzl5588161!',
|
||
'db': 'stock',
|
||
'charset': 'utf8mb4',
|
||
'autocommit': True,
|
||
'minsize': 1,
|
||
'maxsize': 10
|
||
}
|
||
|
||
# ==================== 全局变量 ====================
|
||
|
||
es_client = None
|
||
openai_client = None
|
||
mysql_pool = None
|
||
|
||
# 层级结构相关
|
||
hierarchy_data = {} # 原始层级数据
|
||
concept_to_hierarchy = {} # 概念名称 -> 层级信息的映射
|
||
|
||
|
||
def load_hierarchy():
|
||
"""加载层级结构并建立概念到层级的映射"""
|
||
global hierarchy_data, concept_to_hierarchy
|
||
|
||
hierarchy_path = os.path.join(os.path.dirname(__file__), HIERARCHY_FILE)
|
||
if not os.path.exists(hierarchy_path):
|
||
logger.warning(f"层级文件不存在: {hierarchy_path}")
|
||
return
|
||
|
||
try:
|
||
with open(hierarchy_path, 'r', encoding='utf-8') as f:
|
||
hierarchy_data = json.load(f)
|
||
|
||
# 遍历层级结构,建立概念到层级的映射
|
||
for lv1 in hierarchy_data.get('hierarchy', []):
|
||
lv1_name = lv1.get('lv1', '')
|
||
lv1_id = lv1.get('lv1_id', '')
|
||
|
||
for child in lv1.get('children', []):
|
||
lv2_name = child.get('lv2', '')
|
||
lv2_id = child.get('lv2_id', '')
|
||
|
||
# 检查是否有 lv3 子级
|
||
if 'children' in child:
|
||
for lv3_child in child.get('children', []):
|
||
lv3_name = lv3_child.get('lv3', '')
|
||
lv3_id = lv3_child.get('lv3_id', '')
|
||
|
||
# lv3 级别的概念
|
||
for concept in lv3_child.get('concepts', []):
|
||
concept_to_hierarchy[concept] = {
|
||
'lv1': lv1_name,
|
||
'lv1_id': lv1_id,
|
||
'lv2': lv2_name,
|
||
'lv2_id': lv2_id,
|
||
'lv3': lv3_name,
|
||
'lv3_id': lv3_id
|
||
}
|
||
else:
|
||
# lv2 级别直接有概念(没有 lv3)
|
||
for concept in child.get('concepts', []):
|
||
concept_to_hierarchy[concept] = {
|
||
'lv1': lv1_name,
|
||
'lv1_id': lv1_id,
|
||
'lv2': lv2_name,
|
||
'lv2_id': lv2_id,
|
||
'lv3': None,
|
||
'lv3_id': None
|
||
}
|
||
|
||
logger.info(f"加载层级结构完成,共 {len(concept_to_hierarchy)} 个概念有层级信息")
|
||
|
||
except Exception as e:
|
||
logger.error(f"加载层级结构失败: {e}")
|
||
|
||
|
||
def get_concept_hierarchy(concept_name: str) -> Optional[Dict]:
|
||
"""获取概念的层级信息"""
|
||
return concept_to_hierarchy.get(concept_name)
|
||
|
||
|
||
# ==================== 生命周期管理 ====================
|
||
|
||
@asynccontextmanager
|
||
async def lifespan(app: FastAPI):
|
||
global es_client, openai_client, mysql_pool
|
||
|
||
# 加载层级结构
|
||
load_hierarchy()
|
||
|
||
# 初始化ES客户端
|
||
es_client = Elasticsearch(
|
||
[ES_HOST],
|
||
timeout=30,
|
||
max_retries=3,
|
||
retry_on_timeout=True
|
||
)
|
||
logger.info(f"Connected to Elasticsearch at {ES_HOST}, index: {INDEX_NAME}")
|
||
|
||
# 初始化OpenAI客户端
|
||
openai_client = openai.OpenAI(
|
||
api_key=OPENAI_API_KEY,
|
||
base_url=OPENAI_BASE_URL,
|
||
timeout=60,
|
||
)
|
||
logger.info(f"Initialized OpenAI client")
|
||
|
||
# 初始化MySQL连接池
|
||
try:
|
||
mysql_pool = await aiomysql.create_pool(**MYSQL_CONFIG)
|
||
logger.info(f"Connected to MySQL at {MYSQL_CONFIG['host']}")
|
||
except Exception as e:
|
||
logger.error(f"Failed to connect to MySQL: {e}")
|
||
|
||
yield
|
||
|
||
# 清理资源
|
||
if es_client:
|
||
es_client.close()
|
||
if mysql_pool:
|
||
mysql_pool.close()
|
||
await mysql_pool.wait_closed()
|
||
logger.info("Cleanup completed")
|
||
|
||
|
||
# ==================== FastAPI 应用 ====================
|
||
|
||
app = FastAPI(
|
||
title="概念搜索API V2",
|
||
description="支持层级结构的概念库搜索API,适配 concept_library_v3 索引",
|
||
version="2.0.0",
|
||
lifespan=lifespan
|
||
)
|
||
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=["*"],
|
||
allow_credentials=True,
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
|
||
|
||
# ==================== 数据模型 ====================
|
||
|
||
class HierarchyInfo(BaseModel):
|
||
"""层级信息"""
|
||
lv1: Optional[str] = Field(None, description="一级分类")
|
||
lv1_id: Optional[str] = None
|
||
lv2: Optional[str] = Field(None, description="二级分类")
|
||
lv2_id: Optional[str] = None
|
||
lv3: Optional[str] = Field(None, description="三级分类")
|
||
lv3_id: Optional[str] = None
|
||
|
||
|
||
class StockInfo(BaseModel):
|
||
"""股票信息"""
|
||
name: str = Field(..., description="股票名称")
|
||
code: Optional[str] = Field(None, description="股票代码")
|
||
|
||
|
||
class ConceptPriceInfo(BaseModel):
|
||
"""概念涨跌幅信息"""
|
||
trade_date: date
|
||
avg_change_pct: Optional[float] = Field(None, description="平均涨跌幅(%)")
|
||
|
||
|
||
class SearchRequest(BaseModel):
|
||
"""搜索请求"""
|
||
query: str = Field(..., description="搜索查询文本")
|
||
size: int = Field(10, ge=1, le=100, description="每页返回结果数量")
|
||
page: int = Field(1, ge=1, description="页码")
|
||
search_size: int = Field(100, ge=10, le=1000, description="搜索数量")
|
||
semantic_weight: Optional[float] = Field(None, ge=0.0, le=1.0, description="语义搜索权重")
|
||
filter_stocks: Optional[List[str]] = Field(None, description="过滤特定股票")
|
||
filter_lv1: Optional[str] = Field(None, description="过滤一级分类")
|
||
filter_lv2: Optional[str] = Field(None, description="过滤二级分类")
|
||
trade_date: Optional[date] = Field(None, description="交易日期")
|
||
sort_by: str = Field("_score", description="排序: _score, change_pct, stock_count, outbreak_date")
|
||
use_knn: bool = Field(True, description="是否使用KNN搜索")
|
||
|
||
|
||
class ConceptResult(BaseModel):
|
||
"""概念搜索结果"""
|
||
concept_id: str
|
||
concept: str
|
||
description: Optional[str] = None
|
||
tags: Optional[List[str]] = Field(None, description="标签")
|
||
outbreak_dates: Optional[List[str]] = Field(None, description="爆发日期")
|
||
stocks: List[StockInfo] = Field(default_factory=list)
|
||
stock_count: int = 0
|
||
hierarchy: Optional[HierarchyInfo] = Field(None, description="层级信息")
|
||
score: float = 0.0
|
||
match_type: str = "keyword"
|
||
highlights: Optional[Dict[str, List[str]]] = None
|
||
price_info: Optional[ConceptPriceInfo] = None
|
||
|
||
|
||
class SearchResponse(BaseModel):
|
||
"""搜索响应"""
|
||
total: int
|
||
took_ms: int
|
||
results: List[ConceptResult]
|
||
search_info: Dict[str, Any]
|
||
price_date: Optional[date] = None
|
||
page: int = 1
|
||
total_pages: int = 1
|
||
|
||
|
||
class HierarchyLevel(BaseModel):
|
||
"""层级节点"""
|
||
id: str
|
||
name: str
|
||
concept_count: int = 0
|
||
children: Optional[List['HierarchyLevel']] = None
|
||
concepts: Optional[List[str]] = None
|
||
|
||
|
||
class HierarchyResponse(BaseModel):
|
||
"""层级结构响应"""
|
||
hierarchy: List[HierarchyLevel]
|
||
total_concepts: int
|
||
|
||
|
||
class ConceptDetailResponse(BaseModel):
|
||
"""概念详情响应"""
|
||
concept_id: str
|
||
concept: str
|
||
description: Optional[str] = None
|
||
insight: Optional[str] = None
|
||
tags: Optional[List[str]] = None
|
||
outbreak_dates: Optional[List[str]] = None
|
||
stocks: List[Dict[str, Any]] = Field(default_factory=list)
|
||
stock_count: int = 0
|
||
hierarchy: Optional[HierarchyInfo] = None
|
||
folders: Optional[List[str]] = None
|
||
created_at: Optional[str] = None
|
||
price_info: Optional[ConceptPriceInfo] = None
|
||
|
||
|
||
# ==================== 辅助函数 ====================
|
||
|
||
def generate_embedding(text: str) -> List[float]:
|
||
"""生成文本向量"""
|
||
try:
|
||
if not text or len(text.strip()) == 0:
|
||
return []
|
||
|
||
text = text[:8000] if len(text) > 8000 else text
|
||
|
||
if not openai_client:
|
||
return []
|
||
|
||
response = openai_client.embeddings.create(
|
||
model=EMBEDDING_MODEL,
|
||
input=[text]
|
||
)
|
||
return response.data[0].embedding
|
||
|
||
except Exception as e:
|
||
logger.warning(f"Embedding生成失败: {e}")
|
||
return []
|
||
|
||
|
||
def calculate_semantic_weight(query: str) -> float:
|
||
"""根据查询长度动态计算语义权重"""
|
||
query_length = len(query)
|
||
|
||
if query_length < 10:
|
||
return 0.3
|
||
elif query_length < 50:
|
||
return 0.5
|
||
elif query_length < 200:
|
||
return 0.6
|
||
else:
|
||
return 0.7
|
||
|
||
|
||
async def get_concept_price_data(concept_ids: List[str], trade_date: Optional[date] = None) -> Dict[str, ConceptPriceInfo]:
|
||
"""获取概念的涨跌幅数据"""
|
||
if not mysql_pool or not concept_ids:
|
||
return {}
|
||
|
||
try:
|
||
async with mysql_pool.acquire() as conn:
|
||
async with conn.cursor(aiomysql.DictCursor) as cursor:
|
||
placeholders = ','.join(['%s'] * len(concept_ids))
|
||
|
||
if trade_date:
|
||
query = f"""
|
||
SELECT concept_id, trade_date, avg_change_pct
|
||
FROM concept_daily_stats
|
||
WHERE concept_id IN ({placeholders}) AND trade_date = %s
|
||
"""
|
||
await cursor.execute(query, (*concept_ids, trade_date))
|
||
else:
|
||
query = f"""
|
||
SELECT cds.concept_id, cds.trade_date, cds.avg_change_pct
|
||
FROM concept_daily_stats cds
|
||
INNER JOIN (
|
||
SELECT concept_id, MAX(trade_date) as max_date
|
||
FROM concept_daily_stats
|
||
WHERE concept_id IN ({placeholders})
|
||
GROUP BY concept_id
|
||
) latest ON cds.concept_id = latest.concept_id
|
||
AND cds.trade_date = latest.max_date
|
||
"""
|
||
await cursor.execute(query, concept_ids)
|
||
|
||
rows = await cursor.fetchall()
|
||
|
||
result = {}
|
||
for row in rows:
|
||
result[row['concept_id']] = ConceptPriceInfo(
|
||
trade_date=row['trade_date'],
|
||
avg_change_pct=float(row['avg_change_pct']) if row['avg_change_pct'] is not None else None
|
||
)
|
||
return result
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取涨跌幅数据失败: {e}")
|
||
return {}
|
||
|
||
|
||
def build_keyword_query(query: str, stock_filters: List[str] = None) -> Dict:
|
||
"""构建关键词查询 - 适配 v3 索引"""
|
||
must_queries = []
|
||
|
||
if query.strip():
|
||
must_queries.append({
|
||
"multi_match": {
|
||
"query": query,
|
||
"fields": [
|
||
"concept^3",
|
||
"description^2",
|
||
"tags.text^2",
|
||
"tags^1.5"
|
||
],
|
||
"type": "best_fields",
|
||
"analyzer": "ik_smart"
|
||
}
|
||
})
|
||
|
||
if stock_filters:
|
||
stock_query = {
|
||
"nested": {
|
||
"path": "stocks",
|
||
"query": {
|
||
"bool": {
|
||
"should": [
|
||
{"terms": {"stocks.name": stock_filters}},
|
||
{"terms": {"stocks.code": stock_filters}}
|
||
]
|
||
}
|
||
}
|
||
}
|
||
}
|
||
must_queries.append(stock_query)
|
||
|
||
return {
|
||
"bool": {
|
||
"must": must_queries if must_queries else [{"match_all": {}}]
|
||
}
|
||
}
|
||
|
||
|
||
def parse_stocks_from_es(source: Dict) -> tuple:
|
||
"""从ES文档解析股票信息"""
|
||
# nested stocks 字段(只有 name, code)
|
||
stocks_nested = source.get('stocks', [])
|
||
stocks_list = []
|
||
|
||
for s in stocks_nested[:30]: # 限制返回数量
|
||
stocks_list.append(StockInfo(
|
||
name=s.get('name', ''),
|
||
code=s.get('code')
|
||
))
|
||
|
||
# 完整股票信息从 stocks_json 解析
|
||
stocks_full = []
|
||
stocks_json = source.get('stocks_json', '')
|
||
if stocks_json:
|
||
try:
|
||
stocks_full = json.loads(stocks_json)
|
||
except:
|
||
pass
|
||
|
||
return stocks_list, stocks_full
|
||
|
||
|
||
# ==================== API 端点 ====================
|
||
|
||
@app.get("/", tags=["Health"])
|
||
async def root():
|
||
"""健康检查"""
|
||
return {
|
||
"status": "healthy",
|
||
"service": "概念搜索API V2",
|
||
"version": "2.0.0",
|
||
"index": INDEX_NAME,
|
||
"hierarchy_concepts": len(concept_to_hierarchy)
|
||
}
|
||
|
||
|
||
@app.post("/search", response_model=SearchResponse, tags=["Search"])
|
||
async def search_concepts(request: SearchRequest):
|
||
"""
|
||
搜索概念 - 支持语义搜索和层级过滤
|
||
"""
|
||
start_time = datetime.now()
|
||
|
||
try:
|
||
# 计算语义权重
|
||
if request.semantic_weight is not None:
|
||
semantic_weight = request.semantic_weight
|
||
else:
|
||
semantic_weight = calculate_semantic_weight(request.query)
|
||
|
||
# 生成embedding
|
||
embedding = []
|
||
if semantic_weight > 0:
|
||
embedding = generate_embedding(request.query)
|
||
if not embedding:
|
||
semantic_weight = 0
|
||
|
||
# 确定搜索数量
|
||
# 修复:当按涨跌幅排序时,需要获取全部概念(800+)再排序
|
||
# 原来限制 500 会导致排序不准确
|
||
effective_search_size = request.search_size
|
||
if request.sort_by in ["change_pct", "outbreak_date"]:
|
||
effective_search_size = min(1000, request.search_size * 10)
|
||
|
||
# 构建查询
|
||
search_body = {}
|
||
match_type = "keyword"
|
||
|
||
if semantic_weight == 0:
|
||
# 纯关键词搜索
|
||
search_body = {
|
||
"query": build_keyword_query(request.query, request.filter_stocks),
|
||
"size": effective_search_size
|
||
}
|
||
match_type = "keyword"
|
||
|
||
elif request.use_knn and embedding:
|
||
# KNN + 关键词混合搜索
|
||
keyword_weight = 1.0 - semantic_weight
|
||
|
||
search_body = {
|
||
"knn": {
|
||
"field": "description_embedding",
|
||
"query_vector": embedding,
|
||
"k": effective_search_size,
|
||
"num_candidates": min(effective_search_size * 2, 10000),
|
||
"boost": semantic_weight
|
||
},
|
||
"query": {
|
||
"bool": {
|
||
"must": [build_keyword_query(request.query, request.filter_stocks)],
|
||
"boost": keyword_weight
|
||
}
|
||
},
|
||
"size": effective_search_size
|
||
}
|
||
match_type = "hybrid_knn"
|
||
|
||
else:
|
||
# 传统混合搜索
|
||
search_body = {
|
||
"query": build_keyword_query(request.query, request.filter_stocks),
|
||
"size": effective_search_size
|
||
}
|
||
match_type = "keyword"
|
||
|
||
# 添加高亮和源过滤
|
||
search_body.update({
|
||
"highlight": {
|
||
"fields": {
|
||
"concept": {},
|
||
"description": {"fragment_size": 150},
|
||
"tags": {}
|
||
}
|
||
},
|
||
"_source": {
|
||
"excludes": ["description_embedding", "insight"]
|
||
},
|
||
"track_total_hits": True
|
||
})
|
||
|
||
# 执行搜索
|
||
es_response = es_client.search(
|
||
index=INDEX_NAME,
|
||
body=search_body,
|
||
timeout="30s"
|
||
)
|
||
|
||
# 处理结果
|
||
all_results = []
|
||
concept_ids = []
|
||
|
||
for hit in es_response['hits']['hits']:
|
||
source = hit['_source']
|
||
concept_name = source.get('concept', '')
|
||
concept_id = source.get('concept_id', '')
|
||
concept_ids.append(concept_id)
|
||
|
||
# 解析股票
|
||
stocks_list, _ = parse_stocks_from_es(source)
|
||
|
||
# 获取层级信息
|
||
hierarchy_info = get_concept_hierarchy(concept_name)
|
||
hierarchy = None
|
||
if hierarchy_info:
|
||
hierarchy = HierarchyInfo(**hierarchy_info)
|
||
|
||
# 层级过滤
|
||
if request.filter_lv1 and (not hierarchy_info or hierarchy_info.get('lv1') != request.filter_lv1):
|
||
continue
|
||
if request.filter_lv2 and (not hierarchy_info or hierarchy_info.get('lv2') != request.filter_lv2):
|
||
continue
|
||
|
||
# 高亮信息
|
||
highlights = hit.get('highlight', {})
|
||
|
||
result = ConceptResult(
|
||
concept_id=concept_id,
|
||
concept=concept_name,
|
||
description=source.get('description'),
|
||
tags=source.get('tags', []),
|
||
outbreak_dates=source.get('outbreak_dates', []),
|
||
stocks=stocks_list,
|
||
stock_count=len(source.get('stocks', [])),
|
||
hierarchy=hierarchy,
|
||
score=hit['_score'] or 0,
|
||
match_type=match_type,
|
||
highlights=highlights,
|
||
price_info=None
|
||
)
|
||
all_results.append(result)
|
||
|
||
# 获取涨跌幅数据
|
||
price_data = {}
|
||
actual_price_date = None
|
||
|
||
if concept_ids:
|
||
price_data = await get_concept_price_data(concept_ids, request.trade_date)
|
||
if price_data:
|
||
actual_price_date = next(iter(price_data.values())).trade_date
|
||
|
||
for result in all_results:
|
||
result.price_info = price_data.get(result.concept_id)
|
||
|
||
# 排序
|
||
if request.sort_by == "change_pct":
|
||
all_results.sort(
|
||
key=lambda x: (
|
||
x.price_info.avg_change_pct
|
||
if x.price_info and x.price_info.avg_change_pct is not None
|
||
else -999
|
||
),
|
||
reverse=True
|
||
)
|
||
elif request.sort_by == "stock_count":
|
||
all_results.sort(key=lambda x: x.stock_count, reverse=True)
|
||
elif request.sort_by == "outbreak_date":
|
||
all_results.sort(
|
||
key=lambda x: (
|
||
x.outbreak_dates[0] if x.outbreak_dates else '1900-01-01'
|
||
),
|
||
reverse=True
|
||
)
|
||
|
||
# 分页
|
||
total_results = len(all_results)
|
||
total_pages = max(1, (total_results + request.size - 1) // request.size)
|
||
current_page = min(request.page, total_pages)
|
||
|
||
start_idx = (current_page - 1) * request.size
|
||
end_idx = start_idx + request.size
|
||
page_results = all_results[start_idx:end_idx]
|
||
|
||
took_ms = int((datetime.now() - start_time).total_seconds() * 1000)
|
||
|
||
return SearchResponse(
|
||
total=es_response['hits']['total']['value'],
|
||
took_ms=took_ms,
|
||
results=page_results,
|
||
search_info={
|
||
"query": request.query,
|
||
"semantic_weight": semantic_weight,
|
||
"match_type": match_type,
|
||
"filter_lv1": request.filter_lv1,
|
||
"filter_lv2": request.filter_lv2,
|
||
"sort_by": request.sort_by,
|
||
"has_embedding": bool(embedding)
|
||
},
|
||
price_date=actual_price_date,
|
||
page=current_page,
|
||
total_pages=total_pages
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"搜索失败: {e}", exc_info=True)
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
@app.get("/concept/{concept_id}", response_model=ConceptDetailResponse, tags=["Concepts"])
|
||
async def get_concept_detail(
|
||
concept_id: str,
|
||
trade_date: Optional[date] = Query(None, description="交易日期")
|
||
):
|
||
"""获取概念详情"""
|
||
try:
|
||
result = es_client.get(index=INDEX_NAME, id=concept_id)
|
||
source = result['_source']
|
||
|
||
concept_name = source.get('concept', '')
|
||
|
||
# 解析完整股票信息
|
||
_, stocks_full = parse_stocks_from_es(source)
|
||
|
||
# 获取层级信息
|
||
hierarchy_info = get_concept_hierarchy(concept_name)
|
||
hierarchy = HierarchyInfo(**hierarchy_info) if hierarchy_info else None
|
||
|
||
# 获取涨跌幅
|
||
price_data = await get_concept_price_data([concept_id], trade_date)
|
||
price_info = price_data.get(concept_id)
|
||
|
||
return ConceptDetailResponse(
|
||
concept_id=concept_id,
|
||
concept=concept_name,
|
||
description=source.get('description'),
|
||
insight=source.get('insight'),
|
||
tags=source.get('tags', []),
|
||
outbreak_dates=source.get('outbreak_dates', []),
|
||
stocks=stocks_full,
|
||
stock_count=len(stocks_full),
|
||
hierarchy=hierarchy,
|
||
folders=source.get('folders', []),
|
||
created_at=source.get('created_at'),
|
||
price_info=price_info
|
||
)
|
||
|
||
except Exception as e:
|
||
if "NotFoundError" in str(type(e)):
|
||
raise HTTPException(status_code=404, detail="概念不存在")
|
||
logger.error(f"获取概念详情失败: {e}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
@app.get("/hierarchy", response_model=HierarchyResponse, tags=["Hierarchy"])
|
||
async def get_hierarchy():
|
||
"""获取完整层级结构"""
|
||
try:
|
||
result = []
|
||
|
||
for lv1 in hierarchy_data.get('hierarchy', []):
|
||
lv1_node = HierarchyLevel(
|
||
id=lv1.get('lv1_id', ''),
|
||
name=lv1.get('lv1', ''),
|
||
concept_count=0,
|
||
children=[]
|
||
)
|
||
|
||
for child in lv1.get('children', []):
|
||
lv2_node = HierarchyLevel(
|
||
id=child.get('lv2_id', ''),
|
||
name=child.get('lv2', ''),
|
||
concept_count=0,
|
||
children=[]
|
||
)
|
||
|
||
if 'children' in child:
|
||
# 有 lv3
|
||
for lv3_child in child.get('children', []):
|
||
concepts = lv3_child.get('concepts', [])
|
||
lv3_node = HierarchyLevel(
|
||
id=lv3_child.get('lv3_id', ''),
|
||
name=lv3_child.get('lv3', ''),
|
||
concept_count=len(concepts),
|
||
concepts=concepts
|
||
)
|
||
lv2_node.children.append(lv3_node)
|
||
lv2_node.concept_count += len(concepts)
|
||
else:
|
||
# lv2 直接有概念
|
||
concepts = child.get('concepts', [])
|
||
lv2_node.concept_count = len(concepts)
|
||
lv2_node.concepts = concepts
|
||
|
||
lv1_node.children.append(lv2_node)
|
||
lv1_node.concept_count += lv2_node.concept_count
|
||
|
||
result.append(lv1_node)
|
||
|
||
total_concepts = sum(lv1.concept_count for lv1 in result)
|
||
|
||
return HierarchyResponse(
|
||
hierarchy=result,
|
||
total_concepts=total_concepts
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取层级结构失败: {e}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
# 注意:/hierarchy/price 必须在 /hierarchy/{lv1_id} 之前定义,否则会被错误匹配
|
||
@app.get("/hierarchy/price", tags=["Hierarchy"])
|
||
async def get_hierarchy_price_early(
|
||
trade_date: Optional[date] = Query(None, description="交易日期,默认最新"),
|
||
lv1_filter: Optional[str] = Query(None, description="筛选特定一级分类")
|
||
):
|
||
"""获取所有概念的涨跌幅数据(包含母概念lv1/lv2/lv3和叶子概念leaf)"""
|
||
logger.info(f"[hierarchy/price] 请求参数: trade_date={trade_date}, lv1_filter={lv1_filter}")
|
||
|
||
if not mysql_pool:
|
||
logger.error("[hierarchy/price] MySQL连接池不可用")
|
||
raise HTTPException(status_code=503, detail="MySQL连接不可用")
|
||
|
||
try:
|
||
async with mysql_pool.acquire() as conn:
|
||
async with conn.cursor(aiomysql.DictCursor) as cursor:
|
||
# 获取交易日期
|
||
query_date = trade_date
|
||
if query_date is None:
|
||
await cursor.execute("SELECT MAX(trade_date) as max_date FROM concept_daily_stats")
|
||
result = await cursor.fetchone()
|
||
if not result or not result['max_date']:
|
||
raise HTTPException(status_code=404, detail="无涨跌幅数据")
|
||
query_date = result['max_date']
|
||
|
||
logger.info(f"[hierarchy/price] 使用查询日期: {query_date}")
|
||
|
||
# 构建查询
|
||
base_query = """
|
||
SELECT concept_id, concept_name, concept_type, trade_date, avg_change_pct, stock_count
|
||
FROM concept_daily_stats
|
||
WHERE trade_date = %s AND concept_type = %s
|
||
"""
|
||
|
||
if lv1_filter:
|
||
base_query += " AND concept_name LIKE %s"
|
||
|
||
base_query += " ORDER BY avg_change_pct DESC"
|
||
|
||
lv1_concepts = []
|
||
lv2_concepts = []
|
||
lv3_concepts = []
|
||
leaf_concepts = []
|
||
|
||
# 查询 lv1
|
||
if lv1_filter:
|
||
await cursor.execute(base_query, (query_date, 'lv1', f"%{lv1_filter}%"))
|
||
else:
|
||
await cursor.execute(base_query, (query_date, 'lv1'))
|
||
for row in await cursor.fetchall():
|
||
lv1_concepts.append({
|
||
"concept_id": row['concept_id'],
|
||
"concept_name": row['concept_name'],
|
||
"concept_type": 'lv1',
|
||
"trade_date": str(row['trade_date']),
|
||
"avg_change_pct": float(row['avg_change_pct']) if row['avg_change_pct'] else None,
|
||
"stock_count": row['stock_count']
|
||
})
|
||
|
||
# 查询 lv2
|
||
if lv1_filter:
|
||
await cursor.execute(base_query, (query_date, 'lv2', f"%{lv1_filter}%"))
|
||
else:
|
||
await cursor.execute(base_query, (query_date, 'lv2'))
|
||
for row in await cursor.fetchall():
|
||
lv2_concepts.append({
|
||
"concept_id": row['concept_id'],
|
||
"concept_name": row['concept_name'],
|
||
"concept_type": 'lv2',
|
||
"trade_date": str(row['trade_date']),
|
||
"avg_change_pct": float(row['avg_change_pct']) if row['avg_change_pct'] else None,
|
||
"stock_count": row['stock_count']
|
||
})
|
||
|
||
# 查询 lv3
|
||
if lv1_filter:
|
||
await cursor.execute(base_query, (query_date, 'lv3', f"%{lv1_filter}%"))
|
||
else:
|
||
await cursor.execute(base_query, (query_date, 'lv3'))
|
||
for row in await cursor.fetchall():
|
||
lv3_concepts.append({
|
||
"concept_id": row['concept_id'],
|
||
"concept_name": row['concept_name'],
|
||
"concept_type": 'lv3',
|
||
"trade_date": str(row['trade_date']),
|
||
"avg_change_pct": float(row['avg_change_pct']) if row['avg_change_pct'] else None,
|
||
"stock_count": row['stock_count']
|
||
})
|
||
|
||
# 查询叶子概念 leaf
|
||
await cursor.execute(base_query, (query_date, 'leaf'))
|
||
for row in await cursor.fetchall():
|
||
# 获取层级信息
|
||
hierarchy_info = get_concept_hierarchy(row['concept_name'])
|
||
leaf_concepts.append({
|
||
"concept_id": row['concept_id'],
|
||
"concept_name": row['concept_name'],
|
||
"concept_type": 'leaf',
|
||
"trade_date": str(row['trade_date']),
|
||
"avg_change_pct": float(row['avg_change_pct']) if row['avg_change_pct'] else None,
|
||
"stock_count": row['stock_count'],
|
||
"hierarchy": hierarchy_info
|
||
})
|
||
|
||
total = len(lv1_concepts) + len(lv2_concepts) + len(lv3_concepts) + len(leaf_concepts)
|
||
logger.info(f"[hierarchy/price] 返回结果: lv1={len(lv1_concepts)}, lv2={len(lv2_concepts)}, lv3={len(lv3_concepts)}, leaf={len(leaf_concepts)}")
|
||
|
||
return {
|
||
"trade_date": str(query_date),
|
||
"lv1_concepts": lv1_concepts,
|
||
"lv2_concepts": lv2_concepts,
|
||
"lv3_concepts": lv3_concepts,
|
||
"leaf_concepts": leaf_concepts,
|
||
"total_count": total,
|
||
"summary": {
|
||
"lv1_count": len(lv1_concepts),
|
||
"lv2_count": len(lv2_concepts),
|
||
"lv3_count": len(lv3_concepts),
|
||
"leaf_count": len(leaf_concepts)
|
||
}
|
||
}
|
||
|
||
except HTTPException as he:
|
||
logger.error(f"[hierarchy/price] HTTPException: {he.status_code} - {he.detail}")
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"[hierarchy/price] 获取层级涨跌幅失败: {e}")
|
||
import traceback
|
||
logger.error(traceback.format_exc())
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
@app.get("/hierarchy/{lv1_id}", tags=["Hierarchy"])
|
||
async def get_hierarchy_level(
|
||
lv1_id: str,
|
||
lv2_id: Optional[str] = Query(None, description="二级分类ID")
|
||
):
|
||
"""获取指定层级的概念列表"""
|
||
try:
|
||
for lv1 in hierarchy_data.get('hierarchy', []):
|
||
if lv1.get('lv1_id') == lv1_id:
|
||
if lv2_id:
|
||
# 返回指定 lv2 下的概念
|
||
for child in lv1.get('children', []):
|
||
if child.get('lv2_id') == lv2_id:
|
||
concepts = []
|
||
if 'children' in child:
|
||
for lv3_child in child.get('children', []):
|
||
concepts.extend(lv3_child.get('concepts', []))
|
||
else:
|
||
concepts = child.get('concepts', [])
|
||
|
||
return {
|
||
"lv1": lv1.get('lv1'),
|
||
"lv1_id": lv1_id,
|
||
"lv2": child.get('lv2'),
|
||
"lv2_id": lv2_id,
|
||
"concepts": concepts,
|
||
"concept_count": len(concepts)
|
||
}
|
||
raise HTTPException(status_code=404, detail=f"未找到二级分类: {lv2_id}")
|
||
else:
|
||
# 返回 lv1 下所有概念
|
||
concepts = []
|
||
for child in lv1.get('children', []):
|
||
if 'children' in child:
|
||
for lv3_child in child.get('children', []):
|
||
concepts.extend(lv3_child.get('concepts', []))
|
||
else:
|
||
concepts.extend(child.get('concepts', []))
|
||
|
||
return {
|
||
"lv1": lv1.get('lv1'),
|
||
"lv1_id": lv1_id,
|
||
"concepts": concepts,
|
||
"concept_count": len(concepts),
|
||
"children": lv1.get('children', [])
|
||
}
|
||
|
||
raise HTTPException(status_code=404, detail=f"未找到一级分类: {lv1_id}")
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"获取层级详情失败: {e}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
@app.get("/stock/{stock_code}/concepts", tags=["Stocks"])
|
||
async def get_stock_concepts(
|
||
stock_code: str,
|
||
size: int = Query(50, ge=1, le=200),
|
||
trade_date: Optional[date] = Query(None)
|
||
):
|
||
"""根据股票代码查询相关概念"""
|
||
try:
|
||
query = {
|
||
"nested": {
|
||
"path": "stocks",
|
||
"query": {
|
||
"bool": {
|
||
"should": [
|
||
{"term": {"stocks.code": stock_code}},
|
||
{"term": {"stocks.name": stock_code}}
|
||
]
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
search_body = {
|
||
"query": query,
|
||
"size": size,
|
||
"_source": {
|
||
"excludes": ["description_embedding", "insight", "stocks_json"]
|
||
}
|
||
}
|
||
|
||
es_response = es_client.search(index=INDEX_NAME, body=search_body)
|
||
|
||
concept_ids = []
|
||
concepts = []
|
||
|
||
for hit in es_response['hits']['hits']:
|
||
source = hit['_source']
|
||
concept_name = source.get('concept', '')
|
||
concept_id = source.get('concept_id', '')
|
||
concept_ids.append(concept_id)
|
||
|
||
# 获取层级
|
||
hierarchy_info = get_concept_hierarchy(concept_name)
|
||
|
||
concepts.append({
|
||
"concept_id": concept_id,
|
||
"concept": concept_name,
|
||
"description": source.get('description'),
|
||
"tags": source.get('tags', []),
|
||
"stock_count": len(source.get('stocks', [])),
|
||
"hierarchy": hierarchy_info,
|
||
"outbreak_dates": source.get('outbreak_dates', [])
|
||
})
|
||
|
||
# 获取涨跌幅
|
||
price_data = await get_concept_price_data(concept_ids, trade_date)
|
||
for concept in concepts:
|
||
concept['price_info'] = price_data.get(concept['concept_id'])
|
||
|
||
return {
|
||
"stock_code": stock_code,
|
||
"total": es_response['hits']['total']['value'],
|
||
"concepts": concepts
|
||
}
|
||
|
||
except Exception as e:
|
||
logger.error(f"查询股票概念失败: {e}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
@app.get("/statistics/hierarchy", tags=["Statistics"])
|
||
async def get_hierarchy_statistics():
|
||
"""获取层级统计信息"""
|
||
try:
|
||
stats = []
|
||
|
||
for lv1 in hierarchy_data.get('hierarchy', []):
|
||
lv1_stats = {
|
||
"lv1": lv1.get('lv1'),
|
||
"lv1_id": lv1.get('lv1_id'),
|
||
"lv2_count": len(lv1.get('children', [])),
|
||
"concept_count": 0,
|
||
"children": []
|
||
}
|
||
|
||
for child in lv1.get('children', []):
|
||
lv2_concept_count = 0
|
||
if 'children' in child:
|
||
for lv3_child in child.get('children', []):
|
||
lv2_concept_count += len(lv3_child.get('concepts', []))
|
||
else:
|
||
lv2_concept_count = len(child.get('concepts', []))
|
||
|
||
lv1_stats['children'].append({
|
||
"lv2": child.get('lv2'),
|
||
"lv2_id": child.get('lv2_id'),
|
||
"concept_count": lv2_concept_count
|
||
})
|
||
lv1_stats['concept_count'] += lv2_concept_count
|
||
|
||
stats.append(lv1_stats)
|
||
|
||
# 按概念数量排序
|
||
stats.sort(key=lambda x: x['concept_count'], reverse=True)
|
||
|
||
return {
|
||
"total_lv1": len(stats),
|
||
"total_concepts": sum(s['concept_count'] for s in stats),
|
||
"statistics": stats
|
||
}
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取层级统计失败: {e}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
# ==================== 股票搜索相关 ====================
|
||
|
||
class StockSearchResult(BaseModel):
|
||
"""股票搜索结果"""
|
||
stock_code: str
|
||
stock_name: str
|
||
concept_count: int
|
||
|
||
|
||
class StockSearchResponse(BaseModel):
|
||
"""股票搜索响应"""
|
||
total: int
|
||
stocks: List[StockSearchResult]
|
||
|
||
|
||
@app.get("/stock/search", response_model=StockSearchResponse, tags=["Stocks"])
|
||
async def search_stocks(
|
||
keyword: str = Query(..., description="股票关键词(代码或名称)"),
|
||
size: int = Query(20, ge=1, le=100, description="返回数量")
|
||
):
|
||
"""搜索股票名称或代码"""
|
||
try:
|
||
# v3 索引中股票字段是 stocks.name 和 stocks.code
|
||
search_body = {
|
||
"query": {
|
||
"nested": {
|
||
"path": "stocks",
|
||
"query": {
|
||
"bool": {
|
||
"should": [
|
||
{"wildcard": {"stocks.code": f"*{keyword}*"}},
|
||
{"match": {"stocks.name": keyword}}
|
||
]
|
||
}
|
||
}
|
||
}
|
||
},
|
||
"size": 0,
|
||
"aggs": {
|
||
"unique_stocks": {
|
||
"nested": {
|
||
"path": "stocks"
|
||
},
|
||
"aggs": {
|
||
"stock_codes": {
|
||
"terms": {
|
||
"field": "stocks.code",
|
||
"size": size * 2
|
||
},
|
||
"aggs": {
|
||
"stock_names": {
|
||
"terms": {
|
||
"field": "stocks.name",
|
||
"size": 1
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
es_response = es_client.search(index=INDEX_NAME, body=search_body)
|
||
|
||
stocks = []
|
||
buckets = es_response['aggregations']['unique_stocks']['stock_codes']['buckets']
|
||
|
||
for bucket in buckets:
|
||
stock_code = bucket['key']
|
||
concept_count = bucket['doc_count']
|
||
|
||
name_buckets = bucket['stock_names']['buckets']
|
||
stock_name = name_buckets[0]['key'] if name_buckets else stock_code
|
||
|
||
if keyword in stock_code or keyword in stock_name:
|
||
stocks.append(StockSearchResult(
|
||
stock_code=stock_code,
|
||
stock_name=stock_name,
|
||
concept_count=concept_count
|
||
))
|
||
|
||
stocks.sort(key=lambda x: x.concept_count, reverse=True)
|
||
|
||
return StockSearchResponse(
|
||
total=len(stocks),
|
||
stocks=stocks[:size]
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"搜索股票失败: {e}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
# ==================== 价格相关接口 ====================
|
||
|
||
class PriceTimeSeriesItem(BaseModel):
|
||
"""价格时间序列数据点"""
|
||
trade_date: date
|
||
avg_change_pct: Optional[float] = Field(None, description="平均涨跌幅(%)")
|
||
stock_count: Optional[int] = Field(None, description="当日股票数量")
|
||
|
||
|
||
class PriceTimeSeriesResponse(BaseModel):
|
||
"""价格时间序列响应"""
|
||
concept_id: str
|
||
concept_name: str
|
||
start_date: date
|
||
end_date: date
|
||
data_points: int
|
||
timeseries: List[PriceTimeSeriesItem]
|
||
|
||
|
||
async def get_latest_trade_date() -> Optional[date]:
|
||
"""获取最新的交易日期"""
|
||
if not mysql_pool:
|
||
return None
|
||
|
||
try:
|
||
async with mysql_pool.acquire() as conn:
|
||
async with conn.cursor() as cursor:
|
||
query = "SELECT MAX(trade_date) as latest_date FROM concept_daily_stats"
|
||
await cursor.execute(query)
|
||
result = await cursor.fetchone()
|
||
return result[0] if result and result[0] else None
|
||
except Exception as e:
|
||
logger.error(f"获取最新交易日期失败: {e}")
|
||
return None
|
||
|
||
|
||
@app.get("/price/latest", tags=["Price"])
|
||
async def get_latest_price_date():
|
||
"""获取最新的涨跌幅数据日期"""
|
||
try:
|
||
latest_date = await get_latest_trade_date()
|
||
return {
|
||
"latest_trade_date": latest_date,
|
||
"has_data": latest_date is not None
|
||
}
|
||
except Exception as e:
|
||
logger.error(f"获取最新价格日期失败: {e}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
@app.get("/concept/{concept_id}/price-timeseries", response_model=PriceTimeSeriesResponse, tags=["Price"])
|
||
async def get_concept_price_timeseries(
|
||
concept_id: str,
|
||
start_date: date = Query(..., description="开始日期,格式:YYYY-MM-DD"),
|
||
end_date: date = Query(..., description="结束日期,格式:YYYY-MM-DD")
|
||
):
|
||
"""获取概念在指定日期范围内的涨跌幅时间序列数据"""
|
||
if not mysql_pool:
|
||
logger.warning(f"[PriceTimeseries] MySQL 连接不可用")
|
||
return PriceTimeSeriesResponse(
|
||
concept_id=concept_id,
|
||
concept_name=concept_id,
|
||
start_date=start_date,
|
||
end_date=end_date,
|
||
data_points=0,
|
||
timeseries=[]
|
||
)
|
||
|
||
if start_date > end_date:
|
||
raise HTTPException(status_code=400, detail="开始日期不能晚于结束日期")
|
||
|
||
try:
|
||
async with mysql_pool.acquire() as conn:
|
||
async with conn.cursor(aiomysql.DictCursor) as cursor:
|
||
query = """
|
||
SELECT trade_date, concept_name, avg_change_pct, stock_count
|
||
FROM concept_daily_stats
|
||
WHERE concept_id = %s
|
||
AND trade_date >= %s
|
||
AND trade_date <= %s
|
||
ORDER BY trade_date ASC
|
||
"""
|
||
|
||
await cursor.execute(query, (concept_id, start_date, end_date))
|
||
rows = await cursor.fetchall()
|
||
|
||
if not rows:
|
||
raise HTTPException(
|
||
status_code=404,
|
||
detail=f"未找到概念ID {concept_id} 在 {start_date} 至 {end_date} 期间的数据"
|
||
)
|
||
|
||
timeseries = []
|
||
concept_name = ""
|
||
|
||
for row in rows:
|
||
if not concept_name and row.get('concept_name'):
|
||
concept_name = row['concept_name']
|
||
|
||
item = PriceTimeSeriesItem(
|
||
trade_date=row['trade_date'],
|
||
avg_change_pct=float(row['avg_change_pct']) if row['avg_change_pct'] is not None else None,
|
||
stock_count=row.get('stock_count')
|
||
)
|
||
timeseries.append(item)
|
||
|
||
return PriceTimeSeriesResponse(
|
||
concept_id=concept_id,
|
||
concept_name=concept_name or concept_id,
|
||
start_date=start_date,
|
||
end_date=end_date,
|
||
data_points=len(timeseries),
|
||
timeseries=timeseries
|
||
)
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"获取价格时间序列失败: {e}")
|
||
raise HTTPException(status_code=500, detail=f"获取时间序列数据失败: {str(e)}")
|
||
|
||
|
||
# ==================== 统计接口 ====================
|
||
|
||
class ConceptStatItem(BaseModel):
|
||
"""概念统计项"""
|
||
name: str
|
||
concept_id: Optional[str] = None
|
||
change_pct: Optional[float] = None
|
||
stock_count: Optional[int] = None
|
||
news_count: Optional[int] = None
|
||
report_count: Optional[int] = None
|
||
total_mentions: Optional[int] = None
|
||
volatility: Optional[float] = None
|
||
avg_change: Optional[float] = None
|
||
max_change: Optional[float] = None
|
||
consecutive_days: Optional[int] = None
|
||
total_change: Optional[float] = None
|
||
avg_daily: Optional[float] = None
|
||
hierarchy: Optional[HierarchyInfo] = None
|
||
|
||
|
||
class ConceptStatistics(BaseModel):
|
||
"""概念统计数据"""
|
||
hot_concepts: List[ConceptStatItem]
|
||
cold_concepts: List[ConceptStatItem]
|
||
active_concepts: List[ConceptStatItem]
|
||
volatile_concepts: List[ConceptStatItem]
|
||
momentum_concepts: List[ConceptStatItem]
|
||
summary: Dict[str, Any]
|
||
|
||
|
||
class ConceptStatisticsResponse(BaseModel):
|
||
"""概念统计响应"""
|
||
success: bool
|
||
data: ConceptStatistics
|
||
params: Dict[str, Any]
|
||
note: Optional[str] = None
|
||
|
||
|
||
def get_fallback_statistics(start_date, end_date) -> ConceptStatistics:
|
||
"""返回示例统计数据"""
|
||
return ConceptStatistics(
|
||
hot_concepts=[
|
||
ConceptStatItem(name="人工智能", change_pct=8.76, stock_count=45, news_count=12),
|
||
ConceptStatItem(name="机器人", change_pct=7.54, stock_count=38, news_count=10),
|
||
ConceptStatItem(name="半导体", change_pct=6.43, stock_count=52, news_count=15),
|
||
ConceptStatItem(name="低空经济", change_pct=5.21, stock_count=28, news_count=6),
|
||
ConceptStatItem(name="量子科技", change_pct=4.98, stock_count=15, news_count=9),
|
||
],
|
||
cold_concepts=[
|
||
ConceptStatItem(name="房地产", change_pct=-5.76, stock_count=33, news_count=5),
|
||
ConceptStatItem(name="煤炭", change_pct=-4.32, stock_count=25, news_count=3),
|
||
ConceptStatItem(name="钢铁", change_pct=-3.21, stock_count=28, news_count=4),
|
||
ConceptStatItem(name="传统零售", change_pct=-2.98, stock_count=19, news_count=2),
|
||
ConceptStatItem(name="纺织服装", change_pct=-2.45, stock_count=15, news_count=2),
|
||
],
|
||
active_concepts=[
|
||
ConceptStatItem(name="人工智能", news_count=45, report_count=15, total_mentions=60),
|
||
ConceptStatItem(name="半导体", news_count=42, report_count=12, total_mentions=54),
|
||
ConceptStatItem(name="机器人", news_count=38, report_count=8, total_mentions=46),
|
||
ConceptStatItem(name="新能源", news_count=28, report_count=6, total_mentions=34),
|
||
ConceptStatItem(name="量子科技", news_count=25, report_count=5, total_mentions=30),
|
||
],
|
||
volatile_concepts=[
|
||
ConceptStatItem(name="数字货币", volatility=25.6, avg_change=2.1, max_change=15.2),
|
||
ConceptStatItem(name="元宇宙", volatility=23.8, avg_change=1.8, max_change=13.9),
|
||
ConceptStatItem(name="虚拟现实", volatility=21.2, avg_change=-0.5, max_change=10.1),
|
||
ConceptStatItem(name="游戏", volatility=19.7, avg_change=3.2, max_change=12.8),
|
||
ConceptStatItem(name="短剧", volatility=18.3, avg_change=-1.1, max_change=8.1),
|
||
],
|
||
momentum_concepts=[
|
||
ConceptStatItem(name="AI应用", consecutive_days=6, total_change=19.2, avg_daily=3.2),
|
||
ConceptStatItem(name="算力", consecutive_days=5, total_change=16.8, avg_daily=3.36),
|
||
ConceptStatItem(name="光通信", consecutive_days=4, total_change=13.1, avg_daily=3.28),
|
||
ConceptStatItem(name="存储", consecutive_days=4, total_change=12.4, avg_daily=3.1),
|
||
ConceptStatItem(name="PCB", consecutive_days=3, total_change=9.6, avg_daily=3.2),
|
||
],
|
||
summary={
|
||
'total_concepts': 500,
|
||
'positive_count': 320,
|
||
'negative_count': 180,
|
||
'avg_change': 1.8,
|
||
'update_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
|
||
'date_range': f"{start_date} 至 {end_date}",
|
||
'days': (end_date - start_date).days + 1 if start_date and end_date else 7,
|
||
'start_date': str(start_date) if start_date else '',
|
||
'end_date': str(end_date) if end_date else ''
|
||
}
|
||
)
|
||
|
||
|
||
@app.get("/statistics", response_model=ConceptStatisticsResponse, tags=["Statistics"])
|
||
async def get_concept_statistics(
|
||
days: Optional[int] = Query(None, ge=1, le=90, description="统计天数范围"),
|
||
start_date: Optional[date] = Query(None, description="开始日期"),
|
||
end_date: Optional[date] = Query(None, description="结束日期"),
|
||
min_stock_count: int = Query(3, ge=1, description="最少股票数量过滤"),
|
||
concept_type: Optional[str] = Query(None, description="概念类型过滤: leaf/lv1/lv2/lv3,默认为leaf")
|
||
):
|
||
"""获取概念板块统计数据 - 涨幅榜、跌幅榜、活跃榜、波动榜、连涨榜"""
|
||
from datetime import timedelta
|
||
|
||
# 计算日期范围
|
||
if days is not None and (start_date is not None or end_date is not None):
|
||
raise HTTPException(status_code=400, detail="days参数与start_date/end_date参数不能同时使用")
|
||
|
||
if start_date is not None and end_date is not None:
|
||
if start_date > end_date:
|
||
raise HTTPException(status_code=400, detail="开始日期不能晚于结束日期")
|
||
if (end_date - start_date).days > 90:
|
||
raise HTTPException(status_code=400, detail="日期范围不能超过90天")
|
||
elif days is not None:
|
||
end_date = datetime.now().date()
|
||
start_date = end_date - timedelta(days=days)
|
||
elif start_date is not None:
|
||
end_date = datetime.now().date()
|
||
elif end_date is not None:
|
||
start_date = end_date - timedelta(days=7)
|
||
else:
|
||
end_date = datetime.now().date()
|
||
start_date = end_date - timedelta(days=7)
|
||
|
||
if not mysql_pool:
|
||
logger.warning("[Statistics] MySQL 连接不可用,使用示例数据")
|
||
return ConceptStatisticsResponse(
|
||
success=True,
|
||
data=get_fallback_statistics(start_date, end_date),
|
||
params={
|
||
'days': (end_date - start_date).days + 1,
|
||
'min_stock_count': min_stock_count,
|
||
'start_date': str(start_date),
|
||
'end_date': str(end_date)
|
||
},
|
||
note="MySQL 连接不可用,使用示例数据"
|
||
)
|
||
|
||
# 默认只查询叶子概念
|
||
type_filter = concept_type if concept_type in ['leaf', 'lv1', 'lv2', 'lv3'] else 'leaf'
|
||
|
||
try:
|
||
async with mysql_pool.acquire() as conn:
|
||
async with conn.cursor(aiomysql.DictCursor) as cursor:
|
||
|
||
# 1. 涨幅榜
|
||
hot_query = """
|
||
SELECT concept_id, concept_name,
|
||
AVG(avg_change_pct) as avg_change_pct,
|
||
AVG(stock_count) as avg_stock_count,
|
||
COUNT(*) as trading_days
|
||
FROM concept_daily_stats
|
||
WHERE trade_date >= %s AND trade_date <= %s
|
||
AND avg_change_pct IS NOT NULL AND stock_count >= %s
|
||
AND concept_type = %s
|
||
GROUP BY concept_id, concept_name
|
||
HAVING COUNT(*) >= 2
|
||
ORDER BY AVG(avg_change_pct) DESC
|
||
LIMIT 10
|
||
"""
|
||
await cursor.execute(hot_query, (start_date, end_date, min_stock_count, type_filter))
|
||
hot_rows = await cursor.fetchall()
|
||
|
||
# 2. 跌幅榜
|
||
cold_query = """
|
||
SELECT concept_id, concept_name,
|
||
AVG(avg_change_pct) as avg_change_pct,
|
||
AVG(stock_count) as avg_stock_count,
|
||
COUNT(*) as trading_days
|
||
FROM concept_daily_stats
|
||
WHERE trade_date >= %s AND trade_date <= %s
|
||
AND avg_change_pct IS NOT NULL AND stock_count >= %s
|
||
AND concept_type = %s
|
||
GROUP BY concept_id, concept_name
|
||
HAVING COUNT(*) >= 2
|
||
ORDER BY AVG(avg_change_pct) ASC
|
||
LIMIT 10
|
||
"""
|
||
await cursor.execute(cold_query, (start_date, end_date, min_stock_count, type_filter))
|
||
cold_rows = await cursor.fetchall()
|
||
|
||
# 3. 活跃榜
|
||
active_query = """
|
||
SELECT concept_id, concept_name,
|
||
COUNT(*) as trading_days,
|
||
AVG(stock_count) as avg_stock_count
|
||
FROM concept_daily_stats
|
||
WHERE trade_date >= %s AND trade_date <= %s AND stock_count >= %s
|
||
AND concept_type = %s
|
||
GROUP BY concept_id, concept_name
|
||
ORDER BY COUNT(*) DESC, AVG(stock_count) DESC
|
||
LIMIT 10
|
||
"""
|
||
await cursor.execute(active_query, (start_date, end_date, min_stock_count, type_filter))
|
||
active_rows = await cursor.fetchall()
|
||
|
||
# 4. 波动榜
|
||
volatile_query = """
|
||
SELECT concept_id, concept_name,
|
||
STDDEV(avg_change_pct) as volatility,
|
||
AVG(avg_change_pct) as avg_change_pct,
|
||
MAX(avg_change_pct) as max_change_pct
|
||
FROM concept_daily_stats
|
||
WHERE trade_date >= %s AND trade_date <= %s
|
||
AND avg_change_pct IS NOT NULL AND stock_count >= %s
|
||
AND concept_type = %s
|
||
GROUP BY concept_id, concept_name
|
||
HAVING COUNT(*) >= 3 AND STDDEV(avg_change_pct) IS NOT NULL
|
||
ORDER BY STDDEV(avg_change_pct) DESC
|
||
LIMIT 10
|
||
"""
|
||
await cursor.execute(volatile_query, (start_date, end_date, min_stock_count, type_filter))
|
||
volatile_rows = await cursor.fetchall()
|
||
|
||
# 5. 连涨榜
|
||
momentum_query = """
|
||
SELECT concept_id, concept_name,
|
||
COUNT(*) as positive_days,
|
||
SUM(avg_change_pct) as total_change,
|
||
AVG(avg_change_pct) as avg_change_pct
|
||
FROM concept_daily_stats
|
||
WHERE trade_date >= %s AND trade_date <= %s
|
||
AND avg_change_pct > 0 AND stock_count >= %s
|
||
AND concept_type = %s
|
||
GROUP BY concept_id, concept_name
|
||
HAVING COUNT(*) >= 2
|
||
ORDER BY COUNT(*) DESC, AVG(avg_change_pct) DESC
|
||
LIMIT 10
|
||
"""
|
||
await cursor.execute(momentum_query, (start_date, end_date, min_stock_count, type_filter))
|
||
momentum_rows = await cursor.fetchall()
|
||
|
||
# 6. 总体统计
|
||
total_query = """
|
||
SELECT COUNT(DISTINCT concept_id) as total_concepts,
|
||
COUNT(DISTINCT CASE WHEN avg_change_pct > 0 THEN concept_id END) as positive_concepts,
|
||
COUNT(DISTINCT CASE WHEN avg_change_pct < 0 THEN concept_id END) as negative_concepts,
|
||
AVG(avg_change_pct) as overall_avg_change
|
||
FROM concept_daily_stats
|
||
WHERE trade_date >= %s AND trade_date <= %s AND avg_change_pct IS NOT NULL
|
||
AND concept_type = %s
|
||
"""
|
||
await cursor.execute(total_query, (start_date, end_date, type_filter))
|
||
total_row = await cursor.fetchone()
|
||
|
||
# 构建响应
|
||
def build_items(rows, item_type):
|
||
items = []
|
||
for row in rows:
|
||
concept_name = row['concept_name']
|
||
hierarchy_info = get_concept_hierarchy(concept_name)
|
||
|
||
item = ConceptStatItem(
|
||
name=concept_name,
|
||
concept_id=row.get('concept_id'),
|
||
hierarchy=HierarchyInfo(**hierarchy_info) if hierarchy_info else None
|
||
)
|
||
|
||
if item_type in ['hot', 'cold']:
|
||
item.change_pct = round(float(row['avg_change_pct']), 2) if row['avg_change_pct'] else 0.0
|
||
item.stock_count = int(row['avg_stock_count']) if row['avg_stock_count'] else 0
|
||
item.news_count = int(row['trading_days']) if row['trading_days'] else 0
|
||
|
||
elif item_type == 'active':
|
||
item.news_count = int(row['trading_days']) if row['trading_days'] else 0
|
||
item.stock_count = int(row['avg_stock_count']) if row['avg_stock_count'] else 0
|
||
item.report_count = max(1, int(row['trading_days'] * 0.3)) if row['trading_days'] else 1
|
||
item.total_mentions = item.news_count + item.report_count
|
||
|
||
elif item_type == 'volatile':
|
||
item.volatility = round(float(row['volatility']), 2) if row['volatility'] else 0.0
|
||
item.avg_change = round(float(row['avg_change_pct']), 2) if row['avg_change_pct'] else 0.0
|
||
item.max_change = round(float(row['max_change_pct']), 2) if row['max_change_pct'] else 0.0
|
||
|
||
elif item_type == 'momentum':
|
||
item.consecutive_days = int(row['positive_days']) if row['positive_days'] else 0
|
||
item.total_change = round(float(row['total_change']), 2) if row['total_change'] else 0.0
|
||
item.avg_daily = round(float(row['avg_change_pct']), 2) if row['avg_change_pct'] else 0.0
|
||
|
||
items.append(item)
|
||
return items
|
||
|
||
statistics = ConceptStatistics(
|
||
hot_concepts=build_items(hot_rows, 'hot'),
|
||
cold_concepts=build_items(cold_rows, 'cold'),
|
||
active_concepts=build_items(active_rows, 'active'),
|
||
volatile_concepts=build_items(volatile_rows, 'volatile'),
|
||
momentum_concepts=build_items(momentum_rows, 'momentum'),
|
||
summary={
|
||
'total_concepts': int(total_row['total_concepts']) if total_row['total_concepts'] else 0,
|
||
'positive_count': int(total_row['positive_concepts']) if total_row['positive_concepts'] else 0,
|
||
'negative_count': int(total_row['negative_concepts']) if total_row['negative_concepts'] else 0,
|
||
'avg_change': round(float(total_row['overall_avg_change']), 2) if total_row['overall_avg_change'] else 0.0,
|
||
'update_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
|
||
'date_range': f"{start_date} 至 {end_date}",
|
||
'days': (end_date - start_date).days + 1,
|
||
'start_date': str(start_date),
|
||
'end_date': str(end_date)
|
||
}
|
||
)
|
||
|
||
# 补充空数据
|
||
if not statistics.hot_concepts or not statistics.cold_concepts:
|
||
fallback = get_fallback_statistics(start_date, end_date)
|
||
if not statistics.hot_concepts:
|
||
statistics.hot_concepts = fallback.hot_concepts
|
||
if not statistics.cold_concepts:
|
||
statistics.cold_concepts = fallback.cold_concepts
|
||
|
||
return ConceptStatisticsResponse(
|
||
success=True,
|
||
data=statistics,
|
||
params={
|
||
'days': (end_date - start_date).days + 1,
|
||
'min_stock_count': min_stock_count,
|
||
'start_date': str(start_date),
|
||
'end_date': str(end_date),
|
||
'concept_type': type_filter
|
||
}
|
||
)
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"获取概念统计失败: {e}")
|
||
return ConceptStatisticsResponse(
|
||
success=True,
|
||
data=get_fallback_statistics(start_date, end_date),
|
||
params={
|
||
'days': (end_date - start_date).days + 1,
|
||
'min_stock_count': min_stock_count,
|
||
'start_date': str(start_date),
|
||
'end_date': str(end_date),
|
||
'concept_type': type_filter
|
||
},
|
||
note=f"使用示例数据,原因: {str(e)}"
|
||
)
|
||
|
||
|
||
# ==================== 层级概念涨跌幅接口 ====================
|
||
|
||
class HierarchyPriceItem(BaseModel):
|
||
"""层级概念价格项"""
|
||
concept_id: str
|
||
concept_name: str
|
||
concept_type: str # lv1/lv2/lv3
|
||
trade_date: date
|
||
avg_change_pct: Optional[float] = None
|
||
stock_count: Optional[int] = None
|
||
|
||
|
||
class HierarchyPriceResponse(BaseModel):
|
||
"""层级概念价格响应"""
|
||
trade_date: date
|
||
lv1_concepts: List[HierarchyPriceItem]
|
||
lv2_concepts: List[HierarchyPriceItem]
|
||
lv3_concepts: List[HierarchyPriceItem]
|
||
total_count: int
|
||
|
||
|
||
class ConceptPriceItem(BaseModel):
|
||
"""概念价格项"""
|
||
concept_id: str
|
||
concept_name: str
|
||
concept_type: str # leaf/lv1/lv2/lv3
|
||
trade_date: date
|
||
avg_change_pct: Optional[float] = None
|
||
stock_count: Optional[int] = None
|
||
hierarchy: Optional[HierarchyInfo] = None
|
||
|
||
|
||
class ConceptPriceListResponse(BaseModel):
|
||
"""概念价格列表响应"""
|
||
trade_date: date
|
||
total: int
|
||
concepts: List[ConceptPriceItem]
|
||
|
||
|
||
@app.get("/price/list", response_model=ConceptPriceListResponse, tags=["Price"])
|
||
async def get_concept_price_list(
|
||
trade_date: Optional[date] = Query(None, description="交易日期,默认最新"),
|
||
concept_type: Optional[str] = Query(None, description="概念类型: leaf/lv1/lv2/lv3,默认全部"),
|
||
sort_by: str = Query("change_desc", description="排序: change_desc(涨幅降序), change_asc(涨幅升序), stock_count(股票数)"),
|
||
limit: int = Query(100, ge=1, le=1000, description="返回数量"),
|
||
offset: int = Query(0, ge=0, description="偏移量")
|
||
):
|
||
"""批量获取概念涨跌幅列表"""
|
||
if not mysql_pool:
|
||
raise HTTPException(status_code=503, detail="MySQL连接不可用")
|
||
|
||
try:
|
||
async with mysql_pool.acquire() as conn:
|
||
async with conn.cursor(aiomysql.DictCursor) as cursor:
|
||
# 获取交易日期
|
||
if trade_date is None:
|
||
await cursor.execute("SELECT MAX(trade_date) as max_date FROM concept_daily_stats")
|
||
result = await cursor.fetchone()
|
||
if not result or not result['max_date']:
|
||
raise HTTPException(status_code=404, detail="无涨跌幅数据")
|
||
trade_date = result['max_date']
|
||
|
||
# 构建查询
|
||
where_conditions = ["trade_date = %s"]
|
||
params = [trade_date]
|
||
|
||
if concept_type and concept_type in ['leaf', 'lv1', 'lv2', 'lv3']:
|
||
where_conditions.append("concept_type = %s")
|
||
params.append(concept_type)
|
||
|
||
# 排序
|
||
order_clause = "avg_change_pct DESC"
|
||
if sort_by == "change_asc":
|
||
order_clause = "avg_change_pct ASC"
|
||
elif sort_by == "stock_count":
|
||
order_clause = "stock_count DESC"
|
||
|
||
# 获取总数
|
||
count_query = f"SELECT COUNT(*) as cnt FROM concept_daily_stats WHERE {' AND '.join(where_conditions)}"
|
||
await cursor.execute(count_query, params)
|
||
total = (await cursor.fetchone())['cnt']
|
||
|
||
# 获取数据
|
||
query = f"""
|
||
SELECT concept_id, concept_name, concept_type, trade_date, avg_change_pct, stock_count
|
||
FROM concept_daily_stats
|
||
WHERE {' AND '.join(where_conditions)}
|
||
ORDER BY {order_clause}
|
||
LIMIT %s OFFSET %s
|
||
"""
|
||
params.extend([limit, offset])
|
||
await cursor.execute(query, params)
|
||
rows = await cursor.fetchall()
|
||
|
||
concepts = []
|
||
for row in rows:
|
||
concept_name = row['concept_name']
|
||
# 获取层级信息(仅对叶子概念)
|
||
hierarchy = None
|
||
if row['concept_type'] == 'leaf':
|
||
hierarchy_info = get_concept_hierarchy(concept_name)
|
||
if hierarchy_info:
|
||
hierarchy = HierarchyInfo(**hierarchy_info)
|
||
|
||
concepts.append(ConceptPriceItem(
|
||
concept_id=row['concept_id'],
|
||
concept_name=concept_name,
|
||
concept_type=row['concept_type'],
|
||
trade_date=row['trade_date'],
|
||
avg_change_pct=float(row['avg_change_pct']) if row['avg_change_pct'] else None,
|
||
stock_count=row['stock_count'],
|
||
hierarchy=hierarchy
|
||
))
|
||
|
||
return ConceptPriceListResponse(
|
||
trade_date=trade_date,
|
||
total=total,
|
||
concepts=concepts
|
||
)
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"获取概念涨跌幅列表失败: {e}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
@app.get("/hierarchy/{hierarchy_id}/price-timeseries", tags=["Hierarchy"])
|
||
async def get_hierarchy_price_timeseries(
|
||
hierarchy_id: str,
|
||
start_date: date = Query(..., description="开始日期"),
|
||
end_date: date = Query(..., description="结束日期")
|
||
):
|
||
"""获取层级概念(lv1/lv2/lv3)的价格时间序列"""
|
||
if not mysql_pool:
|
||
raise HTTPException(status_code=503, detail="MySQL连接不可用")
|
||
|
||
if start_date > end_date:
|
||
raise HTTPException(status_code=400, detail="开始日期不能晚于结束日期")
|
||
|
||
try:
|
||
async with mysql_pool.acquire() as conn:
|
||
async with conn.cursor(aiomysql.DictCursor) as cursor:
|
||
query = """
|
||
SELECT concept_id, concept_name, concept_type, trade_date, avg_change_pct, stock_count
|
||
FROM concept_daily_stats
|
||
WHERE concept_id = %s
|
||
AND trade_date >= %s AND trade_date <= %s
|
||
ORDER BY trade_date ASC
|
||
"""
|
||
await cursor.execute(query, (hierarchy_id, start_date, end_date))
|
||
rows = await cursor.fetchall()
|
||
|
||
if not rows:
|
||
raise HTTPException(status_code=404, detail=f"未找到层级概念 {hierarchy_id} 的数据")
|
||
|
||
timeseries = []
|
||
concept_name = ""
|
||
concept_type = ""
|
||
|
||
for row in rows:
|
||
if not concept_name:
|
||
concept_name = row['concept_name']
|
||
concept_type = row['concept_type']
|
||
|
||
timeseries.append({
|
||
'trade_date': row['trade_date'],
|
||
'avg_change_pct': float(row['avg_change_pct']) if row['avg_change_pct'] else None,
|
||
'stock_count': row['stock_count']
|
||
})
|
||
|
||
return {
|
||
'concept_id': hierarchy_id,
|
||
'concept_name': concept_name,
|
||
'concept_type': concept_type,
|
||
'start_date': start_date,
|
||
'end_date': end_date,
|
||
'data_points': len(timeseries),
|
||
'timeseries': timeseries
|
||
}
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"获取层级价格时间序列失败: {e}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
# ==================== 主函数 ====================
|
||
|
||
if __name__ == "__main__":
|
||
import uvicorn
|
||
|
||
uvicorn.run(
|
||
"concept_api_v2:app",
|
||
host="0.0.0.0",
|
||
port=6801,
|
||
reload=True,
|
||
log_level="info"
|
||
)
|