From e501ac38193089da76088de88551af1fd23e9bbe Mon Sep 17 00:00:00 2001 From: zzlgreat Date: Wed, 10 Dec 2025 11:02:09 +0800 Subject: [PATCH] update pay ui --- app.py | 198 ++--- ml/backtest_v2.py | 294 +++++++ ml/checkpoints_v2/config.json | 31 + ml/checkpoints_v2/thresholds.json | 8 + ml/detector_v2.py | 716 +++++++++++++++++ ml/model.py | 5 +- ml/prepare_data.py | 126 +-- ml/prepare_data_v2.py | 715 +++++++++++++++++ ml/realtime_detector.py | 8 +- ml/realtime_detector_v2.py | 729 ++++++++++++++++++ ml/train_v2.py | 622 +++++++++++++++ ml/update_baseline.py | 132 ++++ .../components/MiniTimelineChart.js | 267 +++++++ .../FlexScreen/components/OrderBookPanel.js | 275 +++++++ .../FlexScreen/components/QuoteTile.js | 270 +++++++ .../components/FlexScreen/components/index.js | 3 + .../components/FlexScreen/hooks/index.js | 1 + .../FlexScreen/hooks/useRealtimeQuote.js | 692 +++++++++++++++++ .../components/FlexScreen/index.js | 463 +++++++++++ .../components/ConceptAlertList.js | 104 ++- src/views/StockOverview/index.js | 6 + 21 files changed, 5514 insertions(+), 151 deletions(-) create mode 100644 ml/backtest_v2.py create mode 100644 ml/checkpoints_v2/config.json create mode 100644 ml/checkpoints_v2/thresholds.json create mode 100644 ml/detector_v2.py create mode 100644 ml/prepare_data_v2.py create mode 100644 ml/realtime_detector_v2.py create mode 100644 ml/train_v2.py create mode 100644 ml/update_baseline.py create mode 100644 src/views/StockOverview/components/FlexScreen/components/MiniTimelineChart.js create mode 100644 src/views/StockOverview/components/FlexScreen/components/OrderBookPanel.js create mode 100644 src/views/StockOverview/components/FlexScreen/components/QuoteTile.js create mode 100644 src/views/StockOverview/components/FlexScreen/components/index.js create mode 100644 src/views/StockOverview/components/FlexScreen/hooks/index.js create mode 100644 src/views/StockOverview/components/FlexScreen/hooks/useRealtimeQuote.js create mode 100644 src/views/StockOverview/components/FlexScreen/index.js diff --git a/app.py b/app.py index 07769534..0980cf9f 100755 --- a/app.py +++ b/app.py @@ -12536,113 +12536,113 @@ def get_hotspot_overview(): 'change_pct': change_pct }) - # 2. 获取概念异动数据(从 concept_anomaly_hybrid 表) + # 2. 获取概念异动数据(优先从 V2 表,fallback 到旧表) alerts = [] - - # 首先确保表存在(使用 begin() 来自动提交) - try: - with engine.begin() as conn: - conn.execute(text(""" - CREATE TABLE IF NOT EXISTS concept_anomaly_hybrid ( - id INT AUTO_INCREMENT PRIMARY KEY, - concept_id VARCHAR(64) NOT NULL, - alert_time DATETIME NOT NULL, - trade_date DATE NOT NULL, - alert_type VARCHAR(32) NOT NULL, - final_score FLOAT NOT NULL, - rule_score FLOAT NOT NULL, - ml_score FLOAT NOT NULL, - trigger_reason VARCHAR(64), - alpha FLOAT, - alpha_delta FLOAT, - amt_ratio FLOAT, - amt_delta FLOAT, - rank_pct FLOAT, - limit_up_ratio FLOAT, - stock_count INT, - total_amt FLOAT, - triggered_rules JSON, - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - UNIQUE KEY uk_concept_time (concept_id, alert_time, trade_date), - INDEX idx_trade_date (trade_date), - INDEX idx_concept_id (concept_id), - INDEX idx_final_score (final_score), - INDEX idx_alert_type (alert_type) - ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='概念异动检测结果(融合版)' - """)) - except Exception as create_err: - logger.debug(f"创建表检查: {create_err}") + use_v2 = False with engine.connect() as conn: - # 查询 concept_anomaly_hybrid 表 - alert_result = conn.execute(text(""" - SELECT - a.concept_id, - a.alert_time, - a.trade_date, - a.alert_type, - a.final_score, - a.rule_score, - a.ml_score, - a.trigger_reason, - a.alpha, - a.alpha_delta, - a.amt_ratio, - a.amt_delta, - a.rank_pct, - a.limit_up_ratio, - a.stock_count, - a.total_amt, - a.triggered_rules - FROM concept_anomaly_hybrid a - WHERE a.trade_date = :trade_date - ORDER BY a.alert_time - """), {'trade_date': trade_date}) + # 尝试查询 V2 表(时间片对齐 + 持续确认版本) + try: + v2_result = conn.execute(text(""" + SELECT + concept_id, alert_time, trade_date, alert_type, + final_score, rule_score, ml_score, trigger_reason, confirm_ratio, + alpha, alpha_zscore, amt_zscore, rank_zscore, + momentum_3m, momentum_5m, limit_up_ratio, triggered_rules + FROM concept_anomaly_v2 + WHERE trade_date = :trade_date + ORDER BY alert_time + """), {'trade_date': trade_date}) + v2_rows = v2_result.fetchall() + if v2_rows: + use_v2 = True + for row in v2_rows: + triggered_rules = None + if row[16]: + try: + triggered_rules = json.loads(row[16]) if isinstance(row[16], str) else row[16] + except: + pass - # 获取概念名称映射(从 ES 或缓存) - concept_names = {} + alerts.append({ + 'concept_id': row[0], + 'concept_name': row[0], # 后面会填充 + 'time': row[1].strftime('%H:%M') if row[1] else None, + 'timestamp': row[1].isoformat() if row[1] else None, + 'alert_type': row[3], + 'final_score': float(row[4]) if row[4] else None, + 'rule_score': float(row[5]) if row[5] else None, + 'ml_score': float(row[6]) if row[6] else None, + 'trigger_reason': row[7], + # V2 新增字段 + 'confirm_ratio': float(row[8]) if row[8] else None, + 'alpha': float(row[9]) if row[9] else None, + 'alpha_zscore': float(row[10]) if row[10] else None, + 'amt_zscore': float(row[11]) if row[11] else None, + 'rank_zscore': float(row[12]) if row[12] else None, + 'momentum_3m': float(row[13]) if row[13] else None, + 'momentum_5m': float(row[14]) if row[14] else None, + 'limit_up_ratio': float(row[15]) if row[15] else 0, + 'triggered_rules': triggered_rules, + # 兼容字段 + 'importance_score': float(row[4]) / 100 if row[4] else None, + 'is_v2': True, + }) + except Exception as v2_err: + logger.debug(f"V2 表查询失败,使用旧表: {v2_err}") - for row in alert_result: - concept_id = row[0] - alert_time = row[1] - triggered_rules = None - if row[16]: - try: - triggered_rules = json.loads(row[16]) if isinstance(row[16], str) else row[16] - except: - pass + # Fallback: 查询旧表 + if not use_v2: + try: + alert_result = conn.execute(text(""" + SELECT + a.concept_id, a.alert_time, a.trade_date, a.alert_type, + a.final_score, a.rule_score, a.ml_score, a.trigger_reason, + a.alpha, a.alpha_delta, a.amt_ratio, a.amt_delta, + a.rank_pct, a.limit_up_ratio, a.stock_count, a.total_amt, + a.triggered_rules + FROM concept_anomaly_hybrid a + WHERE a.trade_date = :trade_date + ORDER BY a.alert_time + """), {'trade_date': trade_date}) - # 获取概念名称(优先从缓存,否则使用 concept_id) - concept_name = concept_names.get(concept_id) or concept_id + for row in alert_result: + triggered_rules = None + if row[16]: + try: + triggered_rules = json.loads(row[16]) if isinstance(row[16], str) else row[16] + except: + pass - # 计算涨停数量(从 limit_up_ratio 和 stock_count 估算) - limit_up_ratio = float(row[13]) if row[13] else 0 - stock_count = int(row[14]) if row[14] else 0 - limit_up_count = int(limit_up_ratio * stock_count) if stock_count > 0 else 0 + limit_up_ratio = float(row[13]) if row[13] else 0 + stock_count = int(row[14]) if row[14] else 0 + limit_up_count = int(limit_up_ratio * stock_count) if stock_count > 0 else 0 - alerts.append({ - 'concept_id': concept_id, - 'concept_name': concept_name, - 'time': alert_time.strftime('%H:%M') if alert_time else None, - 'timestamp': alert_time.isoformat() if alert_time else None, - 'alert_type': row[3], - 'final_score': float(row[4]) if row[4] else None, - 'rule_score': float(row[5]) if row[5] else None, - 'ml_score': float(row[6]) if row[6] else None, - 'trigger_reason': row[7], - 'alpha': float(row[8]) if row[8] else None, - 'alpha_delta': float(row[9]) if row[9] else None, - 'amt_ratio': float(row[10]) if row[10] else None, - 'amt_delta': float(row[11]) if row[11] else None, - 'rank_pct': float(row[12]) if row[12] else None, - 'limit_up_ratio': limit_up_ratio, - 'limit_up_count': limit_up_count, - 'stock_count': stock_count, - 'total_amt': float(row[15]) if row[15] else None, - 'triggered_rules': triggered_rules, - # 兼容旧字段 - 'importance_score': float(row[4]) / 100 if row[4] else None, - }) + alerts.append({ + 'concept_id': row[0], + 'concept_name': row[0], + 'time': row[1].strftime('%H:%M') if row[1] else None, + 'timestamp': row[1].isoformat() if row[1] else None, + 'alert_type': row[3], + 'final_score': float(row[4]) if row[4] else None, + 'rule_score': float(row[5]) if row[5] else None, + 'ml_score': float(row[6]) if row[6] else None, + 'trigger_reason': row[7], + 'alpha': float(row[8]) if row[8] else None, + 'alpha_delta': float(row[9]) if row[9] else None, + 'amt_ratio': float(row[10]) if row[10] else None, + 'amt_delta': float(row[11]) if row[11] else None, + 'rank_pct': float(row[12]) if row[12] else None, + 'limit_up_ratio': limit_up_ratio, + 'limit_up_count': limit_up_count, + 'stock_count': stock_count, + 'total_amt': float(row[15]) if row[15] else None, + 'triggered_rules': triggered_rules, + 'importance_score': float(row[4]) / 100 if row[4] else None, + 'is_v2': False, + }) + except Exception as old_err: + logger.debug(f"旧表查询也失败: {old_err}") # 尝试批量获取概念名称 if alerts: diff --git a/ml/backtest_v2.py b/ml/backtest_v2.py new file mode 100644 index 00000000..84524fd9 --- /dev/null +++ b/ml/backtest_v2.py @@ -0,0 +1,294 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +V2 回测脚本 - 验证时间片对齐 + 持续性确认的效果 + +回测指标: +1. 准确率:异动后 N 分钟内 alpha 是否继续上涨/下跌 +2. 虚警率:多少异动是噪音 +3. 持续性:平均异动持续时长 +""" + +import os +import sys +import json +import argparse +from datetime import datetime +from pathlib import Path +from typing import Dict, List, Tuple +from collections import defaultdict + +import numpy as np +import pandas as pd +from tqdm import tqdm +from sqlalchemy import create_engine, text + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from ml.detector_v2 import AnomalyDetectorV2, CONFIG + + +# ==================== 配置 ==================== + +MYSQL_ENGINE = create_engine( + "mysql+pymysql://root:Zzl5588161!@192.168.1.5:3306/stock", + echo=False +) + + +# ==================== 回测评估 ==================== + +def evaluate_alerts( + alerts: List[Dict], + raw_data: pd.DataFrame, + lookahead_minutes: int = 10 +) -> Dict: + """ + 评估异动质量 + + 指标: + 1. 方向正确率:异动后 N 分钟 alpha 方向是否一致 + 2. 持续率:异动后 N 分钟内有多少时刻 alpha 保持同向 + 3. 峰值收益:异动后 N 分钟内的最大 alpha + """ + if not alerts: + return {'accuracy': 0, 'sustained_rate': 0, 'avg_peak': 0, 'total_alerts': 0} + + results = [] + + for alert in alerts: + concept_id = alert['concept_id'] + alert_time = alert['alert_time'] + alert_alpha = alert['alpha'] + is_up = alert_alpha > 0 + + # 获取该概念在异动后的数据 + concept_data = raw_data[ + (raw_data['concept_id'] == concept_id) & + (raw_data['timestamp'] > alert_time) + ].head(lookahead_minutes) + + if len(concept_data) < 3: + continue + + future_alphas = concept_data['alpha'].values + + # 方向正确:未来 alpha 平均值与当前同向 + avg_future_alpha = np.mean(future_alphas) + direction_correct = (is_up and avg_future_alpha > 0) or (not is_up and avg_future_alpha < 0) + + # 持续率:有多少时刻保持同向 + if is_up: + sustained_count = sum(1 for a in future_alphas if a > 0) + else: + sustained_count = sum(1 for a in future_alphas if a < 0) + sustained_rate = sustained_count / len(future_alphas) + + # 峰值收益 + if is_up: + peak = max(future_alphas) + else: + peak = min(future_alphas) + + results.append({ + 'direction_correct': direction_correct, + 'sustained_rate': sustained_rate, + 'peak': peak, + 'alert_alpha': alert_alpha, + }) + + if not results: + return {'accuracy': 0, 'sustained_rate': 0, 'avg_peak': 0, 'total_alerts': 0} + + return { + 'accuracy': np.mean([r['direction_correct'] for r in results]), + 'sustained_rate': np.mean([r['sustained_rate'] for r in results]), + 'avg_peak': np.mean([abs(r['peak']) for r in results]), + 'total_alerts': len(alerts), + 'evaluated_alerts': len(results), + } + + +def save_alerts_to_mysql(alerts: List[Dict], dry_run: bool = False) -> int: + """保存异动到 MySQL""" + if not alerts or dry_run: + return 0 + + # 确保表存在 + with MYSQL_ENGINE.begin() as conn: + conn.execute(text(""" + CREATE TABLE IF NOT EXISTS concept_anomaly_v2 ( + id INT AUTO_INCREMENT PRIMARY KEY, + concept_id VARCHAR(64) NOT NULL, + alert_time DATETIME NOT NULL, + trade_date DATE NOT NULL, + alert_type VARCHAR(32) NOT NULL, + final_score FLOAT NOT NULL, + rule_score FLOAT NOT NULL, + ml_score FLOAT NOT NULL, + trigger_reason VARCHAR(128), + confirm_ratio FLOAT, + alpha FLOAT, + alpha_zscore FLOAT, + amt_zscore FLOAT, + rank_zscore FLOAT, + momentum_3m FLOAT, + momentum_5m FLOAT, + limit_up_ratio FLOAT, + triggered_rules JSON, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + UNIQUE KEY uk_concept_time (concept_id, alert_time, trade_date), + INDEX idx_trade_date (trade_date), + INDEX idx_final_score (final_score) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='概念异动 V2(时间片对齐+持续确认)' + """)) + + # 插入数据 + saved = 0 + with MYSQL_ENGINE.begin() as conn: + for alert in alerts: + try: + conn.execute(text(""" + INSERT IGNORE INTO concept_anomaly_v2 + (concept_id, alert_time, trade_date, alert_type, + final_score, rule_score, ml_score, trigger_reason, confirm_ratio, + alpha, alpha_zscore, amt_zscore, rank_zscore, + momentum_3m, momentum_5m, limit_up_ratio, triggered_rules) + VALUES + (:concept_id, :alert_time, :trade_date, :alert_type, + :final_score, :rule_score, :ml_score, :trigger_reason, :confirm_ratio, + :alpha, :alpha_zscore, :amt_zscore, :rank_zscore, + :momentum_3m, :momentum_5m, :limit_up_ratio, :triggered_rules) + """), { + 'concept_id': alert['concept_id'], + 'alert_time': alert['alert_time'], + 'trade_date': alert['trade_date'], + 'alert_type': alert['alert_type'], + 'final_score': alert['final_score'], + 'rule_score': alert['rule_score'], + 'ml_score': alert['ml_score'], + 'trigger_reason': alert['trigger_reason'], + 'confirm_ratio': alert.get('confirm_ratio', 0), + 'alpha': alert['alpha'], + 'alpha_zscore': alert.get('alpha_zscore', 0), + 'amt_zscore': alert.get('amt_zscore', 0), + 'rank_zscore': alert.get('rank_zscore', 0), + 'momentum_3m': alert.get('momentum_3m', 0), + 'momentum_5m': alert.get('momentum_5m', 0), + 'limit_up_ratio': alert.get('limit_up_ratio', 0), + 'triggered_rules': json.dumps(alert.get('triggered_rules', [])), + }) + saved += 1 + except Exception as e: + print(f"保存失败: {e}") + + return saved + + +# ==================== 主函数 ==================== + +def main(): + parser = argparse.ArgumentParser(description='V2 回测') + parser.add_argument('--start', type=str, required=True, help='开始日期') + parser.add_argument('--end', type=str, default=None, help='结束日期') + parser.add_argument('--model_dir', type=str, default='ml/checkpoints_v2') + parser.add_argument('--baseline_dir', type=str, default='ml/data_v2/baselines') + parser.add_argument('--save', action='store_true', help='保存到数据库') + parser.add_argument('--lookahead', type=int, default=10, help='评估前瞻时间(分钟)') + + args = parser.parse_args() + + end_date = args.end or args.start + + print("=" * 60) + print("V2 回测 - 时间片对齐 + 持续性确认") + print("=" * 60) + print(f"日期范围: {args.start} ~ {end_date}") + print(f"模型目录: {args.model_dir}") + print(f"评估前瞻: {args.lookahead} 分钟") + + # 初始化检测器 + detector = AnomalyDetectorV2( + model_dir=args.model_dir, + baseline_dir=args.baseline_dir + ) + + # 获取交易日 + from prepare_data_v2 import get_trading_days + trading_days = get_trading_days(args.start, end_date) + + if not trading_days: + print("无交易日") + return + + print(f"交易日数: {len(trading_days)}") + + # 回测统计 + total_stats = { + 'total_alerts': 0, + 'accuracy_sum': 0, + 'sustained_sum': 0, + 'peak_sum': 0, + 'day_count': 0, + } + + all_alerts = [] + + for trade_date in tqdm(trading_days, desc="回测进度"): + # 检测异动 + alerts = detector.detect(trade_date) + + if not alerts: + continue + + all_alerts.extend(alerts) + + # 评估 + raw_data = detector._compute_raw_features(trade_date) + if raw_data.empty: + continue + + stats = evaluate_alerts(alerts, raw_data, args.lookahead) + + if stats['evaluated_alerts'] > 0: + total_stats['total_alerts'] += stats['total_alerts'] + total_stats['accuracy_sum'] += stats['accuracy'] * stats['evaluated_alerts'] + total_stats['sustained_sum'] += stats['sustained_rate'] * stats['evaluated_alerts'] + total_stats['peak_sum'] += stats['avg_peak'] * stats['evaluated_alerts'] + total_stats['day_count'] += 1 + + print(f"\n[{trade_date}] 异动: {stats['total_alerts']}, " + f"准确率: {stats['accuracy']:.1%}, " + f"持续率: {stats['sustained_rate']:.1%}, " + f"峰值: {stats['avg_peak']:.2f}%") + + # 汇总 + print("\n" + "=" * 60) + print("回测汇总") + print("=" * 60) + + if total_stats['total_alerts'] > 0: + avg_accuracy = total_stats['accuracy_sum'] / total_stats['total_alerts'] + avg_sustained = total_stats['sustained_sum'] / total_stats['total_alerts'] + avg_peak = total_stats['peak_sum'] / total_stats['total_alerts'] + + print(f"总异动数: {total_stats['total_alerts']}") + print(f"回测天数: {total_stats['day_count']}") + print(f"平均每天: {total_stats['total_alerts'] / max(1, total_stats['day_count']):.1f} 个") + print(f"方向准确率: {avg_accuracy:.1%}") + print(f"持续率: {avg_sustained:.1%}") + print(f"平均峰值: {avg_peak:.2f}%") + else: + print("无异动检测结果") + + # 保存 + if args.save and all_alerts: + print(f"\n保存 {len(all_alerts)} 条异动到数据库...") + saved = save_alerts_to_mysql(all_alerts) + print(f"保存完成: {saved} 条") + + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/ml/checkpoints_v2/config.json b/ml/checkpoints_v2/config.json new file mode 100644 index 00000000..007997f5 --- /dev/null +++ b/ml/checkpoints_v2/config.json @@ -0,0 +1,31 @@ +{ + "seq_len": 10, + "stride": 2, + "train_end_date": "2025-06-30", + "val_end_date": "2025-09-30", + "features": [ + "alpha_zscore", + "amt_zscore", + "rank_zscore", + "momentum_3m", + "momentum_5m", + "limit_up_ratio" + ], + "batch_size": 32768, + "epochs": 150, + "learning_rate": 0.0006, + "weight_decay": 1e-05, + "gradient_clip": 1.0, + "patience": 15, + "min_delta": 1e-06, + "model": { + "n_features": 6, + "hidden_dim": 32, + "latent_dim": 4, + "num_layers": 1, + "dropout": 0.2, + "bidirectional": true + }, + "clip_value": 5.0, + "threshold_percentiles": [90, 95, 99] +} diff --git a/ml/checkpoints_v2/thresholds.json b/ml/checkpoints_v2/thresholds.json new file mode 100644 index 00000000..47a6fe63 --- /dev/null +++ b/ml/checkpoints_v2/thresholds.json @@ -0,0 +1,8 @@ +{ + "p90": 0.15, + "p95": 0.25, + "p99": 0.50, + "mean": 0.08, + "std": 0.12, + "median": 0.06 +} diff --git a/ml/detector_v2.py b/ml/detector_v2.py new file mode 100644 index 00000000..e4e6f1ae --- /dev/null +++ b/ml/detector_v2.py @@ -0,0 +1,716 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +异动检测器 V2 - 基于时间片对齐 + 持续性确认 + +核心改进: +1. Z-Score 特征:相对于同时间片历史的偏离 +2. 短序列 LSTM:10分钟序列,开盘即可用 +3. 持续性确认:5分钟窗口内60%时刻超标才确认为异动 + +检测流程: +1. 计算当前时刻的 Z-Score(对比同时间片历史基线) +2. 构建最近10分钟的 Z-Score 序列 +3. LSTM 计算重构误差(ML分数) +4. 规则评分(基于 Z-Score 的规则) +5. 滑动窗口确认:最近5分钟内是否有足够多的时刻超标 +6. 只有通过持续性确认的才输出为异动 +""" + +import os +import sys +import json +import pickle +from datetime import datetime, timedelta +from typing import Dict, List, Optional, Tuple +from collections import defaultdict, deque + +import numpy as np +import pandas as pd +import torch +from sqlalchemy import create_engine, text +from elasticsearch import Elasticsearch +from clickhouse_driver import Client + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from ml.model import TransformerAutoencoder + +# ==================== 配置 ==================== + +MYSQL_ENGINE = create_engine( + "mysql+pymysql://root:Zzl5588161!@192.168.1.5:3306/stock", + echo=False +) + +ES_CLIENT = Elasticsearch(['http://127.0.0.1:9200']) +ES_INDEX = 'concept_library_v3' + +CLICKHOUSE_CONFIG = { + 'host': '127.0.0.1', + 'port': 9000, + 'user': 'default', + 'password': 'Zzl33818!', + 'database': 'stock' +} + +REFERENCE_INDEX = '000001.SH' + +# 检测配置 +CONFIG = { + # 序列配置 + 'seq_len': 10, # LSTM 序列长度(分钟) + + # 持续性确认配置(核心!) + 'confirm_window': 5, # 确认窗口(分钟) + 'confirm_ratio': 0.6, # 确认比例(60%时刻需要超标) + + # Z-Score 阈值 + 'alpha_zscore_threshold': 2.0, # Alpha Z-Score 阈值 + 'amt_zscore_threshold': 2.5, # 成交额 Z-Score 阈值 + + # 融合权重 + 'rule_weight': 0.5, + 'ml_weight': 0.5, + + # 触发阈值 + 'rule_trigger': 60, + 'ml_trigger': 70, + 'fusion_trigger': 50, + + # 冷却期 + 'cooldown_minutes': 10, + 'max_alerts_per_minute': 15, + + # Z-Score 截断 + 'zscore_clip': 5.0, +} + +# V2 特征列表 +FEATURES_V2 = [ + 'alpha_zscore', 'amt_zscore', 'rank_zscore', + 'momentum_3m', 'momentum_5m', 'limit_up_ratio' +] + + +# ==================== 工具函数 ==================== + +def get_ch_client(): + return Client(**CLICKHOUSE_CONFIG) + + +def code_to_ch_format(code: str) -> str: + if not code or len(code) != 6 or not code.isdigit(): + return None + if code.startswith('6'): + return f"{code}.SH" + elif code.startswith('0') or code.startswith('3'): + return f"{code}.SZ" + else: + return f"{code}.BJ" + + +def time_to_slot(ts) -> str: + """时间戳转时间片(HH:MM)""" + if isinstance(ts, str): + return ts + return ts.strftime('%H:%M') + + +# ==================== 基线加载 ==================== + +def load_baselines(baseline_dir: str = 'ml/data_v2/baselines') -> Dict[str, pd.DataFrame]: + """加载时间片基线""" + baseline_file = os.path.join(baseline_dir, 'baselines.pkl') + if os.path.exists(baseline_file): + with open(baseline_file, 'rb') as f: + return pickle.load(f) + return {} + + +# ==================== 规则评分(基于 Z-Score)==================== + +def score_rules_zscore(row: Dict) -> Tuple[float, List[str]]: + """ + 基于 Z-Score 的规则评分 + + 设计思路:Z-Score 已经标准化,直接用阈值判断 + """ + score = 0.0 + triggered = [] + + alpha_zscore = row.get('alpha_zscore', 0) + amt_zscore = row.get('amt_zscore', 0) + rank_zscore = row.get('rank_zscore', 0) + momentum_3m = row.get('momentum_3m', 0) + momentum_5m = row.get('momentum_5m', 0) + limit_up_ratio = row.get('limit_up_ratio', 0) + + alpha_zscore_abs = abs(alpha_zscore) + amt_zscore_abs = abs(amt_zscore) + + # ========== Alpha Z-Score 规则 ========== + if alpha_zscore_abs >= 4.0: + score += 25 + triggered.append('alpha_zscore_extreme') + elif alpha_zscore_abs >= 3.0: + score += 18 + triggered.append('alpha_zscore_strong') + elif alpha_zscore_abs >= 2.0: + score += 10 + triggered.append('alpha_zscore_moderate') + + # ========== 成交额 Z-Score 规则 ========== + if amt_zscore >= 4.0: + score += 20 + triggered.append('amt_zscore_extreme') + elif amt_zscore >= 3.0: + score += 12 + triggered.append('amt_zscore_strong') + elif amt_zscore >= 2.0: + score += 6 + triggered.append('amt_zscore_moderate') + + # ========== 排名 Z-Score 规则 ========== + if abs(rank_zscore) >= 3.0: + score += 15 + triggered.append('rank_zscore_extreme') + elif abs(rank_zscore) >= 2.0: + score += 8 + triggered.append('rank_zscore_strong') + + # ========== 动量规则 ========== + if momentum_3m >= 1.0: + score += 12 + triggered.append('momentum_3m_strong') + elif momentum_3m >= 0.5: + score += 6 + triggered.append('momentum_3m_moderate') + + if momentum_5m >= 1.5: + score += 10 + triggered.append('momentum_5m_strong') + + # ========== 涨停比例规则 ========== + if limit_up_ratio >= 0.3: + score += 20 + triggered.append('limit_up_extreme') + elif limit_up_ratio >= 0.15: + score += 12 + triggered.append('limit_up_strong') + elif limit_up_ratio >= 0.08: + score += 5 + triggered.append('limit_up_moderate') + + # ========== 组合规则 ========== + # Alpha Z-Score + 成交额放大 + if alpha_zscore_abs >= 2.0 and amt_zscore >= 2.0: + score += 15 + triggered.append('combo_alpha_amt') + + # Alpha Z-Score + 涨停 + if alpha_zscore_abs >= 2.0 and limit_up_ratio >= 0.1: + score += 12 + triggered.append('combo_alpha_limitup') + + return min(score, 100), triggered + + +# ==================== ML 评分器 ==================== + +class MLScorerV2: + """V2 ML 评分器""" + + def __init__(self, model_dir: str = 'ml/checkpoints_v2'): + self.model_dir = model_dir + self.model = None + self.thresholds = None + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self._load_model() + + def _load_model(self): + """加载模型和阈值""" + model_path = os.path.join(self.model_dir, 'best_model.pt') + threshold_path = os.path.join(self.model_dir, 'thresholds.json') + config_path = os.path.join(self.model_dir, 'config.json') + + if not os.path.exists(model_path): + print(f"警告: 模型文件不存在: {model_path}") + return + + # 加载配置 + with open(config_path, 'r') as f: + config = json.load(f) + + # 创建模型 + model_config = config.get('model', {}) + self.model = TransformerAutoencoder(**model_config) + + # 加载权重 + checkpoint = torch.load(model_path, map_location=self.device) + self.model.load_state_dict(checkpoint['model_state_dict']) + self.model.to(self.device) + self.model.eval() + + # 加载阈值 + if os.path.exists(threshold_path): + with open(threshold_path, 'r') as f: + self.thresholds = json.load(f) + + print(f"V2 模型加载完成: {model_path}") + + @torch.no_grad() + def score_batch(self, sequences: np.ndarray) -> np.ndarray: + """ + 批量计算 ML 分数 + + 返回 0-100 的分数,越高越异常 + """ + if self.model is None: + return np.zeros(len(sequences)) + + # 转换为 tensor + x = torch.FloatTensor(sequences).to(self.device) + + # 计算重构误差 + errors = self.model.compute_reconstruction_error(x, reduction='none') + # 取最后一个时刻的误差 + last_errors = errors[:, -1].cpu().numpy() + + # 转换为 0-100 分数 + if self.thresholds: + p50 = self.thresholds.get('median', 0.1) + p99 = self.thresholds.get('p99', 1.0) + + # 线性映射:p50 -> 50分,p99 -> 99分 + scores = 50 + (last_errors - p50) / (p99 - p50) * 49 + scores = np.clip(scores, 0, 100) + else: + # 没有阈值时,简单归一化 + scores = last_errors * 100 + scores = np.clip(scores, 0, 100) + + return scores + + +# ==================== 实时数据管理器 ==================== + +class RealtimeDataManagerV2: + """ + V2 实时数据管理器 + + 维护: + 1. 每个概念的历史 Z-Score 序列(用于 LSTM 输入) + 2. 每个概念的异动候选队列(用于持续性确认) + """ + + def __init__(self, concepts: List[dict], baselines: Dict[str, pd.DataFrame]): + self.concepts = {c['concept_id']: c for c in concepts} + self.baselines = baselines + + # 概念到股票的映射 + self.concept_stocks = {c['concept_id']: set(c['stocks']) for c in concepts} + + # 历史 Z-Score 序列(每个概念) + # {concept_id: deque([(timestamp, features_dict), ...], maxlen=seq_len)} + self.zscore_history = defaultdict(lambda: deque(maxlen=CONFIG['seq_len'])) + + # 异动候选队列(用于持续性确认) + # {concept_id: deque([(timestamp, score), ...], maxlen=confirm_window)} + self.anomaly_candidates = defaultdict(lambda: deque(maxlen=CONFIG['confirm_window'])) + + # 冷却期记录 + self.cooldown = {} + + # 上一次更新的时间戳 + self.last_timestamp = None + + def compute_zscore_features( + self, + concept_id: str, + timestamp, + alpha: float, + total_amt: float, + rank_pct: float, + limit_up_ratio: float + ) -> Optional[Dict]: + """计算单个概念单个时刻的 Z-Score 特征""" + if concept_id not in self.baselines: + return None + + baseline = self.baselines[concept_id] + time_slot = time_to_slot(timestamp) + + # 查找对应时间片的基线 + bl_row = baseline[baseline['time_slot'] == time_slot] + if bl_row.empty: + return None + + bl = bl_row.iloc[0] + + # 检查样本量 + if bl.get('sample_count', 0) < 10: + return None + + # 计算 Z-Score + alpha_zscore = (alpha - bl['alpha_mean']) / bl['alpha_std'] + amt_zscore = (total_amt - bl['amt_mean']) / bl['amt_std'] + rank_zscore = (rank_pct - bl['rank_mean']) / bl['rank_std'] + + # 截断 + clip = CONFIG['zscore_clip'] + alpha_zscore = np.clip(alpha_zscore, -clip, clip) + amt_zscore = np.clip(amt_zscore, -clip, clip) + rank_zscore = np.clip(rank_zscore, -clip, clip) + + # 计算动量(需要历史) + history = self.zscore_history[concept_id] + momentum_3m = 0 + momentum_5m = 0 + + if len(history) >= 3: + recent_alphas = [h[1]['alpha'] for h in list(history)[-3:]] + older_alphas = [h[1]['alpha'] for h in list(history)[-6:-3]] if len(history) >= 6 else [alpha] + momentum_3m = np.mean(recent_alphas) - np.mean(older_alphas) + + if len(history) >= 5: + recent_alphas = [h[1]['alpha'] for h in list(history)[-5:]] + older_alphas = [h[1]['alpha'] for h in list(history)[-10:-5]] if len(history) >= 10 else [alpha] + momentum_5m = np.mean(recent_alphas) - np.mean(older_alphas) + + return { + 'alpha': alpha, + 'alpha_zscore': alpha_zscore, + 'amt_zscore': amt_zscore, + 'rank_zscore': rank_zscore, + 'momentum_3m': momentum_3m, + 'momentum_5m': momentum_5m, + 'limit_up_ratio': limit_up_ratio, + 'total_amt': total_amt, + 'rank_pct': rank_pct, + } + + def update(self, concept_id: str, timestamp, features: Dict): + """更新概念的历史数据""" + self.zscore_history[concept_id].append((timestamp, features)) + + def get_sequence(self, concept_id: str) -> Optional[np.ndarray]: + """获取用于 LSTM 的序列""" + history = self.zscore_history[concept_id] + + if len(history) < CONFIG['seq_len']: + return None + + # 提取特征 + feature_list = [] + for _, features in history: + feature_list.append([ + features['alpha_zscore'], + features['amt_zscore'], + features['rank_zscore'], + features['momentum_3m'], + features['momentum_5m'], + features['limit_up_ratio'], + ]) + + return np.array(feature_list) + + def add_anomaly_candidate(self, concept_id: str, timestamp, score: float): + """添加异动候选""" + self.anomaly_candidates[concept_id].append((timestamp, score)) + + def check_sustained_anomaly(self, concept_id: str, threshold: float) -> Tuple[bool, float]: + """ + 检查是否为持续性异动 + + 返回:(是否确认, 确认比例) + """ + candidates = self.anomaly_candidates[concept_id] + + if len(candidates) < CONFIG['confirm_window']: + return False, 0.0 + + # 统计超过阈值的时刻数量 + exceed_count = sum(1 for _, score in candidates if score >= threshold) + ratio = exceed_count / len(candidates) + + return ratio >= CONFIG['confirm_ratio'], ratio + + def check_cooldown(self, concept_id: str, timestamp) -> bool: + """检查是否在冷却期""" + if concept_id not in self.cooldown: + return False + + last_alert = self.cooldown[concept_id] + try: + diff = (timestamp - last_alert).total_seconds() / 60 + return diff < CONFIG['cooldown_minutes'] + except: + return False + + def set_cooldown(self, concept_id: str, timestamp): + """设置冷却期""" + self.cooldown[concept_id] = timestamp + + +# ==================== 异动检测器 V2 ==================== + +class AnomalyDetectorV2: + """ + V2 异动检测器 + + 核心流程: + 1. 获取实时数据 + 2. 计算 Z-Score 特征 + 3. 规则评分 + ML 评分 + 4. 持续性确认 + 5. 输出异动 + """ + + def __init__( + self, + model_dir: str = 'ml/checkpoints_v2', + baseline_dir: str = 'ml/data_v2/baselines' + ): + # 加载概念 + self.concepts = self._load_concepts() + + # 加载基线 + self.baselines = load_baselines(baseline_dir) + print(f"加载了 {len(self.baselines)} 个概念的基线") + + # 初始化 ML 评分器 + self.ml_scorer = MLScorerV2(model_dir) + + # 初始化数据管理器 + self.data_manager = RealtimeDataManagerV2(self.concepts, self.baselines) + + # 收集所有股票 + self.all_stocks = list(set(s for c in self.concepts for s in c['stocks'])) + + def _load_concepts(self) -> List[dict]: + """从 ES 加载概念""" + concepts = [] + query = {"query": {"match_all": {}}, "size": 100, "_source": ["concept_id", "concept", "stocks"]} + + resp = ES_CLIENT.search(index=ES_INDEX, body=query, scroll='2m') + scroll_id = resp['_scroll_id'] + hits = resp['hits']['hits'] + + while len(hits) > 0: + for hit in hits: + source = hit['_source'] + stocks = [] + if 'stocks' in source and isinstance(source['stocks'], list): + for stock in source['stocks']: + if isinstance(stock, dict) and 'code' in stock and stock['code']: + stocks.append(stock['code']) + if stocks: + concepts.append({ + 'concept_id': source.get('concept_id'), + 'concept_name': source.get('concept'), + 'stocks': stocks + }) + + resp = ES_CLIENT.scroll(scroll_id=scroll_id, scroll='2m') + scroll_id = resp['_scroll_id'] + hits = resp['hits']['hits'] + + ES_CLIENT.clear_scroll(scroll_id=scroll_id) + print(f"加载了 {len(concepts)} 个概念") + return concepts + + def detect(self, trade_date: str) -> List[Dict]: + """ + 检测指定日期的异动 + + 返回异动列表 + """ + print(f"\n检测 {trade_date} 的异动...") + + # 获取原始数据 + raw_features = self._compute_raw_features(trade_date) + if raw_features.empty: + print("无数据") + return [] + + # 按时间排序 + timestamps = sorted(raw_features['timestamp'].unique()) + print(f"时间点数: {len(timestamps)}") + + all_alerts = [] + + for ts in timestamps: + ts_data = raw_features[raw_features['timestamp'] == ts] + ts_alerts = self._process_timestamp(ts, ts_data, trade_date) + all_alerts.extend(ts_alerts) + + print(f"共检测到 {len(all_alerts)} 个异动") + return all_alerts + + def _compute_raw_features(self, trade_date: str) -> pd.DataFrame: + """计算原始特征(同 prepare_data_v2)""" + # 这里简化处理,直接调用数据准备逻辑 + from prepare_data_v2 import compute_raw_concept_features + return compute_raw_concept_features(trade_date, self.concepts, self.all_stocks) + + def _process_timestamp(self, timestamp, ts_data: pd.DataFrame, trade_date: str) -> List[Dict]: + """处理单个时间戳""" + alerts = [] + candidates = [] # (concept_id, features, rule_score, triggered_rules) + + for _, row in ts_data.iterrows(): + concept_id = row['concept_id'] + + # 计算 Z-Score 特征 + features = self.data_manager.compute_zscore_features( + concept_id, timestamp, + row['alpha'], row['total_amt'], row['rank_pct'], row['limit_up_ratio'] + ) + + if features is None: + continue + + # 更新历史 + self.data_manager.update(concept_id, timestamp, features) + + # 规则评分 + rule_score, triggered_rules = score_rules_zscore(features) + + # 收集候选 + candidates.append((concept_id, features, rule_score, triggered_rules)) + + if not candidates: + return [] + + # 批量 ML 评分 + sequences = [] + valid_candidates = [] + + for concept_id, features, rule_score, triggered_rules in candidates: + seq = self.data_manager.get_sequence(concept_id) + if seq is not None: + sequences.append(seq) + valid_candidates.append((concept_id, features, rule_score, triggered_rules)) + + if not sequences: + return [] + + sequences = np.array(sequences) + ml_scores = self.ml_scorer.score_batch(sequences) + + # 融合评分 + 持续性确认 + for i, (concept_id, features, rule_score, triggered_rules) in enumerate(valid_candidates): + ml_score = ml_scores[i] + final_score = CONFIG['rule_weight'] * rule_score + CONFIG['ml_weight'] * ml_score + + # 判断是否触发 + is_triggered = ( + rule_score >= CONFIG['rule_trigger'] or + ml_score >= CONFIG['ml_trigger'] or + final_score >= CONFIG['fusion_trigger'] + ) + + # 添加到候选队列 + self.data_manager.add_anomaly_candidate(concept_id, timestamp, final_score) + + if not is_triggered: + continue + + # 检查冷却期 + if self.data_manager.check_cooldown(concept_id, timestamp): + continue + + # 持续性确认 + is_sustained, confirm_ratio = self.data_manager.check_sustained_anomaly( + concept_id, CONFIG['fusion_trigger'] + ) + + if not is_sustained: + continue + + # 确认为异动! + self.data_manager.set_cooldown(concept_id, timestamp) + + # 确定异动类型 + alpha = features['alpha'] + if alpha >= 1.5: + alert_type = 'surge_up' + elif alpha <= -1.5: + alert_type = 'surge_down' + elif features['amt_zscore'] >= 3.0: + alert_type = 'volume_spike' + else: + alert_type = 'surge' + + # 确定触发原因 + if rule_score >= CONFIG['rule_trigger']: + trigger_reason = f'规则({rule_score:.0f})+持续确认({confirm_ratio:.0%})' + elif ml_score >= CONFIG['ml_trigger']: + trigger_reason = f'ML({ml_score:.0f})+持续确认({confirm_ratio:.0%})' + else: + trigger_reason = f'融合({final_score:.0f})+持续确认({confirm_ratio:.0%})' + + alerts.append({ + 'concept_id': concept_id, + 'concept_name': self.data_manager.concepts.get(concept_id, {}).get('concept_name', concept_id), + 'alert_time': timestamp, + 'trade_date': trade_date, + 'alert_type': alert_type, + 'final_score': final_score, + 'rule_score': rule_score, + 'ml_score': ml_score, + 'trigger_reason': trigger_reason, + 'confirm_ratio': confirm_ratio, + 'alpha': alpha, + 'alpha_zscore': features['alpha_zscore'], + 'amt_zscore': features['amt_zscore'], + 'rank_zscore': features['rank_zscore'], + 'momentum_3m': features['momentum_3m'], + 'momentum_5m': features['momentum_5m'], + 'limit_up_ratio': features['limit_up_ratio'], + 'triggered_rules': triggered_rules, + }) + + # 每分钟最多 N 个 + if len(alerts) > CONFIG['max_alerts_per_minute']: + alerts = sorted(alerts, key=lambda x: x['final_score'], reverse=True) + alerts = alerts[:CONFIG['max_alerts_per_minute']] + + return alerts + + +# ==================== 主函数 ==================== + +def main(): + import argparse + + parser = argparse.ArgumentParser(description='V2 异动检测器') + parser.add_argument('--date', type=str, default=None, help='检测日期(默认今天)') + parser.add_argument('--model_dir', type=str, default='ml/checkpoints_v2') + parser.add_argument('--baseline_dir', type=str, default='ml/data_v2/baselines') + + args = parser.parse_args() + + trade_date = args.date or datetime.now().strftime('%Y-%m-%d') + + detector = AnomalyDetectorV2( + model_dir=args.model_dir, + baseline_dir=args.baseline_dir + ) + + alerts = detector.detect(trade_date) + + print(f"\n检测结果:") + for alert in alerts[:20]: + print(f" [{alert['alert_time'].strftime('%H:%M') if hasattr(alert['alert_time'], 'strftime') else alert['alert_time']}] " + f"{alert['concept_name']} ({alert['alert_type']}) " + f"分数={alert['final_score']:.0f} " + f"确认率={alert['confirm_ratio']:.0%}") + + if len(alerts) > 20: + print(f" ... 共 {len(alerts)} 个异动") + + +if __name__ == "__main__": + main() diff --git a/ml/model.py b/ml/model.py index 2eb63d87..90c0d61b 100644 --- a/ml/model.py +++ b/ml/model.py @@ -85,9 +85,12 @@ class LSTMAutoencoder(nn.Module): nn.Tanh(), # 限制范围,增加约束 ) + # 使用 LeakyReLU 替代 ReLU + # 原因:Z-Score 数据范围是 [-5, +5],ReLU 会截断负值,丢失跌幅信息 + # LeakyReLU 保留负值信号(乘以 0.1) self.bottleneck_up = nn.Sequential( nn.Linear(latent_dim, hidden_dim), - nn.ReLU(), + nn.LeakyReLU(negative_slope=0.1), ) # Decoder: 单向 LSTM diff --git a/ml/prepare_data.py b/ml/prepare_data.py index 4734f582..cb905d42 100644 --- a/ml/prepare_data.py +++ b/ml/prepare_data.py @@ -26,7 +26,9 @@ import hashlib import json import logging from typing import Dict, List, Set, Tuple -from concurrent.futures import ThreadPoolExecutor, as_completed +from concurrent.futures import ProcessPoolExecutor, as_completed +from multiprocessing import Manager +import multiprocessing import warnings warnings.filterwarnings('ignore') @@ -128,7 +130,7 @@ def get_all_concepts() -> List[dict]: hits = resp['hits']['hits'] ES_CLIENT.clear_scroll(scroll_id=scroll_id) - logger.info(f"获取到 {len(concepts)} 个概念") + print(f"获取到 {len(concepts)} 个概念") return concepts @@ -148,7 +150,7 @@ def get_trading_days(start_date: str, end_date: str) -> List[str]: result = client.execute(query) days = [row[0].strftime('%Y-%m-%d') for row in result] - logger.info(f"找到 {len(days)} 个交易日: {days[0]} ~ {days[-1]}") + print(f"找到 {len(days)} 个交易日: {days[0]} ~ {days[-1]}") return days @@ -223,21 +225,23 @@ def get_daily_index_data(trade_date: str, index_code: str = REFERENCE_INDEX) -> def get_prev_close(stock_codes: List[str], trade_date: str) -> Dict[str, float]: - """获取昨收价""" + """获取昨收价(上一交易日的收盘价 F007N)""" valid_codes = [c for c in stock_codes if c and len(c) == 6 and c.isdigit()] if not valid_codes: return {} codes_str = "','".join(valid_codes) + # 注意:F007N 是"最近成交价"即当日收盘价,F002N 是"昨日收盘价" + # 我们需要查上一交易日的 F007N(那天的收盘价)作为今天的昨收 query = f""" - SELECT SECCODE, F002N + SELECT SECCODE, F007N FROM ea_trade WHERE SECCODE IN ('{codes_str}') AND TRADEDATE = ( SELECT MAX(TRADEDATE) FROM ea_trade WHERE TRADEDATE < '{trade_date}' ) - AND F002N IS NOT NULL AND F002N > 0 + AND F007N IS NOT NULL AND F007N > 0 """ try: @@ -245,7 +249,7 @@ def get_prev_close(stock_codes: List[str], trade_date: str) -> Dict[str, float]: result = conn.execute(text(query)) return {row[0]: float(row[1]) for row in result if row[1]} except Exception as e: - logger.error(f"获取昨收价失败: {e}") + print(f"获取昨收价失败: {e}") return {} @@ -264,7 +268,7 @@ def get_index_prev_close(trade_date: str, index_code: str = REFERENCE_INDEX) -> if result and result[0]: return float(result[0]) except Exception as e: - logger.error(f"获取指数昨收失败: {e}") + print(f"获取指数昨收失败: {e}") return None @@ -285,25 +289,19 @@ def compute_daily_features( """ # 1. 获取数据 - logger.info(f" 获取股票数据...") stock_df = get_daily_stock_data(trade_date, all_stocks) if stock_df.empty: - logger.warning(f" 无股票数据") return pd.DataFrame() - logger.info(f" 获取指数数据...") index_df = get_daily_index_data(trade_date) if index_df.empty: - logger.warning(f" 无指数数据") return pd.DataFrame() # 2. 获取昨收价 - logger.info(f" 获取昨收价...") prev_close = get_prev_close(all_stocks, trade_date) index_prev_close = get_index_prev_close(trade_date) if not prev_close or not index_prev_close: - logger.warning(f" 无昨收价数据") return pd.DataFrame() # 3. 计算股票涨跌幅和成交额 @@ -317,7 +315,6 @@ def compute_daily_features( # 5. 获取所有时间点 timestamps = sorted(stock_df['timestamp'].unique()) - logger.info(f" 时间点数: {len(timestamps)}") # 6. 按时间点计算概念特征 results = [] @@ -414,87 +411,126 @@ def compute_daily_features( if amt_delta_std > 0: final_df['amt_delta'] = final_df['amt_delta'] / amt_delta_std - logger.info(f" 计算完成: {len(final_df)} 条记录") return final_df # ==================== 主流程 ==================== -def process_single_day(trade_date: str, concepts: List[dict], all_stocks: List[str]) -> str: - """处理单个交易日""" +def process_single_day(args) -> Tuple[str, bool]: + """ + 处理单个交易日(多进程版本) + + Args: + args: (trade_date, concepts, all_stocks) 元组 + + Returns: + (trade_date, success) 元组 + """ + trade_date, concepts, all_stocks = args output_file = os.path.join(OUTPUT_DIR, f'features_{trade_date}.parquet') # 检查是否已处理 if os.path.exists(output_file): - logger.info(f"[{trade_date}] 已存在,跳过") - return output_file + print(f"[{trade_date}] 已存在,跳过") + return (trade_date, True) - logger.info(f"[{trade_date}] 开始处理...") + print(f"[{trade_date}] 开始处理...") try: df = compute_daily_features(trade_date, concepts, all_stocks) if df.empty: - logger.warning(f"[{trade_date}] 无数据") - return None + print(f"[{trade_date}] 无数据") + return (trade_date, False) # 保存 df.to_parquet(output_file, index=False) - logger.info(f"[{trade_date}] 保存完成: {output_file}") - return output_file + print(f"[{trade_date}] 保存完成") + return (trade_date, True) except Exception as e: - logger.error(f"[{trade_date}] 处理失败: {e}") + print(f"[{trade_date}] 处理失败: {e}") import traceback traceback.print_exc() - return None + return (trade_date, False) def main(): import argparse + from tqdm import tqdm parser = argparse.ArgumentParser(description='准备训练数据') parser.add_argument('--start', type=str, default='2022-01-01', help='开始日期') parser.add_argument('--end', type=str, default=None, help='结束日期(默认今天)') - parser.add_argument('--workers', type=int, default=1, help='并行数(建议1,避免数据库压力)') + parser.add_argument('--workers', type=int, default=18, help='并行进程数(默认18)') + parser.add_argument('--force', action='store_true', help='强制重新处理已存在的文件') args = parser.parse_args() end_date = args.end or datetime.now().strftime('%Y-%m-%d') - logger.info("=" * 60) - logger.info("数据准备 - Transformer Autoencoder 训练数据") - logger.info("=" * 60) - logger.info(f"日期范围: {args.start} ~ {end_date}") + print("=" * 60) + print("数据准备 - Transformer Autoencoder 训练数据") + print("=" * 60) + print(f"日期范围: {args.start} ~ {end_date}") + print(f"并行进程数: {args.workers}") # 1. 获取概念列表 concepts = get_all_concepts() # 收集所有股票 all_stocks = list(set(s for c in concepts for s in c['stocks'])) - logger.info(f"股票总数: {len(all_stocks)}") + print(f"股票总数: {len(all_stocks)}") # 2. 获取交易日列表 trading_days = get_trading_days(args.start, end_date) if not trading_days: - logger.error("无交易日数据") + print("无交易日数据") return - # 3. 处理每个交易日 - logger.info(f"\n开始处理 {len(trading_days)} 个交易日...") + # 如果强制模式,删除已有文件 + if args.force: + for trade_date in trading_days: + output_file = os.path.join(OUTPUT_DIR, f'features_{trade_date}.parquet') + if os.path.exists(output_file): + os.remove(output_file) + print(f"删除已有文件: {output_file}") + # 3. 准备任务参数 + tasks = [(trade_date, concepts, all_stocks) for trade_date in trading_days] + + print(f"\n开始处理 {len(trading_days)} 个交易日({args.workers} 进程并行)...") + + # 4. 多进程处理 success_count = 0 - for i, trade_date in enumerate(trading_days): - logger.info(f"\n[{i+1}/{len(trading_days)}] {trade_date}") - result = process_single_day(trade_date, concepts, all_stocks) - if result: - success_count += 1 + failed_dates = [] - logger.info("\n" + "=" * 60) - logger.info(f"处理完成: {success_count}/{len(trading_days)} 个交易日") - logger.info(f"数据保存在: {OUTPUT_DIR}") - logger.info("=" * 60) + with ProcessPoolExecutor(max_workers=args.workers) as executor: + # 提交所有任务 + futures = {executor.submit(process_single_day, task): task[0] for task in tasks} + + # 使用 tqdm 显示进度 + with tqdm(total=len(futures), desc="处理进度", unit="天") as pbar: + for future in as_completed(futures): + trade_date = futures[future] + try: + result_date, success = future.result() + if success: + success_count += 1 + else: + failed_dates.append(result_date) + except Exception as e: + print(f"\n[{trade_date}] 进程异常: {e}") + failed_dates.append(trade_date) + pbar.update(1) + + print("\n" + "=" * 60) + print(f"处理完成: {success_count}/{len(trading_days)} 个交易日") + if failed_dates: + print(f"失败日期: {failed_dates[:10]}{'...' if len(failed_dates) > 10 else ''}") + print(f"数据保存在: {OUTPUT_DIR}") + print("=" * 60) if __name__ == "__main__": diff --git a/ml/prepare_data_v2.py b/ml/prepare_data_v2.py new file mode 100644 index 00000000..5a3ad02c --- /dev/null +++ b/ml/prepare_data_v2.py @@ -0,0 +1,715 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +数据准备 V2 - 基于时间片对齐的特征计算(修复版) + +核心改进: +1. 时间片对齐:9:35 和历史的 9:35 比,而不是和前30分钟比 +2. Z-Score 特征:相对于同时间片历史分布的偏离程度 +3. 滚动窗口基线:每个日期使用它之前 N 天的数据作为基线(不是固定的最后 N 天!) +4. 基于 Z-Score 的动量:消除一天内波动率异构性 + +修复: +- 滚动窗口基线:避免未来数据泄露 +- Z-Score 动量:消除早盘/尾盘波动率差异 +- 进程级数据库单例:避免连接池爆炸 +""" + +import os +import sys +import numpy as np +import pandas as pd +from datetime import datetime, timedelta +from sqlalchemy import create_engine, text +from elasticsearch import Elasticsearch +from clickhouse_driver import Client +from concurrent.futures import ProcessPoolExecutor, as_completed +from typing import Dict, List, Tuple, Optional +from tqdm import tqdm +from collections import defaultdict +import warnings +import pickle + +warnings.filterwarnings('ignore') + +# ==================== 配置 ==================== + +MYSQL_URL = "mysql+pymysql://root:Zzl5588161!@192.168.1.5:3306/stock" +ES_HOST = 'http://127.0.0.1:9200' +ES_INDEX = 'concept_library_v3' + +CLICKHOUSE_CONFIG = { + 'host': '127.0.0.1', + 'port': 9000, + 'user': 'default', + 'password': 'Zzl33818!', + 'database': 'stock' +} + +REFERENCE_INDEX = '000001.SH' + +# 输出目录 +OUTPUT_DIR = os.path.join(os.path.dirname(__file__), 'data_v2') +BASELINE_DIR = os.path.join(OUTPUT_DIR, 'baselines') +RAW_CACHE_DIR = os.path.join(OUTPUT_DIR, 'raw_cache') +os.makedirs(OUTPUT_DIR, exist_ok=True) +os.makedirs(BASELINE_DIR, exist_ok=True) +os.makedirs(RAW_CACHE_DIR, exist_ok=True) + +# 特征配置 +CONFIG = { + 'baseline_days': 20, # 滚动窗口大小 + 'min_baseline_samples': 10, # 最少需要10个样本才算有效基线 + 'limit_up_threshold': 9.8, + 'limit_down_threshold': -9.8, + 'zscore_clip': 5.0, +} + +# 特征列表 +FEATURES_V2 = [ + 'alpha', 'alpha_zscore', 'amt_zscore', 'rank_zscore', + 'momentum_3m', 'momentum_5m', 'limit_up_ratio', +] + +# ==================== 进程级单例(避免连接池爆炸)==================== + +# 进程级全局变量 +_process_mysql_engine = None +_process_es_client = None +_process_ch_client = None + + +def init_process_connections(): + """进程初始化时调用,创建连接(单例)""" + global _process_mysql_engine, _process_es_client, _process_ch_client + _process_mysql_engine = create_engine(MYSQL_URL, echo=False, pool_pre_ping=True, pool_size=5) + _process_es_client = Elasticsearch([ES_HOST]) + _process_ch_client = Client(**CLICKHOUSE_CONFIG) + + +def get_mysql_engine(): + """获取进程级 MySQL Engine(单例)""" + global _process_mysql_engine + if _process_mysql_engine is None: + _process_mysql_engine = create_engine(MYSQL_URL, echo=False, pool_pre_ping=True, pool_size=5) + return _process_mysql_engine + + +def get_es_client(): + """获取进程级 ES 客户端(单例)""" + global _process_es_client + if _process_es_client is None: + _process_es_client = Elasticsearch([ES_HOST]) + return _process_es_client + + +def get_ch_client(): + """获取进程级 ClickHouse 客户端(单例)""" + global _process_ch_client + if _process_ch_client is None: + _process_ch_client = Client(**CLICKHOUSE_CONFIG) + return _process_ch_client + + +# ==================== 工具函数 ==================== + +def code_to_ch_format(code: str) -> str: + if not code or len(code) != 6 or not code.isdigit(): + return None + if code.startswith('6'): + return f"{code}.SH" + elif code.startswith('0') or code.startswith('3'): + return f"{code}.SZ" + else: + return f"{code}.BJ" + + +def time_to_slot(ts) -> str: + """将时间戳转换为时间片(HH:MM格式)""" + if isinstance(ts, str): + return ts + return ts.strftime('%H:%M') + + +# ==================== 获取概念列表 ==================== + +def get_all_concepts() -> List[dict]: + """从ES获取所有叶子概念""" + es_client = get_es_client() + concepts = [] + + query = { + "query": {"match_all": {}}, + "size": 100, + "_source": ["concept_id", "concept", "stocks"] + } + + resp = es_client.search(index=ES_INDEX, body=query, scroll='2m') + scroll_id = resp['_scroll_id'] + hits = resp['hits']['hits'] + + while len(hits) > 0: + for hit in hits: + source = hit['_source'] + stocks = [] + if 'stocks' in source and isinstance(source['stocks'], list): + for stock in source['stocks']: + if isinstance(stock, dict) and 'code' in stock and stock['code']: + stocks.append(stock['code']) + + if stocks: + concepts.append({ + 'concept_id': source.get('concept_id'), + 'concept_name': source.get('concept'), + 'stocks': stocks + }) + + resp = es_client.scroll(scroll_id=scroll_id, scroll='2m') + scroll_id = resp['_scroll_id'] + hits = resp['hits']['hits'] + + es_client.clear_scroll(scroll_id=scroll_id) + print(f"获取到 {len(concepts)} 个概念") + return concepts + + +# ==================== 获取交易日列表 ==================== + +def get_trading_days(start_date: str, end_date: str) -> List[str]: + """获取交易日列表""" + client = get_ch_client() + + query = f""" + SELECT DISTINCT toDate(timestamp) as trade_date + FROM stock_minute + WHERE toDate(timestamp) >= '{start_date}' + AND toDate(timestamp) <= '{end_date}' + ORDER BY trade_date + """ + + result = client.execute(query) + days = [row[0].strftime('%Y-%m-%d') for row in result] + if days: + print(f"找到 {len(days)} 个交易日: {days[0]} ~ {days[-1]}") + return days + + +# ==================== 获取昨收价 ==================== + +def get_prev_close(stock_codes: List[str], trade_date: str) -> Dict[str, float]: + """获取昨收价(上一交易日的收盘价 F007N)""" + valid_codes = [c for c in stock_codes if c and len(c) == 6 and c.isdigit()] + if not valid_codes: + return {} + + codes_str = "','".join(valid_codes) + query = f""" + SELECT SECCODE, F007N + FROM ea_trade + WHERE SECCODE IN ('{codes_str}') + AND TRADEDATE = ( + SELECT MAX(TRADEDATE) FROM ea_trade WHERE TRADEDATE < '{trade_date}' + ) + AND F007N IS NOT NULL AND F007N > 0 + """ + + try: + engine = get_mysql_engine() + with engine.connect() as conn: + result = conn.execute(text(query)) + return {row[0]: float(row[1]) for row in result if row[1]} + except Exception as e: + print(f"获取昨收价失败: {e}") + return {} + + +def get_index_prev_close(trade_date: str, index_code: str = REFERENCE_INDEX) -> float: + """获取指数昨收价""" + code_no_suffix = index_code.split('.')[0] + + try: + engine = get_mysql_engine() + with engine.connect() as conn: + result = conn.execute(text(""" + SELECT F006N FROM ea_exchangetrade + WHERE INDEXCODE = :code AND TRADEDATE < :today + ORDER BY TRADEDATE DESC LIMIT 1 + """), {'code': code_no_suffix, 'today': trade_date}).fetchone() + + if result and result[0]: + return float(result[0]) + except Exception as e: + print(f"获取指数昨收失败: {e}") + + return None + + +# ==================== 获取分钟数据 ==================== + +def get_daily_stock_data(trade_date: str, stock_codes: List[str]) -> pd.DataFrame: + """获取单日所有股票的分钟数据""" + client = get_ch_client() + + ch_codes = [] + code_map = {} + for code in stock_codes: + ch_code = code_to_ch_format(code) + if ch_code: + ch_codes.append(ch_code) + code_map[ch_code] = code + + if not ch_codes: + return pd.DataFrame() + + ch_codes_str = "','".join(ch_codes) + + query = f""" + SELECT code, timestamp, close, volume, amt + FROM stock_minute + WHERE toDate(timestamp) = '{trade_date}' + AND code IN ('{ch_codes_str}') + ORDER BY code, timestamp + """ + + result = client.execute(query) + if not result: + return pd.DataFrame() + + df = pd.DataFrame(result, columns=['ch_code', 'timestamp', 'close', 'volume', 'amt']) + df['code'] = df['ch_code'].map(code_map) + df = df.dropna(subset=['code']) + + return df[['code', 'timestamp', 'close', 'volume', 'amt']] + + +def get_daily_index_data(trade_date: str, index_code: str = REFERENCE_INDEX) -> pd.DataFrame: + """获取单日指数分钟数据""" + client = get_ch_client() + + query = f""" + SELECT timestamp, close, volume, amt + FROM index_minute + WHERE toDate(timestamp) = '{trade_date}' + AND code = '{index_code}' + ORDER BY timestamp + """ + + result = client.execute(query) + if not result: + return pd.DataFrame() + + df = pd.DataFrame(result, columns=['timestamp', 'close', 'volume', 'amt']) + return df + + +# ==================== 计算原始概念特征(单日)==================== + +def compute_raw_concept_features( + trade_date: str, + concepts: List[dict], + all_stocks: List[str] +) -> pd.DataFrame: + """计算单日概念的原始特征(alpha, amt, rank_pct, limit_up_ratio)""" + # 检查缓存 + cache_file = os.path.join(RAW_CACHE_DIR, f'raw_{trade_date}.parquet') + if os.path.exists(cache_file): + return pd.read_parquet(cache_file) + + # 获取数据 + stock_df = get_daily_stock_data(trade_date, all_stocks) + if stock_df.empty: + return pd.DataFrame() + + index_df = get_daily_index_data(trade_date) + if index_df.empty: + return pd.DataFrame() + + # 获取昨收价 + prev_close = get_prev_close(all_stocks, trade_date) + index_prev_close = get_index_prev_close(trade_date) + + if not prev_close or not index_prev_close: + return pd.DataFrame() + + # 计算涨跌幅 + stock_df['prev_close'] = stock_df['code'].map(prev_close) + stock_df = stock_df.dropna(subset=['prev_close']) + stock_df['change_pct'] = (stock_df['close'] - stock_df['prev_close']) / stock_df['prev_close'] * 100 + + index_df['change_pct'] = (index_df['close'] - index_prev_close) / index_prev_close * 100 + index_change_map = dict(zip(index_df['timestamp'], index_df['change_pct'])) + + # 获取所有时间点 + timestamps = sorted(stock_df['timestamp'].unique()) + + # 概念到股票的映射 + concept_stocks = {c['concept_id']: set(c['stocks']) for c in concepts} + + results = [] + + for ts in timestamps: + ts_stock_data = stock_df[stock_df['timestamp'] == ts] + index_change = index_change_map.get(ts, 0) + + stock_change = dict(zip(ts_stock_data['code'], ts_stock_data['change_pct'])) + stock_amt = dict(zip(ts_stock_data['code'], ts_stock_data['amt'])) + + concept_features = [] + + for concept_id, stocks in concept_stocks.items(): + concept_changes = [stock_change[s] for s in stocks if s in stock_change] + concept_amts = [stock_amt.get(s, 0) for s in stocks if s in stock_change] + + if not concept_changes: + continue + + avg_change = np.mean(concept_changes) + total_amt = sum(concept_amts) + alpha = avg_change - index_change + + limit_up_count = sum(1 for c in concept_changes if c >= CONFIG['limit_up_threshold']) + limit_up_ratio = limit_up_count / len(concept_changes) + + concept_features.append({ + 'concept_id': concept_id, + 'alpha': alpha, + 'total_amt': total_amt, + 'limit_up_ratio': limit_up_ratio, + 'stock_count': len(concept_changes), + }) + + if not concept_features: + continue + + concept_df = pd.DataFrame(concept_features) + concept_df['rank_pct'] = concept_df['alpha'].rank(pct=True) + concept_df['timestamp'] = ts + concept_df['time_slot'] = time_to_slot(ts) + concept_df['trade_date'] = trade_date + + results.append(concept_df) + + if not results: + return pd.DataFrame() + + result_df = pd.concat(results, ignore_index=True) + + # 保存缓存 + result_df.to_parquet(cache_file, index=False) + + return result_df + + +# ==================== 滚动窗口基线计算 ==================== + +def compute_rolling_baseline( + historical_data: pd.DataFrame, + concept_id: str +) -> Dict[str, Dict]: + """ + 计算单个概念的滚动基线 + + 返回: {time_slot: {alpha_mean, alpha_std, amt_mean, amt_std, rank_mean, rank_std, sample_count}} + """ + if historical_data.empty: + return {} + + concept_data = historical_data[historical_data['concept_id'] == concept_id] + if concept_data.empty: + return {} + + baseline_dict = {} + + for time_slot, group in concept_data.groupby('time_slot'): + if len(group) < CONFIG['min_baseline_samples']: + continue + + alpha_std = group['alpha'].std() + amt_std = group['total_amt'].std() + rank_std = group['rank_pct'].std() + + baseline_dict[time_slot] = { + 'alpha_mean': group['alpha'].mean(), + 'alpha_std': max(alpha_std if pd.notna(alpha_std) else 1.0, 0.1), + 'amt_mean': group['total_amt'].mean(), + 'amt_std': max(amt_std if pd.notna(amt_std) else group['total_amt'].mean() * 0.5, 1.0), + 'rank_mean': group['rank_pct'].mean(), + 'rank_std': max(rank_std if pd.notna(rank_std) else 0.2, 0.05), + 'sample_count': len(group), + } + + return baseline_dict + + +# ==================== 计算单日 Z-Score 特征(带滚动基线)==================== + +def compute_zscore_features_rolling( + trade_date: str, + concepts: List[dict], + all_stocks: List[str], + historical_raw_data: pd.DataFrame # 该日期之前 N 天的原始数据 +) -> pd.DataFrame: + """ + 计算单日的 Z-Score 特征(使用滚动窗口基线) + + 关键改进: + 1. 基线只使用 trade_date 之前的数据(无未来泄露) + 2. 动量基于 Z-Score 计算(消除波动率异构性) + """ + # 计算当日原始特征 + raw_df = compute_raw_concept_features(trade_date, concepts, all_stocks) + + if raw_df.empty: + return pd.DataFrame() + + zscore_records = [] + + for concept_id, group in raw_df.groupby('concept_id'): + # 计算该概念的滚动基线(只用历史数据) + baseline_dict = compute_rolling_baseline(historical_raw_data, concept_id) + + if not baseline_dict: + continue + + # 按时间排序 + group = group.sort_values('timestamp').reset_index(drop=True) + + # Z-Score 历史(用于计算基于 Z-Score 的动量) + zscore_history = [] + + for idx, row in group.iterrows(): + time_slot = row['time_slot'] + + if time_slot not in baseline_dict: + continue + + bl = baseline_dict[time_slot] + + # 计算 Z-Score + alpha_zscore = (row['alpha'] - bl['alpha_mean']) / bl['alpha_std'] + amt_zscore = (row['total_amt'] - bl['amt_mean']) / bl['amt_std'] + rank_zscore = (row['rank_pct'] - bl['rank_mean']) / bl['rank_std'] + + # 截断极端值 + clip = CONFIG['zscore_clip'] + alpha_zscore = np.clip(alpha_zscore, -clip, clip) + amt_zscore = np.clip(amt_zscore, -clip, clip) + rank_zscore = np.clip(rank_zscore, -clip, clip) + + # 记录 Z-Score 历史 + zscore_history.append(alpha_zscore) + + # 基于 Z-Score 计算动量(消除波动率异构性) + momentum_3m = 0.0 + momentum_5m = 0.0 + + if len(zscore_history) >= 3: + recent_3 = zscore_history[-3:] + older_3 = zscore_history[-6:-3] if len(zscore_history) >= 6 else [zscore_history[0]] + momentum_3m = np.mean(recent_3) - np.mean(older_3) + + if len(zscore_history) >= 5: + recent_5 = zscore_history[-5:] + older_5 = zscore_history[-10:-5] if len(zscore_history) >= 10 else [zscore_history[0]] + momentum_5m = np.mean(recent_5) - np.mean(older_5) + + zscore_records.append({ + 'concept_id': concept_id, + 'timestamp': row['timestamp'], + 'time_slot': time_slot, + 'trade_date': trade_date, + # 原始特征 + 'alpha': row['alpha'], + 'total_amt': row['total_amt'], + 'limit_up_ratio': row['limit_up_ratio'], + 'stock_count': row['stock_count'], + 'rank_pct': row['rank_pct'], + # Z-Score 特征 + 'alpha_zscore': alpha_zscore, + 'amt_zscore': amt_zscore, + 'rank_zscore': rank_zscore, + # 基于 Z-Score 的动量 + 'momentum_3m': momentum_3m, + 'momentum_5m': momentum_5m, + }) + + if not zscore_records: + return pd.DataFrame() + + return pd.DataFrame(zscore_records) + + +# ==================== 多进程处理 ==================== + +def process_single_day_v2(args) -> Tuple[str, bool]: + """处理单个交易日(多进程版本)""" + trade_date, day_index, concepts, all_stocks, all_trading_days = args + output_file = os.path.join(OUTPUT_DIR, f'features_v2_{trade_date}.parquet') + + if os.path.exists(output_file): + return (trade_date, True) + + try: + # 计算滚动窗口范围(该日期之前的 N 天) + baseline_days = CONFIG['baseline_days'] + + # 找出 trade_date 之前的交易日 + start_idx = max(0, day_index - baseline_days) + end_idx = day_index # 不包含当天 + + if end_idx <= start_idx: + # 没有足够的历史数据 + return (trade_date, False) + + historical_days = all_trading_days[start_idx:end_idx] + + # 加载历史原始数据 + historical_dfs = [] + for hist_date in historical_days: + cache_file = os.path.join(RAW_CACHE_DIR, f'raw_{hist_date}.parquet') + if os.path.exists(cache_file): + historical_dfs.append(pd.read_parquet(cache_file)) + else: + # 需要计算 + hist_df = compute_raw_concept_features(hist_date, concepts, all_stocks) + if not hist_df.empty: + historical_dfs.append(hist_df) + + if not historical_dfs: + return (trade_date, False) + + historical_raw_data = pd.concat(historical_dfs, ignore_index=True) + + # 计算当日 Z-Score 特征(使用滚动基线) + df = compute_zscore_features_rolling(trade_date, concepts, all_stocks, historical_raw_data) + + if df.empty: + return (trade_date, False) + + df.to_parquet(output_file, index=False) + return (trade_date, True) + + except Exception as e: + print(f"[{trade_date}] 处理失败: {e}") + import traceback + traceback.print_exc() + return (trade_date, False) + + +# ==================== 主流程 ==================== + +def main(): + import argparse + + parser = argparse.ArgumentParser(description='准备训练数据 V2(滚动窗口基线 + Z-Score 动量)') + parser.add_argument('--start', type=str, default='2022-01-01', help='开始日期') + parser.add_argument('--end', type=str, default=None, help='结束日期(默认今天)') + parser.add_argument('--workers', type=int, default=18, help='并行进程数') + parser.add_argument('--baseline-days', type=int, default=20, help='滚动基线窗口大小') + parser.add_argument('--force', action='store_true', help='强制重新计算(忽略缓存)') + + args = parser.parse_args() + + end_date = args.end or datetime.now().strftime('%Y-%m-%d') + CONFIG['baseline_days'] = args.baseline_days + + print("=" * 60) + print("数据准备 V2 - 滚动窗口基线 + Z-Score 动量") + print("=" * 60) + print(f"日期范围: {args.start} ~ {end_date}") + print(f"并行进程数: {args.workers}") + print(f"滚动基线窗口: {args.baseline_days} 天") + + # 初始化主进程连接 + init_process_connections() + + # 1. 获取概念列表 + concepts = get_all_concepts() + all_stocks = list(set(s for c in concepts for s in c['stocks'])) + print(f"股票总数: {len(all_stocks)}") + + # 2. 获取交易日列表 + trading_days = get_trading_days(args.start, end_date) + + if not trading_days: + print("无交易日数据") + return + + # 3. 第一阶段:预计算所有原始特征(用于缓存) + print(f"\n{'='*60}") + print("第一阶段:预计算原始特征(用于滚动基线)") + print(f"{'='*60}") + + # 如果强制重新计算,删除缓存 + if args.force: + import shutil + if os.path.exists(RAW_CACHE_DIR): + shutil.rmtree(RAW_CACHE_DIR) + os.makedirs(RAW_CACHE_DIR, exist_ok=True) + if os.path.exists(OUTPUT_DIR): + for f in os.listdir(OUTPUT_DIR): + if f.startswith('features_v2_'): + os.remove(os.path.join(OUTPUT_DIR, f)) + + # 单线程预计算原始特征(因为需要顺序缓存) + print(f"预计算 {len(trading_days)} 天的原始特征...") + for trade_date in tqdm(trading_days, desc="预计算原始特征"): + cache_file = os.path.join(RAW_CACHE_DIR, f'raw_{trade_date}.parquet') + if not os.path.exists(cache_file): + compute_raw_concept_features(trade_date, concepts, all_stocks) + + # 4. 第二阶段:计算 Z-Score 特征(多进程) + print(f"\n{'='*60}") + print("第二阶段:计算 Z-Score 特征(滚动基线)") + print(f"{'='*60}") + + # 从第 baseline_days 天开始(前面的没有足够历史) + start_idx = args.baseline_days + processable_days = trading_days[start_idx:] + + if not processable_days: + print(f"错误:需要至少 {args.baseline_days + 1} 天的数据") + return + + print(f"可处理日期: {processable_days[0]} ~ {processable_days[-1]} ({len(processable_days)} 天)") + print(f"跳过前 {start_idx} 天(基线预热期)") + + # 构建任务 + tasks = [] + for i, trade_date in enumerate(trading_days): + if i >= start_idx: + tasks.append((trade_date, i, concepts, all_stocks, trading_days)) + + print(f"开始处理 {len(tasks)} 个交易日({args.workers} 进程并行)...") + + success_count = 0 + failed_dates = [] + + # 使用进程池初始化器 + with ProcessPoolExecutor(max_workers=args.workers, initializer=init_process_connections) as executor: + futures = {executor.submit(process_single_day_v2, task): task[0] for task in tasks} + + with tqdm(total=len(futures), desc="处理进度", unit="天") as pbar: + for future in as_completed(futures): + trade_date = futures[future] + try: + result_date, success = future.result() + if success: + success_count += 1 + else: + failed_dates.append(result_date) + except Exception as e: + print(f"\n[{trade_date}] 进程异常: {e}") + failed_dates.append(trade_date) + pbar.update(1) + + print("\n" + "=" * 60) + print(f"处理完成: {success_count}/{len(tasks)} 个交易日") + if failed_dates: + print(f"失败日期: {failed_dates[:10]}{'...' if len(failed_dates) > 10 else ''}") + print(f"数据保存在: {OUTPUT_DIR}") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/ml/realtime_detector.py b/ml/realtime_detector.py index 0144add3..3af09828 100644 --- a/ml/realtime_detector.py +++ b/ml/realtime_detector.py @@ -190,20 +190,22 @@ def get_all_concepts() -> List[dict]: def get_prev_close(stock_codes: List[str], trade_date: str) -> Dict[str, float]: - """获取昨收价""" + """获取昨收价(上一交易日的收盘价 F007N)""" valid_codes = [c for c in stock_codes if c and len(c) == 6 and c.isdigit()] if not valid_codes: return {} codes_str = "','".join(valid_codes) + # 注意:F007N 是"最近成交价"即当日收盘价,F002N 是"昨日收盘价" + # 我们需要查上一交易日的 F007N(那天的收盘价)作为今天的昨收 query = f""" - SELECT SECCODE, F002N + SELECT SECCODE, F007N FROM ea_trade WHERE SECCODE IN ('{codes_str}') AND TRADEDATE = ( SELECT MAX(TRADEDATE) FROM ea_trade WHERE TRADEDATE < '{trade_date}' ) - AND F002N IS NOT NULL AND F002N > 0 + AND F007N IS NOT NULL AND F007N > 0 """ try: diff --git a/ml/realtime_detector_v2.py b/ml/realtime_detector_v2.py new file mode 100644 index 00000000..e5ccc942 --- /dev/null +++ b/ml/realtime_detector_v2.py @@ -0,0 +1,729 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +V2 实时异动检测器 + +使用方法: + # 作为模块导入 + from ml.realtime_detector_v2 import RealtimeDetectorV2 + + detector = RealtimeDetectorV2() + alerts = detector.detect_realtime() # 检测当前时刻 + + # 或命令行测试 + python ml/realtime_detector_v2.py --date 2025-12-09 +""" + +import os +import sys +import json +import pickle +from datetime import datetime, timedelta +from typing import Dict, List, Optional, Tuple +from collections import defaultdict, deque + +import numpy as np +import pandas as pd +import torch +from sqlalchemy import create_engine, text +from elasticsearch import Elasticsearch +from clickhouse_driver import Client + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from ml.model import TransformerAutoencoder + +# ==================== 配置 ==================== + +MYSQL_URL = "mysql+pymysql://root:Zzl5588161!@192.168.1.5:3306/stock" +ES_HOST = 'http://127.0.0.1:9200' +ES_INDEX = 'concept_library_v3' + +CLICKHOUSE_CONFIG = { + 'host': '127.0.0.1', + 'port': 9000, + 'user': 'default', + 'password': 'Zzl33818!', + 'database': 'stock' +} + +REFERENCE_INDEX = '000001.SH' +BASELINE_FILE = 'ml/data_v2/baselines/realtime_baseline.pkl' +MODEL_DIR = 'ml/checkpoints_v2' + +# 检测配置 +CONFIG = { + 'seq_len': 10, # LSTM 序列长度 + 'confirm_window': 5, # 持续确认窗口 + 'confirm_ratio': 0.6, # 确认比例 + 'rule_weight': 0.5, + 'ml_weight': 0.5, + 'rule_trigger': 60, + 'ml_trigger': 70, + 'fusion_trigger': 50, + 'cooldown_minutes': 10, + 'max_alerts_per_minute': 15, + 'zscore_clip': 5.0, + 'limit_up_threshold': 9.8, +} + +FEATURES = ['alpha_zscore', 'amt_zscore', 'rank_zscore', 'momentum_3m', 'momentum_5m', 'limit_up_ratio'] + + +# ==================== 数据库连接 ==================== + +_mysql_engine = None +_es_client = None +_ch_client = None + + +def get_mysql_engine(): + global _mysql_engine + if _mysql_engine is None: + _mysql_engine = create_engine(MYSQL_URL, echo=False, pool_pre_ping=True) + return _mysql_engine + + +def get_es_client(): + global _es_client + if _es_client is None: + _es_client = Elasticsearch([ES_HOST]) + return _es_client + + +def get_ch_client(): + global _ch_client + if _ch_client is None: + _ch_client = Client(**CLICKHOUSE_CONFIG) + return _ch_client + + +def code_to_ch_format(code: str) -> str: + if not code or len(code) != 6 or not code.isdigit(): + return None + if code.startswith('6'): + return f"{code}.SH" + elif code.startswith('0') or code.startswith('3'): + return f"{code}.SZ" + return f"{code}.BJ" + + +def time_to_slot(ts) -> str: + if isinstance(ts, str): + return ts + return ts.strftime('%H:%M') + + +# ==================== 规则评分 ==================== + +def score_rules_zscore(features: Dict) -> Tuple[float, List[str]]: + """基于 Z-Score 的规则评分""" + score = 0.0 + triggered = [] + + alpha_z = abs(features.get('alpha_zscore', 0)) + amt_z = features.get('amt_zscore', 0) + rank_z = abs(features.get('rank_zscore', 0)) + mom_3m = features.get('momentum_3m', 0) + mom_5m = features.get('momentum_5m', 0) + limit_up = features.get('limit_up_ratio', 0) + + # Alpha Z-Score + if alpha_z >= 4.0: + score += 25 + triggered.append('alpha_extreme') + elif alpha_z >= 3.0: + score += 18 + triggered.append('alpha_strong') + elif alpha_z >= 2.0: + score += 10 + triggered.append('alpha_moderate') + + # 成交额 Z-Score + if amt_z >= 4.0: + score += 20 + triggered.append('amt_extreme') + elif amt_z >= 3.0: + score += 12 + triggered.append('amt_strong') + elif amt_z >= 2.0: + score += 6 + triggered.append('amt_moderate') + + # 排名 Z-Score + if rank_z >= 3.0: + score += 15 + triggered.append('rank_extreme') + elif rank_z >= 2.0: + score += 8 + triggered.append('rank_strong') + + # 动量(基于 Z-Score 的) + if mom_3m >= 1.0: + score += 12 + triggered.append('momentum_3m_strong') + elif mom_3m >= 0.5: + score += 6 + triggered.append('momentum_3m_moderate') + + if mom_5m >= 1.5: + score += 10 + triggered.append('momentum_5m_strong') + + # 涨停比例 + if limit_up >= 0.3: + score += 20 + triggered.append('limit_up_extreme') + elif limit_up >= 0.15: + score += 12 + triggered.append('limit_up_strong') + elif limit_up >= 0.08: + score += 5 + triggered.append('limit_up_moderate') + + # 组合规则 + if alpha_z >= 2.0 and amt_z >= 2.0: + score += 15 + triggered.append('combo_alpha_amt') + + if alpha_z >= 2.0 and limit_up >= 0.1: + score += 12 + triggered.append('combo_alpha_limitup') + + return min(score, 100), triggered + + +# ==================== 实时检测器 ==================== + +class RealtimeDetectorV2: + """V2 实时异动检测器""" + + def __init__(self, model_dir: str = MODEL_DIR, baseline_file: str = BASELINE_FILE): + print("初始化 V2 实时检测器...") + + # 加载概念 + self.concepts = self._load_concepts() + self.concept_stocks = {c['concept_id']: set(c['stocks']) for c in self.concepts} + self.all_stocks = list(set(s for c in self.concepts for s in c['stocks'])) + + # 加载基线 + self.baselines = self._load_baselines(baseline_file) + + # 加载模型 + self.model, self.thresholds, self.device = self._load_model(model_dir) + + # 状态管理 + self.zscore_history = defaultdict(lambda: deque(maxlen=CONFIG['seq_len'])) + self.anomaly_candidates = defaultdict(lambda: deque(maxlen=CONFIG['confirm_window'])) + self.cooldown = {} + + print(f"初始化完成: {len(self.concepts)} 概念, {len(self.baselines)} 基线") + + def _load_concepts(self) -> List[dict]: + """从 ES 加载概念""" + es = get_es_client() + concepts = [] + + query = {"query": {"match_all": {}}, "size": 100, "_source": ["concept_id", "concept", "stocks"]} + resp = es.search(index=ES_INDEX, body=query, scroll='2m') + scroll_id = resp['_scroll_id'] + hits = resp['hits']['hits'] + + while hits: + for hit in hits: + src = hit['_source'] + stocks = [s['code'] for s in src.get('stocks', []) if isinstance(s, dict) and s.get('code')] + if stocks: + concepts.append({ + 'concept_id': src.get('concept_id'), + 'concept_name': src.get('concept'), + 'stocks': stocks + }) + resp = es.scroll(scroll_id=scroll_id, scroll='2m') + scroll_id = resp['_scroll_id'] + hits = resp['hits']['hits'] + + es.clear_scroll(scroll_id=scroll_id) + return concepts + + def _load_baselines(self, baseline_file: str) -> Dict: + """加载基线""" + if not os.path.exists(baseline_file): + print(f"警告: 基线文件不存在: {baseline_file}") + print("请先运行: python ml/update_baseline.py") + return {} + + with open(baseline_file, 'rb') as f: + data = pickle.load(f) + + print(f"基线日期范围: {data.get('date_range', 'unknown')}") + print(f"更新时间: {data.get('update_time', 'unknown')}") + + return data.get('baselines', {}) + + def _load_model(self, model_dir: str): + """加载模型""" + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + config_path = os.path.join(model_dir, 'config.json') + model_path = os.path.join(model_dir, 'best_model.pt') + threshold_path = os.path.join(model_dir, 'thresholds.json') + + if not os.path.exists(model_path): + print(f"警告: 模型不存在: {model_path}") + return None, {}, device + + with open(config_path) as f: + config = json.load(f) + + model = TransformerAutoencoder(**config['model']) + checkpoint = torch.load(model_path, map_location=device) + model.load_state_dict(checkpoint['model_state_dict']) + model.to(device) + model.eval() + + thresholds = {} + if os.path.exists(threshold_path): + with open(threshold_path) as f: + thresholds = json.load(f) + + print(f"模型已加载: {model_path}") + return model, thresholds, device + + def _get_realtime_data(self, trade_date: str) -> pd.DataFrame: + """获取实时数据并计算原始特征""" + ch = get_ch_client() + + # 获取股票数据 + ch_codes = [code_to_ch_format(c) for c in self.all_stocks if code_to_ch_format(c)] + ch_codes_str = "','".join(ch_codes) + + stock_query = f""" + SELECT code, timestamp, close, amt + FROM stock_minute + WHERE toDate(timestamp) = '{trade_date}' + AND code IN ('{ch_codes_str}') + ORDER BY timestamp + """ + stock_result = ch.execute(stock_query) + if not stock_result: + return pd.DataFrame() + + stock_df = pd.DataFrame(stock_result, columns=['ch_code', 'timestamp', 'close', 'amt']) + + # 映射回原始代码 + ch_to_code = {code_to_ch_format(c): c for c in self.all_stocks if code_to_ch_format(c)} + stock_df['code'] = stock_df['ch_code'].map(ch_to_code) + stock_df = stock_df.dropna(subset=['code']) + + # 获取指数数据 + index_query = f""" + SELECT timestamp, close + FROM index_minute + WHERE toDate(timestamp) = '{trade_date}' + AND code = '{REFERENCE_INDEX}' + ORDER BY timestamp + """ + index_result = ch.execute(index_query) + if not index_result: + return pd.DataFrame() + + index_df = pd.DataFrame(index_result, columns=['timestamp', 'close']) + + # 获取昨收价 + engine = get_mysql_engine() + codes_str = "','".join([c for c in self.all_stocks if c and len(c) == 6]) + + with engine.connect() as conn: + prev_result = conn.execute(text(f""" + SELECT SECCODE, F007N FROM ea_trade + WHERE SECCODE IN ('{codes_str}') + AND TRADEDATE = (SELECT MAX(TRADEDATE) FROM ea_trade WHERE TRADEDATE < '{trade_date}') + AND F007N > 0 + """)) + prev_close = {row[0]: float(row[1]) for row in prev_result if row[1]} + + idx_result = conn.execute(text(""" + SELECT F006N FROM ea_exchangetrade + WHERE INDEXCODE = '000001' AND TRADEDATE < :today + ORDER BY TRADEDATE DESC LIMIT 1 + """), {'today': trade_date}).fetchone() + index_prev_close = float(idx_result[0]) if idx_result else None + + if not prev_close or not index_prev_close: + return pd.DataFrame() + + # 计算涨跌幅 + stock_df['prev_close'] = stock_df['code'].map(prev_close) + stock_df = stock_df.dropna(subset=['prev_close']) + stock_df['change_pct'] = (stock_df['close'] - stock_df['prev_close']) / stock_df['prev_close'] * 100 + + index_df['change_pct'] = (index_df['close'] - index_prev_close) / index_prev_close * 100 + index_map = dict(zip(index_df['timestamp'], index_df['change_pct'])) + + # 按时间聚合概念特征 + results = [] + for ts in sorted(stock_df['timestamp'].unique()): + ts_data = stock_df[stock_df['timestamp'] == ts] + idx_chg = index_map.get(ts, 0) + + stock_chg = dict(zip(ts_data['code'], ts_data['change_pct'])) + stock_amt = dict(zip(ts_data['code'], ts_data['amt'])) + + for cid, stocks in self.concept_stocks.items(): + changes = [stock_chg[s] for s in stocks if s in stock_chg] + amts = [stock_amt.get(s, 0) for s in stocks if s in stock_chg] + + if not changes: + continue + + alpha = np.mean(changes) - idx_chg + total_amt = sum(amts) + limit_up_ratio = sum(1 for c in changes if c >= CONFIG['limit_up_threshold']) / len(changes) + + results.append({ + 'concept_id': cid, + 'timestamp': ts, + 'time_slot': time_to_slot(ts), + 'alpha': alpha, + 'total_amt': total_amt, + 'limit_up_ratio': limit_up_ratio, + 'stock_count': len(changes), + }) + + if not results: + return pd.DataFrame() + + df = pd.DataFrame(results) + + # 计算排名 + for ts in df['timestamp'].unique(): + mask = df['timestamp'] == ts + df.loc[mask, 'rank_pct'] = df.loc[mask, 'alpha'].rank(pct=True) + + return df + + def _compute_zscore(self, concept_id: str, time_slot: str, alpha: float, total_amt: float, rank_pct: float) -> Optional[Dict]: + """计算 Z-Score""" + if concept_id not in self.baselines: + return None + + baseline = self.baselines[concept_id] + if time_slot not in baseline: + return None + + bl = baseline[time_slot] + + alpha_z = np.clip((alpha - bl['alpha_mean']) / bl['alpha_std'], -5, 5) + amt_z = np.clip((total_amt - bl['amt_mean']) / bl['amt_std'], -5, 5) + rank_z = np.clip((rank_pct - bl['rank_mean']) / bl['rank_std'], -5, 5) + + # 动量(基于 Z-Score 历史) + history = list(self.zscore_history[concept_id]) + mom_3m = 0.0 + mom_5m = 0.0 + + if len(history) >= 3: + recent = [h['alpha_zscore'] for h in history[-3:]] + older = [h['alpha_zscore'] for h in history[-6:-3]] if len(history) >= 6 else [history[0]['alpha_zscore']] + mom_3m = np.mean(recent) - np.mean(older) + + if len(history) >= 5: + recent = [h['alpha_zscore'] for h in history[-5:]] + older = [h['alpha_zscore'] for h in history[-10:-5]] if len(history) >= 10 else [history[0]['alpha_zscore']] + mom_5m = np.mean(recent) - np.mean(older) + + return { + 'alpha_zscore': float(alpha_z), + 'amt_zscore': float(amt_z), + 'rank_zscore': float(rank_z), + 'momentum_3m': float(mom_3m), + 'momentum_5m': float(mom_5m), + } + + @torch.no_grad() + def _ml_score(self, sequences: np.ndarray) -> np.ndarray: + """批量 ML 评分""" + if self.model is None or len(sequences) == 0: + return np.zeros(len(sequences)) + + x = torch.FloatTensor(sequences).to(self.device) + errors = self.model.compute_reconstruction_error(x, reduction='none') + last_errors = errors[:, -1].cpu().numpy() + + # 转换为 0-100 分数 + if self.thresholds: + p50 = self.thresholds.get('median', 0.001) + p99 = self.thresholds.get('p99', 0.05) + scores = 50 + (last_errors - p50) / (p99 - p50 + 1e-6) * 49 + else: + scores = last_errors * 1000 + + return np.clip(scores, 0, 100) + + def detect(self, trade_date: str = None) -> List[Dict]: + """检测指定日期的异动""" + trade_date = trade_date or datetime.now().strftime('%Y-%m-%d') + print(f"\n检测 {trade_date} 的异动...") + + # 重置状态 + self.zscore_history.clear() + self.anomaly_candidates.clear() + self.cooldown.clear() + + # 获取数据 + raw_df = self._get_realtime_data(trade_date) + if raw_df.empty: + print("无数据") + return [] + + timestamps = sorted(raw_df['timestamp'].unique()) + print(f"时间点数: {len(timestamps)}") + + all_alerts = [] + + for ts in timestamps: + ts_data = raw_df[raw_df['timestamp'] == ts] + time_slot = time_to_slot(ts) + + candidates = [] + + # 计算每个概念的 Z-Score + for _, row in ts_data.iterrows(): + cid = row['concept_id'] + + zscore = self._compute_zscore( + cid, time_slot, + row['alpha'], row['total_amt'], row['rank_pct'] + ) + + if zscore is None: + continue + + # 完整特征 + features = { + **zscore, + 'alpha': row['alpha'], + 'limit_up_ratio': row['limit_up_ratio'], + 'total_amt': row['total_amt'], + } + + # 更新历史 + self.zscore_history[cid].append(zscore) + + # 规则评分 + rule_score, triggered = score_rules_zscore(features) + + candidates.append((cid, features, rule_score, triggered)) + + if not candidates: + continue + + # 批量 ML 评分 + sequences = [] + valid_candidates = [] + + for cid, features, rule_score, triggered in candidates: + history = list(self.zscore_history[cid]) + if len(history) >= CONFIG['seq_len']: + seq = np.array([[h['alpha_zscore'], h['amt_zscore'], h['rank_zscore'], + h['momentum_3m'], h['momentum_5m'], features['limit_up_ratio']] + for h in history]) + sequences.append(seq) + valid_candidates.append((cid, features, rule_score, triggered)) + + if not sequences: + continue + + ml_scores = self._ml_score(np.array(sequences)) + + # 融合 + 确认 + for i, (cid, features, rule_score, triggered) in enumerate(valid_candidates): + ml_score = ml_scores[i] + final_score = CONFIG['rule_weight'] * rule_score + CONFIG['ml_weight'] * ml_score + + # 判断触发 + is_triggered = ( + rule_score >= CONFIG['rule_trigger'] or + ml_score >= CONFIG['ml_trigger'] or + final_score >= CONFIG['fusion_trigger'] + ) + + self.anomaly_candidates[cid].append((ts, final_score)) + + if not is_triggered: + continue + + # 冷却期 + if cid in self.cooldown: + if (ts - self.cooldown[cid]).total_seconds() < CONFIG['cooldown_minutes'] * 60: + continue + + # 持续确认 + recent = list(self.anomaly_candidates[cid]) + if len(recent) < CONFIG['confirm_window']: + continue + + exceed = sum(1 for _, s in recent if s >= CONFIG['fusion_trigger']) + ratio = exceed / len(recent) + + if ratio < CONFIG['confirm_ratio']: + continue + + # 确认异动! + self.cooldown[cid] = ts + + alpha = features['alpha'] + alert_type = 'surge_up' if alpha >= 1.5 else 'surge_down' if alpha <= -1.5 else 'surge' + + concept_name = next((c['concept_name'] for c in self.concepts if c['concept_id'] == cid), cid) + + all_alerts.append({ + 'concept_id': cid, + 'concept_name': concept_name, + 'alert_time': ts, + 'trade_date': trade_date, + 'alert_type': alert_type, + 'final_score': float(final_score), + 'rule_score': float(rule_score), + 'ml_score': float(ml_score), + 'confirm_ratio': float(ratio), + 'alpha': float(alpha), + 'alpha_zscore': float(features['alpha_zscore']), + 'amt_zscore': float(features['amt_zscore']), + 'rank_zscore': float(features['rank_zscore']), + 'momentum_3m': float(features['momentum_3m']), + 'momentum_5m': float(features['momentum_5m']), + 'limit_up_ratio': float(features['limit_up_ratio']), + 'triggered_rules': triggered, + 'trigger_reason': f"融合({final_score:.0f})+确认({ratio:.0%})", + }) + + print(f"检测到 {len(all_alerts)} 个异动") + return all_alerts + + +# ==================== 数据库存储 ==================== + +def create_v2_table(): + """创建 V2 异动表(如果不存在)""" + engine = get_mysql_engine() + with engine.begin() as conn: + conn.execute(text(""" + CREATE TABLE IF NOT EXISTS concept_anomaly_v2 ( + id INT AUTO_INCREMENT PRIMARY KEY, + concept_id VARCHAR(50) NOT NULL, + alert_time DATETIME NOT NULL, + trade_date DATE NOT NULL, + alert_type VARCHAR(20) NOT NULL, + final_score FLOAT, + rule_score FLOAT, + ml_score FLOAT, + trigger_reason VARCHAR(200), + confirm_ratio FLOAT, + alpha FLOAT, + alpha_zscore FLOAT, + amt_zscore FLOAT, + rank_zscore FLOAT, + momentum_3m FLOAT, + momentum_5m FLOAT, + limit_up_ratio FLOAT, + triggered_rules TEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + UNIQUE KEY uk_concept_time (concept_id, alert_time), + INDEX idx_trade_date (trade_date), + INDEX idx_alert_type (alert_type) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 + """)) + print("concept_anomaly_v2 表已就绪") + + +def save_alerts_to_db(alerts: List[Dict]) -> int: + """保存异动到数据库""" + if not alerts: + return 0 + + engine = get_mysql_engine() + saved = 0 + + with engine.begin() as conn: + for alert in alerts: + try: + insert_sql = text(""" + INSERT IGNORE INTO concept_anomaly_v2 + (concept_id, alert_time, trade_date, alert_type, + final_score, rule_score, ml_score, trigger_reason, confirm_ratio, + alpha, alpha_zscore, amt_zscore, rank_zscore, + momentum_3m, momentum_5m, limit_up_ratio, triggered_rules) + VALUES + (:concept_id, :alert_time, :trade_date, :alert_type, + :final_score, :rule_score, :ml_score, :trigger_reason, :confirm_ratio, + :alpha, :alpha_zscore, :amt_zscore, :rank_zscore, + :momentum_3m, :momentum_5m, :limit_up_ratio, :triggered_rules) + """) + + result = conn.execute(insert_sql, { + 'concept_id': alert['concept_id'], + 'alert_time': alert['alert_time'], + 'trade_date': alert['trade_date'], + 'alert_type': alert['alert_type'], + 'final_score': alert['final_score'], + 'rule_score': alert['rule_score'], + 'ml_score': alert['ml_score'], + 'trigger_reason': alert['trigger_reason'], + 'confirm_ratio': alert['confirm_ratio'], + 'alpha': alert['alpha'], + 'alpha_zscore': alert['alpha_zscore'], + 'amt_zscore': alert['amt_zscore'], + 'rank_zscore': alert['rank_zscore'], + 'momentum_3m': alert['momentum_3m'], + 'momentum_5m': alert['momentum_5m'], + 'limit_up_ratio': alert['limit_up_ratio'], + 'triggered_rules': json.dumps(alert.get('triggered_rules', []), ensure_ascii=False), + }) + + if result.rowcount > 0: + saved += 1 + except Exception as e: + print(f"保存失败: {alert['concept_id']} - {e}") + + return saved + + +def main(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('--date', type=str, default=None) + parser.add_argument('--no-save', action='store_true', help='不保存到数据库,只打印') + args = parser.parse_args() + + # 确保表存在 + if not args.no_save: + create_v2_table() + + detector = RealtimeDetectorV2() + alerts = detector.detect(args.date) + + print(f"\n{'='*60}") + print(f"检测结果 ({len(alerts)} 个异动)") + print('='*60) + + for a in alerts[:20]: + print(f"[{a['alert_time'].strftime('%H:%M') if hasattr(a['alert_time'], 'strftime') else a['alert_time']}] " + f"{a['concept_name']} | {a['alert_type']} | " + f"分数={a['final_score']:.0f} 确认={a['confirm_ratio']:.0%} " + f"α={a['alpha']:.2f}% αZ={a['alpha_zscore']:.1f}") + + if len(alerts) > 20: + print(f"... 共 {len(alerts)} 个") + + # 保存到数据库 + if not args.no_save and alerts: + saved = save_alerts_to_db(alerts) + print(f"\n✅ 已保存 {saved}/{len(alerts)} 条到 concept_anomaly_v2 表") + elif args.no_save: + print(f"\n⚠️ --no-save 模式,未保存到数据库") + + +if __name__ == "__main__": + main() diff --git a/ml/train_v2.py b/ml/train_v2.py new file mode 100644 index 00000000..47e6cdb7 --- /dev/null +++ b/ml/train_v2.py @@ -0,0 +1,622 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +训练脚本 V2 - 基于 Z-Score 特征的 LSTM Autoencoder + +改进点: +1. 使用 Z-Score 特征(相对于同时间片历史的偏离) +2. 短序列:10分钟(不需要30分钟预热) +3. 开盘即可检测:9:30 直接有特征 + +模型输入: +- 过去10分钟的 Z-Score 特征序列 +- 特征:alpha_zscore, amt_zscore, rank_zscore, momentum_3m, momentum_5m, limit_up_ratio + +模型学习: +- 学习 Z-Score 序列的"正常演化模式" +- 异动 = Z-Score 序列的异常演化(重构误差大) +""" + +import os +import sys +import argparse +import json +from datetime import datetime +from pathlib import Path +from typing import List, Tuple, Dict + +import numpy as np +import pandas as pd +import torch +import torch.nn as nn +from torch.utils.data import Dataset, DataLoader +from torch.optim import AdamW +from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts +from tqdm import tqdm + +from model import TransformerAutoencoder, AnomalyDetectionLoss, count_parameters + +# 性能优化 +torch.backends.cudnn.benchmark = True +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True + +try: + import matplotlib + matplotlib.use('Agg') + import matplotlib.pyplot as plt + HAS_MATPLOTLIB = True +except ImportError: + HAS_MATPLOTLIB = False + + +# ==================== 配置 ==================== + +TRAIN_CONFIG = { + # 数据配置(改进!) + 'seq_len': 10, # 10分钟序列(不是30分钟!) + 'stride': 2, # 步长2分钟 + + # 时间切分 + 'train_end_date': '2024-06-30', + 'val_end_date': '2024-09-30', + + # V2 特征(Z-Score 为主) + 'features': [ + 'alpha_zscore', # Alpha 的 Z-Score + 'amt_zscore', # 成交额的 Z-Score + 'rank_zscore', # 排名的 Z-Score + 'momentum_3m', # 3分钟动量 + 'momentum_5m', # 5分钟动量 + 'limit_up_ratio', # 涨停占比 + ], + + # 训练配置 + 'batch_size': 4096, + 'epochs': 100, + 'learning_rate': 3e-4, + 'weight_decay': 1e-5, + 'gradient_clip': 1.0, + + # 早停配置 + 'patience': 15, + 'min_delta': 1e-6, + + # 模型配置(小型 LSTM) + 'model': { + 'n_features': 6, + 'hidden_dim': 32, + 'latent_dim': 4, + 'num_layers': 1, + 'dropout': 0.2, + 'bidirectional': True, + }, + + # 标准化配置 + 'clip_value': 5.0, # Z-Score 已经标准化,clip 5.0 足够 + + # 阈值配置 + 'threshold_percentiles': [90, 95, 99], +} + + +# ==================== 数据加载 ==================== + +def load_data_by_date(data_dir: str, features: List[str]) -> Dict[str, pd.DataFrame]: + """按日期加载 V2 数据""" + data_path = Path(data_dir) + parquet_files = sorted(data_path.glob("features_v2_*.parquet")) + + if not parquet_files: + raise FileNotFoundError(f"未找到 V2 数据文件: {data_dir}") + + print(f"找到 {len(parquet_files)} 个 V2 数据文件") + + date_data = {} + + for pf in tqdm(parquet_files, desc="加载数据"): + date = pf.stem.replace('features_v2_', '') + + df = pd.read_parquet(pf) + + required_cols = features + ['concept_id', 'timestamp'] + missing_cols = [c for c in required_cols if c not in df.columns] + if missing_cols: + print(f"警告: {date} 缺少列: {missing_cols}, 跳过") + continue + + date_data[date] = df + + print(f"成功加载 {len(date_data)} 天的数据") + return date_data + + +def split_data_by_date( + date_data: Dict[str, pd.DataFrame], + train_end: str, + val_end: str +) -> Tuple[Dict[str, pd.DataFrame], Dict[str, pd.DataFrame], Dict[str, pd.DataFrame]]: + """按日期划分数据集""" + train_data = {} + val_data = {} + test_data = {} + + for date, df in date_data.items(): + if date <= train_end: + train_data[date] = df + elif date <= val_end: + val_data[date] = df + else: + test_data[date] = df + + print(f"数据集划分:") + print(f" 训练集: {len(train_data)} 天 (<= {train_end})") + print(f" 验证集: {len(val_data)} 天 ({train_end} ~ {val_end})") + print(f" 测试集: {len(test_data)} 天 (> {val_end})") + + return train_data, val_data, test_data + + +def build_sequences_by_concept( + date_data: Dict[str, pd.DataFrame], + features: List[str], + seq_len: int, + stride: int +) -> np.ndarray: + """按概念分组构建序列""" + all_dfs = [] + for date, df in sorted(date_data.items()): + df = df.copy() + df['date'] = date + all_dfs.append(df) + + if not all_dfs: + return np.array([]) + + combined = pd.concat(all_dfs, ignore_index=True) + combined = combined.sort_values(['concept_id', 'date', 'timestamp']) + + all_sequences = [] + grouped = combined.groupby('concept_id', sort=False) + n_concepts = len(grouped) + + for concept_id, concept_df in tqdm(grouped, desc="构建序列", total=n_concepts, leave=False): + feature_data = concept_df[features].values + feature_data = np.nan_to_num(feature_data, nan=0.0, posinf=0.0, neginf=0.0) + + n_points = len(feature_data) + for start in range(0, n_points - seq_len + 1, stride): + seq = feature_data[start:start + seq_len] + all_sequences.append(seq) + + if not all_sequences: + return np.array([]) + + sequences = np.array(all_sequences) + print(f" 构建序列: {len(sequences):,} 条 (来自 {n_concepts} 个概念)") + + return sequences + + +# ==================== 数据集 ==================== + +class SequenceDataset(Dataset): + def __init__(self, sequences: np.ndarray): + self.sequences = torch.FloatTensor(sequences) + + def __len__(self) -> int: + return len(self.sequences) + + def __getitem__(self, idx: int) -> torch.Tensor: + return self.sequences[idx] + + +# ==================== 训练器 ==================== + +class EarlyStopping: + def __init__(self, patience: int = 10, min_delta: float = 1e-6): + self.patience = patience + self.min_delta = min_delta + self.counter = 0 + self.best_loss = float('inf') + self.early_stop = False + + def __call__(self, val_loss: float) -> bool: + if val_loss < self.best_loss - self.min_delta: + self.best_loss = val_loss + self.counter = 0 + else: + self.counter += 1 + if self.counter >= self.patience: + self.early_stop = True + return self.early_stop + + +class Trainer: + def __init__( + self, + model: nn.Module, + train_loader: DataLoader, + val_loader: DataLoader, + config: Dict, + device: torch.device, + save_dir: str = 'ml/checkpoints_v2' + ): + self.model = model.to(device) + self.train_loader = train_loader + self.val_loader = val_loader + self.config = config + self.device = device + self.save_dir = Path(save_dir) + self.save_dir.mkdir(parents=True, exist_ok=True) + + self.optimizer = AdamW( + model.parameters(), + lr=config['learning_rate'], + weight_decay=config['weight_decay'] + ) + + self.scheduler = CosineAnnealingWarmRestarts( + self.optimizer, T_0=10, T_mult=2, eta_min=1e-6 + ) + + self.criterion = AnomalyDetectionLoss() + + self.early_stopping = EarlyStopping( + patience=config['patience'], + min_delta=config['min_delta'] + ) + + self.use_amp = torch.cuda.is_available() + self.scaler = torch.cuda.amp.GradScaler() if self.use_amp else None + if self.use_amp: + print(" ✓ 启用 AMP 混合精度训练") + + self.history = {'train_loss': [], 'val_loss': [], 'learning_rate': []} + self.best_val_loss = float('inf') + + def train_epoch(self) -> float: + self.model.train() + total_loss = 0.0 + n_batches = 0 + + pbar = tqdm(self.train_loader, desc="Training", leave=False) + for batch in pbar: + batch = batch.to(self.device, non_blocking=True) + self.optimizer.zero_grad(set_to_none=True) + + if self.use_amp: + with torch.cuda.amp.autocast(): + output, latent = self.model(batch) + loss, _ = self.criterion(output, batch, latent) + + self.scaler.scale(loss).backward() + self.scaler.unscale_(self.optimizer) + torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config['gradient_clip']) + self.scaler.step(self.optimizer) + self.scaler.update() + else: + output, latent = self.model(batch) + loss, _ = self.criterion(output, batch, latent) + loss.backward() + torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config['gradient_clip']) + self.optimizer.step() + + total_loss += loss.item() + n_batches += 1 + pbar.set_postfix({'loss': f"{loss.item():.4f}"}) + + return total_loss / n_batches + + @torch.no_grad() + def validate(self) -> float: + self.model.eval() + total_loss = 0.0 + n_batches = 0 + + for batch in self.val_loader: + batch = batch.to(self.device, non_blocking=True) + + if self.use_amp: + with torch.cuda.amp.autocast(): + output, latent = self.model(batch) + loss, _ = self.criterion(output, batch, latent) + else: + output, latent = self.model(batch) + loss, _ = self.criterion(output, batch, latent) + + total_loss += loss.item() + n_batches += 1 + + return total_loss / n_batches + + def save_checkpoint(self, epoch: int, val_loss: float, is_best: bool = False): + model_to_save = self.model.module if hasattr(self.model, 'module') else self.model + + checkpoint = { + 'epoch': epoch, + 'model_state_dict': model_to_save.state_dict(), + 'optimizer_state_dict': self.optimizer.state_dict(), + 'scheduler_state_dict': self.scheduler.state_dict(), + 'val_loss': val_loss, + 'config': self.config, + } + + torch.save(checkpoint, self.save_dir / 'last_checkpoint.pt') + + if is_best: + torch.save(checkpoint, self.save_dir / 'best_model.pt') + print(f" ✓ 保存最佳模型 (val_loss: {val_loss:.6f})") + + def train(self, epochs: int): + print(f"\n开始训练 ({epochs} epochs)...") + print(f"设备: {self.device}") + print(f"模型参数量: {count_parameters(self.model):,}") + + for epoch in range(1, epochs + 1): + print(f"\nEpoch {epoch}/{epochs}") + + train_loss = self.train_epoch() + val_loss = self.validate() + + self.scheduler.step() + current_lr = self.optimizer.param_groups[0]['lr'] + + self.history['train_loss'].append(train_loss) + self.history['val_loss'].append(val_loss) + self.history['learning_rate'].append(current_lr) + + print(f" Train Loss: {train_loss:.6f}") + print(f" Val Loss: {val_loss:.6f}") + print(f" LR: {current_lr:.2e}") + + is_best = val_loss < self.best_val_loss + if is_best: + self.best_val_loss = val_loss + self.save_checkpoint(epoch, val_loss, is_best) + + if self.early_stopping(val_loss): + print(f"\n早停触发!") + break + + print(f"\n训练完成!最佳验证损失: {self.best_val_loss:.6f}") + self.save_history() + + return self.history + + def save_history(self): + history_path = self.save_dir / 'training_history.json' + with open(history_path, 'w') as f: + json.dump(self.history, f, indent=2) + print(f"训练历史已保存: {history_path}") + + if HAS_MATPLOTLIB: + self.plot_training_curves() + + def plot_training_curves(self): + fig, axes = plt.subplots(1, 2, figsize=(14, 5)) + epochs = range(1, len(self.history['train_loss']) + 1) + + ax1 = axes[0] + ax1.plot(epochs, self.history['train_loss'], 'b-', label='Train Loss', linewidth=2) + ax1.plot(epochs, self.history['val_loss'], 'r-', label='Val Loss', linewidth=2) + ax1.set_xlabel('Epoch') + ax1.set_ylabel('Loss') + ax1.set_title('Training & Validation Loss (V2)') + ax1.legend() + ax1.grid(True, alpha=0.3) + + best_epoch = np.argmin(self.history['val_loss']) + 1 + best_val_loss = min(self.history['val_loss']) + ax1.axvline(x=best_epoch, color='g', linestyle='--', alpha=0.7) + ax1.scatter([best_epoch], [best_val_loss], color='g', s=100, zorder=5) + + ax2 = axes[1] + ax2.plot(epochs, self.history['learning_rate'], 'g-', linewidth=2) + ax2.set_xlabel('Epoch') + ax2.set_ylabel('Learning Rate') + ax2.set_title('Learning Rate Schedule') + ax2.set_yscale('log') + ax2.grid(True, alpha=0.3) + + plt.tight_layout() + plt.savefig(self.save_dir / 'training_curves.png', dpi=150, bbox_inches='tight') + plt.close() + print(f"训练曲线已保存") + + +# ==================== 阈值计算 ==================== + +@torch.no_grad() +def compute_thresholds( + model: nn.Module, + data_loader: DataLoader, + device: torch.device, + percentiles: List[float] = [90, 95, 99] +) -> Dict[str, float]: + """在验证集上计算阈值""" + model.eval() + all_errors = [] + + print("计算异动阈值...") + for batch in tqdm(data_loader, desc="Computing thresholds"): + batch = batch.to(device) + errors = model.compute_reconstruction_error(batch, reduction='none') + seq_errors = errors[:, -1] # 最后一个时刻 + all_errors.append(seq_errors.cpu().numpy()) + + all_errors = np.concatenate(all_errors) + + thresholds = {} + for p in percentiles: + threshold = np.percentile(all_errors, p) + thresholds[f'p{p}'] = float(threshold) + print(f" P{p}: {threshold:.6f}") + + thresholds['mean'] = float(np.mean(all_errors)) + thresholds['std'] = float(np.std(all_errors)) + thresholds['median'] = float(np.median(all_errors)) + + return thresholds + + +# ==================== 主函数 ==================== + +def main(): + parser = argparse.ArgumentParser(description='训练 V2 模型') + parser.add_argument('--data_dir', type=str, default='ml/data_v2', help='V2 数据目录') + parser.add_argument('--epochs', type=int, default=100) + parser.add_argument('--batch_size', type=int, default=4096) + parser.add_argument('--lr', type=float, default=3e-4) + parser.add_argument('--device', type=str, default='auto') + parser.add_argument('--save_dir', type=str, default='ml/checkpoints_v2') + parser.add_argument('--train_end', type=str, default='2024-06-30') + parser.add_argument('--val_end', type=str, default='2024-09-30') + parser.add_argument('--seq_len', type=int, default=10, help='序列长度(分钟)') + + args = parser.parse_args() + + config = TRAIN_CONFIG.copy() + config['batch_size'] = args.batch_size + config['epochs'] = args.epochs + config['learning_rate'] = args.lr + config['train_end_date'] = args.train_end + config['val_end_date'] = args.val_end + config['seq_len'] = args.seq_len + + if args.device == 'auto': + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + else: + device = torch.device(args.device) + + print("=" * 60) + print("概念异动检测模型训练 V2(Z-Score 特征)") + print("=" * 60) + print(f"数据目录: {args.data_dir}") + print(f"设备: {device}") + print(f"序列长度: {config['seq_len']} 分钟") + print(f"批次大小: {config['batch_size']}") + print(f"特征: {config['features']}") + print("=" * 60) + + # 1. 加载数据 + print("\n[1/6] 加载 V2 数据...") + date_data = load_data_by_date(args.data_dir, config['features']) + + # 2. 划分数据集 + print("\n[2/6] 划分数据集...") + train_data, val_data, test_data = split_data_by_date( + date_data, config['train_end_date'], config['val_end_date'] + ) + + # 3. 构建序列 + print("\n[3/6] 构建序列...") + print("训练集:") + train_sequences = build_sequences_by_concept( + train_data, config['features'], config['seq_len'], config['stride'] + ) + print("验证集:") + val_sequences = build_sequences_by_concept( + val_data, config['features'], config['seq_len'], config['stride'] + ) + + if len(train_sequences) == 0: + print("错误: 训练集为空!") + return + + # 4. 预处理 + print("\n[4/6] 数据预处理...") + clip_value = config['clip_value'] + print(f" Z-Score 特征已标准化,截断: ±{clip_value}") + + train_sequences = np.clip(train_sequences, -clip_value, clip_value) + if len(val_sequences) > 0: + val_sequences = np.clip(val_sequences, -clip_value, clip_value) + + # 保存配置 + save_dir = Path(args.save_dir) + save_dir.mkdir(parents=True, exist_ok=True) + + with open(save_dir / 'config.json', 'w') as f: + json.dump(config, f, indent=2) + + # 5. 创建数据加载器 + print("\n[5/6] 创建数据加载器...") + train_dataset = SequenceDataset(train_sequences) + val_dataset = SequenceDataset(val_sequences) if len(val_sequences) > 0 else None + + print(f" 训练序列: {len(train_dataset):,}") + print(f" 验证序列: {len(val_dataset) if val_dataset else 0:,}") + + n_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 1 + num_workers = min(32, 8 * n_gpus) if sys.platform != 'win32' else 0 + + train_loader = DataLoader( + train_dataset, + batch_size=config['batch_size'], + shuffle=True, + num_workers=num_workers, + pin_memory=True, + prefetch_factor=4 if num_workers > 0 else None, + persistent_workers=True if num_workers > 0 else False, + drop_last=True + ) + + val_loader = DataLoader( + val_dataset, + batch_size=config['batch_size'] * 2, + shuffle=False, + num_workers=num_workers, + pin_memory=True, + ) if val_dataset else None + + # 6. 训练 + print("\n[6/6] 训练模型...") + model = TransformerAutoencoder(**config['model']) + + if torch.cuda.device_count() > 1: + print(f" 使用 {torch.cuda.device_count()} 张 GPU 并行训练") + model = nn.DataParallel(model) + + if val_loader is None: + print("警告: 验证集为空,使用训练集的 10% 作为验证") + split_idx = int(len(train_dataset) * 0.9) + train_subset = torch.utils.data.Subset(train_dataset, range(split_idx)) + val_subset = torch.utils.data.Subset(train_dataset, range(split_idx, len(train_dataset))) + train_loader = DataLoader(train_subset, batch_size=config['batch_size'], shuffle=True, num_workers=num_workers, pin_memory=True) + val_loader = DataLoader(val_subset, batch_size=config['batch_size'], shuffle=False, num_workers=num_workers, pin_memory=True) + + trainer = Trainer( + model=model, + train_loader=train_loader, + val_loader=val_loader, + config=config, + device=device, + save_dir=args.save_dir + ) + + trainer.train(config['epochs']) + + # 计算阈值 + print("\n[额外] 计算异动阈值...") + best_checkpoint = torch.load(save_dir / 'best_model.pt', map_location=device) + + # 创建新的单 GPU 模型用于计算阈值(避免 DataParallel 问题) + threshold_model = TransformerAutoencoder(**config['model']) + threshold_model.load_state_dict(best_checkpoint['model_state_dict']) + threshold_model.to(device) + threshold_model.eval() + + thresholds = compute_thresholds(threshold_model, val_loader, device, config['threshold_percentiles']) + + with open(save_dir / 'thresholds.json', 'w') as f: + json.dump(thresholds, f, indent=2) + + print("\n" + "=" * 60) + print("训练完成!") + print(f"模型保存位置: {args.save_dir}") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/ml/update_baseline.py b/ml/update_baseline.py new file mode 100644 index 00000000..7ee7e3cc --- /dev/null +++ b/ml/update_baseline.py @@ -0,0 +1,132 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +每日盘后运行:更新滚动基线 + +使用方法: + python ml/update_baseline.py + +建议加入 crontab,每天 15:30 后运行: + 30 15 * * 1-5 cd /path/to/project && python ml/update_baseline.py +""" + +import os +import sys +import pickle +import pandas as pd +import numpy as np +from datetime import datetime, timedelta +from pathlib import Path +from tqdm import tqdm + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from ml.prepare_data_v2 import ( + get_all_concepts, get_trading_days, compute_raw_concept_features, + init_process_connections, CONFIG, RAW_CACHE_DIR, BASELINE_DIR +) + + +def update_rolling_baseline(baseline_days: int = 20): + """ + 更新滚动基线(用于实盘检测) + + 基线 = 最近 N 个交易日每个时间片的统计量 + """ + print("=" * 60) + print("更新滚动基线(用于实盘)") + print("=" * 60) + + # 初始化连接 + init_process_connections() + + # 获取概念列表 + concepts = get_all_concepts() + all_stocks = list(set(s for c in concepts for s in c['stocks'])) + + # 获取最近的交易日 + today = datetime.now().strftime('%Y-%m-%d') + start_date = (datetime.now() - timedelta(days=60)).strftime('%Y-%m-%d') # 多取一些 + + trading_days = get_trading_days(start_date, today) + + if len(trading_days) < baseline_days: + print(f"错误:交易日不足 {baseline_days} 天") + return + + # 只取最近 N 天 + recent_days = trading_days[-baseline_days:] + print(f"使用 {len(recent_days)} 天数据: {recent_days[0]} ~ {recent_days[-1]}") + + # 加载原始数据 + all_data = [] + for trade_date in tqdm(recent_days, desc="加载数据"): + cache_file = os.path.join(RAW_CACHE_DIR, f'raw_{trade_date}.parquet') + + if os.path.exists(cache_file): + df = pd.read_parquet(cache_file) + else: + df = compute_raw_concept_features(trade_date, concepts, all_stocks) + + if not df.empty: + all_data.append(df) + + if not all_data: + print("错误:无数据") + return + + combined = pd.concat(all_data, ignore_index=True) + print(f"总数据量: {len(combined):,} 条") + + # 按概念计算基线 + baselines = {} + + for concept_id, group in tqdm(combined.groupby('concept_id'), desc="计算基线"): + baseline_dict = {} + + for time_slot, slot_group in group.groupby('time_slot'): + if len(slot_group) < CONFIG['min_baseline_samples']: + continue + + alpha_std = slot_group['alpha'].std() + amt_std = slot_group['total_amt'].std() + rank_std = slot_group['rank_pct'].std() + + baseline_dict[time_slot] = { + 'alpha_mean': float(slot_group['alpha'].mean()), + 'alpha_std': float(max(alpha_std if pd.notna(alpha_std) else 1.0, 0.1)), + 'amt_mean': float(slot_group['total_amt'].mean()), + 'amt_std': float(max(amt_std if pd.notna(amt_std) else slot_group['total_amt'].mean() * 0.5, 1.0)), + 'rank_mean': float(slot_group['rank_pct'].mean()), + 'rank_std': float(max(rank_std if pd.notna(rank_std) else 0.2, 0.05)), + 'sample_count': len(slot_group), + } + + if baseline_dict: + baselines[concept_id] = baseline_dict + + print(f"计算了 {len(baselines)} 个概念的基线") + + # 保存 + os.makedirs(BASELINE_DIR, exist_ok=True) + baseline_file = os.path.join(BASELINE_DIR, 'realtime_baseline.pkl') + + with open(baseline_file, 'wb') as f: + pickle.dump({ + 'baselines': baselines, + 'update_time': datetime.now().isoformat(), + 'date_range': [recent_days[0], recent_days[-1]], + 'baseline_days': baseline_days, + }, f) + + print(f"基线已保存: {baseline_file}") + print("=" * 60) + + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('--days', type=int, default=20, help='基线天数') + args = parser.parse_args() + + update_rolling_baseline(args.days) diff --git a/src/views/StockOverview/components/FlexScreen/components/MiniTimelineChart.js b/src/views/StockOverview/components/FlexScreen/components/MiniTimelineChart.js new file mode 100644 index 00000000..46c15e58 --- /dev/null +++ b/src/views/StockOverview/components/FlexScreen/components/MiniTimelineChart.js @@ -0,0 +1,267 @@ +/** + * 迷你分时图组件 + * 用于灵活屏中显示证券的日内走势 + */ +import React, { useEffect, useRef, useState, useMemo } from 'react'; +import { Box, Spinner, Center, Text } from '@chakra-ui/react'; +import * as echarts from 'echarts'; + +/** + * 生成交易时间刻度(用于 X 轴) + * A股交易时间:9:30-11:30, 13:00-15:00 + */ +const generateTimeTicks = () => { + const ticks = []; + // 上午 + for (let h = 9; h <= 11; h++) { + for (let m = (h === 9 ? 30 : 0); m < 60; m++) { + if (h === 11 && m > 30) break; + ticks.push(`${String(h).padStart(2, '0')}:${String(m).padStart(2, '0')}`); + } + } + // 下午 + for (let h = 13; h <= 15; h++) { + for (let m = 0; m < 60; m++) { + if (h === 15 && m > 0) break; + ticks.push(`${String(h).padStart(2, '0')}:${String(m).padStart(2, '0')}`); + } + } + return ticks; +}; + +const TIME_TICKS = generateTimeTicks(); + +/** + * MiniTimelineChart 组件 + * @param {Object} props + * @param {string} props.code - 证券代码 + * @param {boolean} props.isIndex - 是否为指数 + * @param {number} props.prevClose - 昨收价 + * @param {number} props.currentPrice - 当前价(实时) + * @param {number} props.height - 图表高度 + */ +const MiniTimelineChart = ({ + code, + isIndex = false, + prevClose, + currentPrice, + height = 120, +}) => { + const chartRef = useRef(null); + const chartInstance = useRef(null); + const [timelineData, setTimelineData] = useState([]); + const [loading, setLoading] = useState(true); + const [error, setError] = useState(null); + + // 获取分钟数据 + useEffect(() => { + if (!code) return; + + const fetchData = async () => { + setLoading(true); + setError(null); + + try { + const apiPath = isIndex + ? `/api/index/${code}/kline?type=minute` + : `/api/stock/${code}/kline?type=minute`; + + const response = await fetch(apiPath); + const result = await response.json(); + + if (result.success !== false && result.data) { + // 格式化数据 + const formatted = result.data.map(item => ({ + time: item.time || item.timestamp, + price: item.close || item.price, + })); + setTimelineData(formatted); + } else { + setError(result.error || '暂无数据'); + } + } catch (e) { + setError('加载失败'); + } finally { + setLoading(false); + } + }; + + fetchData(); + + // 交易时间内每分钟刷新 + const now = new Date(); + const hours = now.getHours(); + const minutes = now.getMinutes(); + const currentMinutes = hours * 60 + minutes; + const isTrading = (currentMinutes >= 570 && currentMinutes <= 690) || + (currentMinutes >= 780 && currentMinutes <= 900); + + let intervalId; + if (isTrading) { + intervalId = setInterval(fetchData, 60000); // 1分钟刷新 + } + + return () => { + if (intervalId) clearInterval(intervalId); + }; + }, [code, isIndex]); + + // 合并实时价格到数据中 + const chartData = useMemo(() => { + if (!timelineData.length) return []; + + const data = [...timelineData]; + + // 如果有实时价格,添加到最新点 + if (currentPrice && data.length > 0) { + const now = new Date(); + const timeStr = `${String(now.getHours()).padStart(2, '0')}:${String(now.getMinutes()).padStart(2, '0')}`; + const lastItem = data[data.length - 1]; + + // 如果实时价格的时间比最后一条数据新,添加新点 + if (lastItem.time !== timeStr) { + data.push({ time: timeStr, price: currentPrice }); + } else { + // 更新最后一条 + data[data.length - 1] = { ...lastItem, price: currentPrice }; + } + } + + return data; + }, [timelineData, currentPrice]); + + // 渲染图表 + useEffect(() => { + if (!chartRef.current || loading || !chartData.length) return; + + if (!chartInstance.current) { + chartInstance.current = echarts.init(chartRef.current); + } + + const baseLine = prevClose || chartData[0]?.price || 0; + + // 计算价格范围 + const prices = chartData.map(d => d.price).filter(p => p > 0); + const minPrice = Math.min(...prices, baseLine); + const maxPrice = Math.max(...prices, baseLine); + const range = Math.max(maxPrice - baseLine, baseLine - minPrice) * 1.1; + + // 准备数据 + const times = chartData.map(d => d.time); + const values = chartData.map(d => d.price); + + // 判断涨跌 + const lastPrice = values[values.length - 1] || baseLine; + const isUp = lastPrice >= baseLine; + + const option = { + grid: { + top: 5, + right: 5, + bottom: 5, + left: 5, + containLabel: false, + }, + xAxis: { + type: 'category', + data: times, + show: false, + boundaryGap: false, + }, + yAxis: { + type: 'value', + min: baseLine - range, + max: baseLine + range, + show: false, + }, + series: [ + { + type: 'line', + data: values, + smooth: false, + symbol: 'none', + lineStyle: { + width: 1.5, + color: isUp ? '#ef4444' : '#22c55e', + }, + areaStyle: { + color: { + type: 'linear', + x: 0, + y: 0, + x2: 0, + y2: 1, + colorStops: [ + { offset: 0, color: isUp ? 'rgba(239, 68, 68, 0.3)' : 'rgba(34, 197, 94, 0.3)' }, + { offset: 1, color: isUp ? 'rgba(239, 68, 68, 0.05)' : 'rgba(34, 197, 94, 0.05)' }, + ], + }, + }, + markLine: { + silent: true, + symbol: 'none', + data: [ + { + yAxis: baseLine, + lineStyle: { + color: '#666', + type: 'dashed', + width: 1, + }, + label: { show: false }, + }, + ], + }, + }, + ], + animation: false, + }; + + chartInstance.current.setOption(option); + + return () => { + // 不在这里销毁,只在组件卸载时销毁 + }; + }, [chartData, prevClose, loading]); + + // 组件卸载时销毁图表 + useEffect(() => { + return () => { + if (chartInstance.current) { + chartInstance.current.dispose(); + chartInstance.current = null; + } + }; + }, []); + + // 窗口 resize 处理 + useEffect(() => { + const handleResize = () => { + chartInstance.current?.resize(); + }; + window.addEventListener('resize', handleResize); + return () => window.removeEventListener('resize', handleResize); + }, []); + + if (loading) { + return ( +
+ +
+ ); + } + + if (error || !chartData.length) { + return ( +
+ + {error || '暂无数据'} + +
+ ); + } + + return ; +}; + +export default MiniTimelineChart; diff --git a/src/views/StockOverview/components/FlexScreen/components/OrderBookPanel.js b/src/views/StockOverview/components/FlexScreen/components/OrderBookPanel.js new file mode 100644 index 00000000..a3cf92fc --- /dev/null +++ b/src/views/StockOverview/components/FlexScreen/components/OrderBookPanel.js @@ -0,0 +1,275 @@ +/** + * 盘口行情面板组件 + * 支持显示 5 档或 10 档买卖盘数据 + * + * 上交所: 5 档行情 + * 深交所: 10 档行情 + */ +import React, { useState } from 'react'; +import { + Box, + VStack, + HStack, + Text, + Button, + ButtonGroup, + useColorModeValue, + Tooltip, + Badge, +} from '@chakra-ui/react'; + +/** + * 格式化成交量 + * @param {number} volume - 成交量(股) + * @returns {string} 格式化后的字符串 + */ +const formatVolume = (volume) => { + if (!volume || volume === 0) return '-'; + if (volume >= 10000) { + return `${(volume / 10000).toFixed(0)}万`; + } + if (volume >= 1000) { + return `${(volume / 1000).toFixed(1)}k`; + } + return String(volume); +}; + +/** + * 格式化价格 + * @param {number} price - 价格 + * @param {number} prevClose - 昨收价 + * @returns {Object} { text, color } + */ +const formatPrice = (price, prevClose) => { + if (!price || price === 0) { + return { text: '-', color: 'gray.400' }; + } + + const text = price.toFixed(2); + + if (!prevClose || prevClose === 0) { + return { text, color: 'gray.600' }; + } + + if (price > prevClose) { + return { text, color: 'red.500' }; + } + if (price < prevClose) { + return { text, color: 'green.500' }; + } + return { text, color: 'gray.600' }; +}; + +/** + * 单行盘口 + */ +const OrderRow = ({ label, price, volume, prevClose, isBid, maxVolume, isLimitPrice }) => { + const bgColor = useColorModeValue( + isBid ? 'red.50' : 'green.50', + isBid ? 'rgba(239, 68, 68, 0.1)' : 'rgba(34, 197, 94, 0.1)' + ); + const barColor = useColorModeValue( + isBid ? 'red.200' : 'green.200', + isBid ? 'rgba(239, 68, 68, 0.3)' : 'rgba(34, 197, 94, 0.3)' + ); + const limitColor = useColorModeValue('orange.500', 'orange.300'); + + const priceInfo = formatPrice(price, prevClose); + const volumeText = formatVolume(volume); + + // 计算成交量条宽度 + const barWidth = maxVolume > 0 ? Math.min((volume / maxVolume) * 100, 100) : 0; + + return ( + + {/* 成交量条 */} + + + {/* 内容 */} + + {label} + + + + {priceInfo.text} + + {isLimitPrice && ( + + + {isBid ? '跌' : '涨'} + + + )} + + + {volumeText} + + + ); +}; + +/** + * OrderBookPanel 组件 + * @param {Object} props + * @param {number[]} props.bidPrices - 买档价格(最多10档) + * @param {number[]} props.bidVolumes - 买档量 + * @param {number[]} props.askPrices - 卖档价格(最多10档) + * @param {number[]} props.askVolumes - 卖档量 + * @param {number} props.prevClose - 昨收价 + * @param {number} props.upperLimit - 涨停价 + * @param {number} props.lowerLimit - 跌停价 + * @param {number} props.defaultLevels - 默认显示档数(5 或 10) + */ +const OrderBookPanel = ({ + bidPrices = [], + bidVolumes = [], + askPrices = [], + askVolumes = [], + prevClose, + upperLimit, + lowerLimit, + defaultLevels = 5, +}) => { + const borderColor = useColorModeValue('gray.200', 'gray.700'); + const buttonBg = useColorModeValue('gray.100', 'gray.700'); + + // 可切换显示的档位数 + const maxAvailableLevels = Math.max(bidPrices.length, askPrices.length, 1); + const [showLevels, setShowLevels] = useState(Math.min(defaultLevels, maxAvailableLevels)); + + // 计算最大成交量(用于条形图比例) + const displayBidVolumes = bidVolumes.slice(0, showLevels); + const displayAskVolumes = askVolumes.slice(0, showLevels); + const allVolumes = [...displayBidVolumes, ...displayAskVolumes].filter(v => v > 0); + const maxVolume = allVolumes.length > 0 ? Math.max(...allVolumes) : 0; + + // 判断是否为涨跌停价 + const isUpperLimit = (price) => upperLimit && Math.abs(price - upperLimit) < 0.001; + const isLowerLimit = (price) => lowerLimit && Math.abs(price - lowerLimit) < 0.001; + + // 卖盘(从卖N到卖1,即价格从高到低) + const askRows = []; + for (let i = showLevels - 1; i >= 0; i--) { + askRows.push( + + ); + } + + // 买盘(从买1到买N,即价格从高到低) + const bidRows = []; + for (let i = 0; i < showLevels; i++) { + bidRows.push( + + ); + } + + // 没有数据时的提示 + const hasData = bidPrices.length > 0 || askPrices.length > 0; + + if (!hasData) { + return ( + + + 暂无盘口数据 + + + ); + } + + return ( + + {/* 档位切换(只有当有超过5档数据时才显示) */} + {maxAvailableLevels > 5 && ( + + + + + + + )} + + {/* 卖盘 */} + {askRows} + + {/* 分隔线 + 当前价信息 */} + + {prevClose && ( + + 昨收 {prevClose.toFixed(2)} + + )} + + + {/* 买盘 */} + {bidRows} + + {/* 涨跌停价信息 */} + {(upperLimit || lowerLimit) && ( + + {lowerLimit && 跌停 {lowerLimit.toFixed(2)}} + {upperLimit && 涨停 {upperLimit.toFixed(2)}} + + )} + + ); +}; + +export default OrderBookPanel; diff --git a/src/views/StockOverview/components/FlexScreen/components/QuoteTile.js b/src/views/StockOverview/components/FlexScreen/components/QuoteTile.js new file mode 100644 index 00000000..f5b5c389 --- /dev/null +++ b/src/views/StockOverview/components/FlexScreen/components/QuoteTile.js @@ -0,0 +1,270 @@ +/** + * 行情瓷砖组件 + * 单个证券的实时行情展示卡片,包含分时图和五档盘口 + */ +import React, { useState } from 'react'; +import { + Box, + VStack, + HStack, + Text, + IconButton, + Tooltip, + useColorModeValue, + Collapse, + Badge, + Flex, + Spacer, +} from '@chakra-ui/react'; +import { CloseIcon, ChevronDownIcon, ChevronUpIcon, ExternalLinkIcon } from '@chakra-ui/icons'; +import { useNavigate } from 'react-router-dom'; +import MiniTimelineChart from './MiniTimelineChart'; +import OrderBookPanel from './OrderBookPanel'; + +/** + * 格式化价格显示 + */ +const formatPrice = (price) => { + if (!price || isNaN(price)) return '-'; + return price.toFixed(2); +}; + +/** + * 格式化涨跌幅 + */ +const formatChangePct = (pct) => { + if (!pct || isNaN(pct)) return '0.00%'; + const sign = pct > 0 ? '+' : ''; + return `${sign}${pct.toFixed(2)}%`; +}; + +/** + * 格式化涨跌额 + */ +const formatChange = (change) => { + if (!change || isNaN(change)) return '-'; + const sign = change > 0 ? '+' : ''; + return `${sign}${change.toFixed(2)}`; +}; + +/** + * 格式化成交额 + */ +const formatAmount = (amount) => { + if (!amount || isNaN(amount)) return '-'; + if (amount >= 100000000) { + return `${(amount / 100000000).toFixed(2)}亿`; + } + if (amount >= 10000) { + return `${(amount / 10000).toFixed(0)}万`; + } + return amount.toFixed(0); +}; + +/** + * QuoteTile 组件 + * @param {Object} props + * @param {string} props.code - 证券代码 + * @param {string} props.name - 证券名称 + * @param {Object} props.quote - 实时行情数据 + * @param {boolean} props.isIndex - 是否为指数 + * @param {Function} props.onRemove - 移除回调 + */ +const QuoteTile = ({ + code, + name, + quote = {}, + isIndex = false, + onRemove, +}) => { + const navigate = useNavigate(); + const [expanded, setExpanded] = useState(true); + + // 颜色主题 + const cardBg = useColorModeValue('white', '#1a1a1a'); + const borderColor = useColorModeValue('gray.200', '#333'); + const hoverBorderColor = useColorModeValue('purple.300', '#666'); + const textColor = useColorModeValue('gray.800', 'white'); + const subTextColor = useColorModeValue('gray.500', 'gray.400'); + + // 涨跌色 + const { price, prevClose, change, changePct, amount } = quote; + const priceColor = useColorModeValue( + !prevClose || price === prevClose ? 'gray.800' : + price > prevClose ? 'red.500' : 'green.500', + !prevClose || price === prevClose ? 'gray.200' : + price > prevClose ? 'red.400' : 'green.400' + ); + + // 涨跌幅背景色 + const changeBgColor = useColorModeValue( + !changePct || changePct === 0 ? 'gray.100' : + changePct > 0 ? 'red.100' : 'green.100', + !changePct || changePct === 0 ? 'gray.700' : + changePct > 0 ? 'rgba(239, 68, 68, 0.2)' : 'rgba(34, 197, 94, 0.2)' + ); + + // 跳转到详情页 + const handleNavigate = () => { + if (isIndex) { + // 指数暂无详情页 + return; + } + navigate(`/company?scode=${code}`); + }; + + return ( + + {/* 头部 */} + setExpanded(!expanded)} + > + {/* 名称和代码 */} + + + { + e.stopPropagation(); + handleNavigate(); + }} + > + {name || code} + + {isIndex && ( + + 指数 + + )} + + + {code} + + + + {/* 价格信息 */} + + + {formatPrice(price)} + + + + {formatChangePct(changePct)} + + + {formatChange(change)} + + + + + {/* 操作按钮 */} + + : } + size="xs" + variant="ghost" + aria-label={expanded ? '收起' : '展开'} + onClick={(e) => { + e.stopPropagation(); + setExpanded(!expanded); + }} + /> + + } + size="xs" + variant="ghost" + colorScheme="red" + aria-label="移除" + onClick={(e) => { + e.stopPropagation(); + onRemove?.(code); + }} + /> + + + + + {/* 可折叠内容 */} + + + {/* 统计信息 */} + + + 昨收: + {formatPrice(prevClose)} + + + 今开: + {formatPrice(quote.open)} + + + 成交额: + {formatAmount(amount)} + + + + {/* 分时图 */} + + + + + {/* 盘口(指数没有盘口) */} + {!isIndex && ( + + + 盘口 {quote.bidPrices?.length > 5 ? '(10档)' : '(5档)'} + + + + )} + + + + ); +}; + +export default QuoteTile; diff --git a/src/views/StockOverview/components/FlexScreen/components/index.js b/src/views/StockOverview/components/FlexScreen/components/index.js new file mode 100644 index 00000000..46c2c9f5 --- /dev/null +++ b/src/views/StockOverview/components/FlexScreen/components/index.js @@ -0,0 +1,3 @@ +export { default as MiniTimelineChart } from './MiniTimelineChart'; +export { default as OrderBookPanel } from './OrderBookPanel'; +export { default as QuoteTile } from './QuoteTile'; diff --git a/src/views/StockOverview/components/FlexScreen/hooks/index.js b/src/views/StockOverview/components/FlexScreen/hooks/index.js new file mode 100644 index 00000000..492690a3 --- /dev/null +++ b/src/views/StockOverview/components/FlexScreen/hooks/index.js @@ -0,0 +1 @@ +export { useRealtimeQuote } from './useRealtimeQuote'; diff --git a/src/views/StockOverview/components/FlexScreen/hooks/useRealtimeQuote.js b/src/views/StockOverview/components/FlexScreen/hooks/useRealtimeQuote.js new file mode 100644 index 00000000..476563ed --- /dev/null +++ b/src/views/StockOverview/components/FlexScreen/hooks/useRealtimeQuote.js @@ -0,0 +1,692 @@ +/** + * 实时行情 Hook + * 管理上交所和深交所 WebSocket 连接,获取实时行情数据 + * + * 上交所 (SSE): ws://49.232.185.254:8765 - 需主动订阅,提供五档行情 + * 深交所 (SZSE): ws://222.128.1.157:8765 - 自动推送,提供十档行情 + * + * 深交所支持的数据类型 (category): + * - stock (300111): 股票快照,含10档买卖盘 + * - bond (300211): 债券快照 + * - afterhours_block (300611): 盘后定价大宗交易 + * - afterhours_trading (303711): 盘后定价交易 + * - hk_stock (306311): 港股快照(深港通) + * - index (309011): 指数快照 + * - volume_stats (309111): 成交量统计 + * - fund_nav (309211): 基金净值 + */ +import { useState, useEffect, useRef, useCallback } from 'react'; +import { logger } from '@utils/logger'; + +// WebSocket 地址配置 +const WS_CONFIG = { + SSE: 'ws://49.232.185.254:8765', // 上交所 + SZSE: 'ws://222.128.1.157:8765', // 深交所 +}; + +// 心跳间隔 (ms) +const HEARTBEAT_INTERVAL = 30000; + +// 重连间隔 (ms) +const RECONNECT_INTERVAL = 3000; + +/** + * 判断证券代码属于哪个交易所 + * @param {string} code - 证券代码(可带或不带后缀) + * @returns {'SSE'|'SZSE'} 交易所标识 + */ +const getExchange = (code) => { + const baseCode = code.split('.')[0]; + + // 6开头为上海股票 + if (baseCode.startsWith('6')) { + return 'SSE'; + } + + // 000开头的6位数可能是上证指数或深圳股票 + if (baseCode.startsWith('000') && baseCode.length === 6) { + // 000001-000999 是上证指数范围,但 000001 也是平安银行 + // 这里需要更精确的判断,暂时把 000 开头当深圳 + return 'SZSE'; + } + + // 399开头是深证指数 + if (baseCode.startsWith('399')) { + return 'SZSE'; + } + + // 0、3开头是深圳股票 + if (baseCode.startsWith('0') || baseCode.startsWith('3')) { + return 'SZSE'; + } + + // 5开头是上海 ETF + if (baseCode.startsWith('5')) { + return 'SSE'; + } + + // 1开头是深圳 ETF/债券 + if (baseCode.startsWith('1')) { + return 'SZSE'; + } + + // 默认上海 + return 'SSE'; +}; + +/** + * 标准化证券代码为无后缀格式 + */ +const normalizeCode = (code) => { + return code.split('.')[0]; +}; + +/** + * 从深交所 bids/asks 数组提取价格和量数组 + * @param {Array} orderBook - [{price, volume}, ...] + * @returns {{ prices: number[], volumes: number[] }} + */ +const extractOrderBook = (orderBook) => { + if (!orderBook || !Array.isArray(orderBook)) { + return { prices: [], volumes: [] }; + } + const prices = orderBook.map(item => item.price || 0); + const volumes = orderBook.map(item => item.volume || 0); + return { prices, volumes }; +}; + +/** + * 实时行情 Hook + * @param {string[]} codes - 订阅的证券代码列表 + * @returns {Object} { quotes, connected, subscribe, unsubscribe } + */ +export const useRealtimeQuote = (codes = []) => { + // 行情数据 { [code]: QuoteData } + const [quotes, setQuotes] = useState({}); + // 连接状态 { SSE: boolean, SZSE: boolean } + const [connected, setConnected] = useState({ SSE: false, SZSE: false }); + + // WebSocket 实例引用 + const wsRefs = useRef({ SSE: null, SZSE: null }); + // 心跳定时器 + const heartbeatRefs = useRef({ SSE: null, SZSE: null }); + // 重连定时器 + const reconnectRefs = useRef({ SSE: null, SZSE: null }); + // 当前订阅的代码(按交易所分组) + const subscribedCodes = useRef({ SSE: new Set(), SZSE: new Set() }); + + /** + * 创建 WebSocket 连接 + */ + const createConnection = useCallback((exchange) => { + // 清理现有连接 + if (wsRefs.current[exchange]) { + wsRefs.current[exchange].close(); + } + + const ws = new WebSocket(WS_CONFIG[exchange]); + wsRefs.current[exchange] = ws; + + ws.onopen = () => { + logger.info('FlexScreen', `${exchange} WebSocket 已连接`); + setConnected(prev => ({ ...prev, [exchange]: true })); + + // 上交所需要主动订阅 + if (exchange === 'SSE') { + const codes = Array.from(subscribedCodes.current.SSE); + if (codes.length > 0) { + ws.send(JSON.stringify({ + action: 'subscribe', + channels: ['stock', 'index'], + codes: codes, + })); + } + } + + // 启动心跳 + startHeartbeat(exchange); + }; + + ws.onmessage = (event) => { + try { + const msg = JSON.parse(event.data); + handleMessage(exchange, msg); + } catch (e) { + logger.warn('FlexScreen', `${exchange} 消息解析失败`, e); + } + }; + + ws.onerror = (error) => { + logger.error('FlexScreen', `${exchange} WebSocket 错误`, error); + }; + + ws.onclose = () => { + logger.info('FlexScreen', `${exchange} WebSocket 断开`); + setConnected(prev => ({ ...prev, [exchange]: false })); + stopHeartbeat(exchange); + + // 自动重连 + scheduleReconnect(exchange); + }; + }, []); + + /** + * 处理 WebSocket 消息 + */ + const handleMessage = useCallback((exchange, msg) => { + // 处理 pong + if (msg.type === 'pong') { + return; + } + + if (exchange === 'SSE') { + // 上交所消息格式 + if (msg.type === 'stock' || msg.type === 'index') { + const data = msg.data || {}; + setQuotes(prev => { + const updated = { ...prev }; + Object.entries(data).forEach(([code, quote]) => { + // 只更新订阅的代码 + if (subscribedCodes.current.SSE.has(code)) { + updated[code] = { + code: quote.security_id, + name: quote.security_name, + price: quote.last_price, + prevClose: quote.prev_close, + open: quote.open_price, + high: quote.high_price, + low: quote.low_price, + volume: quote.volume, + amount: quote.amount, + change: quote.last_price - quote.prev_close, + changePct: quote.prev_close ? ((quote.last_price - quote.prev_close) / quote.prev_close * 100) : 0, + bidPrices: quote.bid_prices || [], + bidVolumes: quote.bid_volumes || [], + askPrices: quote.ask_prices || [], + askVolumes: quote.ask_volumes || [], + updateTime: quote.trade_time, + exchange: 'SSE', + }; + } + }); + return updated; + }); + } + } else if (exchange === 'SZSE') { + // 深交所消息格式(更新后的 API) + if (msg.type === 'realtime') { + const { category, data } = msg; + const code = data.security_id; + + // 只更新订阅的代码 + if (!subscribedCodes.current.SZSE.has(code)) { + return; + } + + if (category === 'stock') { + // 股票行情 - 含 10 档买卖盘 + const { prices: bidPrices, volumes: bidVolumes } = extractOrderBook(data.bids); + const { prices: askPrices, volumes: askVolumes } = extractOrderBook(data.asks); + + setQuotes(prev => ({ + ...prev, + [code]: { + code: code, + name: prev[code]?.name || '', + price: data.last_px, + prevClose: data.prev_close, + open: data.open_px, + high: data.high_px, + low: data.low_px, + volume: data.volume, + amount: data.amount, + numTrades: data.num_trades, + upperLimit: data.upper_limit, // 涨停价 + lowerLimit: data.lower_limit, // 跌停价 + change: data.last_px - data.prev_close, + changePct: data.prev_close ? ((data.last_px - data.prev_close) / data.prev_close * 100) : 0, + bidPrices, + bidVolumes, + askPrices, + askVolumes, + tradingPhase: data.trading_phase, + updateTime: msg.timestamp, + exchange: 'SZSE', + }, + })); + } else if (category === 'index') { + // 指数行情 + setQuotes(prev => ({ + ...prev, + [code]: { + code: code, + name: prev[code]?.name || '', + price: data.current_index, + prevClose: data.prev_close, + open: data.open_index, + high: data.high_index, + low: data.low_index, + close: data.close_index, + volume: data.volume, + amount: data.amount, + numTrades: data.num_trades, + change: data.current_index - data.prev_close, + changePct: data.prev_close ? ((data.current_index - data.prev_close) / data.prev_close * 100) : 0, + bidPrices: [], + bidVolumes: [], + askPrices: [], + askVolumes: [], + tradingPhase: data.trading_phase, + updateTime: msg.timestamp, + exchange: 'SZSE', + }, + })); + } else if (category === 'bond') { + // 债券行情 + setQuotes(prev => ({ + ...prev, + [code]: { + code: code, + name: prev[code]?.name || '', + price: data.last_px, + prevClose: data.prev_close, + open: data.open_px, + high: data.high_px, + low: data.low_px, + volume: data.volume, + amount: data.amount, + numTrades: data.num_trades, + weightedAvgPx: data.weighted_avg_px, + change: data.last_px - data.prev_close, + changePct: data.prev_close ? ((data.last_px - data.prev_close) / data.prev_close * 100) : 0, + bidPrices: [], + bidVolumes: [], + askPrices: [], + askVolumes: [], + tradingPhase: data.trading_phase, + updateTime: msg.timestamp, + exchange: 'SZSE', + isBond: true, + }, + })); + } else if (category === 'hk_stock') { + // 港股行情(深港通) + const { prices: bidPrices, volumes: bidVolumes } = extractOrderBook(data.bids); + const { prices: askPrices, volumes: askVolumes } = extractOrderBook(data.asks); + + setQuotes(prev => ({ + ...prev, + [code]: { + code: code, + name: prev[code]?.name || '', + price: data.last_px, + prevClose: data.prev_close, + open: data.open_px, + high: data.high_px, + low: data.low_px, + volume: data.volume, + amount: data.amount, + numTrades: data.num_trades, + nominalPx: data.nominal_px, // 按盘价 + referencePx: data.reference_px, // 参考价 + change: data.last_px - data.prev_close, + changePct: data.prev_close ? ((data.last_px - data.prev_close) / data.prev_close * 100) : 0, + bidPrices, + bidVolumes, + askPrices, + askVolumes, + tradingPhase: data.trading_phase, + updateTime: msg.timestamp, + exchange: 'SZSE', + isHK: true, + }, + })); + } else if (category === 'afterhours_block' || category === 'afterhours_trading') { + // 盘后交易 + setQuotes(prev => ({ + ...prev, + [code]: { + ...prev[code], + afterhours: { + bidPx: data.bid_px, + bidSize: data.bid_size, + offerPx: data.offer_px, + offerSize: data.offer_size, + volume: data.volume, + amount: data.amount, + numTrades: data.num_trades, + }, + updateTime: msg.timestamp, + }, + })); + } + // fund_nav 和 volume_stats 暂不处理 + } else if (msg.type === 'snapshot') { + // 深交所初始快照 + const { stocks = [], indexes = [], bonds = [] } = msg.data || {}; + setQuotes(prev => { + const updated = { ...prev }; + + stocks.forEach(s => { + if (subscribedCodes.current.SZSE.has(s.security_id)) { + const { prices: bidPrices, volumes: bidVolumes } = extractOrderBook(s.bids); + const { prices: askPrices, volumes: askVolumes } = extractOrderBook(s.asks); + + updated[s.security_id] = { + code: s.security_id, + name: s.security_name || '', + price: s.last_px, + prevClose: s.prev_close, + open: s.open_px, + high: s.high_px, + low: s.low_px, + volume: s.volume, + amount: s.amount, + numTrades: s.num_trades, + upperLimit: s.upper_limit, + lowerLimit: s.lower_limit, + change: s.last_px - s.prev_close, + changePct: s.prev_close ? ((s.last_px - s.prev_close) / s.prev_close * 100) : 0, + bidPrices, + bidVolumes, + askPrices, + askVolumes, + exchange: 'SZSE', + }; + } + }); + + indexes.forEach(i => { + if (subscribedCodes.current.SZSE.has(i.security_id)) { + updated[i.security_id] = { + code: i.security_id, + name: i.security_name || '', + price: i.current_index, + prevClose: i.prev_close, + open: i.open_index, + high: i.high_index, + low: i.low_index, + volume: i.volume, + amount: i.amount, + numTrades: i.num_trades, + change: i.current_index - i.prev_close, + changePct: i.prev_close ? ((i.current_index - i.prev_close) / i.prev_close * 100) : 0, + bidPrices: [], + bidVolumes: [], + askPrices: [], + askVolumes: [], + exchange: 'SZSE', + }; + } + }); + + bonds.forEach(b => { + if (subscribedCodes.current.SZSE.has(b.security_id)) { + updated[b.security_id] = { + code: b.security_id, + name: b.security_name || '', + price: b.last_px, + prevClose: b.prev_close, + open: b.open_px, + high: b.high_px, + low: b.low_px, + volume: b.volume, + amount: b.amount, + change: b.last_px - b.prev_close, + changePct: b.prev_close ? ((b.last_px - b.prev_close) / b.prev_close * 100) : 0, + bidPrices: [], + bidVolumes: [], + askPrices: [], + askVolumes: [], + exchange: 'SZSE', + isBond: true, + }; + } + }); + + return updated; + }); + } + } + }, []); + + /** + * 启动心跳 + */ + const startHeartbeat = useCallback((exchange) => { + stopHeartbeat(exchange); + heartbeatRefs.current[exchange] = setInterval(() => { + const ws = wsRefs.current[exchange]; + if (ws && ws.readyState === WebSocket.OPEN) { + if (exchange === 'SSE') { + ws.send(JSON.stringify({ action: 'ping' })); + } else { + ws.send(JSON.stringify({ type: 'ping' })); + } + } + }, HEARTBEAT_INTERVAL); + }, []); + + /** + * 停止心跳 + */ + const stopHeartbeat = useCallback((exchange) => { + if (heartbeatRefs.current[exchange]) { + clearInterval(heartbeatRefs.current[exchange]); + heartbeatRefs.current[exchange] = null; + } + }, []); + + /** + * 安排重连 + */ + const scheduleReconnect = useCallback((exchange) => { + if (reconnectRefs.current[exchange]) { + return; // 已有重连计划 + } + + reconnectRefs.current[exchange] = setTimeout(() => { + reconnectRefs.current[exchange] = null; + // 只有还有订阅的代码才重连 + if (subscribedCodes.current[exchange].size > 0) { + createConnection(exchange); + } + }, RECONNECT_INTERVAL); + }, [createConnection]); + + /** + * 订阅证券 + */ + const subscribe = useCallback((code) => { + const baseCode = normalizeCode(code); + const exchange = getExchange(code); + + // 添加到订阅列表 + subscribedCodes.current[exchange].add(baseCode); + + // 如果连接已建立,发送订阅消息(仅上交所需要) + const ws = wsRefs.current[exchange]; + if (exchange === 'SSE' && ws && ws.readyState === WebSocket.OPEN) { + ws.send(JSON.stringify({ + action: 'subscribe', + channels: ['stock', 'index'], + codes: [baseCode], + })); + } + + // 如果连接未建立,创建连接 + if (!ws || ws.readyState !== WebSocket.OPEN) { + createConnection(exchange); + } + }, [createConnection]); + + /** + * 取消订阅 + */ + const unsubscribe = useCallback((code) => { + const baseCode = normalizeCode(code); + const exchange = getExchange(code); + + // 从订阅列表移除 + subscribedCodes.current[exchange].delete(baseCode); + + // 从 quotes 中移除 + setQuotes(prev => { + const updated = { ...prev }; + delete updated[baseCode]; + return updated; + }); + + // 如果该交易所没有订阅了,关闭连接 + if (subscribedCodes.current[exchange].size === 0) { + const ws = wsRefs.current[exchange]; + if (ws) { + ws.close(); + wsRefs.current[exchange] = null; + } + } + }, []); + + /** + * 初始化订阅 + */ + useEffect(() => { + if (!codes || codes.length === 0) { + return; + } + + // 按交易所分组 + const sseCodesSet = new Set(); + const szseCodesSet = new Set(); + + codes.forEach(code => { + const baseCode = normalizeCode(code); + const exchange = getExchange(code); + if (exchange === 'SSE') { + sseCodesSet.add(baseCode); + } else { + szseCodesSet.add(baseCode); + } + }); + + // 更新订阅列表 + subscribedCodes.current.SSE = sseCodesSet; + subscribedCodes.current.SZSE = szseCodesSet; + + // 建立连接 + if (sseCodesSet.size > 0) { + createConnection('SSE'); + } + if (szseCodesSet.size > 0) { + createConnection('SZSE'); + } + + // 清理 + return () => { + ['SSE', 'SZSE'].forEach(exchange => { + stopHeartbeat(exchange); + if (reconnectRefs.current[exchange]) { + clearTimeout(reconnectRefs.current[exchange]); + } + const ws = wsRefs.current[exchange]; + if (ws) { + ws.close(); + } + }); + }; + }, []); // 只在挂载时执行 + + /** + * 处理 codes 变化 + */ + useEffect(() => { + if (!codes) return; + + // 计算新的订阅列表 + const newSseCodes = new Set(); + const newSzseCodes = new Set(); + + codes.forEach(code => { + const baseCode = normalizeCode(code); + const exchange = getExchange(code); + if (exchange === 'SSE') { + newSseCodes.add(baseCode); + } else { + newSzseCodes.add(baseCode); + } + }); + + // 找出需要新增和删除的代码 + const oldSseCodes = subscribedCodes.current.SSE; + const oldSzseCodes = subscribedCodes.current.SZSE; + + // 更新上交所订阅 + const sseToAdd = [...newSseCodes].filter(c => !oldSseCodes.has(c)); + const sseToRemove = [...oldSseCodes].filter(c => !newSseCodes.has(c)); + + if (sseToAdd.length > 0 || sseToRemove.length > 0) { + subscribedCodes.current.SSE = newSseCodes; + + const ws = wsRefs.current.SSE; + if (ws && ws.readyState === WebSocket.OPEN && sseToAdd.length > 0) { + ws.send(JSON.stringify({ + action: 'subscribe', + channels: ['stock', 'index'], + codes: sseToAdd, + })); + } + + // 如果新增了代码但连接未建立 + if (sseToAdd.length > 0 && (!ws || ws.readyState !== WebSocket.OPEN)) { + createConnection('SSE'); + } + + // 如果没有订阅了,关闭连接 + if (newSseCodes.size === 0 && ws) { + ws.close(); + wsRefs.current.SSE = null; + } + } + + // 更新深交所订阅 + const szseToAdd = [...newSzseCodes].filter(c => !oldSzseCodes.has(c)); + const szseToRemove = [...oldSzseCodes].filter(c => !newSzseCodes.has(c)); + + if (szseToAdd.length > 0 || szseToRemove.length > 0) { + subscribedCodes.current.SZSE = newSzseCodes; + + // 深交所是自动推送,只需要管理连接 + const ws = wsRefs.current.SZSE; + + if (szseToAdd.length > 0 && (!ws || ws.readyState !== WebSocket.OPEN)) { + createConnection('SZSE'); + } + + if (newSzseCodes.size === 0 && ws) { + ws.close(); + wsRefs.current.SZSE = null; + } + } + + // 清理已取消订阅的 quotes + const removedCodes = [...sseToRemove, ...szseToRemove]; + if (removedCodes.length > 0) { + setQuotes(prev => { + const updated = { ...prev }; + removedCodes.forEach(code => { + delete updated[code]; + }); + return updated; + }); + } + }, [codes, createConnection]); + + return { + quotes, + connected, + subscribe, + unsubscribe, + }; +}; + +export default useRealtimeQuote; diff --git a/src/views/StockOverview/components/FlexScreen/index.js b/src/views/StockOverview/components/FlexScreen/index.js new file mode 100644 index 00000000..aeace485 --- /dev/null +++ b/src/views/StockOverview/components/FlexScreen/index.js @@ -0,0 +1,463 @@ +/** + * 灵活屏组件 + * 用户可自定义添加关注的指数/个股,实时显示行情 + * + * 功能: + * 1. 添加/删除自选证券 + * 2. 显示实时行情(通过 WebSocket) + * 3. 显示分时走势(结合 ClickHouse 历史数据) + * 4. 显示五档盘口(上交所完整五档,深交所买一卖一) + * 5. 本地存储自选列表 + */ +import React, { useState, useEffect, useCallback, useMemo } from 'react'; +import { + Box, + Card, + CardBody, + VStack, + HStack, + Heading, + Text, + Input, + InputGroup, + InputLeftElement, + InputRightElement, + IconButton, + Button, + SimpleGrid, + Flex, + Spacer, + Icon, + useColorModeValue, + useToast, + Badge, + Tooltip, + Collapse, + List, + ListItem, + Spinner, + Center, + Menu, + MenuButton, + MenuList, + MenuItem, + Divider, + Tag, + TagLabel, +} from '@chakra-ui/react'; +import { + SearchIcon, + CloseIcon, + AddIcon, + ChevronDownIcon, + ChevronUpIcon, + SettingsIcon, +} from '@chakra-ui/icons'; +import { FaDesktop, FaPlus, FaTrash, FaSync, FaWifi, FaExclamationCircle } from 'react-icons/fa'; + +import { useRealtimeQuote } from './hooks'; +import { QuoteTile } from './components'; +import { logger } from '@utils/logger'; + +// 本地存储 key +const STORAGE_KEY = 'flexscreen_watchlist'; + +// 默认自选列表 +const DEFAULT_WATCHLIST = [ + { code: '000001', name: '上证指数', isIndex: true }, + { code: '399001', name: '深证成指', isIndex: true }, + { code: '399006', name: '创业板指', isIndex: true }, +]; + +// 热门推荐 +const HOT_RECOMMENDATIONS = [ + { code: '000001', name: '上证指数', isIndex: true }, + { code: '399001', name: '深证成指', isIndex: true }, + { code: '399006', name: '创业板指', isIndex: true }, + { code: '399300', name: '沪深300', isIndex: true }, + { code: '600519', name: '贵州茅台', isIndex: false }, + { code: '000858', name: '五粮液', isIndex: false }, + { code: '300750', name: '宁德时代', isIndex: false }, + { code: '002594', name: '比亚迪', isIndex: false }, +]; + +/** + * FlexScreen 组件 + */ +const FlexScreen = () => { + const toast = useToast(); + + // 自选列表 + const [watchlist, setWatchlist] = useState([]); + // 搜索状态 + const [searchQuery, setSearchQuery] = useState(''); + const [searchResults, setSearchResults] = useState([]); + const [isSearching, setIsSearching] = useState(false); + const [showResults, setShowResults] = useState(false); + // 面板状态 + const [isCollapsed, setIsCollapsed] = useState(false); + + // 颜色主题 + const cardBg = useColorModeValue('white', '#1a1a1a'); + const borderColor = useColorModeValue('gray.200', '#333'); + const textColor = useColorModeValue('gray.800', 'white'); + const subTextColor = useColorModeValue('gray.600', 'gray.400'); + const searchBg = useColorModeValue('gray.50', '#2a2a2a'); + const hoverBg = useColorModeValue('gray.100', '#333'); + + // 获取订阅的证券代码列表 + const subscribedCodes = useMemo(() => { + return watchlist.map(item => item.code); + }, [watchlist]); + + // WebSocket 实时行情 + const { quotes, connected } = useRealtimeQuote(subscribedCodes); + + // 从本地存储加载自选列表 + useEffect(() => { + try { + const saved = localStorage.getItem(STORAGE_KEY); + if (saved) { + const parsed = JSON.parse(saved); + if (Array.isArray(parsed) && parsed.length > 0) { + setWatchlist(parsed); + return; + } + } + } catch (e) { + logger.warn('FlexScreen', '加载自选列表失败', e); + } + // 使用默认列表 + setWatchlist(DEFAULT_WATCHLIST); + }, []); + + // 保存自选列表到本地存储 + useEffect(() => { + if (watchlist.length > 0) { + try { + localStorage.setItem(STORAGE_KEY, JSON.stringify(watchlist)); + } catch (e) { + logger.warn('FlexScreen', '保存自选列表失败', e); + } + } + }, [watchlist]); + + // 搜索证券 + const searchSecurities = useCallback(async (query) => { + if (!query.trim()) { + setSearchResults([]); + setShowResults(false); + return; + } + + setIsSearching(true); + try { + const response = await fetch(`/api/stocks/search?q=${encodeURIComponent(query)}&limit=10`); + const data = await response.json(); + + if (data.success) { + setSearchResults(data.data || []); + setShowResults(true); + } else { + setSearchResults([]); + } + } catch (e) { + logger.error('FlexScreen', '搜索失败', e); + setSearchResults([]); + } finally { + setIsSearching(false); + } + }, []); + + // 防抖搜索 + useEffect(() => { + const timer = setTimeout(() => { + searchSecurities(searchQuery); + }, 300); + return () => clearTimeout(timer); + }, [searchQuery, searchSecurities]); + + // 添加证券 + const addSecurity = useCallback((security) => { + const code = security.stock_code || security.code; + const name = security.stock_name || security.name; + const isIndex = security.isIndex || code.startsWith('000') || code.startsWith('399'); + + // 检查是否已存在 + if (watchlist.some(item => item.code === code)) { + toast({ + title: '已在自选列表中', + status: 'info', + duration: 2000, + isClosable: true, + }); + return; + } + + // 添加到列表 + setWatchlist(prev => [...prev, { code, name, isIndex }]); + + toast({ + title: `已添加 ${name}`, + status: 'success', + duration: 2000, + isClosable: true, + }); + + // 清空搜索 + setSearchQuery(''); + setShowResults(false); + }, [watchlist, toast]); + + // 移除证券 + const removeSecurity = useCallback((code) => { + setWatchlist(prev => prev.filter(item => item.code !== code)); + }, []); + + // 清空自选列表 + const clearWatchlist = useCallback(() => { + setWatchlist([]); + localStorage.removeItem(STORAGE_KEY); + toast({ + title: '已清空自选列表', + status: 'info', + duration: 2000, + isClosable: true, + }); + }, [toast]); + + // 重置为默认列表 + const resetWatchlist = useCallback(() => { + setWatchlist(DEFAULT_WATCHLIST); + toast({ + title: '已重置为默认列表', + status: 'success', + duration: 2000, + isClosable: true, + }); + }, [toast]); + + // 连接状态指示 + const isAnyConnected = connected.SSE || connected.SZSE; + const connectionStatus = useMemo(() => { + if (connected.SSE && connected.SZSE) { + return { color: 'green', text: '上交所/深交所 已连接' }; + } + if (connected.SSE) { + return { color: 'yellow', text: '上交所 已连接' }; + } + if (connected.SZSE) { + return { color: 'yellow', text: '深交所 已连接' }; + } + return { color: 'red', text: '未连接' }; + }, [connected]); + + return ( + + + {/* 头部 */} + + + + + 灵活屏 + + + + + {isAnyConnected ? '实时' : '离线'} + + + + + + {/* 操作菜单 */} + + } + size="sm" + variant="ghost" + aria-label="设置" + /> + + } onClick={resetWatchlist}> + 重置为默认 + + } onClick={clearWatchlist} color="red.500"> + 清空列表 + + + + {/* 折叠按钮 */} + : } + size="sm" + variant="ghost" + onClick={() => setIsCollapsed(!isCollapsed)} + aria-label={isCollapsed ? '展开' : '收起'} + /> + + + + {/* 可折叠内容 */} + + {/* 搜索框 */} + + + + + + setSearchQuery(e.target.value)} + bg={searchBg} + borderRadius="lg" + _focus={{ + borderColor: 'purple.400', + boxShadow: '0 0 0 1px var(--chakra-colors-purple-400)', + }} + /> + {searchQuery && ( + + } + variant="ghost" + onClick={() => { + setSearchQuery(''); + setShowResults(false); + }} + aria-label="清空" + /> + + )} + + + {/* 搜索结果下拉 */} + + + {isSearching ? ( +
+ +
+ ) : searchResults.length > 0 ? ( + + {searchResults.map((stock, index) => ( + addSecurity(stock)} + borderBottomWidth={index < searchResults.length - 1 ? '1px' : '0'} + borderColor={borderColor} + > + + + + {stock.stock_name} + + + {stock.stock_code} + + + } + size="xs" + colorScheme="purple" + variant="ghost" + aria-label="添加" + /> + + + ))} + + ) : ( +
+ + 未找到相关证券 + +
+ )} +
+
+
+ + {/* 快捷添加 */} + {watchlist.length === 0 && ( + + + 热门推荐(点击添加) + + + {HOT_RECOMMENDATIONS.map((item) => ( + addSecurity(item)} + > + {item.name} + + ))} + + + )} + + {/* 自选列表 */} + {watchlist.length > 0 ? ( + + {watchlist.map((item) => ( + + ))} + + ) : ( +
+ + + + 自选列表为空,请搜索添加证券 + + +
+ )} +
+
+
+ ); +}; + +export default FlexScreen; diff --git a/src/views/StockOverview/components/HotspotOverview/components/ConceptAlertList.js b/src/views/StockOverview/components/HotspotOverview/components/ConceptAlertList.js index 9c846122..1c375bce 100644 --- a/src/views/StockOverview/components/HotspotOverview/components/ConceptAlertList.js +++ b/src/views/StockOverview/components/HotspotOverview/components/ConceptAlertList.js @@ -18,6 +18,69 @@ import { import { FaBolt, FaArrowUp, FaArrowDown, FaChartLine, FaFire, FaVolumeUp } from 'react-icons/fa'; import { getAlertTypeLabel, formatScore, getScoreColor } from '../utils/chartHelpers'; +/** + * Z-Score 指示器组件 + */ +const ZScoreIndicator = ({ value, label, tooltip }) => { + if (value === null || value === undefined) return null; + + // Z-Score 颜色:越大越红,越小越绿 + const getZScoreColor = (z) => { + const absZ = Math.abs(z); + if (absZ >= 3) return z > 0 ? 'red.600' : 'green.600'; + if (absZ >= 2) return z > 0 ? 'red.500' : 'green.500'; + if (absZ >= 1) return z > 0 ? 'orange.400' : 'teal.400'; + return 'gray.400'; + }; + + // Z-Score 强度条宽度(最大 5σ) + const barWidth = Math.min(Math.abs(value) / 5 * 100, 100); + + return ( + + + {label} + + = 0 ? '50%' : `${50 - barWidth / 2}%`} + w={`${barWidth / 2}%`} + h="100%" + bg={getZScoreColor(value)} + borderRadius="full" + /> + + + {value >= 0 ? '+' : ''}{value.toFixed(1)} + + + + ); +}; + +/** + * 持续确认率指示器 + */ +const ConfirmRatioIndicator = ({ ratio }) => { + if (ratio === null || ratio === undefined) return null; + + const percent = Math.round(ratio * 100); + const color = percent >= 80 ? 'green' : percent >= 60 ? 'orange' : 'red'; + + return ( + + + + + + + {percent}% + + + + ); +}; + /** * 单个异动项组件 */ @@ -29,6 +92,7 @@ const AlertItem = ({ alert, onClick, isSelected }) => { const isUp = alert.alert_type !== 'surge_down'; const typeColor = isUp ? 'red' : 'green'; + const isV2 = alert.is_v2; // 获取异动类型图标 const getTypeIcon = (type) => { @@ -69,13 +133,30 @@ const AlertItem = ({ alert, onClick, isSelected }) => { {alert.concept_name} + {isV2 && ( + + V2 + + )} {alert.time} {getAlertTypeLabel(alert.alert_type)} + {/* V2: 持续确认率 */} + {isV2 && alert.confirm_ratio !== undefined && ( + + )} + + {/* V2: Z-Score 指标行 */} + {isV2 && (alert.alpha_zscore !== undefined || alert.amt_zscore !== undefined) && ( + + + + + )} {/* 右侧:分数和关键指标 */} @@ -104,12 +185,29 @@ const AlertItem = ({ alert, onClick, isSelected }) => { )} - {/* 涨停数量 */} - {alert.limit_up_count > 0 && ( + {/* V2: 动量指标 */} + {isV2 && alert.momentum_5m !== undefined && Math.abs(alert.momentum_5m) > 0.3 && ( + + 0 ? FaArrowUp : FaArrowDown} + color={alert.momentum_5m > 0 ? 'red.400' : 'green.400'} + boxSize={3} + /> + 0 ? 'red.400' : 'green.400'}> + 动量 {alert.momentum_5m > 0 ? '+' : ''}{alert.momentum_5m.toFixed(2)} + + + )} + + {/* 涨停数量 / 涨停比例 */} + {(alert.limit_up_count > 0 || (alert.limit_up_ratio > 0.05 && isV2)) && ( - 涨停 {alert.limit_up_count} + {alert.limit_up_count > 0 + ? `涨停 ${alert.limit_up_count}` + : `涨停 ${Math.round(alert.limit_up_ratio * 100)}%` + } )} diff --git a/src/views/StockOverview/index.js b/src/views/StockOverview/index.js index a888ec09..dacae01b 100644 --- a/src/views/StockOverview/index.js +++ b/src/views/StockOverview/index.js @@ -54,6 +54,7 @@ import { FaChartLine, FaFire, FaRocket, FaBrain, FaCalendarAlt, FaChevronRight, import ConceptStocksModal from '@components/ConceptStocksModal'; import TradeDatePicker from '@components/TradeDatePicker'; import HotspotOverview from './components/HotspotOverview'; +import FlexScreen from './components/FlexScreen'; import { BsGraphUp, BsLightningFill } from 'react-icons/bs'; import * as echarts from 'echarts'; import { logger } from '../../utils/logger'; @@ -846,6 +847,11 @@ const StockOverview = () => {
+ {/* 灵活屏 - 实时行情监控 */} + + + + {/* 今日热门概念 */}