#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 融合异动回测脚本 使用 HybridAnomalyDetector 进行回测: - 规则评分 + LSTM Autoencoder 融合判断 - 输出更丰富的异动信息 使用方法: python backtest_hybrid.py --start 2024-01-01 --end 2024-12-01 python backtest_hybrid.py --start 2024-11-01 --dry-run """ import os import sys import argparse import json from datetime import datetime from pathlib import Path from typing import Dict, List, Optional from collections import defaultdict import numpy as np import pandas as pd from tqdm import tqdm from sqlalchemy import create_engine, text # 添加父目录到路径 sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from detector import HybridAnomalyDetector, create_detector # ==================== 配置 ==================== MYSQL_ENGINE = create_engine( "mysql+pymysql://root:Zzl5588161!@192.168.1.5:3306/stock", echo=False ) FEATURES = [ 'alpha', 'alpha_delta', 'amt_ratio', 'amt_delta', 'rank_pct', 'limit_up_ratio', ] BACKTEST_CONFIG = { 'seq_len': 30, 'min_alpha_abs': 0.3, # 降低阈值,让规则也能发挥作用 'cooldown_minutes': 8, 'max_alerts_per_minute': 20, 'clip_value': 10.0, } # ==================== 数据加载 ==================== def load_daily_features(data_dir: str, date: str) -> Optional[pd.DataFrame]: """加载单天的特征数据""" file_path = Path(data_dir) / f"features_{date}.parquet" if not file_path.exists(): return None df = pd.read_parquet(file_path) return df def get_available_dates(data_dir: str, start_date: str, end_date: str) -> List[str]: """获取可用的日期列表""" data_path = Path(data_dir) all_files = sorted(data_path.glob("features_*.parquet")) dates = [] for f in all_files: date = f.stem.replace('features_', '') if start_date <= date <= end_date: dates.append(date) return dates # ==================== 融合回测 ==================== def backtest_single_day_hybrid( detector: HybridAnomalyDetector, df: pd.DataFrame, date: str, seq_len: int = 30 ) -> List[Dict]: """ 使用融合检测器回测单天数据(批量优化版) """ alerts = [] # 按概念分组,预先构建字典 grouped_dict = {cid: cdf for cid, cdf in df.groupby('concept_id', sort=False)} # 冷却记录 cooldown = {} # 获取所有时间点 all_timestamps = sorted(df['timestamp'].unique()) if len(all_timestamps) < seq_len: return alerts # 对每个时间点进行检测 for t_idx in range(seq_len - 1, len(all_timestamps)): current_time = all_timestamps[t_idx] window_start_time = all_timestamps[t_idx - seq_len + 1] # 批量收集该时刻所有候选概念 batch_sequences = [] batch_features = [] batch_infos = [] for concept_id, concept_df in grouped_dict.items(): # 检查冷却(提前过滤) if concept_id in cooldown: last_alert = cooldown[concept_id] if isinstance(current_time, datetime): time_diff = (current_time - last_alert).total_seconds() / 60 else: time_diff = BACKTEST_CONFIG['cooldown_minutes'] + 1 if time_diff < BACKTEST_CONFIG['cooldown_minutes']: continue # 获取时间窗口内的数据 mask = (concept_df['timestamp'] >= window_start_time) & (concept_df['timestamp'] <= current_time) window_df = concept_df.loc[mask] if len(window_df) < seq_len: continue window_df = window_df.sort_values('timestamp').tail(seq_len) # 当前时刻特征 current_row = window_df.iloc[-1] alpha = current_row.get('alpha', 0) # 过滤微小波动(提前过滤) if abs(alpha) < BACKTEST_CONFIG['min_alpha_abs']: continue # 提取特征序列 sequence = window_df[FEATURES].values sequence = np.nan_to_num(sequence, nan=0.0, posinf=0.0, neginf=0.0) sequence = np.clip(sequence, -BACKTEST_CONFIG['clip_value'], BACKTEST_CONFIG['clip_value']) current_features = { 'alpha': alpha, 'alpha_delta': current_row.get('alpha_delta', 0), 'amt_ratio': current_row.get('amt_ratio', 1), 'amt_delta': current_row.get('amt_delta', 0), 'rank_pct': current_row.get('rank_pct', 0.5), 'limit_up_ratio': current_row.get('limit_up_ratio', 0), } batch_sequences.append(sequence) batch_features.append(current_features) batch_infos.append({ 'concept_id': concept_id, 'stock_count': current_row.get('stock_count', 0), 'total_amt': current_row.get('total_amt', 0), }) if not batch_sequences: continue # 批量 ML 推理 sequences_array = np.array(batch_sequences) ml_scores = detector.ml_scorer.score(sequences_array) if detector.ml_scorer.is_ready() else [0.0] * len(batch_sequences) if isinstance(ml_scores, float): ml_scores = [ml_scores] # 批量规则评分 + 融合 minute_alerts = [] for i, (features, info) in enumerate(zip(batch_features, batch_infos)): concept_id = info['concept_id'] # 规则评分 rule_score, rule_details = detector.rule_scorer.score(features) # ML 评分 ml_score = ml_scores[i] if i < len(ml_scores) else 0.0 # 融合 w1 = detector.config['rule_weight'] w2 = detector.config['ml_weight'] final_score = w1 * rule_score + w2 * ml_score # 判断是否异动 is_anomaly = False trigger_reason = '' if rule_score >= detector.config['rule_trigger']: is_anomaly = True trigger_reason = f'规则强信号({rule_score:.0f}分)' elif ml_score >= detector.config['ml_trigger']: is_anomaly = True trigger_reason = f'ML强信号({ml_score:.0f}分)' elif final_score >= detector.config['fusion_trigger']: is_anomaly = True trigger_reason = f'融合触发({final_score:.0f}分)' if not is_anomaly: continue # 异动类型 alpha = features.get('alpha', 0) if alpha >= 1.5: anomaly_type = 'surge_up' elif alpha <= -1.5: anomaly_type = 'surge_down' elif features.get('amt_ratio', 1) >= 3.0: anomaly_type = 'volume_spike' else: anomaly_type = 'unknown' alert = { 'concept_id': concept_id, 'alert_time': current_time, 'trade_date': date, 'alert_type': anomaly_type, 'final_score': final_score, 'rule_score': rule_score, 'ml_score': ml_score, 'trigger_reason': trigger_reason, 'triggered_rules': list(rule_details.keys()), **features, **info, } minute_alerts.append(alert) cooldown[concept_id] = current_time # 按最终得分排序 minute_alerts.sort(key=lambda x: x['final_score'], reverse=True) alerts.extend(minute_alerts[:BACKTEST_CONFIG['max_alerts_per_minute']]) return alerts # ==================== 数据库写入 ==================== def save_alerts_to_mysql(alerts: List[Dict], dry_run: bool = False) -> int: """保存异动到 MySQL(增强版字段)""" if not alerts: return 0 if dry_run: print(f" [Dry Run] 将写入 {len(alerts)} 条异动") return len(alerts) saved = 0 with MYSQL_ENGINE.begin() as conn: for alert in alerts: try: # 检查是否已存在 check_sql = text(""" SELECT id FROM concept_minute_alert WHERE concept_id = :concept_id AND alert_time = :alert_time AND trade_date = :trade_date """) exists = conn.execute(check_sql, { 'concept_id': alert['concept_id'], 'alert_time': alert['alert_time'], 'trade_date': alert['trade_date'], }).fetchone() if exists: continue # 插入新记录 insert_sql = text(""" INSERT INTO concept_minute_alert (concept_id, concept_name, alert_time, alert_type, trade_date, change_pct, zscore, importance_score, stock_count, extra_info) VALUES (:concept_id, :concept_name, :alert_time, :alert_type, :trade_date, :change_pct, :zscore, :importance_score, :stock_count, :extra_info) """) extra_info = { 'detection_method': 'hybrid', 'final_score': alert['final_score'], 'rule_score': alert['rule_score'], 'ml_score': alert['ml_score'], 'trigger_reason': alert['trigger_reason'], 'triggered_rules': alert['triggered_rules'], 'alpha': alert.get('alpha', 0), 'alpha_delta': alert.get('alpha_delta', 0), 'amt_ratio': alert.get('amt_ratio', 1), } conn.execute(insert_sql, { 'concept_id': alert['concept_id'], 'concept_name': alert.get('concept_name', ''), 'alert_time': alert['alert_time'], 'alert_type': alert['alert_type'], 'trade_date': alert['trade_date'], 'change_pct': alert.get('alpha', 0), 'zscore': alert['final_score'], # 用最终得分作为 zscore 'importance_score': alert['final_score'], 'stock_count': alert.get('stock_count', 0), 'extra_info': json.dumps(extra_info, ensure_ascii=False) }) saved += 1 except Exception as e: print(f" 保存失败: {alert['concept_id']} - {e}") return saved def export_alerts_to_csv(alerts: List[Dict], output_path: str): """导出异动到 CSV""" if not alerts: return df = pd.DataFrame(alerts) df.to_csv(output_path, index=False, encoding='utf-8-sig') print(f"已导出到: {output_path}") # ==================== 统计分析 ==================== def analyze_alerts(alerts: List[Dict]): """分析异动结果""" if not alerts: print("无异动数据") return df = pd.DataFrame(alerts) print("\n" + "=" * 60) print("异动统计分析") print("=" * 60) # 1. 基本统计 print(f"\n总异动数: {len(alerts)}") # 2. 按类型统计 print(f"\n异动类型分布:") print(df['alert_type'].value_counts()) # 3. 得分统计 print(f"\n得分统计:") print(f" 最终得分 - Mean: {df['final_score'].mean():.1f}, Max: {df['final_score'].max():.1f}") print(f" 规则得分 - Mean: {df['rule_score'].mean():.1f}, Max: {df['rule_score'].max():.1f}") print(f" ML得分 - Mean: {df['ml_score'].mean():.1f}, Max: {df['ml_score'].max():.1f}") # 4. 触发来源分析 print(f"\n触发来源分析:") trigger_counts = df['trigger_reason'].apply( lambda x: '规则' if '规则' in x else ('ML' if 'ML' in x else '融合') ).value_counts() print(trigger_counts) # 5. 规则触发频率 all_rules = [] for rules in df['triggered_rules']: if isinstance(rules, list): all_rules.extend(rules) if all_rules: print(f"\n最常触发的规则 (Top 10):") from collections import Counter rule_counts = Counter(all_rules) for rule, count in rule_counts.most_common(10): print(f" {rule}: {count}") # ==================== 主函数 ==================== def main(): parser = argparse.ArgumentParser(description='融合异动回测') parser.add_argument('--data_dir', type=str, default='ml/data', help='特征数据目录') parser.add_argument('--checkpoint_dir', type=str, default='ml/checkpoints', help='模型检查点目录') parser.add_argument('--start', type=str, required=True, help='开始日期 (YYYY-MM-DD)') parser.add_argument('--end', type=str, default=None, help='结束日期 (YYYY-MM-DD),默认=start') parser.add_argument('--dry-run', action='store_true', help='只计算,不写入数据库') parser.add_argument('--export-csv', type=str, default=None, help='导出 CSV 文件路径') parser.add_argument('--rule-weight', type=float, default=0.6, help='规则权重 (0-1)') parser.add_argument('--ml-weight', type=float, default=0.4, help='ML权重 (0-1)') parser.add_argument('--device', type=str, default='cuda', help='设备 (cuda/cpu),默认 cuda') args = parser.parse_args() if args.end is None: args.end = args.start print("=" * 60) print("融合异动回测 (规则 + LSTM)") print("=" * 60) print(f"日期范围: {args.start} ~ {args.end}") print(f"数据目录: {args.data_dir}") print(f"模型目录: {args.checkpoint_dir}") print(f"规则权重: {args.rule_weight}") print(f"ML权重: {args.ml_weight}") print(f"设备: {args.device}") print(f"Dry Run: {args.dry_run}") print("=" * 60) # 初始化融合检测器(使用 GPU) config = { 'rule_weight': args.rule_weight, 'ml_weight': args.ml_weight, } # 修改 detector.py 中 MLScorer 的设备 from detector import HybridAnomalyDetector detector = HybridAnomalyDetector(config, args.checkpoint_dir, device=args.device) # 获取可用日期 dates = get_available_dates(args.data_dir, args.start, args.end) if not dates: print(f"未找到 {args.start} ~ {args.end} 范围内的数据") return print(f"\n找到 {len(dates)} 天的数据") # 回测 all_alerts = [] total_saved = 0 for date in tqdm(dates, desc="回测进度"): df = load_daily_features(args.data_dir, date) if df is None or df.empty: continue alerts = backtest_single_day_hybrid( detector, df, date, seq_len=BACKTEST_CONFIG['seq_len'] ) if alerts: all_alerts.extend(alerts) saved = save_alerts_to_mysql(alerts, dry_run=args.dry_run) total_saved += saved if not args.dry_run: tqdm.write(f" {date}: 检测到 {len(alerts)} 个异动,保存 {saved} 条") # 导出 CSV if args.export_csv and all_alerts: export_alerts_to_csv(all_alerts, args.export_csv) # 统计分析 analyze_alerts(all_alerts) # 汇总 print("\n" + "=" * 60) print("回测完成!") print("=" * 60) print(f"总计检测到: {len(all_alerts)} 个异动") print(f"保存到数据库: {total_saved} 条") print("=" * 60) if __name__ == "__main__": main()