update pay ui
This commit is contained in:
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,28 +0,0 @@
|
||||
2025-12-08 16:40:41,567 - INFO - ============================================================
|
||||
2025-12-08 16:40:41,567 - INFO - 🔄 回测: 2025-12-08 (Alpha Z-Score 方法)
|
||||
2025-12-08 16:40:41,569 - INFO - ============================================================
|
||||
2025-12-08 16:40:41,679 - INFO - 已清除 2025-12-08 的数据
|
||||
2025-12-08 16:40:41,903 - INFO - POST http://222.128.1.157:19200/concept_library_v3/_search?scroll=2m [status:200 duration:0.224s]
|
||||
2025-12-08 16:40:42,105 - INFO - POST http://222.128.1.157:19200/_search/scroll [status:200 duration:0.197s]
|
||||
2025-12-08 16:40:42,330 - INFO - POST http://222.128.1.157:19200/_search/scroll [status:200 duration:0.178s]
|
||||
2025-12-08 16:40:42,518 - INFO - POST http://222.128.1.157:19200/_search/scroll [status:200 duration:0.183s]
|
||||
2025-12-08 16:40:42,704 - INFO - POST http://222.128.1.157:19200/_search/scroll [status:200 duration:0.182s]
|
||||
2025-12-08 16:40:42,894 - INFO - POST http://222.128.1.157:19200/_search/scroll [status:200 duration:0.186s]
|
||||
2025-12-08 16:40:43,060 - INFO - POST http://222.128.1.157:19200/_search/scroll [status:200 duration:0.162s]
|
||||
2025-12-08 16:40:43,234 - INFO - POST http://222.128.1.157:19200/_search/scroll [status:200 duration:0.171s]
|
||||
2025-12-08 16:40:43,383 - INFO - POST http://222.128.1.157:19200/_search/scroll [status:200 duration:0.145s]
|
||||
2025-12-08 16:40:43,394 - INFO - POST http://222.128.1.157:19200/_search/scroll [status:200 duration:0.008s]
|
||||
2025-12-08 16:40:43,399 - INFO - DELETE http://222.128.1.157:19200/_search/scroll [status:200 duration:0.005s]
|
||||
2025-12-08 16:40:43,409 - INFO - 概念: 968, 股票: 5938
|
||||
2025-12-08 16:40:43,505 - INFO - 时间点: 241
|
||||
2025-12-08 16:41:02,028 - INFO - 进度: 30/241 (12%), 异动: 0
|
||||
2025-12-08 16:41:20,851 - INFO - 进度: 60/241 (24%), 异动: 0
|
||||
2025-12-08 16:41:39,396 - INFO - 进度: 90/241 (37%), 异动: 0
|
||||
2025-12-08 16:41:58,687 - INFO - 进度: 120/241 (49%), 异动: 0
|
||||
2025-12-08 16:43:08,124 - INFO - 进度: 150/241 (62%), 异动: 0
|
||||
2025-12-08 16:43:26,973 - INFO - 进度: 180/241 (74%), 异动: 0
|
||||
2025-12-08 16:43:45,746 - INFO - 进度: 210/241 (87%), 异动: 0
|
||||
2025-12-08 16:44:04,479 - INFO - 进度: 240/241 (99%), 异动: 0
|
||||
2025-12-08 16:44:05,123 - INFO - ============================================================
|
||||
2025-12-08 16:44:05,123 - INFO - ✅ 回测完成! 检测到 0 条异动
|
||||
2025-12-08 16:44:05,125 - INFO - ============================================================
|
||||
1625
concept_alert_ml.py
1625
concept_alert_ml.py
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,681 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
概念涨跌幅实时更新服务
|
||||
- 在交易时间段每分钟从ClickHouse获取最新分钟数据
|
||||
- 计算涨跌幅后更新MySQL的concept_daily_stats表
|
||||
- 支持叶子概念和母概念(lv1/lv2/lv3)
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from datetime import datetime, timedelta
|
||||
from sqlalchemy import create_engine, text
|
||||
from elasticsearch import Elasticsearch
|
||||
from clickhouse_driver import Client
|
||||
import time
|
||||
import logging
|
||||
import json
|
||||
import os
|
||||
import hashlib
|
||||
import argparse
|
||||
|
||||
# ==================== 配置 ====================
|
||||
|
||||
# MySQL配置
|
||||
MYSQL_ENGINE = create_engine(
|
||||
"mysql+pymysql://root:Zzl5588161!@222.128.1.157:33060/stock",
|
||||
echo=False
|
||||
)
|
||||
|
||||
# Elasticsearch配置
|
||||
ES_CLIENT = Elasticsearch(['http://222.128.1.157:19200'])
|
||||
INDEX_NAME = 'concept_library_v3'
|
||||
|
||||
# ClickHouse配置
|
||||
CLICKHOUSE_CONFIG = {
|
||||
'host': '222.128.1.157',
|
||||
'port': 18000,
|
||||
'user': 'default',
|
||||
'password': 'Zzl33818!',
|
||||
'database': 'stock'
|
||||
}
|
||||
|
||||
# 层级结构文件
|
||||
HIERARCHY_FILE = 'concept_hierarchy_v3.json'
|
||||
|
||||
# 交易时间配置
|
||||
TRADING_HOURS = {
|
||||
'morning_start': (9, 30),
|
||||
'morning_end': (11, 30),
|
||||
'afternoon_start': (13, 0),
|
||||
'afternoon_end': (15, 0),
|
||||
}
|
||||
|
||||
# ==================== 日志配置 ====================
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.FileHandler(f'concept_realtime_{datetime.now().strftime("%Y%m%d")}.log', encoding='utf-8'),
|
||||
logging.StreamHandler()
|
||||
]
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ClickHouse客户端
|
||||
ch_client = None
|
||||
|
||||
|
||||
def get_ch_client():
|
||||
"""获取ClickHouse客户端"""
|
||||
global ch_client
|
||||
if ch_client is None:
|
||||
ch_client = Client(**CLICKHOUSE_CONFIG)
|
||||
return ch_client
|
||||
|
||||
|
||||
def generate_id(name: str) -> str:
|
||||
"""生成概念ID"""
|
||||
return hashlib.md5(name.encode('utf-8')).hexdigest()[:16]
|
||||
|
||||
|
||||
def code_to_ch_format(code: str) -> str:
|
||||
"""将6位股票代码转换为ClickHouse格式(带后缀)
|
||||
|
||||
规则:
|
||||
- 6开头 -> .SH(上海)
|
||||
- 0或3开头 -> .SZ(深圳)
|
||||
- 其他 -> .BJ(北京)
|
||||
- 非6位数字的忽略(可能是港股)
|
||||
"""
|
||||
if not code or len(code) != 6 or not code.isdigit():
|
||||
return None
|
||||
|
||||
if code.startswith('6'):
|
||||
return f"{code}.SH"
|
||||
elif code.startswith('0') or code.startswith('3'):
|
||||
return f"{code}.SZ"
|
||||
else:
|
||||
return f"{code}.BJ"
|
||||
|
||||
|
||||
def ch_code_to_pure(ch_code: str) -> str:
|
||||
"""将ClickHouse格式的股票代码转回纯6位代码"""
|
||||
if not ch_code:
|
||||
return None
|
||||
return ch_code.split('.')[0]
|
||||
|
||||
|
||||
# ==================== 概念数据获取 ====================
|
||||
|
||||
def get_all_concepts():
|
||||
"""从ES获取所有叶子概念及其股票列表"""
|
||||
concepts = []
|
||||
|
||||
query = {
|
||||
"query": {"match_all": {}},
|
||||
"size": 100,
|
||||
"_source": ["concept_id", "concept", "stocks"]
|
||||
}
|
||||
|
||||
resp = ES_CLIENT.search(index=INDEX_NAME, body=query, scroll='2m')
|
||||
scroll_id = resp['_scroll_id']
|
||||
hits = resp['hits']['hits']
|
||||
|
||||
while len(hits) > 0:
|
||||
for hit in hits:
|
||||
source = hit['_source']
|
||||
concept_info = {
|
||||
'concept_id': source.get('concept_id'),
|
||||
'concept_name': source.get('concept'),
|
||||
'stocks': [],
|
||||
'concept_type': 'leaf'
|
||||
}
|
||||
|
||||
# v3索引的stocks字段是 [{name, code}, ...]
|
||||
if 'stocks' in source and isinstance(source['stocks'], list):
|
||||
for stock in source['stocks']:
|
||||
if isinstance(stock, dict) and 'code' in stock and stock['code']:
|
||||
concept_info['stocks'].append(stock['code'])
|
||||
|
||||
if concept_info['stocks']:
|
||||
concepts.append(concept_info)
|
||||
|
||||
resp = ES_CLIENT.scroll(scroll_id=scroll_id, scroll='2m')
|
||||
scroll_id = resp['_scroll_id']
|
||||
hits = resp['hits']['hits']
|
||||
|
||||
ES_CLIENT.clear_scroll(scroll_id=scroll_id)
|
||||
return concepts
|
||||
|
||||
|
||||
def load_hierarchy_concepts(leaf_concepts: list) -> list:
|
||||
"""加载层级结构,生成母概念(lv1/lv2/lv3)"""
|
||||
hierarchy_path = os.path.join(os.path.dirname(__file__), HIERARCHY_FILE)
|
||||
if not os.path.exists(hierarchy_path):
|
||||
logger.warning(f"层级文件不存在: {hierarchy_path}")
|
||||
return []
|
||||
|
||||
with open(hierarchy_path, 'r', encoding='utf-8') as f:
|
||||
hierarchy_data = json.load(f)
|
||||
|
||||
# 建立概念名称到股票的映射
|
||||
concept_to_stocks = {}
|
||||
for c in leaf_concepts:
|
||||
concept_to_stocks[c['concept_name']] = set(c['stocks'])
|
||||
|
||||
parent_concepts = []
|
||||
|
||||
for lv1 in hierarchy_data.get('hierarchy', []):
|
||||
lv1_name = lv1.get('lv1', '')
|
||||
lv1_stocks = set()
|
||||
|
||||
for child in lv1.get('children', []):
|
||||
lv2_name = child.get('lv2', '')
|
||||
lv2_stocks = set()
|
||||
|
||||
if 'children' in child:
|
||||
for lv3_child in child.get('children', []):
|
||||
lv3_name = lv3_child.get('lv3', '')
|
||||
lv3_stocks = set()
|
||||
|
||||
for concept_name in lv3_child.get('concepts', []):
|
||||
if concept_name in concept_to_stocks:
|
||||
lv3_stocks.update(concept_to_stocks[concept_name])
|
||||
|
||||
if lv3_stocks:
|
||||
parent_concepts.append({
|
||||
'concept_id': generate_id(f"lv3_{lv3_name}"),
|
||||
'concept_name': f"[三级] {lv3_name}",
|
||||
'stocks': list(lv3_stocks),
|
||||
'concept_type': 'lv3'
|
||||
})
|
||||
|
||||
lv2_stocks.update(lv3_stocks)
|
||||
else:
|
||||
for concept_name in child.get('concepts', []):
|
||||
if concept_name in concept_to_stocks:
|
||||
lv2_stocks.update(concept_to_stocks[concept_name])
|
||||
|
||||
if lv2_stocks:
|
||||
parent_concepts.append({
|
||||
'concept_id': generate_id(f"lv2_{lv2_name}"),
|
||||
'concept_name': f"[二级] {lv2_name}",
|
||||
'stocks': list(lv2_stocks),
|
||||
'concept_type': 'lv2'
|
||||
})
|
||||
|
||||
lv1_stocks.update(lv2_stocks)
|
||||
|
||||
if lv1_stocks:
|
||||
parent_concepts.append({
|
||||
'concept_id': generate_id(f"lv1_{lv1_name}"),
|
||||
'concept_name': f"[一级] {lv1_name}",
|
||||
'stocks': list(lv1_stocks),
|
||||
'concept_type': 'lv1'
|
||||
})
|
||||
|
||||
return parent_concepts
|
||||
|
||||
|
||||
# ==================== 基准价格获取 ====================
|
||||
|
||||
def get_base_prices(stock_codes: list, current_date: str) -> dict:
|
||||
"""获取当日的昨收价作为基准(从ea_trade的F002N字段)
|
||||
|
||||
ea_trade表字段说明:
|
||||
- F002N: 昨日收盘价
|
||||
- F007N: 最近成交价(收盘价)
|
||||
- F010N: 涨跌幅
|
||||
"""
|
||||
if not stock_codes:
|
||||
return {}
|
||||
|
||||
# 过滤出有效的6位股票代码
|
||||
valid_codes = [code for code in stock_codes if code and len(code) == 6 and code.isdigit()]
|
||||
if not valid_codes:
|
||||
return {}
|
||||
|
||||
stock_codes_str = "','".join(valid_codes)
|
||||
|
||||
# 获取当日数据中的昨收价(F002N)
|
||||
query = f"""
|
||||
SELECT SECCODE, F002N
|
||||
FROM ea_trade
|
||||
WHERE SECCODE IN ('{stock_codes_str}')
|
||||
AND TRADEDATE = (
|
||||
SELECT MAX(TRADEDATE)
|
||||
FROM ea_trade
|
||||
WHERE TRADEDATE <= '{current_date}'
|
||||
)
|
||||
AND F002N IS NOT NULL AND F002N > 0
|
||||
"""
|
||||
|
||||
try:
|
||||
with MYSQL_ENGINE.connect() as conn:
|
||||
result = conn.execute(text(query))
|
||||
base_prices = {row[0]: float(row[1]) for row in result if row[1] and float(row[1]) > 0}
|
||||
logger.info(f"获取到 {len(base_prices)} 个基准价格")
|
||||
return base_prices
|
||||
except Exception as e:
|
||||
logger.error(f"获取基准价格失败: {e}")
|
||||
return {}
|
||||
|
||||
|
||||
# ==================== 实时价格获取 ====================
|
||||
|
||||
def get_latest_prices(stock_codes: list) -> dict:
|
||||
"""从ClickHouse获取最新分钟数据的收盘价
|
||||
|
||||
Args:
|
||||
stock_codes: 纯6位股票代码列表(如 ['000001', '600000'])
|
||||
|
||||
Returns:
|
||||
dict: {纯6位代码: {'close': 价格, 'timestamp': 时间}}
|
||||
"""
|
||||
if not stock_codes:
|
||||
return {}
|
||||
|
||||
client = get_ch_client()
|
||||
|
||||
# 转换为ClickHouse格式的代码(带后缀)
|
||||
ch_codes = []
|
||||
code_mapping = {} # ch_code -> pure_code
|
||||
for code in stock_codes:
|
||||
ch_code = code_to_ch_format(code)
|
||||
if ch_code:
|
||||
ch_codes.append(ch_code)
|
||||
code_mapping[ch_code] = code
|
||||
|
||||
if not ch_codes:
|
||||
logger.warning("没有有效的股票代码可查询")
|
||||
return {}
|
||||
|
||||
ch_codes_str = "','".join(ch_codes)
|
||||
|
||||
# 获取今日最新的分钟数据
|
||||
query = f"""
|
||||
SELECT code, close, timestamp
|
||||
FROM (
|
||||
SELECT code, close, timestamp,
|
||||
ROW_NUMBER() OVER (PARTITION BY code ORDER BY timestamp DESC) as rn
|
||||
FROM stock_minute
|
||||
WHERE code IN ('{ch_codes_str}')
|
||||
AND toDate(timestamp) = today()
|
||||
)
|
||||
WHERE rn = 1
|
||||
"""
|
||||
|
||||
try:
|
||||
result = client.execute(query)
|
||||
if not result:
|
||||
return {}
|
||||
|
||||
latest_prices = {}
|
||||
for row in result:
|
||||
ch_code, close, ts = row
|
||||
if close and close > 0:
|
||||
# 转回纯6位代码
|
||||
pure_code = code_mapping.get(ch_code)
|
||||
if pure_code:
|
||||
latest_prices[pure_code] = {
|
||||
'close': float(close),
|
||||
'timestamp': ts
|
||||
}
|
||||
|
||||
return latest_prices
|
||||
except Exception as e:
|
||||
logger.error(f"获取最新价格失败: {e}")
|
||||
return {}
|
||||
|
||||
|
||||
# ==================== 涨跌幅计算 ====================
|
||||
|
||||
def calculate_change_pct(base_prices: dict, latest_prices: dict) -> dict:
|
||||
"""计算涨跌幅"""
|
||||
changes = {}
|
||||
for code, latest in latest_prices.items():
|
||||
if code in base_prices and base_prices[code] > 0:
|
||||
base = base_prices[code]
|
||||
close = latest['close']
|
||||
change_pct = (close - base) / base * 100
|
||||
changes[code] = round(change_pct, 4)
|
||||
return changes
|
||||
|
||||
|
||||
def calculate_concept_stats(concepts: list, stock_changes: dict, trade_date: str) -> list:
|
||||
"""计算所有概念的涨跌幅统计"""
|
||||
stats = []
|
||||
|
||||
for concept in concepts:
|
||||
concept_id = concept['concept_id']
|
||||
concept_name = concept['concept_name']
|
||||
stock_codes = concept['stocks']
|
||||
concept_type = concept.get('concept_type', 'leaf')
|
||||
|
||||
# 获取该概念股票的涨跌幅
|
||||
changes = [stock_changes[code] for code in stock_codes if code in stock_changes]
|
||||
|
||||
if not changes:
|
||||
continue
|
||||
|
||||
avg_change_pct = round(np.mean(changes), 4)
|
||||
stock_count = len(changes)
|
||||
|
||||
stats.append({
|
||||
'concept_id': concept_id,
|
||||
'concept_name': concept_name,
|
||||
'trade_date': trade_date,
|
||||
'avg_change_pct': avg_change_pct,
|
||||
'stock_count': stock_count,
|
||||
'concept_type': concept_type
|
||||
})
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
# ==================== MySQL更新 ====================
|
||||
|
||||
def update_mysql_stats(stats: list):
|
||||
"""更新MySQL的concept_daily_stats表"""
|
||||
if not stats:
|
||||
return 0
|
||||
|
||||
with MYSQL_ENGINE.begin() as conn:
|
||||
updated = 0
|
||||
for item in stats:
|
||||
upsert_sql = text("""
|
||||
REPLACE INTO concept_daily_stats
|
||||
(concept_id, concept_name, trade_date, avg_change_pct, stock_count, concept_type)
|
||||
VALUES (:concept_id, :concept_name, :trade_date, :avg_change_pct, :stock_count, :concept_type)
|
||||
""")
|
||||
conn.execute(upsert_sql, item)
|
||||
updated += 1
|
||||
|
||||
return updated
|
||||
|
||||
|
||||
# ==================== 交易时间判断 ====================
|
||||
|
||||
def is_trading_time() -> bool:
|
||||
"""判断当前是否为交易时间"""
|
||||
now = datetime.now()
|
||||
weekday = now.weekday()
|
||||
|
||||
# 周末不交易
|
||||
if weekday >= 5:
|
||||
return False
|
||||
|
||||
hour, minute = now.hour, now.minute
|
||||
current_time = hour * 60 + minute
|
||||
|
||||
# 上午 9:30 - 11:30
|
||||
morning_start = 9 * 60 + 30
|
||||
morning_end = 11 * 60 + 30
|
||||
|
||||
# 下午 13:00 - 15:00
|
||||
afternoon_start = 13 * 60
|
||||
afternoon_end = 15 * 60
|
||||
|
||||
return (morning_start <= current_time <= morning_end) or \
|
||||
(afternoon_start <= current_time <= afternoon_end)
|
||||
|
||||
|
||||
def get_next_update_time() -> int:
|
||||
"""获取距离下次更新的秒数"""
|
||||
now = datetime.now()
|
||||
|
||||
if is_trading_time():
|
||||
# 交易时间内,等到下一分钟
|
||||
return 60 - now.second
|
||||
else:
|
||||
# 非交易时间
|
||||
hour, minute = now.hour, now.minute
|
||||
|
||||
# 计算距离下次交易开始的时间
|
||||
if hour < 9 or (hour == 9 and minute < 30):
|
||||
# 等到9:30
|
||||
target = now.replace(hour=9, minute=30, second=0, microsecond=0)
|
||||
elif (hour == 11 and minute >= 30) or hour == 12:
|
||||
# 等到13:00
|
||||
target = now.replace(hour=13, minute=0, second=0, microsecond=0)
|
||||
elif hour >= 15:
|
||||
# 等到明天9:30
|
||||
target = (now + timedelta(days=1)).replace(hour=9, minute=30, second=0, microsecond=0)
|
||||
else:
|
||||
target = now + timedelta(minutes=1)
|
||||
|
||||
wait_seconds = (target - now).total_seconds()
|
||||
return max(60, int(wait_seconds))
|
||||
|
||||
|
||||
# ==================== 主运行逻辑 ====================
|
||||
|
||||
def run_once(concepts: list, all_stocks: list) -> int:
|
||||
"""执行一次更新"""
|
||||
now = datetime.now()
|
||||
trade_date = now.strftime('%Y-%m-%d')
|
||||
|
||||
# 获取基准价格(昨日收盘价)
|
||||
base_prices = get_base_prices(all_stocks, trade_date)
|
||||
if not base_prices:
|
||||
logger.warning("无法获取基准价格")
|
||||
return 0
|
||||
|
||||
# 获取最新价格
|
||||
latest_prices = get_latest_prices(all_stocks)
|
||||
if not latest_prices:
|
||||
logger.warning("无法获取最新价格")
|
||||
return 0
|
||||
|
||||
# 计算涨跌幅
|
||||
stock_changes = calculate_change_pct(base_prices, latest_prices)
|
||||
if not stock_changes:
|
||||
logger.warning("无涨跌幅数据")
|
||||
return 0
|
||||
|
||||
logger.info(f"获取到 {len(stock_changes)} 只股票的涨跌幅")
|
||||
|
||||
# 计算概念统计
|
||||
stats = calculate_concept_stats(concepts, stock_changes, trade_date)
|
||||
logger.info(f"计算了 {len(stats)} 个概念的涨跌幅")
|
||||
|
||||
# 更新MySQL
|
||||
updated = update_mysql_stats(stats)
|
||||
logger.info(f"更新了 {updated} 条记录到MySQL")
|
||||
|
||||
return updated
|
||||
|
||||
|
||||
def run_realtime():
|
||||
"""实时更新主循环"""
|
||||
logger.info("=" * 60)
|
||||
logger.info("启动概念涨跌幅实时更新服务")
|
||||
logger.info("=" * 60)
|
||||
|
||||
# 加载概念数据
|
||||
logger.info("加载概念数据...")
|
||||
leaf_concepts = get_all_concepts()
|
||||
logger.info(f"获取到 {len(leaf_concepts)} 个叶子概念")
|
||||
|
||||
parent_concepts = load_hierarchy_concepts(leaf_concepts)
|
||||
logger.info(f"生成了 {len(parent_concepts)} 个母概念")
|
||||
|
||||
all_concepts = leaf_concepts + parent_concepts
|
||||
logger.info(f"总计 {len(all_concepts)} 个概念")
|
||||
|
||||
# 收集所有股票代码
|
||||
all_stocks = set()
|
||||
for c in all_concepts:
|
||||
all_stocks.update(c['stocks'])
|
||||
all_stocks = list(all_stocks)
|
||||
logger.info(f"监控 {len(all_stocks)} 只股票")
|
||||
|
||||
last_concept_update = datetime.now()
|
||||
|
||||
while True:
|
||||
try:
|
||||
now = datetime.now()
|
||||
|
||||
# 每小时重新加载概念数据
|
||||
if (now - last_concept_update).total_seconds() > 3600:
|
||||
logger.info("重新加载概念数据...")
|
||||
leaf_concepts = get_all_concepts()
|
||||
parent_concepts = load_hierarchy_concepts(leaf_concepts)
|
||||
all_concepts = leaf_concepts + parent_concepts
|
||||
all_stocks = set()
|
||||
for c in all_concepts:
|
||||
all_stocks.update(c['stocks'])
|
||||
all_stocks = list(all_stocks)
|
||||
last_concept_update = now
|
||||
logger.info(f"更新完成: {len(all_concepts)} 个概念, {len(all_stocks)} 只股票")
|
||||
|
||||
# 检查是否交易时间
|
||||
if not is_trading_time():
|
||||
wait_sec = get_next_update_time()
|
||||
wait_min = wait_sec // 60
|
||||
logger.info(f"非交易时间,等待 {wait_min} 分钟后重试...")
|
||||
time.sleep(min(wait_sec, 300)) # 最多等5分钟再检查
|
||||
continue
|
||||
|
||||
# 执行更新
|
||||
logger.info(f"\n{'=' * 40}")
|
||||
logger.info(f"更新时间: {now.strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
updated = run_once(all_concepts, all_stocks)
|
||||
|
||||
# 等待下一分钟
|
||||
sleep_sec = 60 - datetime.now().second
|
||||
logger.info(f"完成,等待 {sleep_sec} 秒后继续...")
|
||||
time.sleep(sleep_sec)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("\n收到退出信号,停止服务...")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"发生错误: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
time.sleep(60)
|
||||
|
||||
|
||||
def run_single():
|
||||
"""单次运行(不循环)"""
|
||||
logger.info("单次更新模式")
|
||||
|
||||
leaf_concepts = get_all_concepts()
|
||||
parent_concepts = load_hierarchy_concepts(leaf_concepts)
|
||||
all_concepts = leaf_concepts + parent_concepts
|
||||
|
||||
all_stocks = set()
|
||||
for c in all_concepts:
|
||||
all_stocks.update(c['stocks'])
|
||||
all_stocks = list(all_stocks)
|
||||
|
||||
logger.info(f"概念数: {len(all_concepts)}, 股票数: {len(all_stocks)}")
|
||||
|
||||
updated = run_once(all_concepts, all_stocks)
|
||||
logger.info(f"更新完成: {updated} 条记录")
|
||||
|
||||
|
||||
def show_status():
|
||||
"""显示当前状态"""
|
||||
print("\n" + "=" * 60)
|
||||
print("概念涨跌幅实时更新服务 - 状态")
|
||||
print("=" * 60)
|
||||
|
||||
# 当前时间
|
||||
now = datetime.now()
|
||||
print(f"\n当前时间: {now.strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
print(f"是否交易时间: {'是' if is_trading_time() else '否'}")
|
||||
|
||||
# MySQL数据状态
|
||||
print("\nMySQL数据状态:")
|
||||
try:
|
||||
with MYSQL_ENGINE.connect() as conn:
|
||||
# 今日数据量
|
||||
result = conn.execute(text("""
|
||||
SELECT concept_type, COUNT(*) as cnt
|
||||
FROM concept_daily_stats
|
||||
WHERE trade_date = CURDATE()
|
||||
GROUP BY concept_type
|
||||
"""))
|
||||
rows = list(result)
|
||||
if rows:
|
||||
print(" 今日数据:")
|
||||
for row in rows:
|
||||
print(f" {row[0]}: {row[1]} 条")
|
||||
else:
|
||||
print(" 今日暂无数据")
|
||||
|
||||
# 最新更新时间
|
||||
result = conn.execute(text("""
|
||||
SELECT MAX(updated_at) FROM concept_daily_stats WHERE trade_date = CURDATE()
|
||||
"""))
|
||||
row = result.fetchone()
|
||||
if row and row[0]:
|
||||
print(f" 最后更新: {row[0]}")
|
||||
except Exception as e:
|
||||
print(f" 查询失败: {e}")
|
||||
|
||||
# ClickHouse数据状态
|
||||
print("\nClickHouse数据状态:")
|
||||
try:
|
||||
client = get_ch_client()
|
||||
result = client.execute("""
|
||||
SELECT COUNT(*), MAX(timestamp)
|
||||
FROM stock_minute
|
||||
WHERE toDate(timestamp) = today()
|
||||
""")
|
||||
if result:
|
||||
count, max_ts = result[0]
|
||||
print(f" 今日分钟数据: {count:,} 条")
|
||||
print(f" 最新时间戳: {max_ts}")
|
||||
except Exception as e:
|
||||
print(f" 查询失败: {e}")
|
||||
|
||||
# 今日涨跌幅TOP10
|
||||
print("\n今日涨跌幅 TOP10:")
|
||||
try:
|
||||
with MYSQL_ENGINE.connect() as conn:
|
||||
result = conn.execute(text("""
|
||||
SELECT concept_name, avg_change_pct, stock_count, concept_type
|
||||
FROM concept_daily_stats
|
||||
WHERE trade_date = CURDATE() AND concept_type = 'leaf'
|
||||
ORDER BY avg_change_pct DESC
|
||||
LIMIT 10
|
||||
"""))
|
||||
rows = list(result)
|
||||
if rows:
|
||||
print(f" {'概念':<25} | {'涨跌幅':>8} | {'股票数':>6}")
|
||||
print(" " + "-" * 50)
|
||||
for row in rows:
|
||||
name = row[0][:25] if len(row[0]) > 25 else row[0]
|
||||
print(f" {name:<25} | {row[1]:>7.2f}% | {row[2]:>6}")
|
||||
else:
|
||||
print(" 暂无数据")
|
||||
except Exception as e:
|
||||
print(f" 查询失败: {e}")
|
||||
|
||||
|
||||
# ==================== 主函数 ====================
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='概念涨跌幅实时更新服务')
|
||||
parser.add_argument('command', nargs='?', default='realtime',
|
||||
choices=['realtime', 'once', 'status'],
|
||||
help='命令: realtime(实时运行), once(单次运行), status(状态查看)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.command == 'realtime':
|
||||
run_realtime()
|
||||
elif args.command == 'once':
|
||||
run_single()
|
||||
elif args.command == 'status':
|
||||
show_status()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,89 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""创建异动检测所需的数据库表"""
|
||||
|
||||
import sys
|
||||
from sqlalchemy import create_engine, text
|
||||
|
||||
engine = create_engine('mysql+pymysql://root:Zzl5588161!@222.128.1.157:33060/stock', echo=False)
|
||||
|
||||
# 删除旧表
|
||||
drop_sql1 = 'DROP TABLE IF EXISTS concept_minute_alert'
|
||||
drop_sql2 = 'DROP TABLE IF EXISTS index_minute_snapshot'
|
||||
|
||||
# 创建 concept_minute_alert 表
|
||||
# 支持 Z-Score + SVM 智能检测
|
||||
sql1 = '''
|
||||
CREATE TABLE concept_minute_alert (
|
||||
id BIGINT AUTO_INCREMENT PRIMARY KEY,
|
||||
concept_id VARCHAR(32) NOT NULL,
|
||||
concept_name VARCHAR(100) NOT NULL,
|
||||
alert_time DATETIME NOT NULL,
|
||||
alert_type VARCHAR(20) NOT NULL COMMENT 'surge_up=暴涨, surge_down=暴跌, limit_up=涨停增加, rank_jump=排名跃升',
|
||||
trade_date DATE NOT NULL,
|
||||
change_pct DECIMAL(10,4) COMMENT '当前涨跌幅',
|
||||
prev_change_pct DECIMAL(10,4) COMMENT '之前涨跌幅',
|
||||
change_delta DECIMAL(10,4) COMMENT '涨跌幅变化',
|
||||
limit_up_count INT DEFAULT 0 COMMENT '涨停数',
|
||||
prev_limit_up_count INT DEFAULT 0,
|
||||
limit_up_delta INT DEFAULT 0,
|
||||
limit_down_count INT DEFAULT 0 COMMENT '跌停数',
|
||||
rank_position INT COMMENT '当前排名',
|
||||
prev_rank_position INT COMMENT '之前排名',
|
||||
rank_delta INT COMMENT '排名变化(负数表示上升)',
|
||||
index_code VARCHAR(20) DEFAULT '000001.SH',
|
||||
index_price DECIMAL(12,4),
|
||||
index_change_pct DECIMAL(10,4),
|
||||
stock_count INT,
|
||||
concept_type VARCHAR(20) DEFAULT 'leaf',
|
||||
zscore DECIMAL(8,4) COMMENT 'Z-Score值',
|
||||
importance_score DECIMAL(6,4) COMMENT '重要性评分(0-1)',
|
||||
extra_info JSON COMMENT '扩展信息(包含zscore,svm_score等)',
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
INDEX idx_trade_date (trade_date),
|
||||
INDEX idx_alert_time (alert_time),
|
||||
INDEX idx_concept_id (concept_id),
|
||||
INDEX idx_alert_type (alert_type),
|
||||
INDEX idx_trade_date_time (trade_date, alert_time),
|
||||
INDEX idx_importance (importance_score)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='概念异动记录表(智能版)'
|
||||
'''
|
||||
|
||||
# 创建 index_minute_snapshot 表
|
||||
sql2 = '''
|
||||
CREATE TABLE index_minute_snapshot (
|
||||
id BIGINT AUTO_INCREMENT PRIMARY KEY,
|
||||
index_code VARCHAR(20) NOT NULL,
|
||||
trade_date DATE NOT NULL,
|
||||
snapshot_time DATETIME NOT NULL,
|
||||
price DECIMAL(12,4),
|
||||
open_price DECIMAL(12,4),
|
||||
high_price DECIMAL(12,4),
|
||||
low_price DECIMAL(12,4),
|
||||
prev_close DECIMAL(12,4),
|
||||
change_pct DECIMAL(10,4),
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
UNIQUE KEY uk_index_time (index_code, snapshot_time),
|
||||
INDEX idx_trade_date (trade_date)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
|
||||
'''
|
||||
|
||||
if __name__ == '__main__':
|
||||
print('正在重建数据库表...\n')
|
||||
|
||||
with engine.begin() as conn:
|
||||
# 先删除旧表
|
||||
print('删除旧表...')
|
||||
conn.execute(text(drop_sql1))
|
||||
print(' - concept_minute_alert 已删除')
|
||||
conn.execute(text(drop_sql2))
|
||||
print(' - index_minute_snapshot 已删除')
|
||||
|
||||
# 创建新表
|
||||
print('\n创建新表...')
|
||||
conn.execute(text(sql1))
|
||||
print(' ✅ concept_minute_alert 表创建成功')
|
||||
conn.execute(text(sql2))
|
||||
print(' ✅ index_minute_snapshot 表创建成功')
|
||||
|
||||
print('\n✅ 所有表创建完成!')
|
||||
BIN
ml/__pycache__/realtime_detector.cpython-310.pyc
Normal file
BIN
ml/__pycache__/realtime_detector.cpython-310.pyc
Normal file
Binary file not shown.
859
ml/backtest_fast.py
Normal file
859
ml/backtest_fast.py
Normal file
@@ -0,0 +1,859 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
快速融合异动回测脚本
|
||||
|
||||
优化策略:
|
||||
1. 预先构建所有序列(向量化),避免循环内重复切片
|
||||
2. 批量 ML 推理(一次推理所有候选)
|
||||
3. 使用 NumPy 向量化操作替代 Python 循环
|
||||
|
||||
性能对比:
|
||||
- 原版:5分钟/天
|
||||
- 优化版:预计 10-30秒/天
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from sqlalchemy import create_engine, text
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
|
||||
# ==================== 配置 ====================
|
||||
|
||||
MYSQL_ENGINE = create_engine(
|
||||
"mysql+pymysql://root:Zzl5588161!@192.168.1.5:3306/stock",
|
||||
echo=False
|
||||
)
|
||||
|
||||
FEATURES = ['alpha', 'alpha_delta', 'amt_ratio', 'amt_delta', 'rank_pct', 'limit_up_ratio']
|
||||
|
||||
CONFIG = {
|
||||
'seq_len': 15, # 序列长度(支持跨日后可从 9:30 检测)
|
||||
'min_alpha_abs': 0.3, # 最小 alpha 过滤
|
||||
'cooldown_minutes': 8,
|
||||
'max_alerts_per_minute': 20,
|
||||
'clip_value': 10.0,
|
||||
# === 融合权重:均衡 ===
|
||||
'rule_weight': 0.5,
|
||||
'ml_weight': 0.5,
|
||||
# === 触发阈值 ===
|
||||
'rule_trigger': 65, # 60 -> 65,略提高规则门槛
|
||||
'ml_trigger': 70, # 75 -> 70,略降低 ML 门槛
|
||||
'fusion_trigger': 45,
|
||||
}
|
||||
|
||||
|
||||
# ==================== 规则评分(向量化版)====================
|
||||
|
||||
def get_size_adjusted_thresholds(stock_count: np.ndarray) -> dict:
|
||||
"""
|
||||
根据概念股票数量计算动态阈值
|
||||
|
||||
设计思路:
|
||||
- 小概念(<10 只):波动大是正常的,需要更高阈值
|
||||
- 中概念(10-50 只):标准阈值
|
||||
- 大概念(>50 只):能有明显波动说明是真异动,降低阈值
|
||||
|
||||
返回各指标的调整系数(乘以基准阈值)
|
||||
"""
|
||||
n = len(stock_count)
|
||||
|
||||
# 基于股票数量的调整系数
|
||||
# 小概念:系数 > 1(提高阈值,更难触发)
|
||||
# 大概念:系数 < 1(降低阈值,更容易触发)
|
||||
size_factor = np.ones(n)
|
||||
|
||||
# 微型概念(<5 只):阈值 × 1.8
|
||||
tiny = stock_count < 5
|
||||
size_factor[tiny] = 1.8
|
||||
|
||||
# 小概念(5-10 只):阈值 × 1.4
|
||||
small = (stock_count >= 5) & (stock_count < 10)
|
||||
size_factor[small] = 1.4
|
||||
|
||||
# 中小概念(10-20 只):阈值 × 1.2
|
||||
medium_small = (stock_count >= 10) & (stock_count < 20)
|
||||
size_factor[medium_small] = 1.2
|
||||
|
||||
# 中概念(20-50 只):标准阈值 × 1.0
|
||||
medium = (stock_count >= 20) & (stock_count < 50)
|
||||
size_factor[medium] = 1.0
|
||||
|
||||
# 大概念(50-100 只):阈值 × 0.85
|
||||
large = (stock_count >= 50) & (stock_count < 100)
|
||||
size_factor[large] = 0.85
|
||||
|
||||
# 超大概念(>100 只):阈值 × 0.7
|
||||
xlarge = stock_count >= 100
|
||||
size_factor[xlarge] = 0.7
|
||||
|
||||
return size_factor
|
||||
|
||||
|
||||
def score_rules_batch(df: pd.DataFrame) -> Tuple[np.ndarray, List[List[str]]]:
|
||||
"""
|
||||
批量计算规则得分(向量化)- 考虑概念规模版
|
||||
|
||||
设计原则:
|
||||
- 规则作为辅助信号,不应单独主导决策
|
||||
- 根据概念股票数量动态调整阈值
|
||||
- 大概念异动更有价值,小概念需要更大波动才算异动
|
||||
|
||||
Args:
|
||||
df: DataFrame,包含所有特征列(必须包含 stock_count)
|
||||
Returns:
|
||||
scores: (n,) 规则得分数组
|
||||
triggered_rules: 每行触发的规则列表
|
||||
"""
|
||||
n = len(df)
|
||||
scores = np.zeros(n)
|
||||
triggered = [[] for _ in range(n)]
|
||||
|
||||
alpha = df['alpha'].values
|
||||
alpha_delta = df['alpha_delta'].values
|
||||
amt_ratio = df['amt_ratio'].values
|
||||
amt_delta = df['amt_delta'].values
|
||||
rank_pct = df['rank_pct'].values
|
||||
limit_up_ratio = df['limit_up_ratio'].values
|
||||
stock_count = df['stock_count'].values if 'stock_count' in df.columns else np.full(n, 20)
|
||||
|
||||
alpha_abs = np.abs(alpha)
|
||||
alpha_delta_abs = np.abs(alpha_delta)
|
||||
|
||||
# 获取基于规模的调整系数
|
||||
size_factor = get_size_adjusted_thresholds(stock_count)
|
||||
|
||||
# ========== Alpha 规则(动态阈值)==========
|
||||
# 基准阈值:极强 5%,强 4%,中等 3%
|
||||
# 实际阈值 = 基准 × size_factor
|
||||
|
||||
# 极强信号
|
||||
alpha_extreme_thresh = 5.0 * size_factor
|
||||
mask = alpha_abs >= alpha_extreme_thresh
|
||||
scores[mask] += 20
|
||||
for i in np.where(mask)[0]: triggered[i].append('alpha_extreme')
|
||||
|
||||
# 强信号
|
||||
alpha_strong_thresh = 4.0 * size_factor
|
||||
mask = (alpha_abs >= alpha_strong_thresh) & (alpha_abs < alpha_extreme_thresh)
|
||||
scores[mask] += 15
|
||||
for i in np.where(mask)[0]: triggered[i].append('alpha_strong')
|
||||
|
||||
# 中等信号
|
||||
alpha_medium_thresh = 3.0 * size_factor
|
||||
mask = (alpha_abs >= alpha_medium_thresh) & (alpha_abs < alpha_strong_thresh)
|
||||
scores[mask] += 10
|
||||
for i in np.where(mask)[0]: triggered[i].append('alpha_medium')
|
||||
|
||||
# ========== Alpha 加速度规则(动态阈值)==========
|
||||
delta_strong_thresh = 2.0 * size_factor
|
||||
mask = alpha_delta_abs >= delta_strong_thresh
|
||||
scores[mask] += 15
|
||||
for i in np.where(mask)[0]: triggered[i].append('alpha_delta_strong')
|
||||
|
||||
delta_medium_thresh = 1.5 * size_factor
|
||||
mask = (alpha_delta_abs >= delta_medium_thresh) & (alpha_delta_abs < delta_strong_thresh)
|
||||
scores[mask] += 10
|
||||
for i in np.where(mask)[0]: triggered[i].append('alpha_delta_medium')
|
||||
|
||||
# ========== 成交额规则(不受规模影响,放量就是放量)==========
|
||||
mask = amt_ratio >= 10.0
|
||||
scores[mask] += 20
|
||||
for i in np.where(mask)[0]: triggered[i].append('volume_extreme')
|
||||
|
||||
mask = (amt_ratio >= 6.0) & (amt_ratio < 10.0)
|
||||
scores[mask] += 12
|
||||
for i in np.where(mask)[0]: triggered[i].append('volume_strong')
|
||||
|
||||
# ========== 排名规则 ==========
|
||||
mask = rank_pct >= 0.98
|
||||
scores[mask] += 15
|
||||
for i in np.where(mask)[0]: triggered[i].append('rank_top')
|
||||
|
||||
mask = rank_pct <= 0.02
|
||||
scores[mask] += 15
|
||||
for i in np.where(mask)[0]: triggered[i].append('rank_bottom')
|
||||
|
||||
# ========== 涨停规则(动态阈值)==========
|
||||
# 大概念有涨停更有意义
|
||||
limit_high_thresh = 0.30 * size_factor
|
||||
mask = limit_up_ratio >= limit_high_thresh
|
||||
scores[mask] += 20
|
||||
for i in np.where(mask)[0]: triggered[i].append('limit_up_high')
|
||||
|
||||
limit_medium_thresh = 0.20 * size_factor
|
||||
mask = (limit_up_ratio >= limit_medium_thresh) & (limit_up_ratio < limit_high_thresh)
|
||||
scores[mask] += 12
|
||||
for i in np.where(mask)[0]: triggered[i].append('limit_up_medium')
|
||||
|
||||
# ========== 概念规模加分(大概念异动更有价值)==========
|
||||
# 大概念(50+)额外加分
|
||||
large_concept = stock_count >= 50
|
||||
has_signal = scores > 0 # 至少触发了某个规则
|
||||
mask = large_concept & has_signal
|
||||
scores[mask] += 10
|
||||
for i in np.where(mask)[0]: triggered[i].append('large_concept_bonus')
|
||||
|
||||
# 超大概念(100+)再加分
|
||||
xlarge_concept = stock_count >= 100
|
||||
mask = xlarge_concept & has_signal
|
||||
scores[mask] += 10
|
||||
for i in np.where(mask)[0]: triggered[i].append('xlarge_concept_bonus')
|
||||
|
||||
# ========== 组合规则(动态阈值)==========
|
||||
combo_alpha_thresh = 3.0 * size_factor
|
||||
|
||||
# Alpha + 放量 + 排名(三重验证)
|
||||
mask = (alpha_abs >= combo_alpha_thresh) & (amt_ratio >= 5.0) & ((rank_pct >= 0.95) | (rank_pct <= 0.05))
|
||||
scores[mask] += 20
|
||||
for i in np.where(mask)[0]: triggered[i].append('triple_signal')
|
||||
|
||||
# Alpha + 涨停(强组合)
|
||||
mask = (alpha_abs >= combo_alpha_thresh) & (limit_up_ratio >= 0.15 * size_factor)
|
||||
scores[mask] += 15
|
||||
for i in np.where(mask)[0]: triggered[i].append('alpha_with_limit')
|
||||
|
||||
# ========== 小概念惩罚(过滤噪音)==========
|
||||
# 微型概念(<5 只)如果只有单一信号,减分
|
||||
tiny_concept = stock_count < 5
|
||||
single_rule = np.array([len(t) <= 1 for t in triggered])
|
||||
mask = tiny_concept & single_rule & (scores > 0)
|
||||
scores[mask] *= 0.5 # 减半
|
||||
for i in np.where(mask)[0]: triggered[i].append('tiny_concept_penalty')
|
||||
|
||||
scores = np.clip(scores, 0, 100)
|
||||
return scores, triggered
|
||||
|
||||
|
||||
# ==================== ML 评分器 ====================
|
||||
|
||||
class FastMLScorer:
|
||||
"""快速 ML 评分器"""
|
||||
|
||||
def __init__(self, checkpoint_dir: str = 'ml/checkpoints', device: str = 'auto'):
|
||||
self.checkpoint_dir = Path(checkpoint_dir)
|
||||
|
||||
if device == 'auto':
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
elif device == 'cuda' and not torch.cuda.is_available():
|
||||
print("警告: CUDA 不可用,使用 CPU")
|
||||
self.device = torch.device('cpu')
|
||||
else:
|
||||
self.device = torch.device(device)
|
||||
|
||||
self.model = None
|
||||
self.thresholds = None
|
||||
self._load_model()
|
||||
|
||||
def _load_model(self):
|
||||
model_path = self.checkpoint_dir / 'best_model.pt'
|
||||
thresholds_path = self.checkpoint_dir / 'thresholds.json'
|
||||
config_path = self.checkpoint_dir / 'config.json'
|
||||
|
||||
if not model_path.exists():
|
||||
print(f"警告: 模型不存在 {model_path}")
|
||||
return
|
||||
|
||||
try:
|
||||
from model import LSTMAutoencoder
|
||||
|
||||
config = {}
|
||||
if config_path.exists():
|
||||
with open(config_path) as f:
|
||||
config = json.load(f).get('model', {})
|
||||
|
||||
# 处理旧配置键名
|
||||
if 'd_model' in config:
|
||||
config['hidden_dim'] = config.pop('d_model') // 2
|
||||
for key in ['num_encoder_layers', 'num_decoder_layers', 'nhead', 'dim_feedforward', 'max_seq_len', 'use_instance_norm']:
|
||||
config.pop(key, None)
|
||||
if 'num_layers' not in config:
|
||||
config['num_layers'] = 1
|
||||
|
||||
checkpoint = torch.load(model_path, map_location='cpu')
|
||||
self.model = LSTMAutoencoder(**config)
|
||||
self.model.load_state_dict(checkpoint['model_state_dict'])
|
||||
self.model.to(self.device)
|
||||
self.model.eval()
|
||||
|
||||
if thresholds_path.exists():
|
||||
with open(thresholds_path) as f:
|
||||
self.thresholds = json.load(f)
|
||||
|
||||
print(f"ML模型加载成功 (设备: {self.device})")
|
||||
except Exception as e:
|
||||
print(f"ML模型加载失败: {e}")
|
||||
self.model = None
|
||||
|
||||
def is_ready(self):
|
||||
return self.model is not None
|
||||
|
||||
@torch.no_grad()
|
||||
def score_batch(self, sequences: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
批量计算 ML 得分
|
||||
|
||||
Args:
|
||||
sequences: (batch, seq_len, n_features)
|
||||
Returns:
|
||||
scores: (batch,) 0-100 分数
|
||||
"""
|
||||
if not self.is_ready() or len(sequences) == 0:
|
||||
return np.zeros(len(sequences))
|
||||
|
||||
x = torch.FloatTensor(sequences).to(self.device)
|
||||
output, _ = self.model(x)
|
||||
mse = ((output - x) ** 2).mean(dim=-1)
|
||||
errors = mse[:, -1].cpu().numpy()
|
||||
|
||||
p95 = self.thresholds.get('p95', 0.1) if self.thresholds else 0.1
|
||||
scores = np.clip(errors / p95 * 50, 0, 100)
|
||||
return scores
|
||||
|
||||
|
||||
# ==================== 快速回测 ====================
|
||||
|
||||
def build_sequences_fast(
|
||||
df: pd.DataFrame,
|
||||
seq_len: int = 30,
|
||||
prev_df: pd.DataFrame = None
|
||||
) -> Tuple[np.ndarray, pd.DataFrame]:
|
||||
"""
|
||||
快速构建所有有效序列
|
||||
|
||||
支持跨日序列:用前一天收盘数据 + 当天开盘数据拼接,实现 9:30 就能检测
|
||||
|
||||
Args:
|
||||
df: 当天数据
|
||||
seq_len: 序列长度
|
||||
prev_df: 前一天数据(可选,用于构建开盘时的序列)
|
||||
|
||||
返回:
|
||||
sequences: (n_valid, seq_len, n_features) 所有有效序列
|
||||
info_df: 对应的元信息 DataFrame
|
||||
"""
|
||||
# 确保按概念和时间排序
|
||||
df = df.sort_values(['concept_id', 'timestamp']).reset_index(drop=True)
|
||||
|
||||
# 如果有前一天数据,按概念构建尾部缓存(取每个概念最后 seq_len-1 条)
|
||||
prev_cache = {}
|
||||
if prev_df is not None and len(prev_df) > 0:
|
||||
prev_df = prev_df.sort_values(['concept_id', 'timestamp'])
|
||||
for concept_id, gdf in prev_df.groupby('concept_id'):
|
||||
tail_data = gdf.tail(seq_len - 1)
|
||||
if len(tail_data) > 0:
|
||||
feat_matrix = tail_data[FEATURES].values
|
||||
feat_matrix = np.nan_to_num(feat_matrix, nan=0.0, posinf=0.0, neginf=0.0)
|
||||
feat_matrix = np.clip(feat_matrix, -CONFIG['clip_value'], CONFIG['clip_value'])
|
||||
prev_cache[concept_id] = feat_matrix
|
||||
|
||||
# 按概念分组
|
||||
groups = df.groupby('concept_id')
|
||||
|
||||
sequences = []
|
||||
infos = []
|
||||
|
||||
for concept_id, gdf in groups:
|
||||
gdf = gdf.reset_index(drop=True)
|
||||
|
||||
# 获取特征矩阵
|
||||
feat_matrix = gdf[FEATURES].values
|
||||
feat_matrix = np.nan_to_num(feat_matrix, nan=0.0, posinf=0.0, neginf=0.0)
|
||||
feat_matrix = np.clip(feat_matrix, -CONFIG['clip_value'], CONFIG['clip_value'])
|
||||
|
||||
# 如果有前一天缓存,拼接到当天数据前面
|
||||
if concept_id in prev_cache:
|
||||
prev_data = prev_cache[concept_id]
|
||||
combined_matrix = np.vstack([prev_data, feat_matrix])
|
||||
# 计算偏移量:前一天数据的长度
|
||||
offset = len(prev_data)
|
||||
else:
|
||||
combined_matrix = feat_matrix
|
||||
offset = 0
|
||||
|
||||
# 滑动窗口构建序列
|
||||
n_total = len(combined_matrix)
|
||||
if n_total < seq_len:
|
||||
continue
|
||||
|
||||
for i in range(n_total - seq_len + 1):
|
||||
seq = combined_matrix[i:i + seq_len]
|
||||
|
||||
# 计算对应当天数据的索引
|
||||
# 序列最后一个点的位置 = i + seq_len - 1
|
||||
# 对应当天数据的索引 = (i + seq_len - 1) - offset
|
||||
today_idx = i + seq_len - 1 - offset
|
||||
|
||||
# 只要序列的最后一个点是当天的数据,就记录
|
||||
if today_idx < 0 or today_idx >= len(gdf):
|
||||
continue
|
||||
|
||||
sequences.append(seq)
|
||||
|
||||
# 记录最后一个时间步的信息(当天的)
|
||||
row = gdf.iloc[today_idx]
|
||||
infos.append({
|
||||
'concept_id': concept_id,
|
||||
'timestamp': row['timestamp'],
|
||||
'alpha': row['alpha'],
|
||||
'alpha_delta': row.get('alpha_delta', 0),
|
||||
'amt_ratio': row.get('amt_ratio', 1),
|
||||
'amt_delta': row.get('amt_delta', 0),
|
||||
'rank_pct': row.get('rank_pct', 0.5),
|
||||
'limit_up_ratio': row.get('limit_up_ratio', 0),
|
||||
'stock_count': row.get('stock_count', 0),
|
||||
'total_amt': row.get('total_amt', 0),
|
||||
})
|
||||
|
||||
if not sequences:
|
||||
return np.array([]), pd.DataFrame()
|
||||
|
||||
return np.array(sequences), pd.DataFrame(infos)
|
||||
|
||||
|
||||
def backtest_single_day_fast(
|
||||
ml_scorer: FastMLScorer,
|
||||
df: pd.DataFrame,
|
||||
date: str,
|
||||
config: Dict,
|
||||
prev_df: pd.DataFrame = None
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
快速回测单天(向量化版本)
|
||||
|
||||
Args:
|
||||
ml_scorer: ML 评分器
|
||||
df: 当天数据
|
||||
date: 日期
|
||||
config: 配置
|
||||
prev_df: 前一天数据(用于 9:30 开始检测)
|
||||
"""
|
||||
seq_len = config.get('seq_len', 30)
|
||||
|
||||
# 1. 构建所有序列(支持跨日)
|
||||
sequences, info_df = build_sequences_fast(df, seq_len, prev_df)
|
||||
|
||||
if len(sequences) == 0:
|
||||
return []
|
||||
|
||||
# 2. 过滤小波动
|
||||
alpha_abs = np.abs(info_df['alpha'].values)
|
||||
valid_mask = alpha_abs >= config['min_alpha_abs']
|
||||
|
||||
sequences = sequences[valid_mask]
|
||||
info_df = info_df[valid_mask].reset_index(drop=True)
|
||||
|
||||
if len(sequences) == 0:
|
||||
return []
|
||||
|
||||
# 3. 批量规则评分
|
||||
rule_scores, triggered_rules = score_rules_batch(info_df)
|
||||
|
||||
# 4. 批量 ML 评分(分批处理避免显存溢出)
|
||||
batch_size = 2048
|
||||
ml_scores = []
|
||||
for i in range(0, len(sequences), batch_size):
|
||||
batch_seq = sequences[i:i+batch_size]
|
||||
batch_scores = ml_scorer.score_batch(batch_seq)
|
||||
ml_scores.append(batch_scores)
|
||||
ml_scores = np.concatenate(ml_scores) if ml_scores else np.zeros(len(sequences))
|
||||
|
||||
# 5. 融合得分
|
||||
w1, w2 = config['rule_weight'], config['ml_weight']
|
||||
final_scores = w1 * rule_scores + w2 * ml_scores
|
||||
|
||||
# 6. 判断异动
|
||||
is_anomaly = (
|
||||
(rule_scores >= config['rule_trigger']) |
|
||||
(ml_scores >= config['ml_trigger']) |
|
||||
(final_scores >= config['fusion_trigger'])
|
||||
)
|
||||
|
||||
# 7. 应用冷却期(按概念+时间排序后处理)
|
||||
info_df['rule_score'] = rule_scores
|
||||
info_df['ml_score'] = ml_scores
|
||||
info_df['final_score'] = final_scores
|
||||
info_df['is_anomaly'] = is_anomaly
|
||||
info_df['triggered_rules'] = triggered_rules
|
||||
|
||||
# 只保留异动
|
||||
anomaly_df = info_df[info_df['is_anomaly']].copy()
|
||||
|
||||
if len(anomaly_df) == 0:
|
||||
return []
|
||||
|
||||
# 应用冷却期
|
||||
anomaly_df = anomaly_df.sort_values(['concept_id', 'timestamp'])
|
||||
cooldown = {}
|
||||
keep_mask = []
|
||||
|
||||
for _, row in anomaly_df.iterrows():
|
||||
cid = row['concept_id']
|
||||
ts = row['timestamp']
|
||||
|
||||
if cid in cooldown:
|
||||
try:
|
||||
diff = (ts - cooldown[cid]).total_seconds() / 60
|
||||
except:
|
||||
diff = config['cooldown_minutes'] + 1
|
||||
|
||||
if diff < config['cooldown_minutes']:
|
||||
keep_mask.append(False)
|
||||
continue
|
||||
|
||||
cooldown[cid] = ts
|
||||
keep_mask.append(True)
|
||||
|
||||
anomaly_df = anomaly_df[keep_mask]
|
||||
|
||||
# 8. 按时间分组,每分钟最多 max_alerts_per_minute 个
|
||||
alerts = []
|
||||
for ts, group in anomaly_df.groupby('timestamp'):
|
||||
group = group.nlargest(config['max_alerts_per_minute'], 'final_score')
|
||||
|
||||
for _, row in group.iterrows():
|
||||
alpha = row['alpha']
|
||||
if alpha >= 1.5:
|
||||
atype = 'surge_up'
|
||||
elif alpha <= -1.5:
|
||||
atype = 'surge_down'
|
||||
elif row['amt_ratio'] >= 3.0:
|
||||
atype = 'volume_spike'
|
||||
else:
|
||||
atype = 'unknown'
|
||||
|
||||
rule_score = row['rule_score']
|
||||
ml_score = row['ml_score']
|
||||
final_score = row['final_score']
|
||||
|
||||
if rule_score >= config['rule_trigger']:
|
||||
trigger = f'规则强信号({rule_score:.0f}分)'
|
||||
elif ml_score >= config['ml_trigger']:
|
||||
trigger = f'ML强信号({ml_score:.0f}分)'
|
||||
else:
|
||||
trigger = f'融合触发({final_score:.0f}分)'
|
||||
|
||||
alerts.append({
|
||||
'concept_id': row['concept_id'],
|
||||
'alert_time': row['timestamp'],
|
||||
'trade_date': date,
|
||||
'alert_type': atype,
|
||||
'final_score': final_score,
|
||||
'rule_score': rule_score,
|
||||
'ml_score': ml_score,
|
||||
'trigger_reason': trigger,
|
||||
'triggered_rules': row['triggered_rules'],
|
||||
'alpha': alpha,
|
||||
'alpha_delta': row['alpha_delta'],
|
||||
'amt_ratio': row['amt_ratio'],
|
||||
'amt_delta': row['amt_delta'],
|
||||
'rank_pct': row['rank_pct'],
|
||||
'limit_up_ratio': row['limit_up_ratio'],
|
||||
'stock_count': row['stock_count'],
|
||||
'total_amt': row['total_amt'],
|
||||
})
|
||||
|
||||
return alerts
|
||||
|
||||
|
||||
# ==================== 数据加载 ====================
|
||||
|
||||
def load_daily_features(data_dir: str, date: str) -> Optional[pd.DataFrame]:
|
||||
file_path = Path(data_dir) / f"features_{date}.parquet"
|
||||
if not file_path.exists():
|
||||
return None
|
||||
return pd.read_parquet(file_path)
|
||||
|
||||
|
||||
def get_available_dates(data_dir: str, start: str, end: str) -> List[str]:
|
||||
data_path = Path(data_dir)
|
||||
dates = []
|
||||
for f in sorted(data_path.glob("features_*.parquet")):
|
||||
d = f.stem.replace('features_', '')
|
||||
if start <= d <= end:
|
||||
dates.append(d)
|
||||
return dates
|
||||
|
||||
|
||||
def get_prev_trading_day(data_dir: str, date: str) -> Optional[str]:
|
||||
"""获取给定日期之前最近的有数据的交易日"""
|
||||
data_path = Path(data_dir)
|
||||
all_dates = sorted([f.stem.replace('features_', '') for f in data_path.glob("features_*.parquet")])
|
||||
|
||||
for i, d in enumerate(all_dates):
|
||||
if d == date and i > 0:
|
||||
return all_dates[i - 1]
|
||||
return None
|
||||
|
||||
|
||||
def export_to_csv(alerts: List[Dict], path: str):
|
||||
if alerts:
|
||||
pd.DataFrame(alerts).to_csv(path, index=False, encoding='utf-8-sig')
|
||||
print(f"已导出: {path}")
|
||||
|
||||
|
||||
# ==================== 数据库写入 ====================
|
||||
|
||||
def init_db_table():
|
||||
"""
|
||||
初始化数据库表(如果不存在则创建)
|
||||
|
||||
表结构说明:
|
||||
- concept_id: 概念ID
|
||||
- alert_time: 异动时间(精确到分钟)
|
||||
- trade_date: 交易日期
|
||||
- alert_type: 异动类型(surge_up/surge_down/volume_spike/unknown)
|
||||
- final_score: 最终得分(0-100)
|
||||
- rule_score: 规则得分(0-100)
|
||||
- ml_score: ML得分(0-100)
|
||||
- trigger_reason: 触发原因
|
||||
- alpha: 超额收益率
|
||||
- alpha_delta: alpha变化速度
|
||||
- amt_ratio: 成交额放大倍数
|
||||
- rank_pct: 排名百分位
|
||||
- stock_count: 概念股票数量
|
||||
- triggered_rules: 触发的规则列表(JSON)
|
||||
"""
|
||||
create_sql = text("""
|
||||
CREATE TABLE IF NOT EXISTS concept_anomaly_hybrid (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
concept_id VARCHAR(64) NOT NULL,
|
||||
alert_time DATETIME NOT NULL,
|
||||
trade_date DATE NOT NULL,
|
||||
alert_type VARCHAR(32) NOT NULL,
|
||||
final_score FLOAT NOT NULL,
|
||||
rule_score FLOAT NOT NULL,
|
||||
ml_score FLOAT NOT NULL,
|
||||
trigger_reason VARCHAR(64),
|
||||
alpha FLOAT,
|
||||
alpha_delta FLOAT,
|
||||
amt_ratio FLOAT,
|
||||
amt_delta FLOAT,
|
||||
rank_pct FLOAT,
|
||||
limit_up_ratio FLOAT,
|
||||
stock_count INT,
|
||||
total_amt FLOAT,
|
||||
triggered_rules JSON,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
UNIQUE KEY uk_concept_time (concept_id, alert_time, trade_date),
|
||||
INDEX idx_trade_date (trade_date),
|
||||
INDEX idx_concept_id (concept_id),
|
||||
INDEX idx_final_score (final_score),
|
||||
INDEX idx_alert_type (alert_type)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='概念异动检测结果(融合版)'
|
||||
""")
|
||||
|
||||
with MYSQL_ENGINE.begin() as conn:
|
||||
conn.execute(create_sql)
|
||||
print("数据库表已就绪: concept_anomaly_hybrid")
|
||||
|
||||
|
||||
def save_alerts_to_mysql(alerts: List[Dict], dry_run: bool = False) -> int:
|
||||
"""
|
||||
保存异动到 MySQL
|
||||
|
||||
Args:
|
||||
alerts: 异动列表
|
||||
dry_run: 是否只模拟,不实际写入
|
||||
|
||||
Returns:
|
||||
实际保存的记录数
|
||||
"""
|
||||
if not alerts:
|
||||
return 0
|
||||
|
||||
if dry_run:
|
||||
print(f" [Dry Run] 将写入 {len(alerts)} 条异动")
|
||||
return len(alerts)
|
||||
|
||||
saved = 0
|
||||
skipped = 0
|
||||
|
||||
with MYSQL_ENGINE.begin() as conn:
|
||||
for alert in alerts:
|
||||
try:
|
||||
# 检查是否已存在(使用 INSERT IGNORE 更高效)
|
||||
insert_sql = text("""
|
||||
INSERT IGNORE INTO concept_anomaly_hybrid
|
||||
(concept_id, alert_time, trade_date, alert_type,
|
||||
final_score, rule_score, ml_score, trigger_reason,
|
||||
alpha, alpha_delta, amt_ratio, amt_delta,
|
||||
rank_pct, limit_up_ratio, stock_count, total_amt,
|
||||
triggered_rules)
|
||||
VALUES
|
||||
(:concept_id, :alert_time, :trade_date, :alert_type,
|
||||
:final_score, :rule_score, :ml_score, :trigger_reason,
|
||||
:alpha, :alpha_delta, :amt_ratio, :amt_delta,
|
||||
:rank_pct, :limit_up_ratio, :stock_count, :total_amt,
|
||||
:triggered_rules)
|
||||
""")
|
||||
|
||||
result = conn.execute(insert_sql, {
|
||||
'concept_id': alert['concept_id'],
|
||||
'alert_time': alert['alert_time'],
|
||||
'trade_date': alert['trade_date'],
|
||||
'alert_type': alert['alert_type'],
|
||||
'final_score': alert['final_score'],
|
||||
'rule_score': alert['rule_score'],
|
||||
'ml_score': alert['ml_score'],
|
||||
'trigger_reason': alert['trigger_reason'],
|
||||
'alpha': alert.get('alpha', 0),
|
||||
'alpha_delta': alert.get('alpha_delta', 0),
|
||||
'amt_ratio': alert.get('amt_ratio', 1),
|
||||
'amt_delta': alert.get('amt_delta', 0),
|
||||
'rank_pct': alert.get('rank_pct', 0.5),
|
||||
'limit_up_ratio': alert.get('limit_up_ratio', 0),
|
||||
'stock_count': alert.get('stock_count', 0),
|
||||
'total_amt': alert.get('total_amt', 0),
|
||||
'triggered_rules': json.dumps(alert.get('triggered_rules', []), ensure_ascii=False),
|
||||
})
|
||||
|
||||
if result.rowcount > 0:
|
||||
saved += 1
|
||||
else:
|
||||
skipped += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f" 保存失败: {alert['concept_id']} @ {alert['alert_time']} - {e}")
|
||||
|
||||
if skipped > 0:
|
||||
print(f" 跳过 {skipped} 条重复记录")
|
||||
|
||||
return saved
|
||||
|
||||
|
||||
def clear_alerts_by_date(trade_date: str) -> int:
|
||||
"""清除指定日期的异动记录(用于重新回测)"""
|
||||
with MYSQL_ENGINE.begin() as conn:
|
||||
result = conn.execute(
|
||||
text("DELETE FROM concept_anomaly_hybrid WHERE trade_date = :trade_date"),
|
||||
{'trade_date': trade_date}
|
||||
)
|
||||
return result.rowcount
|
||||
|
||||
|
||||
def analyze_alerts(alerts: List[Dict]):
|
||||
if not alerts:
|
||||
print("无异动")
|
||||
return
|
||||
|
||||
df = pd.DataFrame(alerts)
|
||||
print(f"\n总异动: {len(alerts)}")
|
||||
print(f"\n类型分布:\n{df['alert_type'].value_counts()}")
|
||||
print(f"\n得分统计:")
|
||||
print(f" 最终: {df['final_score'].mean():.1f} (max: {df['final_score'].max():.1f})")
|
||||
print(f" 规则: {df['rule_score'].mean():.1f} (max: {df['rule_score'].max():.1f})")
|
||||
print(f" ML: {df['ml_score'].mean():.1f} (max: {df['ml_score'].max():.1f})")
|
||||
|
||||
trigger_type = df['trigger_reason'].apply(
|
||||
lambda x: '规则' if '规则' in x else ('ML' if 'ML' in x else '融合')
|
||||
)
|
||||
print(f"\n触发来源:\n{trigger_type.value_counts()}")
|
||||
|
||||
|
||||
# ==================== 主函数 ====================
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='快速融合异动回测')
|
||||
parser.add_argument('--data_dir', default='ml/data')
|
||||
parser.add_argument('--checkpoint_dir', default='ml/checkpoints')
|
||||
parser.add_argument('--start', required=True)
|
||||
parser.add_argument('--end', default=None)
|
||||
parser.add_argument('--dry-run', action='store_true', help='模拟运行,不写入数据库')
|
||||
parser.add_argument('--export-csv', default=None, help='导出 CSV 文件路径')
|
||||
parser.add_argument('--save-db', action='store_true', help='保存结果到数据库')
|
||||
parser.add_argument('--clear-first', action='store_true', help='写入前先清除该日期的旧数据')
|
||||
parser.add_argument('--device', default='auto')
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.end is None:
|
||||
args.end = args.start
|
||||
|
||||
print("=" * 60)
|
||||
print("快速融合异动回测")
|
||||
print("=" * 60)
|
||||
print(f"日期: {args.start} ~ {args.end}")
|
||||
print(f"设备: {args.device}")
|
||||
print(f"保存数据库: {args.save_db}")
|
||||
print("=" * 60)
|
||||
|
||||
# 初始化数据库表(如果需要保存)
|
||||
if args.save_db and not args.dry_run:
|
||||
init_db_table()
|
||||
|
||||
# 初始化 ML 评分器
|
||||
ml_scorer = FastMLScorer(args.checkpoint_dir, args.device)
|
||||
|
||||
# 获取日期
|
||||
dates = get_available_dates(args.data_dir, args.start, args.end)
|
||||
if not dates:
|
||||
print("无数据")
|
||||
return
|
||||
|
||||
print(f"找到 {len(dates)} 天数据\n")
|
||||
|
||||
# 回测(支持跨日序列)
|
||||
all_alerts = []
|
||||
total_saved = 0
|
||||
prev_df = None # 缓存前一天数据
|
||||
|
||||
for i, date in enumerate(tqdm(dates, desc="回测")):
|
||||
df = load_daily_features(args.data_dir, date)
|
||||
if df is None or df.empty:
|
||||
prev_df = None # 当天无数据,清空缓存
|
||||
continue
|
||||
|
||||
# 第一天需要加载前一天数据(如果存在)
|
||||
if i == 0 and prev_df is None:
|
||||
prev_date = get_prev_trading_day(args.data_dir, date)
|
||||
if prev_date:
|
||||
prev_df = load_daily_features(args.data_dir, prev_date)
|
||||
if prev_df is not None:
|
||||
tqdm.write(f" 加载前一天数据: {prev_date}")
|
||||
|
||||
alerts = backtest_single_day_fast(ml_scorer, df, date, CONFIG, prev_df)
|
||||
all_alerts.extend(alerts)
|
||||
|
||||
# 保存到数据库
|
||||
if args.save_db and alerts:
|
||||
if args.clear_first and not args.dry_run:
|
||||
cleared = clear_alerts_by_date(date)
|
||||
if cleared > 0:
|
||||
tqdm.write(f" 清除 {date} 旧数据: {cleared} 条")
|
||||
|
||||
saved = save_alerts_to_mysql(alerts, dry_run=args.dry_run)
|
||||
total_saved += saved
|
||||
tqdm.write(f" {date}: {len(alerts)} 个异动, 保存 {saved} 条")
|
||||
elif alerts:
|
||||
tqdm.write(f" {date}: {len(alerts)} 个异动")
|
||||
|
||||
# 当天数据成为下一天的 prev_df
|
||||
prev_df = df
|
||||
|
||||
# 导出 CSV
|
||||
if args.export_csv:
|
||||
export_to_csv(all_alerts, args.export_csv)
|
||||
|
||||
# 分析
|
||||
analyze_alerts(all_alerts)
|
||||
|
||||
print(f"\n总计: {len(all_alerts)} 个异动")
|
||||
if args.save_db:
|
||||
print(f"已保存到数据库: {total_saved} 条")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -93,12 +93,12 @@ def backtest_single_day_hybrid(
|
||||
seq_len: int = 30
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
使用融合检测器回测单天数据
|
||||
使用融合检测器回测单天数据(批量优化版)
|
||||
"""
|
||||
alerts = []
|
||||
|
||||
# 按概念分组
|
||||
grouped = df.groupby('concept_id', sort=False)
|
||||
# 按概念分组,预先构建字典
|
||||
grouped_dict = {cid: cdf for cid, cdf in df.groupby('concept_id', sort=False)}
|
||||
|
||||
# 冷却记录
|
||||
cooldown = {}
|
||||
@@ -114,27 +114,46 @@ def backtest_single_day_hybrid(
|
||||
current_time = all_timestamps[t_idx]
|
||||
window_start_time = all_timestamps[t_idx - seq_len + 1]
|
||||
|
||||
minute_alerts = []
|
||||
# 批量收集该时刻所有候选概念
|
||||
batch_sequences = []
|
||||
batch_features = []
|
||||
batch_infos = []
|
||||
|
||||
for concept_id, concept_df in grouped_dict.items():
|
||||
# 检查冷却(提前过滤)
|
||||
if concept_id in cooldown:
|
||||
last_alert = cooldown[concept_id]
|
||||
if isinstance(current_time, datetime):
|
||||
time_diff = (current_time - last_alert).total_seconds() / 60
|
||||
else:
|
||||
time_diff = BACKTEST_CONFIG['cooldown_minutes'] + 1
|
||||
if time_diff < BACKTEST_CONFIG['cooldown_minutes']:
|
||||
continue
|
||||
|
||||
for concept_id, concept_df in grouped:
|
||||
# 获取时间窗口内的数据
|
||||
mask = (concept_df['timestamp'] >= window_start_time) & (concept_df['timestamp'] <= current_time)
|
||||
window_df = concept_df[mask].sort_values('timestamp')
|
||||
window_df = concept_df.loc[mask]
|
||||
|
||||
if len(window_df) < seq_len:
|
||||
continue
|
||||
|
||||
window_df = window_df.tail(seq_len)
|
||||
window_df = window_df.sort_values('timestamp').tail(seq_len)
|
||||
|
||||
# 提取特征序列(给 ML 模型)
|
||||
# 当前时刻特征
|
||||
current_row = window_df.iloc[-1]
|
||||
alpha = current_row.get('alpha', 0)
|
||||
|
||||
# 过滤微小波动(提前过滤)
|
||||
if abs(alpha) < BACKTEST_CONFIG['min_alpha_abs']:
|
||||
continue
|
||||
|
||||
# 提取特征序列
|
||||
sequence = window_df[FEATURES].values
|
||||
sequence = np.nan_to_num(sequence, nan=0.0, posinf=0.0, neginf=0.0)
|
||||
sequence = np.clip(sequence, -BACKTEST_CONFIG['clip_value'], BACKTEST_CONFIG['clip_value'])
|
||||
|
||||
# 当前时刻特征(给规则系统)
|
||||
current_row = window_df.iloc[-1]
|
||||
current_features = {
|
||||
'alpha': current_row.get('alpha', 0),
|
||||
'alpha': alpha,
|
||||
'alpha_delta': current_row.get('alpha_delta', 0),
|
||||
'amt_ratio': current_row.get('amt_ratio', 1),
|
||||
'amt_delta': current_row.get('amt_delta', 0),
|
||||
@@ -142,41 +161,79 @@ def backtest_single_day_hybrid(
|
||||
'limit_up_ratio': current_row.get('limit_up_ratio', 0),
|
||||
}
|
||||
|
||||
# 过滤微小波动
|
||||
if abs(current_features['alpha']) < BACKTEST_CONFIG['min_alpha_abs']:
|
||||
batch_sequences.append(sequence)
|
||||
batch_features.append(current_features)
|
||||
batch_infos.append({
|
||||
'concept_id': concept_id,
|
||||
'stock_count': current_row.get('stock_count', 0),
|
||||
'total_amt': current_row.get('total_amt', 0),
|
||||
})
|
||||
|
||||
if not batch_sequences:
|
||||
continue
|
||||
|
||||
# 检查冷却
|
||||
if concept_id in cooldown:
|
||||
last_alert = cooldown[concept_id]
|
||||
if isinstance(current_time, datetime):
|
||||
time_diff = (current_time - last_alert).total_seconds() / 60
|
||||
# 批量 ML 推理
|
||||
sequences_array = np.array(batch_sequences)
|
||||
ml_scores = detector.ml_scorer.score(sequences_array) if detector.ml_scorer.is_ready() else [0.0] * len(batch_sequences)
|
||||
if isinstance(ml_scores, float):
|
||||
ml_scores = [ml_scores]
|
||||
|
||||
# 批量规则评分 + 融合
|
||||
minute_alerts = []
|
||||
for i, (features, info) in enumerate(zip(batch_features, batch_infos)):
|
||||
concept_id = info['concept_id']
|
||||
|
||||
# 规则评分
|
||||
rule_score, rule_details = detector.rule_scorer.score(features)
|
||||
|
||||
# ML 评分
|
||||
ml_score = ml_scores[i] if i < len(ml_scores) else 0.0
|
||||
|
||||
# 融合
|
||||
w1 = detector.config['rule_weight']
|
||||
w2 = detector.config['ml_weight']
|
||||
final_score = w1 * rule_score + w2 * ml_score
|
||||
|
||||
# 判断是否异动
|
||||
is_anomaly = False
|
||||
trigger_reason = ''
|
||||
|
||||
if rule_score >= detector.config['rule_trigger']:
|
||||
is_anomaly = True
|
||||
trigger_reason = f'规则强信号({rule_score:.0f}分)'
|
||||
elif ml_score >= detector.config['ml_trigger']:
|
||||
is_anomaly = True
|
||||
trigger_reason = f'ML强信号({ml_score:.0f}分)'
|
||||
elif final_score >= detector.config['fusion_trigger']:
|
||||
is_anomaly = True
|
||||
trigger_reason = f'融合触发({final_score:.0f}分)'
|
||||
|
||||
if not is_anomaly:
|
||||
continue
|
||||
|
||||
# 异动类型
|
||||
alpha = features.get('alpha', 0)
|
||||
if alpha >= 1.5:
|
||||
anomaly_type = 'surge_up'
|
||||
elif alpha <= -1.5:
|
||||
anomaly_type = 'surge_down'
|
||||
elif features.get('amt_ratio', 1) >= 3.0:
|
||||
anomaly_type = 'volume_spike'
|
||||
else:
|
||||
time_diff = BACKTEST_CONFIG['cooldown_minutes'] + 1
|
||||
anomaly_type = 'unknown'
|
||||
|
||||
if time_diff < BACKTEST_CONFIG['cooldown_minutes']:
|
||||
continue
|
||||
|
||||
# 融合检测
|
||||
result = detector.detect(current_features, sequence)
|
||||
|
||||
if not result.is_anomaly:
|
||||
continue
|
||||
|
||||
# 记录异动
|
||||
alert = {
|
||||
'concept_id': concept_id,
|
||||
'alert_time': current_time,
|
||||
'trade_date': date,
|
||||
'alert_type': result.anomaly_type,
|
||||
'final_score': result.final_score,
|
||||
'rule_score': result.rule_score,
|
||||
'ml_score': result.ml_score,
|
||||
'trigger_reason': result.trigger_reason,
|
||||
'triggered_rules': list(result.rule_details.keys()),
|
||||
**current_features,
|
||||
'stock_count': current_row.get('stock_count', 0),
|
||||
'total_amt': current_row.get('total_amt', 0),
|
||||
'alert_type': anomaly_type,
|
||||
'final_score': final_score,
|
||||
'rule_score': rule_score,
|
||||
'ml_score': ml_score,
|
||||
'trigger_reason': trigger_reason,
|
||||
'triggered_rules': list(rule_details.keys()),
|
||||
**features,
|
||||
**info,
|
||||
}
|
||||
|
||||
minute_alerts.append(alert)
|
||||
@@ -341,6 +398,8 @@ def main():
|
||||
help='规则权重 (0-1)')
|
||||
parser.add_argument('--ml-weight', type=float, default=0.4,
|
||||
help='ML权重 (0-1)')
|
||||
parser.add_argument('--device', type=str, default='cuda',
|
||||
help='设备 (cuda/cpu),默认 cuda')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -355,15 +414,19 @@ def main():
|
||||
print(f"模型目录: {args.checkpoint_dir}")
|
||||
print(f"规则权重: {args.rule_weight}")
|
||||
print(f"ML权重: {args.ml_weight}")
|
||||
print(f"设备: {args.device}")
|
||||
print(f"Dry Run: {args.dry_run}")
|
||||
print("=" * 60)
|
||||
|
||||
# 初始化融合检测器
|
||||
# 初始化融合检测器(使用 GPU)
|
||||
config = {
|
||||
'rule_weight': args.rule_weight,
|
||||
'ml_weight': args.ml_weight,
|
||||
}
|
||||
detector = create_detector(args.checkpoint_dir, config)
|
||||
|
||||
# 修改 detector.py 中 MLScorer 的设备
|
||||
from detector import HybridAnomalyDetector
|
||||
detector = HybridAnomalyDetector(config, args.checkpoint_dir, device=args.device)
|
||||
|
||||
# 获取可用日期
|
||||
dates = get_available_dates(args.data_dir, args.start, args.end)
|
||||
|
||||
@@ -243,9 +243,12 @@ class MLScorer:
|
||||
):
|
||||
self.checkpoint_dir = Path(checkpoint_dir)
|
||||
|
||||
# 设备
|
||||
# 设备检测
|
||||
if device == 'auto':
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
elif device == 'cuda' and not torch.cuda.is_available():
|
||||
print("警告: CUDA 不可用,使用 CPU")
|
||||
self.device = torch.device('cpu')
|
||||
else:
|
||||
self.device = torch.device(device)
|
||||
|
||||
@@ -276,8 +279,8 @@ class MLScorer:
|
||||
with open(config_path, 'r') as f:
|
||||
self.config = json.load(f)
|
||||
|
||||
# 加载模型
|
||||
checkpoint = torch.load(model_path, map_location=self.device)
|
||||
# 先用 CPU 加载模型(避免 CUDA 不可用问题),再移动到目标设备
|
||||
checkpoint = torch.load(model_path, map_location='cpu')
|
||||
|
||||
model_config = self.config.get('model', {}) if self.config else {}
|
||||
self.model = create_model(model_config)
|
||||
@@ -294,6 +297,8 @@ class MLScorer:
|
||||
|
||||
except Exception as e:
|
||||
print(f"警告: 模型加载失败 - {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
self.model = None
|
||||
|
||||
def is_ready(self) -> bool:
|
||||
@@ -551,7 +556,8 @@ if __name__ == "__main__":
|
||||
},
|
||||
]
|
||||
|
||||
print("\n测试结果:")
|
||||
print("\n" + "-" * 60)
|
||||
print("测试1: 只用规则(无序列数据)")
|
||||
print("-" * 60)
|
||||
|
||||
for case in test_cases:
|
||||
@@ -567,5 +573,63 @@ if __name__ == "__main__":
|
||||
print(f" 异动类型: {result.anomaly_type}")
|
||||
print(f" 触发规则: {list(result.rule_details.keys())}")
|
||||
|
||||
# 测试2: 带序列数据的融合检测
|
||||
print("\n" + "-" * 60)
|
||||
print("测试2: 融合检测(规则 + ML)")
|
||||
print("-" * 60)
|
||||
|
||||
# 生成模拟序列数据
|
||||
seq_len = 30
|
||||
n_features = 6
|
||||
|
||||
# 正常序列:小幅波动
|
||||
normal_sequence = np.random.randn(seq_len, n_features) * 0.3
|
||||
normal_sequence[:, 0] = np.linspace(0, 0.5, seq_len) # alpha 缓慢上升
|
||||
normal_sequence[:, 2] = np.abs(normal_sequence[:, 2]) + 1 # amt_ratio > 0
|
||||
|
||||
# 异常序列:最后几个时间步突然变化
|
||||
anomaly_sequence = np.random.randn(seq_len, n_features) * 0.3
|
||||
anomaly_sequence[-5:, 0] = np.linspace(1, 4, 5) # alpha 突然飙升
|
||||
anomaly_sequence[-5:, 1] = np.linspace(0.2, 1.5, 5) # alpha_delta 加速
|
||||
anomaly_sequence[-5:, 2] = np.linspace(2, 6, 5) # amt_ratio 放量
|
||||
anomaly_sequence[:, 2] = np.abs(anomaly_sequence[:, 2]) + 1
|
||||
|
||||
# 测试正常序列
|
||||
normal_features = {
|
||||
'alpha': float(normal_sequence[-1, 0]),
|
||||
'alpha_delta': float(normal_sequence[-1, 1]),
|
||||
'amt_ratio': float(normal_sequence[-1, 2]),
|
||||
'amt_delta': float(normal_sequence[-1, 3]),
|
||||
'rank_pct': 0.5,
|
||||
'limit_up_ratio': 0.02
|
||||
}
|
||||
|
||||
result = detector.detect(normal_features, normal_sequence)
|
||||
print(f"\n正常序列:")
|
||||
print(f" 异动: {'是' if result.is_anomaly else '否'}")
|
||||
print(f" 最终得分: {result.final_score:.1f}")
|
||||
print(f" 规则得分: {result.rule_score:.1f}")
|
||||
print(f" ML得分: {result.ml_score:.1f}")
|
||||
|
||||
# 测试异常序列
|
||||
anomaly_features = {
|
||||
'alpha': float(anomaly_sequence[-1, 0]),
|
||||
'alpha_delta': float(anomaly_sequence[-1, 1]),
|
||||
'amt_ratio': float(anomaly_sequence[-1, 2]),
|
||||
'amt_delta': float(anomaly_sequence[-1, 3]),
|
||||
'rank_pct': 0.95,
|
||||
'limit_up_ratio': 0.15
|
||||
}
|
||||
|
||||
result = detector.detect(anomaly_features, anomaly_sequence)
|
||||
print(f"\n异常序列:")
|
||||
print(f" 异动: {'是' if result.is_anomaly else '否'}")
|
||||
print(f" 最终得分: {result.final_score:.1f}")
|
||||
print(f" 规则得分: {result.rule_score:.1f}")
|
||||
print(f" ML得分: {result.ml_score:.1f}")
|
||||
if result.is_anomaly:
|
||||
print(f" 触发原因: {result.trigger_reason}")
|
||||
print(f" 异动类型: {result.anomaly_type}")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("测试完成!")
|
||||
|
||||
1518
ml/realtime_detector.py
Normal file
1518
ml/realtime_detector.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -346,7 +346,173 @@ export const marketHandlers = [
|
||||
});
|
||||
}),
|
||||
|
||||
// 11. 市场统计数据(个股中心页面使用)
|
||||
// 11. 热点概览数据(大盘分时 + 概念异动)
|
||||
http.get('/api/market/hotspot-overview', async ({ request }) => {
|
||||
await delay(300);
|
||||
const url = new URL(request.url);
|
||||
const date = url.searchParams.get('date');
|
||||
|
||||
const tradeDate = date || new Date().toISOString().split('T')[0];
|
||||
|
||||
// 生成分时数据(240个点,9:30-11:30 + 13:00-15:00)
|
||||
const timeline = [];
|
||||
const basePrice = 3900 + Math.random() * 100; // 基准价格 3900-4000
|
||||
const prevClose = basePrice;
|
||||
let currentPrice = basePrice;
|
||||
let cumulativeVolume = 0;
|
||||
|
||||
// 上午时段 9:30-11:30 (120分钟)
|
||||
for (let i = 0; i < 120; i++) {
|
||||
const hour = 9 + Math.floor((i + 30) / 60);
|
||||
const minute = (i + 30) % 60;
|
||||
const time = `${hour.toString().padStart(2, '0')}:${minute.toString().padStart(2, '0')}`;
|
||||
|
||||
// 模拟价格波动
|
||||
const volatility = 0.002; // 0.2%波动
|
||||
const drift = (Math.random() - 0.5) * 0.001; // 微小趋势
|
||||
currentPrice = currentPrice * (1 + (Math.random() - 0.5) * volatility + drift);
|
||||
|
||||
const volume = Math.floor(Math.random() * 500000 + 100000); // 成交量
|
||||
cumulativeVolume += volume;
|
||||
|
||||
timeline.push({
|
||||
time,
|
||||
price: parseFloat(currentPrice.toFixed(2)),
|
||||
volume: cumulativeVolume,
|
||||
change_pct: parseFloat(((currentPrice - prevClose) / prevClose * 100).toFixed(2))
|
||||
});
|
||||
}
|
||||
|
||||
// 下午时段 13:00-15:00 (120分钟)
|
||||
for (let i = 0; i < 120; i++) {
|
||||
const hour = 13 + Math.floor(i / 60);
|
||||
const minute = i % 60;
|
||||
const time = `${hour.toString().padStart(2, '0')}:${minute.toString().padStart(2, '0')}`;
|
||||
|
||||
// 下午波动略小
|
||||
const volatility = 0.0015;
|
||||
const drift = (Math.random() - 0.5) * 0.0008;
|
||||
currentPrice = currentPrice * (1 + (Math.random() - 0.5) * volatility + drift);
|
||||
|
||||
const volume = Math.floor(Math.random() * 400000 + 80000);
|
||||
cumulativeVolume += volume;
|
||||
|
||||
timeline.push({
|
||||
time,
|
||||
price: parseFloat(currentPrice.toFixed(2)),
|
||||
volume: cumulativeVolume,
|
||||
change_pct: parseFloat(((currentPrice - prevClose) / prevClose * 100).toFixed(2))
|
||||
});
|
||||
}
|
||||
|
||||
// 生成概念异动数据
|
||||
const conceptNames = [
|
||||
'人工智能', 'AI眼镜', '机器人', '核电', '国企', '卫星导航',
|
||||
'福建自贸区', '两岸融合', 'CRO', '三季报增长', '百货零售',
|
||||
'人形机器人', '央企', '数据中心', 'CPO', '新能源', '电网设备',
|
||||
'氢能源', '算力租赁', '厦门国资', '乳业', '低空安防', '创新药',
|
||||
'商业航天', '控制权变更', '文化传媒', '海峡两岸'
|
||||
];
|
||||
|
||||
const alertTypes = ['surge_up', 'surge_down', 'volume_spike', 'limit_up', 'rank_jump'];
|
||||
|
||||
// 生成 15-25 个异动
|
||||
const alertCount = Math.floor(Math.random() * 10) + 15;
|
||||
const alerts = [];
|
||||
const usedTimes = new Set();
|
||||
|
||||
for (let i = 0; i < alertCount; i++) {
|
||||
// 随机选择一个时间点
|
||||
let timeIdx;
|
||||
let attempts = 0;
|
||||
do {
|
||||
timeIdx = Math.floor(Math.random() * timeline.length);
|
||||
attempts++;
|
||||
} while (usedTimes.has(timeIdx) && attempts < 50);
|
||||
|
||||
if (attempts >= 50) continue;
|
||||
|
||||
// 同一时间可以有多个异动
|
||||
const time = timeline[timeIdx].time;
|
||||
const conceptName = conceptNames[Math.floor(Math.random() * conceptNames.length)];
|
||||
const alertType = alertTypes[Math.floor(Math.random() * alertTypes.length)];
|
||||
|
||||
// 根据类型生成 alpha
|
||||
let alpha;
|
||||
if (alertType === 'surge_up') {
|
||||
alpha = parseFloat((Math.random() * 3 + 2).toFixed(2)); // +2% ~ +5%
|
||||
} else if (alertType === 'surge_down') {
|
||||
alpha = parseFloat((-Math.random() * 3 - 1.5).toFixed(2)); // -1.5% ~ -4.5%
|
||||
} else {
|
||||
alpha = parseFloat((Math.random() * 4 - 1).toFixed(2)); // -1% ~ +3%
|
||||
}
|
||||
|
||||
const finalScore = Math.floor(Math.random() * 40 + 45); // 45-85分
|
||||
const ruleScore = Math.floor(Math.random() * 30 + 40);
|
||||
const mlScore = Math.floor(Math.random() * 30 + 40);
|
||||
|
||||
alerts.push({
|
||||
concept_id: `CONCEPT_${1000 + i}`,
|
||||
concept_name: conceptName,
|
||||
time,
|
||||
alert_type: alertType,
|
||||
alpha,
|
||||
alpha_delta: parseFloat((Math.random() * 2 - 0.5).toFixed(2)),
|
||||
amt_ratio: parseFloat((Math.random() * 5 + 1).toFixed(2)),
|
||||
limit_up_count: alertType === 'limit_up' ? Math.floor(Math.random() * 5 + 1) : 0,
|
||||
limit_up_ratio: parseFloat((Math.random() * 0.3).toFixed(3)),
|
||||
final_score: finalScore,
|
||||
rule_score: ruleScore,
|
||||
ml_score: mlScore,
|
||||
trigger_reason: finalScore >= 65 ? '规则强信号' : (mlScore >= 70 ? 'ML强信号' : '融合触发'),
|
||||
importance_score: parseFloat((finalScore / 100).toFixed(2)),
|
||||
index_price: timeline[timeIdx].price
|
||||
});
|
||||
}
|
||||
|
||||
// 按时间排序
|
||||
alerts.sort((a, b) => a.time.localeCompare(b.time));
|
||||
|
||||
// 统计异动类型
|
||||
const alertSummary = alerts.reduce((acc, alert) => {
|
||||
acc[alert.alert_type] = (acc[alert.alert_type] || 0) + 1;
|
||||
return acc;
|
||||
}, {});
|
||||
|
||||
// 计算指数统计
|
||||
const prices = timeline.map(t => t.price);
|
||||
const latestPrice = prices[prices.length - 1];
|
||||
const highPrice = Math.max(...prices);
|
||||
const lowPrice = Math.min(...prices);
|
||||
const changePct = ((latestPrice - prevClose) / prevClose * 100);
|
||||
|
||||
console.log('[Mock Market] 获取热点概览数据:', {
|
||||
date: tradeDate,
|
||||
timelinePoints: timeline.length,
|
||||
alertCount: alerts.length
|
||||
});
|
||||
|
||||
return HttpResponse.json({
|
||||
success: true,
|
||||
data: {
|
||||
index: {
|
||||
code: '000001.SH',
|
||||
name: '上证指数',
|
||||
latest_price: latestPrice,
|
||||
prev_close: prevClose,
|
||||
high: highPrice,
|
||||
low: lowPrice,
|
||||
change_pct: parseFloat(changePct.toFixed(2)),
|
||||
timeline
|
||||
},
|
||||
alerts,
|
||||
alert_summary: alertSummary
|
||||
},
|
||||
trade_date: tradeDate
|
||||
});
|
||||
}),
|
||||
|
||||
// 12. 市场统计数据(个股中心页面使用)
|
||||
http.get('/api/market/statistics', async ({ request }) => {
|
||||
await delay(200);
|
||||
const url = new URL(request.url);
|
||||
|
||||
@@ -0,0 +1,147 @@
|
||||
/**
|
||||
* 异动统计摘要组件
|
||||
* 展示指数统计和异动类型统计
|
||||
*/
|
||||
import React from 'react';
|
||||
import {
|
||||
Box,
|
||||
HStack,
|
||||
VStack,
|
||||
Text,
|
||||
Badge,
|
||||
Icon,
|
||||
Stat,
|
||||
StatLabel,
|
||||
StatNumber,
|
||||
StatHelpText,
|
||||
StatArrow,
|
||||
SimpleGrid,
|
||||
useColorModeValue,
|
||||
} from '@chakra-ui/react';
|
||||
import { FaBolt, FaArrowDown, FaRocket, FaChartLine, FaFire, FaVolumeUp } from 'react-icons/fa';
|
||||
|
||||
/**
|
||||
* 异动类型徽章
|
||||
*/
|
||||
const AlertTypeBadge = ({ type, count }) => {
|
||||
const config = {
|
||||
surge: { label: '急涨', color: 'red', icon: FaBolt },
|
||||
surge_up: { label: '暴涨', color: 'red', icon: FaBolt },
|
||||
surge_down: { label: '暴跌', color: 'green', icon: FaArrowDown },
|
||||
limit_up: { label: '涨停', color: 'orange', icon: FaRocket },
|
||||
rank_jump: { label: '排名跃升', color: 'blue', icon: FaChartLine },
|
||||
volume_spike: { label: '放量', color: 'purple', icon: FaVolumeUp },
|
||||
};
|
||||
|
||||
const cfg = config[type] || { label: type, color: 'gray', icon: FaFire };
|
||||
|
||||
return (
|
||||
<Badge colorScheme={cfg.color} variant="subtle" px={2} py={1} borderRadius="md">
|
||||
<HStack spacing={1}>
|
||||
<Icon as={cfg.icon} boxSize={3} />
|
||||
<Text>{cfg.label}</Text>
|
||||
<Text fontWeight="bold">{count}</Text>
|
||||
</HStack>
|
||||
</Badge>
|
||||
);
|
||||
};
|
||||
|
||||
/**
|
||||
* 指数统计卡片
|
||||
*/
|
||||
const IndexStatCard = ({ indexData }) => {
|
||||
const cardBg = useColorModeValue('white', '#1a1a1a');
|
||||
const borderColor = useColorModeValue('gray.200', '#333');
|
||||
const subTextColor = useColorModeValue('gray.600', 'gray.400');
|
||||
|
||||
if (!indexData) return null;
|
||||
|
||||
const changePct = indexData.change_pct || 0;
|
||||
const isUp = changePct >= 0;
|
||||
|
||||
return (
|
||||
<SimpleGrid columns={{ base: 2, md: 4 }} spacing={4}>
|
||||
<Stat size="sm">
|
||||
<StatLabel color={subTextColor}>{indexData.name || '上证指数'}</StatLabel>
|
||||
<StatNumber fontSize="xl" color={isUp ? 'red.500' : 'green.500'}>
|
||||
{indexData.latest_price?.toFixed(2) || '-'}
|
||||
</StatNumber>
|
||||
<StatHelpText mb={0}>
|
||||
<StatArrow type={isUp ? 'increase' : 'decrease'} />
|
||||
{changePct?.toFixed(2)}%
|
||||
</StatHelpText>
|
||||
</Stat>
|
||||
|
||||
<Stat size="sm">
|
||||
<StatLabel color={subTextColor}>最高</StatLabel>
|
||||
<StatNumber fontSize="xl" color="red.500">
|
||||
{indexData.high?.toFixed(2) || '-'}
|
||||
</StatNumber>
|
||||
</Stat>
|
||||
|
||||
<Stat size="sm">
|
||||
<StatLabel color={subTextColor}>最低</StatLabel>
|
||||
<StatNumber fontSize="xl" color="green.500">
|
||||
{indexData.low?.toFixed(2) || '-'}
|
||||
</StatNumber>
|
||||
</Stat>
|
||||
|
||||
<Stat size="sm">
|
||||
<StatLabel color={subTextColor}>振幅</StatLabel>
|
||||
<StatNumber fontSize="xl" color="purple.500">
|
||||
{indexData.high && indexData.low && indexData.prev_close
|
||||
? (((indexData.high - indexData.low) / indexData.prev_close) * 100).toFixed(2) + '%'
|
||||
: '-'}
|
||||
</StatNumber>
|
||||
</Stat>
|
||||
</SimpleGrid>
|
||||
);
|
||||
};
|
||||
|
||||
/**
|
||||
* 异动统计摘要
|
||||
* @param {Object} props
|
||||
* @param {Object} props.indexData - 指数数据
|
||||
* @param {Array} props.alerts - 异动数组
|
||||
* @param {Object} props.alertSummary - 异动类型统计
|
||||
*/
|
||||
const AlertSummary = ({ indexData, alerts = [], alertSummary = {} }) => {
|
||||
const cardBg = useColorModeValue('white', '#1a1a1a');
|
||||
const borderColor = useColorModeValue('gray.200', '#333');
|
||||
|
||||
// 如果没有 alertSummary,从 alerts 中统计
|
||||
const summary = alertSummary && Object.keys(alertSummary).length > 0
|
||||
? alertSummary
|
||||
: alerts.reduce((acc, alert) => {
|
||||
const type = alert.alert_type || 'unknown';
|
||||
acc[type] = (acc[type] || 0) + 1;
|
||||
return acc;
|
||||
}, {});
|
||||
|
||||
const totalAlerts = alerts.length;
|
||||
|
||||
return (
|
||||
<VStack spacing={4} align="stretch">
|
||||
{/* 指数统计 */}
|
||||
<IndexStatCard indexData={indexData} />
|
||||
|
||||
{/* 异动统计 */}
|
||||
{totalAlerts > 0 && (
|
||||
<HStack spacing={2} flexWrap="wrap">
|
||||
<Text fontSize="sm" color="gray.500" mr={2}>
|
||||
异动 {totalAlerts} 次:
|
||||
</Text>
|
||||
{(summary.surge_up > 0 || summary.surge > 0) && (
|
||||
<AlertTypeBadge type="surge_up" count={(summary.surge_up || 0) + (summary.surge || 0)} />
|
||||
)}
|
||||
{summary.surge_down > 0 && <AlertTypeBadge type="surge_down" count={summary.surge_down} />}
|
||||
{summary.limit_up > 0 && <AlertTypeBadge type="limit_up" count={summary.limit_up} />}
|
||||
{summary.volume_spike > 0 && <AlertTypeBadge type="volume_spike" count={summary.volume_spike} />}
|
||||
{summary.rank_jump > 0 && <AlertTypeBadge type="rank_jump" count={summary.rank_jump} />}
|
||||
</HStack>
|
||||
)}
|
||||
</VStack>
|
||||
);
|
||||
};
|
||||
|
||||
export default AlertSummary;
|
||||
@@ -0,0 +1,194 @@
|
||||
/**
|
||||
* 概念异动列表组件
|
||||
* 展示当日的概念异动记录
|
||||
*/
|
||||
import React from 'react';
|
||||
import {
|
||||
Box,
|
||||
VStack,
|
||||
HStack,
|
||||
Text,
|
||||
Badge,
|
||||
Icon,
|
||||
Tooltip,
|
||||
useColorModeValue,
|
||||
Flex,
|
||||
Divider,
|
||||
} from '@chakra-ui/react';
|
||||
import { FaBolt, FaArrowUp, FaArrowDown, FaChartLine, FaFire, FaVolumeUp } from 'react-icons/fa';
|
||||
import { getAlertTypeLabel, formatScore, getScoreColor } from '../utils/chartHelpers';
|
||||
|
||||
/**
|
||||
* 单个异动项组件
|
||||
*/
|
||||
const AlertItem = ({ alert, onClick, isSelected }) => {
|
||||
const bgColor = useColorModeValue('white', '#1a1a1a');
|
||||
const hoverBg = useColorModeValue('gray.50', '#2a2a2a');
|
||||
const borderColor = useColorModeValue('gray.200', '#333');
|
||||
const selectedBg = useColorModeValue('purple.50', '#2a2a3a');
|
||||
|
||||
const isUp = alert.alert_type !== 'surge_down';
|
||||
const typeColor = isUp ? 'red' : 'green';
|
||||
|
||||
// 获取异动类型图标
|
||||
const getTypeIcon = (type) => {
|
||||
switch (type) {
|
||||
case 'surge_up':
|
||||
case 'surge':
|
||||
return FaArrowUp;
|
||||
case 'surge_down':
|
||||
return FaArrowDown;
|
||||
case 'limit_up':
|
||||
return FaFire;
|
||||
case 'volume_spike':
|
||||
return FaVolumeUp;
|
||||
case 'rank_jump':
|
||||
return FaChartLine;
|
||||
default:
|
||||
return FaBolt;
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<Box
|
||||
p={3}
|
||||
bg={isSelected ? selectedBg : bgColor}
|
||||
borderRadius="md"
|
||||
borderWidth="1px"
|
||||
borderColor={isSelected ? 'purple.400' : borderColor}
|
||||
cursor="pointer"
|
||||
transition="all 0.2s"
|
||||
_hover={{ bg: hoverBg, transform: 'translateX(4px)' }}
|
||||
onClick={() => onClick?.(alert)}
|
||||
>
|
||||
<Flex justify="space-between" align="flex-start">
|
||||
{/* 左侧:概念名称和时间 */}
|
||||
<VStack align="start" spacing={1} flex={1}>
|
||||
<HStack spacing={2}>
|
||||
<Icon as={getTypeIcon(alert.alert_type)} color={`${typeColor}.500`} boxSize={4} />
|
||||
<Text fontWeight="bold" fontSize="sm" noOfLines={1}>
|
||||
{alert.concept_name}
|
||||
</Text>
|
||||
</HStack>
|
||||
<HStack spacing={2} fontSize="xs" color="gray.500">
|
||||
<Text>{alert.time}</Text>
|
||||
<Badge colorScheme={typeColor} size="sm" variant="subtle">
|
||||
{getAlertTypeLabel(alert.alert_type)}
|
||||
</Badge>
|
||||
</HStack>
|
||||
</VStack>
|
||||
|
||||
{/* 右侧:分数和关键指标 */}
|
||||
<VStack align="end" spacing={1}>
|
||||
{/* 综合得分 */}
|
||||
{alert.final_score !== undefined && (
|
||||
<Tooltip label={`规则: ${formatScore(alert.rule_score)} / ML: ${formatScore(alert.ml_score)}`}>
|
||||
<Badge
|
||||
px={2}
|
||||
py={1}
|
||||
borderRadius="full"
|
||||
bg={getScoreColor(alert.final_score)}
|
||||
color="white"
|
||||
fontSize="xs"
|
||||
fontWeight="bold"
|
||||
>
|
||||
{formatScore(alert.final_score)}分
|
||||
</Badge>
|
||||
</Tooltip>
|
||||
)}
|
||||
|
||||
{/* Alpha 值 */}
|
||||
{alert.alpha !== undefined && (
|
||||
<Text fontSize="xs" color={alert.alpha >= 0 ? 'red.500' : 'green.500'} fontWeight="medium">
|
||||
α {alert.alpha >= 0 ? '+' : ''}{alert.alpha.toFixed(2)}%
|
||||
</Text>
|
||||
)}
|
||||
|
||||
{/* 涨停数量 */}
|
||||
{alert.limit_up_count > 0 && (
|
||||
<HStack spacing={1}>
|
||||
<Icon as={FaFire} color="orange.500" boxSize={3} />
|
||||
<Text fontSize="xs" color="orange.500">
|
||||
涨停 {alert.limit_up_count}
|
||||
</Text>
|
||||
</HStack>
|
||||
)}
|
||||
</VStack>
|
||||
</Flex>
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
|
||||
/**
|
||||
* 概念异动列表
|
||||
* @param {Object} props
|
||||
* @param {Array} props.alerts - 异动数据数组
|
||||
* @param {Function} props.onAlertClick - 点击异动的回调
|
||||
* @param {Object} props.selectedAlert - 当前选中的异动
|
||||
* @param {number} props.maxHeight - 最大高度
|
||||
*/
|
||||
const ConceptAlertList = ({ alerts = [], onAlertClick, selectedAlert, maxHeight = '400px' }) => {
|
||||
const textColor = useColorModeValue('gray.800', 'white');
|
||||
const subTextColor = useColorModeValue('gray.500', 'gray.400');
|
||||
|
||||
if (!alerts || alerts.length === 0) {
|
||||
return (
|
||||
<Box p={4} textAlign="center">
|
||||
<Text color={subTextColor} fontSize="sm">
|
||||
当日暂无概念异动
|
||||
</Text>
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
|
||||
// 按时间分组
|
||||
const groupedAlerts = alerts.reduce((acc, alert) => {
|
||||
const time = alert.time || '未知时间';
|
||||
if (!acc[time]) {
|
||||
acc[time] = [];
|
||||
}
|
||||
acc[time].push(alert);
|
||||
return acc;
|
||||
}, {});
|
||||
|
||||
// 按时间排序
|
||||
const sortedTimes = Object.keys(groupedAlerts).sort();
|
||||
|
||||
return (
|
||||
<Box maxH={maxHeight} overflowY="auto" pr={2}>
|
||||
<VStack spacing={3} align="stretch">
|
||||
{sortedTimes.map((time, timeIndex) => (
|
||||
<Box key={time}>
|
||||
{/* 时间分隔线 */}
|
||||
{timeIndex > 0 && <Divider my={2} />}
|
||||
|
||||
{/* 时间标签 */}
|
||||
<HStack spacing={2} mb={2}>
|
||||
<Box w={2} h={2} borderRadius="full" bg="purple.500" />
|
||||
<Text fontSize="xs" fontWeight="bold" color={subTextColor}>
|
||||
{time}
|
||||
</Text>
|
||||
<Text fontSize="xs" color={subTextColor}>
|
||||
({groupedAlerts[time].length}个异动)
|
||||
</Text>
|
||||
</HStack>
|
||||
|
||||
{/* 该时间点的异动 */}
|
||||
<VStack spacing={2} align="stretch" pl={4}>
|
||||
{groupedAlerts[time].map((alert, idx) => (
|
||||
<AlertItem
|
||||
key={`${alert.concept_id || alert.concept_name}-${idx}`}
|
||||
alert={alert}
|
||||
onClick={onAlertClick}
|
||||
isSelected={selectedAlert?.concept_id === alert.concept_id && selectedAlert?.time === alert.time}
|
||||
/>
|
||||
))}
|
||||
</VStack>
|
||||
</Box>
|
||||
))}
|
||||
</VStack>
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
|
||||
export default ConceptAlertList;
|
||||
@@ -0,0 +1,264 @@
|
||||
/**
|
||||
* 指数分时图组件
|
||||
* 展示大盘分时走势,支持概念异动标注
|
||||
*/
|
||||
import React, { useRef, useEffect, useCallback, useMemo } from 'react';
|
||||
import { Box, useColorModeValue } from '@chakra-ui/react';
|
||||
import * as echarts from 'echarts';
|
||||
import { getAlertMarkPoints } from '../utils/chartHelpers';
|
||||
|
||||
/**
|
||||
* @param {Object} props
|
||||
* @param {Object} props.indexData - 指数数据 { timeline, prev_close, name, ... }
|
||||
* @param {Array} props.alerts - 异动数据数组
|
||||
* @param {Function} props.onAlertClick - 点击异动标注的回调
|
||||
* @param {string} props.height - 图表高度
|
||||
*/
|
||||
const IndexMinuteChart = ({ indexData, alerts = [], onAlertClick, height = '350px' }) => {
|
||||
const chartRef = useRef(null);
|
||||
const chartInstance = useRef(null);
|
||||
|
||||
const textColor = useColorModeValue('gray.800', 'white');
|
||||
const subTextColor = useColorModeValue('gray.600', 'gray.400');
|
||||
const gridLineColor = useColorModeValue('#eee', '#333');
|
||||
|
||||
// 计算图表配置
|
||||
const chartOption = useMemo(() => {
|
||||
if (!indexData || !indexData.timeline || indexData.timeline.length === 0) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const timeline = indexData.timeline || [];
|
||||
const times = timeline.map((d) => d.time);
|
||||
const prices = timeline.map((d) => d.price);
|
||||
const volumes = timeline.map((d) => d.volume);
|
||||
const changePcts = timeline.map((d) => d.change_pct);
|
||||
|
||||
// 计算Y轴范围
|
||||
const validPrices = prices.filter(Boolean);
|
||||
if (validPrices.length === 0) return null;
|
||||
|
||||
const priceMin = Math.min(...validPrices);
|
||||
const priceMax = Math.max(...validPrices);
|
||||
const priceRange = priceMax - priceMin;
|
||||
const yAxisMin = priceMin - priceRange * 0.1;
|
||||
const yAxisMax = priceMax + priceRange * 0.25; // 上方留更多空间给标注
|
||||
|
||||
// 准备异动标注
|
||||
const markPoints = getAlertMarkPoints(alerts, times, prices, priceMax);
|
||||
|
||||
// 渐变色 - 根据涨跌
|
||||
const latestChangePct = changePcts[changePcts.length - 1] || 0;
|
||||
const isUp = latestChangePct >= 0;
|
||||
const lineColor = isUp ? '#ff4d4d' : '#22c55e';
|
||||
const areaColorStops = isUp
|
||||
? [
|
||||
{ offset: 0, color: 'rgba(255, 77, 77, 0.4)' },
|
||||
{ offset: 1, color: 'rgba(255, 77, 77, 0.05)' },
|
||||
]
|
||||
: [
|
||||
{ offset: 0, color: 'rgba(34, 197, 94, 0.4)' },
|
||||
{ offset: 1, color: 'rgba(34, 197, 94, 0.05)' },
|
||||
];
|
||||
|
||||
return {
|
||||
backgroundColor: 'transparent',
|
||||
tooltip: {
|
||||
trigger: 'axis',
|
||||
axisPointer: {
|
||||
type: 'cross',
|
||||
crossStyle: { color: '#999' },
|
||||
},
|
||||
formatter: (params) => {
|
||||
if (!params || params.length === 0) return '';
|
||||
|
||||
const dataIndex = params[0].dataIndex;
|
||||
const time = times[dataIndex];
|
||||
const price = prices[dataIndex];
|
||||
const changePct = changePcts[dataIndex];
|
||||
const volume = volumes[dataIndex];
|
||||
|
||||
let html = `
|
||||
<div style="padding: 8px;">
|
||||
<div style="font-weight: bold; margin-bottom: 4px;">${time}</div>
|
||||
<div>指数: <span style="color: ${changePct >= 0 ? '#ff4d4d' : '#22c55e'}; font-weight: bold;">${price?.toFixed(2)}</span></div>
|
||||
<div>涨跌: <span style="color: ${changePct >= 0 ? '#ff4d4d' : '#22c55e'};">${changePct >= 0 ? '+' : ''}${changePct?.toFixed(2)}%</span></div>
|
||||
<div>成交量: ${(volume / 10000).toFixed(0)}万手</div>
|
||||
</div>
|
||||
`;
|
||||
|
||||
// 检查是否有异动
|
||||
const alertsAtTime = alerts.filter((a) => a.time === time);
|
||||
if (alertsAtTime.length > 0) {
|
||||
html += '<div style="border-top: 1px solid #eee; margin-top: 4px; padding-top: 4px;">';
|
||||
html += '<div style="font-weight: bold; color: #ff6b6b;">概念异动:</div>';
|
||||
alertsAtTime.forEach((alert) => {
|
||||
const typeLabel = {
|
||||
surge: '急涨',
|
||||
surge_up: '暴涨',
|
||||
surge_down: '暴跌',
|
||||
limit_up: '涨停增加',
|
||||
rank_jump: '排名跃升',
|
||||
volume_spike: '放量',
|
||||
}[alert.alert_type] || alert.alert_type;
|
||||
const typeColor = alert.alert_type === 'surge_down' ? '#2ed573' : '#ff6b6b';
|
||||
const alpha = alert.alpha ? ` (α${alert.alpha > 0 ? '+' : ''}${alert.alpha.toFixed(2)}%)` : '';
|
||||
html += `<div style="color: ${typeColor}">• ${alert.concept_name} (${typeLabel}${alpha})</div>`;
|
||||
});
|
||||
html += '</div>';
|
||||
}
|
||||
|
||||
return html;
|
||||
},
|
||||
},
|
||||
legend: { show: false },
|
||||
grid: [
|
||||
{ left: '8%', right: '3%', top: '8%', height: '58%' },
|
||||
{ left: '8%', right: '3%', top: '72%', height: '18%' },
|
||||
],
|
||||
xAxis: [
|
||||
{
|
||||
type: 'category',
|
||||
data: times,
|
||||
axisLine: { lineStyle: { color: gridLineColor } },
|
||||
axisLabel: {
|
||||
color: subTextColor,
|
||||
fontSize: 10,
|
||||
interval: Math.floor(times.length / 6),
|
||||
},
|
||||
axisTick: { show: false },
|
||||
splitLine: { show: false },
|
||||
},
|
||||
{
|
||||
type: 'category',
|
||||
gridIndex: 1,
|
||||
data: times,
|
||||
axisLine: { lineStyle: { color: gridLineColor } },
|
||||
axisLabel: { show: false },
|
||||
axisTick: { show: false },
|
||||
splitLine: { show: false },
|
||||
},
|
||||
],
|
||||
yAxis: [
|
||||
{
|
||||
type: 'value',
|
||||
min: yAxisMin,
|
||||
max: yAxisMax,
|
||||
axisLine: { show: false },
|
||||
axisLabel: {
|
||||
color: subTextColor,
|
||||
fontSize: 10,
|
||||
formatter: (val) => val.toFixed(0),
|
||||
},
|
||||
splitLine: { lineStyle: { color: gridLineColor, type: 'dashed' } },
|
||||
axisPointer: {
|
||||
label: {
|
||||
formatter: (params) => {
|
||||
if (!indexData.prev_close) return params.value.toFixed(2);
|
||||
const pct = ((params.value - indexData.prev_close) / indexData.prev_close) * 100;
|
||||
return `${params.value.toFixed(2)} (${pct >= 0 ? '+' : ''}${pct.toFixed(2)}%)`;
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
type: 'value',
|
||||
gridIndex: 1,
|
||||
axisLine: { show: false },
|
||||
axisLabel: { show: false },
|
||||
splitLine: { show: false },
|
||||
},
|
||||
],
|
||||
series: [
|
||||
// 分时线
|
||||
{
|
||||
name: indexData.name || '上证指数',
|
||||
type: 'line',
|
||||
data: prices,
|
||||
smooth: true,
|
||||
symbol: 'none',
|
||||
lineStyle: { color: lineColor, width: 1.5 },
|
||||
areaStyle: {
|
||||
color: new echarts.graphic.LinearGradient(0, 0, 0, 1, areaColorStops),
|
||||
},
|
||||
markPoint: {
|
||||
symbol: 'pin',
|
||||
symbolSize: 40,
|
||||
data: markPoints,
|
||||
animation: true,
|
||||
},
|
||||
},
|
||||
// 成交量
|
||||
{
|
||||
name: '成交量',
|
||||
type: 'bar',
|
||||
xAxisIndex: 1,
|
||||
yAxisIndex: 1,
|
||||
data: volumes.map((v, i) => ({
|
||||
value: v,
|
||||
itemStyle: {
|
||||
color: changePcts[i] >= 0 ? 'rgba(255, 77, 77, 0.6)' : 'rgba(34, 197, 94, 0.6)',
|
||||
},
|
||||
})),
|
||||
barWidth: '60%',
|
||||
},
|
||||
],
|
||||
};
|
||||
}, [indexData, alerts, subTextColor, gridLineColor]);
|
||||
|
||||
// 渲染图表
|
||||
const renderChart = useCallback(() => {
|
||||
if (!chartRef.current || !chartOption) return;
|
||||
|
||||
if (!chartInstance.current) {
|
||||
chartInstance.current = echarts.init(chartRef.current);
|
||||
}
|
||||
|
||||
chartInstance.current.setOption(chartOption, true);
|
||||
|
||||
// 点击事件
|
||||
if (onAlertClick) {
|
||||
chartInstance.current.off('click');
|
||||
chartInstance.current.on('click', 'series.line.markPoint', (params) => {
|
||||
if (params.data && params.data.alertData) {
|
||||
onAlertClick(params.data.alertData);
|
||||
}
|
||||
});
|
||||
}
|
||||
}, [chartOption, onAlertClick]);
|
||||
|
||||
// 数据变化时重新渲染
|
||||
useEffect(() => {
|
||||
renderChart();
|
||||
}, [renderChart]);
|
||||
|
||||
// 窗口大小变化时重新渲染
|
||||
useEffect(() => {
|
||||
const handleResize = () => {
|
||||
if (chartInstance.current) {
|
||||
chartInstance.current.resize();
|
||||
}
|
||||
};
|
||||
|
||||
window.addEventListener('resize', handleResize);
|
||||
return () => {
|
||||
window.removeEventListener('resize', handleResize);
|
||||
if (chartInstance.current) {
|
||||
chartInstance.current.dispose();
|
||||
chartInstance.current = null;
|
||||
}
|
||||
};
|
||||
}, []);
|
||||
|
||||
if (!chartOption) {
|
||||
return (
|
||||
<Box h={height} display="flex" alignItems="center" justifyContent="center" color={subTextColor}>
|
||||
暂无数据
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
|
||||
return <Box ref={chartRef} h={height} w="100%" />;
|
||||
};
|
||||
|
||||
export default IndexMinuteChart;
|
||||
@@ -0,0 +1,3 @@
|
||||
export { default as IndexMinuteChart } from './IndexMinuteChart';
|
||||
export { default as ConceptAlertList } from './ConceptAlertList';
|
||||
export { default as AlertSummary } from './AlertSummary';
|
||||
@@ -0,0 +1 @@
|
||||
export { useHotspotData } from './useHotspotData';
|
||||
@@ -0,0 +1,53 @@
|
||||
/**
|
||||
* 热点概览数据获取 Hook
|
||||
* 负责获取指数分时数据和概念异动数据
|
||||
*/
|
||||
import { useState, useEffect, useCallback } from 'react';
|
||||
import { logger } from '@utils/logger';
|
||||
|
||||
/**
|
||||
* @param {Date|null} selectedDate - 选中的交易日期
|
||||
* @returns {Object} 数据和状态
|
||||
*/
|
||||
export const useHotspotData = (selectedDate) => {
|
||||
const [loading, setLoading] = useState(true);
|
||||
const [error, setError] = useState(null);
|
||||
const [data, setData] = useState(null);
|
||||
|
||||
const fetchData = useCallback(async () => {
|
||||
setLoading(true);
|
||||
setError(null);
|
||||
|
||||
try {
|
||||
const dateParam = selectedDate
|
||||
? `?date=${selectedDate.toISOString().split('T')[0]}`
|
||||
: '';
|
||||
const response = await fetch(`/api/market/hotspot-overview${dateParam}`);
|
||||
const result = await response.json();
|
||||
|
||||
if (result.success) {
|
||||
setData(result.data);
|
||||
} else {
|
||||
setError(result.error || '获取数据失败');
|
||||
}
|
||||
} catch (err) {
|
||||
logger.error('useHotspotData', 'fetchData', err);
|
||||
setError('网络请求失败');
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
}, [selectedDate]);
|
||||
|
||||
useEffect(() => {
|
||||
fetchData();
|
||||
}, [fetchData]);
|
||||
|
||||
return {
|
||||
loading,
|
||||
error,
|
||||
data,
|
||||
refetch: fetchData,
|
||||
};
|
||||
};
|
||||
|
||||
export default useHotspotData;
|
||||
@@ -1,8 +1,15 @@
|
||||
/**
|
||||
* 热点概览组件
|
||||
* 展示大盘分时走势 + 概念异动标注
|
||||
*
|
||||
* 模块化结构:
|
||||
* - hooks/useHotspotData.js - 数据获取
|
||||
* - components/IndexMinuteChart.js - 分时图
|
||||
* - components/ConceptAlertList.js - 异动列表
|
||||
* - components/AlertSummary.js - 统计摘要
|
||||
* - utils/chartHelpers.js - 图表辅助函数
|
||||
*/
|
||||
import React, { useState, useEffect, useRef, useCallback } from 'react';
|
||||
import React, { useState, useCallback } from 'react';
|
||||
import {
|
||||
Box,
|
||||
Card,
|
||||
@@ -11,7 +18,6 @@ import {
|
||||
Text,
|
||||
HStack,
|
||||
VStack,
|
||||
Badge,
|
||||
Spinner,
|
||||
Center,
|
||||
Icon,
|
||||
@@ -19,24 +25,29 @@ import {
|
||||
Spacer,
|
||||
Tooltip,
|
||||
useColorModeValue,
|
||||
Stat,
|
||||
StatLabel,
|
||||
StatNumber,
|
||||
StatHelpText,
|
||||
StatArrow,
|
||||
SimpleGrid,
|
||||
Grid,
|
||||
GridItem,
|
||||
Divider,
|
||||
IconButton,
|
||||
Collapse,
|
||||
} from '@chakra-ui/react';
|
||||
import { FaFire, FaRocket, FaChartLine, FaBolt, FaArrowDown } from 'react-icons/fa';
|
||||
import { FaFire, FaList, FaChartArea, FaChevronDown, FaChevronUp } from 'react-icons/fa';
|
||||
import { InfoIcon } from '@chakra-ui/icons';
|
||||
import * as echarts from 'echarts';
|
||||
import { logger } from '@utils/logger';
|
||||
|
||||
import { useHotspotData } from './hooks';
|
||||
import { IndexMinuteChart, ConceptAlertList, AlertSummary } from './components';
|
||||
|
||||
/**
|
||||
* 热点概览主组件
|
||||
* @param {Object} props
|
||||
* @param {Date|null} props.selectedDate - 选中的交易日期
|
||||
*/
|
||||
const HotspotOverview = ({ selectedDate }) => {
|
||||
const chartRef = useRef(null);
|
||||
const chartInstance = useRef(null);
|
||||
const [loading, setLoading] = useState(true);
|
||||
const [data, setData] = useState(null);
|
||||
const [error, setError] = useState(null);
|
||||
const [selectedAlert, setSelectedAlert] = useState(null);
|
||||
const [showAlertList, setShowAlertList] = useState(true);
|
||||
|
||||
// 获取数据
|
||||
const { loading, error, data } = useHotspotData(selectedDate);
|
||||
|
||||
// 颜色主题
|
||||
const cardBg = useColorModeValue('white', '#1a1a1a');
|
||||
@@ -44,373 +55,13 @@ const HotspotOverview = ({ selectedDate }) => {
|
||||
const textColor = useColorModeValue('gray.800', 'white');
|
||||
const subTextColor = useColorModeValue('gray.600', 'gray.400');
|
||||
|
||||
// 获取数据
|
||||
const fetchData = useCallback(async () => {
|
||||
setLoading(true);
|
||||
setError(null);
|
||||
|
||||
try {
|
||||
const dateParam = selectedDate
|
||||
? `?date=${selectedDate.toISOString().split('T')[0]}`
|
||||
: '';
|
||||
const response = await fetch(`/api/market/hotspot-overview${dateParam}`);
|
||||
const result = await response.json();
|
||||
|
||||
if (result.success) {
|
||||
setData(result.data);
|
||||
} else {
|
||||
setError(result.error || '获取数据失败');
|
||||
}
|
||||
} catch (err) {
|
||||
logger.error('HotspotOverview', 'fetchData', err);
|
||||
setError('网络请求失败');
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
}, [selectedDate]);
|
||||
|
||||
useEffect(() => {
|
||||
fetchData();
|
||||
}, [fetchData]);
|
||||
|
||||
// 渲染图表
|
||||
const renderChart = useCallback(() => {
|
||||
if (!chartRef.current || !data) return;
|
||||
|
||||
if (!chartInstance.current) {
|
||||
chartInstance.current = echarts.init(chartRef.current);
|
||||
}
|
||||
|
||||
const { index, alerts } = data;
|
||||
const timeline = index.timeline || [];
|
||||
|
||||
// 准备数据
|
||||
const times = timeline.map((d) => d.time);
|
||||
const prices = timeline.map((d) => d.price);
|
||||
const volumes = timeline.map((d) => d.volume);
|
||||
const changePcts = timeline.map((d) => d.change_pct);
|
||||
|
||||
// 计算Y轴范围
|
||||
const priceMin = Math.min(...prices.filter(Boolean));
|
||||
const priceMax = Math.max(...prices.filter(Boolean));
|
||||
const priceRange = priceMax - priceMin;
|
||||
const yAxisMin = priceMin - priceRange * 0.1;
|
||||
const yAxisMax = priceMax + priceRange * 0.2; // 上方留更多空间给标注
|
||||
|
||||
// 准备异动标注 - 按重要性排序,限制显示数量
|
||||
const sortedAlerts = [...alerts]
|
||||
.sort((a, b) => (b.importance_score || 0) - (a.importance_score || 0))
|
||||
.slice(0, 15); // 最多显示15个标注,避免图表过于密集
|
||||
|
||||
const markPoints = sortedAlerts.map((alert) => {
|
||||
// 找到对应时间的价格
|
||||
const timeIndex = times.indexOf(alert.time);
|
||||
const price = timeIndex >= 0 ? prices[timeIndex] : (alert.index_price || priceMax);
|
||||
|
||||
// 根据异动类型设置颜色和符号
|
||||
let color = '#ff6b6b';
|
||||
let symbol = 'pin';
|
||||
let symbolSize = 35;
|
||||
|
||||
// 暴涨
|
||||
if (alert.alert_type === 'surge_up' || alert.alert_type === 'surge') {
|
||||
color = '#ff4757';
|
||||
symbol = 'triangle';
|
||||
symbolSize = 30 + Math.min((alert.importance_score || 0.5) * 20, 15); // 根据重要性调整大小
|
||||
}
|
||||
// 暴跌
|
||||
else if (alert.alert_type === 'surge_down') {
|
||||
color = '#2ed573';
|
||||
symbol = 'path://M0,0 L10,0 L5,10 Z'; // 向下三角形
|
||||
symbolSize = 30 + Math.min((alert.importance_score || 0.5) * 20, 15);
|
||||
}
|
||||
// 涨停增加
|
||||
else if (alert.alert_type === 'limit_up') {
|
||||
color = '#ff6348';
|
||||
symbol = 'diamond';
|
||||
symbolSize = 28;
|
||||
}
|
||||
// 排名跃升
|
||||
else if (alert.alert_type === 'rank_jump') {
|
||||
color = '#3742fa';
|
||||
symbol = 'circle';
|
||||
symbolSize = 25;
|
||||
}
|
||||
|
||||
// 格式化标签 - 简化显示
|
||||
let label = alert.concept_name;
|
||||
// 截断过长的名称
|
||||
if (label.length > 8) {
|
||||
label = label.substring(0, 7) + '...';
|
||||
}
|
||||
|
||||
// 添加变化信息
|
||||
const changeDelta = alert.change_delta;
|
||||
if (changeDelta) {
|
||||
const sign = changeDelta > 0 ? '+' : '';
|
||||
label += `\n${sign}${changeDelta.toFixed(1)}%`;
|
||||
}
|
||||
|
||||
return {
|
||||
name: alert.concept_name,
|
||||
coord: [alert.time, price],
|
||||
value: label,
|
||||
symbol: symbol,
|
||||
symbolSize: symbolSize,
|
||||
itemStyle: {
|
||||
color: color,
|
||||
borderColor: '#fff',
|
||||
borderWidth: 1,
|
||||
shadowBlur: 3,
|
||||
shadowColor: 'rgba(0,0,0,0.2)',
|
||||
},
|
||||
label: {
|
||||
show: true,
|
||||
position: alert.alert_type === 'surge_down' ? 'bottom' : 'top', // 暴跌标签在下方
|
||||
formatter: '{b}',
|
||||
fontSize: 9,
|
||||
color: textColor,
|
||||
backgroundColor: alert.alert_type === 'surge_down'
|
||||
? 'rgba(46, 213, 115, 0.9)'
|
||||
: 'rgba(255,255,255,0.9)',
|
||||
padding: [2, 4],
|
||||
borderRadius: 2,
|
||||
borderColor: color,
|
||||
borderWidth: 1,
|
||||
},
|
||||
// 存储额外信息用于 tooltip
|
||||
alertData: alert,
|
||||
};
|
||||
});
|
||||
|
||||
// 渐变色 - 根据涨跌
|
||||
const latestChangePct = changePcts[changePcts.length - 1] || 0;
|
||||
const areaColorStops = latestChangePct >= 0
|
||||
? [
|
||||
{ offset: 0, color: 'rgba(255, 77, 77, 0.4)' },
|
||||
{ offset: 1, color: 'rgba(255, 77, 77, 0.05)' },
|
||||
]
|
||||
: [
|
||||
{ offset: 0, color: 'rgba(34, 197, 94, 0.4)' },
|
||||
{ offset: 1, color: 'rgba(34, 197, 94, 0.05)' },
|
||||
];
|
||||
|
||||
const lineColor = latestChangePct >= 0 ? '#ff4d4d' : '#22c55e';
|
||||
|
||||
const option = {
|
||||
backgroundColor: 'transparent',
|
||||
tooltip: {
|
||||
trigger: 'axis',
|
||||
axisPointer: {
|
||||
type: 'cross',
|
||||
crossStyle: {
|
||||
color: '#999',
|
||||
},
|
||||
},
|
||||
formatter: function (params) {
|
||||
if (!params || params.length === 0) return '';
|
||||
|
||||
const dataIndex = params[0].dataIndex;
|
||||
const time = times[dataIndex];
|
||||
const price = prices[dataIndex];
|
||||
const changePct = changePcts[dataIndex];
|
||||
const volume = volumes[dataIndex];
|
||||
|
||||
let html = `
|
||||
<div style="padding: 8px;">
|
||||
<div style="font-weight: bold; margin-bottom: 4px;">${time}</div>
|
||||
<div>指数: <span style="color: ${changePct >= 0 ? '#ff4d4d' : '#22c55e'}; font-weight: bold;">${price?.toFixed(2)}</span></div>
|
||||
<div>涨跌: <span style="color: ${changePct >= 0 ? '#ff4d4d' : '#22c55e'};">${changePct >= 0 ? '+' : ''}${changePct?.toFixed(2)}%</span></div>
|
||||
<div>成交量: ${(volume / 10000).toFixed(0)}万手</div>
|
||||
</div>
|
||||
`;
|
||||
|
||||
// 检查是否有异动
|
||||
const alertsAtTime = alerts.filter((a) => a.time === time);
|
||||
if (alertsAtTime.length > 0) {
|
||||
html += '<div style="border-top: 1px solid #eee; margin-top: 4px; padding-top: 4px;">';
|
||||
html += '<div style="font-weight: bold; color: #ff6b6b;">概念异动:</div>';
|
||||
alertsAtTime.forEach((alert) => {
|
||||
const typeLabel = {
|
||||
surge: '急涨',
|
||||
surge_up: '暴涨',
|
||||
surge_down: '暴跌',
|
||||
limit_up: '涨停增加',
|
||||
rank_jump: '排名跃升',
|
||||
}[alert.alert_type] || alert.alert_type;
|
||||
const typeColor = alert.alert_type === 'surge_down' ? '#2ed573' : '#ff6b6b';
|
||||
const delta = alert.change_delta ? ` (${alert.change_delta > 0 ? '+' : ''}${alert.change_delta.toFixed(2)}%)` : '';
|
||||
const zscore = alert.zscore ? ` Z=${alert.zscore.toFixed(1)}` : '';
|
||||
html += `<div style="color: ${typeColor}">• ${alert.concept_name} (${typeLabel}${delta}${zscore})</div>`;
|
||||
});
|
||||
html += '</div>';
|
||||
}
|
||||
|
||||
return html;
|
||||
},
|
||||
},
|
||||
legend: {
|
||||
show: false,
|
||||
},
|
||||
grid: [
|
||||
{
|
||||
left: '8%',
|
||||
right: '3%',
|
||||
top: '8%',
|
||||
height: '55%',
|
||||
},
|
||||
{
|
||||
left: '8%',
|
||||
right: '3%',
|
||||
top: '70%',
|
||||
height: '20%',
|
||||
},
|
||||
],
|
||||
xAxis: [
|
||||
{
|
||||
type: 'category',
|
||||
data: times,
|
||||
axisLine: { lineStyle: { color: '#ddd' } },
|
||||
axisLabel: {
|
||||
color: subTextColor,
|
||||
fontSize: 10,
|
||||
interval: Math.floor(times.length / 6),
|
||||
},
|
||||
axisTick: { show: false },
|
||||
splitLine: { show: false },
|
||||
},
|
||||
{
|
||||
type: 'category',
|
||||
gridIndex: 1,
|
||||
data: times,
|
||||
axisLine: { lineStyle: { color: '#ddd' } },
|
||||
axisLabel: { show: false },
|
||||
axisTick: { show: false },
|
||||
splitLine: { show: false },
|
||||
},
|
||||
],
|
||||
yAxis: [
|
||||
{
|
||||
type: 'value',
|
||||
min: yAxisMin,
|
||||
max: yAxisMax,
|
||||
axisLine: { show: false },
|
||||
axisLabel: {
|
||||
color: subTextColor,
|
||||
fontSize: 10,
|
||||
formatter: (val) => val.toFixed(0),
|
||||
},
|
||||
splitLine: {
|
||||
lineStyle: { color: '#eee', type: 'dashed' },
|
||||
},
|
||||
// 右侧显示涨跌幅
|
||||
axisPointer: {
|
||||
label: {
|
||||
formatter: function (params) {
|
||||
const pct = ((params.value - index.prev_close) / index.prev_close) * 100;
|
||||
return `${params.value.toFixed(2)} (${pct >= 0 ? '+' : ''}${pct.toFixed(2)}%)`;
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
type: 'value',
|
||||
gridIndex: 1,
|
||||
axisLine: { show: false },
|
||||
axisLabel: { show: false },
|
||||
splitLine: { show: false },
|
||||
},
|
||||
],
|
||||
series: [
|
||||
// 分时线
|
||||
{
|
||||
name: '上证指数',
|
||||
type: 'line',
|
||||
data: prices,
|
||||
smooth: true,
|
||||
symbol: 'none',
|
||||
lineStyle: {
|
||||
color: lineColor,
|
||||
width: 1.5,
|
||||
},
|
||||
areaStyle: {
|
||||
color: new echarts.graphic.LinearGradient(0, 0, 0, 1, areaColorStops),
|
||||
},
|
||||
markPoint: {
|
||||
symbol: 'pin',
|
||||
symbolSize: 40,
|
||||
data: markPoints,
|
||||
animation: true,
|
||||
},
|
||||
},
|
||||
// 成交量
|
||||
{
|
||||
name: '成交量',
|
||||
type: 'bar',
|
||||
xAxisIndex: 1,
|
||||
yAxisIndex: 1,
|
||||
data: volumes.map((v, i) => ({
|
||||
value: v,
|
||||
itemStyle: {
|
||||
color: changePcts[i] >= 0 ? 'rgba(255, 77, 77, 0.6)' : 'rgba(34, 197, 94, 0.6)',
|
||||
},
|
||||
})),
|
||||
barWidth: '60%',
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
chartInstance.current.setOption(option, true);
|
||||
}, [data, textColor, subTextColor]);
|
||||
|
||||
// 数据变化时重新渲染
|
||||
useEffect(() => {
|
||||
if (data) {
|
||||
renderChart();
|
||||
}
|
||||
}, [data, renderChart]);
|
||||
|
||||
// 窗口大小变化时重新渲染
|
||||
useEffect(() => {
|
||||
const handleResize = () => {
|
||||
if (chartInstance.current) {
|
||||
chartInstance.current.resize();
|
||||
}
|
||||
};
|
||||
|
||||
window.addEventListener('resize', handleResize);
|
||||
return () => {
|
||||
window.removeEventListener('resize', handleResize);
|
||||
if (chartInstance.current) {
|
||||
chartInstance.current.dispose();
|
||||
chartInstance.current = null;
|
||||
}
|
||||
};
|
||||
// 点击异动标注
|
||||
const handleAlertClick = useCallback((alert) => {
|
||||
setSelectedAlert(alert);
|
||||
// 可以在这里添加滚动到对应位置的逻辑
|
||||
}, []);
|
||||
|
||||
// 异动类型标签
|
||||
const AlertTypeBadge = ({ type, count }) => {
|
||||
const config = {
|
||||
surge: { label: '急涨', color: 'red', icon: FaBolt },
|
||||
surge_up: { label: '暴涨', color: 'red', icon: FaBolt },
|
||||
surge_down: { label: '暴跌', color: 'green', icon: FaArrowDown },
|
||||
limit_up: { label: '涨停', color: 'orange', icon: FaRocket },
|
||||
rank_jump: { label: '排名跃升', color: 'blue', icon: FaChartLine },
|
||||
};
|
||||
|
||||
const cfg = config[type] || { label: type, color: 'gray', icon: FaFire };
|
||||
|
||||
return (
|
||||
<Badge colorScheme={cfg.color} variant="subtle" px={2} py={1} borderRadius="md">
|
||||
<HStack spacing={1}>
|
||||
<Icon as={cfg.icon} boxSize={3} />
|
||||
<Text>{cfg.label}</Text>
|
||||
<Text fontWeight="bold">{count}</Text>
|
||||
</HStack>
|
||||
</Badge>
|
||||
);
|
||||
};
|
||||
|
||||
// 渲染加载状态
|
||||
if (loading) {
|
||||
return (
|
||||
<Card bg={cardBg} borderWidth="1px" borderColor={borderColor}>
|
||||
@@ -426,6 +77,7 @@ const HotspotOverview = ({ selectedDate }) => {
|
||||
);
|
||||
}
|
||||
|
||||
// 渲染错误状态
|
||||
if (error) {
|
||||
return (
|
||||
<Card bg={cardBg} borderWidth="1px" borderColor={borderColor}>
|
||||
@@ -441,6 +93,7 @@ const HotspotOverview = ({ selectedDate }) => {
|
||||
);
|
||||
}
|
||||
|
||||
// 无数据
|
||||
if (!data) {
|
||||
return null;
|
||||
}
|
||||
@@ -450,7 +103,7 @@ const HotspotOverview = ({ selectedDate }) => {
|
||||
return (
|
||||
<Card bg={cardBg} borderWidth="1px" borderColor={borderColor}>
|
||||
<CardBody>
|
||||
{/* 头部信息 */}
|
||||
{/* 头部 */}
|
||||
<Flex align="center" mb={4}>
|
||||
<HStack spacing={3}>
|
||||
<Icon as={FaFire} boxSize={6} color="orange.500" />
|
||||
@@ -459,69 +112,75 @@ const HotspotOverview = ({ selectedDate }) => {
|
||||
</Heading>
|
||||
</HStack>
|
||||
<Spacer />
|
||||
<HStack spacing={2}>
|
||||
<Tooltip label={showAlertList ? '收起异动列表' : '展开异动列表'}>
|
||||
<IconButton
|
||||
icon={showAlertList ? <FaChevronUp /> : <FaList />}
|
||||
size="sm"
|
||||
variant="ghost"
|
||||
onClick={() => setShowAlertList(!showAlertList)}
|
||||
aria-label="切换异动列表"
|
||||
/>
|
||||
</Tooltip>
|
||||
<Tooltip label="展示大盘走势与概念异动的关联">
|
||||
<Icon as={InfoIcon} color={subTextColor} />
|
||||
</Tooltip>
|
||||
</HStack>
|
||||
</Flex>
|
||||
|
||||
{/* 指数统计 */}
|
||||
<SimpleGrid columns={{ base: 2, md: 4 }} spacing={4} mb={4}>
|
||||
<Stat size="sm">
|
||||
<StatLabel color={subTextColor}>{index.name}</StatLabel>
|
||||
<StatNumber
|
||||
fontSize="xl"
|
||||
color={index.change_pct >= 0 ? 'red.500' : 'green.500'}
|
||||
{/* 统计摘要 */}
|
||||
<Box mb={4}>
|
||||
<AlertSummary indexData={index} alerts={alerts} alertSummary={alert_summary} />
|
||||
</Box>
|
||||
|
||||
<Divider mb={4} />
|
||||
|
||||
{/* 主体内容:图表 + 异动列表 */}
|
||||
<Grid
|
||||
templateColumns={{ base: '1fr', lg: showAlertList ? '1fr 300px' : '1fr' }}
|
||||
gap={4}
|
||||
>
|
||||
{index.latest_price?.toFixed(2)}
|
||||
</StatNumber>
|
||||
<StatHelpText mb={0}>
|
||||
<StatArrow type={index.change_pct >= 0 ? 'increase' : 'decrease'} />
|
||||
{index.change_pct?.toFixed(2)}%
|
||||
</StatHelpText>
|
||||
</Stat>
|
||||
|
||||
<Stat size="sm">
|
||||
<StatLabel color={subTextColor}>最高</StatLabel>
|
||||
<StatNumber fontSize="xl" color="red.500">
|
||||
{index.high?.toFixed(2)}
|
||||
</StatNumber>
|
||||
</Stat>
|
||||
|
||||
<Stat size="sm">
|
||||
<StatLabel color={subTextColor}>最低</StatLabel>
|
||||
<StatNumber fontSize="xl" color="green.500">
|
||||
{index.low?.toFixed(2)}
|
||||
</StatNumber>
|
||||
</Stat>
|
||||
|
||||
<Stat size="sm">
|
||||
<StatLabel color={subTextColor}>异动次数</StatLabel>
|
||||
<StatNumber fontSize="xl" color="orange.500">
|
||||
{alerts.length}
|
||||
</StatNumber>
|
||||
</Stat>
|
||||
</SimpleGrid>
|
||||
|
||||
{/* 异动类型统计 */}
|
||||
{alerts.length > 0 && (
|
||||
<HStack spacing={2} mb={4} flexWrap="wrap">
|
||||
{(alert_summary.surge_up > 0 || alert_summary.surge > 0) && (
|
||||
<AlertTypeBadge type="surge_up" count={(alert_summary.surge_up || 0) + (alert_summary.surge || 0)} />
|
||||
)}
|
||||
{alert_summary.surge_down > 0 && (
|
||||
<AlertTypeBadge type="surge_down" count={alert_summary.surge_down} />
|
||||
)}
|
||||
{alert_summary.limit_up > 0 && (
|
||||
<AlertTypeBadge type="limit_up" count={alert_summary.limit_up} />
|
||||
)}
|
||||
{alert_summary.rank_jump > 0 && (
|
||||
<AlertTypeBadge type="rank_jump" count={alert_summary.rank_jump} />
|
||||
)}
|
||||
{/* 分时图 */}
|
||||
<GridItem>
|
||||
<Box>
|
||||
<HStack spacing={2} mb={2}>
|
||||
<Icon as={FaChartArea} color="purple.500" boxSize={4} />
|
||||
<Text fontSize="sm" fontWeight="medium" color={textColor}>
|
||||
大盘分时走势
|
||||
</Text>
|
||||
</HStack>
|
||||
)}
|
||||
<IndexMinuteChart
|
||||
indexData={index}
|
||||
alerts={alerts}
|
||||
onAlertClick={handleAlertClick}
|
||||
height="350px"
|
||||
/>
|
||||
</Box>
|
||||
</GridItem>
|
||||
|
||||
{/* 图表 */}
|
||||
<Box ref={chartRef} h="400px" w="100%" />
|
||||
{/* 异动列表(可收起) */}
|
||||
<Collapse in={showAlertList} animateOpacity>
|
||||
<GridItem>
|
||||
<Box>
|
||||
<HStack spacing={2} mb={2}>
|
||||
<Icon as={FaList} color="orange.500" boxSize={4} />
|
||||
<Text fontSize="sm" fontWeight="medium" color={textColor}>
|
||||
异动记录
|
||||
</Text>
|
||||
<Text fontSize="xs" color={subTextColor}>
|
||||
({alerts.length})
|
||||
</Text>
|
||||
</HStack>
|
||||
<ConceptAlertList
|
||||
alerts={alerts}
|
||||
onAlertClick={handleAlertClick}
|
||||
selectedAlert={selectedAlert}
|
||||
maxHeight="350px"
|
||||
/>
|
||||
</Box>
|
||||
</GridItem>
|
||||
</Collapse>
|
||||
</Grid>
|
||||
|
||||
{/* 无异动提示 */}
|
||||
{alerts.length === 0 && (
|
||||
|
||||
@@ -0,0 +1,159 @@
|
||||
/**
|
||||
* 图表辅助函数
|
||||
* 用于处理异动标注等图表相关逻辑
|
||||
*/
|
||||
|
||||
/**
|
||||
* 获取异动标注的配色和符号
|
||||
* @param {string} alertType - 异动类型
|
||||
* @param {number} importanceScore - 重要性得分
|
||||
* @returns {Object} { color, symbol, symbolSize }
|
||||
*/
|
||||
export const getAlertStyle = (alertType, importanceScore = 0.5) => {
|
||||
let color = '#ff6b6b';
|
||||
let symbol = 'pin';
|
||||
let symbolSize = 35;
|
||||
|
||||
switch (alertType) {
|
||||
case 'surge_up':
|
||||
case 'surge':
|
||||
color = '#ff4757';
|
||||
symbol = 'triangle';
|
||||
symbolSize = 30 + Math.min(importanceScore * 20, 15);
|
||||
break;
|
||||
case 'surge_down':
|
||||
color = '#2ed573';
|
||||
symbol = 'path://M0,0 L10,0 L5,10 Z'; // 向下三角形
|
||||
symbolSize = 30 + Math.min(importanceScore * 20, 15);
|
||||
break;
|
||||
case 'limit_up':
|
||||
color = '#ff6348';
|
||||
symbol = 'diamond';
|
||||
symbolSize = 28;
|
||||
break;
|
||||
case 'rank_jump':
|
||||
color = '#3742fa';
|
||||
symbol = 'circle';
|
||||
symbolSize = 25;
|
||||
break;
|
||||
case 'volume_spike':
|
||||
color = '#ffa502';
|
||||
symbol = 'rect';
|
||||
symbolSize = 25;
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
return { color, symbol, symbolSize };
|
||||
};
|
||||
|
||||
/**
|
||||
* 获取异动类型的显示标签
|
||||
* @param {string} alertType - 异动类型
|
||||
* @returns {string} 显示标签
|
||||
*/
|
||||
export const getAlertTypeLabel = (alertType) => {
|
||||
const labels = {
|
||||
surge: '急涨',
|
||||
surge_up: '暴涨',
|
||||
surge_down: '暴跌',
|
||||
limit_up: '涨停增加',
|
||||
rank_jump: '排名跃升',
|
||||
volume_spike: '放量',
|
||||
unknown: '异动',
|
||||
};
|
||||
return labels[alertType] || alertType;
|
||||
};
|
||||
|
||||
/**
|
||||
* 生成图表标注点数据
|
||||
* @param {Array} alerts - 异动数据数组
|
||||
* @param {Array} times - 时间数组
|
||||
* @param {Array} prices - 价格数组
|
||||
* @param {number} priceMax - 最高价格(用于无法匹配时间时的默认位置)
|
||||
* @param {number} maxCount - 最大显示数量
|
||||
* @returns {Array} ECharts markPoint data
|
||||
*/
|
||||
export const getAlertMarkPoints = (alerts, times, prices, priceMax, maxCount = 15) => {
|
||||
if (!alerts || alerts.length === 0) return [];
|
||||
|
||||
// 按重要性排序,限制显示数量
|
||||
const sortedAlerts = [...alerts]
|
||||
.sort((a, b) => (b.final_score || b.importance_score || 0) - (a.final_score || a.importance_score || 0))
|
||||
.slice(0, maxCount);
|
||||
|
||||
return sortedAlerts.map((alert) => {
|
||||
// 找到对应时间的价格
|
||||
const timeIndex = times.indexOf(alert.time);
|
||||
const price = timeIndex >= 0 ? prices[timeIndex] : (alert.index_price || priceMax);
|
||||
|
||||
const { color, symbol, symbolSize } = getAlertStyle(
|
||||
alert.alert_type,
|
||||
alert.final_score / 100 || alert.importance_score || 0.5
|
||||
);
|
||||
|
||||
// 格式化标签
|
||||
let label = alert.concept_name || '';
|
||||
if (label.length > 6) {
|
||||
label = label.substring(0, 5) + '...';
|
||||
}
|
||||
|
||||
// 添加涨停数量(如果有)
|
||||
if (alert.limit_up_count > 0) {
|
||||
label += `\n涨停: ${alert.limit_up_count}`;
|
||||
}
|
||||
|
||||
const isDown = alert.alert_type === 'surge_down';
|
||||
|
||||
return {
|
||||
name: alert.concept_name,
|
||||
coord: [alert.time, price],
|
||||
value: label,
|
||||
symbol,
|
||||
symbolSize,
|
||||
itemStyle: {
|
||||
color,
|
||||
borderColor: '#fff',
|
||||
borderWidth: 1,
|
||||
shadowBlur: 3,
|
||||
shadowColor: 'rgba(0,0,0,0.2)',
|
||||
},
|
||||
label: {
|
||||
show: true,
|
||||
position: isDown ? 'bottom' : 'top',
|
||||
formatter: '{b}',
|
||||
fontSize: 9,
|
||||
color: '#333',
|
||||
backgroundColor: isDown ? 'rgba(46, 213, 115, 0.9)' : 'rgba(255,255,255,0.9)',
|
||||
padding: [2, 4],
|
||||
borderRadius: 2,
|
||||
borderColor: color,
|
||||
borderWidth: 1,
|
||||
},
|
||||
alertData: alert, // 存储原始数据
|
||||
};
|
||||
});
|
||||
};
|
||||
|
||||
/**
|
||||
* 格式化分数显示
|
||||
* @param {number} score - 分数
|
||||
* @returns {string} 格式化后的分数
|
||||
*/
|
||||
export const formatScore = (score) => {
|
||||
if (score === null || score === undefined) return '-';
|
||||
return Math.round(score).toString();
|
||||
};
|
||||
|
||||
/**
|
||||
* 获取分数对应的颜色
|
||||
* @param {number} score - 分数 (0-100)
|
||||
* @returns {string} 颜色代码
|
||||
*/
|
||||
export const getScoreColor = (score) => {
|
||||
if (score >= 80) return '#ff4757';
|
||||
if (score >= 60) return '#ff6348';
|
||||
if (score >= 40) return '#ffa502';
|
||||
return '#747d8c';
|
||||
};
|
||||
@@ -0,0 +1 @@
|
||||
export * from './chartHelpers';
|
||||
Reference in New Issue
Block a user