update pay ui
This commit is contained in:
198
app.py
198
app.py
@@ -12536,113 +12536,113 @@ def get_hotspot_overview():
|
|||||||
'change_pct': change_pct
|
'change_pct': change_pct
|
||||||
})
|
})
|
||||||
|
|
||||||
# 2. 获取概念异动数据(从 concept_anomaly_hybrid 表)
|
# 2. 获取概念异动数据(优先从 V2 表,fallback 到旧表)
|
||||||
alerts = []
|
alerts = []
|
||||||
|
use_v2 = False
|
||||||
# 首先确保表存在(使用 begin() 来自动提交)
|
|
||||||
try:
|
|
||||||
with engine.begin() as conn:
|
|
||||||
conn.execute(text("""
|
|
||||||
CREATE TABLE IF NOT EXISTS concept_anomaly_hybrid (
|
|
||||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
|
||||||
concept_id VARCHAR(64) NOT NULL,
|
|
||||||
alert_time DATETIME NOT NULL,
|
|
||||||
trade_date DATE NOT NULL,
|
|
||||||
alert_type VARCHAR(32) NOT NULL,
|
|
||||||
final_score FLOAT NOT NULL,
|
|
||||||
rule_score FLOAT NOT NULL,
|
|
||||||
ml_score FLOAT NOT NULL,
|
|
||||||
trigger_reason VARCHAR(64),
|
|
||||||
alpha FLOAT,
|
|
||||||
alpha_delta FLOAT,
|
|
||||||
amt_ratio FLOAT,
|
|
||||||
amt_delta FLOAT,
|
|
||||||
rank_pct FLOAT,
|
|
||||||
limit_up_ratio FLOAT,
|
|
||||||
stock_count INT,
|
|
||||||
total_amt FLOAT,
|
|
||||||
triggered_rules JSON,
|
|
||||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
||||||
UNIQUE KEY uk_concept_time (concept_id, alert_time, trade_date),
|
|
||||||
INDEX idx_trade_date (trade_date),
|
|
||||||
INDEX idx_concept_id (concept_id),
|
|
||||||
INDEX idx_final_score (final_score),
|
|
||||||
INDEX idx_alert_type (alert_type)
|
|
||||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='概念异动检测结果(融合版)'
|
|
||||||
"""))
|
|
||||||
except Exception as create_err:
|
|
||||||
logger.debug(f"创建表检查: {create_err}")
|
|
||||||
|
|
||||||
with engine.connect() as conn:
|
with engine.connect() as conn:
|
||||||
# 查询 concept_anomaly_hybrid 表
|
# 尝试查询 V2 表(时间片对齐 + 持续确认版本)
|
||||||
alert_result = conn.execute(text("""
|
try:
|
||||||
SELECT
|
v2_result = conn.execute(text("""
|
||||||
a.concept_id,
|
SELECT
|
||||||
a.alert_time,
|
concept_id, alert_time, trade_date, alert_type,
|
||||||
a.trade_date,
|
final_score, rule_score, ml_score, trigger_reason, confirm_ratio,
|
||||||
a.alert_type,
|
alpha, alpha_zscore, amt_zscore, rank_zscore,
|
||||||
a.final_score,
|
momentum_3m, momentum_5m, limit_up_ratio, triggered_rules
|
||||||
a.rule_score,
|
FROM concept_anomaly_v2
|
||||||
a.ml_score,
|
WHERE trade_date = :trade_date
|
||||||
a.trigger_reason,
|
ORDER BY alert_time
|
||||||
a.alpha,
|
"""), {'trade_date': trade_date})
|
||||||
a.alpha_delta,
|
v2_rows = v2_result.fetchall()
|
||||||
a.amt_ratio,
|
if v2_rows:
|
||||||
a.amt_delta,
|
use_v2 = True
|
||||||
a.rank_pct,
|
for row in v2_rows:
|
||||||
a.limit_up_ratio,
|
triggered_rules = None
|
||||||
a.stock_count,
|
if row[16]:
|
||||||
a.total_amt,
|
try:
|
||||||
a.triggered_rules
|
triggered_rules = json.loads(row[16]) if isinstance(row[16], str) else row[16]
|
||||||
FROM concept_anomaly_hybrid a
|
except:
|
||||||
WHERE a.trade_date = :trade_date
|
pass
|
||||||
ORDER BY a.alert_time
|
|
||||||
"""), {'trade_date': trade_date})
|
|
||||||
|
|
||||||
# 获取概念名称映射(从 ES 或缓存)
|
alerts.append({
|
||||||
concept_names = {}
|
'concept_id': row[0],
|
||||||
|
'concept_name': row[0], # 后面会填充
|
||||||
|
'time': row[1].strftime('%H:%M') if row[1] else None,
|
||||||
|
'timestamp': row[1].isoformat() if row[1] else None,
|
||||||
|
'alert_type': row[3],
|
||||||
|
'final_score': float(row[4]) if row[4] else None,
|
||||||
|
'rule_score': float(row[5]) if row[5] else None,
|
||||||
|
'ml_score': float(row[6]) if row[6] else None,
|
||||||
|
'trigger_reason': row[7],
|
||||||
|
# V2 新增字段
|
||||||
|
'confirm_ratio': float(row[8]) if row[8] else None,
|
||||||
|
'alpha': float(row[9]) if row[9] else None,
|
||||||
|
'alpha_zscore': float(row[10]) if row[10] else None,
|
||||||
|
'amt_zscore': float(row[11]) if row[11] else None,
|
||||||
|
'rank_zscore': float(row[12]) if row[12] else None,
|
||||||
|
'momentum_3m': float(row[13]) if row[13] else None,
|
||||||
|
'momentum_5m': float(row[14]) if row[14] else None,
|
||||||
|
'limit_up_ratio': float(row[15]) if row[15] else 0,
|
||||||
|
'triggered_rules': triggered_rules,
|
||||||
|
# 兼容字段
|
||||||
|
'importance_score': float(row[4]) / 100 if row[4] else None,
|
||||||
|
'is_v2': True,
|
||||||
|
})
|
||||||
|
except Exception as v2_err:
|
||||||
|
logger.debug(f"V2 表查询失败,使用旧表: {v2_err}")
|
||||||
|
|
||||||
for row in alert_result:
|
# Fallback: 查询旧表
|
||||||
concept_id = row[0]
|
if not use_v2:
|
||||||
alert_time = row[1]
|
try:
|
||||||
triggered_rules = None
|
alert_result = conn.execute(text("""
|
||||||
if row[16]:
|
SELECT
|
||||||
try:
|
a.concept_id, a.alert_time, a.trade_date, a.alert_type,
|
||||||
triggered_rules = json.loads(row[16]) if isinstance(row[16], str) else row[16]
|
a.final_score, a.rule_score, a.ml_score, a.trigger_reason,
|
||||||
except:
|
a.alpha, a.alpha_delta, a.amt_ratio, a.amt_delta,
|
||||||
pass
|
a.rank_pct, a.limit_up_ratio, a.stock_count, a.total_amt,
|
||||||
|
a.triggered_rules
|
||||||
|
FROM concept_anomaly_hybrid a
|
||||||
|
WHERE a.trade_date = :trade_date
|
||||||
|
ORDER BY a.alert_time
|
||||||
|
"""), {'trade_date': trade_date})
|
||||||
|
|
||||||
# 获取概念名称(优先从缓存,否则使用 concept_id)
|
for row in alert_result:
|
||||||
concept_name = concept_names.get(concept_id) or concept_id
|
triggered_rules = None
|
||||||
|
if row[16]:
|
||||||
|
try:
|
||||||
|
triggered_rules = json.loads(row[16]) if isinstance(row[16], str) else row[16]
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
# 计算涨停数量(从 limit_up_ratio 和 stock_count 估算)
|
limit_up_ratio = float(row[13]) if row[13] else 0
|
||||||
limit_up_ratio = float(row[13]) if row[13] else 0
|
stock_count = int(row[14]) if row[14] else 0
|
||||||
stock_count = int(row[14]) if row[14] else 0
|
limit_up_count = int(limit_up_ratio * stock_count) if stock_count > 0 else 0
|
||||||
limit_up_count = int(limit_up_ratio * stock_count) if stock_count > 0 else 0
|
|
||||||
|
|
||||||
alerts.append({
|
alerts.append({
|
||||||
'concept_id': concept_id,
|
'concept_id': row[0],
|
||||||
'concept_name': concept_name,
|
'concept_name': row[0],
|
||||||
'time': alert_time.strftime('%H:%M') if alert_time else None,
|
'time': row[1].strftime('%H:%M') if row[1] else None,
|
||||||
'timestamp': alert_time.isoformat() if alert_time else None,
|
'timestamp': row[1].isoformat() if row[1] else None,
|
||||||
'alert_type': row[3],
|
'alert_type': row[3],
|
||||||
'final_score': float(row[4]) if row[4] else None,
|
'final_score': float(row[4]) if row[4] else None,
|
||||||
'rule_score': float(row[5]) if row[5] else None,
|
'rule_score': float(row[5]) if row[5] else None,
|
||||||
'ml_score': float(row[6]) if row[6] else None,
|
'ml_score': float(row[6]) if row[6] else None,
|
||||||
'trigger_reason': row[7],
|
'trigger_reason': row[7],
|
||||||
'alpha': float(row[8]) if row[8] else None,
|
'alpha': float(row[8]) if row[8] else None,
|
||||||
'alpha_delta': float(row[9]) if row[9] else None,
|
'alpha_delta': float(row[9]) if row[9] else None,
|
||||||
'amt_ratio': float(row[10]) if row[10] else None,
|
'amt_ratio': float(row[10]) if row[10] else None,
|
||||||
'amt_delta': float(row[11]) if row[11] else None,
|
'amt_delta': float(row[11]) if row[11] else None,
|
||||||
'rank_pct': float(row[12]) if row[12] else None,
|
'rank_pct': float(row[12]) if row[12] else None,
|
||||||
'limit_up_ratio': limit_up_ratio,
|
'limit_up_ratio': limit_up_ratio,
|
||||||
'limit_up_count': limit_up_count,
|
'limit_up_count': limit_up_count,
|
||||||
'stock_count': stock_count,
|
'stock_count': stock_count,
|
||||||
'total_amt': float(row[15]) if row[15] else None,
|
'total_amt': float(row[15]) if row[15] else None,
|
||||||
'triggered_rules': triggered_rules,
|
'triggered_rules': triggered_rules,
|
||||||
# 兼容旧字段
|
'importance_score': float(row[4]) / 100 if row[4] else None,
|
||||||
'importance_score': float(row[4]) / 100 if row[4] else None,
|
'is_v2': False,
|
||||||
})
|
})
|
||||||
|
except Exception as old_err:
|
||||||
|
logger.debug(f"旧表查询也失败: {old_err}")
|
||||||
|
|
||||||
# 尝试批量获取概念名称
|
# 尝试批量获取概念名称
|
||||||
if alerts:
|
if alerts:
|
||||||
|
|||||||
294
ml/backtest_v2.py
Normal file
294
ml/backtest_v2.py
Normal file
@@ -0,0 +1,294 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
V2 回测脚本 - 验证时间片对齐 + 持续性确认的效果
|
||||||
|
|
||||||
|
回测指标:
|
||||||
|
1. 准确率:异动后 N 分钟内 alpha 是否继续上涨/下跌
|
||||||
|
2. 虚警率:多少异动是噪音
|
||||||
|
3. 持续性:平均异动持续时长
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import json
|
||||||
|
import argparse
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Tuple
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
from tqdm import tqdm
|
||||||
|
from sqlalchemy import create_engine, text
|
||||||
|
|
||||||
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
|
||||||
|
from ml.detector_v2 import AnomalyDetectorV2, CONFIG
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 配置 ====================
|
||||||
|
|
||||||
|
MYSQL_ENGINE = create_engine(
|
||||||
|
"mysql+pymysql://root:Zzl5588161!@192.168.1.5:3306/stock",
|
||||||
|
echo=False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 回测评估 ====================
|
||||||
|
|
||||||
|
def evaluate_alerts(
|
||||||
|
alerts: List[Dict],
|
||||||
|
raw_data: pd.DataFrame,
|
||||||
|
lookahead_minutes: int = 10
|
||||||
|
) -> Dict:
|
||||||
|
"""
|
||||||
|
评估异动质量
|
||||||
|
|
||||||
|
指标:
|
||||||
|
1. 方向正确率:异动后 N 分钟 alpha 方向是否一致
|
||||||
|
2. 持续率:异动后 N 分钟内有多少时刻 alpha 保持同向
|
||||||
|
3. 峰值收益:异动后 N 分钟内的最大 alpha
|
||||||
|
"""
|
||||||
|
if not alerts:
|
||||||
|
return {'accuracy': 0, 'sustained_rate': 0, 'avg_peak': 0, 'total_alerts': 0}
|
||||||
|
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for alert in alerts:
|
||||||
|
concept_id = alert['concept_id']
|
||||||
|
alert_time = alert['alert_time']
|
||||||
|
alert_alpha = alert['alpha']
|
||||||
|
is_up = alert_alpha > 0
|
||||||
|
|
||||||
|
# 获取该概念在异动后的数据
|
||||||
|
concept_data = raw_data[
|
||||||
|
(raw_data['concept_id'] == concept_id) &
|
||||||
|
(raw_data['timestamp'] > alert_time)
|
||||||
|
].head(lookahead_minutes)
|
||||||
|
|
||||||
|
if len(concept_data) < 3:
|
||||||
|
continue
|
||||||
|
|
||||||
|
future_alphas = concept_data['alpha'].values
|
||||||
|
|
||||||
|
# 方向正确:未来 alpha 平均值与当前同向
|
||||||
|
avg_future_alpha = np.mean(future_alphas)
|
||||||
|
direction_correct = (is_up and avg_future_alpha > 0) or (not is_up and avg_future_alpha < 0)
|
||||||
|
|
||||||
|
# 持续率:有多少时刻保持同向
|
||||||
|
if is_up:
|
||||||
|
sustained_count = sum(1 for a in future_alphas if a > 0)
|
||||||
|
else:
|
||||||
|
sustained_count = sum(1 for a in future_alphas if a < 0)
|
||||||
|
sustained_rate = sustained_count / len(future_alphas)
|
||||||
|
|
||||||
|
# 峰值收益
|
||||||
|
if is_up:
|
||||||
|
peak = max(future_alphas)
|
||||||
|
else:
|
||||||
|
peak = min(future_alphas)
|
||||||
|
|
||||||
|
results.append({
|
||||||
|
'direction_correct': direction_correct,
|
||||||
|
'sustained_rate': sustained_rate,
|
||||||
|
'peak': peak,
|
||||||
|
'alert_alpha': alert_alpha,
|
||||||
|
})
|
||||||
|
|
||||||
|
if not results:
|
||||||
|
return {'accuracy': 0, 'sustained_rate': 0, 'avg_peak': 0, 'total_alerts': 0}
|
||||||
|
|
||||||
|
return {
|
||||||
|
'accuracy': np.mean([r['direction_correct'] for r in results]),
|
||||||
|
'sustained_rate': np.mean([r['sustained_rate'] for r in results]),
|
||||||
|
'avg_peak': np.mean([abs(r['peak']) for r in results]),
|
||||||
|
'total_alerts': len(alerts),
|
||||||
|
'evaluated_alerts': len(results),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def save_alerts_to_mysql(alerts: List[Dict], dry_run: bool = False) -> int:
|
||||||
|
"""保存异动到 MySQL"""
|
||||||
|
if not alerts or dry_run:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# 确保表存在
|
||||||
|
with MYSQL_ENGINE.begin() as conn:
|
||||||
|
conn.execute(text("""
|
||||||
|
CREATE TABLE IF NOT EXISTS concept_anomaly_v2 (
|
||||||
|
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||||
|
concept_id VARCHAR(64) NOT NULL,
|
||||||
|
alert_time DATETIME NOT NULL,
|
||||||
|
trade_date DATE NOT NULL,
|
||||||
|
alert_type VARCHAR(32) NOT NULL,
|
||||||
|
final_score FLOAT NOT NULL,
|
||||||
|
rule_score FLOAT NOT NULL,
|
||||||
|
ml_score FLOAT NOT NULL,
|
||||||
|
trigger_reason VARCHAR(128),
|
||||||
|
confirm_ratio FLOAT,
|
||||||
|
alpha FLOAT,
|
||||||
|
alpha_zscore FLOAT,
|
||||||
|
amt_zscore FLOAT,
|
||||||
|
rank_zscore FLOAT,
|
||||||
|
momentum_3m FLOAT,
|
||||||
|
momentum_5m FLOAT,
|
||||||
|
limit_up_ratio FLOAT,
|
||||||
|
triggered_rules JSON,
|
||||||
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
UNIQUE KEY uk_concept_time (concept_id, alert_time, trade_date),
|
||||||
|
INDEX idx_trade_date (trade_date),
|
||||||
|
INDEX idx_final_score (final_score)
|
||||||
|
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='概念异动 V2(时间片对齐+持续确认)'
|
||||||
|
"""))
|
||||||
|
|
||||||
|
# 插入数据
|
||||||
|
saved = 0
|
||||||
|
with MYSQL_ENGINE.begin() as conn:
|
||||||
|
for alert in alerts:
|
||||||
|
try:
|
||||||
|
conn.execute(text("""
|
||||||
|
INSERT IGNORE INTO concept_anomaly_v2
|
||||||
|
(concept_id, alert_time, trade_date, alert_type,
|
||||||
|
final_score, rule_score, ml_score, trigger_reason, confirm_ratio,
|
||||||
|
alpha, alpha_zscore, amt_zscore, rank_zscore,
|
||||||
|
momentum_3m, momentum_5m, limit_up_ratio, triggered_rules)
|
||||||
|
VALUES
|
||||||
|
(:concept_id, :alert_time, :trade_date, :alert_type,
|
||||||
|
:final_score, :rule_score, :ml_score, :trigger_reason, :confirm_ratio,
|
||||||
|
:alpha, :alpha_zscore, :amt_zscore, :rank_zscore,
|
||||||
|
:momentum_3m, :momentum_5m, :limit_up_ratio, :triggered_rules)
|
||||||
|
"""), {
|
||||||
|
'concept_id': alert['concept_id'],
|
||||||
|
'alert_time': alert['alert_time'],
|
||||||
|
'trade_date': alert['trade_date'],
|
||||||
|
'alert_type': alert['alert_type'],
|
||||||
|
'final_score': alert['final_score'],
|
||||||
|
'rule_score': alert['rule_score'],
|
||||||
|
'ml_score': alert['ml_score'],
|
||||||
|
'trigger_reason': alert['trigger_reason'],
|
||||||
|
'confirm_ratio': alert.get('confirm_ratio', 0),
|
||||||
|
'alpha': alert['alpha'],
|
||||||
|
'alpha_zscore': alert.get('alpha_zscore', 0),
|
||||||
|
'amt_zscore': alert.get('amt_zscore', 0),
|
||||||
|
'rank_zscore': alert.get('rank_zscore', 0),
|
||||||
|
'momentum_3m': alert.get('momentum_3m', 0),
|
||||||
|
'momentum_5m': alert.get('momentum_5m', 0),
|
||||||
|
'limit_up_ratio': alert.get('limit_up_ratio', 0),
|
||||||
|
'triggered_rules': json.dumps(alert.get('triggered_rules', [])),
|
||||||
|
})
|
||||||
|
saved += 1
|
||||||
|
except Exception as e:
|
||||||
|
print(f"保存失败: {e}")
|
||||||
|
|
||||||
|
return saved
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 主函数 ====================
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description='V2 回测')
|
||||||
|
parser.add_argument('--start', type=str, required=True, help='开始日期')
|
||||||
|
parser.add_argument('--end', type=str, default=None, help='结束日期')
|
||||||
|
parser.add_argument('--model_dir', type=str, default='ml/checkpoints_v2')
|
||||||
|
parser.add_argument('--baseline_dir', type=str, default='ml/data_v2/baselines')
|
||||||
|
parser.add_argument('--save', action='store_true', help='保存到数据库')
|
||||||
|
parser.add_argument('--lookahead', type=int, default=10, help='评估前瞻时间(分钟)')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
end_date = args.end or args.start
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
print("V2 回测 - 时间片对齐 + 持续性确认")
|
||||||
|
print("=" * 60)
|
||||||
|
print(f"日期范围: {args.start} ~ {end_date}")
|
||||||
|
print(f"模型目录: {args.model_dir}")
|
||||||
|
print(f"评估前瞻: {args.lookahead} 分钟")
|
||||||
|
|
||||||
|
# 初始化检测器
|
||||||
|
detector = AnomalyDetectorV2(
|
||||||
|
model_dir=args.model_dir,
|
||||||
|
baseline_dir=args.baseline_dir
|
||||||
|
)
|
||||||
|
|
||||||
|
# 获取交易日
|
||||||
|
from prepare_data_v2 import get_trading_days
|
||||||
|
trading_days = get_trading_days(args.start, end_date)
|
||||||
|
|
||||||
|
if not trading_days:
|
||||||
|
print("无交易日")
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f"交易日数: {len(trading_days)}")
|
||||||
|
|
||||||
|
# 回测统计
|
||||||
|
total_stats = {
|
||||||
|
'total_alerts': 0,
|
||||||
|
'accuracy_sum': 0,
|
||||||
|
'sustained_sum': 0,
|
||||||
|
'peak_sum': 0,
|
||||||
|
'day_count': 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
all_alerts = []
|
||||||
|
|
||||||
|
for trade_date in tqdm(trading_days, desc="回测进度"):
|
||||||
|
# 检测异动
|
||||||
|
alerts = detector.detect(trade_date)
|
||||||
|
|
||||||
|
if not alerts:
|
||||||
|
continue
|
||||||
|
|
||||||
|
all_alerts.extend(alerts)
|
||||||
|
|
||||||
|
# 评估
|
||||||
|
raw_data = detector._compute_raw_features(trade_date)
|
||||||
|
if raw_data.empty:
|
||||||
|
continue
|
||||||
|
|
||||||
|
stats = evaluate_alerts(alerts, raw_data, args.lookahead)
|
||||||
|
|
||||||
|
if stats['evaluated_alerts'] > 0:
|
||||||
|
total_stats['total_alerts'] += stats['total_alerts']
|
||||||
|
total_stats['accuracy_sum'] += stats['accuracy'] * stats['evaluated_alerts']
|
||||||
|
total_stats['sustained_sum'] += stats['sustained_rate'] * stats['evaluated_alerts']
|
||||||
|
total_stats['peak_sum'] += stats['avg_peak'] * stats['evaluated_alerts']
|
||||||
|
total_stats['day_count'] += 1
|
||||||
|
|
||||||
|
print(f"\n[{trade_date}] 异动: {stats['total_alerts']}, "
|
||||||
|
f"准确率: {stats['accuracy']:.1%}, "
|
||||||
|
f"持续率: {stats['sustained_rate']:.1%}, "
|
||||||
|
f"峰值: {stats['avg_peak']:.2f}%")
|
||||||
|
|
||||||
|
# 汇总
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("回测汇总")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
if total_stats['total_alerts'] > 0:
|
||||||
|
avg_accuracy = total_stats['accuracy_sum'] / total_stats['total_alerts']
|
||||||
|
avg_sustained = total_stats['sustained_sum'] / total_stats['total_alerts']
|
||||||
|
avg_peak = total_stats['peak_sum'] / total_stats['total_alerts']
|
||||||
|
|
||||||
|
print(f"总异动数: {total_stats['total_alerts']}")
|
||||||
|
print(f"回测天数: {total_stats['day_count']}")
|
||||||
|
print(f"平均每天: {total_stats['total_alerts'] / max(1, total_stats['day_count']):.1f} 个")
|
||||||
|
print(f"方向准确率: {avg_accuracy:.1%}")
|
||||||
|
print(f"持续率: {avg_sustained:.1%}")
|
||||||
|
print(f"平均峰值: {avg_peak:.2f}%")
|
||||||
|
else:
|
||||||
|
print("无异动检测结果")
|
||||||
|
|
||||||
|
# 保存
|
||||||
|
if args.save and all_alerts:
|
||||||
|
print(f"\n保存 {len(all_alerts)} 条异动到数据库...")
|
||||||
|
saved = save_alerts_to_mysql(all_alerts)
|
||||||
|
print(f"保存完成: {saved} 条")
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
31
ml/checkpoints_v2/config.json
Normal file
31
ml/checkpoints_v2/config.json
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
{
|
||||||
|
"seq_len": 10,
|
||||||
|
"stride": 2,
|
||||||
|
"train_end_date": "2025-06-30",
|
||||||
|
"val_end_date": "2025-09-30",
|
||||||
|
"features": [
|
||||||
|
"alpha_zscore",
|
||||||
|
"amt_zscore",
|
||||||
|
"rank_zscore",
|
||||||
|
"momentum_3m",
|
||||||
|
"momentum_5m",
|
||||||
|
"limit_up_ratio"
|
||||||
|
],
|
||||||
|
"batch_size": 32768,
|
||||||
|
"epochs": 150,
|
||||||
|
"learning_rate": 0.0006,
|
||||||
|
"weight_decay": 1e-05,
|
||||||
|
"gradient_clip": 1.0,
|
||||||
|
"patience": 15,
|
||||||
|
"min_delta": 1e-06,
|
||||||
|
"model": {
|
||||||
|
"n_features": 6,
|
||||||
|
"hidden_dim": 32,
|
||||||
|
"latent_dim": 4,
|
||||||
|
"num_layers": 1,
|
||||||
|
"dropout": 0.2,
|
||||||
|
"bidirectional": true
|
||||||
|
},
|
||||||
|
"clip_value": 5.0,
|
||||||
|
"threshold_percentiles": [90, 95, 99]
|
||||||
|
}
|
||||||
8
ml/checkpoints_v2/thresholds.json
Normal file
8
ml/checkpoints_v2/thresholds.json
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
{
|
||||||
|
"p90": 0.15,
|
||||||
|
"p95": 0.25,
|
||||||
|
"p99": 0.50,
|
||||||
|
"mean": 0.08,
|
||||||
|
"std": 0.12,
|
||||||
|
"median": 0.06
|
||||||
|
}
|
||||||
716
ml/detector_v2.py
Normal file
716
ml/detector_v2.py
Normal file
@@ -0,0 +1,716 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
异动检测器 V2 - 基于时间片对齐 + 持续性确认
|
||||||
|
|
||||||
|
核心改进:
|
||||||
|
1. Z-Score 特征:相对于同时间片历史的偏离
|
||||||
|
2. 短序列 LSTM:10分钟序列,开盘即可用
|
||||||
|
3. 持续性确认:5分钟窗口内60%时刻超标才确认为异动
|
||||||
|
|
||||||
|
检测流程:
|
||||||
|
1. 计算当前时刻的 Z-Score(对比同时间片历史基线)
|
||||||
|
2. 构建最近10分钟的 Z-Score 序列
|
||||||
|
3. LSTM 计算重构误差(ML分数)
|
||||||
|
4. 规则评分(基于 Z-Score 的规则)
|
||||||
|
5. 滑动窗口确认:最近5分钟内是否有足够多的时刻超标
|
||||||
|
6. 只有通过持续性确认的才输出为异动
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import json
|
||||||
|
import pickle
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
from collections import defaultdict, deque
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import torch
|
||||||
|
from sqlalchemy import create_engine, text
|
||||||
|
from elasticsearch import Elasticsearch
|
||||||
|
from clickhouse_driver import Client
|
||||||
|
|
||||||
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
|
||||||
|
from ml.model import TransformerAutoencoder
|
||||||
|
|
||||||
|
# ==================== 配置 ====================
|
||||||
|
|
||||||
|
MYSQL_ENGINE = create_engine(
|
||||||
|
"mysql+pymysql://root:Zzl5588161!@192.168.1.5:3306/stock",
|
||||||
|
echo=False
|
||||||
|
)
|
||||||
|
|
||||||
|
ES_CLIENT = Elasticsearch(['http://127.0.0.1:9200'])
|
||||||
|
ES_INDEX = 'concept_library_v3'
|
||||||
|
|
||||||
|
CLICKHOUSE_CONFIG = {
|
||||||
|
'host': '127.0.0.1',
|
||||||
|
'port': 9000,
|
||||||
|
'user': 'default',
|
||||||
|
'password': 'Zzl33818!',
|
||||||
|
'database': 'stock'
|
||||||
|
}
|
||||||
|
|
||||||
|
REFERENCE_INDEX = '000001.SH'
|
||||||
|
|
||||||
|
# 检测配置
|
||||||
|
CONFIG = {
|
||||||
|
# 序列配置
|
||||||
|
'seq_len': 10, # LSTM 序列长度(分钟)
|
||||||
|
|
||||||
|
# 持续性确认配置(核心!)
|
||||||
|
'confirm_window': 5, # 确认窗口(分钟)
|
||||||
|
'confirm_ratio': 0.6, # 确认比例(60%时刻需要超标)
|
||||||
|
|
||||||
|
# Z-Score 阈值
|
||||||
|
'alpha_zscore_threshold': 2.0, # Alpha Z-Score 阈值
|
||||||
|
'amt_zscore_threshold': 2.5, # 成交额 Z-Score 阈值
|
||||||
|
|
||||||
|
# 融合权重
|
||||||
|
'rule_weight': 0.5,
|
||||||
|
'ml_weight': 0.5,
|
||||||
|
|
||||||
|
# 触发阈值
|
||||||
|
'rule_trigger': 60,
|
||||||
|
'ml_trigger': 70,
|
||||||
|
'fusion_trigger': 50,
|
||||||
|
|
||||||
|
# 冷却期
|
||||||
|
'cooldown_minutes': 10,
|
||||||
|
'max_alerts_per_minute': 15,
|
||||||
|
|
||||||
|
# Z-Score 截断
|
||||||
|
'zscore_clip': 5.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
# V2 特征列表
|
||||||
|
FEATURES_V2 = [
|
||||||
|
'alpha_zscore', 'amt_zscore', 'rank_zscore',
|
||||||
|
'momentum_3m', 'momentum_5m', 'limit_up_ratio'
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 工具函数 ====================
|
||||||
|
|
||||||
|
def get_ch_client():
|
||||||
|
return Client(**CLICKHOUSE_CONFIG)
|
||||||
|
|
||||||
|
|
||||||
|
def code_to_ch_format(code: str) -> str:
|
||||||
|
if not code or len(code) != 6 or not code.isdigit():
|
||||||
|
return None
|
||||||
|
if code.startswith('6'):
|
||||||
|
return f"{code}.SH"
|
||||||
|
elif code.startswith('0') or code.startswith('3'):
|
||||||
|
return f"{code}.SZ"
|
||||||
|
else:
|
||||||
|
return f"{code}.BJ"
|
||||||
|
|
||||||
|
|
||||||
|
def time_to_slot(ts) -> str:
|
||||||
|
"""时间戳转时间片(HH:MM)"""
|
||||||
|
if isinstance(ts, str):
|
||||||
|
return ts
|
||||||
|
return ts.strftime('%H:%M')
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 基线加载 ====================
|
||||||
|
|
||||||
|
def load_baselines(baseline_dir: str = 'ml/data_v2/baselines') -> Dict[str, pd.DataFrame]:
|
||||||
|
"""加载时间片基线"""
|
||||||
|
baseline_file = os.path.join(baseline_dir, 'baselines.pkl')
|
||||||
|
if os.path.exists(baseline_file):
|
||||||
|
with open(baseline_file, 'rb') as f:
|
||||||
|
return pickle.load(f)
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 规则评分(基于 Z-Score)====================
|
||||||
|
|
||||||
|
def score_rules_zscore(row: Dict) -> Tuple[float, List[str]]:
|
||||||
|
"""
|
||||||
|
基于 Z-Score 的规则评分
|
||||||
|
|
||||||
|
设计思路:Z-Score 已经标准化,直接用阈值判断
|
||||||
|
"""
|
||||||
|
score = 0.0
|
||||||
|
triggered = []
|
||||||
|
|
||||||
|
alpha_zscore = row.get('alpha_zscore', 0)
|
||||||
|
amt_zscore = row.get('amt_zscore', 0)
|
||||||
|
rank_zscore = row.get('rank_zscore', 0)
|
||||||
|
momentum_3m = row.get('momentum_3m', 0)
|
||||||
|
momentum_5m = row.get('momentum_5m', 0)
|
||||||
|
limit_up_ratio = row.get('limit_up_ratio', 0)
|
||||||
|
|
||||||
|
alpha_zscore_abs = abs(alpha_zscore)
|
||||||
|
amt_zscore_abs = abs(amt_zscore)
|
||||||
|
|
||||||
|
# ========== Alpha Z-Score 规则 ==========
|
||||||
|
if alpha_zscore_abs >= 4.0:
|
||||||
|
score += 25
|
||||||
|
triggered.append('alpha_zscore_extreme')
|
||||||
|
elif alpha_zscore_abs >= 3.0:
|
||||||
|
score += 18
|
||||||
|
triggered.append('alpha_zscore_strong')
|
||||||
|
elif alpha_zscore_abs >= 2.0:
|
||||||
|
score += 10
|
||||||
|
triggered.append('alpha_zscore_moderate')
|
||||||
|
|
||||||
|
# ========== 成交额 Z-Score 规则 ==========
|
||||||
|
if amt_zscore >= 4.0:
|
||||||
|
score += 20
|
||||||
|
triggered.append('amt_zscore_extreme')
|
||||||
|
elif amt_zscore >= 3.0:
|
||||||
|
score += 12
|
||||||
|
triggered.append('amt_zscore_strong')
|
||||||
|
elif amt_zscore >= 2.0:
|
||||||
|
score += 6
|
||||||
|
triggered.append('amt_zscore_moderate')
|
||||||
|
|
||||||
|
# ========== 排名 Z-Score 规则 ==========
|
||||||
|
if abs(rank_zscore) >= 3.0:
|
||||||
|
score += 15
|
||||||
|
triggered.append('rank_zscore_extreme')
|
||||||
|
elif abs(rank_zscore) >= 2.0:
|
||||||
|
score += 8
|
||||||
|
triggered.append('rank_zscore_strong')
|
||||||
|
|
||||||
|
# ========== 动量规则 ==========
|
||||||
|
if momentum_3m >= 1.0:
|
||||||
|
score += 12
|
||||||
|
triggered.append('momentum_3m_strong')
|
||||||
|
elif momentum_3m >= 0.5:
|
||||||
|
score += 6
|
||||||
|
triggered.append('momentum_3m_moderate')
|
||||||
|
|
||||||
|
if momentum_5m >= 1.5:
|
||||||
|
score += 10
|
||||||
|
triggered.append('momentum_5m_strong')
|
||||||
|
|
||||||
|
# ========== 涨停比例规则 ==========
|
||||||
|
if limit_up_ratio >= 0.3:
|
||||||
|
score += 20
|
||||||
|
triggered.append('limit_up_extreme')
|
||||||
|
elif limit_up_ratio >= 0.15:
|
||||||
|
score += 12
|
||||||
|
triggered.append('limit_up_strong')
|
||||||
|
elif limit_up_ratio >= 0.08:
|
||||||
|
score += 5
|
||||||
|
triggered.append('limit_up_moderate')
|
||||||
|
|
||||||
|
# ========== 组合规则 ==========
|
||||||
|
# Alpha Z-Score + 成交额放大
|
||||||
|
if alpha_zscore_abs >= 2.0 and amt_zscore >= 2.0:
|
||||||
|
score += 15
|
||||||
|
triggered.append('combo_alpha_amt')
|
||||||
|
|
||||||
|
# Alpha Z-Score + 涨停
|
||||||
|
if alpha_zscore_abs >= 2.0 and limit_up_ratio >= 0.1:
|
||||||
|
score += 12
|
||||||
|
triggered.append('combo_alpha_limitup')
|
||||||
|
|
||||||
|
return min(score, 100), triggered
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== ML 评分器 ====================
|
||||||
|
|
||||||
|
class MLScorerV2:
|
||||||
|
"""V2 ML 评分器"""
|
||||||
|
|
||||||
|
def __init__(self, model_dir: str = 'ml/checkpoints_v2'):
|
||||||
|
self.model_dir = model_dir
|
||||||
|
self.model = None
|
||||||
|
self.thresholds = None
|
||||||
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
self._load_model()
|
||||||
|
|
||||||
|
def _load_model(self):
|
||||||
|
"""加载模型和阈值"""
|
||||||
|
model_path = os.path.join(self.model_dir, 'best_model.pt')
|
||||||
|
threshold_path = os.path.join(self.model_dir, 'thresholds.json')
|
||||||
|
config_path = os.path.join(self.model_dir, 'config.json')
|
||||||
|
|
||||||
|
if not os.path.exists(model_path):
|
||||||
|
print(f"警告: 模型文件不存在: {model_path}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 加载配置
|
||||||
|
with open(config_path, 'r') as f:
|
||||||
|
config = json.load(f)
|
||||||
|
|
||||||
|
# 创建模型
|
||||||
|
model_config = config.get('model', {})
|
||||||
|
self.model = TransformerAutoencoder(**model_config)
|
||||||
|
|
||||||
|
# 加载权重
|
||||||
|
checkpoint = torch.load(model_path, map_location=self.device)
|
||||||
|
self.model.load_state_dict(checkpoint['model_state_dict'])
|
||||||
|
self.model.to(self.device)
|
||||||
|
self.model.eval()
|
||||||
|
|
||||||
|
# 加载阈值
|
||||||
|
if os.path.exists(threshold_path):
|
||||||
|
with open(threshold_path, 'r') as f:
|
||||||
|
self.thresholds = json.load(f)
|
||||||
|
|
||||||
|
print(f"V2 模型加载完成: {model_path}")
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def score_batch(self, sequences: np.ndarray) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
批量计算 ML 分数
|
||||||
|
|
||||||
|
返回 0-100 的分数,越高越异常
|
||||||
|
"""
|
||||||
|
if self.model is None:
|
||||||
|
return np.zeros(len(sequences))
|
||||||
|
|
||||||
|
# 转换为 tensor
|
||||||
|
x = torch.FloatTensor(sequences).to(self.device)
|
||||||
|
|
||||||
|
# 计算重构误差
|
||||||
|
errors = self.model.compute_reconstruction_error(x, reduction='none')
|
||||||
|
# 取最后一个时刻的误差
|
||||||
|
last_errors = errors[:, -1].cpu().numpy()
|
||||||
|
|
||||||
|
# 转换为 0-100 分数
|
||||||
|
if self.thresholds:
|
||||||
|
p50 = self.thresholds.get('median', 0.1)
|
||||||
|
p99 = self.thresholds.get('p99', 1.0)
|
||||||
|
|
||||||
|
# 线性映射:p50 -> 50分,p99 -> 99分
|
||||||
|
scores = 50 + (last_errors - p50) / (p99 - p50) * 49
|
||||||
|
scores = np.clip(scores, 0, 100)
|
||||||
|
else:
|
||||||
|
# 没有阈值时,简单归一化
|
||||||
|
scores = last_errors * 100
|
||||||
|
scores = np.clip(scores, 0, 100)
|
||||||
|
|
||||||
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 实时数据管理器 ====================
|
||||||
|
|
||||||
|
class RealtimeDataManagerV2:
|
||||||
|
"""
|
||||||
|
V2 实时数据管理器
|
||||||
|
|
||||||
|
维护:
|
||||||
|
1. 每个概念的历史 Z-Score 序列(用于 LSTM 输入)
|
||||||
|
2. 每个概念的异动候选队列(用于持续性确认)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, concepts: List[dict], baselines: Dict[str, pd.DataFrame]):
|
||||||
|
self.concepts = {c['concept_id']: c for c in concepts}
|
||||||
|
self.baselines = baselines
|
||||||
|
|
||||||
|
# 概念到股票的映射
|
||||||
|
self.concept_stocks = {c['concept_id']: set(c['stocks']) for c in concepts}
|
||||||
|
|
||||||
|
# 历史 Z-Score 序列(每个概念)
|
||||||
|
# {concept_id: deque([(timestamp, features_dict), ...], maxlen=seq_len)}
|
||||||
|
self.zscore_history = defaultdict(lambda: deque(maxlen=CONFIG['seq_len']))
|
||||||
|
|
||||||
|
# 异动候选队列(用于持续性确认)
|
||||||
|
# {concept_id: deque([(timestamp, score), ...], maxlen=confirm_window)}
|
||||||
|
self.anomaly_candidates = defaultdict(lambda: deque(maxlen=CONFIG['confirm_window']))
|
||||||
|
|
||||||
|
# 冷却期记录
|
||||||
|
self.cooldown = {}
|
||||||
|
|
||||||
|
# 上一次更新的时间戳
|
||||||
|
self.last_timestamp = None
|
||||||
|
|
||||||
|
def compute_zscore_features(
|
||||||
|
self,
|
||||||
|
concept_id: str,
|
||||||
|
timestamp,
|
||||||
|
alpha: float,
|
||||||
|
total_amt: float,
|
||||||
|
rank_pct: float,
|
||||||
|
limit_up_ratio: float
|
||||||
|
) -> Optional[Dict]:
|
||||||
|
"""计算单个概念单个时刻的 Z-Score 特征"""
|
||||||
|
if concept_id not in self.baselines:
|
||||||
|
return None
|
||||||
|
|
||||||
|
baseline = self.baselines[concept_id]
|
||||||
|
time_slot = time_to_slot(timestamp)
|
||||||
|
|
||||||
|
# 查找对应时间片的基线
|
||||||
|
bl_row = baseline[baseline['time_slot'] == time_slot]
|
||||||
|
if bl_row.empty:
|
||||||
|
return None
|
||||||
|
|
||||||
|
bl = bl_row.iloc[0]
|
||||||
|
|
||||||
|
# 检查样本量
|
||||||
|
if bl.get('sample_count', 0) < 10:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 计算 Z-Score
|
||||||
|
alpha_zscore = (alpha - bl['alpha_mean']) / bl['alpha_std']
|
||||||
|
amt_zscore = (total_amt - bl['amt_mean']) / bl['amt_std']
|
||||||
|
rank_zscore = (rank_pct - bl['rank_mean']) / bl['rank_std']
|
||||||
|
|
||||||
|
# 截断
|
||||||
|
clip = CONFIG['zscore_clip']
|
||||||
|
alpha_zscore = np.clip(alpha_zscore, -clip, clip)
|
||||||
|
amt_zscore = np.clip(amt_zscore, -clip, clip)
|
||||||
|
rank_zscore = np.clip(rank_zscore, -clip, clip)
|
||||||
|
|
||||||
|
# 计算动量(需要历史)
|
||||||
|
history = self.zscore_history[concept_id]
|
||||||
|
momentum_3m = 0
|
||||||
|
momentum_5m = 0
|
||||||
|
|
||||||
|
if len(history) >= 3:
|
||||||
|
recent_alphas = [h[1]['alpha'] for h in list(history)[-3:]]
|
||||||
|
older_alphas = [h[1]['alpha'] for h in list(history)[-6:-3]] if len(history) >= 6 else [alpha]
|
||||||
|
momentum_3m = np.mean(recent_alphas) - np.mean(older_alphas)
|
||||||
|
|
||||||
|
if len(history) >= 5:
|
||||||
|
recent_alphas = [h[1]['alpha'] for h in list(history)[-5:]]
|
||||||
|
older_alphas = [h[1]['alpha'] for h in list(history)[-10:-5]] if len(history) >= 10 else [alpha]
|
||||||
|
momentum_5m = np.mean(recent_alphas) - np.mean(older_alphas)
|
||||||
|
|
||||||
|
return {
|
||||||
|
'alpha': alpha,
|
||||||
|
'alpha_zscore': alpha_zscore,
|
||||||
|
'amt_zscore': amt_zscore,
|
||||||
|
'rank_zscore': rank_zscore,
|
||||||
|
'momentum_3m': momentum_3m,
|
||||||
|
'momentum_5m': momentum_5m,
|
||||||
|
'limit_up_ratio': limit_up_ratio,
|
||||||
|
'total_amt': total_amt,
|
||||||
|
'rank_pct': rank_pct,
|
||||||
|
}
|
||||||
|
|
||||||
|
def update(self, concept_id: str, timestamp, features: Dict):
|
||||||
|
"""更新概念的历史数据"""
|
||||||
|
self.zscore_history[concept_id].append((timestamp, features))
|
||||||
|
|
||||||
|
def get_sequence(self, concept_id: str) -> Optional[np.ndarray]:
|
||||||
|
"""获取用于 LSTM 的序列"""
|
||||||
|
history = self.zscore_history[concept_id]
|
||||||
|
|
||||||
|
if len(history) < CONFIG['seq_len']:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 提取特征
|
||||||
|
feature_list = []
|
||||||
|
for _, features in history:
|
||||||
|
feature_list.append([
|
||||||
|
features['alpha_zscore'],
|
||||||
|
features['amt_zscore'],
|
||||||
|
features['rank_zscore'],
|
||||||
|
features['momentum_3m'],
|
||||||
|
features['momentum_5m'],
|
||||||
|
features['limit_up_ratio'],
|
||||||
|
])
|
||||||
|
|
||||||
|
return np.array(feature_list)
|
||||||
|
|
||||||
|
def add_anomaly_candidate(self, concept_id: str, timestamp, score: float):
|
||||||
|
"""添加异动候选"""
|
||||||
|
self.anomaly_candidates[concept_id].append((timestamp, score))
|
||||||
|
|
||||||
|
def check_sustained_anomaly(self, concept_id: str, threshold: float) -> Tuple[bool, float]:
|
||||||
|
"""
|
||||||
|
检查是否为持续性异动
|
||||||
|
|
||||||
|
返回:(是否确认, 确认比例)
|
||||||
|
"""
|
||||||
|
candidates = self.anomaly_candidates[concept_id]
|
||||||
|
|
||||||
|
if len(candidates) < CONFIG['confirm_window']:
|
||||||
|
return False, 0.0
|
||||||
|
|
||||||
|
# 统计超过阈值的时刻数量
|
||||||
|
exceed_count = sum(1 for _, score in candidates if score >= threshold)
|
||||||
|
ratio = exceed_count / len(candidates)
|
||||||
|
|
||||||
|
return ratio >= CONFIG['confirm_ratio'], ratio
|
||||||
|
|
||||||
|
def check_cooldown(self, concept_id: str, timestamp) -> bool:
|
||||||
|
"""检查是否在冷却期"""
|
||||||
|
if concept_id not in self.cooldown:
|
||||||
|
return False
|
||||||
|
|
||||||
|
last_alert = self.cooldown[concept_id]
|
||||||
|
try:
|
||||||
|
diff = (timestamp - last_alert).total_seconds() / 60
|
||||||
|
return diff < CONFIG['cooldown_minutes']
|
||||||
|
except:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def set_cooldown(self, concept_id: str, timestamp):
|
||||||
|
"""设置冷却期"""
|
||||||
|
self.cooldown[concept_id] = timestamp
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 异动检测器 V2 ====================
|
||||||
|
|
||||||
|
class AnomalyDetectorV2:
|
||||||
|
"""
|
||||||
|
V2 异动检测器
|
||||||
|
|
||||||
|
核心流程:
|
||||||
|
1. 获取实时数据
|
||||||
|
2. 计算 Z-Score 特征
|
||||||
|
3. 规则评分 + ML 评分
|
||||||
|
4. 持续性确认
|
||||||
|
5. 输出异动
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_dir: str = 'ml/checkpoints_v2',
|
||||||
|
baseline_dir: str = 'ml/data_v2/baselines'
|
||||||
|
):
|
||||||
|
# 加载概念
|
||||||
|
self.concepts = self._load_concepts()
|
||||||
|
|
||||||
|
# 加载基线
|
||||||
|
self.baselines = load_baselines(baseline_dir)
|
||||||
|
print(f"加载了 {len(self.baselines)} 个概念的基线")
|
||||||
|
|
||||||
|
# 初始化 ML 评分器
|
||||||
|
self.ml_scorer = MLScorerV2(model_dir)
|
||||||
|
|
||||||
|
# 初始化数据管理器
|
||||||
|
self.data_manager = RealtimeDataManagerV2(self.concepts, self.baselines)
|
||||||
|
|
||||||
|
# 收集所有股票
|
||||||
|
self.all_stocks = list(set(s for c in self.concepts for s in c['stocks']))
|
||||||
|
|
||||||
|
def _load_concepts(self) -> List[dict]:
|
||||||
|
"""从 ES 加载概念"""
|
||||||
|
concepts = []
|
||||||
|
query = {"query": {"match_all": {}}, "size": 100, "_source": ["concept_id", "concept", "stocks"]}
|
||||||
|
|
||||||
|
resp = ES_CLIENT.search(index=ES_INDEX, body=query, scroll='2m')
|
||||||
|
scroll_id = resp['_scroll_id']
|
||||||
|
hits = resp['hits']['hits']
|
||||||
|
|
||||||
|
while len(hits) > 0:
|
||||||
|
for hit in hits:
|
||||||
|
source = hit['_source']
|
||||||
|
stocks = []
|
||||||
|
if 'stocks' in source and isinstance(source['stocks'], list):
|
||||||
|
for stock in source['stocks']:
|
||||||
|
if isinstance(stock, dict) and 'code' in stock and stock['code']:
|
||||||
|
stocks.append(stock['code'])
|
||||||
|
if stocks:
|
||||||
|
concepts.append({
|
||||||
|
'concept_id': source.get('concept_id'),
|
||||||
|
'concept_name': source.get('concept'),
|
||||||
|
'stocks': stocks
|
||||||
|
})
|
||||||
|
|
||||||
|
resp = ES_CLIENT.scroll(scroll_id=scroll_id, scroll='2m')
|
||||||
|
scroll_id = resp['_scroll_id']
|
||||||
|
hits = resp['hits']['hits']
|
||||||
|
|
||||||
|
ES_CLIENT.clear_scroll(scroll_id=scroll_id)
|
||||||
|
print(f"加载了 {len(concepts)} 个概念")
|
||||||
|
return concepts
|
||||||
|
|
||||||
|
def detect(self, trade_date: str) -> List[Dict]:
|
||||||
|
"""
|
||||||
|
检测指定日期的异动
|
||||||
|
|
||||||
|
返回异动列表
|
||||||
|
"""
|
||||||
|
print(f"\n检测 {trade_date} 的异动...")
|
||||||
|
|
||||||
|
# 获取原始数据
|
||||||
|
raw_features = self._compute_raw_features(trade_date)
|
||||||
|
if raw_features.empty:
|
||||||
|
print("无数据")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 按时间排序
|
||||||
|
timestamps = sorted(raw_features['timestamp'].unique())
|
||||||
|
print(f"时间点数: {len(timestamps)}")
|
||||||
|
|
||||||
|
all_alerts = []
|
||||||
|
|
||||||
|
for ts in timestamps:
|
||||||
|
ts_data = raw_features[raw_features['timestamp'] == ts]
|
||||||
|
ts_alerts = self._process_timestamp(ts, ts_data, trade_date)
|
||||||
|
all_alerts.extend(ts_alerts)
|
||||||
|
|
||||||
|
print(f"共检测到 {len(all_alerts)} 个异动")
|
||||||
|
return all_alerts
|
||||||
|
|
||||||
|
def _compute_raw_features(self, trade_date: str) -> pd.DataFrame:
|
||||||
|
"""计算原始特征(同 prepare_data_v2)"""
|
||||||
|
# 这里简化处理,直接调用数据准备逻辑
|
||||||
|
from prepare_data_v2 import compute_raw_concept_features
|
||||||
|
return compute_raw_concept_features(trade_date, self.concepts, self.all_stocks)
|
||||||
|
|
||||||
|
def _process_timestamp(self, timestamp, ts_data: pd.DataFrame, trade_date: str) -> List[Dict]:
|
||||||
|
"""处理单个时间戳"""
|
||||||
|
alerts = []
|
||||||
|
candidates = [] # (concept_id, features, rule_score, triggered_rules)
|
||||||
|
|
||||||
|
for _, row in ts_data.iterrows():
|
||||||
|
concept_id = row['concept_id']
|
||||||
|
|
||||||
|
# 计算 Z-Score 特征
|
||||||
|
features = self.data_manager.compute_zscore_features(
|
||||||
|
concept_id, timestamp,
|
||||||
|
row['alpha'], row['total_amt'], row['rank_pct'], row['limit_up_ratio']
|
||||||
|
)
|
||||||
|
|
||||||
|
if features is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 更新历史
|
||||||
|
self.data_manager.update(concept_id, timestamp, features)
|
||||||
|
|
||||||
|
# 规则评分
|
||||||
|
rule_score, triggered_rules = score_rules_zscore(features)
|
||||||
|
|
||||||
|
# 收集候选
|
||||||
|
candidates.append((concept_id, features, rule_score, triggered_rules))
|
||||||
|
|
||||||
|
if not candidates:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 批量 ML 评分
|
||||||
|
sequences = []
|
||||||
|
valid_candidates = []
|
||||||
|
|
||||||
|
for concept_id, features, rule_score, triggered_rules in candidates:
|
||||||
|
seq = self.data_manager.get_sequence(concept_id)
|
||||||
|
if seq is not None:
|
||||||
|
sequences.append(seq)
|
||||||
|
valid_candidates.append((concept_id, features, rule_score, triggered_rules))
|
||||||
|
|
||||||
|
if not sequences:
|
||||||
|
return []
|
||||||
|
|
||||||
|
sequences = np.array(sequences)
|
||||||
|
ml_scores = self.ml_scorer.score_batch(sequences)
|
||||||
|
|
||||||
|
# 融合评分 + 持续性确认
|
||||||
|
for i, (concept_id, features, rule_score, triggered_rules) in enumerate(valid_candidates):
|
||||||
|
ml_score = ml_scores[i]
|
||||||
|
final_score = CONFIG['rule_weight'] * rule_score + CONFIG['ml_weight'] * ml_score
|
||||||
|
|
||||||
|
# 判断是否触发
|
||||||
|
is_triggered = (
|
||||||
|
rule_score >= CONFIG['rule_trigger'] or
|
||||||
|
ml_score >= CONFIG['ml_trigger'] or
|
||||||
|
final_score >= CONFIG['fusion_trigger']
|
||||||
|
)
|
||||||
|
|
||||||
|
# 添加到候选队列
|
||||||
|
self.data_manager.add_anomaly_candidate(concept_id, timestamp, final_score)
|
||||||
|
|
||||||
|
if not is_triggered:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 检查冷却期
|
||||||
|
if self.data_manager.check_cooldown(concept_id, timestamp):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 持续性确认
|
||||||
|
is_sustained, confirm_ratio = self.data_manager.check_sustained_anomaly(
|
||||||
|
concept_id, CONFIG['fusion_trigger']
|
||||||
|
)
|
||||||
|
|
||||||
|
if not is_sustained:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 确认为异动!
|
||||||
|
self.data_manager.set_cooldown(concept_id, timestamp)
|
||||||
|
|
||||||
|
# 确定异动类型
|
||||||
|
alpha = features['alpha']
|
||||||
|
if alpha >= 1.5:
|
||||||
|
alert_type = 'surge_up'
|
||||||
|
elif alpha <= -1.5:
|
||||||
|
alert_type = 'surge_down'
|
||||||
|
elif features['amt_zscore'] >= 3.0:
|
||||||
|
alert_type = 'volume_spike'
|
||||||
|
else:
|
||||||
|
alert_type = 'surge'
|
||||||
|
|
||||||
|
# 确定触发原因
|
||||||
|
if rule_score >= CONFIG['rule_trigger']:
|
||||||
|
trigger_reason = f'规则({rule_score:.0f})+持续确认({confirm_ratio:.0%})'
|
||||||
|
elif ml_score >= CONFIG['ml_trigger']:
|
||||||
|
trigger_reason = f'ML({ml_score:.0f})+持续确认({confirm_ratio:.0%})'
|
||||||
|
else:
|
||||||
|
trigger_reason = f'融合({final_score:.0f})+持续确认({confirm_ratio:.0%})'
|
||||||
|
|
||||||
|
alerts.append({
|
||||||
|
'concept_id': concept_id,
|
||||||
|
'concept_name': self.data_manager.concepts.get(concept_id, {}).get('concept_name', concept_id),
|
||||||
|
'alert_time': timestamp,
|
||||||
|
'trade_date': trade_date,
|
||||||
|
'alert_type': alert_type,
|
||||||
|
'final_score': final_score,
|
||||||
|
'rule_score': rule_score,
|
||||||
|
'ml_score': ml_score,
|
||||||
|
'trigger_reason': trigger_reason,
|
||||||
|
'confirm_ratio': confirm_ratio,
|
||||||
|
'alpha': alpha,
|
||||||
|
'alpha_zscore': features['alpha_zscore'],
|
||||||
|
'amt_zscore': features['amt_zscore'],
|
||||||
|
'rank_zscore': features['rank_zscore'],
|
||||||
|
'momentum_3m': features['momentum_3m'],
|
||||||
|
'momentum_5m': features['momentum_5m'],
|
||||||
|
'limit_up_ratio': features['limit_up_ratio'],
|
||||||
|
'triggered_rules': triggered_rules,
|
||||||
|
})
|
||||||
|
|
||||||
|
# 每分钟最多 N 个
|
||||||
|
if len(alerts) > CONFIG['max_alerts_per_minute']:
|
||||||
|
alerts = sorted(alerts, key=lambda x: x['final_score'], reverse=True)
|
||||||
|
alerts = alerts[:CONFIG['max_alerts_per_minute']]
|
||||||
|
|
||||||
|
return alerts
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 主函数 ====================
|
||||||
|
|
||||||
|
def main():
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='V2 异动检测器')
|
||||||
|
parser.add_argument('--date', type=str, default=None, help='检测日期(默认今天)')
|
||||||
|
parser.add_argument('--model_dir', type=str, default='ml/checkpoints_v2')
|
||||||
|
parser.add_argument('--baseline_dir', type=str, default='ml/data_v2/baselines')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
trade_date = args.date or datetime.now().strftime('%Y-%m-%d')
|
||||||
|
|
||||||
|
detector = AnomalyDetectorV2(
|
||||||
|
model_dir=args.model_dir,
|
||||||
|
baseline_dir=args.baseline_dir
|
||||||
|
)
|
||||||
|
|
||||||
|
alerts = detector.detect(trade_date)
|
||||||
|
|
||||||
|
print(f"\n检测结果:")
|
||||||
|
for alert in alerts[:20]:
|
||||||
|
print(f" [{alert['alert_time'].strftime('%H:%M') if hasattr(alert['alert_time'], 'strftime') else alert['alert_time']}] "
|
||||||
|
f"{alert['concept_name']} ({alert['alert_type']}) "
|
||||||
|
f"分数={alert['final_score']:.0f} "
|
||||||
|
f"确认率={alert['confirm_ratio']:.0%}")
|
||||||
|
|
||||||
|
if len(alerts) > 20:
|
||||||
|
print(f" ... 共 {len(alerts)} 个异动")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -85,9 +85,12 @@ class LSTMAutoencoder(nn.Module):
|
|||||||
nn.Tanh(), # 限制范围,增加约束
|
nn.Tanh(), # 限制范围,增加约束
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 使用 LeakyReLU 替代 ReLU
|
||||||
|
# 原因:Z-Score 数据范围是 [-5, +5],ReLU 会截断负值,丢失跌幅信息
|
||||||
|
# LeakyReLU 保留负值信号(乘以 0.1)
|
||||||
self.bottleneck_up = nn.Sequential(
|
self.bottleneck_up = nn.Sequential(
|
||||||
nn.Linear(latent_dim, hidden_dim),
|
nn.Linear(latent_dim, hidden_dim),
|
||||||
nn.ReLU(),
|
nn.LeakyReLU(negative_slope=0.1),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Decoder: 单向 LSTM
|
# Decoder: 单向 LSTM
|
||||||
|
|||||||
@@ -26,7 +26,9 @@ import hashlib
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Dict, List, Set, Tuple
|
from typing import Dict, List, Set, Tuple
|
||||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||||
|
from multiprocessing import Manager
|
||||||
|
import multiprocessing
|
||||||
import warnings
|
import warnings
|
||||||
warnings.filterwarnings('ignore')
|
warnings.filterwarnings('ignore')
|
||||||
|
|
||||||
@@ -128,7 +130,7 @@ def get_all_concepts() -> List[dict]:
|
|||||||
hits = resp['hits']['hits']
|
hits = resp['hits']['hits']
|
||||||
|
|
||||||
ES_CLIENT.clear_scroll(scroll_id=scroll_id)
|
ES_CLIENT.clear_scroll(scroll_id=scroll_id)
|
||||||
logger.info(f"获取到 {len(concepts)} 个概念")
|
print(f"获取到 {len(concepts)} 个概念")
|
||||||
return concepts
|
return concepts
|
||||||
|
|
||||||
|
|
||||||
@@ -148,7 +150,7 @@ def get_trading_days(start_date: str, end_date: str) -> List[str]:
|
|||||||
|
|
||||||
result = client.execute(query)
|
result = client.execute(query)
|
||||||
days = [row[0].strftime('%Y-%m-%d') for row in result]
|
days = [row[0].strftime('%Y-%m-%d') for row in result]
|
||||||
logger.info(f"找到 {len(days)} 个交易日: {days[0]} ~ {days[-1]}")
|
print(f"找到 {len(days)} 个交易日: {days[0]} ~ {days[-1]}")
|
||||||
return days
|
return days
|
||||||
|
|
||||||
|
|
||||||
@@ -223,21 +225,23 @@ def get_daily_index_data(trade_date: str, index_code: str = REFERENCE_INDEX) ->
|
|||||||
|
|
||||||
|
|
||||||
def get_prev_close(stock_codes: List[str], trade_date: str) -> Dict[str, float]:
|
def get_prev_close(stock_codes: List[str], trade_date: str) -> Dict[str, float]:
|
||||||
"""获取昨收价"""
|
"""获取昨收价(上一交易日的收盘价 F007N)"""
|
||||||
valid_codes = [c for c in stock_codes if c and len(c) == 6 and c.isdigit()]
|
valid_codes = [c for c in stock_codes if c and len(c) == 6 and c.isdigit()]
|
||||||
if not valid_codes:
|
if not valid_codes:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
codes_str = "','".join(valid_codes)
|
codes_str = "','".join(valid_codes)
|
||||||
|
|
||||||
|
# 注意:F007N 是"最近成交价"即当日收盘价,F002N 是"昨日收盘价"
|
||||||
|
# 我们需要查上一交易日的 F007N(那天的收盘价)作为今天的昨收
|
||||||
query = f"""
|
query = f"""
|
||||||
SELECT SECCODE, F002N
|
SELECT SECCODE, F007N
|
||||||
FROM ea_trade
|
FROM ea_trade
|
||||||
WHERE SECCODE IN ('{codes_str}')
|
WHERE SECCODE IN ('{codes_str}')
|
||||||
AND TRADEDATE = (
|
AND TRADEDATE = (
|
||||||
SELECT MAX(TRADEDATE) FROM ea_trade WHERE TRADEDATE < '{trade_date}'
|
SELECT MAX(TRADEDATE) FROM ea_trade WHERE TRADEDATE < '{trade_date}'
|
||||||
)
|
)
|
||||||
AND F002N IS NOT NULL AND F002N > 0
|
AND F007N IS NOT NULL AND F007N > 0
|
||||||
"""
|
"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -245,7 +249,7 @@ def get_prev_close(stock_codes: List[str], trade_date: str) -> Dict[str, float]:
|
|||||||
result = conn.execute(text(query))
|
result = conn.execute(text(query))
|
||||||
return {row[0]: float(row[1]) for row in result if row[1]}
|
return {row[0]: float(row[1]) for row in result if row[1]}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取昨收价失败: {e}")
|
print(f"获取昨收价失败: {e}")
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
@@ -264,7 +268,7 @@ def get_index_prev_close(trade_date: str, index_code: str = REFERENCE_INDEX) ->
|
|||||||
if result and result[0]:
|
if result and result[0]:
|
||||||
return float(result[0])
|
return float(result[0])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取指数昨收失败: {e}")
|
print(f"获取指数昨收失败: {e}")
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -285,25 +289,19 @@ def compute_daily_features(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# 1. 获取数据
|
# 1. 获取数据
|
||||||
logger.info(f" 获取股票数据...")
|
|
||||||
stock_df = get_daily_stock_data(trade_date, all_stocks)
|
stock_df = get_daily_stock_data(trade_date, all_stocks)
|
||||||
if stock_df.empty:
|
if stock_df.empty:
|
||||||
logger.warning(f" 无股票数据")
|
|
||||||
return pd.DataFrame()
|
return pd.DataFrame()
|
||||||
|
|
||||||
logger.info(f" 获取指数数据...")
|
|
||||||
index_df = get_daily_index_data(trade_date)
|
index_df = get_daily_index_data(trade_date)
|
||||||
if index_df.empty:
|
if index_df.empty:
|
||||||
logger.warning(f" 无指数数据")
|
|
||||||
return pd.DataFrame()
|
return pd.DataFrame()
|
||||||
|
|
||||||
# 2. 获取昨收价
|
# 2. 获取昨收价
|
||||||
logger.info(f" 获取昨收价...")
|
|
||||||
prev_close = get_prev_close(all_stocks, trade_date)
|
prev_close = get_prev_close(all_stocks, trade_date)
|
||||||
index_prev_close = get_index_prev_close(trade_date)
|
index_prev_close = get_index_prev_close(trade_date)
|
||||||
|
|
||||||
if not prev_close or not index_prev_close:
|
if not prev_close or not index_prev_close:
|
||||||
logger.warning(f" 无昨收价数据")
|
|
||||||
return pd.DataFrame()
|
return pd.DataFrame()
|
||||||
|
|
||||||
# 3. 计算股票涨跌幅和成交额
|
# 3. 计算股票涨跌幅和成交额
|
||||||
@@ -317,7 +315,6 @@ def compute_daily_features(
|
|||||||
|
|
||||||
# 5. 获取所有时间点
|
# 5. 获取所有时间点
|
||||||
timestamps = sorted(stock_df['timestamp'].unique())
|
timestamps = sorted(stock_df['timestamp'].unique())
|
||||||
logger.info(f" 时间点数: {len(timestamps)}")
|
|
||||||
|
|
||||||
# 6. 按时间点计算概念特征
|
# 6. 按时间点计算概念特征
|
||||||
results = []
|
results = []
|
||||||
@@ -414,87 +411,126 @@ def compute_daily_features(
|
|||||||
if amt_delta_std > 0:
|
if amt_delta_std > 0:
|
||||||
final_df['amt_delta'] = final_df['amt_delta'] / amt_delta_std
|
final_df['amt_delta'] = final_df['amt_delta'] / amt_delta_std
|
||||||
|
|
||||||
logger.info(f" 计算完成: {len(final_df)} 条记录")
|
|
||||||
return final_df
|
return final_df
|
||||||
|
|
||||||
|
|
||||||
# ==================== 主流程 ====================
|
# ==================== 主流程 ====================
|
||||||
|
|
||||||
def process_single_day(trade_date: str, concepts: List[dict], all_stocks: List[str]) -> str:
|
def process_single_day(args) -> Tuple[str, bool]:
|
||||||
"""处理单个交易日"""
|
"""
|
||||||
|
处理单个交易日(多进程版本)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args: (trade_date, concepts, all_stocks) 元组
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(trade_date, success) 元组
|
||||||
|
"""
|
||||||
|
trade_date, concepts, all_stocks = args
|
||||||
output_file = os.path.join(OUTPUT_DIR, f'features_{trade_date}.parquet')
|
output_file = os.path.join(OUTPUT_DIR, f'features_{trade_date}.parquet')
|
||||||
|
|
||||||
# 检查是否已处理
|
# 检查是否已处理
|
||||||
if os.path.exists(output_file):
|
if os.path.exists(output_file):
|
||||||
logger.info(f"[{trade_date}] 已存在,跳过")
|
print(f"[{trade_date}] 已存在,跳过")
|
||||||
return output_file
|
return (trade_date, True)
|
||||||
|
|
||||||
logger.info(f"[{trade_date}] 开始处理...")
|
print(f"[{trade_date}] 开始处理...")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
df = compute_daily_features(trade_date, concepts, all_stocks)
|
df = compute_daily_features(trade_date, concepts, all_stocks)
|
||||||
|
|
||||||
if df.empty:
|
if df.empty:
|
||||||
logger.warning(f"[{trade_date}] 无数据")
|
print(f"[{trade_date}] 无数据")
|
||||||
return None
|
return (trade_date, False)
|
||||||
|
|
||||||
# 保存
|
# 保存
|
||||||
df.to_parquet(output_file, index=False)
|
df.to_parquet(output_file, index=False)
|
||||||
logger.info(f"[{trade_date}] 保存完成: {output_file}")
|
print(f"[{trade_date}] 保存完成")
|
||||||
return output_file
|
return (trade_date, True)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[{trade_date}] 处理失败: {e}")
|
print(f"[{trade_date}] 处理失败: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
return None
|
return (trade_date, False)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
import argparse
|
import argparse
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='准备训练数据')
|
parser = argparse.ArgumentParser(description='准备训练数据')
|
||||||
parser.add_argument('--start', type=str, default='2022-01-01', help='开始日期')
|
parser.add_argument('--start', type=str, default='2022-01-01', help='开始日期')
|
||||||
parser.add_argument('--end', type=str, default=None, help='结束日期(默认今天)')
|
parser.add_argument('--end', type=str, default=None, help='结束日期(默认今天)')
|
||||||
parser.add_argument('--workers', type=int, default=1, help='并行数(建议1,避免数据库压力)')
|
parser.add_argument('--workers', type=int, default=18, help='并行进程数(默认18)')
|
||||||
|
parser.add_argument('--force', action='store_true', help='强制重新处理已存在的文件')
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
end_date = args.end or datetime.now().strftime('%Y-%m-%d')
|
end_date = args.end or datetime.now().strftime('%Y-%m-%d')
|
||||||
|
|
||||||
logger.info("=" * 60)
|
print("=" * 60)
|
||||||
logger.info("数据准备 - Transformer Autoencoder 训练数据")
|
print("数据准备 - Transformer Autoencoder 训练数据")
|
||||||
logger.info("=" * 60)
|
print("=" * 60)
|
||||||
logger.info(f"日期范围: {args.start} ~ {end_date}")
|
print(f"日期范围: {args.start} ~ {end_date}")
|
||||||
|
print(f"并行进程数: {args.workers}")
|
||||||
|
|
||||||
# 1. 获取概念列表
|
# 1. 获取概念列表
|
||||||
concepts = get_all_concepts()
|
concepts = get_all_concepts()
|
||||||
|
|
||||||
# 收集所有股票
|
# 收集所有股票
|
||||||
all_stocks = list(set(s for c in concepts for s in c['stocks']))
|
all_stocks = list(set(s for c in concepts for s in c['stocks']))
|
||||||
logger.info(f"股票总数: {len(all_stocks)}")
|
print(f"股票总数: {len(all_stocks)}")
|
||||||
|
|
||||||
# 2. 获取交易日列表
|
# 2. 获取交易日列表
|
||||||
trading_days = get_trading_days(args.start, end_date)
|
trading_days = get_trading_days(args.start, end_date)
|
||||||
|
|
||||||
if not trading_days:
|
if not trading_days:
|
||||||
logger.error("无交易日数据")
|
print("无交易日数据")
|
||||||
return
|
return
|
||||||
|
|
||||||
# 3. 处理每个交易日
|
# 如果强制模式,删除已有文件
|
||||||
logger.info(f"\n开始处理 {len(trading_days)} 个交易日...")
|
if args.force:
|
||||||
|
for trade_date in trading_days:
|
||||||
|
output_file = os.path.join(OUTPUT_DIR, f'features_{trade_date}.parquet')
|
||||||
|
if os.path.exists(output_file):
|
||||||
|
os.remove(output_file)
|
||||||
|
print(f"删除已有文件: {output_file}")
|
||||||
|
|
||||||
|
# 3. 准备任务参数
|
||||||
|
tasks = [(trade_date, concepts, all_stocks) for trade_date in trading_days]
|
||||||
|
|
||||||
|
print(f"\n开始处理 {len(trading_days)} 个交易日({args.workers} 进程并行)...")
|
||||||
|
|
||||||
|
# 4. 多进程处理
|
||||||
success_count = 0
|
success_count = 0
|
||||||
for i, trade_date in enumerate(trading_days):
|
failed_dates = []
|
||||||
logger.info(f"\n[{i+1}/{len(trading_days)}] {trade_date}")
|
|
||||||
result = process_single_day(trade_date, concepts, all_stocks)
|
|
||||||
if result:
|
|
||||||
success_count += 1
|
|
||||||
|
|
||||||
logger.info("\n" + "=" * 60)
|
with ProcessPoolExecutor(max_workers=args.workers) as executor:
|
||||||
logger.info(f"处理完成: {success_count}/{len(trading_days)} 个交易日")
|
# 提交所有任务
|
||||||
logger.info(f"数据保存在: {OUTPUT_DIR}")
|
futures = {executor.submit(process_single_day, task): task[0] for task in tasks}
|
||||||
logger.info("=" * 60)
|
|
||||||
|
# 使用 tqdm 显示进度
|
||||||
|
with tqdm(total=len(futures), desc="处理进度", unit="天") as pbar:
|
||||||
|
for future in as_completed(futures):
|
||||||
|
trade_date = futures[future]
|
||||||
|
try:
|
||||||
|
result_date, success = future.result()
|
||||||
|
if success:
|
||||||
|
success_count += 1
|
||||||
|
else:
|
||||||
|
failed_dates.append(result_date)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n[{trade_date}] 进程异常: {e}")
|
||||||
|
failed_dates.append(trade_date)
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print(f"处理完成: {success_count}/{len(trading_days)} 个交易日")
|
||||||
|
if failed_dates:
|
||||||
|
print(f"失败日期: {failed_dates[:10]}{'...' if len(failed_dates) > 10 else ''}")
|
||||||
|
print(f"数据保存在: {OUTPUT_DIR}")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
715
ml/prepare_data_v2.py
Normal file
715
ml/prepare_data_v2.py
Normal file
@@ -0,0 +1,715 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
数据准备 V2 - 基于时间片对齐的特征计算(修复版)
|
||||||
|
|
||||||
|
核心改进:
|
||||||
|
1. 时间片对齐:9:35 和历史的 9:35 比,而不是和前30分钟比
|
||||||
|
2. Z-Score 特征:相对于同时间片历史分布的偏离程度
|
||||||
|
3. 滚动窗口基线:每个日期使用它之前 N 天的数据作为基线(不是固定的最后 N 天!)
|
||||||
|
4. 基于 Z-Score 的动量:消除一天内波动率异构性
|
||||||
|
|
||||||
|
修复:
|
||||||
|
- 滚动窗口基线:避免未来数据泄露
|
||||||
|
- Z-Score 动量:消除早盘/尾盘波动率差异
|
||||||
|
- 进程级数据库单例:避免连接池爆炸
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from sqlalchemy import create_engine, text
|
||||||
|
from elasticsearch import Elasticsearch
|
||||||
|
from clickhouse_driver import Client
|
||||||
|
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||||
|
from typing import Dict, List, Tuple, Optional
|
||||||
|
from tqdm import tqdm
|
||||||
|
from collections import defaultdict
|
||||||
|
import warnings
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
warnings.filterwarnings('ignore')
|
||||||
|
|
||||||
|
# ==================== 配置 ====================
|
||||||
|
|
||||||
|
MYSQL_URL = "mysql+pymysql://root:Zzl5588161!@192.168.1.5:3306/stock"
|
||||||
|
ES_HOST = 'http://127.0.0.1:9200'
|
||||||
|
ES_INDEX = 'concept_library_v3'
|
||||||
|
|
||||||
|
CLICKHOUSE_CONFIG = {
|
||||||
|
'host': '127.0.0.1',
|
||||||
|
'port': 9000,
|
||||||
|
'user': 'default',
|
||||||
|
'password': 'Zzl33818!',
|
||||||
|
'database': 'stock'
|
||||||
|
}
|
||||||
|
|
||||||
|
REFERENCE_INDEX = '000001.SH'
|
||||||
|
|
||||||
|
# 输出目录
|
||||||
|
OUTPUT_DIR = os.path.join(os.path.dirname(__file__), 'data_v2')
|
||||||
|
BASELINE_DIR = os.path.join(OUTPUT_DIR, 'baselines')
|
||||||
|
RAW_CACHE_DIR = os.path.join(OUTPUT_DIR, 'raw_cache')
|
||||||
|
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
||||||
|
os.makedirs(BASELINE_DIR, exist_ok=True)
|
||||||
|
os.makedirs(RAW_CACHE_DIR, exist_ok=True)
|
||||||
|
|
||||||
|
# 特征配置
|
||||||
|
CONFIG = {
|
||||||
|
'baseline_days': 20, # 滚动窗口大小
|
||||||
|
'min_baseline_samples': 10, # 最少需要10个样本才算有效基线
|
||||||
|
'limit_up_threshold': 9.8,
|
||||||
|
'limit_down_threshold': -9.8,
|
||||||
|
'zscore_clip': 5.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 特征列表
|
||||||
|
FEATURES_V2 = [
|
||||||
|
'alpha', 'alpha_zscore', 'amt_zscore', 'rank_zscore',
|
||||||
|
'momentum_3m', 'momentum_5m', 'limit_up_ratio',
|
||||||
|
]
|
||||||
|
|
||||||
|
# ==================== 进程级单例(避免连接池爆炸)====================
|
||||||
|
|
||||||
|
# 进程级全局变量
|
||||||
|
_process_mysql_engine = None
|
||||||
|
_process_es_client = None
|
||||||
|
_process_ch_client = None
|
||||||
|
|
||||||
|
|
||||||
|
def init_process_connections():
|
||||||
|
"""进程初始化时调用,创建连接(单例)"""
|
||||||
|
global _process_mysql_engine, _process_es_client, _process_ch_client
|
||||||
|
_process_mysql_engine = create_engine(MYSQL_URL, echo=False, pool_pre_ping=True, pool_size=5)
|
||||||
|
_process_es_client = Elasticsearch([ES_HOST])
|
||||||
|
_process_ch_client = Client(**CLICKHOUSE_CONFIG)
|
||||||
|
|
||||||
|
|
||||||
|
def get_mysql_engine():
|
||||||
|
"""获取进程级 MySQL Engine(单例)"""
|
||||||
|
global _process_mysql_engine
|
||||||
|
if _process_mysql_engine is None:
|
||||||
|
_process_mysql_engine = create_engine(MYSQL_URL, echo=False, pool_pre_ping=True, pool_size=5)
|
||||||
|
return _process_mysql_engine
|
||||||
|
|
||||||
|
|
||||||
|
def get_es_client():
|
||||||
|
"""获取进程级 ES 客户端(单例)"""
|
||||||
|
global _process_es_client
|
||||||
|
if _process_es_client is None:
|
||||||
|
_process_es_client = Elasticsearch([ES_HOST])
|
||||||
|
return _process_es_client
|
||||||
|
|
||||||
|
|
||||||
|
def get_ch_client():
|
||||||
|
"""获取进程级 ClickHouse 客户端(单例)"""
|
||||||
|
global _process_ch_client
|
||||||
|
if _process_ch_client is None:
|
||||||
|
_process_ch_client = Client(**CLICKHOUSE_CONFIG)
|
||||||
|
return _process_ch_client
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 工具函数 ====================
|
||||||
|
|
||||||
|
def code_to_ch_format(code: str) -> str:
|
||||||
|
if not code or len(code) != 6 or not code.isdigit():
|
||||||
|
return None
|
||||||
|
if code.startswith('6'):
|
||||||
|
return f"{code}.SH"
|
||||||
|
elif code.startswith('0') or code.startswith('3'):
|
||||||
|
return f"{code}.SZ"
|
||||||
|
else:
|
||||||
|
return f"{code}.BJ"
|
||||||
|
|
||||||
|
|
||||||
|
def time_to_slot(ts) -> str:
|
||||||
|
"""将时间戳转换为时间片(HH:MM格式)"""
|
||||||
|
if isinstance(ts, str):
|
||||||
|
return ts
|
||||||
|
return ts.strftime('%H:%M')
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 获取概念列表 ====================
|
||||||
|
|
||||||
|
def get_all_concepts() -> List[dict]:
|
||||||
|
"""从ES获取所有叶子概念"""
|
||||||
|
es_client = get_es_client()
|
||||||
|
concepts = []
|
||||||
|
|
||||||
|
query = {
|
||||||
|
"query": {"match_all": {}},
|
||||||
|
"size": 100,
|
||||||
|
"_source": ["concept_id", "concept", "stocks"]
|
||||||
|
}
|
||||||
|
|
||||||
|
resp = es_client.search(index=ES_INDEX, body=query, scroll='2m')
|
||||||
|
scroll_id = resp['_scroll_id']
|
||||||
|
hits = resp['hits']['hits']
|
||||||
|
|
||||||
|
while len(hits) > 0:
|
||||||
|
for hit in hits:
|
||||||
|
source = hit['_source']
|
||||||
|
stocks = []
|
||||||
|
if 'stocks' in source and isinstance(source['stocks'], list):
|
||||||
|
for stock in source['stocks']:
|
||||||
|
if isinstance(stock, dict) and 'code' in stock and stock['code']:
|
||||||
|
stocks.append(stock['code'])
|
||||||
|
|
||||||
|
if stocks:
|
||||||
|
concepts.append({
|
||||||
|
'concept_id': source.get('concept_id'),
|
||||||
|
'concept_name': source.get('concept'),
|
||||||
|
'stocks': stocks
|
||||||
|
})
|
||||||
|
|
||||||
|
resp = es_client.scroll(scroll_id=scroll_id, scroll='2m')
|
||||||
|
scroll_id = resp['_scroll_id']
|
||||||
|
hits = resp['hits']['hits']
|
||||||
|
|
||||||
|
es_client.clear_scroll(scroll_id=scroll_id)
|
||||||
|
print(f"获取到 {len(concepts)} 个概念")
|
||||||
|
return concepts
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 获取交易日列表 ====================
|
||||||
|
|
||||||
|
def get_trading_days(start_date: str, end_date: str) -> List[str]:
|
||||||
|
"""获取交易日列表"""
|
||||||
|
client = get_ch_client()
|
||||||
|
|
||||||
|
query = f"""
|
||||||
|
SELECT DISTINCT toDate(timestamp) as trade_date
|
||||||
|
FROM stock_minute
|
||||||
|
WHERE toDate(timestamp) >= '{start_date}'
|
||||||
|
AND toDate(timestamp) <= '{end_date}'
|
||||||
|
ORDER BY trade_date
|
||||||
|
"""
|
||||||
|
|
||||||
|
result = client.execute(query)
|
||||||
|
days = [row[0].strftime('%Y-%m-%d') for row in result]
|
||||||
|
if days:
|
||||||
|
print(f"找到 {len(days)} 个交易日: {days[0]} ~ {days[-1]}")
|
||||||
|
return days
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 获取昨收价 ====================
|
||||||
|
|
||||||
|
def get_prev_close(stock_codes: List[str], trade_date: str) -> Dict[str, float]:
|
||||||
|
"""获取昨收价(上一交易日的收盘价 F007N)"""
|
||||||
|
valid_codes = [c for c in stock_codes if c and len(c) == 6 and c.isdigit()]
|
||||||
|
if not valid_codes:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
codes_str = "','".join(valid_codes)
|
||||||
|
query = f"""
|
||||||
|
SELECT SECCODE, F007N
|
||||||
|
FROM ea_trade
|
||||||
|
WHERE SECCODE IN ('{codes_str}')
|
||||||
|
AND TRADEDATE = (
|
||||||
|
SELECT MAX(TRADEDATE) FROM ea_trade WHERE TRADEDATE < '{trade_date}'
|
||||||
|
)
|
||||||
|
AND F007N IS NOT NULL AND F007N > 0
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
engine = get_mysql_engine()
|
||||||
|
with engine.connect() as conn:
|
||||||
|
result = conn.execute(text(query))
|
||||||
|
return {row[0]: float(row[1]) for row in result if row[1]}
|
||||||
|
except Exception as e:
|
||||||
|
print(f"获取昨收价失败: {e}")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def get_index_prev_close(trade_date: str, index_code: str = REFERENCE_INDEX) -> float:
|
||||||
|
"""获取指数昨收价"""
|
||||||
|
code_no_suffix = index_code.split('.')[0]
|
||||||
|
|
||||||
|
try:
|
||||||
|
engine = get_mysql_engine()
|
||||||
|
with engine.connect() as conn:
|
||||||
|
result = conn.execute(text("""
|
||||||
|
SELECT F006N FROM ea_exchangetrade
|
||||||
|
WHERE INDEXCODE = :code AND TRADEDATE < :today
|
||||||
|
ORDER BY TRADEDATE DESC LIMIT 1
|
||||||
|
"""), {'code': code_no_suffix, 'today': trade_date}).fetchone()
|
||||||
|
|
||||||
|
if result and result[0]:
|
||||||
|
return float(result[0])
|
||||||
|
except Exception as e:
|
||||||
|
print(f"获取指数昨收失败: {e}")
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 获取分钟数据 ====================
|
||||||
|
|
||||||
|
def get_daily_stock_data(trade_date: str, stock_codes: List[str]) -> pd.DataFrame:
|
||||||
|
"""获取单日所有股票的分钟数据"""
|
||||||
|
client = get_ch_client()
|
||||||
|
|
||||||
|
ch_codes = []
|
||||||
|
code_map = {}
|
||||||
|
for code in stock_codes:
|
||||||
|
ch_code = code_to_ch_format(code)
|
||||||
|
if ch_code:
|
||||||
|
ch_codes.append(ch_code)
|
||||||
|
code_map[ch_code] = code
|
||||||
|
|
||||||
|
if not ch_codes:
|
||||||
|
return pd.DataFrame()
|
||||||
|
|
||||||
|
ch_codes_str = "','".join(ch_codes)
|
||||||
|
|
||||||
|
query = f"""
|
||||||
|
SELECT code, timestamp, close, volume, amt
|
||||||
|
FROM stock_minute
|
||||||
|
WHERE toDate(timestamp) = '{trade_date}'
|
||||||
|
AND code IN ('{ch_codes_str}')
|
||||||
|
ORDER BY code, timestamp
|
||||||
|
"""
|
||||||
|
|
||||||
|
result = client.execute(query)
|
||||||
|
if not result:
|
||||||
|
return pd.DataFrame()
|
||||||
|
|
||||||
|
df = pd.DataFrame(result, columns=['ch_code', 'timestamp', 'close', 'volume', 'amt'])
|
||||||
|
df['code'] = df['ch_code'].map(code_map)
|
||||||
|
df = df.dropna(subset=['code'])
|
||||||
|
|
||||||
|
return df[['code', 'timestamp', 'close', 'volume', 'amt']]
|
||||||
|
|
||||||
|
|
||||||
|
def get_daily_index_data(trade_date: str, index_code: str = REFERENCE_INDEX) -> pd.DataFrame:
|
||||||
|
"""获取单日指数分钟数据"""
|
||||||
|
client = get_ch_client()
|
||||||
|
|
||||||
|
query = f"""
|
||||||
|
SELECT timestamp, close, volume, amt
|
||||||
|
FROM index_minute
|
||||||
|
WHERE toDate(timestamp) = '{trade_date}'
|
||||||
|
AND code = '{index_code}'
|
||||||
|
ORDER BY timestamp
|
||||||
|
"""
|
||||||
|
|
||||||
|
result = client.execute(query)
|
||||||
|
if not result:
|
||||||
|
return pd.DataFrame()
|
||||||
|
|
||||||
|
df = pd.DataFrame(result, columns=['timestamp', 'close', 'volume', 'amt'])
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 计算原始概念特征(单日)====================
|
||||||
|
|
||||||
|
def compute_raw_concept_features(
|
||||||
|
trade_date: str,
|
||||||
|
concepts: List[dict],
|
||||||
|
all_stocks: List[str]
|
||||||
|
) -> pd.DataFrame:
|
||||||
|
"""计算单日概念的原始特征(alpha, amt, rank_pct, limit_up_ratio)"""
|
||||||
|
# 检查缓存
|
||||||
|
cache_file = os.path.join(RAW_CACHE_DIR, f'raw_{trade_date}.parquet')
|
||||||
|
if os.path.exists(cache_file):
|
||||||
|
return pd.read_parquet(cache_file)
|
||||||
|
|
||||||
|
# 获取数据
|
||||||
|
stock_df = get_daily_stock_data(trade_date, all_stocks)
|
||||||
|
if stock_df.empty:
|
||||||
|
return pd.DataFrame()
|
||||||
|
|
||||||
|
index_df = get_daily_index_data(trade_date)
|
||||||
|
if index_df.empty:
|
||||||
|
return pd.DataFrame()
|
||||||
|
|
||||||
|
# 获取昨收价
|
||||||
|
prev_close = get_prev_close(all_stocks, trade_date)
|
||||||
|
index_prev_close = get_index_prev_close(trade_date)
|
||||||
|
|
||||||
|
if not prev_close or not index_prev_close:
|
||||||
|
return pd.DataFrame()
|
||||||
|
|
||||||
|
# 计算涨跌幅
|
||||||
|
stock_df['prev_close'] = stock_df['code'].map(prev_close)
|
||||||
|
stock_df = stock_df.dropna(subset=['prev_close'])
|
||||||
|
stock_df['change_pct'] = (stock_df['close'] - stock_df['prev_close']) / stock_df['prev_close'] * 100
|
||||||
|
|
||||||
|
index_df['change_pct'] = (index_df['close'] - index_prev_close) / index_prev_close * 100
|
||||||
|
index_change_map = dict(zip(index_df['timestamp'], index_df['change_pct']))
|
||||||
|
|
||||||
|
# 获取所有时间点
|
||||||
|
timestamps = sorted(stock_df['timestamp'].unique())
|
||||||
|
|
||||||
|
# 概念到股票的映射
|
||||||
|
concept_stocks = {c['concept_id']: set(c['stocks']) for c in concepts}
|
||||||
|
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for ts in timestamps:
|
||||||
|
ts_stock_data = stock_df[stock_df['timestamp'] == ts]
|
||||||
|
index_change = index_change_map.get(ts, 0)
|
||||||
|
|
||||||
|
stock_change = dict(zip(ts_stock_data['code'], ts_stock_data['change_pct']))
|
||||||
|
stock_amt = dict(zip(ts_stock_data['code'], ts_stock_data['amt']))
|
||||||
|
|
||||||
|
concept_features = []
|
||||||
|
|
||||||
|
for concept_id, stocks in concept_stocks.items():
|
||||||
|
concept_changes = [stock_change[s] for s in stocks if s in stock_change]
|
||||||
|
concept_amts = [stock_amt.get(s, 0) for s in stocks if s in stock_change]
|
||||||
|
|
||||||
|
if not concept_changes:
|
||||||
|
continue
|
||||||
|
|
||||||
|
avg_change = np.mean(concept_changes)
|
||||||
|
total_amt = sum(concept_amts)
|
||||||
|
alpha = avg_change - index_change
|
||||||
|
|
||||||
|
limit_up_count = sum(1 for c in concept_changes if c >= CONFIG['limit_up_threshold'])
|
||||||
|
limit_up_ratio = limit_up_count / len(concept_changes)
|
||||||
|
|
||||||
|
concept_features.append({
|
||||||
|
'concept_id': concept_id,
|
||||||
|
'alpha': alpha,
|
||||||
|
'total_amt': total_amt,
|
||||||
|
'limit_up_ratio': limit_up_ratio,
|
||||||
|
'stock_count': len(concept_changes),
|
||||||
|
})
|
||||||
|
|
||||||
|
if not concept_features:
|
||||||
|
continue
|
||||||
|
|
||||||
|
concept_df = pd.DataFrame(concept_features)
|
||||||
|
concept_df['rank_pct'] = concept_df['alpha'].rank(pct=True)
|
||||||
|
concept_df['timestamp'] = ts
|
||||||
|
concept_df['time_slot'] = time_to_slot(ts)
|
||||||
|
concept_df['trade_date'] = trade_date
|
||||||
|
|
||||||
|
results.append(concept_df)
|
||||||
|
|
||||||
|
if not results:
|
||||||
|
return pd.DataFrame()
|
||||||
|
|
||||||
|
result_df = pd.concat(results, ignore_index=True)
|
||||||
|
|
||||||
|
# 保存缓存
|
||||||
|
result_df.to_parquet(cache_file, index=False)
|
||||||
|
|
||||||
|
return result_df
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 滚动窗口基线计算 ====================
|
||||||
|
|
||||||
|
def compute_rolling_baseline(
|
||||||
|
historical_data: pd.DataFrame,
|
||||||
|
concept_id: str
|
||||||
|
) -> Dict[str, Dict]:
|
||||||
|
"""
|
||||||
|
计算单个概念的滚动基线
|
||||||
|
|
||||||
|
返回: {time_slot: {alpha_mean, alpha_std, amt_mean, amt_std, rank_mean, rank_std, sample_count}}
|
||||||
|
"""
|
||||||
|
if historical_data.empty:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
concept_data = historical_data[historical_data['concept_id'] == concept_id]
|
||||||
|
if concept_data.empty:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
baseline_dict = {}
|
||||||
|
|
||||||
|
for time_slot, group in concept_data.groupby('time_slot'):
|
||||||
|
if len(group) < CONFIG['min_baseline_samples']:
|
||||||
|
continue
|
||||||
|
|
||||||
|
alpha_std = group['alpha'].std()
|
||||||
|
amt_std = group['total_amt'].std()
|
||||||
|
rank_std = group['rank_pct'].std()
|
||||||
|
|
||||||
|
baseline_dict[time_slot] = {
|
||||||
|
'alpha_mean': group['alpha'].mean(),
|
||||||
|
'alpha_std': max(alpha_std if pd.notna(alpha_std) else 1.0, 0.1),
|
||||||
|
'amt_mean': group['total_amt'].mean(),
|
||||||
|
'amt_std': max(amt_std if pd.notna(amt_std) else group['total_amt'].mean() * 0.5, 1.0),
|
||||||
|
'rank_mean': group['rank_pct'].mean(),
|
||||||
|
'rank_std': max(rank_std if pd.notna(rank_std) else 0.2, 0.05),
|
||||||
|
'sample_count': len(group),
|
||||||
|
}
|
||||||
|
|
||||||
|
return baseline_dict
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 计算单日 Z-Score 特征(带滚动基线)====================
|
||||||
|
|
||||||
|
def compute_zscore_features_rolling(
|
||||||
|
trade_date: str,
|
||||||
|
concepts: List[dict],
|
||||||
|
all_stocks: List[str],
|
||||||
|
historical_raw_data: pd.DataFrame # 该日期之前 N 天的原始数据
|
||||||
|
) -> pd.DataFrame:
|
||||||
|
"""
|
||||||
|
计算单日的 Z-Score 特征(使用滚动窗口基线)
|
||||||
|
|
||||||
|
关键改进:
|
||||||
|
1. 基线只使用 trade_date 之前的数据(无未来泄露)
|
||||||
|
2. 动量基于 Z-Score 计算(消除波动率异构性)
|
||||||
|
"""
|
||||||
|
# 计算当日原始特征
|
||||||
|
raw_df = compute_raw_concept_features(trade_date, concepts, all_stocks)
|
||||||
|
|
||||||
|
if raw_df.empty:
|
||||||
|
return pd.DataFrame()
|
||||||
|
|
||||||
|
zscore_records = []
|
||||||
|
|
||||||
|
for concept_id, group in raw_df.groupby('concept_id'):
|
||||||
|
# 计算该概念的滚动基线(只用历史数据)
|
||||||
|
baseline_dict = compute_rolling_baseline(historical_raw_data, concept_id)
|
||||||
|
|
||||||
|
if not baseline_dict:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 按时间排序
|
||||||
|
group = group.sort_values('timestamp').reset_index(drop=True)
|
||||||
|
|
||||||
|
# Z-Score 历史(用于计算基于 Z-Score 的动量)
|
||||||
|
zscore_history = []
|
||||||
|
|
||||||
|
for idx, row in group.iterrows():
|
||||||
|
time_slot = row['time_slot']
|
||||||
|
|
||||||
|
if time_slot not in baseline_dict:
|
||||||
|
continue
|
||||||
|
|
||||||
|
bl = baseline_dict[time_slot]
|
||||||
|
|
||||||
|
# 计算 Z-Score
|
||||||
|
alpha_zscore = (row['alpha'] - bl['alpha_mean']) / bl['alpha_std']
|
||||||
|
amt_zscore = (row['total_amt'] - bl['amt_mean']) / bl['amt_std']
|
||||||
|
rank_zscore = (row['rank_pct'] - bl['rank_mean']) / bl['rank_std']
|
||||||
|
|
||||||
|
# 截断极端值
|
||||||
|
clip = CONFIG['zscore_clip']
|
||||||
|
alpha_zscore = np.clip(alpha_zscore, -clip, clip)
|
||||||
|
amt_zscore = np.clip(amt_zscore, -clip, clip)
|
||||||
|
rank_zscore = np.clip(rank_zscore, -clip, clip)
|
||||||
|
|
||||||
|
# 记录 Z-Score 历史
|
||||||
|
zscore_history.append(alpha_zscore)
|
||||||
|
|
||||||
|
# 基于 Z-Score 计算动量(消除波动率异构性)
|
||||||
|
momentum_3m = 0.0
|
||||||
|
momentum_5m = 0.0
|
||||||
|
|
||||||
|
if len(zscore_history) >= 3:
|
||||||
|
recent_3 = zscore_history[-3:]
|
||||||
|
older_3 = zscore_history[-6:-3] if len(zscore_history) >= 6 else [zscore_history[0]]
|
||||||
|
momentum_3m = np.mean(recent_3) - np.mean(older_3)
|
||||||
|
|
||||||
|
if len(zscore_history) >= 5:
|
||||||
|
recent_5 = zscore_history[-5:]
|
||||||
|
older_5 = zscore_history[-10:-5] if len(zscore_history) >= 10 else [zscore_history[0]]
|
||||||
|
momentum_5m = np.mean(recent_5) - np.mean(older_5)
|
||||||
|
|
||||||
|
zscore_records.append({
|
||||||
|
'concept_id': concept_id,
|
||||||
|
'timestamp': row['timestamp'],
|
||||||
|
'time_slot': time_slot,
|
||||||
|
'trade_date': trade_date,
|
||||||
|
# 原始特征
|
||||||
|
'alpha': row['alpha'],
|
||||||
|
'total_amt': row['total_amt'],
|
||||||
|
'limit_up_ratio': row['limit_up_ratio'],
|
||||||
|
'stock_count': row['stock_count'],
|
||||||
|
'rank_pct': row['rank_pct'],
|
||||||
|
# Z-Score 特征
|
||||||
|
'alpha_zscore': alpha_zscore,
|
||||||
|
'amt_zscore': amt_zscore,
|
||||||
|
'rank_zscore': rank_zscore,
|
||||||
|
# 基于 Z-Score 的动量
|
||||||
|
'momentum_3m': momentum_3m,
|
||||||
|
'momentum_5m': momentum_5m,
|
||||||
|
})
|
||||||
|
|
||||||
|
if not zscore_records:
|
||||||
|
return pd.DataFrame()
|
||||||
|
|
||||||
|
return pd.DataFrame(zscore_records)
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 多进程处理 ====================
|
||||||
|
|
||||||
|
def process_single_day_v2(args) -> Tuple[str, bool]:
|
||||||
|
"""处理单个交易日(多进程版本)"""
|
||||||
|
trade_date, day_index, concepts, all_stocks, all_trading_days = args
|
||||||
|
output_file = os.path.join(OUTPUT_DIR, f'features_v2_{trade_date}.parquet')
|
||||||
|
|
||||||
|
if os.path.exists(output_file):
|
||||||
|
return (trade_date, True)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 计算滚动窗口范围(该日期之前的 N 天)
|
||||||
|
baseline_days = CONFIG['baseline_days']
|
||||||
|
|
||||||
|
# 找出 trade_date 之前的交易日
|
||||||
|
start_idx = max(0, day_index - baseline_days)
|
||||||
|
end_idx = day_index # 不包含当天
|
||||||
|
|
||||||
|
if end_idx <= start_idx:
|
||||||
|
# 没有足够的历史数据
|
||||||
|
return (trade_date, False)
|
||||||
|
|
||||||
|
historical_days = all_trading_days[start_idx:end_idx]
|
||||||
|
|
||||||
|
# 加载历史原始数据
|
||||||
|
historical_dfs = []
|
||||||
|
for hist_date in historical_days:
|
||||||
|
cache_file = os.path.join(RAW_CACHE_DIR, f'raw_{hist_date}.parquet')
|
||||||
|
if os.path.exists(cache_file):
|
||||||
|
historical_dfs.append(pd.read_parquet(cache_file))
|
||||||
|
else:
|
||||||
|
# 需要计算
|
||||||
|
hist_df = compute_raw_concept_features(hist_date, concepts, all_stocks)
|
||||||
|
if not hist_df.empty:
|
||||||
|
historical_dfs.append(hist_df)
|
||||||
|
|
||||||
|
if not historical_dfs:
|
||||||
|
return (trade_date, False)
|
||||||
|
|
||||||
|
historical_raw_data = pd.concat(historical_dfs, ignore_index=True)
|
||||||
|
|
||||||
|
# 计算当日 Z-Score 特征(使用滚动基线)
|
||||||
|
df = compute_zscore_features_rolling(trade_date, concepts, all_stocks, historical_raw_data)
|
||||||
|
|
||||||
|
if df.empty:
|
||||||
|
return (trade_date, False)
|
||||||
|
|
||||||
|
df.to_parquet(output_file, index=False)
|
||||||
|
return (trade_date, True)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[{trade_date}] 处理失败: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
return (trade_date, False)
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 主流程 ====================
|
||||||
|
|
||||||
|
def main():
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='准备训练数据 V2(滚动窗口基线 + Z-Score 动量)')
|
||||||
|
parser.add_argument('--start', type=str, default='2022-01-01', help='开始日期')
|
||||||
|
parser.add_argument('--end', type=str, default=None, help='结束日期(默认今天)')
|
||||||
|
parser.add_argument('--workers', type=int, default=18, help='并行进程数')
|
||||||
|
parser.add_argument('--baseline-days', type=int, default=20, help='滚动基线窗口大小')
|
||||||
|
parser.add_argument('--force', action='store_true', help='强制重新计算(忽略缓存)')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
end_date = args.end or datetime.now().strftime('%Y-%m-%d')
|
||||||
|
CONFIG['baseline_days'] = args.baseline_days
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
print("数据准备 V2 - 滚动窗口基线 + Z-Score 动量")
|
||||||
|
print("=" * 60)
|
||||||
|
print(f"日期范围: {args.start} ~ {end_date}")
|
||||||
|
print(f"并行进程数: {args.workers}")
|
||||||
|
print(f"滚动基线窗口: {args.baseline_days} 天")
|
||||||
|
|
||||||
|
# 初始化主进程连接
|
||||||
|
init_process_connections()
|
||||||
|
|
||||||
|
# 1. 获取概念列表
|
||||||
|
concepts = get_all_concepts()
|
||||||
|
all_stocks = list(set(s for c in concepts for s in c['stocks']))
|
||||||
|
print(f"股票总数: {len(all_stocks)}")
|
||||||
|
|
||||||
|
# 2. 获取交易日列表
|
||||||
|
trading_days = get_trading_days(args.start, end_date)
|
||||||
|
|
||||||
|
if not trading_days:
|
||||||
|
print("无交易日数据")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 3. 第一阶段:预计算所有原始特征(用于缓存)
|
||||||
|
print(f"\n{'='*60}")
|
||||||
|
print("第一阶段:预计算原始特征(用于滚动基线)")
|
||||||
|
print(f"{'='*60}")
|
||||||
|
|
||||||
|
# 如果强制重新计算,删除缓存
|
||||||
|
if args.force:
|
||||||
|
import shutil
|
||||||
|
if os.path.exists(RAW_CACHE_DIR):
|
||||||
|
shutil.rmtree(RAW_CACHE_DIR)
|
||||||
|
os.makedirs(RAW_CACHE_DIR, exist_ok=True)
|
||||||
|
if os.path.exists(OUTPUT_DIR):
|
||||||
|
for f in os.listdir(OUTPUT_DIR):
|
||||||
|
if f.startswith('features_v2_'):
|
||||||
|
os.remove(os.path.join(OUTPUT_DIR, f))
|
||||||
|
|
||||||
|
# 单线程预计算原始特征(因为需要顺序缓存)
|
||||||
|
print(f"预计算 {len(trading_days)} 天的原始特征...")
|
||||||
|
for trade_date in tqdm(trading_days, desc="预计算原始特征"):
|
||||||
|
cache_file = os.path.join(RAW_CACHE_DIR, f'raw_{trade_date}.parquet')
|
||||||
|
if not os.path.exists(cache_file):
|
||||||
|
compute_raw_concept_features(trade_date, concepts, all_stocks)
|
||||||
|
|
||||||
|
# 4. 第二阶段:计算 Z-Score 特征(多进程)
|
||||||
|
print(f"\n{'='*60}")
|
||||||
|
print("第二阶段:计算 Z-Score 特征(滚动基线)")
|
||||||
|
print(f"{'='*60}")
|
||||||
|
|
||||||
|
# 从第 baseline_days 天开始(前面的没有足够历史)
|
||||||
|
start_idx = args.baseline_days
|
||||||
|
processable_days = trading_days[start_idx:]
|
||||||
|
|
||||||
|
if not processable_days:
|
||||||
|
print(f"错误:需要至少 {args.baseline_days + 1} 天的数据")
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f"可处理日期: {processable_days[0]} ~ {processable_days[-1]} ({len(processable_days)} 天)")
|
||||||
|
print(f"跳过前 {start_idx} 天(基线预热期)")
|
||||||
|
|
||||||
|
# 构建任务
|
||||||
|
tasks = []
|
||||||
|
for i, trade_date in enumerate(trading_days):
|
||||||
|
if i >= start_idx:
|
||||||
|
tasks.append((trade_date, i, concepts, all_stocks, trading_days))
|
||||||
|
|
||||||
|
print(f"开始处理 {len(tasks)} 个交易日({args.workers} 进程并行)...")
|
||||||
|
|
||||||
|
success_count = 0
|
||||||
|
failed_dates = []
|
||||||
|
|
||||||
|
# 使用进程池初始化器
|
||||||
|
with ProcessPoolExecutor(max_workers=args.workers, initializer=init_process_connections) as executor:
|
||||||
|
futures = {executor.submit(process_single_day_v2, task): task[0] for task in tasks}
|
||||||
|
|
||||||
|
with tqdm(total=len(futures), desc="处理进度", unit="天") as pbar:
|
||||||
|
for future in as_completed(futures):
|
||||||
|
trade_date = futures[future]
|
||||||
|
try:
|
||||||
|
result_date, success = future.result()
|
||||||
|
if success:
|
||||||
|
success_count += 1
|
||||||
|
else:
|
||||||
|
failed_dates.append(result_date)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n[{trade_date}] 进程异常: {e}")
|
||||||
|
failed_dates.append(trade_date)
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print(f"处理完成: {success_count}/{len(tasks)} 个交易日")
|
||||||
|
if failed_dates:
|
||||||
|
print(f"失败日期: {failed_dates[:10]}{'...' if len(failed_dates) > 10 else ''}")
|
||||||
|
print(f"数据保存在: {OUTPUT_DIR}")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -190,20 +190,22 @@ def get_all_concepts() -> List[dict]:
|
|||||||
|
|
||||||
|
|
||||||
def get_prev_close(stock_codes: List[str], trade_date: str) -> Dict[str, float]:
|
def get_prev_close(stock_codes: List[str], trade_date: str) -> Dict[str, float]:
|
||||||
"""获取昨收价"""
|
"""获取昨收价(上一交易日的收盘价 F007N)"""
|
||||||
valid_codes = [c for c in stock_codes if c and len(c) == 6 and c.isdigit()]
|
valid_codes = [c for c in stock_codes if c and len(c) == 6 and c.isdigit()]
|
||||||
if not valid_codes:
|
if not valid_codes:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
codes_str = "','".join(valid_codes)
|
codes_str = "','".join(valid_codes)
|
||||||
|
# 注意:F007N 是"最近成交价"即当日收盘价,F002N 是"昨日收盘价"
|
||||||
|
# 我们需要查上一交易日的 F007N(那天的收盘价)作为今天的昨收
|
||||||
query = f"""
|
query = f"""
|
||||||
SELECT SECCODE, F002N
|
SELECT SECCODE, F007N
|
||||||
FROM ea_trade
|
FROM ea_trade
|
||||||
WHERE SECCODE IN ('{codes_str}')
|
WHERE SECCODE IN ('{codes_str}')
|
||||||
AND TRADEDATE = (
|
AND TRADEDATE = (
|
||||||
SELECT MAX(TRADEDATE) FROM ea_trade WHERE TRADEDATE < '{trade_date}'
|
SELECT MAX(TRADEDATE) FROM ea_trade WHERE TRADEDATE < '{trade_date}'
|
||||||
)
|
)
|
||||||
AND F002N IS NOT NULL AND F002N > 0
|
AND F007N IS NOT NULL AND F007N > 0
|
||||||
"""
|
"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
729
ml/realtime_detector_v2.py
Normal file
729
ml/realtime_detector_v2.py
Normal file
@@ -0,0 +1,729 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
V2 实时异动检测器
|
||||||
|
|
||||||
|
使用方法:
|
||||||
|
# 作为模块导入
|
||||||
|
from ml.realtime_detector_v2 import RealtimeDetectorV2
|
||||||
|
|
||||||
|
detector = RealtimeDetectorV2()
|
||||||
|
alerts = detector.detect_realtime() # 检测当前时刻
|
||||||
|
|
||||||
|
# 或命令行测试
|
||||||
|
python ml/realtime_detector_v2.py --date 2025-12-09
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import json
|
||||||
|
import pickle
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
from collections import defaultdict, deque
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import torch
|
||||||
|
from sqlalchemy import create_engine, text
|
||||||
|
from elasticsearch import Elasticsearch
|
||||||
|
from clickhouse_driver import Client
|
||||||
|
|
||||||
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
|
||||||
|
from ml.model import TransformerAutoencoder
|
||||||
|
|
||||||
|
# ==================== 配置 ====================
|
||||||
|
|
||||||
|
MYSQL_URL = "mysql+pymysql://root:Zzl5588161!@192.168.1.5:3306/stock"
|
||||||
|
ES_HOST = 'http://127.0.0.1:9200'
|
||||||
|
ES_INDEX = 'concept_library_v3'
|
||||||
|
|
||||||
|
CLICKHOUSE_CONFIG = {
|
||||||
|
'host': '127.0.0.1',
|
||||||
|
'port': 9000,
|
||||||
|
'user': 'default',
|
||||||
|
'password': 'Zzl33818!',
|
||||||
|
'database': 'stock'
|
||||||
|
}
|
||||||
|
|
||||||
|
REFERENCE_INDEX = '000001.SH'
|
||||||
|
BASELINE_FILE = 'ml/data_v2/baselines/realtime_baseline.pkl'
|
||||||
|
MODEL_DIR = 'ml/checkpoints_v2'
|
||||||
|
|
||||||
|
# 检测配置
|
||||||
|
CONFIG = {
|
||||||
|
'seq_len': 10, # LSTM 序列长度
|
||||||
|
'confirm_window': 5, # 持续确认窗口
|
||||||
|
'confirm_ratio': 0.6, # 确认比例
|
||||||
|
'rule_weight': 0.5,
|
||||||
|
'ml_weight': 0.5,
|
||||||
|
'rule_trigger': 60,
|
||||||
|
'ml_trigger': 70,
|
||||||
|
'fusion_trigger': 50,
|
||||||
|
'cooldown_minutes': 10,
|
||||||
|
'max_alerts_per_minute': 15,
|
||||||
|
'zscore_clip': 5.0,
|
||||||
|
'limit_up_threshold': 9.8,
|
||||||
|
}
|
||||||
|
|
||||||
|
FEATURES = ['alpha_zscore', 'amt_zscore', 'rank_zscore', 'momentum_3m', 'momentum_5m', 'limit_up_ratio']
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 数据库连接 ====================
|
||||||
|
|
||||||
|
_mysql_engine = None
|
||||||
|
_es_client = None
|
||||||
|
_ch_client = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_mysql_engine():
|
||||||
|
global _mysql_engine
|
||||||
|
if _mysql_engine is None:
|
||||||
|
_mysql_engine = create_engine(MYSQL_URL, echo=False, pool_pre_ping=True)
|
||||||
|
return _mysql_engine
|
||||||
|
|
||||||
|
|
||||||
|
def get_es_client():
|
||||||
|
global _es_client
|
||||||
|
if _es_client is None:
|
||||||
|
_es_client = Elasticsearch([ES_HOST])
|
||||||
|
return _es_client
|
||||||
|
|
||||||
|
|
||||||
|
def get_ch_client():
|
||||||
|
global _ch_client
|
||||||
|
if _ch_client is None:
|
||||||
|
_ch_client = Client(**CLICKHOUSE_CONFIG)
|
||||||
|
return _ch_client
|
||||||
|
|
||||||
|
|
||||||
|
def code_to_ch_format(code: str) -> str:
|
||||||
|
if not code or len(code) != 6 or not code.isdigit():
|
||||||
|
return None
|
||||||
|
if code.startswith('6'):
|
||||||
|
return f"{code}.SH"
|
||||||
|
elif code.startswith('0') or code.startswith('3'):
|
||||||
|
return f"{code}.SZ"
|
||||||
|
return f"{code}.BJ"
|
||||||
|
|
||||||
|
|
||||||
|
def time_to_slot(ts) -> str:
|
||||||
|
if isinstance(ts, str):
|
||||||
|
return ts
|
||||||
|
return ts.strftime('%H:%M')
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 规则评分 ====================
|
||||||
|
|
||||||
|
def score_rules_zscore(features: Dict) -> Tuple[float, List[str]]:
|
||||||
|
"""基于 Z-Score 的规则评分"""
|
||||||
|
score = 0.0
|
||||||
|
triggered = []
|
||||||
|
|
||||||
|
alpha_z = abs(features.get('alpha_zscore', 0))
|
||||||
|
amt_z = features.get('amt_zscore', 0)
|
||||||
|
rank_z = abs(features.get('rank_zscore', 0))
|
||||||
|
mom_3m = features.get('momentum_3m', 0)
|
||||||
|
mom_5m = features.get('momentum_5m', 0)
|
||||||
|
limit_up = features.get('limit_up_ratio', 0)
|
||||||
|
|
||||||
|
# Alpha Z-Score
|
||||||
|
if alpha_z >= 4.0:
|
||||||
|
score += 25
|
||||||
|
triggered.append('alpha_extreme')
|
||||||
|
elif alpha_z >= 3.0:
|
||||||
|
score += 18
|
||||||
|
triggered.append('alpha_strong')
|
||||||
|
elif alpha_z >= 2.0:
|
||||||
|
score += 10
|
||||||
|
triggered.append('alpha_moderate')
|
||||||
|
|
||||||
|
# 成交额 Z-Score
|
||||||
|
if amt_z >= 4.0:
|
||||||
|
score += 20
|
||||||
|
triggered.append('amt_extreme')
|
||||||
|
elif amt_z >= 3.0:
|
||||||
|
score += 12
|
||||||
|
triggered.append('amt_strong')
|
||||||
|
elif amt_z >= 2.0:
|
||||||
|
score += 6
|
||||||
|
triggered.append('amt_moderate')
|
||||||
|
|
||||||
|
# 排名 Z-Score
|
||||||
|
if rank_z >= 3.0:
|
||||||
|
score += 15
|
||||||
|
triggered.append('rank_extreme')
|
||||||
|
elif rank_z >= 2.0:
|
||||||
|
score += 8
|
||||||
|
triggered.append('rank_strong')
|
||||||
|
|
||||||
|
# 动量(基于 Z-Score 的)
|
||||||
|
if mom_3m >= 1.0:
|
||||||
|
score += 12
|
||||||
|
triggered.append('momentum_3m_strong')
|
||||||
|
elif mom_3m >= 0.5:
|
||||||
|
score += 6
|
||||||
|
triggered.append('momentum_3m_moderate')
|
||||||
|
|
||||||
|
if mom_5m >= 1.5:
|
||||||
|
score += 10
|
||||||
|
triggered.append('momentum_5m_strong')
|
||||||
|
|
||||||
|
# 涨停比例
|
||||||
|
if limit_up >= 0.3:
|
||||||
|
score += 20
|
||||||
|
triggered.append('limit_up_extreme')
|
||||||
|
elif limit_up >= 0.15:
|
||||||
|
score += 12
|
||||||
|
triggered.append('limit_up_strong')
|
||||||
|
elif limit_up >= 0.08:
|
||||||
|
score += 5
|
||||||
|
triggered.append('limit_up_moderate')
|
||||||
|
|
||||||
|
# 组合规则
|
||||||
|
if alpha_z >= 2.0 and amt_z >= 2.0:
|
||||||
|
score += 15
|
||||||
|
triggered.append('combo_alpha_amt')
|
||||||
|
|
||||||
|
if alpha_z >= 2.0 and limit_up >= 0.1:
|
||||||
|
score += 12
|
||||||
|
triggered.append('combo_alpha_limitup')
|
||||||
|
|
||||||
|
return min(score, 100), triggered
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 实时检测器 ====================
|
||||||
|
|
||||||
|
class RealtimeDetectorV2:
|
||||||
|
"""V2 实时异动检测器"""
|
||||||
|
|
||||||
|
def __init__(self, model_dir: str = MODEL_DIR, baseline_file: str = BASELINE_FILE):
|
||||||
|
print("初始化 V2 实时检测器...")
|
||||||
|
|
||||||
|
# 加载概念
|
||||||
|
self.concepts = self._load_concepts()
|
||||||
|
self.concept_stocks = {c['concept_id']: set(c['stocks']) for c in self.concepts}
|
||||||
|
self.all_stocks = list(set(s for c in self.concepts for s in c['stocks']))
|
||||||
|
|
||||||
|
# 加载基线
|
||||||
|
self.baselines = self._load_baselines(baseline_file)
|
||||||
|
|
||||||
|
# 加载模型
|
||||||
|
self.model, self.thresholds, self.device = self._load_model(model_dir)
|
||||||
|
|
||||||
|
# 状态管理
|
||||||
|
self.zscore_history = defaultdict(lambda: deque(maxlen=CONFIG['seq_len']))
|
||||||
|
self.anomaly_candidates = defaultdict(lambda: deque(maxlen=CONFIG['confirm_window']))
|
||||||
|
self.cooldown = {}
|
||||||
|
|
||||||
|
print(f"初始化完成: {len(self.concepts)} 概念, {len(self.baselines)} 基线")
|
||||||
|
|
||||||
|
def _load_concepts(self) -> List[dict]:
|
||||||
|
"""从 ES 加载概念"""
|
||||||
|
es = get_es_client()
|
||||||
|
concepts = []
|
||||||
|
|
||||||
|
query = {"query": {"match_all": {}}, "size": 100, "_source": ["concept_id", "concept", "stocks"]}
|
||||||
|
resp = es.search(index=ES_INDEX, body=query, scroll='2m')
|
||||||
|
scroll_id = resp['_scroll_id']
|
||||||
|
hits = resp['hits']['hits']
|
||||||
|
|
||||||
|
while hits:
|
||||||
|
for hit in hits:
|
||||||
|
src = hit['_source']
|
||||||
|
stocks = [s['code'] for s in src.get('stocks', []) if isinstance(s, dict) and s.get('code')]
|
||||||
|
if stocks:
|
||||||
|
concepts.append({
|
||||||
|
'concept_id': src.get('concept_id'),
|
||||||
|
'concept_name': src.get('concept'),
|
||||||
|
'stocks': stocks
|
||||||
|
})
|
||||||
|
resp = es.scroll(scroll_id=scroll_id, scroll='2m')
|
||||||
|
scroll_id = resp['_scroll_id']
|
||||||
|
hits = resp['hits']['hits']
|
||||||
|
|
||||||
|
es.clear_scroll(scroll_id=scroll_id)
|
||||||
|
return concepts
|
||||||
|
|
||||||
|
def _load_baselines(self, baseline_file: str) -> Dict:
|
||||||
|
"""加载基线"""
|
||||||
|
if not os.path.exists(baseline_file):
|
||||||
|
print(f"警告: 基线文件不存在: {baseline_file}")
|
||||||
|
print("请先运行: python ml/update_baseline.py")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
with open(baseline_file, 'rb') as f:
|
||||||
|
data = pickle.load(f)
|
||||||
|
|
||||||
|
print(f"基线日期范围: {data.get('date_range', 'unknown')}")
|
||||||
|
print(f"更新时间: {data.get('update_time', 'unknown')}")
|
||||||
|
|
||||||
|
return data.get('baselines', {})
|
||||||
|
|
||||||
|
def _load_model(self, model_dir: str):
|
||||||
|
"""加载模型"""
|
||||||
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
|
||||||
|
config_path = os.path.join(model_dir, 'config.json')
|
||||||
|
model_path = os.path.join(model_dir, 'best_model.pt')
|
||||||
|
threshold_path = os.path.join(model_dir, 'thresholds.json')
|
||||||
|
|
||||||
|
if not os.path.exists(model_path):
|
||||||
|
print(f"警告: 模型不存在: {model_path}")
|
||||||
|
return None, {}, device
|
||||||
|
|
||||||
|
with open(config_path) as f:
|
||||||
|
config = json.load(f)
|
||||||
|
|
||||||
|
model = TransformerAutoencoder(**config['model'])
|
||||||
|
checkpoint = torch.load(model_path, map_location=device)
|
||||||
|
model.load_state_dict(checkpoint['model_state_dict'])
|
||||||
|
model.to(device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
thresholds = {}
|
||||||
|
if os.path.exists(threshold_path):
|
||||||
|
with open(threshold_path) as f:
|
||||||
|
thresholds = json.load(f)
|
||||||
|
|
||||||
|
print(f"模型已加载: {model_path}")
|
||||||
|
return model, thresholds, device
|
||||||
|
|
||||||
|
def _get_realtime_data(self, trade_date: str) -> pd.DataFrame:
|
||||||
|
"""获取实时数据并计算原始特征"""
|
||||||
|
ch = get_ch_client()
|
||||||
|
|
||||||
|
# 获取股票数据
|
||||||
|
ch_codes = [code_to_ch_format(c) for c in self.all_stocks if code_to_ch_format(c)]
|
||||||
|
ch_codes_str = "','".join(ch_codes)
|
||||||
|
|
||||||
|
stock_query = f"""
|
||||||
|
SELECT code, timestamp, close, amt
|
||||||
|
FROM stock_minute
|
||||||
|
WHERE toDate(timestamp) = '{trade_date}'
|
||||||
|
AND code IN ('{ch_codes_str}')
|
||||||
|
ORDER BY timestamp
|
||||||
|
"""
|
||||||
|
stock_result = ch.execute(stock_query)
|
||||||
|
if not stock_result:
|
||||||
|
return pd.DataFrame()
|
||||||
|
|
||||||
|
stock_df = pd.DataFrame(stock_result, columns=['ch_code', 'timestamp', 'close', 'amt'])
|
||||||
|
|
||||||
|
# 映射回原始代码
|
||||||
|
ch_to_code = {code_to_ch_format(c): c for c in self.all_stocks if code_to_ch_format(c)}
|
||||||
|
stock_df['code'] = stock_df['ch_code'].map(ch_to_code)
|
||||||
|
stock_df = stock_df.dropna(subset=['code'])
|
||||||
|
|
||||||
|
# 获取指数数据
|
||||||
|
index_query = f"""
|
||||||
|
SELECT timestamp, close
|
||||||
|
FROM index_minute
|
||||||
|
WHERE toDate(timestamp) = '{trade_date}'
|
||||||
|
AND code = '{REFERENCE_INDEX}'
|
||||||
|
ORDER BY timestamp
|
||||||
|
"""
|
||||||
|
index_result = ch.execute(index_query)
|
||||||
|
if not index_result:
|
||||||
|
return pd.DataFrame()
|
||||||
|
|
||||||
|
index_df = pd.DataFrame(index_result, columns=['timestamp', 'close'])
|
||||||
|
|
||||||
|
# 获取昨收价
|
||||||
|
engine = get_mysql_engine()
|
||||||
|
codes_str = "','".join([c for c in self.all_stocks if c and len(c) == 6])
|
||||||
|
|
||||||
|
with engine.connect() as conn:
|
||||||
|
prev_result = conn.execute(text(f"""
|
||||||
|
SELECT SECCODE, F007N FROM ea_trade
|
||||||
|
WHERE SECCODE IN ('{codes_str}')
|
||||||
|
AND TRADEDATE = (SELECT MAX(TRADEDATE) FROM ea_trade WHERE TRADEDATE < '{trade_date}')
|
||||||
|
AND F007N > 0
|
||||||
|
"""))
|
||||||
|
prev_close = {row[0]: float(row[1]) for row in prev_result if row[1]}
|
||||||
|
|
||||||
|
idx_result = conn.execute(text("""
|
||||||
|
SELECT F006N FROM ea_exchangetrade
|
||||||
|
WHERE INDEXCODE = '000001' AND TRADEDATE < :today
|
||||||
|
ORDER BY TRADEDATE DESC LIMIT 1
|
||||||
|
"""), {'today': trade_date}).fetchone()
|
||||||
|
index_prev_close = float(idx_result[0]) if idx_result else None
|
||||||
|
|
||||||
|
if not prev_close or not index_prev_close:
|
||||||
|
return pd.DataFrame()
|
||||||
|
|
||||||
|
# 计算涨跌幅
|
||||||
|
stock_df['prev_close'] = stock_df['code'].map(prev_close)
|
||||||
|
stock_df = stock_df.dropna(subset=['prev_close'])
|
||||||
|
stock_df['change_pct'] = (stock_df['close'] - stock_df['prev_close']) / stock_df['prev_close'] * 100
|
||||||
|
|
||||||
|
index_df['change_pct'] = (index_df['close'] - index_prev_close) / index_prev_close * 100
|
||||||
|
index_map = dict(zip(index_df['timestamp'], index_df['change_pct']))
|
||||||
|
|
||||||
|
# 按时间聚合概念特征
|
||||||
|
results = []
|
||||||
|
for ts in sorted(stock_df['timestamp'].unique()):
|
||||||
|
ts_data = stock_df[stock_df['timestamp'] == ts]
|
||||||
|
idx_chg = index_map.get(ts, 0)
|
||||||
|
|
||||||
|
stock_chg = dict(zip(ts_data['code'], ts_data['change_pct']))
|
||||||
|
stock_amt = dict(zip(ts_data['code'], ts_data['amt']))
|
||||||
|
|
||||||
|
for cid, stocks in self.concept_stocks.items():
|
||||||
|
changes = [stock_chg[s] for s in stocks if s in stock_chg]
|
||||||
|
amts = [stock_amt.get(s, 0) for s in stocks if s in stock_chg]
|
||||||
|
|
||||||
|
if not changes:
|
||||||
|
continue
|
||||||
|
|
||||||
|
alpha = np.mean(changes) - idx_chg
|
||||||
|
total_amt = sum(amts)
|
||||||
|
limit_up_ratio = sum(1 for c in changes if c >= CONFIG['limit_up_threshold']) / len(changes)
|
||||||
|
|
||||||
|
results.append({
|
||||||
|
'concept_id': cid,
|
||||||
|
'timestamp': ts,
|
||||||
|
'time_slot': time_to_slot(ts),
|
||||||
|
'alpha': alpha,
|
||||||
|
'total_amt': total_amt,
|
||||||
|
'limit_up_ratio': limit_up_ratio,
|
||||||
|
'stock_count': len(changes),
|
||||||
|
})
|
||||||
|
|
||||||
|
if not results:
|
||||||
|
return pd.DataFrame()
|
||||||
|
|
||||||
|
df = pd.DataFrame(results)
|
||||||
|
|
||||||
|
# 计算排名
|
||||||
|
for ts in df['timestamp'].unique():
|
||||||
|
mask = df['timestamp'] == ts
|
||||||
|
df.loc[mask, 'rank_pct'] = df.loc[mask, 'alpha'].rank(pct=True)
|
||||||
|
|
||||||
|
return df
|
||||||
|
|
||||||
|
def _compute_zscore(self, concept_id: str, time_slot: str, alpha: float, total_amt: float, rank_pct: float) -> Optional[Dict]:
|
||||||
|
"""计算 Z-Score"""
|
||||||
|
if concept_id not in self.baselines:
|
||||||
|
return None
|
||||||
|
|
||||||
|
baseline = self.baselines[concept_id]
|
||||||
|
if time_slot not in baseline:
|
||||||
|
return None
|
||||||
|
|
||||||
|
bl = baseline[time_slot]
|
||||||
|
|
||||||
|
alpha_z = np.clip((alpha - bl['alpha_mean']) / bl['alpha_std'], -5, 5)
|
||||||
|
amt_z = np.clip((total_amt - bl['amt_mean']) / bl['amt_std'], -5, 5)
|
||||||
|
rank_z = np.clip((rank_pct - bl['rank_mean']) / bl['rank_std'], -5, 5)
|
||||||
|
|
||||||
|
# 动量(基于 Z-Score 历史)
|
||||||
|
history = list(self.zscore_history[concept_id])
|
||||||
|
mom_3m = 0.0
|
||||||
|
mom_5m = 0.0
|
||||||
|
|
||||||
|
if len(history) >= 3:
|
||||||
|
recent = [h['alpha_zscore'] for h in history[-3:]]
|
||||||
|
older = [h['alpha_zscore'] for h in history[-6:-3]] if len(history) >= 6 else [history[0]['alpha_zscore']]
|
||||||
|
mom_3m = np.mean(recent) - np.mean(older)
|
||||||
|
|
||||||
|
if len(history) >= 5:
|
||||||
|
recent = [h['alpha_zscore'] for h in history[-5:]]
|
||||||
|
older = [h['alpha_zscore'] for h in history[-10:-5]] if len(history) >= 10 else [history[0]['alpha_zscore']]
|
||||||
|
mom_5m = np.mean(recent) - np.mean(older)
|
||||||
|
|
||||||
|
return {
|
||||||
|
'alpha_zscore': float(alpha_z),
|
||||||
|
'amt_zscore': float(amt_z),
|
||||||
|
'rank_zscore': float(rank_z),
|
||||||
|
'momentum_3m': float(mom_3m),
|
||||||
|
'momentum_5m': float(mom_5m),
|
||||||
|
}
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def _ml_score(self, sequences: np.ndarray) -> np.ndarray:
|
||||||
|
"""批量 ML 评分"""
|
||||||
|
if self.model is None or len(sequences) == 0:
|
||||||
|
return np.zeros(len(sequences))
|
||||||
|
|
||||||
|
x = torch.FloatTensor(sequences).to(self.device)
|
||||||
|
errors = self.model.compute_reconstruction_error(x, reduction='none')
|
||||||
|
last_errors = errors[:, -1].cpu().numpy()
|
||||||
|
|
||||||
|
# 转换为 0-100 分数
|
||||||
|
if self.thresholds:
|
||||||
|
p50 = self.thresholds.get('median', 0.001)
|
||||||
|
p99 = self.thresholds.get('p99', 0.05)
|
||||||
|
scores = 50 + (last_errors - p50) / (p99 - p50 + 1e-6) * 49
|
||||||
|
else:
|
||||||
|
scores = last_errors * 1000
|
||||||
|
|
||||||
|
return np.clip(scores, 0, 100)
|
||||||
|
|
||||||
|
def detect(self, trade_date: str = None) -> List[Dict]:
|
||||||
|
"""检测指定日期的异动"""
|
||||||
|
trade_date = trade_date or datetime.now().strftime('%Y-%m-%d')
|
||||||
|
print(f"\n检测 {trade_date} 的异动...")
|
||||||
|
|
||||||
|
# 重置状态
|
||||||
|
self.zscore_history.clear()
|
||||||
|
self.anomaly_candidates.clear()
|
||||||
|
self.cooldown.clear()
|
||||||
|
|
||||||
|
# 获取数据
|
||||||
|
raw_df = self._get_realtime_data(trade_date)
|
||||||
|
if raw_df.empty:
|
||||||
|
print("无数据")
|
||||||
|
return []
|
||||||
|
|
||||||
|
timestamps = sorted(raw_df['timestamp'].unique())
|
||||||
|
print(f"时间点数: {len(timestamps)}")
|
||||||
|
|
||||||
|
all_alerts = []
|
||||||
|
|
||||||
|
for ts in timestamps:
|
||||||
|
ts_data = raw_df[raw_df['timestamp'] == ts]
|
||||||
|
time_slot = time_to_slot(ts)
|
||||||
|
|
||||||
|
candidates = []
|
||||||
|
|
||||||
|
# 计算每个概念的 Z-Score
|
||||||
|
for _, row in ts_data.iterrows():
|
||||||
|
cid = row['concept_id']
|
||||||
|
|
||||||
|
zscore = self._compute_zscore(
|
||||||
|
cid, time_slot,
|
||||||
|
row['alpha'], row['total_amt'], row['rank_pct']
|
||||||
|
)
|
||||||
|
|
||||||
|
if zscore is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 完整特征
|
||||||
|
features = {
|
||||||
|
**zscore,
|
||||||
|
'alpha': row['alpha'],
|
||||||
|
'limit_up_ratio': row['limit_up_ratio'],
|
||||||
|
'total_amt': row['total_amt'],
|
||||||
|
}
|
||||||
|
|
||||||
|
# 更新历史
|
||||||
|
self.zscore_history[cid].append(zscore)
|
||||||
|
|
||||||
|
# 规则评分
|
||||||
|
rule_score, triggered = score_rules_zscore(features)
|
||||||
|
|
||||||
|
candidates.append((cid, features, rule_score, triggered))
|
||||||
|
|
||||||
|
if not candidates:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 批量 ML 评分
|
||||||
|
sequences = []
|
||||||
|
valid_candidates = []
|
||||||
|
|
||||||
|
for cid, features, rule_score, triggered in candidates:
|
||||||
|
history = list(self.zscore_history[cid])
|
||||||
|
if len(history) >= CONFIG['seq_len']:
|
||||||
|
seq = np.array([[h['alpha_zscore'], h['amt_zscore'], h['rank_zscore'],
|
||||||
|
h['momentum_3m'], h['momentum_5m'], features['limit_up_ratio']]
|
||||||
|
for h in history])
|
||||||
|
sequences.append(seq)
|
||||||
|
valid_candidates.append((cid, features, rule_score, triggered))
|
||||||
|
|
||||||
|
if not sequences:
|
||||||
|
continue
|
||||||
|
|
||||||
|
ml_scores = self._ml_score(np.array(sequences))
|
||||||
|
|
||||||
|
# 融合 + 确认
|
||||||
|
for i, (cid, features, rule_score, triggered) in enumerate(valid_candidates):
|
||||||
|
ml_score = ml_scores[i]
|
||||||
|
final_score = CONFIG['rule_weight'] * rule_score + CONFIG['ml_weight'] * ml_score
|
||||||
|
|
||||||
|
# 判断触发
|
||||||
|
is_triggered = (
|
||||||
|
rule_score >= CONFIG['rule_trigger'] or
|
||||||
|
ml_score >= CONFIG['ml_trigger'] or
|
||||||
|
final_score >= CONFIG['fusion_trigger']
|
||||||
|
)
|
||||||
|
|
||||||
|
self.anomaly_candidates[cid].append((ts, final_score))
|
||||||
|
|
||||||
|
if not is_triggered:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 冷却期
|
||||||
|
if cid in self.cooldown:
|
||||||
|
if (ts - self.cooldown[cid]).total_seconds() < CONFIG['cooldown_minutes'] * 60:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 持续确认
|
||||||
|
recent = list(self.anomaly_candidates[cid])
|
||||||
|
if len(recent) < CONFIG['confirm_window']:
|
||||||
|
continue
|
||||||
|
|
||||||
|
exceed = sum(1 for _, s in recent if s >= CONFIG['fusion_trigger'])
|
||||||
|
ratio = exceed / len(recent)
|
||||||
|
|
||||||
|
if ratio < CONFIG['confirm_ratio']:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 确认异动!
|
||||||
|
self.cooldown[cid] = ts
|
||||||
|
|
||||||
|
alpha = features['alpha']
|
||||||
|
alert_type = 'surge_up' if alpha >= 1.5 else 'surge_down' if alpha <= -1.5 else 'surge'
|
||||||
|
|
||||||
|
concept_name = next((c['concept_name'] for c in self.concepts if c['concept_id'] == cid), cid)
|
||||||
|
|
||||||
|
all_alerts.append({
|
||||||
|
'concept_id': cid,
|
||||||
|
'concept_name': concept_name,
|
||||||
|
'alert_time': ts,
|
||||||
|
'trade_date': trade_date,
|
||||||
|
'alert_type': alert_type,
|
||||||
|
'final_score': float(final_score),
|
||||||
|
'rule_score': float(rule_score),
|
||||||
|
'ml_score': float(ml_score),
|
||||||
|
'confirm_ratio': float(ratio),
|
||||||
|
'alpha': float(alpha),
|
||||||
|
'alpha_zscore': float(features['alpha_zscore']),
|
||||||
|
'amt_zscore': float(features['amt_zscore']),
|
||||||
|
'rank_zscore': float(features['rank_zscore']),
|
||||||
|
'momentum_3m': float(features['momentum_3m']),
|
||||||
|
'momentum_5m': float(features['momentum_5m']),
|
||||||
|
'limit_up_ratio': float(features['limit_up_ratio']),
|
||||||
|
'triggered_rules': triggered,
|
||||||
|
'trigger_reason': f"融合({final_score:.0f})+确认({ratio:.0%})",
|
||||||
|
})
|
||||||
|
|
||||||
|
print(f"检测到 {len(all_alerts)} 个异动")
|
||||||
|
return all_alerts
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 数据库存储 ====================
|
||||||
|
|
||||||
|
def create_v2_table():
|
||||||
|
"""创建 V2 异动表(如果不存在)"""
|
||||||
|
engine = get_mysql_engine()
|
||||||
|
with engine.begin() as conn:
|
||||||
|
conn.execute(text("""
|
||||||
|
CREATE TABLE IF NOT EXISTS concept_anomaly_v2 (
|
||||||
|
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||||
|
concept_id VARCHAR(50) NOT NULL,
|
||||||
|
alert_time DATETIME NOT NULL,
|
||||||
|
trade_date DATE NOT NULL,
|
||||||
|
alert_type VARCHAR(20) NOT NULL,
|
||||||
|
final_score FLOAT,
|
||||||
|
rule_score FLOAT,
|
||||||
|
ml_score FLOAT,
|
||||||
|
trigger_reason VARCHAR(200),
|
||||||
|
confirm_ratio FLOAT,
|
||||||
|
alpha FLOAT,
|
||||||
|
alpha_zscore FLOAT,
|
||||||
|
amt_zscore FLOAT,
|
||||||
|
rank_zscore FLOAT,
|
||||||
|
momentum_3m FLOAT,
|
||||||
|
momentum_5m FLOAT,
|
||||||
|
limit_up_ratio FLOAT,
|
||||||
|
triggered_rules TEXT,
|
||||||
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
UNIQUE KEY uk_concept_time (concept_id, alert_time),
|
||||||
|
INDEX idx_trade_date (trade_date),
|
||||||
|
INDEX idx_alert_type (alert_type)
|
||||||
|
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
|
||||||
|
"""))
|
||||||
|
print("concept_anomaly_v2 表已就绪")
|
||||||
|
|
||||||
|
|
||||||
|
def save_alerts_to_db(alerts: List[Dict]) -> int:
|
||||||
|
"""保存异动到数据库"""
|
||||||
|
if not alerts:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
engine = get_mysql_engine()
|
||||||
|
saved = 0
|
||||||
|
|
||||||
|
with engine.begin() as conn:
|
||||||
|
for alert in alerts:
|
||||||
|
try:
|
||||||
|
insert_sql = text("""
|
||||||
|
INSERT IGNORE INTO concept_anomaly_v2
|
||||||
|
(concept_id, alert_time, trade_date, alert_type,
|
||||||
|
final_score, rule_score, ml_score, trigger_reason, confirm_ratio,
|
||||||
|
alpha, alpha_zscore, amt_zscore, rank_zscore,
|
||||||
|
momentum_3m, momentum_5m, limit_up_ratio, triggered_rules)
|
||||||
|
VALUES
|
||||||
|
(:concept_id, :alert_time, :trade_date, :alert_type,
|
||||||
|
:final_score, :rule_score, :ml_score, :trigger_reason, :confirm_ratio,
|
||||||
|
:alpha, :alpha_zscore, :amt_zscore, :rank_zscore,
|
||||||
|
:momentum_3m, :momentum_5m, :limit_up_ratio, :triggered_rules)
|
||||||
|
""")
|
||||||
|
|
||||||
|
result = conn.execute(insert_sql, {
|
||||||
|
'concept_id': alert['concept_id'],
|
||||||
|
'alert_time': alert['alert_time'],
|
||||||
|
'trade_date': alert['trade_date'],
|
||||||
|
'alert_type': alert['alert_type'],
|
||||||
|
'final_score': alert['final_score'],
|
||||||
|
'rule_score': alert['rule_score'],
|
||||||
|
'ml_score': alert['ml_score'],
|
||||||
|
'trigger_reason': alert['trigger_reason'],
|
||||||
|
'confirm_ratio': alert['confirm_ratio'],
|
||||||
|
'alpha': alert['alpha'],
|
||||||
|
'alpha_zscore': alert['alpha_zscore'],
|
||||||
|
'amt_zscore': alert['amt_zscore'],
|
||||||
|
'rank_zscore': alert['rank_zscore'],
|
||||||
|
'momentum_3m': alert['momentum_3m'],
|
||||||
|
'momentum_5m': alert['momentum_5m'],
|
||||||
|
'limit_up_ratio': alert['limit_up_ratio'],
|
||||||
|
'triggered_rules': json.dumps(alert.get('triggered_rules', []), ensure_ascii=False),
|
||||||
|
})
|
||||||
|
|
||||||
|
if result.rowcount > 0:
|
||||||
|
saved += 1
|
||||||
|
except Exception as e:
|
||||||
|
print(f"保存失败: {alert['concept_id']} - {e}")
|
||||||
|
|
||||||
|
return saved
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
import argparse
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--date', type=str, default=None)
|
||||||
|
parser.add_argument('--no-save', action='store_true', help='不保存到数据库,只打印')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# 确保表存在
|
||||||
|
if not args.no_save:
|
||||||
|
create_v2_table()
|
||||||
|
|
||||||
|
detector = RealtimeDetectorV2()
|
||||||
|
alerts = detector.detect(args.date)
|
||||||
|
|
||||||
|
print(f"\n{'='*60}")
|
||||||
|
print(f"检测结果 ({len(alerts)} 个异动)")
|
||||||
|
print('='*60)
|
||||||
|
|
||||||
|
for a in alerts[:20]:
|
||||||
|
print(f"[{a['alert_time'].strftime('%H:%M') if hasattr(a['alert_time'], 'strftime') else a['alert_time']}] "
|
||||||
|
f"{a['concept_name']} | {a['alert_type']} | "
|
||||||
|
f"分数={a['final_score']:.0f} 确认={a['confirm_ratio']:.0%} "
|
||||||
|
f"α={a['alpha']:.2f}% αZ={a['alpha_zscore']:.1f}")
|
||||||
|
|
||||||
|
if len(alerts) > 20:
|
||||||
|
print(f"... 共 {len(alerts)} 个")
|
||||||
|
|
||||||
|
# 保存到数据库
|
||||||
|
if not args.no_save and alerts:
|
||||||
|
saved = save_alerts_to_db(alerts)
|
||||||
|
print(f"\n✅ 已保存 {saved}/{len(alerts)} 条到 concept_anomaly_v2 表")
|
||||||
|
elif args.no_save:
|
||||||
|
print(f"\n⚠️ --no-save 模式,未保存到数据库")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
622
ml/train_v2.py
Normal file
622
ml/train_v2.py
Normal file
@@ -0,0 +1,622 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
训练脚本 V2 - 基于 Z-Score 特征的 LSTM Autoencoder
|
||||||
|
|
||||||
|
改进点:
|
||||||
|
1. 使用 Z-Score 特征(相对于同时间片历史的偏离)
|
||||||
|
2. 短序列:10分钟(不需要30分钟预热)
|
||||||
|
3. 开盘即可检测:9:30 直接有特征
|
||||||
|
|
||||||
|
模型输入:
|
||||||
|
- 过去10分钟的 Z-Score 特征序列
|
||||||
|
- 特征:alpha_zscore, amt_zscore, rank_zscore, momentum_3m, momentum_5m, limit_up_ratio
|
||||||
|
|
||||||
|
模型学习:
|
||||||
|
- 学习 Z-Score 序列的"正常演化模式"
|
||||||
|
- 异动 = Z-Score 序列的异常演化(重构误差大)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Tuple, Dict
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
from torch.optim import AdamW
|
||||||
|
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from model import TransformerAutoencoder, AnomalyDetectionLoss, count_parameters
|
||||||
|
|
||||||
|
# 性能优化
|
||||||
|
torch.backends.cudnn.benchmark = True
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
torch.backends.cudnn.allow_tf32 = True
|
||||||
|
|
||||||
|
try:
|
||||||
|
import matplotlib
|
||||||
|
matplotlib.use('Agg')
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
HAS_MATPLOTLIB = True
|
||||||
|
except ImportError:
|
||||||
|
HAS_MATPLOTLIB = False
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 配置 ====================
|
||||||
|
|
||||||
|
TRAIN_CONFIG = {
|
||||||
|
# 数据配置(改进!)
|
||||||
|
'seq_len': 10, # 10分钟序列(不是30分钟!)
|
||||||
|
'stride': 2, # 步长2分钟
|
||||||
|
|
||||||
|
# 时间切分
|
||||||
|
'train_end_date': '2024-06-30',
|
||||||
|
'val_end_date': '2024-09-30',
|
||||||
|
|
||||||
|
# V2 特征(Z-Score 为主)
|
||||||
|
'features': [
|
||||||
|
'alpha_zscore', # Alpha 的 Z-Score
|
||||||
|
'amt_zscore', # 成交额的 Z-Score
|
||||||
|
'rank_zscore', # 排名的 Z-Score
|
||||||
|
'momentum_3m', # 3分钟动量
|
||||||
|
'momentum_5m', # 5分钟动量
|
||||||
|
'limit_up_ratio', # 涨停占比
|
||||||
|
],
|
||||||
|
|
||||||
|
# 训练配置
|
||||||
|
'batch_size': 4096,
|
||||||
|
'epochs': 100,
|
||||||
|
'learning_rate': 3e-4,
|
||||||
|
'weight_decay': 1e-5,
|
||||||
|
'gradient_clip': 1.0,
|
||||||
|
|
||||||
|
# 早停配置
|
||||||
|
'patience': 15,
|
||||||
|
'min_delta': 1e-6,
|
||||||
|
|
||||||
|
# 模型配置(小型 LSTM)
|
||||||
|
'model': {
|
||||||
|
'n_features': 6,
|
||||||
|
'hidden_dim': 32,
|
||||||
|
'latent_dim': 4,
|
||||||
|
'num_layers': 1,
|
||||||
|
'dropout': 0.2,
|
||||||
|
'bidirectional': True,
|
||||||
|
},
|
||||||
|
|
||||||
|
# 标准化配置
|
||||||
|
'clip_value': 5.0, # Z-Score 已经标准化,clip 5.0 足够
|
||||||
|
|
||||||
|
# 阈值配置
|
||||||
|
'threshold_percentiles': [90, 95, 99],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 数据加载 ====================
|
||||||
|
|
||||||
|
def load_data_by_date(data_dir: str, features: List[str]) -> Dict[str, pd.DataFrame]:
|
||||||
|
"""按日期加载 V2 数据"""
|
||||||
|
data_path = Path(data_dir)
|
||||||
|
parquet_files = sorted(data_path.glob("features_v2_*.parquet"))
|
||||||
|
|
||||||
|
if not parquet_files:
|
||||||
|
raise FileNotFoundError(f"未找到 V2 数据文件: {data_dir}")
|
||||||
|
|
||||||
|
print(f"找到 {len(parquet_files)} 个 V2 数据文件")
|
||||||
|
|
||||||
|
date_data = {}
|
||||||
|
|
||||||
|
for pf in tqdm(parquet_files, desc="加载数据"):
|
||||||
|
date = pf.stem.replace('features_v2_', '')
|
||||||
|
|
||||||
|
df = pd.read_parquet(pf)
|
||||||
|
|
||||||
|
required_cols = features + ['concept_id', 'timestamp']
|
||||||
|
missing_cols = [c for c in required_cols if c not in df.columns]
|
||||||
|
if missing_cols:
|
||||||
|
print(f"警告: {date} 缺少列: {missing_cols}, 跳过")
|
||||||
|
continue
|
||||||
|
|
||||||
|
date_data[date] = df
|
||||||
|
|
||||||
|
print(f"成功加载 {len(date_data)} 天的数据")
|
||||||
|
return date_data
|
||||||
|
|
||||||
|
|
||||||
|
def split_data_by_date(
|
||||||
|
date_data: Dict[str, pd.DataFrame],
|
||||||
|
train_end: str,
|
||||||
|
val_end: str
|
||||||
|
) -> Tuple[Dict[str, pd.DataFrame], Dict[str, pd.DataFrame], Dict[str, pd.DataFrame]]:
|
||||||
|
"""按日期划分数据集"""
|
||||||
|
train_data = {}
|
||||||
|
val_data = {}
|
||||||
|
test_data = {}
|
||||||
|
|
||||||
|
for date, df in date_data.items():
|
||||||
|
if date <= train_end:
|
||||||
|
train_data[date] = df
|
||||||
|
elif date <= val_end:
|
||||||
|
val_data[date] = df
|
||||||
|
else:
|
||||||
|
test_data[date] = df
|
||||||
|
|
||||||
|
print(f"数据集划分:")
|
||||||
|
print(f" 训练集: {len(train_data)} 天 (<= {train_end})")
|
||||||
|
print(f" 验证集: {len(val_data)} 天 ({train_end} ~ {val_end})")
|
||||||
|
print(f" 测试集: {len(test_data)} 天 (> {val_end})")
|
||||||
|
|
||||||
|
return train_data, val_data, test_data
|
||||||
|
|
||||||
|
|
||||||
|
def build_sequences_by_concept(
|
||||||
|
date_data: Dict[str, pd.DataFrame],
|
||||||
|
features: List[str],
|
||||||
|
seq_len: int,
|
||||||
|
stride: int
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""按概念分组构建序列"""
|
||||||
|
all_dfs = []
|
||||||
|
for date, df in sorted(date_data.items()):
|
||||||
|
df = df.copy()
|
||||||
|
df['date'] = date
|
||||||
|
all_dfs.append(df)
|
||||||
|
|
||||||
|
if not all_dfs:
|
||||||
|
return np.array([])
|
||||||
|
|
||||||
|
combined = pd.concat(all_dfs, ignore_index=True)
|
||||||
|
combined = combined.sort_values(['concept_id', 'date', 'timestamp'])
|
||||||
|
|
||||||
|
all_sequences = []
|
||||||
|
grouped = combined.groupby('concept_id', sort=False)
|
||||||
|
n_concepts = len(grouped)
|
||||||
|
|
||||||
|
for concept_id, concept_df in tqdm(grouped, desc="构建序列", total=n_concepts, leave=False):
|
||||||
|
feature_data = concept_df[features].values
|
||||||
|
feature_data = np.nan_to_num(feature_data, nan=0.0, posinf=0.0, neginf=0.0)
|
||||||
|
|
||||||
|
n_points = len(feature_data)
|
||||||
|
for start in range(0, n_points - seq_len + 1, stride):
|
||||||
|
seq = feature_data[start:start + seq_len]
|
||||||
|
all_sequences.append(seq)
|
||||||
|
|
||||||
|
if not all_sequences:
|
||||||
|
return np.array([])
|
||||||
|
|
||||||
|
sequences = np.array(all_sequences)
|
||||||
|
print(f" 构建序列: {len(sequences):,} 条 (来自 {n_concepts} 个概念)")
|
||||||
|
|
||||||
|
return sequences
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 数据集 ====================
|
||||||
|
|
||||||
|
class SequenceDataset(Dataset):
|
||||||
|
def __init__(self, sequences: np.ndarray):
|
||||||
|
self.sequences = torch.FloatTensor(sequences)
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self.sequences)
|
||||||
|
|
||||||
|
def __getitem__(self, idx: int) -> torch.Tensor:
|
||||||
|
return self.sequences[idx]
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 训练器 ====================
|
||||||
|
|
||||||
|
class EarlyStopping:
|
||||||
|
def __init__(self, patience: int = 10, min_delta: float = 1e-6):
|
||||||
|
self.patience = patience
|
||||||
|
self.min_delta = min_delta
|
||||||
|
self.counter = 0
|
||||||
|
self.best_loss = float('inf')
|
||||||
|
self.early_stop = False
|
||||||
|
|
||||||
|
def __call__(self, val_loss: float) -> bool:
|
||||||
|
if val_loss < self.best_loss - self.min_delta:
|
||||||
|
self.best_loss = val_loss
|
||||||
|
self.counter = 0
|
||||||
|
else:
|
||||||
|
self.counter += 1
|
||||||
|
if self.counter >= self.patience:
|
||||||
|
self.early_stop = True
|
||||||
|
return self.early_stop
|
||||||
|
|
||||||
|
|
||||||
|
class Trainer:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: nn.Module,
|
||||||
|
train_loader: DataLoader,
|
||||||
|
val_loader: DataLoader,
|
||||||
|
config: Dict,
|
||||||
|
device: torch.device,
|
||||||
|
save_dir: str = 'ml/checkpoints_v2'
|
||||||
|
):
|
||||||
|
self.model = model.to(device)
|
||||||
|
self.train_loader = train_loader
|
||||||
|
self.val_loader = val_loader
|
||||||
|
self.config = config
|
||||||
|
self.device = device
|
||||||
|
self.save_dir = Path(save_dir)
|
||||||
|
self.save_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
self.optimizer = AdamW(
|
||||||
|
model.parameters(),
|
||||||
|
lr=config['learning_rate'],
|
||||||
|
weight_decay=config['weight_decay']
|
||||||
|
)
|
||||||
|
|
||||||
|
self.scheduler = CosineAnnealingWarmRestarts(
|
||||||
|
self.optimizer, T_0=10, T_mult=2, eta_min=1e-6
|
||||||
|
)
|
||||||
|
|
||||||
|
self.criterion = AnomalyDetectionLoss()
|
||||||
|
|
||||||
|
self.early_stopping = EarlyStopping(
|
||||||
|
patience=config['patience'],
|
||||||
|
min_delta=config['min_delta']
|
||||||
|
)
|
||||||
|
|
||||||
|
self.use_amp = torch.cuda.is_available()
|
||||||
|
self.scaler = torch.cuda.amp.GradScaler() if self.use_amp else None
|
||||||
|
if self.use_amp:
|
||||||
|
print(" ✓ 启用 AMP 混合精度训练")
|
||||||
|
|
||||||
|
self.history = {'train_loss': [], 'val_loss': [], 'learning_rate': []}
|
||||||
|
self.best_val_loss = float('inf')
|
||||||
|
|
||||||
|
def train_epoch(self) -> float:
|
||||||
|
self.model.train()
|
||||||
|
total_loss = 0.0
|
||||||
|
n_batches = 0
|
||||||
|
|
||||||
|
pbar = tqdm(self.train_loader, desc="Training", leave=False)
|
||||||
|
for batch in pbar:
|
||||||
|
batch = batch.to(self.device, non_blocking=True)
|
||||||
|
self.optimizer.zero_grad(set_to_none=True)
|
||||||
|
|
||||||
|
if self.use_amp:
|
||||||
|
with torch.cuda.amp.autocast():
|
||||||
|
output, latent = self.model(batch)
|
||||||
|
loss, _ = self.criterion(output, batch, latent)
|
||||||
|
|
||||||
|
self.scaler.scale(loss).backward()
|
||||||
|
self.scaler.unscale_(self.optimizer)
|
||||||
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config['gradient_clip'])
|
||||||
|
self.scaler.step(self.optimizer)
|
||||||
|
self.scaler.update()
|
||||||
|
else:
|
||||||
|
output, latent = self.model(batch)
|
||||||
|
loss, _ = self.criterion(output, batch, latent)
|
||||||
|
loss.backward()
|
||||||
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config['gradient_clip'])
|
||||||
|
self.optimizer.step()
|
||||||
|
|
||||||
|
total_loss += loss.item()
|
||||||
|
n_batches += 1
|
||||||
|
pbar.set_postfix({'loss': f"{loss.item():.4f}"})
|
||||||
|
|
||||||
|
return total_loss / n_batches
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def validate(self) -> float:
|
||||||
|
self.model.eval()
|
||||||
|
total_loss = 0.0
|
||||||
|
n_batches = 0
|
||||||
|
|
||||||
|
for batch in self.val_loader:
|
||||||
|
batch = batch.to(self.device, non_blocking=True)
|
||||||
|
|
||||||
|
if self.use_amp:
|
||||||
|
with torch.cuda.amp.autocast():
|
||||||
|
output, latent = self.model(batch)
|
||||||
|
loss, _ = self.criterion(output, batch, latent)
|
||||||
|
else:
|
||||||
|
output, latent = self.model(batch)
|
||||||
|
loss, _ = self.criterion(output, batch, latent)
|
||||||
|
|
||||||
|
total_loss += loss.item()
|
||||||
|
n_batches += 1
|
||||||
|
|
||||||
|
return total_loss / n_batches
|
||||||
|
|
||||||
|
def save_checkpoint(self, epoch: int, val_loss: float, is_best: bool = False):
|
||||||
|
model_to_save = self.model.module if hasattr(self.model, 'module') else self.model
|
||||||
|
|
||||||
|
checkpoint = {
|
||||||
|
'epoch': epoch,
|
||||||
|
'model_state_dict': model_to_save.state_dict(),
|
||||||
|
'optimizer_state_dict': self.optimizer.state_dict(),
|
||||||
|
'scheduler_state_dict': self.scheduler.state_dict(),
|
||||||
|
'val_loss': val_loss,
|
||||||
|
'config': self.config,
|
||||||
|
}
|
||||||
|
|
||||||
|
torch.save(checkpoint, self.save_dir / 'last_checkpoint.pt')
|
||||||
|
|
||||||
|
if is_best:
|
||||||
|
torch.save(checkpoint, self.save_dir / 'best_model.pt')
|
||||||
|
print(f" ✓ 保存最佳模型 (val_loss: {val_loss:.6f})")
|
||||||
|
|
||||||
|
def train(self, epochs: int):
|
||||||
|
print(f"\n开始训练 ({epochs} epochs)...")
|
||||||
|
print(f"设备: {self.device}")
|
||||||
|
print(f"模型参数量: {count_parameters(self.model):,}")
|
||||||
|
|
||||||
|
for epoch in range(1, epochs + 1):
|
||||||
|
print(f"\nEpoch {epoch}/{epochs}")
|
||||||
|
|
||||||
|
train_loss = self.train_epoch()
|
||||||
|
val_loss = self.validate()
|
||||||
|
|
||||||
|
self.scheduler.step()
|
||||||
|
current_lr = self.optimizer.param_groups[0]['lr']
|
||||||
|
|
||||||
|
self.history['train_loss'].append(train_loss)
|
||||||
|
self.history['val_loss'].append(val_loss)
|
||||||
|
self.history['learning_rate'].append(current_lr)
|
||||||
|
|
||||||
|
print(f" Train Loss: {train_loss:.6f}")
|
||||||
|
print(f" Val Loss: {val_loss:.6f}")
|
||||||
|
print(f" LR: {current_lr:.2e}")
|
||||||
|
|
||||||
|
is_best = val_loss < self.best_val_loss
|
||||||
|
if is_best:
|
||||||
|
self.best_val_loss = val_loss
|
||||||
|
self.save_checkpoint(epoch, val_loss, is_best)
|
||||||
|
|
||||||
|
if self.early_stopping(val_loss):
|
||||||
|
print(f"\n早停触发!")
|
||||||
|
break
|
||||||
|
|
||||||
|
print(f"\n训练完成!最佳验证损失: {self.best_val_loss:.6f}")
|
||||||
|
self.save_history()
|
||||||
|
|
||||||
|
return self.history
|
||||||
|
|
||||||
|
def save_history(self):
|
||||||
|
history_path = self.save_dir / 'training_history.json'
|
||||||
|
with open(history_path, 'w') as f:
|
||||||
|
json.dump(self.history, f, indent=2)
|
||||||
|
print(f"训练历史已保存: {history_path}")
|
||||||
|
|
||||||
|
if HAS_MATPLOTLIB:
|
||||||
|
self.plot_training_curves()
|
||||||
|
|
||||||
|
def plot_training_curves(self):
|
||||||
|
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
|
||||||
|
epochs = range(1, len(self.history['train_loss']) + 1)
|
||||||
|
|
||||||
|
ax1 = axes[0]
|
||||||
|
ax1.plot(epochs, self.history['train_loss'], 'b-', label='Train Loss', linewidth=2)
|
||||||
|
ax1.plot(epochs, self.history['val_loss'], 'r-', label='Val Loss', linewidth=2)
|
||||||
|
ax1.set_xlabel('Epoch')
|
||||||
|
ax1.set_ylabel('Loss')
|
||||||
|
ax1.set_title('Training & Validation Loss (V2)')
|
||||||
|
ax1.legend()
|
||||||
|
ax1.grid(True, alpha=0.3)
|
||||||
|
|
||||||
|
best_epoch = np.argmin(self.history['val_loss']) + 1
|
||||||
|
best_val_loss = min(self.history['val_loss'])
|
||||||
|
ax1.axvline(x=best_epoch, color='g', linestyle='--', alpha=0.7)
|
||||||
|
ax1.scatter([best_epoch], [best_val_loss], color='g', s=100, zorder=5)
|
||||||
|
|
||||||
|
ax2 = axes[1]
|
||||||
|
ax2.plot(epochs, self.history['learning_rate'], 'g-', linewidth=2)
|
||||||
|
ax2.set_xlabel('Epoch')
|
||||||
|
ax2.set_ylabel('Learning Rate')
|
||||||
|
ax2.set_title('Learning Rate Schedule')
|
||||||
|
ax2.set_yscale('log')
|
||||||
|
ax2.grid(True, alpha=0.3)
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig(self.save_dir / 'training_curves.png', dpi=150, bbox_inches='tight')
|
||||||
|
plt.close()
|
||||||
|
print(f"训练曲线已保存")
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 阈值计算 ====================
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def compute_thresholds(
|
||||||
|
model: nn.Module,
|
||||||
|
data_loader: DataLoader,
|
||||||
|
device: torch.device,
|
||||||
|
percentiles: List[float] = [90, 95, 99]
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
"""在验证集上计算阈值"""
|
||||||
|
model.eval()
|
||||||
|
all_errors = []
|
||||||
|
|
||||||
|
print("计算异动阈值...")
|
||||||
|
for batch in tqdm(data_loader, desc="Computing thresholds"):
|
||||||
|
batch = batch.to(device)
|
||||||
|
errors = model.compute_reconstruction_error(batch, reduction='none')
|
||||||
|
seq_errors = errors[:, -1] # 最后一个时刻
|
||||||
|
all_errors.append(seq_errors.cpu().numpy())
|
||||||
|
|
||||||
|
all_errors = np.concatenate(all_errors)
|
||||||
|
|
||||||
|
thresholds = {}
|
||||||
|
for p in percentiles:
|
||||||
|
threshold = np.percentile(all_errors, p)
|
||||||
|
thresholds[f'p{p}'] = float(threshold)
|
||||||
|
print(f" P{p}: {threshold:.6f}")
|
||||||
|
|
||||||
|
thresholds['mean'] = float(np.mean(all_errors))
|
||||||
|
thresholds['std'] = float(np.std(all_errors))
|
||||||
|
thresholds['median'] = float(np.median(all_errors))
|
||||||
|
|
||||||
|
return thresholds
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 主函数 ====================
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description='训练 V2 模型')
|
||||||
|
parser.add_argument('--data_dir', type=str, default='ml/data_v2', help='V2 数据目录')
|
||||||
|
parser.add_argument('--epochs', type=int, default=100)
|
||||||
|
parser.add_argument('--batch_size', type=int, default=4096)
|
||||||
|
parser.add_argument('--lr', type=float, default=3e-4)
|
||||||
|
parser.add_argument('--device', type=str, default='auto')
|
||||||
|
parser.add_argument('--save_dir', type=str, default='ml/checkpoints_v2')
|
||||||
|
parser.add_argument('--train_end', type=str, default='2024-06-30')
|
||||||
|
parser.add_argument('--val_end', type=str, default='2024-09-30')
|
||||||
|
parser.add_argument('--seq_len', type=int, default=10, help='序列长度(分钟)')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
config = TRAIN_CONFIG.copy()
|
||||||
|
config['batch_size'] = args.batch_size
|
||||||
|
config['epochs'] = args.epochs
|
||||||
|
config['learning_rate'] = args.lr
|
||||||
|
config['train_end_date'] = args.train_end
|
||||||
|
config['val_end_date'] = args.val_end
|
||||||
|
config['seq_len'] = args.seq_len
|
||||||
|
|
||||||
|
if args.device == 'auto':
|
||||||
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
else:
|
||||||
|
device = torch.device(args.device)
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
print("概念异动检测模型训练 V2(Z-Score 特征)")
|
||||||
|
print("=" * 60)
|
||||||
|
print(f"数据目录: {args.data_dir}")
|
||||||
|
print(f"设备: {device}")
|
||||||
|
print(f"序列长度: {config['seq_len']} 分钟")
|
||||||
|
print(f"批次大小: {config['batch_size']}")
|
||||||
|
print(f"特征: {config['features']}")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# 1. 加载数据
|
||||||
|
print("\n[1/6] 加载 V2 数据...")
|
||||||
|
date_data = load_data_by_date(args.data_dir, config['features'])
|
||||||
|
|
||||||
|
# 2. 划分数据集
|
||||||
|
print("\n[2/6] 划分数据集...")
|
||||||
|
train_data, val_data, test_data = split_data_by_date(
|
||||||
|
date_data, config['train_end_date'], config['val_end_date']
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. 构建序列
|
||||||
|
print("\n[3/6] 构建序列...")
|
||||||
|
print("训练集:")
|
||||||
|
train_sequences = build_sequences_by_concept(
|
||||||
|
train_data, config['features'], config['seq_len'], config['stride']
|
||||||
|
)
|
||||||
|
print("验证集:")
|
||||||
|
val_sequences = build_sequences_by_concept(
|
||||||
|
val_data, config['features'], config['seq_len'], config['stride']
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(train_sequences) == 0:
|
||||||
|
print("错误: 训练集为空!")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 4. 预处理
|
||||||
|
print("\n[4/6] 数据预处理...")
|
||||||
|
clip_value = config['clip_value']
|
||||||
|
print(f" Z-Score 特征已标准化,截断: ±{clip_value}")
|
||||||
|
|
||||||
|
train_sequences = np.clip(train_sequences, -clip_value, clip_value)
|
||||||
|
if len(val_sequences) > 0:
|
||||||
|
val_sequences = np.clip(val_sequences, -clip_value, clip_value)
|
||||||
|
|
||||||
|
# 保存配置
|
||||||
|
save_dir = Path(args.save_dir)
|
||||||
|
save_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
with open(save_dir / 'config.json', 'w') as f:
|
||||||
|
json.dump(config, f, indent=2)
|
||||||
|
|
||||||
|
# 5. 创建数据加载器
|
||||||
|
print("\n[5/6] 创建数据加载器...")
|
||||||
|
train_dataset = SequenceDataset(train_sequences)
|
||||||
|
val_dataset = SequenceDataset(val_sequences) if len(val_sequences) > 0 else None
|
||||||
|
|
||||||
|
print(f" 训练序列: {len(train_dataset):,}")
|
||||||
|
print(f" 验证序列: {len(val_dataset) if val_dataset else 0:,}")
|
||||||
|
|
||||||
|
n_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 1
|
||||||
|
num_workers = min(32, 8 * n_gpus) if sys.platform != 'win32' else 0
|
||||||
|
|
||||||
|
train_loader = DataLoader(
|
||||||
|
train_dataset,
|
||||||
|
batch_size=config['batch_size'],
|
||||||
|
shuffle=True,
|
||||||
|
num_workers=num_workers,
|
||||||
|
pin_memory=True,
|
||||||
|
prefetch_factor=4 if num_workers > 0 else None,
|
||||||
|
persistent_workers=True if num_workers > 0 else False,
|
||||||
|
drop_last=True
|
||||||
|
)
|
||||||
|
|
||||||
|
val_loader = DataLoader(
|
||||||
|
val_dataset,
|
||||||
|
batch_size=config['batch_size'] * 2,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=num_workers,
|
||||||
|
pin_memory=True,
|
||||||
|
) if val_dataset else None
|
||||||
|
|
||||||
|
# 6. 训练
|
||||||
|
print("\n[6/6] 训练模型...")
|
||||||
|
model = TransformerAutoencoder(**config['model'])
|
||||||
|
|
||||||
|
if torch.cuda.device_count() > 1:
|
||||||
|
print(f" 使用 {torch.cuda.device_count()} 张 GPU 并行训练")
|
||||||
|
model = nn.DataParallel(model)
|
||||||
|
|
||||||
|
if val_loader is None:
|
||||||
|
print("警告: 验证集为空,使用训练集的 10% 作为验证")
|
||||||
|
split_idx = int(len(train_dataset) * 0.9)
|
||||||
|
train_subset = torch.utils.data.Subset(train_dataset, range(split_idx))
|
||||||
|
val_subset = torch.utils.data.Subset(train_dataset, range(split_idx, len(train_dataset)))
|
||||||
|
train_loader = DataLoader(train_subset, batch_size=config['batch_size'], shuffle=True, num_workers=num_workers, pin_memory=True)
|
||||||
|
val_loader = DataLoader(val_subset, batch_size=config['batch_size'], shuffle=False, num_workers=num_workers, pin_memory=True)
|
||||||
|
|
||||||
|
trainer = Trainer(
|
||||||
|
model=model,
|
||||||
|
train_loader=train_loader,
|
||||||
|
val_loader=val_loader,
|
||||||
|
config=config,
|
||||||
|
device=device,
|
||||||
|
save_dir=args.save_dir
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer.train(config['epochs'])
|
||||||
|
|
||||||
|
# 计算阈值
|
||||||
|
print("\n[额外] 计算异动阈值...")
|
||||||
|
best_checkpoint = torch.load(save_dir / 'best_model.pt', map_location=device)
|
||||||
|
|
||||||
|
# 创建新的单 GPU 模型用于计算阈值(避免 DataParallel 问题)
|
||||||
|
threshold_model = TransformerAutoencoder(**config['model'])
|
||||||
|
threshold_model.load_state_dict(best_checkpoint['model_state_dict'])
|
||||||
|
threshold_model.to(device)
|
||||||
|
threshold_model.eval()
|
||||||
|
|
||||||
|
thresholds = compute_thresholds(threshold_model, val_loader, device, config['threshold_percentiles'])
|
||||||
|
|
||||||
|
with open(save_dir / 'thresholds.json', 'w') as f:
|
||||||
|
json.dump(thresholds, f, indent=2)
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("训练完成!")
|
||||||
|
print(f"模型保存位置: {args.save_dir}")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
132
ml/update_baseline.py
Normal file
132
ml/update_baseline.py
Normal file
@@ -0,0 +1,132 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
每日盘后运行:更新滚动基线
|
||||||
|
|
||||||
|
使用方法:
|
||||||
|
python ml/update_baseline.py
|
||||||
|
|
||||||
|
建议加入 crontab,每天 15:30 后运行:
|
||||||
|
30 15 * * 1-5 cd /path/to/project && python ml/update_baseline.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import pickle
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from pathlib import Path
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
|
||||||
|
from ml.prepare_data_v2 import (
|
||||||
|
get_all_concepts, get_trading_days, compute_raw_concept_features,
|
||||||
|
init_process_connections, CONFIG, RAW_CACHE_DIR, BASELINE_DIR
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def update_rolling_baseline(baseline_days: int = 20):
|
||||||
|
"""
|
||||||
|
更新滚动基线(用于实盘检测)
|
||||||
|
|
||||||
|
基线 = 最近 N 个交易日每个时间片的统计量
|
||||||
|
"""
|
||||||
|
print("=" * 60)
|
||||||
|
print("更新滚动基线(用于实盘)")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# 初始化连接
|
||||||
|
init_process_connections()
|
||||||
|
|
||||||
|
# 获取概念列表
|
||||||
|
concepts = get_all_concepts()
|
||||||
|
all_stocks = list(set(s for c in concepts for s in c['stocks']))
|
||||||
|
|
||||||
|
# 获取最近的交易日
|
||||||
|
today = datetime.now().strftime('%Y-%m-%d')
|
||||||
|
start_date = (datetime.now() - timedelta(days=60)).strftime('%Y-%m-%d') # 多取一些
|
||||||
|
|
||||||
|
trading_days = get_trading_days(start_date, today)
|
||||||
|
|
||||||
|
if len(trading_days) < baseline_days:
|
||||||
|
print(f"错误:交易日不足 {baseline_days} 天")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 只取最近 N 天
|
||||||
|
recent_days = trading_days[-baseline_days:]
|
||||||
|
print(f"使用 {len(recent_days)} 天数据: {recent_days[0]} ~ {recent_days[-1]}")
|
||||||
|
|
||||||
|
# 加载原始数据
|
||||||
|
all_data = []
|
||||||
|
for trade_date in tqdm(recent_days, desc="加载数据"):
|
||||||
|
cache_file = os.path.join(RAW_CACHE_DIR, f'raw_{trade_date}.parquet')
|
||||||
|
|
||||||
|
if os.path.exists(cache_file):
|
||||||
|
df = pd.read_parquet(cache_file)
|
||||||
|
else:
|
||||||
|
df = compute_raw_concept_features(trade_date, concepts, all_stocks)
|
||||||
|
|
||||||
|
if not df.empty:
|
||||||
|
all_data.append(df)
|
||||||
|
|
||||||
|
if not all_data:
|
||||||
|
print("错误:无数据")
|
||||||
|
return
|
||||||
|
|
||||||
|
combined = pd.concat(all_data, ignore_index=True)
|
||||||
|
print(f"总数据量: {len(combined):,} 条")
|
||||||
|
|
||||||
|
# 按概念计算基线
|
||||||
|
baselines = {}
|
||||||
|
|
||||||
|
for concept_id, group in tqdm(combined.groupby('concept_id'), desc="计算基线"):
|
||||||
|
baseline_dict = {}
|
||||||
|
|
||||||
|
for time_slot, slot_group in group.groupby('time_slot'):
|
||||||
|
if len(slot_group) < CONFIG['min_baseline_samples']:
|
||||||
|
continue
|
||||||
|
|
||||||
|
alpha_std = slot_group['alpha'].std()
|
||||||
|
amt_std = slot_group['total_amt'].std()
|
||||||
|
rank_std = slot_group['rank_pct'].std()
|
||||||
|
|
||||||
|
baseline_dict[time_slot] = {
|
||||||
|
'alpha_mean': float(slot_group['alpha'].mean()),
|
||||||
|
'alpha_std': float(max(alpha_std if pd.notna(alpha_std) else 1.0, 0.1)),
|
||||||
|
'amt_mean': float(slot_group['total_amt'].mean()),
|
||||||
|
'amt_std': float(max(amt_std if pd.notna(amt_std) else slot_group['total_amt'].mean() * 0.5, 1.0)),
|
||||||
|
'rank_mean': float(slot_group['rank_pct'].mean()),
|
||||||
|
'rank_std': float(max(rank_std if pd.notna(rank_std) else 0.2, 0.05)),
|
||||||
|
'sample_count': len(slot_group),
|
||||||
|
}
|
||||||
|
|
||||||
|
if baseline_dict:
|
||||||
|
baselines[concept_id] = baseline_dict
|
||||||
|
|
||||||
|
print(f"计算了 {len(baselines)} 个概念的基线")
|
||||||
|
|
||||||
|
# 保存
|
||||||
|
os.makedirs(BASELINE_DIR, exist_ok=True)
|
||||||
|
baseline_file = os.path.join(BASELINE_DIR, 'realtime_baseline.pkl')
|
||||||
|
|
||||||
|
with open(baseline_file, 'wb') as f:
|
||||||
|
pickle.dump({
|
||||||
|
'baselines': baselines,
|
||||||
|
'update_time': datetime.now().isoformat(),
|
||||||
|
'date_range': [recent_days[0], recent_days[-1]],
|
||||||
|
'baseline_days': baseline_days,
|
||||||
|
}, f)
|
||||||
|
|
||||||
|
print(f"基线已保存: {baseline_file}")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import argparse
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--days', type=int, default=20, help='基线天数')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
update_rolling_baseline(args.days)
|
||||||
@@ -0,0 +1,267 @@
|
|||||||
|
/**
|
||||||
|
* 迷你分时图组件
|
||||||
|
* 用于灵活屏中显示证券的日内走势
|
||||||
|
*/
|
||||||
|
import React, { useEffect, useRef, useState, useMemo } from 'react';
|
||||||
|
import { Box, Spinner, Center, Text } from '@chakra-ui/react';
|
||||||
|
import * as echarts from 'echarts';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 生成交易时间刻度(用于 X 轴)
|
||||||
|
* A股交易时间:9:30-11:30, 13:00-15:00
|
||||||
|
*/
|
||||||
|
const generateTimeTicks = () => {
|
||||||
|
const ticks = [];
|
||||||
|
// 上午
|
||||||
|
for (let h = 9; h <= 11; h++) {
|
||||||
|
for (let m = (h === 9 ? 30 : 0); m < 60; m++) {
|
||||||
|
if (h === 11 && m > 30) break;
|
||||||
|
ticks.push(`${String(h).padStart(2, '0')}:${String(m).padStart(2, '0')}`);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 下午
|
||||||
|
for (let h = 13; h <= 15; h++) {
|
||||||
|
for (let m = 0; m < 60; m++) {
|
||||||
|
if (h === 15 && m > 0) break;
|
||||||
|
ticks.push(`${String(h).padStart(2, '0')}:${String(m).padStart(2, '0')}`);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ticks;
|
||||||
|
};
|
||||||
|
|
||||||
|
const TIME_TICKS = generateTimeTicks();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* MiniTimelineChart 组件
|
||||||
|
* @param {Object} props
|
||||||
|
* @param {string} props.code - 证券代码
|
||||||
|
* @param {boolean} props.isIndex - 是否为指数
|
||||||
|
* @param {number} props.prevClose - 昨收价
|
||||||
|
* @param {number} props.currentPrice - 当前价(实时)
|
||||||
|
* @param {number} props.height - 图表高度
|
||||||
|
*/
|
||||||
|
const MiniTimelineChart = ({
|
||||||
|
code,
|
||||||
|
isIndex = false,
|
||||||
|
prevClose,
|
||||||
|
currentPrice,
|
||||||
|
height = 120,
|
||||||
|
}) => {
|
||||||
|
const chartRef = useRef(null);
|
||||||
|
const chartInstance = useRef(null);
|
||||||
|
const [timelineData, setTimelineData] = useState([]);
|
||||||
|
const [loading, setLoading] = useState(true);
|
||||||
|
const [error, setError] = useState(null);
|
||||||
|
|
||||||
|
// 获取分钟数据
|
||||||
|
useEffect(() => {
|
||||||
|
if (!code) return;
|
||||||
|
|
||||||
|
const fetchData = async () => {
|
||||||
|
setLoading(true);
|
||||||
|
setError(null);
|
||||||
|
|
||||||
|
try {
|
||||||
|
const apiPath = isIndex
|
||||||
|
? `/api/index/${code}/kline?type=minute`
|
||||||
|
: `/api/stock/${code}/kline?type=minute`;
|
||||||
|
|
||||||
|
const response = await fetch(apiPath);
|
||||||
|
const result = await response.json();
|
||||||
|
|
||||||
|
if (result.success !== false && result.data) {
|
||||||
|
// 格式化数据
|
||||||
|
const formatted = result.data.map(item => ({
|
||||||
|
time: item.time || item.timestamp,
|
||||||
|
price: item.close || item.price,
|
||||||
|
}));
|
||||||
|
setTimelineData(formatted);
|
||||||
|
} else {
|
||||||
|
setError(result.error || '暂无数据');
|
||||||
|
}
|
||||||
|
} catch (e) {
|
||||||
|
setError('加载失败');
|
||||||
|
} finally {
|
||||||
|
setLoading(false);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
fetchData();
|
||||||
|
|
||||||
|
// 交易时间内每分钟刷新
|
||||||
|
const now = new Date();
|
||||||
|
const hours = now.getHours();
|
||||||
|
const minutes = now.getMinutes();
|
||||||
|
const currentMinutes = hours * 60 + minutes;
|
||||||
|
const isTrading = (currentMinutes >= 570 && currentMinutes <= 690) ||
|
||||||
|
(currentMinutes >= 780 && currentMinutes <= 900);
|
||||||
|
|
||||||
|
let intervalId;
|
||||||
|
if (isTrading) {
|
||||||
|
intervalId = setInterval(fetchData, 60000); // 1分钟刷新
|
||||||
|
}
|
||||||
|
|
||||||
|
return () => {
|
||||||
|
if (intervalId) clearInterval(intervalId);
|
||||||
|
};
|
||||||
|
}, [code, isIndex]);
|
||||||
|
|
||||||
|
// 合并实时价格到数据中
|
||||||
|
const chartData = useMemo(() => {
|
||||||
|
if (!timelineData.length) return [];
|
||||||
|
|
||||||
|
const data = [...timelineData];
|
||||||
|
|
||||||
|
// 如果有实时价格,添加到最新点
|
||||||
|
if (currentPrice && data.length > 0) {
|
||||||
|
const now = new Date();
|
||||||
|
const timeStr = `${String(now.getHours()).padStart(2, '0')}:${String(now.getMinutes()).padStart(2, '0')}`;
|
||||||
|
const lastItem = data[data.length - 1];
|
||||||
|
|
||||||
|
// 如果实时价格的时间比最后一条数据新,添加新点
|
||||||
|
if (lastItem.time !== timeStr) {
|
||||||
|
data.push({ time: timeStr, price: currentPrice });
|
||||||
|
} else {
|
||||||
|
// 更新最后一条
|
||||||
|
data[data.length - 1] = { ...lastItem, price: currentPrice };
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return data;
|
||||||
|
}, [timelineData, currentPrice]);
|
||||||
|
|
||||||
|
// 渲染图表
|
||||||
|
useEffect(() => {
|
||||||
|
if (!chartRef.current || loading || !chartData.length) return;
|
||||||
|
|
||||||
|
if (!chartInstance.current) {
|
||||||
|
chartInstance.current = echarts.init(chartRef.current);
|
||||||
|
}
|
||||||
|
|
||||||
|
const baseLine = prevClose || chartData[0]?.price || 0;
|
||||||
|
|
||||||
|
// 计算价格范围
|
||||||
|
const prices = chartData.map(d => d.price).filter(p => p > 0);
|
||||||
|
const minPrice = Math.min(...prices, baseLine);
|
||||||
|
const maxPrice = Math.max(...prices, baseLine);
|
||||||
|
const range = Math.max(maxPrice - baseLine, baseLine - minPrice) * 1.1;
|
||||||
|
|
||||||
|
// 准备数据
|
||||||
|
const times = chartData.map(d => d.time);
|
||||||
|
const values = chartData.map(d => d.price);
|
||||||
|
|
||||||
|
// 判断涨跌
|
||||||
|
const lastPrice = values[values.length - 1] || baseLine;
|
||||||
|
const isUp = lastPrice >= baseLine;
|
||||||
|
|
||||||
|
const option = {
|
||||||
|
grid: {
|
||||||
|
top: 5,
|
||||||
|
right: 5,
|
||||||
|
bottom: 5,
|
||||||
|
left: 5,
|
||||||
|
containLabel: false,
|
||||||
|
},
|
||||||
|
xAxis: {
|
||||||
|
type: 'category',
|
||||||
|
data: times,
|
||||||
|
show: false,
|
||||||
|
boundaryGap: false,
|
||||||
|
},
|
||||||
|
yAxis: {
|
||||||
|
type: 'value',
|
||||||
|
min: baseLine - range,
|
||||||
|
max: baseLine + range,
|
||||||
|
show: false,
|
||||||
|
},
|
||||||
|
series: [
|
||||||
|
{
|
||||||
|
type: 'line',
|
||||||
|
data: values,
|
||||||
|
smooth: false,
|
||||||
|
symbol: 'none',
|
||||||
|
lineStyle: {
|
||||||
|
width: 1.5,
|
||||||
|
color: isUp ? '#ef4444' : '#22c55e',
|
||||||
|
},
|
||||||
|
areaStyle: {
|
||||||
|
color: {
|
||||||
|
type: 'linear',
|
||||||
|
x: 0,
|
||||||
|
y: 0,
|
||||||
|
x2: 0,
|
||||||
|
y2: 1,
|
||||||
|
colorStops: [
|
||||||
|
{ offset: 0, color: isUp ? 'rgba(239, 68, 68, 0.3)' : 'rgba(34, 197, 94, 0.3)' },
|
||||||
|
{ offset: 1, color: isUp ? 'rgba(239, 68, 68, 0.05)' : 'rgba(34, 197, 94, 0.05)' },
|
||||||
|
],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
markLine: {
|
||||||
|
silent: true,
|
||||||
|
symbol: 'none',
|
||||||
|
data: [
|
||||||
|
{
|
||||||
|
yAxis: baseLine,
|
||||||
|
lineStyle: {
|
||||||
|
color: '#666',
|
||||||
|
type: 'dashed',
|
||||||
|
width: 1,
|
||||||
|
},
|
||||||
|
label: { show: false },
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
animation: false,
|
||||||
|
};
|
||||||
|
|
||||||
|
chartInstance.current.setOption(option);
|
||||||
|
|
||||||
|
return () => {
|
||||||
|
// 不在这里销毁,只在组件卸载时销毁
|
||||||
|
};
|
||||||
|
}, [chartData, prevClose, loading]);
|
||||||
|
|
||||||
|
// 组件卸载时销毁图表
|
||||||
|
useEffect(() => {
|
||||||
|
return () => {
|
||||||
|
if (chartInstance.current) {
|
||||||
|
chartInstance.current.dispose();
|
||||||
|
chartInstance.current = null;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
// 窗口 resize 处理
|
||||||
|
useEffect(() => {
|
||||||
|
const handleResize = () => {
|
||||||
|
chartInstance.current?.resize();
|
||||||
|
};
|
||||||
|
window.addEventListener('resize', handleResize);
|
||||||
|
return () => window.removeEventListener('resize', handleResize);
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
if (loading) {
|
||||||
|
return (
|
||||||
|
<Center h={height}>
|
||||||
|
<Spinner size="sm" color="gray.400" />
|
||||||
|
</Center>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (error || !chartData.length) {
|
||||||
|
return (
|
||||||
|
<Center h={height}>
|
||||||
|
<Text fontSize="xs" color="gray.400">
|
||||||
|
{error || '暂无数据'}
|
||||||
|
</Text>
|
||||||
|
</Center>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return <Box ref={chartRef} h={`${height}px`} w="100%" />;
|
||||||
|
};
|
||||||
|
|
||||||
|
export default MiniTimelineChart;
|
||||||
@@ -0,0 +1,275 @@
|
|||||||
|
/**
|
||||||
|
* 盘口行情面板组件
|
||||||
|
* 支持显示 5 档或 10 档买卖盘数据
|
||||||
|
*
|
||||||
|
* 上交所: 5 档行情
|
||||||
|
* 深交所: 10 档行情
|
||||||
|
*/
|
||||||
|
import React, { useState } from 'react';
|
||||||
|
import {
|
||||||
|
Box,
|
||||||
|
VStack,
|
||||||
|
HStack,
|
||||||
|
Text,
|
||||||
|
Button,
|
||||||
|
ButtonGroup,
|
||||||
|
useColorModeValue,
|
||||||
|
Tooltip,
|
||||||
|
Badge,
|
||||||
|
} from '@chakra-ui/react';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 格式化成交量
|
||||||
|
* @param {number} volume - 成交量(股)
|
||||||
|
* @returns {string} 格式化后的字符串
|
||||||
|
*/
|
||||||
|
const formatVolume = (volume) => {
|
||||||
|
if (!volume || volume === 0) return '-';
|
||||||
|
if (volume >= 10000) {
|
||||||
|
return `${(volume / 10000).toFixed(0)}万`;
|
||||||
|
}
|
||||||
|
if (volume >= 1000) {
|
||||||
|
return `${(volume / 1000).toFixed(1)}k`;
|
||||||
|
}
|
||||||
|
return String(volume);
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 格式化价格
|
||||||
|
* @param {number} price - 价格
|
||||||
|
* @param {number} prevClose - 昨收价
|
||||||
|
* @returns {Object} { text, color }
|
||||||
|
*/
|
||||||
|
const formatPrice = (price, prevClose) => {
|
||||||
|
if (!price || price === 0) {
|
||||||
|
return { text: '-', color: 'gray.400' };
|
||||||
|
}
|
||||||
|
|
||||||
|
const text = price.toFixed(2);
|
||||||
|
|
||||||
|
if (!prevClose || prevClose === 0) {
|
||||||
|
return { text, color: 'gray.600' };
|
||||||
|
}
|
||||||
|
|
||||||
|
if (price > prevClose) {
|
||||||
|
return { text, color: 'red.500' };
|
||||||
|
}
|
||||||
|
if (price < prevClose) {
|
||||||
|
return { text, color: 'green.500' };
|
||||||
|
}
|
||||||
|
return { text, color: 'gray.600' };
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 单行盘口
|
||||||
|
*/
|
||||||
|
const OrderRow = ({ label, price, volume, prevClose, isBid, maxVolume, isLimitPrice }) => {
|
||||||
|
const bgColor = useColorModeValue(
|
||||||
|
isBid ? 'red.50' : 'green.50',
|
||||||
|
isBid ? 'rgba(239, 68, 68, 0.1)' : 'rgba(34, 197, 94, 0.1)'
|
||||||
|
);
|
||||||
|
const barColor = useColorModeValue(
|
||||||
|
isBid ? 'red.200' : 'green.200',
|
||||||
|
isBid ? 'rgba(239, 68, 68, 0.3)' : 'rgba(34, 197, 94, 0.3)'
|
||||||
|
);
|
||||||
|
const limitColor = useColorModeValue('orange.500', 'orange.300');
|
||||||
|
|
||||||
|
const priceInfo = formatPrice(price, prevClose);
|
||||||
|
const volumeText = formatVolume(volume);
|
||||||
|
|
||||||
|
// 计算成交量条宽度
|
||||||
|
const barWidth = maxVolume > 0 ? Math.min((volume / maxVolume) * 100, 100) : 0;
|
||||||
|
|
||||||
|
return (
|
||||||
|
<HStack
|
||||||
|
spacing={2}
|
||||||
|
py={0.5}
|
||||||
|
px={1}
|
||||||
|
position="relative"
|
||||||
|
overflow="hidden"
|
||||||
|
fontSize="xs"
|
||||||
|
>
|
||||||
|
{/* 成交量条 */}
|
||||||
|
<Box
|
||||||
|
position="absolute"
|
||||||
|
right={0}
|
||||||
|
top={0}
|
||||||
|
bottom={0}
|
||||||
|
width={`${barWidth}%`}
|
||||||
|
bg={barColor}
|
||||||
|
transition="width 0.2s"
|
||||||
|
/>
|
||||||
|
|
||||||
|
{/* 内容 */}
|
||||||
|
<Text color="gray.500" w="24px" flexShrink={0} zIndex={1}>
|
||||||
|
{label}
|
||||||
|
</Text>
|
||||||
|
<HStack flex={1} justify="flex-end" zIndex={1}>
|
||||||
|
<Text color={isLimitPrice ? limitColor : priceInfo.color} fontWeight="medium">
|
||||||
|
{priceInfo.text}
|
||||||
|
</Text>
|
||||||
|
{isLimitPrice && (
|
||||||
|
<Tooltip label={isBid ? '跌停价' : '涨停价'}>
|
||||||
|
<Badge
|
||||||
|
colorScheme={isBid ? 'green' : 'red'}
|
||||||
|
fontSize="2xs"
|
||||||
|
variant="subtle"
|
||||||
|
>
|
||||||
|
{isBid ? '跌' : '涨'}
|
||||||
|
</Badge>
|
||||||
|
</Tooltip>
|
||||||
|
)}
|
||||||
|
</HStack>
|
||||||
|
<Text color="gray.600" w="40px" textAlign="right" zIndex={1}>
|
||||||
|
{volumeText}
|
||||||
|
</Text>
|
||||||
|
</HStack>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* OrderBookPanel 组件
|
||||||
|
* @param {Object} props
|
||||||
|
* @param {number[]} props.bidPrices - 买档价格(最多10档)
|
||||||
|
* @param {number[]} props.bidVolumes - 买档量
|
||||||
|
* @param {number[]} props.askPrices - 卖档价格(最多10档)
|
||||||
|
* @param {number[]} props.askVolumes - 卖档量
|
||||||
|
* @param {number} props.prevClose - 昨收价
|
||||||
|
* @param {number} props.upperLimit - 涨停价
|
||||||
|
* @param {number} props.lowerLimit - 跌停价
|
||||||
|
* @param {number} props.defaultLevels - 默认显示档数(5 或 10)
|
||||||
|
*/
|
||||||
|
const OrderBookPanel = ({
|
||||||
|
bidPrices = [],
|
||||||
|
bidVolumes = [],
|
||||||
|
askPrices = [],
|
||||||
|
askVolumes = [],
|
||||||
|
prevClose,
|
||||||
|
upperLimit,
|
||||||
|
lowerLimit,
|
||||||
|
defaultLevels = 5,
|
||||||
|
}) => {
|
||||||
|
const borderColor = useColorModeValue('gray.200', 'gray.700');
|
||||||
|
const buttonBg = useColorModeValue('gray.100', 'gray.700');
|
||||||
|
|
||||||
|
// 可切换显示的档位数
|
||||||
|
const maxAvailableLevels = Math.max(bidPrices.length, askPrices.length, 1);
|
||||||
|
const [showLevels, setShowLevels] = useState(Math.min(defaultLevels, maxAvailableLevels));
|
||||||
|
|
||||||
|
// 计算最大成交量(用于条形图比例)
|
||||||
|
const displayBidVolumes = bidVolumes.slice(0, showLevels);
|
||||||
|
const displayAskVolumes = askVolumes.slice(0, showLevels);
|
||||||
|
const allVolumes = [...displayBidVolumes, ...displayAskVolumes].filter(v => v > 0);
|
||||||
|
const maxVolume = allVolumes.length > 0 ? Math.max(...allVolumes) : 0;
|
||||||
|
|
||||||
|
// 判断是否为涨跌停价
|
||||||
|
const isUpperLimit = (price) => upperLimit && Math.abs(price - upperLimit) < 0.001;
|
||||||
|
const isLowerLimit = (price) => lowerLimit && Math.abs(price - lowerLimit) < 0.001;
|
||||||
|
|
||||||
|
// 卖盘(从卖N到卖1,即价格从高到低)
|
||||||
|
const askRows = [];
|
||||||
|
for (let i = showLevels - 1; i >= 0; i--) {
|
||||||
|
askRows.push(
|
||||||
|
<OrderRow
|
||||||
|
key={`ask${i + 1}`}
|
||||||
|
label={`卖${i + 1}`}
|
||||||
|
price={askPrices[i]}
|
||||||
|
volume={askVolumes[i]}
|
||||||
|
prevClose={prevClose}
|
||||||
|
isBid={false}
|
||||||
|
maxVolume={maxVolume}
|
||||||
|
isLimitPrice={isUpperLimit(askPrices[i])}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 买盘(从买1到买N,即价格从高到低)
|
||||||
|
const bidRows = [];
|
||||||
|
for (let i = 0; i < showLevels; i++) {
|
||||||
|
bidRows.push(
|
||||||
|
<OrderRow
|
||||||
|
key={`bid${i + 1}`}
|
||||||
|
label={`买${i + 1}`}
|
||||||
|
price={bidPrices[i]}
|
||||||
|
volume={bidVolumes[i]}
|
||||||
|
prevClose={prevClose}
|
||||||
|
isBid={true}
|
||||||
|
maxVolume={maxVolume}
|
||||||
|
isLimitPrice={isLowerLimit(bidPrices[i])}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 没有数据时的提示
|
||||||
|
const hasData = bidPrices.length > 0 || askPrices.length > 0;
|
||||||
|
|
||||||
|
if (!hasData) {
|
||||||
|
return (
|
||||||
|
<Box textAlign="center" py={2}>
|
||||||
|
<Text fontSize="xs" color="gray.400">
|
||||||
|
暂无盘口数据
|
||||||
|
</Text>
|
||||||
|
</Box>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<VStack spacing={0} align="stretch">
|
||||||
|
{/* 档位切换(只有当有超过5档数据时才显示) */}
|
||||||
|
{maxAvailableLevels > 5 && (
|
||||||
|
<HStack justify="flex-end" mb={1}>
|
||||||
|
<ButtonGroup size="xs" isAttached variant="outline">
|
||||||
|
<Button
|
||||||
|
onClick={() => setShowLevels(5)}
|
||||||
|
bg={showLevels === 5 ? buttonBg : 'transparent'}
|
||||||
|
fontWeight={showLevels === 5 ? 'bold' : 'normal'}
|
||||||
|
>
|
||||||
|
5档
|
||||||
|
</Button>
|
||||||
|
<Button
|
||||||
|
onClick={() => setShowLevels(10)}
|
||||||
|
bg={showLevels === 10 ? buttonBg : 'transparent'}
|
||||||
|
fontWeight={showLevels === 10 ? 'bold' : 'normal'}
|
||||||
|
>
|
||||||
|
10档
|
||||||
|
</Button>
|
||||||
|
</ButtonGroup>
|
||||||
|
</HStack>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* 卖盘 */}
|
||||||
|
{askRows}
|
||||||
|
|
||||||
|
{/* 分隔线 + 当前价信息 */}
|
||||||
|
<Box h="1px" bg={borderColor} my={1} position="relative">
|
||||||
|
{prevClose && (
|
||||||
|
<Text
|
||||||
|
position="absolute"
|
||||||
|
right={0}
|
||||||
|
top="50%"
|
||||||
|
transform="translateY(-50%)"
|
||||||
|
fontSize="2xs"
|
||||||
|
color="gray.400"
|
||||||
|
bg={useColorModeValue('white', '#1a1a1a')}
|
||||||
|
px={1}
|
||||||
|
>
|
||||||
|
昨收 {prevClose.toFixed(2)}
|
||||||
|
</Text>
|
||||||
|
)}
|
||||||
|
</Box>
|
||||||
|
|
||||||
|
{/* 买盘 */}
|
||||||
|
{bidRows}
|
||||||
|
|
||||||
|
{/* 涨跌停价信息 */}
|
||||||
|
{(upperLimit || lowerLimit) && (
|
||||||
|
<HStack justify="space-between" mt={1} fontSize="2xs" color="gray.400">
|
||||||
|
{lowerLimit && <Text>跌停 {lowerLimit.toFixed(2)}</Text>}
|
||||||
|
{upperLimit && <Text>涨停 {upperLimit.toFixed(2)}</Text>}
|
||||||
|
</HStack>
|
||||||
|
)}
|
||||||
|
</VStack>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default OrderBookPanel;
|
||||||
@@ -0,0 +1,270 @@
|
|||||||
|
/**
|
||||||
|
* 行情瓷砖组件
|
||||||
|
* 单个证券的实时行情展示卡片,包含分时图和五档盘口
|
||||||
|
*/
|
||||||
|
import React, { useState } from 'react';
|
||||||
|
import {
|
||||||
|
Box,
|
||||||
|
VStack,
|
||||||
|
HStack,
|
||||||
|
Text,
|
||||||
|
IconButton,
|
||||||
|
Tooltip,
|
||||||
|
useColorModeValue,
|
||||||
|
Collapse,
|
||||||
|
Badge,
|
||||||
|
Flex,
|
||||||
|
Spacer,
|
||||||
|
} from '@chakra-ui/react';
|
||||||
|
import { CloseIcon, ChevronDownIcon, ChevronUpIcon, ExternalLinkIcon } from '@chakra-ui/icons';
|
||||||
|
import { useNavigate } from 'react-router-dom';
|
||||||
|
import MiniTimelineChart from './MiniTimelineChart';
|
||||||
|
import OrderBookPanel from './OrderBookPanel';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 格式化价格显示
|
||||||
|
*/
|
||||||
|
const formatPrice = (price) => {
|
||||||
|
if (!price || isNaN(price)) return '-';
|
||||||
|
return price.toFixed(2);
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 格式化涨跌幅
|
||||||
|
*/
|
||||||
|
const formatChangePct = (pct) => {
|
||||||
|
if (!pct || isNaN(pct)) return '0.00%';
|
||||||
|
const sign = pct > 0 ? '+' : '';
|
||||||
|
return `${sign}${pct.toFixed(2)}%`;
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 格式化涨跌额
|
||||||
|
*/
|
||||||
|
const formatChange = (change) => {
|
||||||
|
if (!change || isNaN(change)) return '-';
|
||||||
|
const sign = change > 0 ? '+' : '';
|
||||||
|
return `${sign}${change.toFixed(2)}`;
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 格式化成交额
|
||||||
|
*/
|
||||||
|
const formatAmount = (amount) => {
|
||||||
|
if (!amount || isNaN(amount)) return '-';
|
||||||
|
if (amount >= 100000000) {
|
||||||
|
return `${(amount / 100000000).toFixed(2)}亿`;
|
||||||
|
}
|
||||||
|
if (amount >= 10000) {
|
||||||
|
return `${(amount / 10000).toFixed(0)}万`;
|
||||||
|
}
|
||||||
|
return amount.toFixed(0);
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* QuoteTile 组件
|
||||||
|
* @param {Object} props
|
||||||
|
* @param {string} props.code - 证券代码
|
||||||
|
* @param {string} props.name - 证券名称
|
||||||
|
* @param {Object} props.quote - 实时行情数据
|
||||||
|
* @param {boolean} props.isIndex - 是否为指数
|
||||||
|
* @param {Function} props.onRemove - 移除回调
|
||||||
|
*/
|
||||||
|
const QuoteTile = ({
|
||||||
|
code,
|
||||||
|
name,
|
||||||
|
quote = {},
|
||||||
|
isIndex = false,
|
||||||
|
onRemove,
|
||||||
|
}) => {
|
||||||
|
const navigate = useNavigate();
|
||||||
|
const [expanded, setExpanded] = useState(true);
|
||||||
|
|
||||||
|
// 颜色主题
|
||||||
|
const cardBg = useColorModeValue('white', '#1a1a1a');
|
||||||
|
const borderColor = useColorModeValue('gray.200', '#333');
|
||||||
|
const hoverBorderColor = useColorModeValue('purple.300', '#666');
|
||||||
|
const textColor = useColorModeValue('gray.800', 'white');
|
||||||
|
const subTextColor = useColorModeValue('gray.500', 'gray.400');
|
||||||
|
|
||||||
|
// 涨跌色
|
||||||
|
const { price, prevClose, change, changePct, amount } = quote;
|
||||||
|
const priceColor = useColorModeValue(
|
||||||
|
!prevClose || price === prevClose ? 'gray.800' :
|
||||||
|
price > prevClose ? 'red.500' : 'green.500',
|
||||||
|
!prevClose || price === prevClose ? 'gray.200' :
|
||||||
|
price > prevClose ? 'red.400' : 'green.400'
|
||||||
|
);
|
||||||
|
|
||||||
|
// 涨跌幅背景色
|
||||||
|
const changeBgColor = useColorModeValue(
|
||||||
|
!changePct || changePct === 0 ? 'gray.100' :
|
||||||
|
changePct > 0 ? 'red.100' : 'green.100',
|
||||||
|
!changePct || changePct === 0 ? 'gray.700' :
|
||||||
|
changePct > 0 ? 'rgba(239, 68, 68, 0.2)' : 'rgba(34, 197, 94, 0.2)'
|
||||||
|
);
|
||||||
|
|
||||||
|
// 跳转到详情页
|
||||||
|
const handleNavigate = () => {
|
||||||
|
if (isIndex) {
|
||||||
|
// 指数暂无详情页
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
navigate(`/company?scode=${code}`);
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Box
|
||||||
|
bg={cardBg}
|
||||||
|
borderWidth="1px"
|
||||||
|
borderColor={borderColor}
|
||||||
|
borderRadius="lg"
|
||||||
|
overflow="hidden"
|
||||||
|
transition="all 0.2s"
|
||||||
|
_hover={{
|
||||||
|
borderColor: hoverBorderColor,
|
||||||
|
boxShadow: 'md',
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
{/* 头部 */}
|
||||||
|
<HStack
|
||||||
|
px={3}
|
||||||
|
py={2}
|
||||||
|
borderBottomWidth={expanded ? '1px' : '0'}
|
||||||
|
borderColor={borderColor}
|
||||||
|
cursor="pointer"
|
||||||
|
onClick={() => setExpanded(!expanded)}
|
||||||
|
>
|
||||||
|
{/* 名称和代码 */}
|
||||||
|
<VStack align="start" spacing={0} flex={1} minW={0}>
|
||||||
|
<HStack spacing={2}>
|
||||||
|
<Text
|
||||||
|
fontWeight="bold"
|
||||||
|
fontSize="sm"
|
||||||
|
color={textColor}
|
||||||
|
noOfLines={1}
|
||||||
|
cursor="pointer"
|
||||||
|
_hover={{ textDecoration: 'underline' }}
|
||||||
|
onClick={(e) => {
|
||||||
|
e.stopPropagation();
|
||||||
|
handleNavigate();
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
{name || code}
|
||||||
|
</Text>
|
||||||
|
{isIndex && (
|
||||||
|
<Badge colorScheme="purple" fontSize="xs">
|
||||||
|
指数
|
||||||
|
</Badge>
|
||||||
|
)}
|
||||||
|
</HStack>
|
||||||
|
<Text fontSize="xs" color={subTextColor}>
|
||||||
|
{code}
|
||||||
|
</Text>
|
||||||
|
</VStack>
|
||||||
|
|
||||||
|
{/* 价格信息 */}
|
||||||
|
<VStack align="end" spacing={0}>
|
||||||
|
<Text fontWeight="bold" fontSize="lg" color={priceColor}>
|
||||||
|
{formatPrice(price)}
|
||||||
|
</Text>
|
||||||
|
<HStack spacing={1}>
|
||||||
|
<Box
|
||||||
|
px={1.5}
|
||||||
|
py={0.5}
|
||||||
|
bg={changeBgColor}
|
||||||
|
borderRadius="sm"
|
||||||
|
fontSize="xs"
|
||||||
|
fontWeight="medium"
|
||||||
|
color={priceColor}
|
||||||
|
>
|
||||||
|
{formatChangePct(changePct)}
|
||||||
|
</Box>
|
||||||
|
<Text fontSize="xs" color={priceColor}>
|
||||||
|
{formatChange(change)}
|
||||||
|
</Text>
|
||||||
|
</HStack>
|
||||||
|
</VStack>
|
||||||
|
|
||||||
|
{/* 操作按钮 */}
|
||||||
|
<HStack spacing={1} ml={2}>
|
||||||
|
<IconButton
|
||||||
|
icon={expanded ? <ChevronUpIcon /> : <ChevronDownIcon />}
|
||||||
|
size="xs"
|
||||||
|
variant="ghost"
|
||||||
|
aria-label={expanded ? '收起' : '展开'}
|
||||||
|
onClick={(e) => {
|
||||||
|
e.stopPropagation();
|
||||||
|
setExpanded(!expanded);
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
<Tooltip label="移除">
|
||||||
|
<IconButton
|
||||||
|
icon={<CloseIcon />}
|
||||||
|
size="xs"
|
||||||
|
variant="ghost"
|
||||||
|
colorScheme="red"
|
||||||
|
aria-label="移除"
|
||||||
|
onClick={(e) => {
|
||||||
|
e.stopPropagation();
|
||||||
|
onRemove?.(code);
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
</Tooltip>
|
||||||
|
</HStack>
|
||||||
|
</HStack>
|
||||||
|
|
||||||
|
{/* 可折叠内容 */}
|
||||||
|
<Collapse in={expanded} animateOpacity>
|
||||||
|
<Box p={3}>
|
||||||
|
{/* 统计信息 */}
|
||||||
|
<HStack spacing={4} mb={3} fontSize="xs" color={subTextColor}>
|
||||||
|
<HStack>
|
||||||
|
<Text>昨收:</Text>
|
||||||
|
<Text color={textColor}>{formatPrice(prevClose)}</Text>
|
||||||
|
</HStack>
|
||||||
|
<HStack>
|
||||||
|
<Text>今开:</Text>
|
||||||
|
<Text color={textColor}>{formatPrice(quote.open)}</Text>
|
||||||
|
</HStack>
|
||||||
|
<HStack>
|
||||||
|
<Text>成交额:</Text>
|
||||||
|
<Text color={textColor}>{formatAmount(amount)}</Text>
|
||||||
|
</HStack>
|
||||||
|
</HStack>
|
||||||
|
|
||||||
|
{/* 分时图 */}
|
||||||
|
<Box mb={3}>
|
||||||
|
<MiniTimelineChart
|
||||||
|
code={code}
|
||||||
|
isIndex={isIndex}
|
||||||
|
prevClose={prevClose}
|
||||||
|
currentPrice={price}
|
||||||
|
height={100}
|
||||||
|
/>
|
||||||
|
</Box>
|
||||||
|
|
||||||
|
{/* 盘口(指数没有盘口) */}
|
||||||
|
{!isIndex && (
|
||||||
|
<Box>
|
||||||
|
<Text fontSize="xs" color={subTextColor} mb={1}>
|
||||||
|
盘口 {quote.bidPrices?.length > 5 ? '(10档)' : '(5档)'}
|
||||||
|
</Text>
|
||||||
|
<OrderBookPanel
|
||||||
|
bidPrices={quote.bidPrices || []}
|
||||||
|
bidVolumes={quote.bidVolumes || []}
|
||||||
|
askPrices={quote.askPrices || []}
|
||||||
|
askVolumes={quote.askVolumes || []}
|
||||||
|
prevClose={prevClose}
|
||||||
|
upperLimit={quote.upperLimit}
|
||||||
|
lowerLimit={quote.lowerLimit}
|
||||||
|
/>
|
||||||
|
</Box>
|
||||||
|
)}
|
||||||
|
</Box>
|
||||||
|
</Collapse>
|
||||||
|
</Box>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default QuoteTile;
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
export { default as MiniTimelineChart } from './MiniTimelineChart';
|
||||||
|
export { default as OrderBookPanel } from './OrderBookPanel';
|
||||||
|
export { default as QuoteTile } from './QuoteTile';
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
export { useRealtimeQuote } from './useRealtimeQuote';
|
||||||
@@ -0,0 +1,692 @@
|
|||||||
|
/**
|
||||||
|
* 实时行情 Hook
|
||||||
|
* 管理上交所和深交所 WebSocket 连接,获取实时行情数据
|
||||||
|
*
|
||||||
|
* 上交所 (SSE): ws://49.232.185.254:8765 - 需主动订阅,提供五档行情
|
||||||
|
* 深交所 (SZSE): ws://222.128.1.157:8765 - 自动推送,提供十档行情
|
||||||
|
*
|
||||||
|
* 深交所支持的数据类型 (category):
|
||||||
|
* - stock (300111): 股票快照,含10档买卖盘
|
||||||
|
* - bond (300211): 债券快照
|
||||||
|
* - afterhours_block (300611): 盘后定价大宗交易
|
||||||
|
* - afterhours_trading (303711): 盘后定价交易
|
||||||
|
* - hk_stock (306311): 港股快照(深港通)
|
||||||
|
* - index (309011): 指数快照
|
||||||
|
* - volume_stats (309111): 成交量统计
|
||||||
|
* - fund_nav (309211): 基金净值
|
||||||
|
*/
|
||||||
|
import { useState, useEffect, useRef, useCallback } from 'react';
|
||||||
|
import { logger } from '@utils/logger';
|
||||||
|
|
||||||
|
// WebSocket 地址配置
|
||||||
|
const WS_CONFIG = {
|
||||||
|
SSE: 'ws://49.232.185.254:8765', // 上交所
|
||||||
|
SZSE: 'ws://222.128.1.157:8765', // 深交所
|
||||||
|
};
|
||||||
|
|
||||||
|
// 心跳间隔 (ms)
|
||||||
|
const HEARTBEAT_INTERVAL = 30000;
|
||||||
|
|
||||||
|
// 重连间隔 (ms)
|
||||||
|
const RECONNECT_INTERVAL = 3000;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 判断证券代码属于哪个交易所
|
||||||
|
* @param {string} code - 证券代码(可带或不带后缀)
|
||||||
|
* @returns {'SSE'|'SZSE'} 交易所标识
|
||||||
|
*/
|
||||||
|
const getExchange = (code) => {
|
||||||
|
const baseCode = code.split('.')[0];
|
||||||
|
|
||||||
|
// 6开头为上海股票
|
||||||
|
if (baseCode.startsWith('6')) {
|
||||||
|
return 'SSE';
|
||||||
|
}
|
||||||
|
|
||||||
|
// 000开头的6位数可能是上证指数或深圳股票
|
||||||
|
if (baseCode.startsWith('000') && baseCode.length === 6) {
|
||||||
|
// 000001-000999 是上证指数范围,但 000001 也是平安银行
|
||||||
|
// 这里需要更精确的判断,暂时把 000 开头当深圳
|
||||||
|
return 'SZSE';
|
||||||
|
}
|
||||||
|
|
||||||
|
// 399开头是深证指数
|
||||||
|
if (baseCode.startsWith('399')) {
|
||||||
|
return 'SZSE';
|
||||||
|
}
|
||||||
|
|
||||||
|
// 0、3开头是深圳股票
|
||||||
|
if (baseCode.startsWith('0') || baseCode.startsWith('3')) {
|
||||||
|
return 'SZSE';
|
||||||
|
}
|
||||||
|
|
||||||
|
// 5开头是上海 ETF
|
||||||
|
if (baseCode.startsWith('5')) {
|
||||||
|
return 'SSE';
|
||||||
|
}
|
||||||
|
|
||||||
|
// 1开头是深圳 ETF/债券
|
||||||
|
if (baseCode.startsWith('1')) {
|
||||||
|
return 'SZSE';
|
||||||
|
}
|
||||||
|
|
||||||
|
// 默认上海
|
||||||
|
return 'SSE';
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 标准化证券代码为无后缀格式
|
||||||
|
*/
|
||||||
|
const normalizeCode = (code) => {
|
||||||
|
return code.split('.')[0];
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 从深交所 bids/asks 数组提取价格和量数组
|
||||||
|
* @param {Array} orderBook - [{price, volume}, ...]
|
||||||
|
* @returns {{ prices: number[], volumes: number[] }}
|
||||||
|
*/
|
||||||
|
const extractOrderBook = (orderBook) => {
|
||||||
|
if (!orderBook || !Array.isArray(orderBook)) {
|
||||||
|
return { prices: [], volumes: [] };
|
||||||
|
}
|
||||||
|
const prices = orderBook.map(item => item.price || 0);
|
||||||
|
const volumes = orderBook.map(item => item.volume || 0);
|
||||||
|
return { prices, volumes };
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 实时行情 Hook
|
||||||
|
* @param {string[]} codes - 订阅的证券代码列表
|
||||||
|
* @returns {Object} { quotes, connected, subscribe, unsubscribe }
|
||||||
|
*/
|
||||||
|
export const useRealtimeQuote = (codes = []) => {
|
||||||
|
// 行情数据 { [code]: QuoteData }
|
||||||
|
const [quotes, setQuotes] = useState({});
|
||||||
|
// 连接状态 { SSE: boolean, SZSE: boolean }
|
||||||
|
const [connected, setConnected] = useState({ SSE: false, SZSE: false });
|
||||||
|
|
||||||
|
// WebSocket 实例引用
|
||||||
|
const wsRefs = useRef({ SSE: null, SZSE: null });
|
||||||
|
// 心跳定时器
|
||||||
|
const heartbeatRefs = useRef({ SSE: null, SZSE: null });
|
||||||
|
// 重连定时器
|
||||||
|
const reconnectRefs = useRef({ SSE: null, SZSE: null });
|
||||||
|
// 当前订阅的代码(按交易所分组)
|
||||||
|
const subscribedCodes = useRef({ SSE: new Set(), SZSE: new Set() });
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 创建 WebSocket 连接
|
||||||
|
*/
|
||||||
|
const createConnection = useCallback((exchange) => {
|
||||||
|
// 清理现有连接
|
||||||
|
if (wsRefs.current[exchange]) {
|
||||||
|
wsRefs.current[exchange].close();
|
||||||
|
}
|
||||||
|
|
||||||
|
const ws = new WebSocket(WS_CONFIG[exchange]);
|
||||||
|
wsRefs.current[exchange] = ws;
|
||||||
|
|
||||||
|
ws.onopen = () => {
|
||||||
|
logger.info('FlexScreen', `${exchange} WebSocket 已连接`);
|
||||||
|
setConnected(prev => ({ ...prev, [exchange]: true }));
|
||||||
|
|
||||||
|
// 上交所需要主动订阅
|
||||||
|
if (exchange === 'SSE') {
|
||||||
|
const codes = Array.from(subscribedCodes.current.SSE);
|
||||||
|
if (codes.length > 0) {
|
||||||
|
ws.send(JSON.stringify({
|
||||||
|
action: 'subscribe',
|
||||||
|
channels: ['stock', 'index'],
|
||||||
|
codes: codes,
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 启动心跳
|
||||||
|
startHeartbeat(exchange);
|
||||||
|
};
|
||||||
|
|
||||||
|
ws.onmessage = (event) => {
|
||||||
|
try {
|
||||||
|
const msg = JSON.parse(event.data);
|
||||||
|
handleMessage(exchange, msg);
|
||||||
|
} catch (e) {
|
||||||
|
logger.warn('FlexScreen', `${exchange} 消息解析失败`, e);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
ws.onerror = (error) => {
|
||||||
|
logger.error('FlexScreen', `${exchange} WebSocket 错误`, error);
|
||||||
|
};
|
||||||
|
|
||||||
|
ws.onclose = () => {
|
||||||
|
logger.info('FlexScreen', `${exchange} WebSocket 断开`);
|
||||||
|
setConnected(prev => ({ ...prev, [exchange]: false }));
|
||||||
|
stopHeartbeat(exchange);
|
||||||
|
|
||||||
|
// 自动重连
|
||||||
|
scheduleReconnect(exchange);
|
||||||
|
};
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 处理 WebSocket 消息
|
||||||
|
*/
|
||||||
|
const handleMessage = useCallback((exchange, msg) => {
|
||||||
|
// 处理 pong
|
||||||
|
if (msg.type === 'pong') {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (exchange === 'SSE') {
|
||||||
|
// 上交所消息格式
|
||||||
|
if (msg.type === 'stock' || msg.type === 'index') {
|
||||||
|
const data = msg.data || {};
|
||||||
|
setQuotes(prev => {
|
||||||
|
const updated = { ...prev };
|
||||||
|
Object.entries(data).forEach(([code, quote]) => {
|
||||||
|
// 只更新订阅的代码
|
||||||
|
if (subscribedCodes.current.SSE.has(code)) {
|
||||||
|
updated[code] = {
|
||||||
|
code: quote.security_id,
|
||||||
|
name: quote.security_name,
|
||||||
|
price: quote.last_price,
|
||||||
|
prevClose: quote.prev_close,
|
||||||
|
open: quote.open_price,
|
||||||
|
high: quote.high_price,
|
||||||
|
low: quote.low_price,
|
||||||
|
volume: quote.volume,
|
||||||
|
amount: quote.amount,
|
||||||
|
change: quote.last_price - quote.prev_close,
|
||||||
|
changePct: quote.prev_close ? ((quote.last_price - quote.prev_close) / quote.prev_close * 100) : 0,
|
||||||
|
bidPrices: quote.bid_prices || [],
|
||||||
|
bidVolumes: quote.bid_volumes || [],
|
||||||
|
askPrices: quote.ask_prices || [],
|
||||||
|
askVolumes: quote.ask_volumes || [],
|
||||||
|
updateTime: quote.trade_time,
|
||||||
|
exchange: 'SSE',
|
||||||
|
};
|
||||||
|
}
|
||||||
|
});
|
||||||
|
return updated;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
} else if (exchange === 'SZSE') {
|
||||||
|
// 深交所消息格式(更新后的 API)
|
||||||
|
if (msg.type === 'realtime') {
|
||||||
|
const { category, data } = msg;
|
||||||
|
const code = data.security_id;
|
||||||
|
|
||||||
|
// 只更新订阅的代码
|
||||||
|
if (!subscribedCodes.current.SZSE.has(code)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (category === 'stock') {
|
||||||
|
// 股票行情 - 含 10 档买卖盘
|
||||||
|
const { prices: bidPrices, volumes: bidVolumes } = extractOrderBook(data.bids);
|
||||||
|
const { prices: askPrices, volumes: askVolumes } = extractOrderBook(data.asks);
|
||||||
|
|
||||||
|
setQuotes(prev => ({
|
||||||
|
...prev,
|
||||||
|
[code]: {
|
||||||
|
code: code,
|
||||||
|
name: prev[code]?.name || '',
|
||||||
|
price: data.last_px,
|
||||||
|
prevClose: data.prev_close,
|
||||||
|
open: data.open_px,
|
||||||
|
high: data.high_px,
|
||||||
|
low: data.low_px,
|
||||||
|
volume: data.volume,
|
||||||
|
amount: data.amount,
|
||||||
|
numTrades: data.num_trades,
|
||||||
|
upperLimit: data.upper_limit, // 涨停价
|
||||||
|
lowerLimit: data.lower_limit, // 跌停价
|
||||||
|
change: data.last_px - data.prev_close,
|
||||||
|
changePct: data.prev_close ? ((data.last_px - data.prev_close) / data.prev_close * 100) : 0,
|
||||||
|
bidPrices,
|
||||||
|
bidVolumes,
|
||||||
|
askPrices,
|
||||||
|
askVolumes,
|
||||||
|
tradingPhase: data.trading_phase,
|
||||||
|
updateTime: msg.timestamp,
|
||||||
|
exchange: 'SZSE',
|
||||||
|
},
|
||||||
|
}));
|
||||||
|
} else if (category === 'index') {
|
||||||
|
// 指数行情
|
||||||
|
setQuotes(prev => ({
|
||||||
|
...prev,
|
||||||
|
[code]: {
|
||||||
|
code: code,
|
||||||
|
name: prev[code]?.name || '',
|
||||||
|
price: data.current_index,
|
||||||
|
prevClose: data.prev_close,
|
||||||
|
open: data.open_index,
|
||||||
|
high: data.high_index,
|
||||||
|
low: data.low_index,
|
||||||
|
close: data.close_index,
|
||||||
|
volume: data.volume,
|
||||||
|
amount: data.amount,
|
||||||
|
numTrades: data.num_trades,
|
||||||
|
change: data.current_index - data.prev_close,
|
||||||
|
changePct: data.prev_close ? ((data.current_index - data.prev_close) / data.prev_close * 100) : 0,
|
||||||
|
bidPrices: [],
|
||||||
|
bidVolumes: [],
|
||||||
|
askPrices: [],
|
||||||
|
askVolumes: [],
|
||||||
|
tradingPhase: data.trading_phase,
|
||||||
|
updateTime: msg.timestamp,
|
||||||
|
exchange: 'SZSE',
|
||||||
|
},
|
||||||
|
}));
|
||||||
|
} else if (category === 'bond') {
|
||||||
|
// 债券行情
|
||||||
|
setQuotes(prev => ({
|
||||||
|
...prev,
|
||||||
|
[code]: {
|
||||||
|
code: code,
|
||||||
|
name: prev[code]?.name || '',
|
||||||
|
price: data.last_px,
|
||||||
|
prevClose: data.prev_close,
|
||||||
|
open: data.open_px,
|
||||||
|
high: data.high_px,
|
||||||
|
low: data.low_px,
|
||||||
|
volume: data.volume,
|
||||||
|
amount: data.amount,
|
||||||
|
numTrades: data.num_trades,
|
||||||
|
weightedAvgPx: data.weighted_avg_px,
|
||||||
|
change: data.last_px - data.prev_close,
|
||||||
|
changePct: data.prev_close ? ((data.last_px - data.prev_close) / data.prev_close * 100) : 0,
|
||||||
|
bidPrices: [],
|
||||||
|
bidVolumes: [],
|
||||||
|
askPrices: [],
|
||||||
|
askVolumes: [],
|
||||||
|
tradingPhase: data.trading_phase,
|
||||||
|
updateTime: msg.timestamp,
|
||||||
|
exchange: 'SZSE',
|
||||||
|
isBond: true,
|
||||||
|
},
|
||||||
|
}));
|
||||||
|
} else if (category === 'hk_stock') {
|
||||||
|
// 港股行情(深港通)
|
||||||
|
const { prices: bidPrices, volumes: bidVolumes } = extractOrderBook(data.bids);
|
||||||
|
const { prices: askPrices, volumes: askVolumes } = extractOrderBook(data.asks);
|
||||||
|
|
||||||
|
setQuotes(prev => ({
|
||||||
|
...prev,
|
||||||
|
[code]: {
|
||||||
|
code: code,
|
||||||
|
name: prev[code]?.name || '',
|
||||||
|
price: data.last_px,
|
||||||
|
prevClose: data.prev_close,
|
||||||
|
open: data.open_px,
|
||||||
|
high: data.high_px,
|
||||||
|
low: data.low_px,
|
||||||
|
volume: data.volume,
|
||||||
|
amount: data.amount,
|
||||||
|
numTrades: data.num_trades,
|
||||||
|
nominalPx: data.nominal_px, // 按盘价
|
||||||
|
referencePx: data.reference_px, // 参考价
|
||||||
|
change: data.last_px - data.prev_close,
|
||||||
|
changePct: data.prev_close ? ((data.last_px - data.prev_close) / data.prev_close * 100) : 0,
|
||||||
|
bidPrices,
|
||||||
|
bidVolumes,
|
||||||
|
askPrices,
|
||||||
|
askVolumes,
|
||||||
|
tradingPhase: data.trading_phase,
|
||||||
|
updateTime: msg.timestamp,
|
||||||
|
exchange: 'SZSE',
|
||||||
|
isHK: true,
|
||||||
|
},
|
||||||
|
}));
|
||||||
|
} else if (category === 'afterhours_block' || category === 'afterhours_trading') {
|
||||||
|
// 盘后交易
|
||||||
|
setQuotes(prev => ({
|
||||||
|
...prev,
|
||||||
|
[code]: {
|
||||||
|
...prev[code],
|
||||||
|
afterhours: {
|
||||||
|
bidPx: data.bid_px,
|
||||||
|
bidSize: data.bid_size,
|
||||||
|
offerPx: data.offer_px,
|
||||||
|
offerSize: data.offer_size,
|
||||||
|
volume: data.volume,
|
||||||
|
amount: data.amount,
|
||||||
|
numTrades: data.num_trades,
|
||||||
|
},
|
||||||
|
updateTime: msg.timestamp,
|
||||||
|
},
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
// fund_nav 和 volume_stats 暂不处理
|
||||||
|
} else if (msg.type === 'snapshot') {
|
||||||
|
// 深交所初始快照
|
||||||
|
const { stocks = [], indexes = [], bonds = [] } = msg.data || {};
|
||||||
|
setQuotes(prev => {
|
||||||
|
const updated = { ...prev };
|
||||||
|
|
||||||
|
stocks.forEach(s => {
|
||||||
|
if (subscribedCodes.current.SZSE.has(s.security_id)) {
|
||||||
|
const { prices: bidPrices, volumes: bidVolumes } = extractOrderBook(s.bids);
|
||||||
|
const { prices: askPrices, volumes: askVolumes } = extractOrderBook(s.asks);
|
||||||
|
|
||||||
|
updated[s.security_id] = {
|
||||||
|
code: s.security_id,
|
||||||
|
name: s.security_name || '',
|
||||||
|
price: s.last_px,
|
||||||
|
prevClose: s.prev_close,
|
||||||
|
open: s.open_px,
|
||||||
|
high: s.high_px,
|
||||||
|
low: s.low_px,
|
||||||
|
volume: s.volume,
|
||||||
|
amount: s.amount,
|
||||||
|
numTrades: s.num_trades,
|
||||||
|
upperLimit: s.upper_limit,
|
||||||
|
lowerLimit: s.lower_limit,
|
||||||
|
change: s.last_px - s.prev_close,
|
||||||
|
changePct: s.prev_close ? ((s.last_px - s.prev_close) / s.prev_close * 100) : 0,
|
||||||
|
bidPrices,
|
||||||
|
bidVolumes,
|
||||||
|
askPrices,
|
||||||
|
askVolumes,
|
||||||
|
exchange: 'SZSE',
|
||||||
|
};
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
indexes.forEach(i => {
|
||||||
|
if (subscribedCodes.current.SZSE.has(i.security_id)) {
|
||||||
|
updated[i.security_id] = {
|
||||||
|
code: i.security_id,
|
||||||
|
name: i.security_name || '',
|
||||||
|
price: i.current_index,
|
||||||
|
prevClose: i.prev_close,
|
||||||
|
open: i.open_index,
|
||||||
|
high: i.high_index,
|
||||||
|
low: i.low_index,
|
||||||
|
volume: i.volume,
|
||||||
|
amount: i.amount,
|
||||||
|
numTrades: i.num_trades,
|
||||||
|
change: i.current_index - i.prev_close,
|
||||||
|
changePct: i.prev_close ? ((i.current_index - i.prev_close) / i.prev_close * 100) : 0,
|
||||||
|
bidPrices: [],
|
||||||
|
bidVolumes: [],
|
||||||
|
askPrices: [],
|
||||||
|
askVolumes: [],
|
||||||
|
exchange: 'SZSE',
|
||||||
|
};
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
bonds.forEach(b => {
|
||||||
|
if (subscribedCodes.current.SZSE.has(b.security_id)) {
|
||||||
|
updated[b.security_id] = {
|
||||||
|
code: b.security_id,
|
||||||
|
name: b.security_name || '',
|
||||||
|
price: b.last_px,
|
||||||
|
prevClose: b.prev_close,
|
||||||
|
open: b.open_px,
|
||||||
|
high: b.high_px,
|
||||||
|
low: b.low_px,
|
||||||
|
volume: b.volume,
|
||||||
|
amount: b.amount,
|
||||||
|
change: b.last_px - b.prev_close,
|
||||||
|
changePct: b.prev_close ? ((b.last_px - b.prev_close) / b.prev_close * 100) : 0,
|
||||||
|
bidPrices: [],
|
||||||
|
bidVolumes: [],
|
||||||
|
askPrices: [],
|
||||||
|
askVolumes: [],
|
||||||
|
exchange: 'SZSE',
|
||||||
|
isBond: true,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
return updated;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 启动心跳
|
||||||
|
*/
|
||||||
|
const startHeartbeat = useCallback((exchange) => {
|
||||||
|
stopHeartbeat(exchange);
|
||||||
|
heartbeatRefs.current[exchange] = setInterval(() => {
|
||||||
|
const ws = wsRefs.current[exchange];
|
||||||
|
if (ws && ws.readyState === WebSocket.OPEN) {
|
||||||
|
if (exchange === 'SSE') {
|
||||||
|
ws.send(JSON.stringify({ action: 'ping' }));
|
||||||
|
} else {
|
||||||
|
ws.send(JSON.stringify({ type: 'ping' }));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}, HEARTBEAT_INTERVAL);
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 停止心跳
|
||||||
|
*/
|
||||||
|
const stopHeartbeat = useCallback((exchange) => {
|
||||||
|
if (heartbeatRefs.current[exchange]) {
|
||||||
|
clearInterval(heartbeatRefs.current[exchange]);
|
||||||
|
heartbeatRefs.current[exchange] = null;
|
||||||
|
}
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 安排重连
|
||||||
|
*/
|
||||||
|
const scheduleReconnect = useCallback((exchange) => {
|
||||||
|
if (reconnectRefs.current[exchange]) {
|
||||||
|
return; // 已有重连计划
|
||||||
|
}
|
||||||
|
|
||||||
|
reconnectRefs.current[exchange] = setTimeout(() => {
|
||||||
|
reconnectRefs.current[exchange] = null;
|
||||||
|
// 只有还有订阅的代码才重连
|
||||||
|
if (subscribedCodes.current[exchange].size > 0) {
|
||||||
|
createConnection(exchange);
|
||||||
|
}
|
||||||
|
}, RECONNECT_INTERVAL);
|
||||||
|
}, [createConnection]);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 订阅证券
|
||||||
|
*/
|
||||||
|
const subscribe = useCallback((code) => {
|
||||||
|
const baseCode = normalizeCode(code);
|
||||||
|
const exchange = getExchange(code);
|
||||||
|
|
||||||
|
// 添加到订阅列表
|
||||||
|
subscribedCodes.current[exchange].add(baseCode);
|
||||||
|
|
||||||
|
// 如果连接已建立,发送订阅消息(仅上交所需要)
|
||||||
|
const ws = wsRefs.current[exchange];
|
||||||
|
if (exchange === 'SSE' && ws && ws.readyState === WebSocket.OPEN) {
|
||||||
|
ws.send(JSON.stringify({
|
||||||
|
action: 'subscribe',
|
||||||
|
channels: ['stock', 'index'],
|
||||||
|
codes: [baseCode],
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果连接未建立,创建连接
|
||||||
|
if (!ws || ws.readyState !== WebSocket.OPEN) {
|
||||||
|
createConnection(exchange);
|
||||||
|
}
|
||||||
|
}, [createConnection]);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 取消订阅
|
||||||
|
*/
|
||||||
|
const unsubscribe = useCallback((code) => {
|
||||||
|
const baseCode = normalizeCode(code);
|
||||||
|
const exchange = getExchange(code);
|
||||||
|
|
||||||
|
// 从订阅列表移除
|
||||||
|
subscribedCodes.current[exchange].delete(baseCode);
|
||||||
|
|
||||||
|
// 从 quotes 中移除
|
||||||
|
setQuotes(prev => {
|
||||||
|
const updated = { ...prev };
|
||||||
|
delete updated[baseCode];
|
||||||
|
return updated;
|
||||||
|
});
|
||||||
|
|
||||||
|
// 如果该交易所没有订阅了,关闭连接
|
||||||
|
if (subscribedCodes.current[exchange].size === 0) {
|
||||||
|
const ws = wsRefs.current[exchange];
|
||||||
|
if (ws) {
|
||||||
|
ws.close();
|
||||||
|
wsRefs.current[exchange] = null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 初始化订阅
|
||||||
|
*/
|
||||||
|
useEffect(() => {
|
||||||
|
if (!codes || codes.length === 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 按交易所分组
|
||||||
|
const sseCodesSet = new Set();
|
||||||
|
const szseCodesSet = new Set();
|
||||||
|
|
||||||
|
codes.forEach(code => {
|
||||||
|
const baseCode = normalizeCode(code);
|
||||||
|
const exchange = getExchange(code);
|
||||||
|
if (exchange === 'SSE') {
|
||||||
|
sseCodesSet.add(baseCode);
|
||||||
|
} else {
|
||||||
|
szseCodesSet.add(baseCode);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// 更新订阅列表
|
||||||
|
subscribedCodes.current.SSE = sseCodesSet;
|
||||||
|
subscribedCodes.current.SZSE = szseCodesSet;
|
||||||
|
|
||||||
|
// 建立连接
|
||||||
|
if (sseCodesSet.size > 0) {
|
||||||
|
createConnection('SSE');
|
||||||
|
}
|
||||||
|
if (szseCodesSet.size > 0) {
|
||||||
|
createConnection('SZSE');
|
||||||
|
}
|
||||||
|
|
||||||
|
// 清理
|
||||||
|
return () => {
|
||||||
|
['SSE', 'SZSE'].forEach(exchange => {
|
||||||
|
stopHeartbeat(exchange);
|
||||||
|
if (reconnectRefs.current[exchange]) {
|
||||||
|
clearTimeout(reconnectRefs.current[exchange]);
|
||||||
|
}
|
||||||
|
const ws = wsRefs.current[exchange];
|
||||||
|
if (ws) {
|
||||||
|
ws.close();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
};
|
||||||
|
}, []); // 只在挂载时执行
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 处理 codes 变化
|
||||||
|
*/
|
||||||
|
useEffect(() => {
|
||||||
|
if (!codes) return;
|
||||||
|
|
||||||
|
// 计算新的订阅列表
|
||||||
|
const newSseCodes = new Set();
|
||||||
|
const newSzseCodes = new Set();
|
||||||
|
|
||||||
|
codes.forEach(code => {
|
||||||
|
const baseCode = normalizeCode(code);
|
||||||
|
const exchange = getExchange(code);
|
||||||
|
if (exchange === 'SSE') {
|
||||||
|
newSseCodes.add(baseCode);
|
||||||
|
} else {
|
||||||
|
newSzseCodes.add(baseCode);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// 找出需要新增和删除的代码
|
||||||
|
const oldSseCodes = subscribedCodes.current.SSE;
|
||||||
|
const oldSzseCodes = subscribedCodes.current.SZSE;
|
||||||
|
|
||||||
|
// 更新上交所订阅
|
||||||
|
const sseToAdd = [...newSseCodes].filter(c => !oldSseCodes.has(c));
|
||||||
|
const sseToRemove = [...oldSseCodes].filter(c => !newSseCodes.has(c));
|
||||||
|
|
||||||
|
if (sseToAdd.length > 0 || sseToRemove.length > 0) {
|
||||||
|
subscribedCodes.current.SSE = newSseCodes;
|
||||||
|
|
||||||
|
const ws = wsRefs.current.SSE;
|
||||||
|
if (ws && ws.readyState === WebSocket.OPEN && sseToAdd.length > 0) {
|
||||||
|
ws.send(JSON.stringify({
|
||||||
|
action: 'subscribe',
|
||||||
|
channels: ['stock', 'index'],
|
||||||
|
codes: sseToAdd,
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果新增了代码但连接未建立
|
||||||
|
if (sseToAdd.length > 0 && (!ws || ws.readyState !== WebSocket.OPEN)) {
|
||||||
|
createConnection('SSE');
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果没有订阅了,关闭连接
|
||||||
|
if (newSseCodes.size === 0 && ws) {
|
||||||
|
ws.close();
|
||||||
|
wsRefs.current.SSE = null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 更新深交所订阅
|
||||||
|
const szseToAdd = [...newSzseCodes].filter(c => !oldSzseCodes.has(c));
|
||||||
|
const szseToRemove = [...oldSzseCodes].filter(c => !newSzseCodes.has(c));
|
||||||
|
|
||||||
|
if (szseToAdd.length > 0 || szseToRemove.length > 0) {
|
||||||
|
subscribedCodes.current.SZSE = newSzseCodes;
|
||||||
|
|
||||||
|
// 深交所是自动推送,只需要管理连接
|
||||||
|
const ws = wsRefs.current.SZSE;
|
||||||
|
|
||||||
|
if (szseToAdd.length > 0 && (!ws || ws.readyState !== WebSocket.OPEN)) {
|
||||||
|
createConnection('SZSE');
|
||||||
|
}
|
||||||
|
|
||||||
|
if (newSzseCodes.size === 0 && ws) {
|
||||||
|
ws.close();
|
||||||
|
wsRefs.current.SZSE = null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 清理已取消订阅的 quotes
|
||||||
|
const removedCodes = [...sseToRemove, ...szseToRemove];
|
||||||
|
if (removedCodes.length > 0) {
|
||||||
|
setQuotes(prev => {
|
||||||
|
const updated = { ...prev };
|
||||||
|
removedCodes.forEach(code => {
|
||||||
|
delete updated[code];
|
||||||
|
});
|
||||||
|
return updated;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}, [codes, createConnection]);
|
||||||
|
|
||||||
|
return {
|
||||||
|
quotes,
|
||||||
|
connected,
|
||||||
|
subscribe,
|
||||||
|
unsubscribe,
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
export default useRealtimeQuote;
|
||||||
463
src/views/StockOverview/components/FlexScreen/index.js
Normal file
463
src/views/StockOverview/components/FlexScreen/index.js
Normal file
@@ -0,0 +1,463 @@
|
|||||||
|
/**
|
||||||
|
* 灵活屏组件
|
||||||
|
* 用户可自定义添加关注的指数/个股,实时显示行情
|
||||||
|
*
|
||||||
|
* 功能:
|
||||||
|
* 1. 添加/删除自选证券
|
||||||
|
* 2. 显示实时行情(通过 WebSocket)
|
||||||
|
* 3. 显示分时走势(结合 ClickHouse 历史数据)
|
||||||
|
* 4. 显示五档盘口(上交所完整五档,深交所买一卖一)
|
||||||
|
* 5. 本地存储自选列表
|
||||||
|
*/
|
||||||
|
import React, { useState, useEffect, useCallback, useMemo } from 'react';
|
||||||
|
import {
|
||||||
|
Box,
|
||||||
|
Card,
|
||||||
|
CardBody,
|
||||||
|
VStack,
|
||||||
|
HStack,
|
||||||
|
Heading,
|
||||||
|
Text,
|
||||||
|
Input,
|
||||||
|
InputGroup,
|
||||||
|
InputLeftElement,
|
||||||
|
InputRightElement,
|
||||||
|
IconButton,
|
||||||
|
Button,
|
||||||
|
SimpleGrid,
|
||||||
|
Flex,
|
||||||
|
Spacer,
|
||||||
|
Icon,
|
||||||
|
useColorModeValue,
|
||||||
|
useToast,
|
||||||
|
Badge,
|
||||||
|
Tooltip,
|
||||||
|
Collapse,
|
||||||
|
List,
|
||||||
|
ListItem,
|
||||||
|
Spinner,
|
||||||
|
Center,
|
||||||
|
Menu,
|
||||||
|
MenuButton,
|
||||||
|
MenuList,
|
||||||
|
MenuItem,
|
||||||
|
Divider,
|
||||||
|
Tag,
|
||||||
|
TagLabel,
|
||||||
|
} from '@chakra-ui/react';
|
||||||
|
import {
|
||||||
|
SearchIcon,
|
||||||
|
CloseIcon,
|
||||||
|
AddIcon,
|
||||||
|
ChevronDownIcon,
|
||||||
|
ChevronUpIcon,
|
||||||
|
SettingsIcon,
|
||||||
|
} from '@chakra-ui/icons';
|
||||||
|
import { FaDesktop, FaPlus, FaTrash, FaSync, FaWifi, FaExclamationCircle } from 'react-icons/fa';
|
||||||
|
|
||||||
|
import { useRealtimeQuote } from './hooks';
|
||||||
|
import { QuoteTile } from './components';
|
||||||
|
import { logger } from '@utils/logger';
|
||||||
|
|
||||||
|
// 本地存储 key
|
||||||
|
const STORAGE_KEY = 'flexscreen_watchlist';
|
||||||
|
|
||||||
|
// 默认自选列表
|
||||||
|
const DEFAULT_WATCHLIST = [
|
||||||
|
{ code: '000001', name: '上证指数', isIndex: true },
|
||||||
|
{ code: '399001', name: '深证成指', isIndex: true },
|
||||||
|
{ code: '399006', name: '创业板指', isIndex: true },
|
||||||
|
];
|
||||||
|
|
||||||
|
// 热门推荐
|
||||||
|
const HOT_RECOMMENDATIONS = [
|
||||||
|
{ code: '000001', name: '上证指数', isIndex: true },
|
||||||
|
{ code: '399001', name: '深证成指', isIndex: true },
|
||||||
|
{ code: '399006', name: '创业板指', isIndex: true },
|
||||||
|
{ code: '399300', name: '沪深300', isIndex: true },
|
||||||
|
{ code: '600519', name: '贵州茅台', isIndex: false },
|
||||||
|
{ code: '000858', name: '五粮液', isIndex: false },
|
||||||
|
{ code: '300750', name: '宁德时代', isIndex: false },
|
||||||
|
{ code: '002594', name: '比亚迪', isIndex: false },
|
||||||
|
];
|
||||||
|
|
||||||
|
/**
|
||||||
|
* FlexScreen 组件
|
||||||
|
*/
|
||||||
|
const FlexScreen = () => {
|
||||||
|
const toast = useToast();
|
||||||
|
|
||||||
|
// 自选列表
|
||||||
|
const [watchlist, setWatchlist] = useState([]);
|
||||||
|
// 搜索状态
|
||||||
|
const [searchQuery, setSearchQuery] = useState('');
|
||||||
|
const [searchResults, setSearchResults] = useState([]);
|
||||||
|
const [isSearching, setIsSearching] = useState(false);
|
||||||
|
const [showResults, setShowResults] = useState(false);
|
||||||
|
// 面板状态
|
||||||
|
const [isCollapsed, setIsCollapsed] = useState(false);
|
||||||
|
|
||||||
|
// 颜色主题
|
||||||
|
const cardBg = useColorModeValue('white', '#1a1a1a');
|
||||||
|
const borderColor = useColorModeValue('gray.200', '#333');
|
||||||
|
const textColor = useColorModeValue('gray.800', 'white');
|
||||||
|
const subTextColor = useColorModeValue('gray.600', 'gray.400');
|
||||||
|
const searchBg = useColorModeValue('gray.50', '#2a2a2a');
|
||||||
|
const hoverBg = useColorModeValue('gray.100', '#333');
|
||||||
|
|
||||||
|
// 获取订阅的证券代码列表
|
||||||
|
const subscribedCodes = useMemo(() => {
|
||||||
|
return watchlist.map(item => item.code);
|
||||||
|
}, [watchlist]);
|
||||||
|
|
||||||
|
// WebSocket 实时行情
|
||||||
|
const { quotes, connected } = useRealtimeQuote(subscribedCodes);
|
||||||
|
|
||||||
|
// 从本地存储加载自选列表
|
||||||
|
useEffect(() => {
|
||||||
|
try {
|
||||||
|
const saved = localStorage.getItem(STORAGE_KEY);
|
||||||
|
if (saved) {
|
||||||
|
const parsed = JSON.parse(saved);
|
||||||
|
if (Array.isArray(parsed) && parsed.length > 0) {
|
||||||
|
setWatchlist(parsed);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (e) {
|
||||||
|
logger.warn('FlexScreen', '加载自选列表失败', e);
|
||||||
|
}
|
||||||
|
// 使用默认列表
|
||||||
|
setWatchlist(DEFAULT_WATCHLIST);
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
// 保存自选列表到本地存储
|
||||||
|
useEffect(() => {
|
||||||
|
if (watchlist.length > 0) {
|
||||||
|
try {
|
||||||
|
localStorage.setItem(STORAGE_KEY, JSON.stringify(watchlist));
|
||||||
|
} catch (e) {
|
||||||
|
logger.warn('FlexScreen', '保存自选列表失败', e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}, [watchlist]);
|
||||||
|
|
||||||
|
// 搜索证券
|
||||||
|
const searchSecurities = useCallback(async (query) => {
|
||||||
|
if (!query.trim()) {
|
||||||
|
setSearchResults([]);
|
||||||
|
setShowResults(false);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
setIsSearching(true);
|
||||||
|
try {
|
||||||
|
const response = await fetch(`/api/stocks/search?q=${encodeURIComponent(query)}&limit=10`);
|
||||||
|
const data = await response.json();
|
||||||
|
|
||||||
|
if (data.success) {
|
||||||
|
setSearchResults(data.data || []);
|
||||||
|
setShowResults(true);
|
||||||
|
} else {
|
||||||
|
setSearchResults([]);
|
||||||
|
}
|
||||||
|
} catch (e) {
|
||||||
|
logger.error('FlexScreen', '搜索失败', e);
|
||||||
|
setSearchResults([]);
|
||||||
|
} finally {
|
||||||
|
setIsSearching(false);
|
||||||
|
}
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
// 防抖搜索
|
||||||
|
useEffect(() => {
|
||||||
|
const timer = setTimeout(() => {
|
||||||
|
searchSecurities(searchQuery);
|
||||||
|
}, 300);
|
||||||
|
return () => clearTimeout(timer);
|
||||||
|
}, [searchQuery, searchSecurities]);
|
||||||
|
|
||||||
|
// 添加证券
|
||||||
|
const addSecurity = useCallback((security) => {
|
||||||
|
const code = security.stock_code || security.code;
|
||||||
|
const name = security.stock_name || security.name;
|
||||||
|
const isIndex = security.isIndex || code.startsWith('000') || code.startsWith('399');
|
||||||
|
|
||||||
|
// 检查是否已存在
|
||||||
|
if (watchlist.some(item => item.code === code)) {
|
||||||
|
toast({
|
||||||
|
title: '已在自选列表中',
|
||||||
|
status: 'info',
|
||||||
|
duration: 2000,
|
||||||
|
isClosable: true,
|
||||||
|
});
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 添加到列表
|
||||||
|
setWatchlist(prev => [...prev, { code, name, isIndex }]);
|
||||||
|
|
||||||
|
toast({
|
||||||
|
title: `已添加 ${name}`,
|
||||||
|
status: 'success',
|
||||||
|
duration: 2000,
|
||||||
|
isClosable: true,
|
||||||
|
});
|
||||||
|
|
||||||
|
// 清空搜索
|
||||||
|
setSearchQuery('');
|
||||||
|
setShowResults(false);
|
||||||
|
}, [watchlist, toast]);
|
||||||
|
|
||||||
|
// 移除证券
|
||||||
|
const removeSecurity = useCallback((code) => {
|
||||||
|
setWatchlist(prev => prev.filter(item => item.code !== code));
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
// 清空自选列表
|
||||||
|
const clearWatchlist = useCallback(() => {
|
||||||
|
setWatchlist([]);
|
||||||
|
localStorage.removeItem(STORAGE_KEY);
|
||||||
|
toast({
|
||||||
|
title: '已清空自选列表',
|
||||||
|
status: 'info',
|
||||||
|
duration: 2000,
|
||||||
|
isClosable: true,
|
||||||
|
});
|
||||||
|
}, [toast]);
|
||||||
|
|
||||||
|
// 重置为默认列表
|
||||||
|
const resetWatchlist = useCallback(() => {
|
||||||
|
setWatchlist(DEFAULT_WATCHLIST);
|
||||||
|
toast({
|
||||||
|
title: '已重置为默认列表',
|
||||||
|
status: 'success',
|
||||||
|
duration: 2000,
|
||||||
|
isClosable: true,
|
||||||
|
});
|
||||||
|
}, [toast]);
|
||||||
|
|
||||||
|
// 连接状态指示
|
||||||
|
const isAnyConnected = connected.SSE || connected.SZSE;
|
||||||
|
const connectionStatus = useMemo(() => {
|
||||||
|
if (connected.SSE && connected.SZSE) {
|
||||||
|
return { color: 'green', text: '上交所/深交所 已连接' };
|
||||||
|
}
|
||||||
|
if (connected.SSE) {
|
||||||
|
return { color: 'yellow', text: '上交所 已连接' };
|
||||||
|
}
|
||||||
|
if (connected.SZSE) {
|
||||||
|
return { color: 'yellow', text: '深交所 已连接' };
|
||||||
|
}
|
||||||
|
return { color: 'red', text: '未连接' };
|
||||||
|
}, [connected]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Card bg={cardBg} borderWidth="1px" borderColor={borderColor}>
|
||||||
|
<CardBody>
|
||||||
|
{/* 头部 */}
|
||||||
|
<Flex align="center" mb={4}>
|
||||||
|
<HStack spacing={3}>
|
||||||
|
<Icon as={FaDesktop} boxSize={6} color="purple.500" />
|
||||||
|
<Heading size="md" color={textColor}>
|
||||||
|
灵活屏
|
||||||
|
</Heading>
|
||||||
|
<Tooltip label={connectionStatus.text}>
|
||||||
|
<Badge
|
||||||
|
colorScheme={connectionStatus.color}
|
||||||
|
variant="subtle"
|
||||||
|
display="flex"
|
||||||
|
alignItems="center"
|
||||||
|
gap={1}
|
||||||
|
>
|
||||||
|
<Icon as={FaWifi} boxSize={3} />
|
||||||
|
{isAnyConnected ? '实时' : '离线'}
|
||||||
|
</Badge>
|
||||||
|
</Tooltip>
|
||||||
|
</HStack>
|
||||||
|
<Spacer />
|
||||||
|
<HStack spacing={2}>
|
||||||
|
{/* 操作菜单 */}
|
||||||
|
<Menu>
|
||||||
|
<MenuButton
|
||||||
|
as={IconButton}
|
||||||
|
icon={<SettingsIcon />}
|
||||||
|
size="sm"
|
||||||
|
variant="ghost"
|
||||||
|
aria-label="设置"
|
||||||
|
/>
|
||||||
|
<MenuList>
|
||||||
|
<MenuItem icon={<FaSync />} onClick={resetWatchlist}>
|
||||||
|
重置为默认
|
||||||
|
</MenuItem>
|
||||||
|
<MenuItem icon={<FaTrash />} onClick={clearWatchlist} color="red.500">
|
||||||
|
清空列表
|
||||||
|
</MenuItem>
|
||||||
|
</MenuList>
|
||||||
|
</Menu>
|
||||||
|
{/* 折叠按钮 */}
|
||||||
|
<IconButton
|
||||||
|
icon={isCollapsed ? <ChevronDownIcon /> : <ChevronUpIcon />}
|
||||||
|
size="sm"
|
||||||
|
variant="ghost"
|
||||||
|
onClick={() => setIsCollapsed(!isCollapsed)}
|
||||||
|
aria-label={isCollapsed ? '展开' : '收起'}
|
||||||
|
/>
|
||||||
|
</HStack>
|
||||||
|
</Flex>
|
||||||
|
|
||||||
|
{/* 可折叠内容 */}
|
||||||
|
<Collapse in={!isCollapsed} animateOpacity>
|
||||||
|
{/* 搜索框 */}
|
||||||
|
<Box position="relative" mb={4}>
|
||||||
|
<InputGroup size="md">
|
||||||
|
<InputLeftElement pointerEvents="none">
|
||||||
|
<SearchIcon color="gray.400" />
|
||||||
|
</InputLeftElement>
|
||||||
|
<Input
|
||||||
|
placeholder="搜索股票/指数代码或名称..."
|
||||||
|
value={searchQuery}
|
||||||
|
onChange={(e) => setSearchQuery(e.target.value)}
|
||||||
|
bg={searchBg}
|
||||||
|
borderRadius="lg"
|
||||||
|
_focus={{
|
||||||
|
borderColor: 'purple.400',
|
||||||
|
boxShadow: '0 0 0 1px var(--chakra-colors-purple-400)',
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
{searchQuery && (
|
||||||
|
<InputRightElement>
|
||||||
|
<IconButton
|
||||||
|
size="sm"
|
||||||
|
icon={<CloseIcon />}
|
||||||
|
variant="ghost"
|
||||||
|
onClick={() => {
|
||||||
|
setSearchQuery('');
|
||||||
|
setShowResults(false);
|
||||||
|
}}
|
||||||
|
aria-label="清空"
|
||||||
|
/>
|
||||||
|
</InputRightElement>
|
||||||
|
)}
|
||||||
|
</InputGroup>
|
||||||
|
|
||||||
|
{/* 搜索结果下拉 */}
|
||||||
|
<Collapse in={showResults} animateOpacity>
|
||||||
|
<Box
|
||||||
|
position="absolute"
|
||||||
|
top="100%"
|
||||||
|
left={0}
|
||||||
|
right={0}
|
||||||
|
mt={1}
|
||||||
|
bg={cardBg}
|
||||||
|
borderWidth="1px"
|
||||||
|
borderColor={borderColor}
|
||||||
|
borderRadius="lg"
|
||||||
|
boxShadow="lg"
|
||||||
|
maxH="300px"
|
||||||
|
overflowY="auto"
|
||||||
|
zIndex={10}
|
||||||
|
>
|
||||||
|
{isSearching ? (
|
||||||
|
<Center p={4}>
|
||||||
|
<Spinner size="sm" color="purple.500" />
|
||||||
|
</Center>
|
||||||
|
) : searchResults.length > 0 ? (
|
||||||
|
<List spacing={0}>
|
||||||
|
{searchResults.map((stock, index) => (
|
||||||
|
<ListItem
|
||||||
|
key={stock.stock_code}
|
||||||
|
px={4}
|
||||||
|
py={2}
|
||||||
|
cursor="pointer"
|
||||||
|
_hover={{ bg: hoverBg }}
|
||||||
|
onClick={() => addSecurity(stock)}
|
||||||
|
borderBottomWidth={index < searchResults.length - 1 ? '1px' : '0'}
|
||||||
|
borderColor={borderColor}
|
||||||
|
>
|
||||||
|
<HStack justify="space-between">
|
||||||
|
<VStack align="start" spacing={0}>
|
||||||
|
<Text fontWeight="medium" color={textColor}>
|
||||||
|
{stock.stock_name}
|
||||||
|
</Text>
|
||||||
|
<Text fontSize="xs" color={subTextColor}>
|
||||||
|
{stock.stock_code}
|
||||||
|
</Text>
|
||||||
|
</VStack>
|
||||||
|
<IconButton
|
||||||
|
icon={<AddIcon />}
|
||||||
|
size="xs"
|
||||||
|
colorScheme="purple"
|
||||||
|
variant="ghost"
|
||||||
|
aria-label="添加"
|
||||||
|
/>
|
||||||
|
</HStack>
|
||||||
|
</ListItem>
|
||||||
|
))}
|
||||||
|
</List>
|
||||||
|
) : (
|
||||||
|
<Center p={4}>
|
||||||
|
<Text color={subTextColor} fontSize="sm">
|
||||||
|
未找到相关证券
|
||||||
|
</Text>
|
||||||
|
</Center>
|
||||||
|
)}
|
||||||
|
</Box>
|
||||||
|
</Collapse>
|
||||||
|
</Box>
|
||||||
|
|
||||||
|
{/* 快捷添加 */}
|
||||||
|
{watchlist.length === 0 && (
|
||||||
|
<Box mb={4}>
|
||||||
|
<Text fontSize="sm" color={subTextColor} mb={2}>
|
||||||
|
热门推荐(点击添加)
|
||||||
|
</Text>
|
||||||
|
<Flex flexWrap="wrap" gap={2}>
|
||||||
|
{HOT_RECOMMENDATIONS.map((item) => (
|
||||||
|
<Tag
|
||||||
|
key={item.code}
|
||||||
|
size="md"
|
||||||
|
variant="subtle"
|
||||||
|
colorScheme="purple"
|
||||||
|
cursor="pointer"
|
||||||
|
_hover={{ bg: 'purple.100' }}
|
||||||
|
onClick={() => addSecurity(item)}
|
||||||
|
>
|
||||||
|
<TagLabel>{item.name}</TagLabel>
|
||||||
|
</Tag>
|
||||||
|
))}
|
||||||
|
</Flex>
|
||||||
|
</Box>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* 自选列表 */}
|
||||||
|
{watchlist.length > 0 ? (
|
||||||
|
<SimpleGrid columns={{ base: 1, md: 2, lg: 3 }} spacing={4}>
|
||||||
|
{watchlist.map((item) => (
|
||||||
|
<QuoteTile
|
||||||
|
key={item.code}
|
||||||
|
code={item.code}
|
||||||
|
name={item.name}
|
||||||
|
quote={quotes[item.code] || {}}
|
||||||
|
isIndex={item.isIndex}
|
||||||
|
onRemove={removeSecurity}
|
||||||
|
/>
|
||||||
|
))}
|
||||||
|
</SimpleGrid>
|
||||||
|
) : (
|
||||||
|
<Center py={8}>
|
||||||
|
<VStack spacing={3}>
|
||||||
|
<Icon as={FaExclamationCircle} boxSize={10} color="gray.300" />
|
||||||
|
<Text color={subTextColor}>
|
||||||
|
自选列表为空,请搜索添加证券
|
||||||
|
</Text>
|
||||||
|
</VStack>
|
||||||
|
</Center>
|
||||||
|
)}
|
||||||
|
</Collapse>
|
||||||
|
</CardBody>
|
||||||
|
</Card>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default FlexScreen;
|
||||||
@@ -18,6 +18,69 @@ import {
|
|||||||
import { FaBolt, FaArrowUp, FaArrowDown, FaChartLine, FaFire, FaVolumeUp } from 'react-icons/fa';
|
import { FaBolt, FaArrowUp, FaArrowDown, FaChartLine, FaFire, FaVolumeUp } from 'react-icons/fa';
|
||||||
import { getAlertTypeLabel, formatScore, getScoreColor } from '../utils/chartHelpers';
|
import { getAlertTypeLabel, formatScore, getScoreColor } from '../utils/chartHelpers';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Z-Score 指示器组件
|
||||||
|
*/
|
||||||
|
const ZScoreIndicator = ({ value, label, tooltip }) => {
|
||||||
|
if (value === null || value === undefined) return null;
|
||||||
|
|
||||||
|
// Z-Score 颜色:越大越红,越小越绿
|
||||||
|
const getZScoreColor = (z) => {
|
||||||
|
const absZ = Math.abs(z);
|
||||||
|
if (absZ >= 3) return z > 0 ? 'red.600' : 'green.600';
|
||||||
|
if (absZ >= 2) return z > 0 ? 'red.500' : 'green.500';
|
||||||
|
if (absZ >= 1) return z > 0 ? 'orange.400' : 'teal.400';
|
||||||
|
return 'gray.400';
|
||||||
|
};
|
||||||
|
|
||||||
|
// Z-Score 强度条宽度(最大 5σ)
|
||||||
|
const barWidth = Math.min(Math.abs(value) / 5 * 100, 100);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Tooltip label={tooltip || `${label}: ${value.toFixed(2)}σ`} placement="left">
|
||||||
|
<HStack spacing={1} fontSize="xs">
|
||||||
|
<Text color="gray.500" w="20px">{label}</Text>
|
||||||
|
<Box position="relative" w="40px" h="6px" bg="gray.200" borderRadius="full" overflow="hidden">
|
||||||
|
<Box
|
||||||
|
position="absolute"
|
||||||
|
left={value >= 0 ? '50%' : `${50 - barWidth / 2}%`}
|
||||||
|
w={`${barWidth / 2}%`}
|
||||||
|
h="100%"
|
||||||
|
bg={getZScoreColor(value)}
|
||||||
|
borderRadius="full"
|
||||||
|
/>
|
||||||
|
</Box>
|
||||||
|
<Text color={getZScoreColor(value)} fontWeight="medium" w="28px" textAlign="right">
|
||||||
|
{value >= 0 ? '+' : ''}{value.toFixed(1)}
|
||||||
|
</Text>
|
||||||
|
</HStack>
|
||||||
|
</Tooltip>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 持续确认率指示器
|
||||||
|
*/
|
||||||
|
const ConfirmRatioIndicator = ({ ratio }) => {
|
||||||
|
if (ratio === null || ratio === undefined) return null;
|
||||||
|
|
||||||
|
const percent = Math.round(ratio * 100);
|
||||||
|
const color = percent >= 80 ? 'green' : percent >= 60 ? 'orange' : 'red';
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Tooltip label={`持续确认率: ${percent}%(5分钟窗口内超标比例)`}>
|
||||||
|
<HStack spacing={1}>
|
||||||
|
<Box position="relative" w="32px" h="6px" bg="gray.200" borderRadius="full" overflow="hidden">
|
||||||
|
<Box w={`${percent}%`} h="100%" bg={`${color}.500`} borderRadius="full" />
|
||||||
|
</Box>
|
||||||
|
<Text fontSize="xs" color={`${color}.500`} fontWeight="medium">
|
||||||
|
{percent}%
|
||||||
|
</Text>
|
||||||
|
</HStack>
|
||||||
|
</Tooltip>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 单个异动项组件
|
* 单个异动项组件
|
||||||
*/
|
*/
|
||||||
@@ -29,6 +92,7 @@ const AlertItem = ({ alert, onClick, isSelected }) => {
|
|||||||
|
|
||||||
const isUp = alert.alert_type !== 'surge_down';
|
const isUp = alert.alert_type !== 'surge_down';
|
||||||
const typeColor = isUp ? 'red' : 'green';
|
const typeColor = isUp ? 'red' : 'green';
|
||||||
|
const isV2 = alert.is_v2;
|
||||||
|
|
||||||
// 获取异动类型图标
|
// 获取异动类型图标
|
||||||
const getTypeIcon = (type) => {
|
const getTypeIcon = (type) => {
|
||||||
@@ -69,13 +133,30 @@ const AlertItem = ({ alert, onClick, isSelected }) => {
|
|||||||
<Text fontWeight="bold" fontSize="sm" noOfLines={1}>
|
<Text fontWeight="bold" fontSize="sm" noOfLines={1}>
|
||||||
{alert.concept_name}
|
{alert.concept_name}
|
||||||
</Text>
|
</Text>
|
||||||
|
{isV2 && (
|
||||||
|
<Badge colorScheme="purple" size="xs" variant="subtle" fontSize="10px">
|
||||||
|
V2
|
||||||
|
</Badge>
|
||||||
|
)}
|
||||||
</HStack>
|
</HStack>
|
||||||
<HStack spacing={2} fontSize="xs" color="gray.500">
|
<HStack spacing={2} fontSize="xs" color="gray.500">
|
||||||
<Text>{alert.time}</Text>
|
<Text>{alert.time}</Text>
|
||||||
<Badge colorScheme={typeColor} size="sm" variant="subtle">
|
<Badge colorScheme={typeColor} size="sm" variant="subtle">
|
||||||
{getAlertTypeLabel(alert.alert_type)}
|
{getAlertTypeLabel(alert.alert_type)}
|
||||||
</Badge>
|
</Badge>
|
||||||
|
{/* V2: 持续确认率 */}
|
||||||
|
{isV2 && alert.confirm_ratio !== undefined && (
|
||||||
|
<ConfirmRatioIndicator ratio={alert.confirm_ratio} />
|
||||||
|
)}
|
||||||
</HStack>
|
</HStack>
|
||||||
|
|
||||||
|
{/* V2: Z-Score 指标行 */}
|
||||||
|
{isV2 && (alert.alpha_zscore !== undefined || alert.amt_zscore !== undefined) && (
|
||||||
|
<HStack spacing={3} mt={1}>
|
||||||
|
<ZScoreIndicator value={alert.alpha_zscore} label="α" tooltip={`Alpha Z-Score: ${alert.alpha_zscore?.toFixed(2)}σ(相对于历史同时段)`} />
|
||||||
|
<ZScoreIndicator value={alert.amt_zscore} label="量" tooltip={`成交额 Z-Score: ${alert.amt_zscore?.toFixed(2)}σ`} />
|
||||||
|
</HStack>
|
||||||
|
)}
|
||||||
</VStack>
|
</VStack>
|
||||||
|
|
||||||
{/* 右侧:分数和关键指标 */}
|
{/* 右侧:分数和关键指标 */}
|
||||||
@@ -104,12 +185,29 @@ const AlertItem = ({ alert, onClick, isSelected }) => {
|
|||||||
</Text>
|
</Text>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
{/* 涨停数量 */}
|
{/* V2: 动量指标 */}
|
||||||
{alert.limit_up_count > 0 && (
|
{isV2 && alert.momentum_5m !== undefined && Math.abs(alert.momentum_5m) > 0.3 && (
|
||||||
|
<HStack spacing={1}>
|
||||||
|
<Icon
|
||||||
|
as={alert.momentum_5m > 0 ? FaArrowUp : FaArrowDown}
|
||||||
|
color={alert.momentum_5m > 0 ? 'red.400' : 'green.400'}
|
||||||
|
boxSize={3}
|
||||||
|
/>
|
||||||
|
<Text fontSize="xs" color={alert.momentum_5m > 0 ? 'red.400' : 'green.400'}>
|
||||||
|
动量 {alert.momentum_5m > 0 ? '+' : ''}{alert.momentum_5m.toFixed(2)}
|
||||||
|
</Text>
|
||||||
|
</HStack>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* 涨停数量 / 涨停比例 */}
|
||||||
|
{(alert.limit_up_count > 0 || (alert.limit_up_ratio > 0.05 && isV2)) && (
|
||||||
<HStack spacing={1}>
|
<HStack spacing={1}>
|
||||||
<Icon as={FaFire} color="orange.500" boxSize={3} />
|
<Icon as={FaFire} color="orange.500" boxSize={3} />
|
||||||
<Text fontSize="xs" color="orange.500">
|
<Text fontSize="xs" color="orange.500">
|
||||||
涨停 {alert.limit_up_count}
|
{alert.limit_up_count > 0
|
||||||
|
? `涨停 ${alert.limit_up_count}`
|
||||||
|
: `涨停 ${Math.round(alert.limit_up_ratio * 100)}%`
|
||||||
|
}
|
||||||
</Text>
|
</Text>
|
||||||
</HStack>
|
</HStack>
|
||||||
)}
|
)}
|
||||||
|
|||||||
@@ -54,6 +54,7 @@ import { FaChartLine, FaFire, FaRocket, FaBrain, FaCalendarAlt, FaChevronRight,
|
|||||||
import ConceptStocksModal from '@components/ConceptStocksModal';
|
import ConceptStocksModal from '@components/ConceptStocksModal';
|
||||||
import TradeDatePicker from '@components/TradeDatePicker';
|
import TradeDatePicker from '@components/TradeDatePicker';
|
||||||
import HotspotOverview from './components/HotspotOverview';
|
import HotspotOverview from './components/HotspotOverview';
|
||||||
|
import FlexScreen from './components/FlexScreen';
|
||||||
import { BsGraphUp, BsLightningFill } from 'react-icons/bs';
|
import { BsGraphUp, BsLightningFill } from 'react-icons/bs';
|
||||||
import * as echarts from 'echarts';
|
import * as echarts from 'echarts';
|
||||||
import { logger } from '../../utils/logger';
|
import { logger } from '../../utils/logger';
|
||||||
@@ -846,6 +847,11 @@ const StockOverview = () => {
|
|||||||
<HotspotOverview selectedDate={selectedDate} />
|
<HotspotOverview selectedDate={selectedDate} />
|
||||||
</Box>
|
</Box>
|
||||||
|
|
||||||
|
{/* 灵活屏 - 实时行情监控 */}
|
||||||
|
<Box mb={10}>
|
||||||
|
<FlexScreen />
|
||||||
|
</Box>
|
||||||
|
|
||||||
{/* 今日热门概念 */}
|
{/* 今日热门概念 */}
|
||||||
<Box mb={10}>
|
<Box mb={10}>
|
||||||
<Flex align="center" mb={6}>
|
<Flex align="center" mb={6}>
|
||||||
|
|||||||
Reference in New Issue
Block a user