update pay ui

This commit is contained in:
2025-12-10 11:02:09 +08:00
parent 813d416499
commit e56f62506d
21 changed files with 5514 additions and 151 deletions

198
app.py
View File

@@ -12536,113 +12536,113 @@ def get_hotspot_overview():
'change_pct': change_pct 'change_pct': change_pct
}) })
# 2. 获取概念异动数据(从 concept_anomaly_hybrid 表) # 2. 获取概念异动数据(优先从 V2 表fallback 到旧表)
alerts = [] alerts = []
use_v2 = False
# 首先确保表存在(使用 begin() 来自动提交)
try:
with engine.begin() as conn:
conn.execute(text("""
CREATE TABLE IF NOT EXISTS concept_anomaly_hybrid (
id INT AUTO_INCREMENT PRIMARY KEY,
concept_id VARCHAR(64) NOT NULL,
alert_time DATETIME NOT NULL,
trade_date DATE NOT NULL,
alert_type VARCHAR(32) NOT NULL,
final_score FLOAT NOT NULL,
rule_score FLOAT NOT NULL,
ml_score FLOAT NOT NULL,
trigger_reason VARCHAR(64),
alpha FLOAT,
alpha_delta FLOAT,
amt_ratio FLOAT,
amt_delta FLOAT,
rank_pct FLOAT,
limit_up_ratio FLOAT,
stock_count INT,
total_amt FLOAT,
triggered_rules JSON,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
UNIQUE KEY uk_concept_time (concept_id, alert_time, trade_date),
INDEX idx_trade_date (trade_date),
INDEX idx_concept_id (concept_id),
INDEX idx_final_score (final_score),
INDEX idx_alert_type (alert_type)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='概念异动检测结果(融合版)'
"""))
except Exception as create_err:
logger.debug(f"创建表检查: {create_err}")
with engine.connect() as conn: with engine.connect() as conn:
# 查询 concept_anomaly_hybrid 表 # 尝试查询 V2 表(时间片对齐 + 持续确认版本)
alert_result = conn.execute(text(""" try:
SELECT v2_result = conn.execute(text("""
a.concept_id, SELECT
a.alert_time, concept_id, alert_time, trade_date, alert_type,
a.trade_date, final_score, rule_score, ml_score, trigger_reason, confirm_ratio,
a.alert_type, alpha, alpha_zscore, amt_zscore, rank_zscore,
a.final_score, momentum_3m, momentum_5m, limit_up_ratio, triggered_rules
a.rule_score, FROM concept_anomaly_v2
a.ml_score, WHERE trade_date = :trade_date
a.trigger_reason, ORDER BY alert_time
a.alpha, """), {'trade_date': trade_date})
a.alpha_delta, v2_rows = v2_result.fetchall()
a.amt_ratio, if v2_rows:
a.amt_delta, use_v2 = True
a.rank_pct, for row in v2_rows:
a.limit_up_ratio, triggered_rules = None
a.stock_count, if row[16]:
a.total_amt, try:
a.triggered_rules triggered_rules = json.loads(row[16]) if isinstance(row[16], str) else row[16]
FROM concept_anomaly_hybrid a except:
WHERE a.trade_date = :trade_date pass
ORDER BY a.alert_time
"""), {'trade_date': trade_date})
# 获取概念名称映射(从 ES 或缓存) alerts.append({
concept_names = {} 'concept_id': row[0],
'concept_name': row[0], # 后面会填充
'time': row[1].strftime('%H:%M') if row[1] else None,
'timestamp': row[1].isoformat() if row[1] else None,
'alert_type': row[3],
'final_score': float(row[4]) if row[4] else None,
'rule_score': float(row[5]) if row[5] else None,
'ml_score': float(row[6]) if row[6] else None,
'trigger_reason': row[7],
# V2 新增字段
'confirm_ratio': float(row[8]) if row[8] else None,
'alpha': float(row[9]) if row[9] else None,
'alpha_zscore': float(row[10]) if row[10] else None,
'amt_zscore': float(row[11]) if row[11] else None,
'rank_zscore': float(row[12]) if row[12] else None,
'momentum_3m': float(row[13]) if row[13] else None,
'momentum_5m': float(row[14]) if row[14] else None,
'limit_up_ratio': float(row[15]) if row[15] else 0,
'triggered_rules': triggered_rules,
# 兼容字段
'importance_score': float(row[4]) / 100 if row[4] else None,
'is_v2': True,
})
except Exception as v2_err:
logger.debug(f"V2 表查询失败,使用旧表: {v2_err}")
for row in alert_result: # Fallback: 查询旧表
concept_id = row[0] if not use_v2:
alert_time = row[1] try:
triggered_rules = None alert_result = conn.execute(text("""
if row[16]: SELECT
try: a.concept_id, a.alert_time, a.trade_date, a.alert_type,
triggered_rules = json.loads(row[16]) if isinstance(row[16], str) else row[16] a.final_score, a.rule_score, a.ml_score, a.trigger_reason,
except: a.alpha, a.alpha_delta, a.amt_ratio, a.amt_delta,
pass a.rank_pct, a.limit_up_ratio, a.stock_count, a.total_amt,
a.triggered_rules
FROM concept_anomaly_hybrid a
WHERE a.trade_date = :trade_date
ORDER BY a.alert_time
"""), {'trade_date': trade_date})
# 获取概念名称(优先从缓存,否则使用 concept_id for row in alert_result:
concept_name = concept_names.get(concept_id) or concept_id triggered_rules = None
if row[16]:
try:
triggered_rules = json.loads(row[16]) if isinstance(row[16], str) else row[16]
except:
pass
# 计算涨停数量(从 limit_up_ratio 和 stock_count 估算) limit_up_ratio = float(row[13]) if row[13] else 0
limit_up_ratio = float(row[13]) if row[13] else 0 stock_count = int(row[14]) if row[14] else 0
stock_count = int(row[14]) if row[14] else 0 limit_up_count = int(limit_up_ratio * stock_count) if stock_count > 0 else 0
limit_up_count = int(limit_up_ratio * stock_count) if stock_count > 0 else 0
alerts.append({ alerts.append({
'concept_id': concept_id, 'concept_id': row[0],
'concept_name': concept_name, 'concept_name': row[0],
'time': alert_time.strftime('%H:%M') if alert_time else None, 'time': row[1].strftime('%H:%M') if row[1] else None,
'timestamp': alert_time.isoformat() if alert_time else None, 'timestamp': row[1].isoformat() if row[1] else None,
'alert_type': row[3], 'alert_type': row[3],
'final_score': float(row[4]) if row[4] else None, 'final_score': float(row[4]) if row[4] else None,
'rule_score': float(row[5]) if row[5] else None, 'rule_score': float(row[5]) if row[5] else None,
'ml_score': float(row[6]) if row[6] else None, 'ml_score': float(row[6]) if row[6] else None,
'trigger_reason': row[7], 'trigger_reason': row[7],
'alpha': float(row[8]) if row[8] else None, 'alpha': float(row[8]) if row[8] else None,
'alpha_delta': float(row[9]) if row[9] else None, 'alpha_delta': float(row[9]) if row[9] else None,
'amt_ratio': float(row[10]) if row[10] else None, 'amt_ratio': float(row[10]) if row[10] else None,
'amt_delta': float(row[11]) if row[11] else None, 'amt_delta': float(row[11]) if row[11] else None,
'rank_pct': float(row[12]) if row[12] else None, 'rank_pct': float(row[12]) if row[12] else None,
'limit_up_ratio': limit_up_ratio, 'limit_up_ratio': limit_up_ratio,
'limit_up_count': limit_up_count, 'limit_up_count': limit_up_count,
'stock_count': stock_count, 'stock_count': stock_count,
'total_amt': float(row[15]) if row[15] else None, 'total_amt': float(row[15]) if row[15] else None,
'triggered_rules': triggered_rules, 'triggered_rules': triggered_rules,
# 兼容旧字段 'importance_score': float(row[4]) / 100 if row[4] else None,
'importance_score': float(row[4]) / 100 if row[4] else None, 'is_v2': False,
}) })
except Exception as old_err:
logger.debug(f"旧表查询也失败: {old_err}")
# 尝试批量获取概念名称 # 尝试批量获取概念名称
if alerts: if alerts:

294
ml/backtest_v2.py Normal file
View File

@@ -0,0 +1,294 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
V2 回测脚本 - 验证时间片对齐 + 持续性确认的效果
回测指标:
1. 准确率:异动后 N 分钟内 alpha 是否继续上涨/下跌
2. 虚警率:多少异动是噪音
3. 持续性:平均异动持续时长
"""
import os
import sys
import json
import argparse
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Tuple
from collections import defaultdict
import numpy as np
import pandas as pd
from tqdm import tqdm
from sqlalchemy import create_engine, text
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from ml.detector_v2 import AnomalyDetectorV2, CONFIG
# ==================== 配置 ====================
MYSQL_ENGINE = create_engine(
"mysql+pymysql://root:Zzl5588161!@192.168.1.5:3306/stock",
echo=False
)
# ==================== 回测评估 ====================
def evaluate_alerts(
alerts: List[Dict],
raw_data: pd.DataFrame,
lookahead_minutes: int = 10
) -> Dict:
"""
评估异动质量
指标:
1. 方向正确率:异动后 N 分钟 alpha 方向是否一致
2. 持续率:异动后 N 分钟内有多少时刻 alpha 保持同向
3. 峰值收益:异动后 N 分钟内的最大 alpha
"""
if not alerts:
return {'accuracy': 0, 'sustained_rate': 0, 'avg_peak': 0, 'total_alerts': 0}
results = []
for alert in alerts:
concept_id = alert['concept_id']
alert_time = alert['alert_time']
alert_alpha = alert['alpha']
is_up = alert_alpha > 0
# 获取该概念在异动后的数据
concept_data = raw_data[
(raw_data['concept_id'] == concept_id) &
(raw_data['timestamp'] > alert_time)
].head(lookahead_minutes)
if len(concept_data) < 3:
continue
future_alphas = concept_data['alpha'].values
# 方向正确:未来 alpha 平均值与当前同向
avg_future_alpha = np.mean(future_alphas)
direction_correct = (is_up and avg_future_alpha > 0) or (not is_up and avg_future_alpha < 0)
# 持续率:有多少时刻保持同向
if is_up:
sustained_count = sum(1 for a in future_alphas if a > 0)
else:
sustained_count = sum(1 for a in future_alphas if a < 0)
sustained_rate = sustained_count / len(future_alphas)
# 峰值收益
if is_up:
peak = max(future_alphas)
else:
peak = min(future_alphas)
results.append({
'direction_correct': direction_correct,
'sustained_rate': sustained_rate,
'peak': peak,
'alert_alpha': alert_alpha,
})
if not results:
return {'accuracy': 0, 'sustained_rate': 0, 'avg_peak': 0, 'total_alerts': 0}
return {
'accuracy': np.mean([r['direction_correct'] for r in results]),
'sustained_rate': np.mean([r['sustained_rate'] for r in results]),
'avg_peak': np.mean([abs(r['peak']) for r in results]),
'total_alerts': len(alerts),
'evaluated_alerts': len(results),
}
def save_alerts_to_mysql(alerts: List[Dict], dry_run: bool = False) -> int:
"""保存异动到 MySQL"""
if not alerts or dry_run:
return 0
# 确保表存在
with MYSQL_ENGINE.begin() as conn:
conn.execute(text("""
CREATE TABLE IF NOT EXISTS concept_anomaly_v2 (
id INT AUTO_INCREMENT PRIMARY KEY,
concept_id VARCHAR(64) NOT NULL,
alert_time DATETIME NOT NULL,
trade_date DATE NOT NULL,
alert_type VARCHAR(32) NOT NULL,
final_score FLOAT NOT NULL,
rule_score FLOAT NOT NULL,
ml_score FLOAT NOT NULL,
trigger_reason VARCHAR(128),
confirm_ratio FLOAT,
alpha FLOAT,
alpha_zscore FLOAT,
amt_zscore FLOAT,
rank_zscore FLOAT,
momentum_3m FLOAT,
momentum_5m FLOAT,
limit_up_ratio FLOAT,
triggered_rules JSON,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
UNIQUE KEY uk_concept_time (concept_id, alert_time, trade_date),
INDEX idx_trade_date (trade_date),
INDEX idx_final_score (final_score)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='概念异动 V2时间片对齐+持续确认)'
"""))
# 插入数据
saved = 0
with MYSQL_ENGINE.begin() as conn:
for alert in alerts:
try:
conn.execute(text("""
INSERT IGNORE INTO concept_anomaly_v2
(concept_id, alert_time, trade_date, alert_type,
final_score, rule_score, ml_score, trigger_reason, confirm_ratio,
alpha, alpha_zscore, amt_zscore, rank_zscore,
momentum_3m, momentum_5m, limit_up_ratio, triggered_rules)
VALUES
(:concept_id, :alert_time, :trade_date, :alert_type,
:final_score, :rule_score, :ml_score, :trigger_reason, :confirm_ratio,
:alpha, :alpha_zscore, :amt_zscore, :rank_zscore,
:momentum_3m, :momentum_5m, :limit_up_ratio, :triggered_rules)
"""), {
'concept_id': alert['concept_id'],
'alert_time': alert['alert_time'],
'trade_date': alert['trade_date'],
'alert_type': alert['alert_type'],
'final_score': alert['final_score'],
'rule_score': alert['rule_score'],
'ml_score': alert['ml_score'],
'trigger_reason': alert['trigger_reason'],
'confirm_ratio': alert.get('confirm_ratio', 0),
'alpha': alert['alpha'],
'alpha_zscore': alert.get('alpha_zscore', 0),
'amt_zscore': alert.get('amt_zscore', 0),
'rank_zscore': alert.get('rank_zscore', 0),
'momentum_3m': alert.get('momentum_3m', 0),
'momentum_5m': alert.get('momentum_5m', 0),
'limit_up_ratio': alert.get('limit_up_ratio', 0),
'triggered_rules': json.dumps(alert.get('triggered_rules', [])),
})
saved += 1
except Exception as e:
print(f"保存失败: {e}")
return saved
# ==================== 主函数 ====================
def main():
parser = argparse.ArgumentParser(description='V2 回测')
parser.add_argument('--start', type=str, required=True, help='开始日期')
parser.add_argument('--end', type=str, default=None, help='结束日期')
parser.add_argument('--model_dir', type=str, default='ml/checkpoints_v2')
parser.add_argument('--baseline_dir', type=str, default='ml/data_v2/baselines')
parser.add_argument('--save', action='store_true', help='保存到数据库')
parser.add_argument('--lookahead', type=int, default=10, help='评估前瞻时间(分钟)')
args = parser.parse_args()
end_date = args.end or args.start
print("=" * 60)
print("V2 回测 - 时间片对齐 + 持续性确认")
print("=" * 60)
print(f"日期范围: {args.start} ~ {end_date}")
print(f"模型目录: {args.model_dir}")
print(f"评估前瞻: {args.lookahead} 分钟")
# 初始化检测器
detector = AnomalyDetectorV2(
model_dir=args.model_dir,
baseline_dir=args.baseline_dir
)
# 获取交易日
from prepare_data_v2 import get_trading_days
trading_days = get_trading_days(args.start, end_date)
if not trading_days:
print("无交易日")
return
print(f"交易日数: {len(trading_days)}")
# 回测统计
total_stats = {
'total_alerts': 0,
'accuracy_sum': 0,
'sustained_sum': 0,
'peak_sum': 0,
'day_count': 0,
}
all_alerts = []
for trade_date in tqdm(trading_days, desc="回测进度"):
# 检测异动
alerts = detector.detect(trade_date)
if not alerts:
continue
all_alerts.extend(alerts)
# 评估
raw_data = detector._compute_raw_features(trade_date)
if raw_data.empty:
continue
stats = evaluate_alerts(alerts, raw_data, args.lookahead)
if stats['evaluated_alerts'] > 0:
total_stats['total_alerts'] += stats['total_alerts']
total_stats['accuracy_sum'] += stats['accuracy'] * stats['evaluated_alerts']
total_stats['sustained_sum'] += stats['sustained_rate'] * stats['evaluated_alerts']
total_stats['peak_sum'] += stats['avg_peak'] * stats['evaluated_alerts']
total_stats['day_count'] += 1
print(f"\n[{trade_date}] 异动: {stats['total_alerts']}, "
f"准确率: {stats['accuracy']:.1%}, "
f"持续率: {stats['sustained_rate']:.1%}, "
f"峰值: {stats['avg_peak']:.2f}%")
# 汇总
print("\n" + "=" * 60)
print("回测汇总")
print("=" * 60)
if total_stats['total_alerts'] > 0:
avg_accuracy = total_stats['accuracy_sum'] / total_stats['total_alerts']
avg_sustained = total_stats['sustained_sum'] / total_stats['total_alerts']
avg_peak = total_stats['peak_sum'] / total_stats['total_alerts']
print(f"总异动数: {total_stats['total_alerts']}")
print(f"回测天数: {total_stats['day_count']}")
print(f"平均每天: {total_stats['total_alerts'] / max(1, total_stats['day_count']):.1f}")
print(f"方向准确率: {avg_accuracy:.1%}")
print(f"持续率: {avg_sustained:.1%}")
print(f"平均峰值: {avg_peak:.2f}%")
else:
print("无异动检测结果")
# 保存
if args.save and all_alerts:
print(f"\n保存 {len(all_alerts)} 条异动到数据库...")
saved = save_alerts_to_mysql(all_alerts)
print(f"保存完成: {saved}")
print("=" * 60)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,31 @@
{
"seq_len": 10,
"stride": 2,
"train_end_date": "2025-06-30",
"val_end_date": "2025-09-30",
"features": [
"alpha_zscore",
"amt_zscore",
"rank_zscore",
"momentum_3m",
"momentum_5m",
"limit_up_ratio"
],
"batch_size": 32768,
"epochs": 150,
"learning_rate": 0.0006,
"weight_decay": 1e-05,
"gradient_clip": 1.0,
"patience": 15,
"min_delta": 1e-06,
"model": {
"n_features": 6,
"hidden_dim": 32,
"latent_dim": 4,
"num_layers": 1,
"dropout": 0.2,
"bidirectional": true
},
"clip_value": 5.0,
"threshold_percentiles": [90, 95, 99]
}

View File

@@ -0,0 +1,8 @@
{
"p90": 0.15,
"p95": 0.25,
"p99": 0.50,
"mean": 0.08,
"std": 0.12,
"median": 0.06
}

716
ml/detector_v2.py Normal file
View File

@@ -0,0 +1,716 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
异动检测器 V2 - 基于时间片对齐 + 持续性确认
核心改进:
1. Z-Score 特征:相对于同时间片历史的偏离
2. 短序列 LSTM10分钟序列开盘即可用
3. 持续性确认5分钟窗口内60%时刻超标才确认为异动
检测流程:
1. 计算当前时刻的 Z-Score对比同时间片历史基线
2. 构建最近10分钟的 Z-Score 序列
3. LSTM 计算重构误差ML分数
4. 规则评分(基于 Z-Score 的规则)
5. 滑动窗口确认最近5分钟内是否有足够多的时刻超标
6. 只有通过持续性确认的才输出为异动
"""
import os
import sys
import json
import pickle
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Tuple
from collections import defaultdict, deque
import numpy as np
import pandas as pd
import torch
from sqlalchemy import create_engine, text
from elasticsearch import Elasticsearch
from clickhouse_driver import Client
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from ml.model import TransformerAutoencoder
# ==================== 配置 ====================
MYSQL_ENGINE = create_engine(
"mysql+pymysql://root:Zzl5588161!@192.168.1.5:3306/stock",
echo=False
)
ES_CLIENT = Elasticsearch(['http://127.0.0.1:9200'])
ES_INDEX = 'concept_library_v3'
CLICKHOUSE_CONFIG = {
'host': '127.0.0.1',
'port': 9000,
'user': 'default',
'password': 'Zzl33818!',
'database': 'stock'
}
REFERENCE_INDEX = '000001.SH'
# 检测配置
CONFIG = {
# 序列配置
'seq_len': 10, # LSTM 序列长度(分钟)
# 持续性确认配置(核心!)
'confirm_window': 5, # 确认窗口(分钟)
'confirm_ratio': 0.6, # 确认比例60%时刻需要超标)
# Z-Score 阈值
'alpha_zscore_threshold': 2.0, # Alpha Z-Score 阈值
'amt_zscore_threshold': 2.5, # 成交额 Z-Score 阈值
# 融合权重
'rule_weight': 0.5,
'ml_weight': 0.5,
# 触发阈值
'rule_trigger': 60,
'ml_trigger': 70,
'fusion_trigger': 50,
# 冷却期
'cooldown_minutes': 10,
'max_alerts_per_minute': 15,
# Z-Score 截断
'zscore_clip': 5.0,
}
# V2 特征列表
FEATURES_V2 = [
'alpha_zscore', 'amt_zscore', 'rank_zscore',
'momentum_3m', 'momentum_5m', 'limit_up_ratio'
]
# ==================== 工具函数 ====================
def get_ch_client():
return Client(**CLICKHOUSE_CONFIG)
def code_to_ch_format(code: str) -> str:
if not code or len(code) != 6 or not code.isdigit():
return None
if code.startswith('6'):
return f"{code}.SH"
elif code.startswith('0') or code.startswith('3'):
return f"{code}.SZ"
else:
return f"{code}.BJ"
def time_to_slot(ts) -> str:
"""时间戳转时间片HH:MM"""
if isinstance(ts, str):
return ts
return ts.strftime('%H:%M')
# ==================== 基线加载 ====================
def load_baselines(baseline_dir: str = 'ml/data_v2/baselines') -> Dict[str, pd.DataFrame]:
"""加载时间片基线"""
baseline_file = os.path.join(baseline_dir, 'baselines.pkl')
if os.path.exists(baseline_file):
with open(baseline_file, 'rb') as f:
return pickle.load(f)
return {}
# ==================== 规则评分(基于 Z-Score====================
def score_rules_zscore(row: Dict) -> Tuple[float, List[str]]:
"""
基于 Z-Score 的规则评分
设计思路Z-Score 已经标准化,直接用阈值判断
"""
score = 0.0
triggered = []
alpha_zscore = row.get('alpha_zscore', 0)
amt_zscore = row.get('amt_zscore', 0)
rank_zscore = row.get('rank_zscore', 0)
momentum_3m = row.get('momentum_3m', 0)
momentum_5m = row.get('momentum_5m', 0)
limit_up_ratio = row.get('limit_up_ratio', 0)
alpha_zscore_abs = abs(alpha_zscore)
amt_zscore_abs = abs(amt_zscore)
# ========== Alpha Z-Score 规则 ==========
if alpha_zscore_abs >= 4.0:
score += 25
triggered.append('alpha_zscore_extreme')
elif alpha_zscore_abs >= 3.0:
score += 18
triggered.append('alpha_zscore_strong')
elif alpha_zscore_abs >= 2.0:
score += 10
triggered.append('alpha_zscore_moderate')
# ========== 成交额 Z-Score 规则 ==========
if amt_zscore >= 4.0:
score += 20
triggered.append('amt_zscore_extreme')
elif amt_zscore >= 3.0:
score += 12
triggered.append('amt_zscore_strong')
elif amt_zscore >= 2.0:
score += 6
triggered.append('amt_zscore_moderate')
# ========== 排名 Z-Score 规则 ==========
if abs(rank_zscore) >= 3.0:
score += 15
triggered.append('rank_zscore_extreme')
elif abs(rank_zscore) >= 2.0:
score += 8
triggered.append('rank_zscore_strong')
# ========== 动量规则 ==========
if momentum_3m >= 1.0:
score += 12
triggered.append('momentum_3m_strong')
elif momentum_3m >= 0.5:
score += 6
triggered.append('momentum_3m_moderate')
if momentum_5m >= 1.5:
score += 10
triggered.append('momentum_5m_strong')
# ========== 涨停比例规则 ==========
if limit_up_ratio >= 0.3:
score += 20
triggered.append('limit_up_extreme')
elif limit_up_ratio >= 0.15:
score += 12
triggered.append('limit_up_strong')
elif limit_up_ratio >= 0.08:
score += 5
triggered.append('limit_up_moderate')
# ========== 组合规则 ==========
# Alpha Z-Score + 成交额放大
if alpha_zscore_abs >= 2.0 and amt_zscore >= 2.0:
score += 15
triggered.append('combo_alpha_amt')
# Alpha Z-Score + 涨停
if alpha_zscore_abs >= 2.0 and limit_up_ratio >= 0.1:
score += 12
triggered.append('combo_alpha_limitup')
return min(score, 100), triggered
# ==================== ML 评分器 ====================
class MLScorerV2:
"""V2 ML 评分器"""
def __init__(self, model_dir: str = 'ml/checkpoints_v2'):
self.model_dir = model_dir
self.model = None
self.thresholds = None
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self._load_model()
def _load_model(self):
"""加载模型和阈值"""
model_path = os.path.join(self.model_dir, 'best_model.pt')
threshold_path = os.path.join(self.model_dir, 'thresholds.json')
config_path = os.path.join(self.model_dir, 'config.json')
if not os.path.exists(model_path):
print(f"警告: 模型文件不存在: {model_path}")
return
# 加载配置
with open(config_path, 'r') as f:
config = json.load(f)
# 创建模型
model_config = config.get('model', {})
self.model = TransformerAutoencoder(**model_config)
# 加载权重
checkpoint = torch.load(model_path, map_location=self.device)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.model.to(self.device)
self.model.eval()
# 加载阈值
if os.path.exists(threshold_path):
with open(threshold_path, 'r') as f:
self.thresholds = json.load(f)
print(f"V2 模型加载完成: {model_path}")
@torch.no_grad()
def score_batch(self, sequences: np.ndarray) -> np.ndarray:
"""
批量计算 ML 分数
返回 0-100 的分数,越高越异常
"""
if self.model is None:
return np.zeros(len(sequences))
# 转换为 tensor
x = torch.FloatTensor(sequences).to(self.device)
# 计算重构误差
errors = self.model.compute_reconstruction_error(x, reduction='none')
# 取最后一个时刻的误差
last_errors = errors[:, -1].cpu().numpy()
# 转换为 0-100 分数
if self.thresholds:
p50 = self.thresholds.get('median', 0.1)
p99 = self.thresholds.get('p99', 1.0)
# 线性映射p50 -> 50分p99 -> 99分
scores = 50 + (last_errors - p50) / (p99 - p50) * 49
scores = np.clip(scores, 0, 100)
else:
# 没有阈值时,简单归一化
scores = last_errors * 100
scores = np.clip(scores, 0, 100)
return scores
# ==================== 实时数据管理器 ====================
class RealtimeDataManagerV2:
"""
V2 实时数据管理器
维护:
1. 每个概念的历史 Z-Score 序列(用于 LSTM 输入)
2. 每个概念的异动候选队列(用于持续性确认)
"""
def __init__(self, concepts: List[dict], baselines: Dict[str, pd.DataFrame]):
self.concepts = {c['concept_id']: c for c in concepts}
self.baselines = baselines
# 概念到股票的映射
self.concept_stocks = {c['concept_id']: set(c['stocks']) for c in concepts}
# 历史 Z-Score 序列(每个概念)
# {concept_id: deque([(timestamp, features_dict), ...], maxlen=seq_len)}
self.zscore_history = defaultdict(lambda: deque(maxlen=CONFIG['seq_len']))
# 异动候选队列(用于持续性确认)
# {concept_id: deque([(timestamp, score), ...], maxlen=confirm_window)}
self.anomaly_candidates = defaultdict(lambda: deque(maxlen=CONFIG['confirm_window']))
# 冷却期记录
self.cooldown = {}
# 上一次更新的时间戳
self.last_timestamp = None
def compute_zscore_features(
self,
concept_id: str,
timestamp,
alpha: float,
total_amt: float,
rank_pct: float,
limit_up_ratio: float
) -> Optional[Dict]:
"""计算单个概念单个时刻的 Z-Score 特征"""
if concept_id not in self.baselines:
return None
baseline = self.baselines[concept_id]
time_slot = time_to_slot(timestamp)
# 查找对应时间片的基线
bl_row = baseline[baseline['time_slot'] == time_slot]
if bl_row.empty:
return None
bl = bl_row.iloc[0]
# 检查样本量
if bl.get('sample_count', 0) < 10:
return None
# 计算 Z-Score
alpha_zscore = (alpha - bl['alpha_mean']) / bl['alpha_std']
amt_zscore = (total_amt - bl['amt_mean']) / bl['amt_std']
rank_zscore = (rank_pct - bl['rank_mean']) / bl['rank_std']
# 截断
clip = CONFIG['zscore_clip']
alpha_zscore = np.clip(alpha_zscore, -clip, clip)
amt_zscore = np.clip(amt_zscore, -clip, clip)
rank_zscore = np.clip(rank_zscore, -clip, clip)
# 计算动量(需要历史)
history = self.zscore_history[concept_id]
momentum_3m = 0
momentum_5m = 0
if len(history) >= 3:
recent_alphas = [h[1]['alpha'] for h in list(history)[-3:]]
older_alphas = [h[1]['alpha'] for h in list(history)[-6:-3]] if len(history) >= 6 else [alpha]
momentum_3m = np.mean(recent_alphas) - np.mean(older_alphas)
if len(history) >= 5:
recent_alphas = [h[1]['alpha'] for h in list(history)[-5:]]
older_alphas = [h[1]['alpha'] for h in list(history)[-10:-5]] if len(history) >= 10 else [alpha]
momentum_5m = np.mean(recent_alphas) - np.mean(older_alphas)
return {
'alpha': alpha,
'alpha_zscore': alpha_zscore,
'amt_zscore': amt_zscore,
'rank_zscore': rank_zscore,
'momentum_3m': momentum_3m,
'momentum_5m': momentum_5m,
'limit_up_ratio': limit_up_ratio,
'total_amt': total_amt,
'rank_pct': rank_pct,
}
def update(self, concept_id: str, timestamp, features: Dict):
"""更新概念的历史数据"""
self.zscore_history[concept_id].append((timestamp, features))
def get_sequence(self, concept_id: str) -> Optional[np.ndarray]:
"""获取用于 LSTM 的序列"""
history = self.zscore_history[concept_id]
if len(history) < CONFIG['seq_len']:
return None
# 提取特征
feature_list = []
for _, features in history:
feature_list.append([
features['alpha_zscore'],
features['amt_zscore'],
features['rank_zscore'],
features['momentum_3m'],
features['momentum_5m'],
features['limit_up_ratio'],
])
return np.array(feature_list)
def add_anomaly_candidate(self, concept_id: str, timestamp, score: float):
"""添加异动候选"""
self.anomaly_candidates[concept_id].append((timestamp, score))
def check_sustained_anomaly(self, concept_id: str, threshold: float) -> Tuple[bool, float]:
"""
检查是否为持续性异动
返回:(是否确认, 确认比例)
"""
candidates = self.anomaly_candidates[concept_id]
if len(candidates) < CONFIG['confirm_window']:
return False, 0.0
# 统计超过阈值的时刻数量
exceed_count = sum(1 for _, score in candidates if score >= threshold)
ratio = exceed_count / len(candidates)
return ratio >= CONFIG['confirm_ratio'], ratio
def check_cooldown(self, concept_id: str, timestamp) -> bool:
"""检查是否在冷却期"""
if concept_id not in self.cooldown:
return False
last_alert = self.cooldown[concept_id]
try:
diff = (timestamp - last_alert).total_seconds() / 60
return diff < CONFIG['cooldown_minutes']
except:
return False
def set_cooldown(self, concept_id: str, timestamp):
"""设置冷却期"""
self.cooldown[concept_id] = timestamp
# ==================== 异动检测器 V2 ====================
class AnomalyDetectorV2:
"""
V2 异动检测器
核心流程:
1. 获取实时数据
2. 计算 Z-Score 特征
3. 规则评分 + ML 评分
4. 持续性确认
5. 输出异动
"""
def __init__(
self,
model_dir: str = 'ml/checkpoints_v2',
baseline_dir: str = 'ml/data_v2/baselines'
):
# 加载概念
self.concepts = self._load_concepts()
# 加载基线
self.baselines = load_baselines(baseline_dir)
print(f"加载了 {len(self.baselines)} 个概念的基线")
# 初始化 ML 评分器
self.ml_scorer = MLScorerV2(model_dir)
# 初始化数据管理器
self.data_manager = RealtimeDataManagerV2(self.concepts, self.baselines)
# 收集所有股票
self.all_stocks = list(set(s for c in self.concepts for s in c['stocks']))
def _load_concepts(self) -> List[dict]:
"""从 ES 加载概念"""
concepts = []
query = {"query": {"match_all": {}}, "size": 100, "_source": ["concept_id", "concept", "stocks"]}
resp = ES_CLIENT.search(index=ES_INDEX, body=query, scroll='2m')
scroll_id = resp['_scroll_id']
hits = resp['hits']['hits']
while len(hits) > 0:
for hit in hits:
source = hit['_source']
stocks = []
if 'stocks' in source and isinstance(source['stocks'], list):
for stock in source['stocks']:
if isinstance(stock, dict) and 'code' in stock and stock['code']:
stocks.append(stock['code'])
if stocks:
concepts.append({
'concept_id': source.get('concept_id'),
'concept_name': source.get('concept'),
'stocks': stocks
})
resp = ES_CLIENT.scroll(scroll_id=scroll_id, scroll='2m')
scroll_id = resp['_scroll_id']
hits = resp['hits']['hits']
ES_CLIENT.clear_scroll(scroll_id=scroll_id)
print(f"加载了 {len(concepts)} 个概念")
return concepts
def detect(self, trade_date: str) -> List[Dict]:
"""
检测指定日期的异动
返回异动列表
"""
print(f"\n检测 {trade_date} 的异动...")
# 获取原始数据
raw_features = self._compute_raw_features(trade_date)
if raw_features.empty:
print("无数据")
return []
# 按时间排序
timestamps = sorted(raw_features['timestamp'].unique())
print(f"时间点数: {len(timestamps)}")
all_alerts = []
for ts in timestamps:
ts_data = raw_features[raw_features['timestamp'] == ts]
ts_alerts = self._process_timestamp(ts, ts_data, trade_date)
all_alerts.extend(ts_alerts)
print(f"共检测到 {len(all_alerts)} 个异动")
return all_alerts
def _compute_raw_features(self, trade_date: str) -> pd.DataFrame:
"""计算原始特征(同 prepare_data_v2"""
# 这里简化处理,直接调用数据准备逻辑
from prepare_data_v2 import compute_raw_concept_features
return compute_raw_concept_features(trade_date, self.concepts, self.all_stocks)
def _process_timestamp(self, timestamp, ts_data: pd.DataFrame, trade_date: str) -> List[Dict]:
"""处理单个时间戳"""
alerts = []
candidates = [] # (concept_id, features, rule_score, triggered_rules)
for _, row in ts_data.iterrows():
concept_id = row['concept_id']
# 计算 Z-Score 特征
features = self.data_manager.compute_zscore_features(
concept_id, timestamp,
row['alpha'], row['total_amt'], row['rank_pct'], row['limit_up_ratio']
)
if features is None:
continue
# 更新历史
self.data_manager.update(concept_id, timestamp, features)
# 规则评分
rule_score, triggered_rules = score_rules_zscore(features)
# 收集候选
candidates.append((concept_id, features, rule_score, triggered_rules))
if not candidates:
return []
# 批量 ML 评分
sequences = []
valid_candidates = []
for concept_id, features, rule_score, triggered_rules in candidates:
seq = self.data_manager.get_sequence(concept_id)
if seq is not None:
sequences.append(seq)
valid_candidates.append((concept_id, features, rule_score, triggered_rules))
if not sequences:
return []
sequences = np.array(sequences)
ml_scores = self.ml_scorer.score_batch(sequences)
# 融合评分 + 持续性确认
for i, (concept_id, features, rule_score, triggered_rules) in enumerate(valid_candidates):
ml_score = ml_scores[i]
final_score = CONFIG['rule_weight'] * rule_score + CONFIG['ml_weight'] * ml_score
# 判断是否触发
is_triggered = (
rule_score >= CONFIG['rule_trigger'] or
ml_score >= CONFIG['ml_trigger'] or
final_score >= CONFIG['fusion_trigger']
)
# 添加到候选队列
self.data_manager.add_anomaly_candidate(concept_id, timestamp, final_score)
if not is_triggered:
continue
# 检查冷却期
if self.data_manager.check_cooldown(concept_id, timestamp):
continue
# 持续性确认
is_sustained, confirm_ratio = self.data_manager.check_sustained_anomaly(
concept_id, CONFIG['fusion_trigger']
)
if not is_sustained:
continue
# 确认为异动!
self.data_manager.set_cooldown(concept_id, timestamp)
# 确定异动类型
alpha = features['alpha']
if alpha >= 1.5:
alert_type = 'surge_up'
elif alpha <= -1.5:
alert_type = 'surge_down'
elif features['amt_zscore'] >= 3.0:
alert_type = 'volume_spike'
else:
alert_type = 'surge'
# 确定触发原因
if rule_score >= CONFIG['rule_trigger']:
trigger_reason = f'规则({rule_score:.0f})+持续确认({confirm_ratio:.0%})'
elif ml_score >= CONFIG['ml_trigger']:
trigger_reason = f'ML({ml_score:.0f})+持续确认({confirm_ratio:.0%})'
else:
trigger_reason = f'融合({final_score:.0f})+持续确认({confirm_ratio:.0%})'
alerts.append({
'concept_id': concept_id,
'concept_name': self.data_manager.concepts.get(concept_id, {}).get('concept_name', concept_id),
'alert_time': timestamp,
'trade_date': trade_date,
'alert_type': alert_type,
'final_score': final_score,
'rule_score': rule_score,
'ml_score': ml_score,
'trigger_reason': trigger_reason,
'confirm_ratio': confirm_ratio,
'alpha': alpha,
'alpha_zscore': features['alpha_zscore'],
'amt_zscore': features['amt_zscore'],
'rank_zscore': features['rank_zscore'],
'momentum_3m': features['momentum_3m'],
'momentum_5m': features['momentum_5m'],
'limit_up_ratio': features['limit_up_ratio'],
'triggered_rules': triggered_rules,
})
# 每分钟最多 N 个
if len(alerts) > CONFIG['max_alerts_per_minute']:
alerts = sorted(alerts, key=lambda x: x['final_score'], reverse=True)
alerts = alerts[:CONFIG['max_alerts_per_minute']]
return alerts
# ==================== 主函数 ====================
def main():
import argparse
parser = argparse.ArgumentParser(description='V2 异动检测器')
parser.add_argument('--date', type=str, default=None, help='检测日期(默认今天)')
parser.add_argument('--model_dir', type=str, default='ml/checkpoints_v2')
parser.add_argument('--baseline_dir', type=str, default='ml/data_v2/baselines')
args = parser.parse_args()
trade_date = args.date or datetime.now().strftime('%Y-%m-%d')
detector = AnomalyDetectorV2(
model_dir=args.model_dir,
baseline_dir=args.baseline_dir
)
alerts = detector.detect(trade_date)
print(f"\n检测结果:")
for alert in alerts[:20]:
print(f" [{alert['alert_time'].strftime('%H:%M') if hasattr(alert['alert_time'], 'strftime') else alert['alert_time']}] "
f"{alert['concept_name']} ({alert['alert_type']}) "
f"分数={alert['final_score']:.0f} "
f"确认率={alert['confirm_ratio']:.0%}")
if len(alerts) > 20:
print(f" ... 共 {len(alerts)} 个异动")
if __name__ == "__main__":
main()

View File

@@ -85,9 +85,12 @@ class LSTMAutoencoder(nn.Module):
nn.Tanh(), # 限制范围,增加约束 nn.Tanh(), # 限制范围,增加约束
) )
# 使用 LeakyReLU 替代 ReLU
# 原因Z-Score 数据范围是 [-5, +5]ReLU 会截断负值,丢失跌幅信息
# LeakyReLU 保留负值信号(乘以 0.1
self.bottleneck_up = nn.Sequential( self.bottleneck_up = nn.Sequential(
nn.Linear(latent_dim, hidden_dim), nn.Linear(latent_dim, hidden_dim),
nn.ReLU(), nn.LeakyReLU(negative_slope=0.1),
) )
# Decoder: 单向 LSTM # Decoder: 单向 LSTM

View File

@@ -26,7 +26,9 @@ import hashlib
import json import json
import logging import logging
from typing import Dict, List, Set, Tuple from typing import Dict, List, Set, Tuple
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ProcessPoolExecutor, as_completed
from multiprocessing import Manager
import multiprocessing
import warnings import warnings
warnings.filterwarnings('ignore') warnings.filterwarnings('ignore')
@@ -128,7 +130,7 @@ def get_all_concepts() -> List[dict]:
hits = resp['hits']['hits'] hits = resp['hits']['hits']
ES_CLIENT.clear_scroll(scroll_id=scroll_id) ES_CLIENT.clear_scroll(scroll_id=scroll_id)
logger.info(f"获取到 {len(concepts)} 个概念") print(f"获取到 {len(concepts)} 个概念")
return concepts return concepts
@@ -148,7 +150,7 @@ def get_trading_days(start_date: str, end_date: str) -> List[str]:
result = client.execute(query) result = client.execute(query)
days = [row[0].strftime('%Y-%m-%d') for row in result] days = [row[0].strftime('%Y-%m-%d') for row in result]
logger.info(f"找到 {len(days)} 个交易日: {days[0]} ~ {days[-1]}") print(f"找到 {len(days)} 个交易日: {days[0]} ~ {days[-1]}")
return days return days
@@ -223,21 +225,23 @@ def get_daily_index_data(trade_date: str, index_code: str = REFERENCE_INDEX) ->
def get_prev_close(stock_codes: List[str], trade_date: str) -> Dict[str, float]: def get_prev_close(stock_codes: List[str], trade_date: str) -> Dict[str, float]:
"""获取昨收价""" """获取昨收价(上一交易日的收盘价 F007N"""
valid_codes = [c for c in stock_codes if c and len(c) == 6 and c.isdigit()] valid_codes = [c for c in stock_codes if c and len(c) == 6 and c.isdigit()]
if not valid_codes: if not valid_codes:
return {} return {}
codes_str = "','".join(valid_codes) codes_str = "','".join(valid_codes)
# 注意F007N 是"最近成交价"即当日收盘价F002N 是"昨日收盘价"
# 我们需要查上一交易日的 F007N那天的收盘价作为今天的昨收
query = f""" query = f"""
SELECT SECCODE, F002N SELECT SECCODE, F007N
FROM ea_trade FROM ea_trade
WHERE SECCODE IN ('{codes_str}') WHERE SECCODE IN ('{codes_str}')
AND TRADEDATE = ( AND TRADEDATE = (
SELECT MAX(TRADEDATE) FROM ea_trade WHERE TRADEDATE < '{trade_date}' SELECT MAX(TRADEDATE) FROM ea_trade WHERE TRADEDATE < '{trade_date}'
) )
AND F002N IS NOT NULL AND F002N > 0 AND F007N IS NOT NULL AND F007N > 0
""" """
try: try:
@@ -245,7 +249,7 @@ def get_prev_close(stock_codes: List[str], trade_date: str) -> Dict[str, float]:
result = conn.execute(text(query)) result = conn.execute(text(query))
return {row[0]: float(row[1]) for row in result if row[1]} return {row[0]: float(row[1]) for row in result if row[1]}
except Exception as e: except Exception as e:
logger.error(f"获取昨收价失败: {e}") print(f"获取昨收价失败: {e}")
return {} return {}
@@ -264,7 +268,7 @@ def get_index_prev_close(trade_date: str, index_code: str = REFERENCE_INDEX) ->
if result and result[0]: if result and result[0]:
return float(result[0]) return float(result[0])
except Exception as e: except Exception as e:
logger.error(f"获取指数昨收失败: {e}") print(f"获取指数昨收失败: {e}")
return None return None
@@ -285,25 +289,19 @@ def compute_daily_features(
""" """
# 1. 获取数据 # 1. 获取数据
logger.info(f" 获取股票数据...")
stock_df = get_daily_stock_data(trade_date, all_stocks) stock_df = get_daily_stock_data(trade_date, all_stocks)
if stock_df.empty: if stock_df.empty:
logger.warning(f" 无股票数据")
return pd.DataFrame() return pd.DataFrame()
logger.info(f" 获取指数数据...")
index_df = get_daily_index_data(trade_date) index_df = get_daily_index_data(trade_date)
if index_df.empty: if index_df.empty:
logger.warning(f" 无指数数据")
return pd.DataFrame() return pd.DataFrame()
# 2. 获取昨收价 # 2. 获取昨收价
logger.info(f" 获取昨收价...")
prev_close = get_prev_close(all_stocks, trade_date) prev_close = get_prev_close(all_stocks, trade_date)
index_prev_close = get_index_prev_close(trade_date) index_prev_close = get_index_prev_close(trade_date)
if not prev_close or not index_prev_close: if not prev_close or not index_prev_close:
logger.warning(f" 无昨收价数据")
return pd.DataFrame() return pd.DataFrame()
# 3. 计算股票涨跌幅和成交额 # 3. 计算股票涨跌幅和成交额
@@ -317,7 +315,6 @@ def compute_daily_features(
# 5. 获取所有时间点 # 5. 获取所有时间点
timestamps = sorted(stock_df['timestamp'].unique()) timestamps = sorted(stock_df['timestamp'].unique())
logger.info(f" 时间点数: {len(timestamps)}")
# 6. 按时间点计算概念特征 # 6. 按时间点计算概念特征
results = [] results = []
@@ -414,87 +411,126 @@ def compute_daily_features(
if amt_delta_std > 0: if amt_delta_std > 0:
final_df['amt_delta'] = final_df['amt_delta'] / amt_delta_std final_df['amt_delta'] = final_df['amt_delta'] / amt_delta_std
logger.info(f" 计算完成: {len(final_df)} 条记录")
return final_df return final_df
# ==================== 主流程 ==================== # ==================== 主流程 ====================
def process_single_day(trade_date: str, concepts: List[dict], all_stocks: List[str]) -> str: def process_single_day(args) -> Tuple[str, bool]:
"""处理单个交易日""" """
处理单个交易日(多进程版本)
Args:
args: (trade_date, concepts, all_stocks) 元组
Returns:
(trade_date, success) 元组
"""
trade_date, concepts, all_stocks = args
output_file = os.path.join(OUTPUT_DIR, f'features_{trade_date}.parquet') output_file = os.path.join(OUTPUT_DIR, f'features_{trade_date}.parquet')
# 检查是否已处理 # 检查是否已处理
if os.path.exists(output_file): if os.path.exists(output_file):
logger.info(f"[{trade_date}] 已存在,跳过") print(f"[{trade_date}] 已存在,跳过")
return output_file return (trade_date, True)
logger.info(f"[{trade_date}] 开始处理...") print(f"[{trade_date}] 开始处理...")
try: try:
df = compute_daily_features(trade_date, concepts, all_stocks) df = compute_daily_features(trade_date, concepts, all_stocks)
if df.empty: if df.empty:
logger.warning(f"[{trade_date}] 无数据") print(f"[{trade_date}] 无数据")
return None return (trade_date, False)
# 保存 # 保存
df.to_parquet(output_file, index=False) df.to_parquet(output_file, index=False)
logger.info(f"[{trade_date}] 保存完成: {output_file}") print(f"[{trade_date}] 保存完成")
return output_file return (trade_date, True)
except Exception as e: except Exception as e:
logger.error(f"[{trade_date}] 处理失败: {e}") print(f"[{trade_date}] 处理失败: {e}")
import traceback import traceback
traceback.print_exc() traceback.print_exc()
return None return (trade_date, False)
def main(): def main():
import argparse import argparse
from tqdm import tqdm
parser = argparse.ArgumentParser(description='准备训练数据') parser = argparse.ArgumentParser(description='准备训练数据')
parser.add_argument('--start', type=str, default='2022-01-01', help='开始日期') parser.add_argument('--start', type=str, default='2022-01-01', help='开始日期')
parser.add_argument('--end', type=str, default=None, help='结束日期(默认今天)') parser.add_argument('--end', type=str, default=None, help='结束日期(默认今天)')
parser.add_argument('--workers', type=int, default=1, help='并行建议1避免数据库压力') parser.add_argument('--workers', type=int, default=18, help='并行进程数默认18')
parser.add_argument('--force', action='store_true', help='强制重新处理已存在的文件')
args = parser.parse_args() args = parser.parse_args()
end_date = args.end or datetime.now().strftime('%Y-%m-%d') end_date = args.end or datetime.now().strftime('%Y-%m-%d')
logger.info("=" * 60) print("=" * 60)
logger.info("数据准备 - Transformer Autoencoder 训练数据") print("数据准备 - Transformer Autoencoder 训练数据")
logger.info("=" * 60) print("=" * 60)
logger.info(f"日期范围: {args.start} ~ {end_date}") print(f"日期范围: {args.start} ~ {end_date}")
print(f"并行进程数: {args.workers}")
# 1. 获取概念列表 # 1. 获取概念列表
concepts = get_all_concepts() concepts = get_all_concepts()
# 收集所有股票 # 收集所有股票
all_stocks = list(set(s for c in concepts for s in c['stocks'])) all_stocks = list(set(s for c in concepts for s in c['stocks']))
logger.info(f"股票总数: {len(all_stocks)}") print(f"股票总数: {len(all_stocks)}")
# 2. 获取交易日列表 # 2. 获取交易日列表
trading_days = get_trading_days(args.start, end_date) trading_days = get_trading_days(args.start, end_date)
if not trading_days: if not trading_days:
logger.error("无交易日数据") print("无交易日数据")
return return
# 3. 处理每个交易日 # 如果强制模式,删除已有文件
logger.info(f"\n开始处理 {len(trading_days)} 个交易日...") if args.force:
for trade_date in trading_days:
output_file = os.path.join(OUTPUT_DIR, f'features_{trade_date}.parquet')
if os.path.exists(output_file):
os.remove(output_file)
print(f"删除已有文件: {output_file}")
# 3. 准备任务参数
tasks = [(trade_date, concepts, all_stocks) for trade_date in trading_days]
print(f"\n开始处理 {len(trading_days)} 个交易日({args.workers} 进程并行)...")
# 4. 多进程处理
success_count = 0 success_count = 0
for i, trade_date in enumerate(trading_days): failed_dates = []
logger.info(f"\n[{i+1}/{len(trading_days)}] {trade_date}")
result = process_single_day(trade_date, concepts, all_stocks)
if result:
success_count += 1
logger.info("\n" + "=" * 60) with ProcessPoolExecutor(max_workers=args.workers) as executor:
logger.info(f"处理完成: {success_count}/{len(trading_days)} 个交易日") # 提交所有任务
logger.info(f"数据保存在: {OUTPUT_DIR}") futures = {executor.submit(process_single_day, task): task[0] for task in tasks}
logger.info("=" * 60)
# 使用 tqdm 显示进度
with tqdm(total=len(futures), desc="处理进度", unit="") as pbar:
for future in as_completed(futures):
trade_date = futures[future]
try:
result_date, success = future.result()
if success:
success_count += 1
else:
failed_dates.append(result_date)
except Exception as e:
print(f"\n[{trade_date}] 进程异常: {e}")
failed_dates.append(trade_date)
pbar.update(1)
print("\n" + "=" * 60)
print(f"处理完成: {success_count}/{len(trading_days)} 个交易日")
if failed_dates:
print(f"失败日期: {failed_dates[:10]}{'...' if len(failed_dates) > 10 else ''}")
print(f"数据保存在: {OUTPUT_DIR}")
print("=" * 60)
if __name__ == "__main__": if __name__ == "__main__":

715
ml/prepare_data_v2.py Normal file
View File

@@ -0,0 +1,715 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
数据准备 V2 - 基于时间片对齐的特征计算(修复版)
核心改进:
1. 时间片对齐9:35 和历史的 9:35 比而不是和前30分钟比
2. Z-Score 特征:相对于同时间片历史分布的偏离程度
3. 滚动窗口基线:每个日期使用它之前 N 天的数据作为基线(不是固定的最后 N 天!)
4. 基于 Z-Score 的动量:消除一天内波动率异构性
修复:
- 滚动窗口基线:避免未来数据泄露
- Z-Score 动量:消除早盘/尾盘波动率差异
- 进程级数据库单例:避免连接池爆炸
"""
import os
import sys
import numpy as np
import pandas as pd
from datetime import datetime, timedelta
from sqlalchemy import create_engine, text
from elasticsearch import Elasticsearch
from clickhouse_driver import Client
from concurrent.futures import ProcessPoolExecutor, as_completed
from typing import Dict, List, Tuple, Optional
from tqdm import tqdm
from collections import defaultdict
import warnings
import pickle
warnings.filterwarnings('ignore')
# ==================== 配置 ====================
MYSQL_URL = "mysql+pymysql://root:Zzl5588161!@192.168.1.5:3306/stock"
ES_HOST = 'http://127.0.0.1:9200'
ES_INDEX = 'concept_library_v3'
CLICKHOUSE_CONFIG = {
'host': '127.0.0.1',
'port': 9000,
'user': 'default',
'password': 'Zzl33818!',
'database': 'stock'
}
REFERENCE_INDEX = '000001.SH'
# 输出目录
OUTPUT_DIR = os.path.join(os.path.dirname(__file__), 'data_v2')
BASELINE_DIR = os.path.join(OUTPUT_DIR, 'baselines')
RAW_CACHE_DIR = os.path.join(OUTPUT_DIR, 'raw_cache')
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(BASELINE_DIR, exist_ok=True)
os.makedirs(RAW_CACHE_DIR, exist_ok=True)
# 特征配置
CONFIG = {
'baseline_days': 20, # 滚动窗口大小
'min_baseline_samples': 10, # 最少需要10个样本才算有效基线
'limit_up_threshold': 9.8,
'limit_down_threshold': -9.8,
'zscore_clip': 5.0,
}
# 特征列表
FEATURES_V2 = [
'alpha', 'alpha_zscore', 'amt_zscore', 'rank_zscore',
'momentum_3m', 'momentum_5m', 'limit_up_ratio',
]
# ==================== 进程级单例(避免连接池爆炸)====================
# 进程级全局变量
_process_mysql_engine = None
_process_es_client = None
_process_ch_client = None
def init_process_connections():
"""进程初始化时调用,创建连接(单例)"""
global _process_mysql_engine, _process_es_client, _process_ch_client
_process_mysql_engine = create_engine(MYSQL_URL, echo=False, pool_pre_ping=True, pool_size=5)
_process_es_client = Elasticsearch([ES_HOST])
_process_ch_client = Client(**CLICKHOUSE_CONFIG)
def get_mysql_engine():
"""获取进程级 MySQL Engine单例"""
global _process_mysql_engine
if _process_mysql_engine is None:
_process_mysql_engine = create_engine(MYSQL_URL, echo=False, pool_pre_ping=True, pool_size=5)
return _process_mysql_engine
def get_es_client():
"""获取进程级 ES 客户端(单例)"""
global _process_es_client
if _process_es_client is None:
_process_es_client = Elasticsearch([ES_HOST])
return _process_es_client
def get_ch_client():
"""获取进程级 ClickHouse 客户端(单例)"""
global _process_ch_client
if _process_ch_client is None:
_process_ch_client = Client(**CLICKHOUSE_CONFIG)
return _process_ch_client
# ==================== 工具函数 ====================
def code_to_ch_format(code: str) -> str:
if not code or len(code) != 6 or not code.isdigit():
return None
if code.startswith('6'):
return f"{code}.SH"
elif code.startswith('0') or code.startswith('3'):
return f"{code}.SZ"
else:
return f"{code}.BJ"
def time_to_slot(ts) -> str:
"""将时间戳转换为时间片HH:MM格式"""
if isinstance(ts, str):
return ts
return ts.strftime('%H:%M')
# ==================== 获取概念列表 ====================
def get_all_concepts() -> List[dict]:
"""从ES获取所有叶子概念"""
es_client = get_es_client()
concepts = []
query = {
"query": {"match_all": {}},
"size": 100,
"_source": ["concept_id", "concept", "stocks"]
}
resp = es_client.search(index=ES_INDEX, body=query, scroll='2m')
scroll_id = resp['_scroll_id']
hits = resp['hits']['hits']
while len(hits) > 0:
for hit in hits:
source = hit['_source']
stocks = []
if 'stocks' in source and isinstance(source['stocks'], list):
for stock in source['stocks']:
if isinstance(stock, dict) and 'code' in stock and stock['code']:
stocks.append(stock['code'])
if stocks:
concepts.append({
'concept_id': source.get('concept_id'),
'concept_name': source.get('concept'),
'stocks': stocks
})
resp = es_client.scroll(scroll_id=scroll_id, scroll='2m')
scroll_id = resp['_scroll_id']
hits = resp['hits']['hits']
es_client.clear_scroll(scroll_id=scroll_id)
print(f"获取到 {len(concepts)} 个概念")
return concepts
# ==================== 获取交易日列表 ====================
def get_trading_days(start_date: str, end_date: str) -> List[str]:
"""获取交易日列表"""
client = get_ch_client()
query = f"""
SELECT DISTINCT toDate(timestamp) as trade_date
FROM stock_minute
WHERE toDate(timestamp) >= '{start_date}'
AND toDate(timestamp) <= '{end_date}'
ORDER BY trade_date
"""
result = client.execute(query)
days = [row[0].strftime('%Y-%m-%d') for row in result]
if days:
print(f"找到 {len(days)} 个交易日: {days[0]} ~ {days[-1]}")
return days
# ==================== 获取昨收价 ====================
def get_prev_close(stock_codes: List[str], trade_date: str) -> Dict[str, float]:
"""获取昨收价(上一交易日的收盘价 F007N"""
valid_codes = [c for c in stock_codes if c and len(c) == 6 and c.isdigit()]
if not valid_codes:
return {}
codes_str = "','".join(valid_codes)
query = f"""
SELECT SECCODE, F007N
FROM ea_trade
WHERE SECCODE IN ('{codes_str}')
AND TRADEDATE = (
SELECT MAX(TRADEDATE) FROM ea_trade WHERE TRADEDATE < '{trade_date}'
)
AND F007N IS NOT NULL AND F007N > 0
"""
try:
engine = get_mysql_engine()
with engine.connect() as conn:
result = conn.execute(text(query))
return {row[0]: float(row[1]) for row in result if row[1]}
except Exception as e:
print(f"获取昨收价失败: {e}")
return {}
def get_index_prev_close(trade_date: str, index_code: str = REFERENCE_INDEX) -> float:
"""获取指数昨收价"""
code_no_suffix = index_code.split('.')[0]
try:
engine = get_mysql_engine()
with engine.connect() as conn:
result = conn.execute(text("""
SELECT F006N FROM ea_exchangetrade
WHERE INDEXCODE = :code AND TRADEDATE < :today
ORDER BY TRADEDATE DESC LIMIT 1
"""), {'code': code_no_suffix, 'today': trade_date}).fetchone()
if result and result[0]:
return float(result[0])
except Exception as e:
print(f"获取指数昨收失败: {e}")
return None
# ==================== 获取分钟数据 ====================
def get_daily_stock_data(trade_date: str, stock_codes: List[str]) -> pd.DataFrame:
"""获取单日所有股票的分钟数据"""
client = get_ch_client()
ch_codes = []
code_map = {}
for code in stock_codes:
ch_code = code_to_ch_format(code)
if ch_code:
ch_codes.append(ch_code)
code_map[ch_code] = code
if not ch_codes:
return pd.DataFrame()
ch_codes_str = "','".join(ch_codes)
query = f"""
SELECT code, timestamp, close, volume, amt
FROM stock_minute
WHERE toDate(timestamp) = '{trade_date}'
AND code IN ('{ch_codes_str}')
ORDER BY code, timestamp
"""
result = client.execute(query)
if not result:
return pd.DataFrame()
df = pd.DataFrame(result, columns=['ch_code', 'timestamp', 'close', 'volume', 'amt'])
df['code'] = df['ch_code'].map(code_map)
df = df.dropna(subset=['code'])
return df[['code', 'timestamp', 'close', 'volume', 'amt']]
def get_daily_index_data(trade_date: str, index_code: str = REFERENCE_INDEX) -> pd.DataFrame:
"""获取单日指数分钟数据"""
client = get_ch_client()
query = f"""
SELECT timestamp, close, volume, amt
FROM index_minute
WHERE toDate(timestamp) = '{trade_date}'
AND code = '{index_code}'
ORDER BY timestamp
"""
result = client.execute(query)
if not result:
return pd.DataFrame()
df = pd.DataFrame(result, columns=['timestamp', 'close', 'volume', 'amt'])
return df
# ==================== 计算原始概念特征(单日)====================
def compute_raw_concept_features(
trade_date: str,
concepts: List[dict],
all_stocks: List[str]
) -> pd.DataFrame:
"""计算单日概念的原始特征alpha, amt, rank_pct, limit_up_ratio"""
# 检查缓存
cache_file = os.path.join(RAW_CACHE_DIR, f'raw_{trade_date}.parquet')
if os.path.exists(cache_file):
return pd.read_parquet(cache_file)
# 获取数据
stock_df = get_daily_stock_data(trade_date, all_stocks)
if stock_df.empty:
return pd.DataFrame()
index_df = get_daily_index_data(trade_date)
if index_df.empty:
return pd.DataFrame()
# 获取昨收价
prev_close = get_prev_close(all_stocks, trade_date)
index_prev_close = get_index_prev_close(trade_date)
if not prev_close or not index_prev_close:
return pd.DataFrame()
# 计算涨跌幅
stock_df['prev_close'] = stock_df['code'].map(prev_close)
stock_df = stock_df.dropna(subset=['prev_close'])
stock_df['change_pct'] = (stock_df['close'] - stock_df['prev_close']) / stock_df['prev_close'] * 100
index_df['change_pct'] = (index_df['close'] - index_prev_close) / index_prev_close * 100
index_change_map = dict(zip(index_df['timestamp'], index_df['change_pct']))
# 获取所有时间点
timestamps = sorted(stock_df['timestamp'].unique())
# 概念到股票的映射
concept_stocks = {c['concept_id']: set(c['stocks']) for c in concepts}
results = []
for ts in timestamps:
ts_stock_data = stock_df[stock_df['timestamp'] == ts]
index_change = index_change_map.get(ts, 0)
stock_change = dict(zip(ts_stock_data['code'], ts_stock_data['change_pct']))
stock_amt = dict(zip(ts_stock_data['code'], ts_stock_data['amt']))
concept_features = []
for concept_id, stocks in concept_stocks.items():
concept_changes = [stock_change[s] for s in stocks if s in stock_change]
concept_amts = [stock_amt.get(s, 0) for s in stocks if s in stock_change]
if not concept_changes:
continue
avg_change = np.mean(concept_changes)
total_amt = sum(concept_amts)
alpha = avg_change - index_change
limit_up_count = sum(1 for c in concept_changes if c >= CONFIG['limit_up_threshold'])
limit_up_ratio = limit_up_count / len(concept_changes)
concept_features.append({
'concept_id': concept_id,
'alpha': alpha,
'total_amt': total_amt,
'limit_up_ratio': limit_up_ratio,
'stock_count': len(concept_changes),
})
if not concept_features:
continue
concept_df = pd.DataFrame(concept_features)
concept_df['rank_pct'] = concept_df['alpha'].rank(pct=True)
concept_df['timestamp'] = ts
concept_df['time_slot'] = time_to_slot(ts)
concept_df['trade_date'] = trade_date
results.append(concept_df)
if not results:
return pd.DataFrame()
result_df = pd.concat(results, ignore_index=True)
# 保存缓存
result_df.to_parquet(cache_file, index=False)
return result_df
# ==================== 滚动窗口基线计算 ====================
def compute_rolling_baseline(
historical_data: pd.DataFrame,
concept_id: str
) -> Dict[str, Dict]:
"""
计算单个概念的滚动基线
返回: {time_slot: {alpha_mean, alpha_std, amt_mean, amt_std, rank_mean, rank_std, sample_count}}
"""
if historical_data.empty:
return {}
concept_data = historical_data[historical_data['concept_id'] == concept_id]
if concept_data.empty:
return {}
baseline_dict = {}
for time_slot, group in concept_data.groupby('time_slot'):
if len(group) < CONFIG['min_baseline_samples']:
continue
alpha_std = group['alpha'].std()
amt_std = group['total_amt'].std()
rank_std = group['rank_pct'].std()
baseline_dict[time_slot] = {
'alpha_mean': group['alpha'].mean(),
'alpha_std': max(alpha_std if pd.notna(alpha_std) else 1.0, 0.1),
'amt_mean': group['total_amt'].mean(),
'amt_std': max(amt_std if pd.notna(amt_std) else group['total_amt'].mean() * 0.5, 1.0),
'rank_mean': group['rank_pct'].mean(),
'rank_std': max(rank_std if pd.notna(rank_std) else 0.2, 0.05),
'sample_count': len(group),
}
return baseline_dict
# ==================== 计算单日 Z-Score 特征(带滚动基线)====================
def compute_zscore_features_rolling(
trade_date: str,
concepts: List[dict],
all_stocks: List[str],
historical_raw_data: pd.DataFrame # 该日期之前 N 天的原始数据
) -> pd.DataFrame:
"""
计算单日的 Z-Score 特征(使用滚动窗口基线)
关键改进:
1. 基线只使用 trade_date 之前的数据(无未来泄露)
2. 动量基于 Z-Score 计算(消除波动率异构性)
"""
# 计算当日原始特征
raw_df = compute_raw_concept_features(trade_date, concepts, all_stocks)
if raw_df.empty:
return pd.DataFrame()
zscore_records = []
for concept_id, group in raw_df.groupby('concept_id'):
# 计算该概念的滚动基线(只用历史数据)
baseline_dict = compute_rolling_baseline(historical_raw_data, concept_id)
if not baseline_dict:
continue
# 按时间排序
group = group.sort_values('timestamp').reset_index(drop=True)
# Z-Score 历史(用于计算基于 Z-Score 的动量)
zscore_history = []
for idx, row in group.iterrows():
time_slot = row['time_slot']
if time_slot not in baseline_dict:
continue
bl = baseline_dict[time_slot]
# 计算 Z-Score
alpha_zscore = (row['alpha'] - bl['alpha_mean']) / bl['alpha_std']
amt_zscore = (row['total_amt'] - bl['amt_mean']) / bl['amt_std']
rank_zscore = (row['rank_pct'] - bl['rank_mean']) / bl['rank_std']
# 截断极端值
clip = CONFIG['zscore_clip']
alpha_zscore = np.clip(alpha_zscore, -clip, clip)
amt_zscore = np.clip(amt_zscore, -clip, clip)
rank_zscore = np.clip(rank_zscore, -clip, clip)
# 记录 Z-Score 历史
zscore_history.append(alpha_zscore)
# 基于 Z-Score 计算动量(消除波动率异构性)
momentum_3m = 0.0
momentum_5m = 0.0
if len(zscore_history) >= 3:
recent_3 = zscore_history[-3:]
older_3 = zscore_history[-6:-3] if len(zscore_history) >= 6 else [zscore_history[0]]
momentum_3m = np.mean(recent_3) - np.mean(older_3)
if len(zscore_history) >= 5:
recent_5 = zscore_history[-5:]
older_5 = zscore_history[-10:-5] if len(zscore_history) >= 10 else [zscore_history[0]]
momentum_5m = np.mean(recent_5) - np.mean(older_5)
zscore_records.append({
'concept_id': concept_id,
'timestamp': row['timestamp'],
'time_slot': time_slot,
'trade_date': trade_date,
# 原始特征
'alpha': row['alpha'],
'total_amt': row['total_amt'],
'limit_up_ratio': row['limit_up_ratio'],
'stock_count': row['stock_count'],
'rank_pct': row['rank_pct'],
# Z-Score 特征
'alpha_zscore': alpha_zscore,
'amt_zscore': amt_zscore,
'rank_zscore': rank_zscore,
# 基于 Z-Score 的动量
'momentum_3m': momentum_3m,
'momentum_5m': momentum_5m,
})
if not zscore_records:
return pd.DataFrame()
return pd.DataFrame(zscore_records)
# ==================== 多进程处理 ====================
def process_single_day_v2(args) -> Tuple[str, bool]:
"""处理单个交易日(多进程版本)"""
trade_date, day_index, concepts, all_stocks, all_trading_days = args
output_file = os.path.join(OUTPUT_DIR, f'features_v2_{trade_date}.parquet')
if os.path.exists(output_file):
return (trade_date, True)
try:
# 计算滚动窗口范围(该日期之前的 N 天)
baseline_days = CONFIG['baseline_days']
# 找出 trade_date 之前的交易日
start_idx = max(0, day_index - baseline_days)
end_idx = day_index # 不包含当天
if end_idx <= start_idx:
# 没有足够的历史数据
return (trade_date, False)
historical_days = all_trading_days[start_idx:end_idx]
# 加载历史原始数据
historical_dfs = []
for hist_date in historical_days:
cache_file = os.path.join(RAW_CACHE_DIR, f'raw_{hist_date}.parquet')
if os.path.exists(cache_file):
historical_dfs.append(pd.read_parquet(cache_file))
else:
# 需要计算
hist_df = compute_raw_concept_features(hist_date, concepts, all_stocks)
if not hist_df.empty:
historical_dfs.append(hist_df)
if not historical_dfs:
return (trade_date, False)
historical_raw_data = pd.concat(historical_dfs, ignore_index=True)
# 计算当日 Z-Score 特征(使用滚动基线)
df = compute_zscore_features_rolling(trade_date, concepts, all_stocks, historical_raw_data)
if df.empty:
return (trade_date, False)
df.to_parquet(output_file, index=False)
return (trade_date, True)
except Exception as e:
print(f"[{trade_date}] 处理失败: {e}")
import traceback
traceback.print_exc()
return (trade_date, False)
# ==================== 主流程 ====================
def main():
import argparse
parser = argparse.ArgumentParser(description='准备训练数据 V2滚动窗口基线 + Z-Score 动量)')
parser.add_argument('--start', type=str, default='2022-01-01', help='开始日期')
parser.add_argument('--end', type=str, default=None, help='结束日期(默认今天)')
parser.add_argument('--workers', type=int, default=18, help='并行进程数')
parser.add_argument('--baseline-days', type=int, default=20, help='滚动基线窗口大小')
parser.add_argument('--force', action='store_true', help='强制重新计算(忽略缓存)')
args = parser.parse_args()
end_date = args.end or datetime.now().strftime('%Y-%m-%d')
CONFIG['baseline_days'] = args.baseline_days
print("=" * 60)
print("数据准备 V2 - 滚动窗口基线 + Z-Score 动量")
print("=" * 60)
print(f"日期范围: {args.start} ~ {end_date}")
print(f"并行进程数: {args.workers}")
print(f"滚动基线窗口: {args.baseline_days}")
# 初始化主进程连接
init_process_connections()
# 1. 获取概念列表
concepts = get_all_concepts()
all_stocks = list(set(s for c in concepts for s in c['stocks']))
print(f"股票总数: {len(all_stocks)}")
# 2. 获取交易日列表
trading_days = get_trading_days(args.start, end_date)
if not trading_days:
print("无交易日数据")
return
# 3. 第一阶段:预计算所有原始特征(用于缓存)
print(f"\n{'='*60}")
print("第一阶段:预计算原始特征(用于滚动基线)")
print(f"{'='*60}")
# 如果强制重新计算,删除缓存
if args.force:
import shutil
if os.path.exists(RAW_CACHE_DIR):
shutil.rmtree(RAW_CACHE_DIR)
os.makedirs(RAW_CACHE_DIR, exist_ok=True)
if os.path.exists(OUTPUT_DIR):
for f in os.listdir(OUTPUT_DIR):
if f.startswith('features_v2_'):
os.remove(os.path.join(OUTPUT_DIR, f))
# 单线程预计算原始特征(因为需要顺序缓存)
print(f"预计算 {len(trading_days)} 天的原始特征...")
for trade_date in tqdm(trading_days, desc="预计算原始特征"):
cache_file = os.path.join(RAW_CACHE_DIR, f'raw_{trade_date}.parquet')
if not os.path.exists(cache_file):
compute_raw_concept_features(trade_date, concepts, all_stocks)
# 4. 第二阶段:计算 Z-Score 特征(多进程)
print(f"\n{'='*60}")
print("第二阶段:计算 Z-Score 特征(滚动基线)")
print(f"{'='*60}")
# 从第 baseline_days 天开始(前面的没有足够历史)
start_idx = args.baseline_days
processable_days = trading_days[start_idx:]
if not processable_days:
print(f"错误:需要至少 {args.baseline_days + 1} 天的数据")
return
print(f"可处理日期: {processable_days[0]} ~ {processable_days[-1]} ({len(processable_days)} 天)")
print(f"跳过前 {start_idx} 天(基线预热期)")
# 构建任务
tasks = []
for i, trade_date in enumerate(trading_days):
if i >= start_idx:
tasks.append((trade_date, i, concepts, all_stocks, trading_days))
print(f"开始处理 {len(tasks)} 个交易日({args.workers} 进程并行)...")
success_count = 0
failed_dates = []
# 使用进程池初始化器
with ProcessPoolExecutor(max_workers=args.workers, initializer=init_process_connections) as executor:
futures = {executor.submit(process_single_day_v2, task): task[0] for task in tasks}
with tqdm(total=len(futures), desc="处理进度", unit="") as pbar:
for future in as_completed(futures):
trade_date = futures[future]
try:
result_date, success = future.result()
if success:
success_count += 1
else:
failed_dates.append(result_date)
except Exception as e:
print(f"\n[{trade_date}] 进程异常: {e}")
failed_dates.append(trade_date)
pbar.update(1)
print("\n" + "=" * 60)
print(f"处理完成: {success_count}/{len(tasks)} 个交易日")
if failed_dates:
print(f"失败日期: {failed_dates[:10]}{'...' if len(failed_dates) > 10 else ''}")
print(f"数据保存在: {OUTPUT_DIR}")
print("=" * 60)
if __name__ == "__main__":
main()

View File

@@ -190,20 +190,22 @@ def get_all_concepts() -> List[dict]:
def get_prev_close(stock_codes: List[str], trade_date: str) -> Dict[str, float]: def get_prev_close(stock_codes: List[str], trade_date: str) -> Dict[str, float]:
"""获取昨收价""" """获取昨收价(上一交易日的收盘价 F007N"""
valid_codes = [c for c in stock_codes if c and len(c) == 6 and c.isdigit()] valid_codes = [c for c in stock_codes if c and len(c) == 6 and c.isdigit()]
if not valid_codes: if not valid_codes:
return {} return {}
codes_str = "','".join(valid_codes) codes_str = "','".join(valid_codes)
# 注意F007N 是"最近成交价"即当日收盘价F002N 是"昨日收盘价"
# 我们需要查上一交易日的 F007N那天的收盘价作为今天的昨收
query = f""" query = f"""
SELECT SECCODE, F002N SELECT SECCODE, F007N
FROM ea_trade FROM ea_trade
WHERE SECCODE IN ('{codes_str}') WHERE SECCODE IN ('{codes_str}')
AND TRADEDATE = ( AND TRADEDATE = (
SELECT MAX(TRADEDATE) FROM ea_trade WHERE TRADEDATE < '{trade_date}' SELECT MAX(TRADEDATE) FROM ea_trade WHERE TRADEDATE < '{trade_date}'
) )
AND F002N IS NOT NULL AND F002N > 0 AND F007N IS NOT NULL AND F007N > 0
""" """
try: try:

729
ml/realtime_detector_v2.py Normal file
View File

@@ -0,0 +1,729 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
V2 实时异动检测器
使用方法:
# 作为模块导入
from ml.realtime_detector_v2 import RealtimeDetectorV2
detector = RealtimeDetectorV2()
alerts = detector.detect_realtime() # 检测当前时刻
# 或命令行测试
python ml/realtime_detector_v2.py --date 2025-12-09
"""
import os
import sys
import json
import pickle
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Tuple
from collections import defaultdict, deque
import numpy as np
import pandas as pd
import torch
from sqlalchemy import create_engine, text
from elasticsearch import Elasticsearch
from clickhouse_driver import Client
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from ml.model import TransformerAutoencoder
# ==================== 配置 ====================
MYSQL_URL = "mysql+pymysql://root:Zzl5588161!@192.168.1.5:3306/stock"
ES_HOST = 'http://127.0.0.1:9200'
ES_INDEX = 'concept_library_v3'
CLICKHOUSE_CONFIG = {
'host': '127.0.0.1',
'port': 9000,
'user': 'default',
'password': 'Zzl33818!',
'database': 'stock'
}
REFERENCE_INDEX = '000001.SH'
BASELINE_FILE = 'ml/data_v2/baselines/realtime_baseline.pkl'
MODEL_DIR = 'ml/checkpoints_v2'
# 检测配置
CONFIG = {
'seq_len': 10, # LSTM 序列长度
'confirm_window': 5, # 持续确认窗口
'confirm_ratio': 0.6, # 确认比例
'rule_weight': 0.5,
'ml_weight': 0.5,
'rule_trigger': 60,
'ml_trigger': 70,
'fusion_trigger': 50,
'cooldown_minutes': 10,
'max_alerts_per_minute': 15,
'zscore_clip': 5.0,
'limit_up_threshold': 9.8,
}
FEATURES = ['alpha_zscore', 'amt_zscore', 'rank_zscore', 'momentum_3m', 'momentum_5m', 'limit_up_ratio']
# ==================== 数据库连接 ====================
_mysql_engine = None
_es_client = None
_ch_client = None
def get_mysql_engine():
global _mysql_engine
if _mysql_engine is None:
_mysql_engine = create_engine(MYSQL_URL, echo=False, pool_pre_ping=True)
return _mysql_engine
def get_es_client():
global _es_client
if _es_client is None:
_es_client = Elasticsearch([ES_HOST])
return _es_client
def get_ch_client():
global _ch_client
if _ch_client is None:
_ch_client = Client(**CLICKHOUSE_CONFIG)
return _ch_client
def code_to_ch_format(code: str) -> str:
if not code or len(code) != 6 or not code.isdigit():
return None
if code.startswith('6'):
return f"{code}.SH"
elif code.startswith('0') or code.startswith('3'):
return f"{code}.SZ"
return f"{code}.BJ"
def time_to_slot(ts) -> str:
if isinstance(ts, str):
return ts
return ts.strftime('%H:%M')
# ==================== 规则评分 ====================
def score_rules_zscore(features: Dict) -> Tuple[float, List[str]]:
"""基于 Z-Score 的规则评分"""
score = 0.0
triggered = []
alpha_z = abs(features.get('alpha_zscore', 0))
amt_z = features.get('amt_zscore', 0)
rank_z = abs(features.get('rank_zscore', 0))
mom_3m = features.get('momentum_3m', 0)
mom_5m = features.get('momentum_5m', 0)
limit_up = features.get('limit_up_ratio', 0)
# Alpha Z-Score
if alpha_z >= 4.0:
score += 25
triggered.append('alpha_extreme')
elif alpha_z >= 3.0:
score += 18
triggered.append('alpha_strong')
elif alpha_z >= 2.0:
score += 10
triggered.append('alpha_moderate')
# 成交额 Z-Score
if amt_z >= 4.0:
score += 20
triggered.append('amt_extreme')
elif amt_z >= 3.0:
score += 12
triggered.append('amt_strong')
elif amt_z >= 2.0:
score += 6
triggered.append('amt_moderate')
# 排名 Z-Score
if rank_z >= 3.0:
score += 15
triggered.append('rank_extreme')
elif rank_z >= 2.0:
score += 8
triggered.append('rank_strong')
# 动量(基于 Z-Score 的)
if mom_3m >= 1.0:
score += 12
triggered.append('momentum_3m_strong')
elif mom_3m >= 0.5:
score += 6
triggered.append('momentum_3m_moderate')
if mom_5m >= 1.5:
score += 10
triggered.append('momentum_5m_strong')
# 涨停比例
if limit_up >= 0.3:
score += 20
triggered.append('limit_up_extreme')
elif limit_up >= 0.15:
score += 12
triggered.append('limit_up_strong')
elif limit_up >= 0.08:
score += 5
triggered.append('limit_up_moderate')
# 组合规则
if alpha_z >= 2.0 and amt_z >= 2.0:
score += 15
triggered.append('combo_alpha_amt')
if alpha_z >= 2.0 and limit_up >= 0.1:
score += 12
triggered.append('combo_alpha_limitup')
return min(score, 100), triggered
# ==================== 实时检测器 ====================
class RealtimeDetectorV2:
"""V2 实时异动检测器"""
def __init__(self, model_dir: str = MODEL_DIR, baseline_file: str = BASELINE_FILE):
print("初始化 V2 实时检测器...")
# 加载概念
self.concepts = self._load_concepts()
self.concept_stocks = {c['concept_id']: set(c['stocks']) for c in self.concepts}
self.all_stocks = list(set(s for c in self.concepts for s in c['stocks']))
# 加载基线
self.baselines = self._load_baselines(baseline_file)
# 加载模型
self.model, self.thresholds, self.device = self._load_model(model_dir)
# 状态管理
self.zscore_history = defaultdict(lambda: deque(maxlen=CONFIG['seq_len']))
self.anomaly_candidates = defaultdict(lambda: deque(maxlen=CONFIG['confirm_window']))
self.cooldown = {}
print(f"初始化完成: {len(self.concepts)} 概念, {len(self.baselines)} 基线")
def _load_concepts(self) -> List[dict]:
"""从 ES 加载概念"""
es = get_es_client()
concepts = []
query = {"query": {"match_all": {}}, "size": 100, "_source": ["concept_id", "concept", "stocks"]}
resp = es.search(index=ES_INDEX, body=query, scroll='2m')
scroll_id = resp['_scroll_id']
hits = resp['hits']['hits']
while hits:
for hit in hits:
src = hit['_source']
stocks = [s['code'] for s in src.get('stocks', []) if isinstance(s, dict) and s.get('code')]
if stocks:
concepts.append({
'concept_id': src.get('concept_id'),
'concept_name': src.get('concept'),
'stocks': stocks
})
resp = es.scroll(scroll_id=scroll_id, scroll='2m')
scroll_id = resp['_scroll_id']
hits = resp['hits']['hits']
es.clear_scroll(scroll_id=scroll_id)
return concepts
def _load_baselines(self, baseline_file: str) -> Dict:
"""加载基线"""
if not os.path.exists(baseline_file):
print(f"警告: 基线文件不存在: {baseline_file}")
print("请先运行: python ml/update_baseline.py")
return {}
with open(baseline_file, 'rb') as f:
data = pickle.load(f)
print(f"基线日期范围: {data.get('date_range', 'unknown')}")
print(f"更新时间: {data.get('update_time', 'unknown')}")
return data.get('baselines', {})
def _load_model(self, model_dir: str):
"""加载模型"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
config_path = os.path.join(model_dir, 'config.json')
model_path = os.path.join(model_dir, 'best_model.pt')
threshold_path = os.path.join(model_dir, 'thresholds.json')
if not os.path.exists(model_path):
print(f"警告: 模型不存在: {model_path}")
return None, {}, device
with open(config_path) as f:
config = json.load(f)
model = TransformerAutoencoder(**config['model'])
checkpoint = torch.load(model_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval()
thresholds = {}
if os.path.exists(threshold_path):
with open(threshold_path) as f:
thresholds = json.load(f)
print(f"模型已加载: {model_path}")
return model, thresholds, device
def _get_realtime_data(self, trade_date: str) -> pd.DataFrame:
"""获取实时数据并计算原始特征"""
ch = get_ch_client()
# 获取股票数据
ch_codes = [code_to_ch_format(c) for c in self.all_stocks if code_to_ch_format(c)]
ch_codes_str = "','".join(ch_codes)
stock_query = f"""
SELECT code, timestamp, close, amt
FROM stock_minute
WHERE toDate(timestamp) = '{trade_date}'
AND code IN ('{ch_codes_str}')
ORDER BY timestamp
"""
stock_result = ch.execute(stock_query)
if not stock_result:
return pd.DataFrame()
stock_df = pd.DataFrame(stock_result, columns=['ch_code', 'timestamp', 'close', 'amt'])
# 映射回原始代码
ch_to_code = {code_to_ch_format(c): c for c in self.all_stocks if code_to_ch_format(c)}
stock_df['code'] = stock_df['ch_code'].map(ch_to_code)
stock_df = stock_df.dropna(subset=['code'])
# 获取指数数据
index_query = f"""
SELECT timestamp, close
FROM index_minute
WHERE toDate(timestamp) = '{trade_date}'
AND code = '{REFERENCE_INDEX}'
ORDER BY timestamp
"""
index_result = ch.execute(index_query)
if not index_result:
return pd.DataFrame()
index_df = pd.DataFrame(index_result, columns=['timestamp', 'close'])
# 获取昨收价
engine = get_mysql_engine()
codes_str = "','".join([c for c in self.all_stocks if c and len(c) == 6])
with engine.connect() as conn:
prev_result = conn.execute(text(f"""
SELECT SECCODE, F007N FROM ea_trade
WHERE SECCODE IN ('{codes_str}')
AND TRADEDATE = (SELECT MAX(TRADEDATE) FROM ea_trade WHERE TRADEDATE < '{trade_date}')
AND F007N > 0
"""))
prev_close = {row[0]: float(row[1]) for row in prev_result if row[1]}
idx_result = conn.execute(text("""
SELECT F006N FROM ea_exchangetrade
WHERE INDEXCODE = '000001' AND TRADEDATE < :today
ORDER BY TRADEDATE DESC LIMIT 1
"""), {'today': trade_date}).fetchone()
index_prev_close = float(idx_result[0]) if idx_result else None
if not prev_close or not index_prev_close:
return pd.DataFrame()
# 计算涨跌幅
stock_df['prev_close'] = stock_df['code'].map(prev_close)
stock_df = stock_df.dropna(subset=['prev_close'])
stock_df['change_pct'] = (stock_df['close'] - stock_df['prev_close']) / stock_df['prev_close'] * 100
index_df['change_pct'] = (index_df['close'] - index_prev_close) / index_prev_close * 100
index_map = dict(zip(index_df['timestamp'], index_df['change_pct']))
# 按时间聚合概念特征
results = []
for ts in sorted(stock_df['timestamp'].unique()):
ts_data = stock_df[stock_df['timestamp'] == ts]
idx_chg = index_map.get(ts, 0)
stock_chg = dict(zip(ts_data['code'], ts_data['change_pct']))
stock_amt = dict(zip(ts_data['code'], ts_data['amt']))
for cid, stocks in self.concept_stocks.items():
changes = [stock_chg[s] for s in stocks if s in stock_chg]
amts = [stock_amt.get(s, 0) for s in stocks if s in stock_chg]
if not changes:
continue
alpha = np.mean(changes) - idx_chg
total_amt = sum(amts)
limit_up_ratio = sum(1 for c in changes if c >= CONFIG['limit_up_threshold']) / len(changes)
results.append({
'concept_id': cid,
'timestamp': ts,
'time_slot': time_to_slot(ts),
'alpha': alpha,
'total_amt': total_amt,
'limit_up_ratio': limit_up_ratio,
'stock_count': len(changes),
})
if not results:
return pd.DataFrame()
df = pd.DataFrame(results)
# 计算排名
for ts in df['timestamp'].unique():
mask = df['timestamp'] == ts
df.loc[mask, 'rank_pct'] = df.loc[mask, 'alpha'].rank(pct=True)
return df
def _compute_zscore(self, concept_id: str, time_slot: str, alpha: float, total_amt: float, rank_pct: float) -> Optional[Dict]:
"""计算 Z-Score"""
if concept_id not in self.baselines:
return None
baseline = self.baselines[concept_id]
if time_slot not in baseline:
return None
bl = baseline[time_slot]
alpha_z = np.clip((alpha - bl['alpha_mean']) / bl['alpha_std'], -5, 5)
amt_z = np.clip((total_amt - bl['amt_mean']) / bl['amt_std'], -5, 5)
rank_z = np.clip((rank_pct - bl['rank_mean']) / bl['rank_std'], -5, 5)
# 动量(基于 Z-Score 历史)
history = list(self.zscore_history[concept_id])
mom_3m = 0.0
mom_5m = 0.0
if len(history) >= 3:
recent = [h['alpha_zscore'] for h in history[-3:]]
older = [h['alpha_zscore'] for h in history[-6:-3]] if len(history) >= 6 else [history[0]['alpha_zscore']]
mom_3m = np.mean(recent) - np.mean(older)
if len(history) >= 5:
recent = [h['alpha_zscore'] for h in history[-5:]]
older = [h['alpha_zscore'] for h in history[-10:-5]] if len(history) >= 10 else [history[0]['alpha_zscore']]
mom_5m = np.mean(recent) - np.mean(older)
return {
'alpha_zscore': float(alpha_z),
'amt_zscore': float(amt_z),
'rank_zscore': float(rank_z),
'momentum_3m': float(mom_3m),
'momentum_5m': float(mom_5m),
}
@torch.no_grad()
def _ml_score(self, sequences: np.ndarray) -> np.ndarray:
"""批量 ML 评分"""
if self.model is None or len(sequences) == 0:
return np.zeros(len(sequences))
x = torch.FloatTensor(sequences).to(self.device)
errors = self.model.compute_reconstruction_error(x, reduction='none')
last_errors = errors[:, -1].cpu().numpy()
# 转换为 0-100 分数
if self.thresholds:
p50 = self.thresholds.get('median', 0.001)
p99 = self.thresholds.get('p99', 0.05)
scores = 50 + (last_errors - p50) / (p99 - p50 + 1e-6) * 49
else:
scores = last_errors * 1000
return np.clip(scores, 0, 100)
def detect(self, trade_date: str = None) -> List[Dict]:
"""检测指定日期的异动"""
trade_date = trade_date or datetime.now().strftime('%Y-%m-%d')
print(f"\n检测 {trade_date} 的异动...")
# 重置状态
self.zscore_history.clear()
self.anomaly_candidates.clear()
self.cooldown.clear()
# 获取数据
raw_df = self._get_realtime_data(trade_date)
if raw_df.empty:
print("无数据")
return []
timestamps = sorted(raw_df['timestamp'].unique())
print(f"时间点数: {len(timestamps)}")
all_alerts = []
for ts in timestamps:
ts_data = raw_df[raw_df['timestamp'] == ts]
time_slot = time_to_slot(ts)
candidates = []
# 计算每个概念的 Z-Score
for _, row in ts_data.iterrows():
cid = row['concept_id']
zscore = self._compute_zscore(
cid, time_slot,
row['alpha'], row['total_amt'], row['rank_pct']
)
if zscore is None:
continue
# 完整特征
features = {
**zscore,
'alpha': row['alpha'],
'limit_up_ratio': row['limit_up_ratio'],
'total_amt': row['total_amt'],
}
# 更新历史
self.zscore_history[cid].append(zscore)
# 规则评分
rule_score, triggered = score_rules_zscore(features)
candidates.append((cid, features, rule_score, triggered))
if not candidates:
continue
# 批量 ML 评分
sequences = []
valid_candidates = []
for cid, features, rule_score, triggered in candidates:
history = list(self.zscore_history[cid])
if len(history) >= CONFIG['seq_len']:
seq = np.array([[h['alpha_zscore'], h['amt_zscore'], h['rank_zscore'],
h['momentum_3m'], h['momentum_5m'], features['limit_up_ratio']]
for h in history])
sequences.append(seq)
valid_candidates.append((cid, features, rule_score, triggered))
if not sequences:
continue
ml_scores = self._ml_score(np.array(sequences))
# 融合 + 确认
for i, (cid, features, rule_score, triggered) in enumerate(valid_candidates):
ml_score = ml_scores[i]
final_score = CONFIG['rule_weight'] * rule_score + CONFIG['ml_weight'] * ml_score
# 判断触发
is_triggered = (
rule_score >= CONFIG['rule_trigger'] or
ml_score >= CONFIG['ml_trigger'] or
final_score >= CONFIG['fusion_trigger']
)
self.anomaly_candidates[cid].append((ts, final_score))
if not is_triggered:
continue
# 冷却期
if cid in self.cooldown:
if (ts - self.cooldown[cid]).total_seconds() < CONFIG['cooldown_minutes'] * 60:
continue
# 持续确认
recent = list(self.anomaly_candidates[cid])
if len(recent) < CONFIG['confirm_window']:
continue
exceed = sum(1 for _, s in recent if s >= CONFIG['fusion_trigger'])
ratio = exceed / len(recent)
if ratio < CONFIG['confirm_ratio']:
continue
# 确认异动!
self.cooldown[cid] = ts
alpha = features['alpha']
alert_type = 'surge_up' if alpha >= 1.5 else 'surge_down' if alpha <= -1.5 else 'surge'
concept_name = next((c['concept_name'] for c in self.concepts if c['concept_id'] == cid), cid)
all_alerts.append({
'concept_id': cid,
'concept_name': concept_name,
'alert_time': ts,
'trade_date': trade_date,
'alert_type': alert_type,
'final_score': float(final_score),
'rule_score': float(rule_score),
'ml_score': float(ml_score),
'confirm_ratio': float(ratio),
'alpha': float(alpha),
'alpha_zscore': float(features['alpha_zscore']),
'amt_zscore': float(features['amt_zscore']),
'rank_zscore': float(features['rank_zscore']),
'momentum_3m': float(features['momentum_3m']),
'momentum_5m': float(features['momentum_5m']),
'limit_up_ratio': float(features['limit_up_ratio']),
'triggered_rules': triggered,
'trigger_reason': f"融合({final_score:.0f})+确认({ratio:.0%})",
})
print(f"检测到 {len(all_alerts)} 个异动")
return all_alerts
# ==================== 数据库存储 ====================
def create_v2_table():
"""创建 V2 异动表(如果不存在)"""
engine = get_mysql_engine()
with engine.begin() as conn:
conn.execute(text("""
CREATE TABLE IF NOT EXISTS concept_anomaly_v2 (
id INT AUTO_INCREMENT PRIMARY KEY,
concept_id VARCHAR(50) NOT NULL,
alert_time DATETIME NOT NULL,
trade_date DATE NOT NULL,
alert_type VARCHAR(20) NOT NULL,
final_score FLOAT,
rule_score FLOAT,
ml_score FLOAT,
trigger_reason VARCHAR(200),
confirm_ratio FLOAT,
alpha FLOAT,
alpha_zscore FLOAT,
amt_zscore FLOAT,
rank_zscore FLOAT,
momentum_3m FLOAT,
momentum_5m FLOAT,
limit_up_ratio FLOAT,
triggered_rules TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
UNIQUE KEY uk_concept_time (concept_id, alert_time),
INDEX idx_trade_date (trade_date),
INDEX idx_alert_type (alert_type)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
"""))
print("concept_anomaly_v2 表已就绪")
def save_alerts_to_db(alerts: List[Dict]) -> int:
"""保存异动到数据库"""
if not alerts:
return 0
engine = get_mysql_engine()
saved = 0
with engine.begin() as conn:
for alert in alerts:
try:
insert_sql = text("""
INSERT IGNORE INTO concept_anomaly_v2
(concept_id, alert_time, trade_date, alert_type,
final_score, rule_score, ml_score, trigger_reason, confirm_ratio,
alpha, alpha_zscore, amt_zscore, rank_zscore,
momentum_3m, momentum_5m, limit_up_ratio, triggered_rules)
VALUES
(:concept_id, :alert_time, :trade_date, :alert_type,
:final_score, :rule_score, :ml_score, :trigger_reason, :confirm_ratio,
:alpha, :alpha_zscore, :amt_zscore, :rank_zscore,
:momentum_3m, :momentum_5m, :limit_up_ratio, :triggered_rules)
""")
result = conn.execute(insert_sql, {
'concept_id': alert['concept_id'],
'alert_time': alert['alert_time'],
'trade_date': alert['trade_date'],
'alert_type': alert['alert_type'],
'final_score': alert['final_score'],
'rule_score': alert['rule_score'],
'ml_score': alert['ml_score'],
'trigger_reason': alert['trigger_reason'],
'confirm_ratio': alert['confirm_ratio'],
'alpha': alert['alpha'],
'alpha_zscore': alert['alpha_zscore'],
'amt_zscore': alert['amt_zscore'],
'rank_zscore': alert['rank_zscore'],
'momentum_3m': alert['momentum_3m'],
'momentum_5m': alert['momentum_5m'],
'limit_up_ratio': alert['limit_up_ratio'],
'triggered_rules': json.dumps(alert.get('triggered_rules', []), ensure_ascii=False),
})
if result.rowcount > 0:
saved += 1
except Exception as e:
print(f"保存失败: {alert['concept_id']} - {e}")
return saved
def main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--date', type=str, default=None)
parser.add_argument('--no-save', action='store_true', help='不保存到数据库,只打印')
args = parser.parse_args()
# 确保表存在
if not args.no_save:
create_v2_table()
detector = RealtimeDetectorV2()
alerts = detector.detect(args.date)
print(f"\n{'='*60}")
print(f"检测结果 ({len(alerts)} 个异动)")
print('='*60)
for a in alerts[:20]:
print(f"[{a['alert_time'].strftime('%H:%M') if hasattr(a['alert_time'], 'strftime') else a['alert_time']}] "
f"{a['concept_name']} | {a['alert_type']} | "
f"分数={a['final_score']:.0f} 确认={a['confirm_ratio']:.0%} "
f"α={a['alpha']:.2f}% αZ={a['alpha_zscore']:.1f}")
if len(alerts) > 20:
print(f"... 共 {len(alerts)}")
# 保存到数据库
if not args.no_save and alerts:
saved = save_alerts_to_db(alerts)
print(f"\n✅ 已保存 {saved}/{len(alerts)} 条到 concept_anomaly_v2 表")
elif args.no_save:
print(f"\n⚠️ --no-save 模式,未保存到数据库")
if __name__ == "__main__":
main()

622
ml/train_v2.py Normal file
View File

@@ -0,0 +1,622 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
训练脚本 V2 - 基于 Z-Score 特征的 LSTM Autoencoder
改进点:
1. 使用 Z-Score 特征(相对于同时间片历史的偏离)
2. 短序列10分钟不需要30分钟预热
3. 开盘即可检测9:30 直接有特征
模型输入:
- 过去10分钟的 Z-Score 特征序列
- 特征alpha_zscore, amt_zscore, rank_zscore, momentum_3m, momentum_5m, limit_up_ratio
模型学习:
- 学习 Z-Score 序列的"正常演化模式"
- 异动 = Z-Score 序列的异常演化(重构误差大)
"""
import os
import sys
import argparse
import json
from datetime import datetime
from pathlib import Path
from typing import List, Tuple, Dict
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from tqdm import tqdm
from model import TransformerAutoencoder, AnomalyDetectionLoss, count_parameters
# 性能优化
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
try:
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
HAS_MATPLOTLIB = True
except ImportError:
HAS_MATPLOTLIB = False
# ==================== 配置 ====================
TRAIN_CONFIG = {
# 数据配置(改进!)
'seq_len': 10, # 10分钟序列不是30分钟
'stride': 2, # 步长2分钟
# 时间切分
'train_end_date': '2024-06-30',
'val_end_date': '2024-09-30',
# V2 特征Z-Score 为主)
'features': [
'alpha_zscore', # Alpha 的 Z-Score
'amt_zscore', # 成交额的 Z-Score
'rank_zscore', # 排名的 Z-Score
'momentum_3m', # 3分钟动量
'momentum_5m', # 5分钟动量
'limit_up_ratio', # 涨停占比
],
# 训练配置
'batch_size': 4096,
'epochs': 100,
'learning_rate': 3e-4,
'weight_decay': 1e-5,
'gradient_clip': 1.0,
# 早停配置
'patience': 15,
'min_delta': 1e-6,
# 模型配置(小型 LSTM
'model': {
'n_features': 6,
'hidden_dim': 32,
'latent_dim': 4,
'num_layers': 1,
'dropout': 0.2,
'bidirectional': True,
},
# 标准化配置
'clip_value': 5.0, # Z-Score 已经标准化clip 5.0 足够
# 阈值配置
'threshold_percentiles': [90, 95, 99],
}
# ==================== 数据加载 ====================
def load_data_by_date(data_dir: str, features: List[str]) -> Dict[str, pd.DataFrame]:
"""按日期加载 V2 数据"""
data_path = Path(data_dir)
parquet_files = sorted(data_path.glob("features_v2_*.parquet"))
if not parquet_files:
raise FileNotFoundError(f"未找到 V2 数据文件: {data_dir}")
print(f"找到 {len(parquet_files)} 个 V2 数据文件")
date_data = {}
for pf in tqdm(parquet_files, desc="加载数据"):
date = pf.stem.replace('features_v2_', '')
df = pd.read_parquet(pf)
required_cols = features + ['concept_id', 'timestamp']
missing_cols = [c for c in required_cols if c not in df.columns]
if missing_cols:
print(f"警告: {date} 缺少列: {missing_cols}, 跳过")
continue
date_data[date] = df
print(f"成功加载 {len(date_data)} 天的数据")
return date_data
def split_data_by_date(
date_data: Dict[str, pd.DataFrame],
train_end: str,
val_end: str
) -> Tuple[Dict[str, pd.DataFrame], Dict[str, pd.DataFrame], Dict[str, pd.DataFrame]]:
"""按日期划分数据集"""
train_data = {}
val_data = {}
test_data = {}
for date, df in date_data.items():
if date <= train_end:
train_data[date] = df
elif date <= val_end:
val_data[date] = df
else:
test_data[date] = df
print(f"数据集划分:")
print(f" 训练集: {len(train_data)} 天 (<= {train_end})")
print(f" 验证集: {len(val_data)} 天 ({train_end} ~ {val_end})")
print(f" 测试集: {len(test_data)} 天 (> {val_end})")
return train_data, val_data, test_data
def build_sequences_by_concept(
date_data: Dict[str, pd.DataFrame],
features: List[str],
seq_len: int,
stride: int
) -> np.ndarray:
"""按概念分组构建序列"""
all_dfs = []
for date, df in sorted(date_data.items()):
df = df.copy()
df['date'] = date
all_dfs.append(df)
if not all_dfs:
return np.array([])
combined = pd.concat(all_dfs, ignore_index=True)
combined = combined.sort_values(['concept_id', 'date', 'timestamp'])
all_sequences = []
grouped = combined.groupby('concept_id', sort=False)
n_concepts = len(grouped)
for concept_id, concept_df in tqdm(grouped, desc="构建序列", total=n_concepts, leave=False):
feature_data = concept_df[features].values
feature_data = np.nan_to_num(feature_data, nan=0.0, posinf=0.0, neginf=0.0)
n_points = len(feature_data)
for start in range(0, n_points - seq_len + 1, stride):
seq = feature_data[start:start + seq_len]
all_sequences.append(seq)
if not all_sequences:
return np.array([])
sequences = np.array(all_sequences)
print(f" 构建序列: {len(sequences):,} 条 (来自 {n_concepts} 个概念)")
return sequences
# ==================== 数据集 ====================
class SequenceDataset(Dataset):
def __init__(self, sequences: np.ndarray):
self.sequences = torch.FloatTensor(sequences)
def __len__(self) -> int:
return len(self.sequences)
def __getitem__(self, idx: int) -> torch.Tensor:
return self.sequences[idx]
# ==================== 训练器 ====================
class EarlyStopping:
def __init__(self, patience: int = 10, min_delta: float = 1e-6):
self.patience = patience
self.min_delta = min_delta
self.counter = 0
self.best_loss = float('inf')
self.early_stop = False
def __call__(self, val_loss: float) -> bool:
if val_loss < self.best_loss - self.min_delta:
self.best_loss = val_loss
self.counter = 0
else:
self.counter += 1
if self.counter >= self.patience:
self.early_stop = True
return self.early_stop
class Trainer:
def __init__(
self,
model: nn.Module,
train_loader: DataLoader,
val_loader: DataLoader,
config: Dict,
device: torch.device,
save_dir: str = 'ml/checkpoints_v2'
):
self.model = model.to(device)
self.train_loader = train_loader
self.val_loader = val_loader
self.config = config
self.device = device
self.save_dir = Path(save_dir)
self.save_dir.mkdir(parents=True, exist_ok=True)
self.optimizer = AdamW(
model.parameters(),
lr=config['learning_rate'],
weight_decay=config['weight_decay']
)
self.scheduler = CosineAnnealingWarmRestarts(
self.optimizer, T_0=10, T_mult=2, eta_min=1e-6
)
self.criterion = AnomalyDetectionLoss()
self.early_stopping = EarlyStopping(
patience=config['patience'],
min_delta=config['min_delta']
)
self.use_amp = torch.cuda.is_available()
self.scaler = torch.cuda.amp.GradScaler() if self.use_amp else None
if self.use_amp:
print(" ✓ 启用 AMP 混合精度训练")
self.history = {'train_loss': [], 'val_loss': [], 'learning_rate': []}
self.best_val_loss = float('inf')
def train_epoch(self) -> float:
self.model.train()
total_loss = 0.0
n_batches = 0
pbar = tqdm(self.train_loader, desc="Training", leave=False)
for batch in pbar:
batch = batch.to(self.device, non_blocking=True)
self.optimizer.zero_grad(set_to_none=True)
if self.use_amp:
with torch.cuda.amp.autocast():
output, latent = self.model(batch)
loss, _ = self.criterion(output, batch, latent)
self.scaler.scale(loss).backward()
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config['gradient_clip'])
self.scaler.step(self.optimizer)
self.scaler.update()
else:
output, latent = self.model(batch)
loss, _ = self.criterion(output, batch, latent)
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config['gradient_clip'])
self.optimizer.step()
total_loss += loss.item()
n_batches += 1
pbar.set_postfix({'loss': f"{loss.item():.4f}"})
return total_loss / n_batches
@torch.no_grad()
def validate(self) -> float:
self.model.eval()
total_loss = 0.0
n_batches = 0
for batch in self.val_loader:
batch = batch.to(self.device, non_blocking=True)
if self.use_amp:
with torch.cuda.amp.autocast():
output, latent = self.model(batch)
loss, _ = self.criterion(output, batch, latent)
else:
output, latent = self.model(batch)
loss, _ = self.criterion(output, batch, latent)
total_loss += loss.item()
n_batches += 1
return total_loss / n_batches
def save_checkpoint(self, epoch: int, val_loss: float, is_best: bool = False):
model_to_save = self.model.module if hasattr(self.model, 'module') else self.model
checkpoint = {
'epoch': epoch,
'model_state_dict': model_to_save.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'scheduler_state_dict': self.scheduler.state_dict(),
'val_loss': val_loss,
'config': self.config,
}
torch.save(checkpoint, self.save_dir / 'last_checkpoint.pt')
if is_best:
torch.save(checkpoint, self.save_dir / 'best_model.pt')
print(f" ✓ 保存最佳模型 (val_loss: {val_loss:.6f})")
def train(self, epochs: int):
print(f"\n开始训练 ({epochs} epochs)...")
print(f"设备: {self.device}")
print(f"模型参数量: {count_parameters(self.model):,}")
for epoch in range(1, epochs + 1):
print(f"\nEpoch {epoch}/{epochs}")
train_loss = self.train_epoch()
val_loss = self.validate()
self.scheduler.step()
current_lr = self.optimizer.param_groups[0]['lr']
self.history['train_loss'].append(train_loss)
self.history['val_loss'].append(val_loss)
self.history['learning_rate'].append(current_lr)
print(f" Train Loss: {train_loss:.6f}")
print(f" Val Loss: {val_loss:.6f}")
print(f" LR: {current_lr:.2e}")
is_best = val_loss < self.best_val_loss
if is_best:
self.best_val_loss = val_loss
self.save_checkpoint(epoch, val_loss, is_best)
if self.early_stopping(val_loss):
print(f"\n早停触发!")
break
print(f"\n训练完成!最佳验证损失: {self.best_val_loss:.6f}")
self.save_history()
return self.history
def save_history(self):
history_path = self.save_dir / 'training_history.json'
with open(history_path, 'w') as f:
json.dump(self.history, f, indent=2)
print(f"训练历史已保存: {history_path}")
if HAS_MATPLOTLIB:
self.plot_training_curves()
def plot_training_curves(self):
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
epochs = range(1, len(self.history['train_loss']) + 1)
ax1 = axes[0]
ax1.plot(epochs, self.history['train_loss'], 'b-', label='Train Loss', linewidth=2)
ax1.plot(epochs, self.history['val_loss'], 'r-', label='Val Loss', linewidth=2)
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training & Validation Loss (V2)')
ax1.legend()
ax1.grid(True, alpha=0.3)
best_epoch = np.argmin(self.history['val_loss']) + 1
best_val_loss = min(self.history['val_loss'])
ax1.axvline(x=best_epoch, color='g', linestyle='--', alpha=0.7)
ax1.scatter([best_epoch], [best_val_loss], color='g', s=100, zorder=5)
ax2 = axes[1]
ax2.plot(epochs, self.history['learning_rate'], 'g-', linewidth=2)
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Learning Rate')
ax2.set_title('Learning Rate Schedule')
ax2.set_yscale('log')
ax2.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(self.save_dir / 'training_curves.png', dpi=150, bbox_inches='tight')
plt.close()
print(f"训练曲线已保存")
# ==================== 阈值计算 ====================
@torch.no_grad()
def compute_thresholds(
model: nn.Module,
data_loader: DataLoader,
device: torch.device,
percentiles: List[float] = [90, 95, 99]
) -> Dict[str, float]:
"""在验证集上计算阈值"""
model.eval()
all_errors = []
print("计算异动阈值...")
for batch in tqdm(data_loader, desc="Computing thresholds"):
batch = batch.to(device)
errors = model.compute_reconstruction_error(batch, reduction='none')
seq_errors = errors[:, -1] # 最后一个时刻
all_errors.append(seq_errors.cpu().numpy())
all_errors = np.concatenate(all_errors)
thresholds = {}
for p in percentiles:
threshold = np.percentile(all_errors, p)
thresholds[f'p{p}'] = float(threshold)
print(f" P{p}: {threshold:.6f}")
thresholds['mean'] = float(np.mean(all_errors))
thresholds['std'] = float(np.std(all_errors))
thresholds['median'] = float(np.median(all_errors))
return thresholds
# ==================== 主函数 ====================
def main():
parser = argparse.ArgumentParser(description='训练 V2 模型')
parser.add_argument('--data_dir', type=str, default='ml/data_v2', help='V2 数据目录')
parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--batch_size', type=int, default=4096)
parser.add_argument('--lr', type=float, default=3e-4)
parser.add_argument('--device', type=str, default='auto')
parser.add_argument('--save_dir', type=str, default='ml/checkpoints_v2')
parser.add_argument('--train_end', type=str, default='2024-06-30')
parser.add_argument('--val_end', type=str, default='2024-09-30')
parser.add_argument('--seq_len', type=int, default=10, help='序列长度(分钟)')
args = parser.parse_args()
config = TRAIN_CONFIG.copy()
config['batch_size'] = args.batch_size
config['epochs'] = args.epochs
config['learning_rate'] = args.lr
config['train_end_date'] = args.train_end
config['val_end_date'] = args.val_end
config['seq_len'] = args.seq_len
if args.device == 'auto':
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
else:
device = torch.device(args.device)
print("=" * 60)
print("概念异动检测模型训练 V2Z-Score 特征)")
print("=" * 60)
print(f"数据目录: {args.data_dir}")
print(f"设备: {device}")
print(f"序列长度: {config['seq_len']} 分钟")
print(f"批次大小: {config['batch_size']}")
print(f"特征: {config['features']}")
print("=" * 60)
# 1. 加载数据
print("\n[1/6] 加载 V2 数据...")
date_data = load_data_by_date(args.data_dir, config['features'])
# 2. 划分数据集
print("\n[2/6] 划分数据集...")
train_data, val_data, test_data = split_data_by_date(
date_data, config['train_end_date'], config['val_end_date']
)
# 3. 构建序列
print("\n[3/6] 构建序列...")
print("训练集:")
train_sequences = build_sequences_by_concept(
train_data, config['features'], config['seq_len'], config['stride']
)
print("验证集:")
val_sequences = build_sequences_by_concept(
val_data, config['features'], config['seq_len'], config['stride']
)
if len(train_sequences) == 0:
print("错误: 训练集为空!")
return
# 4. 预处理
print("\n[4/6] 数据预处理...")
clip_value = config['clip_value']
print(f" Z-Score 特征已标准化,截断: ±{clip_value}")
train_sequences = np.clip(train_sequences, -clip_value, clip_value)
if len(val_sequences) > 0:
val_sequences = np.clip(val_sequences, -clip_value, clip_value)
# 保存配置
save_dir = Path(args.save_dir)
save_dir.mkdir(parents=True, exist_ok=True)
with open(save_dir / 'config.json', 'w') as f:
json.dump(config, f, indent=2)
# 5. 创建数据加载器
print("\n[5/6] 创建数据加载器...")
train_dataset = SequenceDataset(train_sequences)
val_dataset = SequenceDataset(val_sequences) if len(val_sequences) > 0 else None
print(f" 训练序列: {len(train_dataset):,}")
print(f" 验证序列: {len(val_dataset) if val_dataset else 0:,}")
n_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 1
num_workers = min(32, 8 * n_gpus) if sys.platform != 'win32' else 0
train_loader = DataLoader(
train_dataset,
batch_size=config['batch_size'],
shuffle=True,
num_workers=num_workers,
pin_memory=True,
prefetch_factor=4 if num_workers > 0 else None,
persistent_workers=True if num_workers > 0 else False,
drop_last=True
)
val_loader = DataLoader(
val_dataset,
batch_size=config['batch_size'] * 2,
shuffle=False,
num_workers=num_workers,
pin_memory=True,
) if val_dataset else None
# 6. 训练
print("\n[6/6] 训练模型...")
model = TransformerAutoencoder(**config['model'])
if torch.cuda.device_count() > 1:
print(f" 使用 {torch.cuda.device_count()} 张 GPU 并行训练")
model = nn.DataParallel(model)
if val_loader is None:
print("警告: 验证集为空,使用训练集的 10% 作为验证")
split_idx = int(len(train_dataset) * 0.9)
train_subset = torch.utils.data.Subset(train_dataset, range(split_idx))
val_subset = torch.utils.data.Subset(train_dataset, range(split_idx, len(train_dataset)))
train_loader = DataLoader(train_subset, batch_size=config['batch_size'], shuffle=True, num_workers=num_workers, pin_memory=True)
val_loader = DataLoader(val_subset, batch_size=config['batch_size'], shuffle=False, num_workers=num_workers, pin_memory=True)
trainer = Trainer(
model=model,
train_loader=train_loader,
val_loader=val_loader,
config=config,
device=device,
save_dir=args.save_dir
)
trainer.train(config['epochs'])
# 计算阈值
print("\n[额外] 计算异动阈值...")
best_checkpoint = torch.load(save_dir / 'best_model.pt', map_location=device)
# 创建新的单 GPU 模型用于计算阈值(避免 DataParallel 问题)
threshold_model = TransformerAutoencoder(**config['model'])
threshold_model.load_state_dict(best_checkpoint['model_state_dict'])
threshold_model.to(device)
threshold_model.eval()
thresholds = compute_thresholds(threshold_model, val_loader, device, config['threshold_percentiles'])
with open(save_dir / 'thresholds.json', 'w') as f:
json.dump(thresholds, f, indent=2)
print("\n" + "=" * 60)
print("训练完成!")
print(f"模型保存位置: {args.save_dir}")
print("=" * 60)
if __name__ == "__main__":
main()

132
ml/update_baseline.py Normal file
View File

@@ -0,0 +1,132 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
每日盘后运行:更新滚动基线
使用方法:
python ml/update_baseline.py
建议加入 crontab每天 15:30 后运行:
30 15 * * 1-5 cd /path/to/project && python ml/update_baseline.py
"""
import os
import sys
import pickle
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
from pathlib import Path
from tqdm import tqdm
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from ml.prepare_data_v2 import (
get_all_concepts, get_trading_days, compute_raw_concept_features,
init_process_connections, CONFIG, RAW_CACHE_DIR, BASELINE_DIR
)
def update_rolling_baseline(baseline_days: int = 20):
"""
更新滚动基线(用于实盘检测)
基线 = 最近 N 个交易日每个时间片的统计量
"""
print("=" * 60)
print("更新滚动基线(用于实盘)")
print("=" * 60)
# 初始化连接
init_process_connections()
# 获取概念列表
concepts = get_all_concepts()
all_stocks = list(set(s for c in concepts for s in c['stocks']))
# 获取最近的交易日
today = datetime.now().strftime('%Y-%m-%d')
start_date = (datetime.now() - timedelta(days=60)).strftime('%Y-%m-%d') # 多取一些
trading_days = get_trading_days(start_date, today)
if len(trading_days) < baseline_days:
print(f"错误:交易日不足 {baseline_days}")
return
# 只取最近 N 天
recent_days = trading_days[-baseline_days:]
print(f"使用 {len(recent_days)} 天数据: {recent_days[0]} ~ {recent_days[-1]}")
# 加载原始数据
all_data = []
for trade_date in tqdm(recent_days, desc="加载数据"):
cache_file = os.path.join(RAW_CACHE_DIR, f'raw_{trade_date}.parquet')
if os.path.exists(cache_file):
df = pd.read_parquet(cache_file)
else:
df = compute_raw_concept_features(trade_date, concepts, all_stocks)
if not df.empty:
all_data.append(df)
if not all_data:
print("错误:无数据")
return
combined = pd.concat(all_data, ignore_index=True)
print(f"总数据量: {len(combined):,}")
# 按概念计算基线
baselines = {}
for concept_id, group in tqdm(combined.groupby('concept_id'), desc="计算基线"):
baseline_dict = {}
for time_slot, slot_group in group.groupby('time_slot'):
if len(slot_group) < CONFIG['min_baseline_samples']:
continue
alpha_std = slot_group['alpha'].std()
amt_std = slot_group['total_amt'].std()
rank_std = slot_group['rank_pct'].std()
baseline_dict[time_slot] = {
'alpha_mean': float(slot_group['alpha'].mean()),
'alpha_std': float(max(alpha_std if pd.notna(alpha_std) else 1.0, 0.1)),
'amt_mean': float(slot_group['total_amt'].mean()),
'amt_std': float(max(amt_std if pd.notna(amt_std) else slot_group['total_amt'].mean() * 0.5, 1.0)),
'rank_mean': float(slot_group['rank_pct'].mean()),
'rank_std': float(max(rank_std if pd.notna(rank_std) else 0.2, 0.05)),
'sample_count': len(slot_group),
}
if baseline_dict:
baselines[concept_id] = baseline_dict
print(f"计算了 {len(baselines)} 个概念的基线")
# 保存
os.makedirs(BASELINE_DIR, exist_ok=True)
baseline_file = os.path.join(BASELINE_DIR, 'realtime_baseline.pkl')
with open(baseline_file, 'wb') as f:
pickle.dump({
'baselines': baselines,
'update_time': datetime.now().isoformat(),
'date_range': [recent_days[0], recent_days[-1]],
'baseline_days': baseline_days,
}, f)
print(f"基线已保存: {baseline_file}")
print("=" * 60)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--days', type=int, default=20, help='基线天数')
args = parser.parse_args()
update_rolling_baseline(args.days)

View File

@@ -0,0 +1,267 @@
/**
* 迷你分时图组件
* 用于灵活屏中显示证券的日内走势
*/
import React, { useEffect, useRef, useState, useMemo } from 'react';
import { Box, Spinner, Center, Text } from '@chakra-ui/react';
import * as echarts from 'echarts';
/**
* 生成交易时间刻度(用于 X 轴)
* A股交易时间9:30-11:30, 13:00-15:00
*/
const generateTimeTicks = () => {
const ticks = [];
// 上午
for (let h = 9; h <= 11; h++) {
for (let m = (h === 9 ? 30 : 0); m < 60; m++) {
if (h === 11 && m > 30) break;
ticks.push(`${String(h).padStart(2, '0')}:${String(m).padStart(2, '0')}`);
}
}
// 下午
for (let h = 13; h <= 15; h++) {
for (let m = 0; m < 60; m++) {
if (h === 15 && m > 0) break;
ticks.push(`${String(h).padStart(2, '0')}:${String(m).padStart(2, '0')}`);
}
}
return ticks;
};
const TIME_TICKS = generateTimeTicks();
/**
* MiniTimelineChart 组件
* @param {Object} props
* @param {string} props.code - 证券代码
* @param {boolean} props.isIndex - 是否为指数
* @param {number} props.prevClose - 昨收价
* @param {number} props.currentPrice - 当前价(实时)
* @param {number} props.height - 图表高度
*/
const MiniTimelineChart = ({
code,
isIndex = false,
prevClose,
currentPrice,
height = 120,
}) => {
const chartRef = useRef(null);
const chartInstance = useRef(null);
const [timelineData, setTimelineData] = useState([]);
const [loading, setLoading] = useState(true);
const [error, setError] = useState(null);
// 获取分钟数据
useEffect(() => {
if (!code) return;
const fetchData = async () => {
setLoading(true);
setError(null);
try {
const apiPath = isIndex
? `/api/index/${code}/kline?type=minute`
: `/api/stock/${code}/kline?type=minute`;
const response = await fetch(apiPath);
const result = await response.json();
if (result.success !== false && result.data) {
// 格式化数据
const formatted = result.data.map(item => ({
time: item.time || item.timestamp,
price: item.close || item.price,
}));
setTimelineData(formatted);
} else {
setError(result.error || '暂无数据');
}
} catch (e) {
setError('加载失败');
} finally {
setLoading(false);
}
};
fetchData();
// 交易时间内每分钟刷新
const now = new Date();
const hours = now.getHours();
const minutes = now.getMinutes();
const currentMinutes = hours * 60 + minutes;
const isTrading = (currentMinutes >= 570 && currentMinutes <= 690) ||
(currentMinutes >= 780 && currentMinutes <= 900);
let intervalId;
if (isTrading) {
intervalId = setInterval(fetchData, 60000); // 1分钟刷新
}
return () => {
if (intervalId) clearInterval(intervalId);
};
}, [code, isIndex]);
// 合并实时价格到数据中
const chartData = useMemo(() => {
if (!timelineData.length) return [];
const data = [...timelineData];
// 如果有实时价格,添加到最新点
if (currentPrice && data.length > 0) {
const now = new Date();
const timeStr = `${String(now.getHours()).padStart(2, '0')}:${String(now.getMinutes()).padStart(2, '0')}`;
const lastItem = data[data.length - 1];
// 如果实时价格的时间比最后一条数据新,添加新点
if (lastItem.time !== timeStr) {
data.push({ time: timeStr, price: currentPrice });
} else {
// 更新最后一条
data[data.length - 1] = { ...lastItem, price: currentPrice };
}
}
return data;
}, [timelineData, currentPrice]);
// 渲染图表
useEffect(() => {
if (!chartRef.current || loading || !chartData.length) return;
if (!chartInstance.current) {
chartInstance.current = echarts.init(chartRef.current);
}
const baseLine = prevClose || chartData[0]?.price || 0;
// 计算价格范围
const prices = chartData.map(d => d.price).filter(p => p > 0);
const minPrice = Math.min(...prices, baseLine);
const maxPrice = Math.max(...prices, baseLine);
const range = Math.max(maxPrice - baseLine, baseLine - minPrice) * 1.1;
// 准备数据
const times = chartData.map(d => d.time);
const values = chartData.map(d => d.price);
// 判断涨跌
const lastPrice = values[values.length - 1] || baseLine;
const isUp = lastPrice >= baseLine;
const option = {
grid: {
top: 5,
right: 5,
bottom: 5,
left: 5,
containLabel: false,
},
xAxis: {
type: 'category',
data: times,
show: false,
boundaryGap: false,
},
yAxis: {
type: 'value',
min: baseLine - range,
max: baseLine + range,
show: false,
},
series: [
{
type: 'line',
data: values,
smooth: false,
symbol: 'none',
lineStyle: {
width: 1.5,
color: isUp ? '#ef4444' : '#22c55e',
},
areaStyle: {
color: {
type: 'linear',
x: 0,
y: 0,
x2: 0,
y2: 1,
colorStops: [
{ offset: 0, color: isUp ? 'rgba(239, 68, 68, 0.3)' : 'rgba(34, 197, 94, 0.3)' },
{ offset: 1, color: isUp ? 'rgba(239, 68, 68, 0.05)' : 'rgba(34, 197, 94, 0.05)' },
],
},
},
markLine: {
silent: true,
symbol: 'none',
data: [
{
yAxis: baseLine,
lineStyle: {
color: '#666',
type: 'dashed',
width: 1,
},
label: { show: false },
},
],
},
},
],
animation: false,
};
chartInstance.current.setOption(option);
return () => {
// 不在这里销毁,只在组件卸载时销毁
};
}, [chartData, prevClose, loading]);
// 组件卸载时销毁图表
useEffect(() => {
return () => {
if (chartInstance.current) {
chartInstance.current.dispose();
chartInstance.current = null;
}
};
}, []);
// 窗口 resize 处理
useEffect(() => {
const handleResize = () => {
chartInstance.current?.resize();
};
window.addEventListener('resize', handleResize);
return () => window.removeEventListener('resize', handleResize);
}, []);
if (loading) {
return (
<Center h={height}>
<Spinner size="sm" color="gray.400" />
</Center>
);
}
if (error || !chartData.length) {
return (
<Center h={height}>
<Text fontSize="xs" color="gray.400">
{error || '暂无数据'}
</Text>
</Center>
);
}
return <Box ref={chartRef} h={`${height}px`} w="100%" />;
};
export default MiniTimelineChart;

View File

@@ -0,0 +1,275 @@
/**
* 盘口行情面板组件
* 支持显示 5 档或 10 档买卖盘数据
*
* 上交所: 5 档行情
* 深交所: 10 档行情
*/
import React, { useState } from 'react';
import {
Box,
VStack,
HStack,
Text,
Button,
ButtonGroup,
useColorModeValue,
Tooltip,
Badge,
} from '@chakra-ui/react';
/**
* 格式化成交量
* @param {number} volume - 成交量(股)
* @returns {string} 格式化后的字符串
*/
const formatVolume = (volume) => {
if (!volume || volume === 0) return '-';
if (volume >= 10000) {
return `${(volume / 10000).toFixed(0)}`;
}
if (volume >= 1000) {
return `${(volume / 1000).toFixed(1)}k`;
}
return String(volume);
};
/**
* 格式化价格
* @param {number} price - 价格
* @param {number} prevClose - 昨收价
* @returns {Object} { text, color }
*/
const formatPrice = (price, prevClose) => {
if (!price || price === 0) {
return { text: '-', color: 'gray.400' };
}
const text = price.toFixed(2);
if (!prevClose || prevClose === 0) {
return { text, color: 'gray.600' };
}
if (price > prevClose) {
return { text, color: 'red.500' };
}
if (price < prevClose) {
return { text, color: 'green.500' };
}
return { text, color: 'gray.600' };
};
/**
* 单行盘口
*/
const OrderRow = ({ label, price, volume, prevClose, isBid, maxVolume, isLimitPrice }) => {
const bgColor = useColorModeValue(
isBid ? 'red.50' : 'green.50',
isBid ? 'rgba(239, 68, 68, 0.1)' : 'rgba(34, 197, 94, 0.1)'
);
const barColor = useColorModeValue(
isBid ? 'red.200' : 'green.200',
isBid ? 'rgba(239, 68, 68, 0.3)' : 'rgba(34, 197, 94, 0.3)'
);
const limitColor = useColorModeValue('orange.500', 'orange.300');
const priceInfo = formatPrice(price, prevClose);
const volumeText = formatVolume(volume);
// 计算成交量条宽度
const barWidth = maxVolume > 0 ? Math.min((volume / maxVolume) * 100, 100) : 0;
return (
<HStack
spacing={2}
py={0.5}
px={1}
position="relative"
overflow="hidden"
fontSize="xs"
>
{/* 成交量条 */}
<Box
position="absolute"
right={0}
top={0}
bottom={0}
width={`${barWidth}%`}
bg={barColor}
transition="width 0.2s"
/>
{/* 内容 */}
<Text color="gray.500" w="24px" flexShrink={0} zIndex={1}>
{label}
</Text>
<HStack flex={1} justify="flex-end" zIndex={1}>
<Text color={isLimitPrice ? limitColor : priceInfo.color} fontWeight="medium">
{priceInfo.text}
</Text>
{isLimitPrice && (
<Tooltip label={isBid ? '跌停价' : '涨停价'}>
<Badge
colorScheme={isBid ? 'green' : 'red'}
fontSize="2xs"
variant="subtle"
>
{isBid ? '跌' : '涨'}
</Badge>
</Tooltip>
)}
</HStack>
<Text color="gray.600" w="40px" textAlign="right" zIndex={1}>
{volumeText}
</Text>
</HStack>
);
};
/**
* OrderBookPanel 组件
* @param {Object} props
* @param {number[]} props.bidPrices - 买档价格最多10档
* @param {number[]} props.bidVolumes - 买档量
* @param {number[]} props.askPrices - 卖档价格最多10档
* @param {number[]} props.askVolumes - 卖档量
* @param {number} props.prevClose - 昨收价
* @param {number} props.upperLimit - 涨停价
* @param {number} props.lowerLimit - 跌停价
* @param {number} props.defaultLevels - 默认显示档数5 或 10
*/
const OrderBookPanel = ({
bidPrices = [],
bidVolumes = [],
askPrices = [],
askVolumes = [],
prevClose,
upperLimit,
lowerLimit,
defaultLevels = 5,
}) => {
const borderColor = useColorModeValue('gray.200', 'gray.700');
const buttonBg = useColorModeValue('gray.100', 'gray.700');
// 可切换显示的档位数
const maxAvailableLevels = Math.max(bidPrices.length, askPrices.length, 1);
const [showLevels, setShowLevels] = useState(Math.min(defaultLevels, maxAvailableLevels));
// 计算最大成交量(用于条形图比例)
const displayBidVolumes = bidVolumes.slice(0, showLevels);
const displayAskVolumes = askVolumes.slice(0, showLevels);
const allVolumes = [...displayBidVolumes, ...displayAskVolumes].filter(v => v > 0);
const maxVolume = allVolumes.length > 0 ? Math.max(...allVolumes) : 0;
// 判断是否为涨跌停价
const isUpperLimit = (price) => upperLimit && Math.abs(price - upperLimit) < 0.001;
const isLowerLimit = (price) => lowerLimit && Math.abs(price - lowerLimit) < 0.001;
// 卖盘从卖N到卖1即价格从高到低
const askRows = [];
for (let i = showLevels - 1; i >= 0; i--) {
askRows.push(
<OrderRow
key={`ask${i + 1}`}
label={`${i + 1}`}
price={askPrices[i]}
volume={askVolumes[i]}
prevClose={prevClose}
isBid={false}
maxVolume={maxVolume}
isLimitPrice={isUpperLimit(askPrices[i])}
/>
);
}
// 买盘从买1到买N即价格从高到低
const bidRows = [];
for (let i = 0; i < showLevels; i++) {
bidRows.push(
<OrderRow
key={`bid${i + 1}`}
label={`${i + 1}`}
price={bidPrices[i]}
volume={bidVolumes[i]}
prevClose={prevClose}
isBid={true}
maxVolume={maxVolume}
isLimitPrice={isLowerLimit(bidPrices[i])}
/>
);
}
// 没有数据时的提示
const hasData = bidPrices.length > 0 || askPrices.length > 0;
if (!hasData) {
return (
<Box textAlign="center" py={2}>
<Text fontSize="xs" color="gray.400">
暂无盘口数据
</Text>
</Box>
);
}
return (
<VStack spacing={0} align="stretch">
{/* 档位切换只有当有超过5档数据时才显示 */}
{maxAvailableLevels > 5 && (
<HStack justify="flex-end" mb={1}>
<ButtonGroup size="xs" isAttached variant="outline">
<Button
onClick={() => setShowLevels(5)}
bg={showLevels === 5 ? buttonBg : 'transparent'}
fontWeight={showLevels === 5 ? 'bold' : 'normal'}
>
5
</Button>
<Button
onClick={() => setShowLevels(10)}
bg={showLevels === 10 ? buttonBg : 'transparent'}
fontWeight={showLevels === 10 ? 'bold' : 'normal'}
>
10
</Button>
</ButtonGroup>
</HStack>
)}
{/* 卖盘 */}
{askRows}
{/* 分隔线 + 当前价信息 */}
<Box h="1px" bg={borderColor} my={1} position="relative">
{prevClose && (
<Text
position="absolute"
right={0}
top="50%"
transform="translateY(-50%)"
fontSize="2xs"
color="gray.400"
bg={useColorModeValue('white', '#1a1a1a')}
px={1}
>
昨收 {prevClose.toFixed(2)}
</Text>
)}
</Box>
{/* 买盘 */}
{bidRows}
{/* 涨跌停价信息 */}
{(upperLimit || lowerLimit) && (
<HStack justify="space-between" mt={1} fontSize="2xs" color="gray.400">
{lowerLimit && <Text>跌停 {lowerLimit.toFixed(2)}</Text>}
{upperLimit && <Text>涨停 {upperLimit.toFixed(2)}</Text>}
</HStack>
)}
</VStack>
);
};
export default OrderBookPanel;

View File

@@ -0,0 +1,270 @@
/**
* 行情瓷砖组件
* 单个证券的实时行情展示卡片,包含分时图和五档盘口
*/
import React, { useState } from 'react';
import {
Box,
VStack,
HStack,
Text,
IconButton,
Tooltip,
useColorModeValue,
Collapse,
Badge,
Flex,
Spacer,
} from '@chakra-ui/react';
import { CloseIcon, ChevronDownIcon, ChevronUpIcon, ExternalLinkIcon } from '@chakra-ui/icons';
import { useNavigate } from 'react-router-dom';
import MiniTimelineChart from './MiniTimelineChart';
import OrderBookPanel from './OrderBookPanel';
/**
* 格式化价格显示
*/
const formatPrice = (price) => {
if (!price || isNaN(price)) return '-';
return price.toFixed(2);
};
/**
* 格式化涨跌幅
*/
const formatChangePct = (pct) => {
if (!pct || isNaN(pct)) return '0.00%';
const sign = pct > 0 ? '+' : '';
return `${sign}${pct.toFixed(2)}%`;
};
/**
* 格式化涨跌额
*/
const formatChange = (change) => {
if (!change || isNaN(change)) return '-';
const sign = change > 0 ? '+' : '';
return `${sign}${change.toFixed(2)}`;
};
/**
* 格式化成交额
*/
const formatAmount = (amount) => {
if (!amount || isNaN(amount)) return '-';
if (amount >= 100000000) {
return `${(amount / 100000000).toFixed(2)}亿`;
}
if (amount >= 10000) {
return `${(amount / 10000).toFixed(0)}`;
}
return amount.toFixed(0);
};
/**
* QuoteTile 组件
* @param {Object} props
* @param {string} props.code - 证券代码
* @param {string} props.name - 证券名称
* @param {Object} props.quote - 实时行情数据
* @param {boolean} props.isIndex - 是否为指数
* @param {Function} props.onRemove - 移除回调
*/
const QuoteTile = ({
code,
name,
quote = {},
isIndex = false,
onRemove,
}) => {
const navigate = useNavigate();
const [expanded, setExpanded] = useState(true);
// 颜色主题
const cardBg = useColorModeValue('white', '#1a1a1a');
const borderColor = useColorModeValue('gray.200', '#333');
const hoverBorderColor = useColorModeValue('purple.300', '#666');
const textColor = useColorModeValue('gray.800', 'white');
const subTextColor = useColorModeValue('gray.500', 'gray.400');
// 涨跌色
const { price, prevClose, change, changePct, amount } = quote;
const priceColor = useColorModeValue(
!prevClose || price === prevClose ? 'gray.800' :
price > prevClose ? 'red.500' : 'green.500',
!prevClose || price === prevClose ? 'gray.200' :
price > prevClose ? 'red.400' : 'green.400'
);
// 涨跌幅背景色
const changeBgColor = useColorModeValue(
!changePct || changePct === 0 ? 'gray.100' :
changePct > 0 ? 'red.100' : 'green.100',
!changePct || changePct === 0 ? 'gray.700' :
changePct > 0 ? 'rgba(239, 68, 68, 0.2)' : 'rgba(34, 197, 94, 0.2)'
);
// 跳转到详情页
const handleNavigate = () => {
if (isIndex) {
// 指数暂无详情页
return;
}
navigate(`/company?scode=${code}`);
};
return (
<Box
bg={cardBg}
borderWidth="1px"
borderColor={borderColor}
borderRadius="lg"
overflow="hidden"
transition="all 0.2s"
_hover={{
borderColor: hoverBorderColor,
boxShadow: 'md',
}}
>
{/* 头部 */}
<HStack
px={3}
py={2}
borderBottomWidth={expanded ? '1px' : '0'}
borderColor={borderColor}
cursor="pointer"
onClick={() => setExpanded(!expanded)}
>
{/* 名称和代码 */}
<VStack align="start" spacing={0} flex={1} minW={0}>
<HStack spacing={2}>
<Text
fontWeight="bold"
fontSize="sm"
color={textColor}
noOfLines={1}
cursor="pointer"
_hover={{ textDecoration: 'underline' }}
onClick={(e) => {
e.stopPropagation();
handleNavigate();
}}
>
{name || code}
</Text>
{isIndex && (
<Badge colorScheme="purple" fontSize="xs">
指数
</Badge>
)}
</HStack>
<Text fontSize="xs" color={subTextColor}>
{code}
</Text>
</VStack>
{/* 价格信息 */}
<VStack align="end" spacing={0}>
<Text fontWeight="bold" fontSize="lg" color={priceColor}>
{formatPrice(price)}
</Text>
<HStack spacing={1}>
<Box
px={1.5}
py={0.5}
bg={changeBgColor}
borderRadius="sm"
fontSize="xs"
fontWeight="medium"
color={priceColor}
>
{formatChangePct(changePct)}
</Box>
<Text fontSize="xs" color={priceColor}>
{formatChange(change)}
</Text>
</HStack>
</VStack>
{/* 操作按钮 */}
<HStack spacing={1} ml={2}>
<IconButton
icon={expanded ? <ChevronUpIcon /> : <ChevronDownIcon />}
size="xs"
variant="ghost"
aria-label={expanded ? '收起' : '展开'}
onClick={(e) => {
e.stopPropagation();
setExpanded(!expanded);
}}
/>
<Tooltip label="移除">
<IconButton
icon={<CloseIcon />}
size="xs"
variant="ghost"
colorScheme="red"
aria-label="移除"
onClick={(e) => {
e.stopPropagation();
onRemove?.(code);
}}
/>
</Tooltip>
</HStack>
</HStack>
{/* 可折叠内容 */}
<Collapse in={expanded} animateOpacity>
<Box p={3}>
{/* 统计信息 */}
<HStack spacing={4} mb={3} fontSize="xs" color={subTextColor}>
<HStack>
<Text>昨收:</Text>
<Text color={textColor}>{formatPrice(prevClose)}</Text>
</HStack>
<HStack>
<Text>今开:</Text>
<Text color={textColor}>{formatPrice(quote.open)}</Text>
</HStack>
<HStack>
<Text>成交额:</Text>
<Text color={textColor}>{formatAmount(amount)}</Text>
</HStack>
</HStack>
{/* 分时图 */}
<Box mb={3}>
<MiniTimelineChart
code={code}
isIndex={isIndex}
prevClose={prevClose}
currentPrice={price}
height={100}
/>
</Box>
{/* 盘口(指数没有盘口) */}
{!isIndex && (
<Box>
<Text fontSize="xs" color={subTextColor} mb={1}>
盘口 {quote.bidPrices?.length > 5 ? '(10档)' : '(5档)'}
</Text>
<OrderBookPanel
bidPrices={quote.bidPrices || []}
bidVolumes={quote.bidVolumes || []}
askPrices={quote.askPrices || []}
askVolumes={quote.askVolumes || []}
prevClose={prevClose}
upperLimit={quote.upperLimit}
lowerLimit={quote.lowerLimit}
/>
</Box>
)}
</Box>
</Collapse>
</Box>
);
};
export default QuoteTile;

View File

@@ -0,0 +1,3 @@
export { default as MiniTimelineChart } from './MiniTimelineChart';
export { default as OrderBookPanel } from './OrderBookPanel';
export { default as QuoteTile } from './QuoteTile';

View File

@@ -0,0 +1 @@
export { useRealtimeQuote } from './useRealtimeQuote';

View File

@@ -0,0 +1,692 @@
/**
* 实时行情 Hook
* 管理上交所和深交所 WebSocket 连接,获取实时行情数据
*
* 上交所 (SSE): ws://49.232.185.254:8765 - 需主动订阅,提供五档行情
* 深交所 (SZSE): ws://222.128.1.157:8765 - 自动推送,提供十档行情
*
* 深交所支持的数据类型 (category):
* - stock (300111): 股票快照含10档买卖盘
* - bond (300211): 债券快照
* - afterhours_block (300611): 盘后定价大宗交易
* - afterhours_trading (303711): 盘后定价交易
* - hk_stock (306311): 港股快照(深港通)
* - index (309011): 指数快照
* - volume_stats (309111): 成交量统计
* - fund_nav (309211): 基金净值
*/
import { useState, useEffect, useRef, useCallback } from 'react';
import { logger } from '@utils/logger';
// WebSocket 地址配置
const WS_CONFIG = {
SSE: 'ws://49.232.185.254:8765', // 上交所
SZSE: 'ws://222.128.1.157:8765', // 深交所
};
// 心跳间隔 (ms)
const HEARTBEAT_INTERVAL = 30000;
// 重连间隔 (ms)
const RECONNECT_INTERVAL = 3000;
/**
* 判断证券代码属于哪个交易所
* @param {string} code - 证券代码(可带或不带后缀)
* @returns {'SSE'|'SZSE'} 交易所标识
*/
const getExchange = (code) => {
const baseCode = code.split('.')[0];
// 6开头为上海股票
if (baseCode.startsWith('6')) {
return 'SSE';
}
// 000开头的6位数可能是上证指数或深圳股票
if (baseCode.startsWith('000') && baseCode.length === 6) {
// 000001-000999 是上证指数范围,但 000001 也是平安银行
// 这里需要更精确的判断,暂时把 000 开头当深圳
return 'SZSE';
}
// 399开头是深证指数
if (baseCode.startsWith('399')) {
return 'SZSE';
}
// 0、3开头是深圳股票
if (baseCode.startsWith('0') || baseCode.startsWith('3')) {
return 'SZSE';
}
// 5开头是上海 ETF
if (baseCode.startsWith('5')) {
return 'SSE';
}
// 1开头是深圳 ETF/债券
if (baseCode.startsWith('1')) {
return 'SZSE';
}
// 默认上海
return 'SSE';
};
/**
* 标准化证券代码为无后缀格式
*/
const normalizeCode = (code) => {
return code.split('.')[0];
};
/**
* 从深交所 bids/asks 数组提取价格和量数组
* @param {Array} orderBook - [{price, volume}, ...]
* @returns {{ prices: number[], volumes: number[] }}
*/
const extractOrderBook = (orderBook) => {
if (!orderBook || !Array.isArray(orderBook)) {
return { prices: [], volumes: [] };
}
const prices = orderBook.map(item => item.price || 0);
const volumes = orderBook.map(item => item.volume || 0);
return { prices, volumes };
};
/**
* 实时行情 Hook
* @param {string[]} codes - 订阅的证券代码列表
* @returns {Object} { quotes, connected, subscribe, unsubscribe }
*/
export const useRealtimeQuote = (codes = []) => {
// 行情数据 { [code]: QuoteData }
const [quotes, setQuotes] = useState({});
// 连接状态 { SSE: boolean, SZSE: boolean }
const [connected, setConnected] = useState({ SSE: false, SZSE: false });
// WebSocket 实例引用
const wsRefs = useRef({ SSE: null, SZSE: null });
// 心跳定时器
const heartbeatRefs = useRef({ SSE: null, SZSE: null });
// 重连定时器
const reconnectRefs = useRef({ SSE: null, SZSE: null });
// 当前订阅的代码(按交易所分组)
const subscribedCodes = useRef({ SSE: new Set(), SZSE: new Set() });
/**
* 创建 WebSocket 连接
*/
const createConnection = useCallback((exchange) => {
// 清理现有连接
if (wsRefs.current[exchange]) {
wsRefs.current[exchange].close();
}
const ws = new WebSocket(WS_CONFIG[exchange]);
wsRefs.current[exchange] = ws;
ws.onopen = () => {
logger.info('FlexScreen', `${exchange} WebSocket 已连接`);
setConnected(prev => ({ ...prev, [exchange]: true }));
// 上交所需要主动订阅
if (exchange === 'SSE') {
const codes = Array.from(subscribedCodes.current.SSE);
if (codes.length > 0) {
ws.send(JSON.stringify({
action: 'subscribe',
channels: ['stock', 'index'],
codes: codes,
}));
}
}
// 启动心跳
startHeartbeat(exchange);
};
ws.onmessage = (event) => {
try {
const msg = JSON.parse(event.data);
handleMessage(exchange, msg);
} catch (e) {
logger.warn('FlexScreen', `${exchange} 消息解析失败`, e);
}
};
ws.onerror = (error) => {
logger.error('FlexScreen', `${exchange} WebSocket 错误`, error);
};
ws.onclose = () => {
logger.info('FlexScreen', `${exchange} WebSocket 断开`);
setConnected(prev => ({ ...prev, [exchange]: false }));
stopHeartbeat(exchange);
// 自动重连
scheduleReconnect(exchange);
};
}, []);
/**
* 处理 WebSocket 消息
*/
const handleMessage = useCallback((exchange, msg) => {
// 处理 pong
if (msg.type === 'pong') {
return;
}
if (exchange === 'SSE') {
// 上交所消息格式
if (msg.type === 'stock' || msg.type === 'index') {
const data = msg.data || {};
setQuotes(prev => {
const updated = { ...prev };
Object.entries(data).forEach(([code, quote]) => {
// 只更新订阅的代码
if (subscribedCodes.current.SSE.has(code)) {
updated[code] = {
code: quote.security_id,
name: quote.security_name,
price: quote.last_price,
prevClose: quote.prev_close,
open: quote.open_price,
high: quote.high_price,
low: quote.low_price,
volume: quote.volume,
amount: quote.amount,
change: quote.last_price - quote.prev_close,
changePct: quote.prev_close ? ((quote.last_price - quote.prev_close) / quote.prev_close * 100) : 0,
bidPrices: quote.bid_prices || [],
bidVolumes: quote.bid_volumes || [],
askPrices: quote.ask_prices || [],
askVolumes: quote.ask_volumes || [],
updateTime: quote.trade_time,
exchange: 'SSE',
};
}
});
return updated;
});
}
} else if (exchange === 'SZSE') {
// 深交所消息格式(更新后的 API
if (msg.type === 'realtime') {
const { category, data } = msg;
const code = data.security_id;
// 只更新订阅的代码
if (!subscribedCodes.current.SZSE.has(code)) {
return;
}
if (category === 'stock') {
// 股票行情 - 含 10 档买卖盘
const { prices: bidPrices, volumes: bidVolumes } = extractOrderBook(data.bids);
const { prices: askPrices, volumes: askVolumes } = extractOrderBook(data.asks);
setQuotes(prev => ({
...prev,
[code]: {
code: code,
name: prev[code]?.name || '',
price: data.last_px,
prevClose: data.prev_close,
open: data.open_px,
high: data.high_px,
low: data.low_px,
volume: data.volume,
amount: data.amount,
numTrades: data.num_trades,
upperLimit: data.upper_limit, // 涨停价
lowerLimit: data.lower_limit, // 跌停价
change: data.last_px - data.prev_close,
changePct: data.prev_close ? ((data.last_px - data.prev_close) / data.prev_close * 100) : 0,
bidPrices,
bidVolumes,
askPrices,
askVolumes,
tradingPhase: data.trading_phase,
updateTime: msg.timestamp,
exchange: 'SZSE',
},
}));
} else if (category === 'index') {
// 指数行情
setQuotes(prev => ({
...prev,
[code]: {
code: code,
name: prev[code]?.name || '',
price: data.current_index,
prevClose: data.prev_close,
open: data.open_index,
high: data.high_index,
low: data.low_index,
close: data.close_index,
volume: data.volume,
amount: data.amount,
numTrades: data.num_trades,
change: data.current_index - data.prev_close,
changePct: data.prev_close ? ((data.current_index - data.prev_close) / data.prev_close * 100) : 0,
bidPrices: [],
bidVolumes: [],
askPrices: [],
askVolumes: [],
tradingPhase: data.trading_phase,
updateTime: msg.timestamp,
exchange: 'SZSE',
},
}));
} else if (category === 'bond') {
// 债券行情
setQuotes(prev => ({
...prev,
[code]: {
code: code,
name: prev[code]?.name || '',
price: data.last_px,
prevClose: data.prev_close,
open: data.open_px,
high: data.high_px,
low: data.low_px,
volume: data.volume,
amount: data.amount,
numTrades: data.num_trades,
weightedAvgPx: data.weighted_avg_px,
change: data.last_px - data.prev_close,
changePct: data.prev_close ? ((data.last_px - data.prev_close) / data.prev_close * 100) : 0,
bidPrices: [],
bidVolumes: [],
askPrices: [],
askVolumes: [],
tradingPhase: data.trading_phase,
updateTime: msg.timestamp,
exchange: 'SZSE',
isBond: true,
},
}));
} else if (category === 'hk_stock') {
// 港股行情(深港通)
const { prices: bidPrices, volumes: bidVolumes } = extractOrderBook(data.bids);
const { prices: askPrices, volumes: askVolumes } = extractOrderBook(data.asks);
setQuotes(prev => ({
...prev,
[code]: {
code: code,
name: prev[code]?.name || '',
price: data.last_px,
prevClose: data.prev_close,
open: data.open_px,
high: data.high_px,
low: data.low_px,
volume: data.volume,
amount: data.amount,
numTrades: data.num_trades,
nominalPx: data.nominal_px, // 按盘价
referencePx: data.reference_px, // 参考价
change: data.last_px - data.prev_close,
changePct: data.prev_close ? ((data.last_px - data.prev_close) / data.prev_close * 100) : 0,
bidPrices,
bidVolumes,
askPrices,
askVolumes,
tradingPhase: data.trading_phase,
updateTime: msg.timestamp,
exchange: 'SZSE',
isHK: true,
},
}));
} else if (category === 'afterhours_block' || category === 'afterhours_trading') {
// 盘后交易
setQuotes(prev => ({
...prev,
[code]: {
...prev[code],
afterhours: {
bidPx: data.bid_px,
bidSize: data.bid_size,
offerPx: data.offer_px,
offerSize: data.offer_size,
volume: data.volume,
amount: data.amount,
numTrades: data.num_trades,
},
updateTime: msg.timestamp,
},
}));
}
// fund_nav 和 volume_stats 暂不处理
} else if (msg.type === 'snapshot') {
// 深交所初始快照
const { stocks = [], indexes = [], bonds = [] } = msg.data || {};
setQuotes(prev => {
const updated = { ...prev };
stocks.forEach(s => {
if (subscribedCodes.current.SZSE.has(s.security_id)) {
const { prices: bidPrices, volumes: bidVolumes } = extractOrderBook(s.bids);
const { prices: askPrices, volumes: askVolumes } = extractOrderBook(s.asks);
updated[s.security_id] = {
code: s.security_id,
name: s.security_name || '',
price: s.last_px,
prevClose: s.prev_close,
open: s.open_px,
high: s.high_px,
low: s.low_px,
volume: s.volume,
amount: s.amount,
numTrades: s.num_trades,
upperLimit: s.upper_limit,
lowerLimit: s.lower_limit,
change: s.last_px - s.prev_close,
changePct: s.prev_close ? ((s.last_px - s.prev_close) / s.prev_close * 100) : 0,
bidPrices,
bidVolumes,
askPrices,
askVolumes,
exchange: 'SZSE',
};
}
});
indexes.forEach(i => {
if (subscribedCodes.current.SZSE.has(i.security_id)) {
updated[i.security_id] = {
code: i.security_id,
name: i.security_name || '',
price: i.current_index,
prevClose: i.prev_close,
open: i.open_index,
high: i.high_index,
low: i.low_index,
volume: i.volume,
amount: i.amount,
numTrades: i.num_trades,
change: i.current_index - i.prev_close,
changePct: i.prev_close ? ((i.current_index - i.prev_close) / i.prev_close * 100) : 0,
bidPrices: [],
bidVolumes: [],
askPrices: [],
askVolumes: [],
exchange: 'SZSE',
};
}
});
bonds.forEach(b => {
if (subscribedCodes.current.SZSE.has(b.security_id)) {
updated[b.security_id] = {
code: b.security_id,
name: b.security_name || '',
price: b.last_px,
prevClose: b.prev_close,
open: b.open_px,
high: b.high_px,
low: b.low_px,
volume: b.volume,
amount: b.amount,
change: b.last_px - b.prev_close,
changePct: b.prev_close ? ((b.last_px - b.prev_close) / b.prev_close * 100) : 0,
bidPrices: [],
bidVolumes: [],
askPrices: [],
askVolumes: [],
exchange: 'SZSE',
isBond: true,
};
}
});
return updated;
});
}
}
}, []);
/**
* 启动心跳
*/
const startHeartbeat = useCallback((exchange) => {
stopHeartbeat(exchange);
heartbeatRefs.current[exchange] = setInterval(() => {
const ws = wsRefs.current[exchange];
if (ws && ws.readyState === WebSocket.OPEN) {
if (exchange === 'SSE') {
ws.send(JSON.stringify({ action: 'ping' }));
} else {
ws.send(JSON.stringify({ type: 'ping' }));
}
}
}, HEARTBEAT_INTERVAL);
}, []);
/**
* 停止心跳
*/
const stopHeartbeat = useCallback((exchange) => {
if (heartbeatRefs.current[exchange]) {
clearInterval(heartbeatRefs.current[exchange]);
heartbeatRefs.current[exchange] = null;
}
}, []);
/**
* 安排重连
*/
const scheduleReconnect = useCallback((exchange) => {
if (reconnectRefs.current[exchange]) {
return; // 已有重连计划
}
reconnectRefs.current[exchange] = setTimeout(() => {
reconnectRefs.current[exchange] = null;
// 只有还有订阅的代码才重连
if (subscribedCodes.current[exchange].size > 0) {
createConnection(exchange);
}
}, RECONNECT_INTERVAL);
}, [createConnection]);
/**
* 订阅证券
*/
const subscribe = useCallback((code) => {
const baseCode = normalizeCode(code);
const exchange = getExchange(code);
// 添加到订阅列表
subscribedCodes.current[exchange].add(baseCode);
// 如果连接已建立,发送订阅消息(仅上交所需要)
const ws = wsRefs.current[exchange];
if (exchange === 'SSE' && ws && ws.readyState === WebSocket.OPEN) {
ws.send(JSON.stringify({
action: 'subscribe',
channels: ['stock', 'index'],
codes: [baseCode],
}));
}
// 如果连接未建立,创建连接
if (!ws || ws.readyState !== WebSocket.OPEN) {
createConnection(exchange);
}
}, [createConnection]);
/**
* 取消订阅
*/
const unsubscribe = useCallback((code) => {
const baseCode = normalizeCode(code);
const exchange = getExchange(code);
// 从订阅列表移除
subscribedCodes.current[exchange].delete(baseCode);
// 从 quotes 中移除
setQuotes(prev => {
const updated = { ...prev };
delete updated[baseCode];
return updated;
});
// 如果该交易所没有订阅了,关闭连接
if (subscribedCodes.current[exchange].size === 0) {
const ws = wsRefs.current[exchange];
if (ws) {
ws.close();
wsRefs.current[exchange] = null;
}
}
}, []);
/**
* 初始化订阅
*/
useEffect(() => {
if (!codes || codes.length === 0) {
return;
}
// 按交易所分组
const sseCodesSet = new Set();
const szseCodesSet = new Set();
codes.forEach(code => {
const baseCode = normalizeCode(code);
const exchange = getExchange(code);
if (exchange === 'SSE') {
sseCodesSet.add(baseCode);
} else {
szseCodesSet.add(baseCode);
}
});
// 更新订阅列表
subscribedCodes.current.SSE = sseCodesSet;
subscribedCodes.current.SZSE = szseCodesSet;
// 建立连接
if (sseCodesSet.size > 0) {
createConnection('SSE');
}
if (szseCodesSet.size > 0) {
createConnection('SZSE');
}
// 清理
return () => {
['SSE', 'SZSE'].forEach(exchange => {
stopHeartbeat(exchange);
if (reconnectRefs.current[exchange]) {
clearTimeout(reconnectRefs.current[exchange]);
}
const ws = wsRefs.current[exchange];
if (ws) {
ws.close();
}
});
};
}, []); // 只在挂载时执行
/**
* 处理 codes 变化
*/
useEffect(() => {
if (!codes) return;
// 计算新的订阅列表
const newSseCodes = new Set();
const newSzseCodes = new Set();
codes.forEach(code => {
const baseCode = normalizeCode(code);
const exchange = getExchange(code);
if (exchange === 'SSE') {
newSseCodes.add(baseCode);
} else {
newSzseCodes.add(baseCode);
}
});
// 找出需要新增和删除的代码
const oldSseCodes = subscribedCodes.current.SSE;
const oldSzseCodes = subscribedCodes.current.SZSE;
// 更新上交所订阅
const sseToAdd = [...newSseCodes].filter(c => !oldSseCodes.has(c));
const sseToRemove = [...oldSseCodes].filter(c => !newSseCodes.has(c));
if (sseToAdd.length > 0 || sseToRemove.length > 0) {
subscribedCodes.current.SSE = newSseCodes;
const ws = wsRefs.current.SSE;
if (ws && ws.readyState === WebSocket.OPEN && sseToAdd.length > 0) {
ws.send(JSON.stringify({
action: 'subscribe',
channels: ['stock', 'index'],
codes: sseToAdd,
}));
}
// 如果新增了代码但连接未建立
if (sseToAdd.length > 0 && (!ws || ws.readyState !== WebSocket.OPEN)) {
createConnection('SSE');
}
// 如果没有订阅了,关闭连接
if (newSseCodes.size === 0 && ws) {
ws.close();
wsRefs.current.SSE = null;
}
}
// 更新深交所订阅
const szseToAdd = [...newSzseCodes].filter(c => !oldSzseCodes.has(c));
const szseToRemove = [...oldSzseCodes].filter(c => !newSzseCodes.has(c));
if (szseToAdd.length > 0 || szseToRemove.length > 0) {
subscribedCodes.current.SZSE = newSzseCodes;
// 深交所是自动推送,只需要管理连接
const ws = wsRefs.current.SZSE;
if (szseToAdd.length > 0 && (!ws || ws.readyState !== WebSocket.OPEN)) {
createConnection('SZSE');
}
if (newSzseCodes.size === 0 && ws) {
ws.close();
wsRefs.current.SZSE = null;
}
}
// 清理已取消订阅的 quotes
const removedCodes = [...sseToRemove, ...szseToRemove];
if (removedCodes.length > 0) {
setQuotes(prev => {
const updated = { ...prev };
removedCodes.forEach(code => {
delete updated[code];
});
return updated;
});
}
}, [codes, createConnection]);
return {
quotes,
connected,
subscribe,
unsubscribe,
};
};
export default useRealtimeQuote;

View File

@@ -0,0 +1,463 @@
/**
* 灵活屏组件
* 用户可自定义添加关注的指数/个股,实时显示行情
*
* 功能:
* 1. 添加/删除自选证券
* 2. 显示实时行情(通过 WebSocket
* 3. 显示分时走势(结合 ClickHouse 历史数据)
* 4. 显示五档盘口(上交所完整五档,深交所买一卖一)
* 5. 本地存储自选列表
*/
import React, { useState, useEffect, useCallback, useMemo } from 'react';
import {
Box,
Card,
CardBody,
VStack,
HStack,
Heading,
Text,
Input,
InputGroup,
InputLeftElement,
InputRightElement,
IconButton,
Button,
SimpleGrid,
Flex,
Spacer,
Icon,
useColorModeValue,
useToast,
Badge,
Tooltip,
Collapse,
List,
ListItem,
Spinner,
Center,
Menu,
MenuButton,
MenuList,
MenuItem,
Divider,
Tag,
TagLabel,
} from '@chakra-ui/react';
import {
SearchIcon,
CloseIcon,
AddIcon,
ChevronDownIcon,
ChevronUpIcon,
SettingsIcon,
} from '@chakra-ui/icons';
import { FaDesktop, FaPlus, FaTrash, FaSync, FaWifi, FaExclamationCircle } from 'react-icons/fa';
import { useRealtimeQuote } from './hooks';
import { QuoteTile } from './components';
import { logger } from '@utils/logger';
// 本地存储 key
const STORAGE_KEY = 'flexscreen_watchlist';
// 默认自选列表
const DEFAULT_WATCHLIST = [
{ code: '000001', name: '上证指数', isIndex: true },
{ code: '399001', name: '深证成指', isIndex: true },
{ code: '399006', name: '创业板指', isIndex: true },
];
// 热门推荐
const HOT_RECOMMENDATIONS = [
{ code: '000001', name: '上证指数', isIndex: true },
{ code: '399001', name: '深证成指', isIndex: true },
{ code: '399006', name: '创业板指', isIndex: true },
{ code: '399300', name: '沪深300', isIndex: true },
{ code: '600519', name: '贵州茅台', isIndex: false },
{ code: '000858', name: '五粮液', isIndex: false },
{ code: '300750', name: '宁德时代', isIndex: false },
{ code: '002594', name: '比亚迪', isIndex: false },
];
/**
* FlexScreen 组件
*/
const FlexScreen = () => {
const toast = useToast();
// 自选列表
const [watchlist, setWatchlist] = useState([]);
// 搜索状态
const [searchQuery, setSearchQuery] = useState('');
const [searchResults, setSearchResults] = useState([]);
const [isSearching, setIsSearching] = useState(false);
const [showResults, setShowResults] = useState(false);
// 面板状态
const [isCollapsed, setIsCollapsed] = useState(false);
// 颜色主题
const cardBg = useColorModeValue('white', '#1a1a1a');
const borderColor = useColorModeValue('gray.200', '#333');
const textColor = useColorModeValue('gray.800', 'white');
const subTextColor = useColorModeValue('gray.600', 'gray.400');
const searchBg = useColorModeValue('gray.50', '#2a2a2a');
const hoverBg = useColorModeValue('gray.100', '#333');
// 获取订阅的证券代码列表
const subscribedCodes = useMemo(() => {
return watchlist.map(item => item.code);
}, [watchlist]);
// WebSocket 实时行情
const { quotes, connected } = useRealtimeQuote(subscribedCodes);
// 从本地存储加载自选列表
useEffect(() => {
try {
const saved = localStorage.getItem(STORAGE_KEY);
if (saved) {
const parsed = JSON.parse(saved);
if (Array.isArray(parsed) && parsed.length > 0) {
setWatchlist(parsed);
return;
}
}
} catch (e) {
logger.warn('FlexScreen', '加载自选列表失败', e);
}
// 使用默认列表
setWatchlist(DEFAULT_WATCHLIST);
}, []);
// 保存自选列表到本地存储
useEffect(() => {
if (watchlist.length > 0) {
try {
localStorage.setItem(STORAGE_KEY, JSON.stringify(watchlist));
} catch (e) {
logger.warn('FlexScreen', '保存自选列表失败', e);
}
}
}, [watchlist]);
// 搜索证券
const searchSecurities = useCallback(async (query) => {
if (!query.trim()) {
setSearchResults([]);
setShowResults(false);
return;
}
setIsSearching(true);
try {
const response = await fetch(`/api/stocks/search?q=${encodeURIComponent(query)}&limit=10`);
const data = await response.json();
if (data.success) {
setSearchResults(data.data || []);
setShowResults(true);
} else {
setSearchResults([]);
}
} catch (e) {
logger.error('FlexScreen', '搜索失败', e);
setSearchResults([]);
} finally {
setIsSearching(false);
}
}, []);
// 防抖搜索
useEffect(() => {
const timer = setTimeout(() => {
searchSecurities(searchQuery);
}, 300);
return () => clearTimeout(timer);
}, [searchQuery, searchSecurities]);
// 添加证券
const addSecurity = useCallback((security) => {
const code = security.stock_code || security.code;
const name = security.stock_name || security.name;
const isIndex = security.isIndex || code.startsWith('000') || code.startsWith('399');
// 检查是否已存在
if (watchlist.some(item => item.code === code)) {
toast({
title: '已在自选列表中',
status: 'info',
duration: 2000,
isClosable: true,
});
return;
}
// 添加到列表
setWatchlist(prev => [...prev, { code, name, isIndex }]);
toast({
title: `已添加 ${name}`,
status: 'success',
duration: 2000,
isClosable: true,
});
// 清空搜索
setSearchQuery('');
setShowResults(false);
}, [watchlist, toast]);
// 移除证券
const removeSecurity = useCallback((code) => {
setWatchlist(prev => prev.filter(item => item.code !== code));
}, []);
// 清空自选列表
const clearWatchlist = useCallback(() => {
setWatchlist([]);
localStorage.removeItem(STORAGE_KEY);
toast({
title: '已清空自选列表',
status: 'info',
duration: 2000,
isClosable: true,
});
}, [toast]);
// 重置为默认列表
const resetWatchlist = useCallback(() => {
setWatchlist(DEFAULT_WATCHLIST);
toast({
title: '已重置为默认列表',
status: 'success',
duration: 2000,
isClosable: true,
});
}, [toast]);
// 连接状态指示
const isAnyConnected = connected.SSE || connected.SZSE;
const connectionStatus = useMemo(() => {
if (connected.SSE && connected.SZSE) {
return { color: 'green', text: '上交所/深交所 已连接' };
}
if (connected.SSE) {
return { color: 'yellow', text: '上交所 已连接' };
}
if (connected.SZSE) {
return { color: 'yellow', text: '深交所 已连接' };
}
return { color: 'red', text: '未连接' };
}, [connected]);
return (
<Card bg={cardBg} borderWidth="1px" borderColor={borderColor}>
<CardBody>
{/* 头部 */}
<Flex align="center" mb={4}>
<HStack spacing={3}>
<Icon as={FaDesktop} boxSize={6} color="purple.500" />
<Heading size="md" color={textColor}>
灵活屏
</Heading>
<Tooltip label={connectionStatus.text}>
<Badge
colorScheme={connectionStatus.color}
variant="subtle"
display="flex"
alignItems="center"
gap={1}
>
<Icon as={FaWifi} boxSize={3} />
{isAnyConnected ? '实时' : '离线'}
</Badge>
</Tooltip>
</HStack>
<Spacer />
<HStack spacing={2}>
{/* 操作菜单 */}
<Menu>
<MenuButton
as={IconButton}
icon={<SettingsIcon />}
size="sm"
variant="ghost"
aria-label="设置"
/>
<MenuList>
<MenuItem icon={<FaSync />} onClick={resetWatchlist}>
重置为默认
</MenuItem>
<MenuItem icon={<FaTrash />} onClick={clearWatchlist} color="red.500">
清空列表
</MenuItem>
</MenuList>
</Menu>
{/* 折叠按钮 */}
<IconButton
icon={isCollapsed ? <ChevronDownIcon /> : <ChevronUpIcon />}
size="sm"
variant="ghost"
onClick={() => setIsCollapsed(!isCollapsed)}
aria-label={isCollapsed ? '展开' : '收起'}
/>
</HStack>
</Flex>
{/* 可折叠内容 */}
<Collapse in={!isCollapsed} animateOpacity>
{/* 搜索框 */}
<Box position="relative" mb={4}>
<InputGroup size="md">
<InputLeftElement pointerEvents="none">
<SearchIcon color="gray.400" />
</InputLeftElement>
<Input
placeholder="搜索股票/指数代码或名称..."
value={searchQuery}
onChange={(e) => setSearchQuery(e.target.value)}
bg={searchBg}
borderRadius="lg"
_focus={{
borderColor: 'purple.400',
boxShadow: '0 0 0 1px var(--chakra-colors-purple-400)',
}}
/>
{searchQuery && (
<InputRightElement>
<IconButton
size="sm"
icon={<CloseIcon />}
variant="ghost"
onClick={() => {
setSearchQuery('');
setShowResults(false);
}}
aria-label="清空"
/>
</InputRightElement>
)}
</InputGroup>
{/* 搜索结果下拉 */}
<Collapse in={showResults} animateOpacity>
<Box
position="absolute"
top="100%"
left={0}
right={0}
mt={1}
bg={cardBg}
borderWidth="1px"
borderColor={borderColor}
borderRadius="lg"
boxShadow="lg"
maxH="300px"
overflowY="auto"
zIndex={10}
>
{isSearching ? (
<Center p={4}>
<Spinner size="sm" color="purple.500" />
</Center>
) : searchResults.length > 0 ? (
<List spacing={0}>
{searchResults.map((stock, index) => (
<ListItem
key={stock.stock_code}
px={4}
py={2}
cursor="pointer"
_hover={{ bg: hoverBg }}
onClick={() => addSecurity(stock)}
borderBottomWidth={index < searchResults.length - 1 ? '1px' : '0'}
borderColor={borderColor}
>
<HStack justify="space-between">
<VStack align="start" spacing={0}>
<Text fontWeight="medium" color={textColor}>
{stock.stock_name}
</Text>
<Text fontSize="xs" color={subTextColor}>
{stock.stock_code}
</Text>
</VStack>
<IconButton
icon={<AddIcon />}
size="xs"
colorScheme="purple"
variant="ghost"
aria-label="添加"
/>
</HStack>
</ListItem>
))}
</List>
) : (
<Center p={4}>
<Text color={subTextColor} fontSize="sm">
未找到相关证券
</Text>
</Center>
)}
</Box>
</Collapse>
</Box>
{/* 快捷添加 */}
{watchlist.length === 0 && (
<Box mb={4}>
<Text fontSize="sm" color={subTextColor} mb={2}>
热门推荐点击添加
</Text>
<Flex flexWrap="wrap" gap={2}>
{HOT_RECOMMENDATIONS.map((item) => (
<Tag
key={item.code}
size="md"
variant="subtle"
colorScheme="purple"
cursor="pointer"
_hover={{ bg: 'purple.100' }}
onClick={() => addSecurity(item)}
>
<TagLabel>{item.name}</TagLabel>
</Tag>
))}
</Flex>
</Box>
)}
{/* 自选列表 */}
{watchlist.length > 0 ? (
<SimpleGrid columns={{ base: 1, md: 2, lg: 3 }} spacing={4}>
{watchlist.map((item) => (
<QuoteTile
key={item.code}
code={item.code}
name={item.name}
quote={quotes[item.code] || {}}
isIndex={item.isIndex}
onRemove={removeSecurity}
/>
))}
</SimpleGrid>
) : (
<Center py={8}>
<VStack spacing={3}>
<Icon as={FaExclamationCircle} boxSize={10} color="gray.300" />
<Text color={subTextColor}>
自选列表为空请搜索添加证券
</Text>
</VStack>
</Center>
)}
</Collapse>
</CardBody>
</Card>
);
};
export default FlexScreen;

View File

@@ -18,6 +18,69 @@ import {
import { FaBolt, FaArrowUp, FaArrowDown, FaChartLine, FaFire, FaVolumeUp } from 'react-icons/fa'; import { FaBolt, FaArrowUp, FaArrowDown, FaChartLine, FaFire, FaVolumeUp } from 'react-icons/fa';
import { getAlertTypeLabel, formatScore, getScoreColor } from '../utils/chartHelpers'; import { getAlertTypeLabel, formatScore, getScoreColor } from '../utils/chartHelpers';
/**
* Z-Score 指示器组件
*/
const ZScoreIndicator = ({ value, label, tooltip }) => {
if (value === null || value === undefined) return null;
// Z-Score 颜色:越大越红,越小越绿
const getZScoreColor = (z) => {
const absZ = Math.abs(z);
if (absZ >= 3) return z > 0 ? 'red.600' : 'green.600';
if (absZ >= 2) return z > 0 ? 'red.500' : 'green.500';
if (absZ >= 1) return z > 0 ? 'orange.400' : 'teal.400';
return 'gray.400';
};
// Z-Score 强度条宽度(最大 5σ
const barWidth = Math.min(Math.abs(value) / 5 * 100, 100);
return (
<Tooltip label={tooltip || `${label}: ${value.toFixed(2)}σ`} placement="left">
<HStack spacing={1} fontSize="xs">
<Text color="gray.500" w="20px">{label}</Text>
<Box position="relative" w="40px" h="6px" bg="gray.200" borderRadius="full" overflow="hidden">
<Box
position="absolute"
left={value >= 0 ? '50%' : `${50 - barWidth / 2}%`}
w={`${barWidth / 2}%`}
h="100%"
bg={getZScoreColor(value)}
borderRadius="full"
/>
</Box>
<Text color={getZScoreColor(value)} fontWeight="medium" w="28px" textAlign="right">
{value >= 0 ? '+' : ''}{value.toFixed(1)}
</Text>
</HStack>
</Tooltip>
);
};
/**
* 持续确认率指示器
*/
const ConfirmRatioIndicator = ({ ratio }) => {
if (ratio === null || ratio === undefined) return null;
const percent = Math.round(ratio * 100);
const color = percent >= 80 ? 'green' : percent >= 60 ? 'orange' : 'red';
return (
<Tooltip label={`持续确认率: ${percent}%5分钟窗口内超标比例`}>
<HStack spacing={1}>
<Box position="relative" w="32px" h="6px" bg="gray.200" borderRadius="full" overflow="hidden">
<Box w={`${percent}%`} h="100%" bg={`${color}.500`} borderRadius="full" />
</Box>
<Text fontSize="xs" color={`${color}.500`} fontWeight="medium">
{percent}%
</Text>
</HStack>
</Tooltip>
);
};
/** /**
* 单个异动项组件 * 单个异动项组件
*/ */
@@ -29,6 +92,7 @@ const AlertItem = ({ alert, onClick, isSelected }) => {
const isUp = alert.alert_type !== 'surge_down'; const isUp = alert.alert_type !== 'surge_down';
const typeColor = isUp ? 'red' : 'green'; const typeColor = isUp ? 'red' : 'green';
const isV2 = alert.is_v2;
// 获取异动类型图标 // 获取异动类型图标
const getTypeIcon = (type) => { const getTypeIcon = (type) => {
@@ -69,13 +133,30 @@ const AlertItem = ({ alert, onClick, isSelected }) => {
<Text fontWeight="bold" fontSize="sm" noOfLines={1}> <Text fontWeight="bold" fontSize="sm" noOfLines={1}>
{alert.concept_name} {alert.concept_name}
</Text> </Text>
{isV2 && (
<Badge colorScheme="purple" size="xs" variant="subtle" fontSize="10px">
V2
</Badge>
)}
</HStack> </HStack>
<HStack spacing={2} fontSize="xs" color="gray.500"> <HStack spacing={2} fontSize="xs" color="gray.500">
<Text>{alert.time}</Text> <Text>{alert.time}</Text>
<Badge colorScheme={typeColor} size="sm" variant="subtle"> <Badge colorScheme={typeColor} size="sm" variant="subtle">
{getAlertTypeLabel(alert.alert_type)} {getAlertTypeLabel(alert.alert_type)}
</Badge> </Badge>
{/* V2: 持续确认率 */}
{isV2 && alert.confirm_ratio !== undefined && (
<ConfirmRatioIndicator ratio={alert.confirm_ratio} />
)}
</HStack> </HStack>
{/* V2: Z-Score 指标行 */}
{isV2 && (alert.alpha_zscore !== undefined || alert.amt_zscore !== undefined) && (
<HStack spacing={3} mt={1}>
<ZScoreIndicator value={alert.alpha_zscore} label="α" tooltip={`Alpha Z-Score: ${alert.alpha_zscore?.toFixed(2)}σ(相对于历史同时段)`} />
<ZScoreIndicator value={alert.amt_zscore} label="量" tooltip={`成交额 Z-Score: ${alert.amt_zscore?.toFixed(2)}σ`} />
</HStack>
)}
</VStack> </VStack>
{/* 右侧:分数和关键指标 */} {/* 右侧:分数和关键指标 */}
@@ -104,12 +185,29 @@ const AlertItem = ({ alert, onClick, isSelected }) => {
</Text> </Text>
)} )}
{/* 涨停数量 */} {/* V2: 动量指标 */}
{alert.limit_up_count > 0 && ( {isV2 && alert.momentum_5m !== undefined && Math.abs(alert.momentum_5m) > 0.3 && (
<HStack spacing={1}>
<Icon
as={alert.momentum_5m > 0 ? FaArrowUp : FaArrowDown}
color={alert.momentum_5m > 0 ? 'red.400' : 'green.400'}
boxSize={3}
/>
<Text fontSize="xs" color={alert.momentum_5m > 0 ? 'red.400' : 'green.400'}>
动量 {alert.momentum_5m > 0 ? '+' : ''}{alert.momentum_5m.toFixed(2)}
</Text>
</HStack>
)}
{/* 涨停数量 / 涨停比例 */}
{(alert.limit_up_count > 0 || (alert.limit_up_ratio > 0.05 && isV2)) && (
<HStack spacing={1}> <HStack spacing={1}>
<Icon as={FaFire} color="orange.500" boxSize={3} /> <Icon as={FaFire} color="orange.500" boxSize={3} />
<Text fontSize="xs" color="orange.500"> <Text fontSize="xs" color="orange.500">
涨停 {alert.limit_up_count} {alert.limit_up_count > 0
? `涨停 ${alert.limit_up_count}`
: `涨停 ${Math.round(alert.limit_up_ratio * 100)}%`
}
</Text> </Text>
</HStack> </HStack>
)} )}

View File

@@ -54,6 +54,7 @@ import { FaChartLine, FaFire, FaRocket, FaBrain, FaCalendarAlt, FaChevronRight,
import ConceptStocksModal from '@components/ConceptStocksModal'; import ConceptStocksModal from '@components/ConceptStocksModal';
import TradeDatePicker from '@components/TradeDatePicker'; import TradeDatePicker from '@components/TradeDatePicker';
import HotspotOverview from './components/HotspotOverview'; import HotspotOverview from './components/HotspotOverview';
import FlexScreen from './components/FlexScreen';
import { BsGraphUp, BsLightningFill } from 'react-icons/bs'; import { BsGraphUp, BsLightningFill } from 'react-icons/bs';
import * as echarts from 'echarts'; import * as echarts from 'echarts';
import { logger } from '../../utils/logger'; import { logger } from '../../utils/logger';
@@ -846,6 +847,11 @@ const StockOverview = () => {
<HotspotOverview selectedDate={selectedDate} /> <HotspotOverview selectedDate={selectedDate} />
</Box> </Box>
{/* 灵活屏 - 实时行情监控 */}
<Box mb={10}>
<FlexScreen />
</Box>
{/* 今日热门概念 */} {/* 今日热门概念 */}
<Box mb={10}> <Box mb={10}>
<Flex align="center" mb={6}> <Flex align="center" mb={6}>