#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ V2 回测脚本 - 验证时间片对齐 + 持续性确认的效果 回测指标: 1. 准确率:异动后 N 分钟内 alpha 是否继续上涨/下跌 2. 虚警率:多少异动是噪音 3. 持续性:平均异动持续时长 """ import os import sys import json import argparse from datetime import datetime from pathlib import Path from typing import Dict, List, Tuple from collections import defaultdict import numpy as np import pandas as pd from tqdm import tqdm from sqlalchemy import create_engine, text sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from ml.detector_v2 import AnomalyDetectorV2, CONFIG # ==================== 配置 ==================== MYSQL_ENGINE = create_engine( "mysql+pymysql://root:Zzl5588161!@192.168.1.5:3306/stock", echo=False ) # ==================== 回测评估 ==================== def evaluate_alerts( alerts: List[Dict], raw_data: pd.DataFrame, lookahead_minutes: int = 10 ) -> Dict: """ 评估异动质量 指标: 1. 方向正确率:异动后 N 分钟 alpha 方向是否一致 2. 持续率:异动后 N 分钟内有多少时刻 alpha 保持同向 3. 峰值收益:异动后 N 分钟内的最大 alpha """ if not alerts: return {'accuracy': 0, 'sustained_rate': 0, 'avg_peak': 0, 'total_alerts': 0} results = [] for alert in alerts: concept_id = alert['concept_id'] alert_time = alert['alert_time'] alert_alpha = alert['alpha'] is_up = alert_alpha > 0 # 获取该概念在异动后的数据 concept_data = raw_data[ (raw_data['concept_id'] == concept_id) & (raw_data['timestamp'] > alert_time) ].head(lookahead_minutes) if len(concept_data) < 3: continue future_alphas = concept_data['alpha'].values # 方向正确:未来 alpha 平均值与当前同向 avg_future_alpha = np.mean(future_alphas) direction_correct = (is_up and avg_future_alpha > 0) or (not is_up and avg_future_alpha < 0) # 持续率:有多少时刻保持同向 if is_up: sustained_count = sum(1 for a in future_alphas if a > 0) else: sustained_count = sum(1 for a in future_alphas if a < 0) sustained_rate = sustained_count / len(future_alphas) # 峰值收益 if is_up: peak = max(future_alphas) else: peak = min(future_alphas) results.append({ 'direction_correct': direction_correct, 'sustained_rate': sustained_rate, 'peak': peak, 'alert_alpha': alert_alpha, }) if not results: return {'accuracy': 0, 'sustained_rate': 0, 'avg_peak': 0, 'total_alerts': 0} return { 'accuracy': np.mean([r['direction_correct'] for r in results]), 'sustained_rate': np.mean([r['sustained_rate'] for r in results]), 'avg_peak': np.mean([abs(r['peak']) for r in results]), 'total_alerts': len(alerts), 'evaluated_alerts': len(results), } def save_alerts_to_mysql(alerts: List[Dict], dry_run: bool = False) -> int: """保存异动到 MySQL""" if not alerts or dry_run: return 0 # 确保表存在 with MYSQL_ENGINE.begin() as conn: conn.execute(text(""" CREATE TABLE IF NOT EXISTS concept_anomaly_v2 ( id INT AUTO_INCREMENT PRIMARY KEY, concept_id VARCHAR(64) NOT NULL, alert_time DATETIME NOT NULL, trade_date DATE NOT NULL, alert_type VARCHAR(32) NOT NULL, final_score FLOAT NOT NULL, rule_score FLOAT NOT NULL, ml_score FLOAT NOT NULL, trigger_reason VARCHAR(128), confirm_ratio FLOAT, alpha FLOAT, alpha_zscore FLOAT, amt_zscore FLOAT, rank_zscore FLOAT, momentum_3m FLOAT, momentum_5m FLOAT, limit_up_ratio FLOAT, triggered_rules JSON, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, UNIQUE KEY uk_concept_time (concept_id, alert_time, trade_date), INDEX idx_trade_date (trade_date), INDEX idx_final_score (final_score) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='概念异动 V2(时间片对齐+持续确认)' """)) # 插入数据 saved = 0 with MYSQL_ENGINE.begin() as conn: for alert in alerts: try: conn.execute(text(""" INSERT IGNORE INTO concept_anomaly_v2 (concept_id, alert_time, trade_date, alert_type, final_score, rule_score, ml_score, trigger_reason, confirm_ratio, alpha, alpha_zscore, amt_zscore, rank_zscore, momentum_3m, momentum_5m, limit_up_ratio, triggered_rules) VALUES (:concept_id, :alert_time, :trade_date, :alert_type, :final_score, :rule_score, :ml_score, :trigger_reason, :confirm_ratio, :alpha, :alpha_zscore, :amt_zscore, :rank_zscore, :momentum_3m, :momentum_5m, :limit_up_ratio, :triggered_rules) """), { 'concept_id': alert['concept_id'], 'alert_time': alert['alert_time'], 'trade_date': alert['trade_date'], 'alert_type': alert['alert_type'], 'final_score': alert['final_score'], 'rule_score': alert['rule_score'], 'ml_score': alert['ml_score'], 'trigger_reason': alert['trigger_reason'], 'confirm_ratio': alert.get('confirm_ratio', 0), 'alpha': alert['alpha'], 'alpha_zscore': alert.get('alpha_zscore', 0), 'amt_zscore': alert.get('amt_zscore', 0), 'rank_zscore': alert.get('rank_zscore', 0), 'momentum_3m': alert.get('momentum_3m', 0), 'momentum_5m': alert.get('momentum_5m', 0), 'limit_up_ratio': alert.get('limit_up_ratio', 0), 'triggered_rules': json.dumps(alert.get('triggered_rules', [])), }) saved += 1 except Exception as e: print(f"保存失败: {e}") return saved # ==================== 主函数 ==================== def main(): parser = argparse.ArgumentParser(description='V2 回测') parser.add_argument('--start', type=str, required=True, help='开始日期') parser.add_argument('--end', type=str, default=None, help='结束日期') parser.add_argument('--model_dir', type=str, default='ml/checkpoints_v2') parser.add_argument('--baseline_dir', type=str, default='ml/data_v2/baselines') parser.add_argument('--save', action='store_true', help='保存到数据库') parser.add_argument('--lookahead', type=int, default=10, help='评估前瞻时间(分钟)') args = parser.parse_args() end_date = args.end or args.start print("=" * 60) print("V2 回测 - 时间片对齐 + 持续性确认") print("=" * 60) print(f"日期范围: {args.start} ~ {end_date}") print(f"模型目录: {args.model_dir}") print(f"评估前瞻: {args.lookahead} 分钟") # 初始化检测器 detector = AnomalyDetectorV2( model_dir=args.model_dir, baseline_dir=args.baseline_dir ) # 获取交易日 from prepare_data_v2 import get_trading_days trading_days = get_trading_days(args.start, end_date) if not trading_days: print("无交易日") return print(f"交易日数: {len(trading_days)}") # 回测统计 total_stats = { 'total_alerts': 0, 'accuracy_sum': 0, 'sustained_sum': 0, 'peak_sum': 0, 'day_count': 0, } all_alerts = [] for trade_date in tqdm(trading_days, desc="回测进度"): # 检测异动 alerts = detector.detect(trade_date) if not alerts: continue all_alerts.extend(alerts) # 评估 raw_data = detector._compute_raw_features(trade_date) if raw_data.empty: continue stats = evaluate_alerts(alerts, raw_data, args.lookahead) if stats['evaluated_alerts'] > 0: total_stats['total_alerts'] += stats['total_alerts'] total_stats['accuracy_sum'] += stats['accuracy'] * stats['evaluated_alerts'] total_stats['sustained_sum'] += stats['sustained_rate'] * stats['evaluated_alerts'] total_stats['peak_sum'] += stats['avg_peak'] * stats['evaluated_alerts'] total_stats['day_count'] += 1 print(f"\n[{trade_date}] 异动: {stats['total_alerts']}, " f"准确率: {stats['accuracy']:.1%}, " f"持续率: {stats['sustained_rate']:.1%}, " f"峰值: {stats['avg_peak']:.2f}%") # 汇总 print("\n" + "=" * 60) print("回测汇总") print("=" * 60) if total_stats['total_alerts'] > 0: avg_accuracy = total_stats['accuracy_sum'] / total_stats['total_alerts'] avg_sustained = total_stats['sustained_sum'] / total_stats['total_alerts'] avg_peak = total_stats['peak_sum'] / total_stats['total_alerts'] print(f"总异动数: {total_stats['total_alerts']}") print(f"回测天数: {total_stats['day_count']}") print(f"平均每天: {total_stats['total_alerts'] / max(1, total_stats['day_count']):.1f} 个") print(f"方向准确率: {avg_accuracy:.1%}") print(f"持续率: {avg_sustained:.1%}") print(f"平均峰值: {avg_peak:.2f}%") else: print("无异动检测结果") # 保存 if args.save and all_alerts: print(f"\n保存 {len(all_alerts)} 条异动到数据库...") saved = save_alerts_to_mysql(all_alerts) print(f"保存完成: {saved} 条") print("=" * 60) if __name__ == "__main__": main()