1850 lines
71 KiB
Python
1850 lines
71 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
概念搜索API V4 - 适配 concept_library_v4 索引
|
||
新特性:
|
||
1. 支持股票原因(reason)搜索
|
||
2. 优化的股票信息结构
|
||
3. 保留层级结构(lv1/lv2/lv3)支持
|
||
4. 新增reason搜索接口
|
||
"""
|
||
import json
|
||
import openai
|
||
from typing import List, Dict, Optional, Union, Any
|
||
from fastapi import FastAPI, HTTPException, Query
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from fastapi.responses import PlainTextResponse
|
||
from pydantic import BaseModel, Field
|
||
from elasticsearch import Elasticsearch
|
||
from datetime import datetime, date
|
||
import logging
|
||
import re
|
||
from contextlib import asynccontextmanager
|
||
import aiomysql
|
||
import os
|
||
|
||
# 配置日志
|
||
logging.basicConfig(level=logging.INFO)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# ==================== 配置 ====================
|
||
|
||
# ES配置
|
||
ES_HOST = 'http://127.0.0.1:9200'
|
||
INDEX_NAME = 'concept_library_v3'
|
||
|
||
# Embedding配置
|
||
OPENAI_BASE_URL = "http://127.0.0.1:8000/v1"
|
||
OPENAI_API_KEY = "dummy"
|
||
EMBEDDING_MODEL = "qwen3-embedding-8b"
|
||
|
||
# 层级结构文件
|
||
HIERARCHY_FILE = 'concept_hierarchy_v3.json'
|
||
|
||
# MySQL配置
|
||
MYSQL_CONFIG = {
|
||
'host': '110.42.32.207',
|
||
'port': 3306,
|
||
'user': 'root',
|
||
'password': 'Zzl33818!',
|
||
'db': 'stock',
|
||
'charset': 'utf8mb4',
|
||
'autocommit': True,
|
||
'minsize': 1,
|
||
'maxsize': 10
|
||
}
|
||
|
||
# ==================== 全局变量 ====================
|
||
|
||
es_client = None
|
||
openai_client = None
|
||
mysql_pool = None
|
||
|
||
# 层级结构相关
|
||
hierarchy_data = {}
|
||
concept_to_hierarchy = {}
|
||
|
||
|
||
def load_hierarchy():
|
||
"""加载层级结构并建立概念到层级的映射"""
|
||
global hierarchy_data, concept_to_hierarchy
|
||
|
||
hierarchy_path = os.path.join(os.path.dirname(__file__), HIERARCHY_FILE)
|
||
if not os.path.exists(hierarchy_path):
|
||
logger.warning(f"层级文件不存在: {hierarchy_path}")
|
||
return
|
||
|
||
try:
|
||
with open(hierarchy_path, 'r', encoding='utf-8') as f:
|
||
hierarchy_data = json.load(f)
|
||
|
||
for lv1 in hierarchy_data.get('hierarchy', []):
|
||
lv1_name = lv1.get('lv1', '')
|
||
lv1_id = lv1.get('lv1_id', '')
|
||
|
||
for child in lv1.get('children', []):
|
||
lv2_name = child.get('lv2', '')
|
||
lv2_id = child.get('lv2_id', '')
|
||
|
||
if 'children' in child:
|
||
for lv3_child in child.get('children', []):
|
||
lv3_name = lv3_child.get('lv3', '')
|
||
lv3_id = lv3_child.get('lv3_id', '')
|
||
|
||
for concept in lv3_child.get('concepts', []):
|
||
concept_to_hierarchy[concept] = {
|
||
'lv1': lv1_name, 'lv1_id': lv1_id,
|
||
'lv2': lv2_name, 'lv2_id': lv2_id,
|
||
'lv3': lv3_name, 'lv3_id': lv3_id
|
||
}
|
||
else:
|
||
for concept in child.get('concepts', []):
|
||
concept_to_hierarchy[concept] = {
|
||
'lv1': lv1_name, 'lv1_id': lv1_id,
|
||
'lv2': lv2_name, 'lv2_id': lv2_id,
|
||
'lv3': None, 'lv3_id': None
|
||
}
|
||
|
||
logger.info(f"加载层级结构完成,共 {len(concept_to_hierarchy)} 个概念有层级信息")
|
||
|
||
except Exception as e:
|
||
logger.error(f"加载层级结构失败: {e}")
|
||
|
||
|
||
def get_concept_hierarchy(concept_name: str) -> Optional[Dict]:
|
||
"""获取概念的层级信息"""
|
||
return concept_to_hierarchy.get(concept_name)
|
||
|
||
|
||
# ==================== 生命周期管理 ====================
|
||
|
||
@asynccontextmanager
|
||
async def lifespan(app: FastAPI):
|
||
global es_client, openai_client, mysql_pool
|
||
|
||
load_hierarchy()
|
||
|
||
es_client = Elasticsearch(
|
||
[ES_HOST],
|
||
timeout=30,
|
||
max_retries=3,
|
||
retry_on_timeout=True
|
||
)
|
||
logger.info(f"Connected to Elasticsearch at {ES_HOST}, index: {INDEX_NAME}")
|
||
|
||
openai_client = openai.OpenAI(
|
||
api_key=OPENAI_API_KEY,
|
||
base_url=OPENAI_BASE_URL,
|
||
timeout=60,
|
||
)
|
||
logger.info(f"Initialized OpenAI client")
|
||
|
||
try:
|
||
mysql_pool = await aiomysql.create_pool(**MYSQL_CONFIG)
|
||
logger.info(f"Connected to MySQL at {MYSQL_CONFIG['host']}")
|
||
except Exception as e:
|
||
logger.error(f"Failed to connect to MySQL: {e}")
|
||
|
||
yield
|
||
|
||
if es_client:
|
||
es_client.close()
|
||
if mysql_pool:
|
||
mysql_pool.close()
|
||
await mysql_pool.wait_closed()
|
||
logger.info("Cleanup completed")
|
||
|
||
|
||
# ==================== FastAPI 应用 ====================
|
||
|
||
app = FastAPI(
|
||
title="概念搜索API V4",
|
||
description="支持股票原因搜索的概念库API,适配 concept_library_v4 索引",
|
||
version="4.0.0",
|
||
lifespan=lifespan
|
||
)
|
||
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=["*"],
|
||
allow_credentials=True,
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
|
||
|
||
# ==================== 微信小程序验证 ====================
|
||
|
||
@app.get("/DfASFmNQoo.txt", response_class=PlainTextResponse)
|
||
async def wechat_verification():
|
||
return "ebd78eb22819b1393a34c6ae1e8fcce6"
|
||
|
||
|
||
# ==================== 数据模型 ====================
|
||
|
||
class HierarchyInfo(BaseModel):
|
||
lv1: Optional[str] = None
|
||
lv1_id: Optional[str] = None
|
||
lv2: Optional[str] = None
|
||
lv2_id: Optional[str] = None
|
||
lv3: Optional[str] = None
|
||
lv3_id: Optional[str] = None
|
||
|
||
|
||
class StockInfo(BaseModel):
|
||
"""股票信息(包含原因)"""
|
||
name: str
|
||
code: Optional[str] = None
|
||
reason: Optional[str] = None
|
||
|
||
|
||
class ConceptPriceInfo(BaseModel):
|
||
trade_date: date
|
||
avg_change_pct: Optional[float] = None
|
||
|
||
|
||
class SearchRequest(BaseModel):
|
||
query: str
|
||
size: int = Field(10, ge=1, le=100)
|
||
page: int = Field(1, ge=1)
|
||
search_size: int = Field(100, ge=10, le=1000)
|
||
semantic_weight: Optional[float] = Field(None, ge=0.0, le=1.0)
|
||
filter_stocks: Optional[List[str]] = None
|
||
filter_lv1: Optional[str] = None
|
||
filter_lv2: Optional[str] = None
|
||
trade_date: Optional[date] = None
|
||
sort_by: str = Field("_score")
|
||
use_knn: bool = Field(True)
|
||
search_reason: bool = Field(False, description="是否同时搜索股票原因")
|
||
include_stock_reasons: bool = Field(False, description="是否在返回的股票列表中包含原因")
|
||
|
||
|
||
class ReasonSearchRequest(BaseModel):
|
||
"""股票原因搜索请求"""
|
||
query: str = Field(..., description="搜索关键词")
|
||
size: int = Field(20, ge=1, le=100)
|
||
page: int = Field(1, ge=1)
|
||
include_stock_details: bool = Field(True, description="是否返回匹配的股票详情")
|
||
|
||
|
||
class ConceptResult(BaseModel):
|
||
concept_id: str
|
||
concept: str
|
||
description: Optional[str] = None
|
||
tags: Optional[List[str]] = None
|
||
outbreak_dates: Optional[List[str]] = None
|
||
stocks: List[StockInfo] = Field(default_factory=list, description="股票列表(是否包含原因取决于include_stock_reasons参数)")
|
||
stock_count: int = 0
|
||
hierarchy: Optional[HierarchyInfo] = None
|
||
score: float = 0.0
|
||
match_type: str = "keyword"
|
||
highlights: Optional[Dict[str, List[str]]] = None
|
||
price_info: Optional[ConceptPriceInfo] = None
|
||
matched_stocks: Optional[List[StockInfo]] = Field(None, description="匹配的股票(含原因,仅在search_reason时返回)")
|
||
|
||
|
||
class SearchResponse(BaseModel):
|
||
total: int
|
||
took_ms: int
|
||
results: List[ConceptResult]
|
||
search_info: Dict[str, Any]
|
||
price_date: Optional[date] = None
|
||
page: int = 1
|
||
total_pages: int = 1
|
||
|
||
|
||
class ReasonSearchResult(BaseModel):
|
||
"""原因搜索结果"""
|
||
concept_id: str
|
||
concept: str
|
||
description: Optional[str] = None
|
||
score: float
|
||
matched_stocks: List[StockInfo] = Field(default_factory=list, description="匹配的股票及原因")
|
||
reason_highlights: Optional[List[str]] = None
|
||
|
||
|
||
class ReasonSearchResponse(BaseModel):
|
||
"""原因搜索响应"""
|
||
total: int
|
||
took_ms: int
|
||
query: str
|
||
results: List[ReasonSearchResult]
|
||
page: int = 1
|
||
total_pages: int = 1
|
||
|
||
|
||
class ConceptDetailResponse(BaseModel):
|
||
concept_id: str
|
||
concept: str
|
||
description: Optional[str] = None
|
||
insight: Optional[str] = None
|
||
tags: Optional[List[str]] = None
|
||
outbreak_dates: Optional[List[str]] = None
|
||
stocks: List[StockInfo] = Field(default_factory=list, description="股票列表(含原因)")
|
||
stock_count: int = 0
|
||
hierarchy: Optional[HierarchyInfo] = None
|
||
folders: Optional[List[str]] = None
|
||
created_at: Optional[str] = None
|
||
price_info: Optional[ConceptPriceInfo] = None
|
||
|
||
|
||
class HierarchyLevel(BaseModel):
|
||
id: str
|
||
name: str
|
||
concept_count: int = 0
|
||
children: Optional[List['HierarchyLevel']] = None
|
||
concepts: Optional[List[str]] = None
|
||
|
||
|
||
class HierarchyResponse(BaseModel):
|
||
hierarchy: List[HierarchyLevel]
|
||
total_concepts: int
|
||
|
||
|
||
class PriceTimeSeriesItem(BaseModel):
|
||
trade_date: date
|
||
avg_change_pct: Optional[float] = None
|
||
stock_count: Optional[int] = None
|
||
|
||
|
||
class PriceTimeSeriesResponse(BaseModel):
|
||
concept_id: str
|
||
concept_name: str
|
||
start_date: date
|
||
end_date: date
|
||
data_points: int
|
||
timeseries: List[PriceTimeSeriesItem]
|
||
|
||
|
||
class ConceptStatItem(BaseModel):
|
||
name: str
|
||
concept_id: Optional[str] = None
|
||
change_pct: Optional[float] = None
|
||
stock_count: Optional[int] = None
|
||
news_count: Optional[int] = None
|
||
report_count: Optional[int] = None
|
||
total_mentions: Optional[int] = None
|
||
volatility: Optional[float] = None
|
||
avg_change: Optional[float] = None
|
||
max_change: Optional[float] = None
|
||
consecutive_days: Optional[int] = None
|
||
total_change: Optional[float] = None
|
||
avg_daily: Optional[float] = None
|
||
hierarchy: Optional[HierarchyInfo] = None
|
||
|
||
|
||
class ConceptStatistics(BaseModel):
|
||
hot_concepts: List[ConceptStatItem]
|
||
cold_concepts: List[ConceptStatItem]
|
||
active_concepts: List[ConceptStatItem]
|
||
volatile_concepts: List[ConceptStatItem]
|
||
momentum_concepts: List[ConceptStatItem]
|
||
summary: Dict[str, Any]
|
||
|
||
|
||
class ConceptStatisticsResponse(BaseModel):
|
||
success: bool
|
||
data: ConceptStatistics
|
||
params: Dict[str, Any]
|
||
note: Optional[str] = None
|
||
|
||
|
||
class ConceptPriceItem(BaseModel):
|
||
concept_id: str
|
||
concept_name: str
|
||
concept_type: str
|
||
trade_date: date
|
||
avg_change_pct: Optional[float] = None
|
||
stock_count: Optional[int] = None
|
||
hierarchy: Optional[HierarchyInfo] = None
|
||
|
||
|
||
class ConceptPriceListResponse(BaseModel):
|
||
trade_date: date
|
||
total: int
|
||
concepts: List[ConceptPriceItem]
|
||
|
||
|
||
class StockSearchResult(BaseModel):
|
||
stock_code: str
|
||
stock_name: str
|
||
concept_count: int
|
||
|
||
|
||
class StockSearchResponse(BaseModel):
|
||
total: int
|
||
stocks: List[StockSearchResult]
|
||
|
||
|
||
# ==================== 辅助函数 ====================
|
||
|
||
def generate_embedding(text: str) -> List[float]:
|
||
try:
|
||
if not text or len(text.strip()) == 0:
|
||
return []
|
||
text = text[:8000] if len(text) > 8000 else text
|
||
if not openai_client:
|
||
return []
|
||
response = openai_client.embeddings.create(model=EMBEDDING_MODEL, input=[text])
|
||
return response.data[0].embedding
|
||
except Exception as e:
|
||
logger.warning(f"Embedding生成失败: {e}")
|
||
return []
|
||
|
||
|
||
def calculate_semantic_weight(query: str) -> float:
|
||
if not query or not query.strip():
|
||
return 0.0
|
||
query_length = len(query.strip())
|
||
if query_length < 10:
|
||
return 0.3
|
||
elif query_length < 50:
|
||
return 0.5
|
||
elif query_length < 200:
|
||
return 0.6
|
||
else:
|
||
return 0.7
|
||
|
||
|
||
async def get_concept_price_data(concept_ids: List[str], trade_date: Optional[date] = None) -> Dict[str, ConceptPriceInfo]:
|
||
if not mysql_pool or not concept_ids:
|
||
return {}
|
||
try:
|
||
async with mysql_pool.acquire() as conn:
|
||
async with conn.cursor(aiomysql.DictCursor) as cursor:
|
||
placeholders = ','.join(['%s'] * len(concept_ids))
|
||
if trade_date:
|
||
query = f"""
|
||
SELECT concept_id, trade_date, avg_change_pct
|
||
FROM concept_daily_stats
|
||
WHERE concept_id IN ({placeholders}) AND trade_date = %s
|
||
"""
|
||
await cursor.execute(query, (*concept_ids, trade_date))
|
||
else:
|
||
query = f"""
|
||
SELECT cds.concept_id, cds.trade_date, cds.avg_change_pct
|
||
FROM concept_daily_stats cds
|
||
INNER JOIN (
|
||
SELECT concept_id, MAX(trade_date) as max_date
|
||
FROM concept_daily_stats
|
||
WHERE concept_id IN ({placeholders})
|
||
GROUP BY concept_id
|
||
) latest ON cds.concept_id = latest.concept_id AND cds.trade_date = latest.max_date
|
||
"""
|
||
await cursor.execute(query, concept_ids)
|
||
rows = await cursor.fetchall()
|
||
result = {}
|
||
for row in rows:
|
||
result[row['concept_id']] = ConceptPriceInfo(
|
||
trade_date=row['trade_date'],
|
||
avg_change_pct=float(row['avg_change_pct']) if row['avg_change_pct'] is not None else None
|
||
)
|
||
return result
|
||
except Exception as e:
|
||
logger.error(f"获取涨跌幅数据失败: {e}")
|
||
return {}
|
||
|
||
|
||
async def get_top_concepts_by_change(
|
||
trade_date: Optional[date],
|
||
limit: int = 100,
|
||
offset: int = 0,
|
||
filter_lv1: Optional[str] = None,
|
||
filter_lv2: Optional[str] = None
|
||
) -> tuple:
|
||
"""
|
||
直接从 MySQL 获取涨跌幅排序的概念列表(用于空查询优化)
|
||
返回: (概念列表, 总数, 实际查询日期)
|
||
"""
|
||
if not mysql_pool:
|
||
return [], 0, None
|
||
|
||
try:
|
||
async with mysql_pool.acquire() as conn:
|
||
async with conn.cursor(aiomysql.DictCursor) as cursor:
|
||
query_date = trade_date
|
||
if query_date is None:
|
||
await cursor.execute("SELECT MAX(trade_date) as max_date FROM concept_daily_stats WHERE concept_type = 'leaf'")
|
||
result = await cursor.fetchone()
|
||
if not result or not result['max_date']:
|
||
return [], 0, None
|
||
query_date = result['max_date']
|
||
|
||
where_conditions = ["trade_date = %s", "concept_type = 'leaf'"]
|
||
params = [query_date]
|
||
|
||
if filter_lv1 or filter_lv2:
|
||
filtered_concepts = []
|
||
for concept_name, hierarchy in concept_to_hierarchy.items():
|
||
if filter_lv1 and hierarchy.get('lv1') != filter_lv1:
|
||
continue
|
||
if filter_lv2 and hierarchy.get('lv2') != filter_lv2:
|
||
continue
|
||
filtered_concepts.append(concept_name)
|
||
|
||
if not filtered_concepts:
|
||
return [], 0, query_date
|
||
|
||
placeholders = ','.join(['%s'] * len(filtered_concepts))
|
||
where_conditions.append(f"concept_name IN ({placeholders})")
|
||
params.extend(filtered_concepts)
|
||
|
||
where_clause = " AND ".join(where_conditions)
|
||
|
||
count_query = f"SELECT COUNT(*) as cnt FROM concept_daily_stats WHERE {where_clause}"
|
||
await cursor.execute(count_query, params)
|
||
total = (await cursor.fetchone())['cnt']
|
||
|
||
data_query = f"""
|
||
SELECT concept_id, concept_name, avg_change_pct, stock_count
|
||
FROM concept_daily_stats
|
||
WHERE {where_clause}
|
||
ORDER BY avg_change_pct DESC
|
||
LIMIT %s OFFSET %s
|
||
"""
|
||
await cursor.execute(data_query, params + [limit, offset])
|
||
rows = await cursor.fetchall()
|
||
|
||
concepts = []
|
||
for row in rows:
|
||
concepts.append({
|
||
'concept_id': row['concept_id'],
|
||
'concept_name': row['concept_name'],
|
||
'avg_change_pct': float(row['avg_change_pct']) if row['avg_change_pct'] else None,
|
||
'stock_count': row['stock_count'],
|
||
'trade_date': query_date
|
||
})
|
||
|
||
return concepts, total, query_date
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取涨跌幅排序概念失败: {e}")
|
||
return [], 0, None
|
||
|
||
|
||
async def get_latest_trade_date() -> Optional[date]:
|
||
"""获取最新的交易日期"""
|
||
if not mysql_pool:
|
||
return None
|
||
try:
|
||
async with mysql_pool.acquire() as conn:
|
||
async with conn.cursor() as cursor:
|
||
await cursor.execute("SELECT MAX(trade_date) as latest_date FROM concept_daily_stats")
|
||
result = await cursor.fetchone()
|
||
return result[0] if result and result[0] else None
|
||
except Exception as e:
|
||
logger.error(f"获取最新交易日期失败: {e}")
|
||
return None
|
||
|
||
|
||
def get_fallback_statistics(start_date, end_date) -> ConceptStatistics:
|
||
"""返回示例统计数据"""
|
||
return ConceptStatistics(
|
||
hot_concepts=[
|
||
ConceptStatItem(name="人工智能", change_pct=8.76, stock_count=45, news_count=12),
|
||
ConceptStatItem(name="机器人", change_pct=7.54, stock_count=38, news_count=10),
|
||
ConceptStatItem(name="半导体", change_pct=6.43, stock_count=52, news_count=15),
|
||
],
|
||
cold_concepts=[
|
||
ConceptStatItem(name="房地产", change_pct=-5.76, stock_count=33, news_count=5),
|
||
ConceptStatItem(name="煤炭", change_pct=-4.32, stock_count=25, news_count=3),
|
||
],
|
||
active_concepts=[
|
||
ConceptStatItem(name="人工智能", news_count=45, report_count=15, total_mentions=60),
|
||
],
|
||
volatile_concepts=[
|
||
ConceptStatItem(name="数字货币", volatility=25.6, avg_change=2.1, max_change=15.2),
|
||
],
|
||
momentum_concepts=[
|
||
ConceptStatItem(name="AI应用", consecutive_days=6, total_change=19.2, avg_daily=3.2),
|
||
],
|
||
summary={
|
||
'total_concepts': 500,
|
||
'positive_count': 320,
|
||
'negative_count': 180,
|
||
'avg_change': 1.8,
|
||
'update_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
|
||
'date_range': f"{start_date} 至 {end_date}",
|
||
'days': (end_date - start_date).days + 1 if start_date and end_date else 7,
|
||
'start_date': str(start_date) if start_date else '',
|
||
'end_date': str(end_date) if end_date else ''
|
||
}
|
||
)
|
||
|
||
|
||
def build_keyword_query(query: str, stock_filters: List[str] = None, search_reason: bool = False) -> Dict:
|
||
"""构建关键词查询 - V4版本"""
|
||
must_queries = []
|
||
should_queries = []
|
||
|
||
if query.strip():
|
||
# 概念名称和描述搜索
|
||
should_queries.append({
|
||
"multi_match": {
|
||
"query": query,
|
||
"fields": ["concept^3", "description^2", "tags.text^2", "tags^1.5"],
|
||
"type": "best_fields",
|
||
"analyzer": "ik_smart"
|
||
}
|
||
})
|
||
|
||
# 股票名称搜索
|
||
should_queries.append({
|
||
"match": {
|
||
"stock_names.text": {
|
||
"query": query,
|
||
"boost": 1.5
|
||
}
|
||
}
|
||
})
|
||
|
||
# 可选:搜索股票原因
|
||
if search_reason:
|
||
should_queries.append({
|
||
"match": {
|
||
"stock_reasons_text": {
|
||
"query": query,
|
||
"boost": 1.0
|
||
}
|
||
}
|
||
})
|
||
|
||
if stock_filters:
|
||
# 股票代码或名称精确匹配
|
||
stock_query = {
|
||
"bool": {
|
||
"should": [
|
||
{"terms": {"stock_codes": stock_filters}},
|
||
{"terms": {"stock_names": stock_filters}}
|
||
]
|
||
}
|
||
}
|
||
must_queries.append(stock_query)
|
||
|
||
if should_queries:
|
||
must_queries.append({
|
||
"bool": {
|
||
"should": should_queries,
|
||
"minimum_should_match": 1
|
||
}
|
||
})
|
||
|
||
return {
|
||
"bool": {
|
||
"must": must_queries if must_queries else [{"match_all": {}}]
|
||
}
|
||
}
|
||
|
||
|
||
def parse_stocks_from_es(source: Dict, limit: int = 30, include_reasons: bool = False) -> List[StockInfo]:
|
||
"""从ES文档解析股票信息
|
||
|
||
Args:
|
||
source: ES文档源数据
|
||
limit: 返回股票数量限制
|
||
include_reasons: 是否包含股票原因
|
||
|
||
Returns:
|
||
股票信息列表
|
||
"""
|
||
stocks_data = source.get('stocks', [])
|
||
|
||
stocks = []
|
||
for s in stocks_data[:limit]:
|
||
stocks.append(StockInfo(
|
||
name=s.get('name', ''),
|
||
code=s.get('code'),
|
||
reason=s.get('reason') if include_reasons else None
|
||
))
|
||
|
||
return stocks
|
||
|
||
|
||
def parse_stocks_full(source: Dict) -> List[StockInfo]:
|
||
"""从ES文档解析完整股票信息(含原因,无数量限制)"""
|
||
stocks_data = source.get('stocks', [])
|
||
return [
|
||
StockInfo(
|
||
name=s.get('name', ''),
|
||
code=s.get('code'),
|
||
reason=s.get('reason')
|
||
)
|
||
for s in stocks_data
|
||
]
|
||
|
||
|
||
# ==================== API 端点 ====================
|
||
|
||
@app.get("/", tags=["Health"])
|
||
async def root():
|
||
return {
|
||
"status": "healthy",
|
||
"service": "概念搜索API V4",
|
||
"version": "4.0.0",
|
||
"index": INDEX_NAME,
|
||
"hierarchy_concepts": len(concept_to_hierarchy),
|
||
"features": ["reason_search", "stock_search", "semantic_search"]
|
||
}
|
||
|
||
|
||
@app.post("/search", response_model=SearchResponse, tags=["Search"])
|
||
async def search_concepts(request: SearchRequest):
|
||
"""搜索概念 - 支持语义搜索和原因搜索"""
|
||
start_time = datetime.now()
|
||
|
||
try:
|
||
# ========== 优化路径:空查询 + 涨跌幅排序 ==========
|
||
is_empty_query = not request.query or not request.query.strip()
|
||
is_change_sort = request.sort_by == "change_pct"
|
||
no_stock_filter = not request.filter_stocks
|
||
|
||
if is_empty_query and is_change_sort and no_stock_filter:
|
||
logger.info(f"[Search优化] 使用 MySQL 快速路径: page={request.page}, size={request.size}")
|
||
|
||
offset = (request.page - 1) * request.size
|
||
top_concepts, total, actual_date = await get_top_concepts_by_change(
|
||
trade_date=request.trade_date,
|
||
limit=request.size,
|
||
offset=offset,
|
||
filter_lv1=request.filter_lv1,
|
||
filter_lv2=request.filter_lv2
|
||
)
|
||
|
||
if not top_concepts:
|
||
took_ms = int((datetime.now() - start_time).total_seconds() * 1000)
|
||
return SearchResponse(
|
||
total=0, took_ms=took_ms, results=[],
|
||
search_info={"query": "", "semantic_weight": 0, "match_type": "mysql_optimized"},
|
||
price_date=actual_date, page=request.page, total_pages=0
|
||
)
|
||
|
||
concept_ids = [c['concept_id'] for c in top_concepts]
|
||
es_body = {
|
||
"query": {"terms": {"concept_id": concept_ids}},
|
||
"size": len(concept_ids),
|
||
"_source": {"excludes": ["description_embedding", "insight", "stock_reasons_text"]}
|
||
}
|
||
es_response = es_client.search(index=INDEX_NAME, body=es_body, timeout="10s")
|
||
|
||
es_details = {}
|
||
for hit in es_response['hits']['hits']:
|
||
source = hit['_source']
|
||
es_details[source.get('concept_id', '')] = source
|
||
|
||
results = []
|
||
for mysql_concept in top_concepts:
|
||
cid = mysql_concept['concept_id']
|
||
es_data = es_details.get(cid, {})
|
||
concept_name = mysql_concept['concept_name']
|
||
|
||
stocks = parse_stocks_from_es(es_data, include_reasons=request.include_stock_reasons)
|
||
hierarchy_info = get_concept_hierarchy(concept_name)
|
||
hierarchy = HierarchyInfo(**hierarchy_info) if hierarchy_info else None
|
||
|
||
result = ConceptResult(
|
||
concept_id=cid,
|
||
concept=concept_name,
|
||
description=es_data.get('description'),
|
||
tags=es_data.get('tags', []),
|
||
outbreak_dates=es_data.get('outbreak_dates', []),
|
||
stocks=stocks,
|
||
stock_count=mysql_concept['stock_count'] or len(es_data.get('stocks', [])),
|
||
hierarchy=hierarchy,
|
||
score=0.0,
|
||
match_type="mysql_optimized",
|
||
highlights=None,
|
||
price_info=ConceptPriceInfo(
|
||
trade_date=actual_date,
|
||
avg_change_pct=mysql_concept['avg_change_pct']
|
||
)
|
||
)
|
||
results.append(result)
|
||
|
||
took_ms = int((datetime.now() - start_time).total_seconds() * 1000)
|
||
total_pages = max(1, (total + request.size - 1) // request.size)
|
||
|
||
return SearchResponse(
|
||
total=total, took_ms=took_ms, results=results,
|
||
search_info={
|
||
"query": "", "semantic_weight": 0, "match_type": "mysql_optimized",
|
||
"filter_lv1": request.filter_lv1, "filter_lv2": request.filter_lv2,
|
||
"sort_by": request.sort_by
|
||
},
|
||
price_date=actual_date, page=request.page, total_pages=total_pages
|
||
)
|
||
|
||
# ========== 常规路径 ==========
|
||
# 计算语义权重
|
||
if request.semantic_weight is not None:
|
||
semantic_weight = request.semantic_weight
|
||
else:
|
||
semantic_weight = calculate_semantic_weight(request.query)
|
||
|
||
# 生成embedding
|
||
embedding = []
|
||
if semantic_weight > 0:
|
||
embedding = generate_embedding(request.query)
|
||
if not embedding:
|
||
semantic_weight = 0
|
||
|
||
effective_search_size = request.search_size
|
||
if request.sort_by in ["change_pct", "outbreak_date"]:
|
||
effective_search_size = min(1000, request.search_size * 10)
|
||
|
||
# 构建查询
|
||
search_body = {}
|
||
match_type = "keyword"
|
||
|
||
if semantic_weight == 0:
|
||
search_body = {
|
||
"query": build_keyword_query(request.query, request.filter_stocks, request.search_reason),
|
||
"size": effective_search_size
|
||
}
|
||
match_type = "keyword"
|
||
elif request.use_knn and embedding:
|
||
keyword_weight = 1.0 - semantic_weight
|
||
search_body = {
|
||
"knn": {
|
||
"field": "description_embedding",
|
||
"query_vector": embedding,
|
||
"k": effective_search_size,
|
||
"num_candidates": min(effective_search_size * 2, 10000),
|
||
"boost": semantic_weight
|
||
},
|
||
"query": {
|
||
"bool": {
|
||
"must": [build_keyword_query(request.query, request.filter_stocks, request.search_reason)],
|
||
"boost": keyword_weight
|
||
}
|
||
},
|
||
"size": effective_search_size
|
||
}
|
||
match_type = "hybrid_knn"
|
||
else:
|
||
search_body = {
|
||
"query": build_keyword_query(request.query, request.filter_stocks, request.search_reason),
|
||
"size": effective_search_size
|
||
}
|
||
match_type = "keyword"
|
||
|
||
search_body.update({
|
||
"highlight": {
|
||
"fields": {
|
||
"concept": {},
|
||
"description": {"fragment_size": 150},
|
||
"tags": {},
|
||
"stock_reasons_text": {"fragment_size": 150} if request.search_reason else {}
|
||
}
|
||
},
|
||
"_source": {"excludes": ["description_embedding", "insight", "stock_reasons_text"]},
|
||
"track_total_hits": True
|
||
})
|
||
|
||
es_response = es_client.search(index=INDEX_NAME, body=search_body, timeout="30s")
|
||
|
||
all_results = []
|
||
concept_ids = []
|
||
|
||
for hit in es_response['hits']['hits']:
|
||
source = hit['_source']
|
||
concept_name = source.get('concept', '')
|
||
concept_id = source.get('concept_id', '')
|
||
concept_ids.append(concept_id)
|
||
|
||
stocks = parse_stocks_from_es(source, include_reasons=request.include_stock_reasons)
|
||
|
||
hierarchy_info = get_concept_hierarchy(concept_name)
|
||
hierarchy = HierarchyInfo(**hierarchy_info) if hierarchy_info else None
|
||
|
||
# 层级过滤
|
||
if request.filter_lv1 and (not hierarchy_info or hierarchy_info.get('lv1') != request.filter_lv1):
|
||
continue
|
||
if request.filter_lv2 and (not hierarchy_info or hierarchy_info.get('lv2') != request.filter_lv2):
|
||
continue
|
||
|
||
highlights = hit.get('highlight', {})
|
||
|
||
result = ConceptResult(
|
||
concept_id=concept_id,
|
||
concept=concept_name,
|
||
description=source.get('description'),
|
||
tags=source.get('tags', []),
|
||
outbreak_dates=source.get('outbreak_dates', []),
|
||
stocks=stocks,
|
||
stock_count=source.get('stock_count', len(source.get('stocks', []))),
|
||
hierarchy=hierarchy,
|
||
score=hit['_score'] or 0,
|
||
match_type=match_type,
|
||
highlights=highlights,
|
||
price_info=None
|
||
)
|
||
all_results.append(result)
|
||
|
||
# 获取涨跌幅数据
|
||
price_data = {}
|
||
actual_price_date = None
|
||
if concept_ids:
|
||
price_data = await get_concept_price_data(concept_ids, request.trade_date)
|
||
if price_data:
|
||
actual_price_date = next(iter(price_data.values())).trade_date
|
||
for result in all_results:
|
||
result.price_info = price_data.get(result.concept_id)
|
||
|
||
# 排序
|
||
if request.sort_by == "change_pct":
|
||
all_results.sort(
|
||
key=lambda x: x.price_info.avg_change_pct if x.price_info and x.price_info.avg_change_pct is not None else -999,
|
||
reverse=True
|
||
)
|
||
elif request.sort_by == "stock_count":
|
||
all_results.sort(key=lambda x: x.stock_count, reverse=True)
|
||
elif request.sort_by == "outbreak_date":
|
||
all_results.sort(
|
||
key=lambda x: x.outbreak_dates[0] if x.outbreak_dates else '1900-01-01',
|
||
reverse=True
|
||
)
|
||
|
||
# 分页
|
||
total_results = len(all_results)
|
||
total_pages = max(1, (total_results + request.size - 1) // request.size)
|
||
current_page = min(request.page, total_pages)
|
||
start_idx = (current_page - 1) * request.size
|
||
end_idx = start_idx + request.size
|
||
page_results = all_results[start_idx:end_idx]
|
||
|
||
took_ms = int((datetime.now() - start_time).total_seconds() * 1000)
|
||
|
||
return SearchResponse(
|
||
total=es_response['hits']['total']['value'],
|
||
took_ms=took_ms,
|
||
results=page_results,
|
||
search_info={
|
||
"query": request.query,
|
||
"semantic_weight": semantic_weight,
|
||
"match_type": match_type,
|
||
"search_reason": request.search_reason,
|
||
"filter_lv1": request.filter_lv1,
|
||
"filter_lv2": request.filter_lv2,
|
||
"sort_by": request.sort_by,
|
||
"has_embedding": bool(embedding)
|
||
},
|
||
price_date=actual_price_date,
|
||
page=current_page,
|
||
total_pages=total_pages
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"搜索失败: {e}", exc_info=True)
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
@app.post("/search/reason", response_model=ReasonSearchResponse, tags=["Search"])
|
||
async def search_by_reason(request: ReasonSearchRequest):
|
||
"""
|
||
按股票原因搜索概念
|
||
|
||
搜索包含特定关键词的股票关联原因
|
||
例如:搜索"AI芯片"可以找到所有reason中提到AI芯片的概念和股票
|
||
"""
|
||
start_time = datetime.now()
|
||
|
||
try:
|
||
search_body = {
|
||
"query": {
|
||
"bool": {
|
||
"should": [
|
||
# 搜索聚合的reason文本(性能更好)
|
||
{
|
||
"match": {
|
||
"stock_reasons_text": {
|
||
"query": request.query,
|
||
"boost": 2.0
|
||
}
|
||
}
|
||
},
|
||
# nested查询(可定位具体股票)
|
||
{
|
||
"nested": {
|
||
"path": "stocks",
|
||
"query": {
|
||
"match": {
|
||
"stocks.reason": request.query
|
||
}
|
||
},
|
||
"inner_hits": {
|
||
"size": 5,
|
||
"highlight": {
|
||
"fields": {
|
||
"stocks.reason": {"fragment_size": 200}
|
||
}
|
||
}
|
||
} if request.include_stock_details else {}
|
||
}
|
||
}
|
||
]
|
||
}
|
||
},
|
||
"highlight": {
|
||
"fields": {
|
||
"stock_reasons_text": {"fragment_size": 200, "number_of_fragments": 3}
|
||
}
|
||
},
|
||
"_source": ["concept_id", "concept", "description", "stocks"],
|
||
"size": request.size * request.page # 获取足够数据进行分页
|
||
}
|
||
|
||
es_response = es_client.search(index=INDEX_NAME, body=search_body, timeout="30s")
|
||
|
||
results = []
|
||
for hit in es_response['hits']['hits']:
|
||
source = hit['_source']
|
||
|
||
# 提取匹配的股票
|
||
matched_stocks = []
|
||
if request.include_stock_details and 'inner_hits' in hit and 'stocks' in hit['inner_hits']:
|
||
for inner in hit['inner_hits']['stocks']['hits']['hits']:
|
||
stock_source = inner['_source']
|
||
reason = stock_source.get('reason', '')
|
||
|
||
# 如果有高亮,使用高亮内容
|
||
if 'highlight' in inner and 'stocks.reason' in inner['highlight']:
|
||
reason = inner['highlight']['stocks.reason'][0]
|
||
|
||
matched_stocks.append(StockInfo(
|
||
name=stock_source.get('name', ''),
|
||
code=stock_source.get('code'),
|
||
reason=reason
|
||
))
|
||
|
||
# 提取reason高亮
|
||
reason_highlights = None
|
||
if 'highlight' in hit and 'stock_reasons_text' in hit['highlight']:
|
||
reason_highlights = hit['highlight']['stock_reasons_text']
|
||
|
||
results.append(ReasonSearchResult(
|
||
concept_id=source.get('concept_id', ''),
|
||
concept=source.get('concept', ''),
|
||
description=source.get('description'),
|
||
score=hit['_score'] or 0,
|
||
matched_stocks=matched_stocks,
|
||
reason_highlights=reason_highlights
|
||
))
|
||
|
||
# 分页
|
||
total = es_response['hits']['total']['value']
|
||
total_pages = max(1, (total + request.size - 1) // request.size)
|
||
start_idx = (request.page - 1) * request.size
|
||
page_results = results[start_idx:start_idx + request.size]
|
||
|
||
took_ms = int((datetime.now() - start_time).total_seconds() * 1000)
|
||
|
||
return ReasonSearchResponse(
|
||
total=total,
|
||
took_ms=took_ms,
|
||
query=request.query,
|
||
results=page_results,
|
||
page=request.page,
|
||
total_pages=total_pages
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"原因搜索失败: {e}", exc_info=True)
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
@app.get("/concept/{concept_id}", response_model=ConceptDetailResponse, tags=["Concepts"])
|
||
async def get_concept_detail(
|
||
concept_id: str,
|
||
trade_date: Optional[date] = Query(None)
|
||
):
|
||
"""获取概念详情(包含完整股票信息和原因)"""
|
||
try:
|
||
result = es_client.get(index=INDEX_NAME, id=concept_id)
|
||
source = result['_source']
|
||
|
||
concept_name = source.get('concept', '')
|
||
|
||
# 解析完整股票信息(含reason)
|
||
stocks_full = parse_stocks_full(source)
|
||
|
||
hierarchy_info = get_concept_hierarchy(concept_name)
|
||
hierarchy = HierarchyInfo(**hierarchy_info) if hierarchy_info else None
|
||
|
||
price_data = await get_concept_price_data([concept_id], trade_date)
|
||
price_info = price_data.get(concept_id)
|
||
|
||
return ConceptDetailResponse(
|
||
concept_id=concept_id,
|
||
concept=concept_name,
|
||
description=source.get('description'),
|
||
insight=source.get('insight'),
|
||
tags=source.get('tags', []),
|
||
outbreak_dates=source.get('outbreak_dates', []),
|
||
stocks=stocks_full,
|
||
stock_count=len(stocks_full),
|
||
hierarchy=hierarchy,
|
||
folders=source.get('folders', []),
|
||
created_at=source.get('created_at'),
|
||
price_info=price_info
|
||
)
|
||
|
||
except Exception as e:
|
||
if "NotFoundError" in str(type(e)):
|
||
raise HTTPException(status_code=404, detail="概念不存在")
|
||
logger.error(f"获取概念详情失败: {e}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
@app.get("/stock/{stock_code}/concepts", tags=["Stocks"])
|
||
async def get_stock_concepts(
|
||
stock_code: str,
|
||
size: int = Query(50, ge=1, le=200),
|
||
trade_date: Optional[date] = Query(None),
|
||
include_reason: bool = Query(True, description="是否返回关联原因")
|
||
):
|
||
"""根据股票代码或名称查询相关概念"""
|
||
try:
|
||
# 使用keyword字段快速匹配
|
||
query = {
|
||
"bool": {
|
||
"should": [
|
||
{"term": {"stock_codes": stock_code}},
|
||
{"term": {"stock_names": stock_code}},
|
||
{"match": {"stock_names.text": stock_code}}
|
||
],
|
||
"minimum_should_match": 1
|
||
}
|
||
}
|
||
|
||
search_body = {
|
||
"query": query,
|
||
"size": size,
|
||
"_source": ["concept_id", "concept", "description", "tags", "stocks", "stock_count", "outbreak_dates"]
|
||
}
|
||
|
||
# 如果需要返回reason,添加nested inner_hits
|
||
if include_reason:
|
||
search_body["query"] = {
|
||
"bool": {
|
||
"should": [
|
||
{"term": {"stock_codes": stock_code}},
|
||
{"term": {"stock_names": stock_code}},
|
||
{
|
||
"nested": {
|
||
"path": "stocks",
|
||
"query": {
|
||
"bool": {
|
||
"should": [
|
||
{"term": {"stocks.code": stock_code}},
|
||
{"match": {"stocks.name.text": stock_code}}
|
||
]
|
||
}
|
||
},
|
||
"inner_hits": {
|
||
"size": 1,
|
||
"_source": ["name", "code", "reason"]
|
||
}
|
||
}
|
||
}
|
||
],
|
||
"minimum_should_match": 1
|
||
}
|
||
}
|
||
|
||
es_response = es_client.search(index=INDEX_NAME, body=search_body)
|
||
|
||
concept_ids = []
|
||
concepts = []
|
||
|
||
for hit in es_response['hits']['hits']:
|
||
source = hit['_source']
|
||
concept_name = source.get('concept', '')
|
||
concept_id = source.get('concept_id', '')
|
||
concept_ids.append(concept_id)
|
||
|
||
hierarchy_info = get_concept_hierarchy(concept_name)
|
||
|
||
concept_data = {
|
||
"concept_id": concept_id,
|
||
"concept": concept_name,
|
||
"description": source.get('description'),
|
||
"tags": source.get('tags', []),
|
||
"stock_count": source.get('stock_count', len(source.get('stocks', []))),
|
||
"hierarchy": hierarchy_info,
|
||
"outbreak_dates": source.get('outbreak_dates', [])
|
||
}
|
||
|
||
# 提取该股票的关联原因
|
||
if include_reason and 'inner_hits' in hit and 'stocks' in hit['inner_hits']:
|
||
inner = hit['inner_hits']['stocks']['hits']['hits']
|
||
if inner:
|
||
stock_info = inner[0]['_source']
|
||
concept_data['stock_reason'] = stock_info.get('reason')
|
||
concept_data['matched_stock'] = {
|
||
'name': stock_info.get('name'),
|
||
'code': stock_info.get('code')
|
||
}
|
||
|
||
concepts.append(concept_data)
|
||
|
||
# 获取涨跌幅
|
||
price_data = await get_concept_price_data(concept_ids, trade_date)
|
||
for concept in concepts:
|
||
price_info = price_data.get(concept['concept_id'])
|
||
concept['price_info'] = {
|
||
'trade_date': str(price_info.trade_date),
|
||
'avg_change_pct': price_info.avg_change_pct
|
||
} if price_info else None
|
||
|
||
return {
|
||
"stock_code": stock_code,
|
||
"total": es_response['hits']['total']['value'],
|
||
"concepts": concepts
|
||
}
|
||
|
||
except Exception as e:
|
||
logger.error(f"查询股票概念失败: {e}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
@app.get("/stock/search", tags=["Stocks"])
|
||
async def search_stocks(
|
||
keyword: str = Query(..., description="股票关键词(代码或名称)"),
|
||
size: int = Query(20, ge=1, le=100)
|
||
):
|
||
"""搜索股票名称或代码"""
|
||
try:
|
||
# V4索引使用独立的stock_names和stock_codes字段
|
||
search_body = {
|
||
"query": {
|
||
"bool": {
|
||
"should": [
|
||
{"wildcard": {"stock_codes": f"*{keyword}*"}},
|
||
{"match": {"stock_names.text": keyword}},
|
||
{"wildcard": {"stock_names": f"*{keyword}*"}}
|
||
]
|
||
}
|
||
},
|
||
"size": 0,
|
||
"aggs": {
|
||
"by_stock_code": {
|
||
"terms": {
|
||
"field": "stock_codes",
|
||
"size": size * 3
|
||
}
|
||
},
|
||
"by_stock_name": {
|
||
"terms": {
|
||
"field": "stock_names",
|
||
"size": size * 3
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
es_response = es_client.search(index=INDEX_NAME, body=search_body)
|
||
|
||
# 合并结果
|
||
stocks_map = {}
|
||
|
||
# 从stock_codes聚合
|
||
for bucket in es_response['aggregations']['by_stock_code']['buckets']:
|
||
code = bucket['key']
|
||
if keyword in code:
|
||
if code not in stocks_map:
|
||
stocks_map[code] = {'code': code, 'name': None, 'count': bucket['doc_count']}
|
||
else:
|
||
stocks_map[code]['count'] = max(stocks_map[code]['count'], bucket['doc_count'])
|
||
|
||
# 从stock_names聚合
|
||
for bucket in es_response['aggregations']['by_stock_name']['buckets']:
|
||
name = bucket['key']
|
||
if keyword in name:
|
||
# 需要找到对应的code
|
||
stocks_map[name] = {'code': None, 'name': name, 'count': bucket['doc_count']}
|
||
|
||
stocks = []
|
||
for key, data in stocks_map.items():
|
||
stocks.append({
|
||
'stock_code': data.get('code') or '',
|
||
'stock_name': data.get('name') or key,
|
||
'concept_count': data['count']
|
||
})
|
||
|
||
stocks.sort(key=lambda x: x['concept_count'], reverse=True)
|
||
|
||
return {
|
||
"total": len(stocks),
|
||
"stocks": stocks[:size]
|
||
}
|
||
|
||
except Exception as e:
|
||
logger.error(f"搜索股票失败: {e}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
# ==================== 层级相关接口 ====================
|
||
|
||
@app.get("/hierarchy", response_model=HierarchyResponse, tags=["Hierarchy"])
|
||
async def get_hierarchy():
|
||
"""获取完整层级结构"""
|
||
try:
|
||
result = []
|
||
for lv1 in hierarchy_data.get('hierarchy', []):
|
||
lv1_node = HierarchyLevel(
|
||
id=lv1.get('lv1_id', ''),
|
||
name=lv1.get('lv1', ''),
|
||
concept_count=0,
|
||
children=[]
|
||
)
|
||
for child in lv1.get('children', []):
|
||
lv2_node = HierarchyLevel(
|
||
id=child.get('lv2_id', ''),
|
||
name=child.get('lv2', ''),
|
||
concept_count=0,
|
||
children=[]
|
||
)
|
||
if 'children' in child:
|
||
for lv3_child in child.get('children', []):
|
||
concepts = lv3_child.get('concepts', [])
|
||
lv3_node = HierarchyLevel(
|
||
id=lv3_child.get('lv3_id', ''),
|
||
name=lv3_child.get('lv3', ''),
|
||
concept_count=len(concepts),
|
||
concepts=concepts
|
||
)
|
||
lv2_node.children.append(lv3_node)
|
||
lv2_node.concept_count += len(concepts)
|
||
else:
|
||
concepts = child.get('concepts', [])
|
||
lv2_node.concept_count = len(concepts)
|
||
lv2_node.concepts = concepts
|
||
lv1_node.children.append(lv2_node)
|
||
lv1_node.concept_count += lv2_node.concept_count
|
||
result.append(lv1_node)
|
||
|
||
total_concepts = sum(lv1.concept_count for lv1 in result)
|
||
return HierarchyResponse(hierarchy=result, total_concepts=total_concepts)
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取层级结构失败: {e}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
# 注意:/hierarchy/price 必须在 /hierarchy/{lv1_id} 之前定义
|
||
@app.get("/hierarchy/price", tags=["Hierarchy"])
|
||
async def get_hierarchy_price(
|
||
trade_date: Optional[date] = Query(None, description="交易日期,默认最新"),
|
||
lv1_filter: Optional[str] = Query(None, description="筛选特定一级分类")
|
||
):
|
||
"""获取所有概念的涨跌幅数据(包含母概念lv1/lv2/lv3和叶子概念leaf)"""
|
||
if not mysql_pool:
|
||
raise HTTPException(status_code=503, detail="MySQL连接不可用")
|
||
|
||
try:
|
||
async with mysql_pool.acquire() as conn:
|
||
async with conn.cursor(aiomysql.DictCursor) as cursor:
|
||
query_date = trade_date
|
||
if query_date is None:
|
||
await cursor.execute("SELECT MAX(trade_date) as max_date FROM concept_daily_stats")
|
||
result = await cursor.fetchone()
|
||
if not result or not result['max_date']:
|
||
raise HTTPException(status_code=404, detail="无涨跌幅数据")
|
||
query_date = result['max_date']
|
||
|
||
base_query = """
|
||
SELECT concept_id, concept_name, concept_type, trade_date, avg_change_pct, stock_count
|
||
FROM concept_daily_stats WHERE trade_date = %s AND concept_type = %s
|
||
"""
|
||
if lv1_filter:
|
||
base_query += " AND concept_name LIKE %s"
|
||
base_query += " ORDER BY avg_change_pct DESC"
|
||
|
||
lv1_concepts, lv2_concepts, lv3_concepts, leaf_concepts = [], [], [], []
|
||
|
||
for ctype, target_list in [('lv1', lv1_concepts), ('lv2', lv2_concepts),
|
||
('lv3', lv3_concepts), ('leaf', leaf_concepts)]:
|
||
params = (query_date, ctype, f"%{lv1_filter}%") if lv1_filter else (query_date, ctype)
|
||
await cursor.execute(base_query, params)
|
||
for row in await cursor.fetchall():
|
||
item = {
|
||
"concept_id": row['concept_id'],
|
||
"concept_name": row['concept_name'],
|
||
"concept_type": ctype,
|
||
"trade_date": str(row['trade_date']),
|
||
"avg_change_pct": float(row['avg_change_pct']) if row['avg_change_pct'] else None,
|
||
"stock_count": row['stock_count']
|
||
}
|
||
if ctype == 'leaf':
|
||
item["hierarchy"] = get_concept_hierarchy(row['concept_name'])
|
||
target_list.append(item)
|
||
|
||
return {
|
||
"trade_date": str(query_date),
|
||
"lv1_concepts": lv1_concepts,
|
||
"lv2_concepts": lv2_concepts,
|
||
"lv3_concepts": lv3_concepts,
|
||
"leaf_concepts": leaf_concepts,
|
||
"total_count": len(lv1_concepts) + len(lv2_concepts) + len(lv3_concepts) + len(leaf_concepts),
|
||
"summary": {
|
||
"lv1_count": len(lv1_concepts), "lv2_count": len(lv2_concepts),
|
||
"lv3_count": len(lv3_concepts), "leaf_count": len(leaf_concepts)
|
||
}
|
||
}
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"获取层级涨跌幅失败: {e}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
@app.get("/hierarchy/{lv1_id}", tags=["Hierarchy"])
|
||
async def get_hierarchy_level(
|
||
lv1_id: str,
|
||
lv2_id: Optional[str] = Query(None, description="二级分类ID")
|
||
):
|
||
"""获取指定层级的概念列表"""
|
||
try:
|
||
for lv1 in hierarchy_data.get('hierarchy', []):
|
||
if lv1.get('lv1_id') == lv1_id:
|
||
if lv2_id:
|
||
for child in lv1.get('children', []):
|
||
if child.get('lv2_id') == lv2_id:
|
||
concepts = []
|
||
if 'children' in child:
|
||
for lv3_child in child.get('children', []):
|
||
concepts.extend(lv3_child.get('concepts', []))
|
||
else:
|
||
concepts = child.get('concepts', [])
|
||
return {
|
||
"lv1": lv1.get('lv1'), "lv1_id": lv1_id,
|
||
"lv2": child.get('lv2'), "lv2_id": lv2_id,
|
||
"concepts": concepts, "concept_count": len(concepts)
|
||
}
|
||
raise HTTPException(status_code=404, detail=f"未找到二级分类: {lv2_id}")
|
||
else:
|
||
concepts = []
|
||
for child in lv1.get('children', []):
|
||
if 'children' in child:
|
||
for lv3_child in child.get('children', []):
|
||
concepts.extend(lv3_child.get('concepts', []))
|
||
else:
|
||
concepts.extend(child.get('concepts', []))
|
||
return {
|
||
"lv1": lv1.get('lv1'), "lv1_id": lv1_id,
|
||
"concepts": concepts, "concept_count": len(concepts),
|
||
"children": lv1.get('children', [])
|
||
}
|
||
raise HTTPException(status_code=404, detail=f"未找到一级分类: {lv1_id}")
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"获取层级详情失败: {e}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
@app.get("/statistics/hierarchy", tags=["Statistics"])
|
||
async def get_hierarchy_statistics():
|
||
"""获取层级统计信息"""
|
||
try:
|
||
stats = []
|
||
for lv1 in hierarchy_data.get('hierarchy', []):
|
||
lv1_stats = {
|
||
"lv1": lv1.get('lv1'), "lv1_id": lv1.get('lv1_id'),
|
||
"lv2_count": len(lv1.get('children', [])), "concept_count": 0, "children": []
|
||
}
|
||
for child in lv1.get('children', []):
|
||
lv2_count = 0
|
||
if 'children' in child:
|
||
for lv3_child in child.get('children', []):
|
||
lv2_count += len(lv3_child.get('concepts', []))
|
||
else:
|
||
lv2_count = len(child.get('concepts', []))
|
||
lv1_stats['children'].append({
|
||
"lv2": child.get('lv2'), "lv2_id": child.get('lv2_id'), "concept_count": lv2_count
|
||
})
|
||
lv1_stats['concept_count'] += lv2_count
|
||
stats.append(lv1_stats)
|
||
stats.sort(key=lambda x: x['concept_count'], reverse=True)
|
||
return {
|
||
"total_lv1": len(stats),
|
||
"total_concepts": sum(s['concept_count'] for s in stats),
|
||
"statistics": stats
|
||
}
|
||
except Exception as e:
|
||
logger.error(f"获取层级统计失败: {e}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
# ==================== 价格相关接口 ====================
|
||
|
||
@app.get("/price/latest", tags=["Price"])
|
||
async def get_latest_price_date():
|
||
"""获取最新的涨跌幅数据日期"""
|
||
try:
|
||
latest_date = await get_latest_trade_date()
|
||
return {"latest_trade_date": latest_date, "has_data": latest_date is not None}
|
||
except Exception as e:
|
||
logger.error(f"获取最新价格日期失败: {e}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
@app.get("/concept/{concept_id}/price-timeseries", response_model=PriceTimeSeriesResponse, tags=["Price"])
|
||
async def get_concept_price_timeseries(
|
||
concept_id: str,
|
||
start_date: date = Query(..., description="开始日期"),
|
||
end_date: date = Query(..., description="结束日期")
|
||
):
|
||
"""获取概念在指定日期范围内的涨跌幅时间序列数据"""
|
||
if not mysql_pool:
|
||
return PriceTimeSeriesResponse(
|
||
concept_id=concept_id, concept_name=concept_id,
|
||
start_date=start_date, end_date=end_date, data_points=0, timeseries=[]
|
||
)
|
||
if start_date > end_date:
|
||
raise HTTPException(status_code=400, detail="开始日期不能晚于结束日期")
|
||
|
||
try:
|
||
async with mysql_pool.acquire() as conn:
|
||
async with conn.cursor(aiomysql.DictCursor) as cursor:
|
||
query = """
|
||
SELECT trade_date, concept_name, avg_change_pct, stock_count
|
||
FROM concept_daily_stats
|
||
WHERE concept_id = %s AND trade_date >= %s AND trade_date <= %s
|
||
ORDER BY trade_date ASC
|
||
"""
|
||
await cursor.execute(query, (concept_id, start_date, end_date))
|
||
rows = await cursor.fetchall()
|
||
|
||
if not rows:
|
||
raise HTTPException(status_code=404, detail=f"未找到概念ID {concept_id} 的数据")
|
||
|
||
timeseries = []
|
||
concept_name = ""
|
||
for row in rows:
|
||
if not concept_name:
|
||
concept_name = row['concept_name']
|
||
timeseries.append(PriceTimeSeriesItem(
|
||
trade_date=row['trade_date'],
|
||
avg_change_pct=float(row['avg_change_pct']) if row['avg_change_pct'] else None,
|
||
stock_count=row.get('stock_count')
|
||
))
|
||
|
||
return PriceTimeSeriesResponse(
|
||
concept_id=concept_id, concept_name=concept_name or concept_id,
|
||
start_date=start_date, end_date=end_date,
|
||
data_points=len(timeseries), timeseries=timeseries
|
||
)
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"获取价格时间序列失败: {e}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
@app.get("/price/list", response_model=ConceptPriceListResponse, tags=["Price"])
|
||
async def get_concept_price_list(
|
||
trade_date: Optional[date] = Query(None, description="交易日期,默认最新"),
|
||
concept_type: Optional[str] = Query(None, description="概念类型: leaf/lv1/lv2/lv3"),
|
||
sort_by: str = Query("change_desc", description="排序方式"),
|
||
limit: int = Query(100, ge=1, le=1000),
|
||
offset: int = Query(0, ge=0)
|
||
):
|
||
"""批量获取概念涨跌幅列表"""
|
||
if not mysql_pool:
|
||
raise HTTPException(status_code=503, detail="MySQL连接不可用")
|
||
|
||
try:
|
||
async with mysql_pool.acquire() as conn:
|
||
async with conn.cursor(aiomysql.DictCursor) as cursor:
|
||
if trade_date is None:
|
||
await cursor.execute("SELECT MAX(trade_date) as max_date FROM concept_daily_stats")
|
||
result = await cursor.fetchone()
|
||
if not result or not result['max_date']:
|
||
raise HTTPException(status_code=404, detail="无涨跌幅数据")
|
||
trade_date = result['max_date']
|
||
|
||
where_conditions = ["trade_date = %s"]
|
||
params = [trade_date]
|
||
if concept_type and concept_type in ['leaf', 'lv1', 'lv2', 'lv3']:
|
||
where_conditions.append("concept_type = %s")
|
||
params.append(concept_type)
|
||
|
||
order_clause = "avg_change_pct DESC"
|
||
if sort_by == "change_asc":
|
||
order_clause = "avg_change_pct ASC"
|
||
elif sort_by == "stock_count":
|
||
order_clause = "stock_count DESC"
|
||
|
||
count_query = f"SELECT COUNT(*) as cnt FROM concept_daily_stats WHERE {' AND '.join(where_conditions)}"
|
||
await cursor.execute(count_query, params)
|
||
total = (await cursor.fetchone())['cnt']
|
||
|
||
query = f"""
|
||
SELECT concept_id, concept_name, concept_type, trade_date, avg_change_pct, stock_count
|
||
FROM concept_daily_stats WHERE {' AND '.join(where_conditions)}
|
||
ORDER BY {order_clause} LIMIT %s OFFSET %s
|
||
"""
|
||
await cursor.execute(query, params + [limit, offset])
|
||
rows = await cursor.fetchall()
|
||
|
||
concepts = []
|
||
for row in rows:
|
||
hierarchy = None
|
||
if row['concept_type'] == 'leaf':
|
||
hi = get_concept_hierarchy(row['concept_name'])
|
||
if hi:
|
||
hierarchy = HierarchyInfo(**hi)
|
||
concepts.append(ConceptPriceItem(
|
||
concept_id=row['concept_id'], concept_name=row['concept_name'],
|
||
concept_type=row['concept_type'], trade_date=row['trade_date'],
|
||
avg_change_pct=float(row['avg_change_pct']) if row['avg_change_pct'] else None,
|
||
stock_count=row['stock_count'], hierarchy=hierarchy
|
||
))
|
||
|
||
return ConceptPriceListResponse(trade_date=trade_date, total=total, concepts=concepts)
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"获取概念涨跌幅列表失败: {e}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
@app.get("/hierarchy/{hierarchy_id}/price-timeseries", tags=["Hierarchy"])
|
||
async def get_hierarchy_price_timeseries(
|
||
hierarchy_id: str,
|
||
start_date: date = Query(..., description="开始日期"),
|
||
end_date: date = Query(..., description="结束日期")
|
||
):
|
||
"""获取层级概念(lv1/lv2/lv3)的价格时间序列"""
|
||
if not mysql_pool:
|
||
raise HTTPException(status_code=503, detail="MySQL连接不可用")
|
||
if start_date > end_date:
|
||
raise HTTPException(status_code=400, detail="开始日期不能晚于结束日期")
|
||
|
||
try:
|
||
async with mysql_pool.acquire() as conn:
|
||
async with conn.cursor(aiomysql.DictCursor) as cursor:
|
||
query = """
|
||
SELECT concept_id, concept_name, concept_type, trade_date, avg_change_pct, stock_count
|
||
FROM concept_daily_stats
|
||
WHERE concept_id = %s AND trade_date >= %s AND trade_date <= %s
|
||
ORDER BY trade_date ASC
|
||
"""
|
||
await cursor.execute(query, (hierarchy_id, start_date, end_date))
|
||
rows = await cursor.fetchall()
|
||
|
||
if not rows:
|
||
raise HTTPException(status_code=404, detail=f"未找到层级概念 {hierarchy_id} 的数据")
|
||
|
||
timeseries = []
|
||
concept_name, concept_type = "", ""
|
||
for row in rows:
|
||
if not concept_name:
|
||
concept_name = row['concept_name']
|
||
concept_type = row['concept_type']
|
||
timeseries.append({
|
||
'trade_date': row['trade_date'],
|
||
'avg_change_pct': float(row['avg_change_pct']) if row['avg_change_pct'] else None,
|
||
'stock_count': row['stock_count']
|
||
})
|
||
|
||
return {
|
||
'concept_id': hierarchy_id, 'concept_name': concept_name,
|
||
'concept_type': concept_type, 'start_date': start_date, 'end_date': end_date,
|
||
'data_points': len(timeseries), 'timeseries': timeseries
|
||
}
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"获取层级价格时间序列失败: {e}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
# ==================== 统计接口 ====================
|
||
|
||
@app.get("/statistics", response_model=ConceptStatisticsResponse, tags=["Statistics"])
|
||
async def get_concept_statistics(
|
||
days: Optional[int] = Query(None, ge=1, le=90),
|
||
start_date: Optional[date] = Query(None),
|
||
end_date: Optional[date] = Query(None),
|
||
min_stock_count: int = Query(3, ge=1),
|
||
concept_type: Optional[str] = Query(None)
|
||
):
|
||
"""获取概念板块统计数据"""
|
||
from datetime import timedelta
|
||
|
||
if days is not None and (start_date is not None or end_date is not None):
|
||
raise HTTPException(status_code=400, detail="days参数与start_date/end_date参数不能同时使用")
|
||
|
||
if start_date is not None and end_date is not None:
|
||
if start_date > end_date:
|
||
raise HTTPException(status_code=400, detail="开始日期不能晚于结束日期")
|
||
elif days is not None:
|
||
end_date = datetime.now().date()
|
||
start_date = end_date - timedelta(days=days)
|
||
elif start_date is not None:
|
||
end_date = datetime.now().date()
|
||
elif end_date is not None:
|
||
start_date = end_date - timedelta(days=7)
|
||
else:
|
||
end_date = datetime.now().date()
|
||
start_date = end_date - timedelta(days=7)
|
||
|
||
if not mysql_pool:
|
||
return ConceptStatisticsResponse(
|
||
success=True, data=get_fallback_statistics(start_date, end_date),
|
||
params={'start_date': str(start_date), 'end_date': str(end_date)},
|
||
note="MySQL连接不可用"
|
||
)
|
||
|
||
type_filter = concept_type if concept_type in ['leaf', 'lv1', 'lv2', 'lv3'] else 'leaf'
|
||
|
||
try:
|
||
async with mysql_pool.acquire() as conn:
|
||
async with conn.cursor(aiomysql.DictCursor) as cursor:
|
||
# 涨幅榜
|
||
await cursor.execute("""
|
||
SELECT concept_id, concept_name, AVG(avg_change_pct) as avg_change_pct,
|
||
AVG(stock_count) as avg_stock_count, COUNT(*) as trading_days
|
||
FROM concept_daily_stats
|
||
WHERE trade_date >= %s AND trade_date <= %s AND avg_change_pct IS NOT NULL
|
||
AND stock_count >= %s AND concept_type = %s
|
||
GROUP BY concept_id, concept_name HAVING COUNT(*) >= 2
|
||
ORDER BY AVG(avg_change_pct) DESC LIMIT 10
|
||
""", (start_date, end_date, min_stock_count, type_filter))
|
||
hot_rows = await cursor.fetchall()
|
||
|
||
# 跌幅榜
|
||
await cursor.execute("""
|
||
SELECT concept_id, concept_name, AVG(avg_change_pct) as avg_change_pct,
|
||
AVG(stock_count) as avg_stock_count, COUNT(*) as trading_days
|
||
FROM concept_daily_stats
|
||
WHERE trade_date >= %s AND trade_date <= %s AND avg_change_pct IS NOT NULL
|
||
AND stock_count >= %s AND concept_type = %s
|
||
GROUP BY concept_id, concept_name HAVING COUNT(*) >= 2
|
||
ORDER BY AVG(avg_change_pct) ASC LIMIT 10
|
||
""", (start_date, end_date, min_stock_count, type_filter))
|
||
cold_rows = await cursor.fetchall()
|
||
|
||
# 活跃榜
|
||
await cursor.execute("""
|
||
SELECT concept_id, concept_name, COUNT(*) as trading_days, AVG(stock_count) as avg_stock_count
|
||
FROM concept_daily_stats
|
||
WHERE trade_date >= %s AND trade_date <= %s AND stock_count >= %s AND concept_type = %s
|
||
GROUP BY concept_id, concept_name ORDER BY COUNT(*) DESC LIMIT 10
|
||
""", (start_date, end_date, min_stock_count, type_filter))
|
||
active_rows = await cursor.fetchall()
|
||
|
||
# 波动榜
|
||
await cursor.execute("""
|
||
SELECT concept_id, concept_name, STDDEV(avg_change_pct) as volatility,
|
||
AVG(avg_change_pct) as avg_change_pct, MAX(avg_change_pct) as max_change_pct
|
||
FROM concept_daily_stats
|
||
WHERE trade_date >= %s AND trade_date <= %s AND avg_change_pct IS NOT NULL
|
||
AND stock_count >= %s AND concept_type = %s
|
||
GROUP BY concept_id, concept_name
|
||
HAVING COUNT(*) >= 3 AND STDDEV(avg_change_pct) IS NOT NULL
|
||
ORDER BY STDDEV(avg_change_pct) DESC LIMIT 10
|
||
""", (start_date, end_date, min_stock_count, type_filter))
|
||
volatile_rows = await cursor.fetchall()
|
||
|
||
# 连涨榜
|
||
await cursor.execute("""
|
||
SELECT concept_id, concept_name, COUNT(*) as positive_days,
|
||
SUM(avg_change_pct) as total_change, AVG(avg_change_pct) as avg_change_pct
|
||
FROM concept_daily_stats
|
||
WHERE trade_date >= %s AND trade_date <= %s AND avg_change_pct > 0
|
||
AND stock_count >= %s AND concept_type = %s
|
||
GROUP BY concept_id, concept_name HAVING COUNT(*) >= 2
|
||
ORDER BY COUNT(*) DESC LIMIT 10
|
||
""", (start_date, end_date, min_stock_count, type_filter))
|
||
momentum_rows = await cursor.fetchall()
|
||
|
||
# 总体统计
|
||
await cursor.execute("""
|
||
SELECT COUNT(DISTINCT concept_id) as total_concepts,
|
||
COUNT(DISTINCT CASE WHEN avg_change_pct > 0 THEN concept_id END) as positive_concepts,
|
||
COUNT(DISTINCT CASE WHEN avg_change_pct < 0 THEN concept_id END) as negative_concepts,
|
||
AVG(avg_change_pct) as overall_avg_change
|
||
FROM concept_daily_stats
|
||
WHERE trade_date >= %s AND trade_date <= %s AND avg_change_pct IS NOT NULL AND concept_type = %s
|
||
""", (start_date, end_date, type_filter))
|
||
total_row = await cursor.fetchone()
|
||
|
||
def build_items(rows, item_type):
|
||
items = []
|
||
for row in rows:
|
||
hi = get_concept_hierarchy(row['concept_name'])
|
||
item = ConceptStatItem(
|
||
name=row['concept_name'], concept_id=row.get('concept_id'),
|
||
hierarchy=HierarchyInfo(**hi) if hi else None
|
||
)
|
||
if item_type in ['hot', 'cold']:
|
||
item.change_pct = round(float(row['avg_change_pct']), 2) if row['avg_change_pct'] else 0.0
|
||
item.stock_count = int(row['avg_stock_count']) if row['avg_stock_count'] else 0
|
||
item.news_count = int(row['trading_days']) if row['trading_days'] else 0
|
||
elif item_type == 'active':
|
||
item.news_count = int(row['trading_days']) if row['trading_days'] else 0
|
||
item.stock_count = int(row['avg_stock_count']) if row['avg_stock_count'] else 0
|
||
item.report_count = max(1, int(row['trading_days'] * 0.3)) if row['trading_days'] else 1
|
||
item.total_mentions = item.news_count + item.report_count
|
||
elif item_type == 'volatile':
|
||
item.volatility = round(float(row['volatility']), 2) if row['volatility'] else 0.0
|
||
item.avg_change = round(float(row['avg_change_pct']), 2) if row['avg_change_pct'] else 0.0
|
||
item.max_change = round(float(row['max_change_pct']), 2) if row['max_change_pct'] else 0.0
|
||
elif item_type == 'momentum':
|
||
item.consecutive_days = int(row['positive_days']) if row['positive_days'] else 0
|
||
item.total_change = round(float(row['total_change']), 2) if row['total_change'] else 0.0
|
||
item.avg_daily = round(float(row['avg_change_pct']), 2) if row['avg_change_pct'] else 0.0
|
||
items.append(item)
|
||
return items
|
||
|
||
statistics = ConceptStatistics(
|
||
hot_concepts=build_items(hot_rows, 'hot'),
|
||
cold_concepts=build_items(cold_rows, 'cold'),
|
||
active_concepts=build_items(active_rows, 'active'),
|
||
volatile_concepts=build_items(volatile_rows, 'volatile'),
|
||
momentum_concepts=build_items(momentum_rows, 'momentum'),
|
||
summary={
|
||
'total_concepts': int(total_row['total_concepts']) if total_row['total_concepts'] else 0,
|
||
'positive_count': int(total_row['positive_concepts']) if total_row['positive_concepts'] else 0,
|
||
'negative_count': int(total_row['negative_concepts']) if total_row['negative_concepts'] else 0,
|
||
'avg_change': round(float(total_row['overall_avg_change']), 2) if total_row['overall_avg_change'] else 0.0,
|
||
'update_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
|
||
'date_range': f"{start_date} 至 {end_date}",
|
||
'days': (end_date - start_date).days + 1,
|
||
'start_date': str(start_date), 'end_date': str(end_date)
|
||
}
|
||
)
|
||
|
||
return ConceptStatisticsResponse(
|
||
success=True, data=statistics,
|
||
params={
|
||
'days': (end_date - start_date).days + 1, 'min_stock_count': min_stock_count,
|
||
'start_date': str(start_date), 'end_date': str(end_date), 'concept_type': type_filter
|
||
}
|
||
)
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"获取概念统计失败: {e}")
|
||
return ConceptStatisticsResponse(
|
||
success=True, data=get_fallback_statistics(start_date, end_date),
|
||
params={'start_date': str(start_date), 'end_date': str(end_date)},
|
||
note=f"使用示例数据: {str(e)}"
|
||
)
|
||
|
||
|
||
# ==================== 主函数 ====================
|
||
|
||
if __name__ == "__main__":
|
||
import uvicorn
|
||
uvicorn.run(
|
||
"concept_api_v4:app",
|
||
host="0.0.0.0",
|
||
port=6801,
|
||
reload=True,
|
||
log_level="info"
|
||
)
|