#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 历史异动回测脚本 使用训练好的模型,对历史数据进行异动检测,生成异动记录 使用方法: # 回测指定日期范围 python backtest.py --start 2024-01-01 --end 2024-12-01 # 回测单天 python backtest.py --start 2024-11-01 --end 2024-11-01 # 只生成结果,不写入数据库 python backtest.py --start 2024-01-01 --dry-run """ import os import sys import argparse import json from datetime import datetime, timedelta from pathlib import Path from typing import Dict, List, Tuple, Optional from collections import defaultdict import numpy as np import pandas as pd import torch from tqdm import tqdm from sqlalchemy import create_engine, text # 添加父目录到路径 sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from model import TransformerAutoencoder # ==================== 配置 ==================== MYSQL_ENGINE = create_engine( "mysql+pymysql://root:Zzl5588161!@192.168.1.5:3306/stock", echo=False ) # 特征列表(与训练一致) FEATURES = [ 'alpha', 'alpha_delta', 'amt_ratio', 'amt_delta', 'rank_pct', 'limit_up_ratio', ] # 回测配置 BACKTEST_CONFIG = { 'seq_len': 30, # 序列长度 'threshold_key': 'p95', # 使用的阈值 'min_alpha_abs': 0.5, # 最小 Alpha 绝对值(过滤微小波动) 'cooldown_minutes': 8, # 同一概念冷却时间 'max_alerts_per_minute': 15, # 每分钟最多异动数 'clip_value': 10.0, # 极端值截断 } # ==================== 模型加载 ==================== class AnomalyDetector: """异动检测器""" def __init__(self, checkpoint_dir: str = 'ml/checkpoints', device: str = 'auto'): self.checkpoint_dir = Path(checkpoint_dir) # 设备 if device == 'auto': self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') else: self.device = torch.device(device) # 加载配置 self._load_config() # 加载模型 self._load_model() # 加载阈值 self._load_thresholds() print(f"AnomalyDetector 初始化完成") print(f" 设备: {self.device}") print(f" 阈值 ({BACKTEST_CONFIG['threshold_key']}): {self.threshold:.6f}") def _load_config(self): config_path = self.checkpoint_dir / 'config.json' with open(config_path, 'r') as f: self.config = json.load(f) def _load_model(self): model_path = self.checkpoint_dir / 'best_model.pt' checkpoint = torch.load(model_path, map_location=self.device) model_config = self.config['model'].copy() model_config['use_instance_norm'] = self.config.get('use_instance_norm', True) self.model = TransformerAutoencoder(**model_config) self.model.load_state_dict(checkpoint['model_state_dict']) self.model.to(self.device) self.model.eval() def _load_thresholds(self): thresholds_path = self.checkpoint_dir / 'thresholds.json' with open(thresholds_path, 'r') as f: thresholds = json.load(f) self.threshold = thresholds[BACKTEST_CONFIG['threshold_key']] @torch.no_grad() def compute_anomaly_scores(self, sequences: np.ndarray) -> np.ndarray: """ 计算异动分数 Args: sequences: (n_sequences, seq_len, n_features) Returns: scores: (n_sequences,) 每个序列最后时刻的异动分数 """ # 截断极端值 sequences = np.clip(sequences, -BACKTEST_CONFIG['clip_value'], BACKTEST_CONFIG['clip_value']) # 转为 tensor x = torch.FloatTensor(sequences).to(self.device) # 计算重构误差 errors = self.model.compute_reconstruction_error(x, reduction='none') # 取最后一个时刻的误差 scores = errors[:, -1].cpu().numpy() return scores def is_anomaly(self, score: float) -> bool: """判断是否异动""" return score > self.threshold # ==================== 数据加载 ==================== def load_daily_features(data_dir: str, date: str) -> Optional[pd.DataFrame]: """加载单天的特征数据""" file_path = Path(data_dir) / f"features_{date}.parquet" if not file_path.exists(): return None df = pd.read_parquet(file_path) return df def get_available_dates(data_dir: str, start_date: str, end_date: str) -> List[str]: """获取可用的日期列表""" data_path = Path(data_dir) all_files = sorted(data_path.glob("features_*.parquet")) dates = [] for f in all_files: date = f.stem.replace('features_', '') if start_date <= date <= end_date: dates.append(date) return dates # ==================== 回测逻辑 ==================== def backtest_single_day( detector: AnomalyDetector, df: pd.DataFrame, date: str, seq_len: int = 30 ) -> List[Dict]: """ 回测单天数据 Args: detector: 异动检测器 df: 当天的特征数据 date: 日期 seq_len: 序列长度 Returns: alerts: 异动列表 """ alerts = [] # 按概念分组 grouped = df.groupby('concept_id', sort=False) # 冷却记录 {concept_id: last_alert_timestamp} cooldown = {} # 获取所有时间点 all_timestamps = sorted(df['timestamp'].unique()) if len(all_timestamps) < seq_len: return alerts # 对每个时间点进行检测(从第 seq_len 个开始) for t_idx in range(seq_len - 1, len(all_timestamps)): current_time = all_timestamps[t_idx] window_start_time = all_timestamps[t_idx - seq_len + 1] minute_alerts = [] # 收集该时刻所有概念的序列 concept_sequences = [] concept_infos = [] for concept_id, concept_df in grouped: # 获取该概念在时间窗口内的数据 mask = (concept_df['timestamp'] >= window_start_time) & (concept_df['timestamp'] <= current_time) window_df = concept_df[mask].sort_values('timestamp') if len(window_df) < seq_len: continue # 取最后 seq_len 个点 window_df = window_df.tail(seq_len) # 提取特征 features = window_df[FEATURES].values # 处理缺失值 features = np.nan_to_num(features, nan=0.0, posinf=0.0, neginf=0.0) # 获取当前时刻的信息 current_row = window_df.iloc[-1] concept_sequences.append(features) concept_infos.append({ 'concept_id': concept_id, 'timestamp': current_time, 'alpha': current_row.get('alpha', 0), 'alpha_delta': current_row.get('alpha_delta', 0), 'amt_ratio': current_row.get('amt_ratio', 1), 'limit_up_ratio': current_row.get('limit_up_ratio', 0), 'limit_down_ratio': current_row.get('limit_down_ratio', 0), 'rank_pct': current_row.get('rank_pct', 0.5), 'stock_count': current_row.get('stock_count', 0), 'total_amt': current_row.get('total_amt', 0), }) if not concept_sequences: continue # 批量计算异动分数 sequences_array = np.array(concept_sequences) scores = detector.compute_anomaly_scores(sequences_array) # 检测异动 for i, (info, score) in enumerate(zip(concept_infos, scores)): concept_id = info['concept_id'] alpha = info['alpha'] # 过滤小波动 if abs(alpha) < BACKTEST_CONFIG['min_alpha_abs']: continue # 检查冷却 if concept_id in cooldown: last_alert = cooldown[concept_id] if isinstance(current_time, datetime): time_diff = (current_time - last_alert).total_seconds() / 60 else: # timestamp 是字符串或其他格式 time_diff = BACKTEST_CONFIG['cooldown_minutes'] + 1 # 跳过冷却检查 if time_diff < BACKTEST_CONFIG['cooldown_minutes']: continue # 判断是否异动 if not detector.is_anomaly(score): continue # 记录异动 alert_type = 'surge_up' if alpha > 0 else 'surge_down' alert = { 'concept_id': concept_id, 'alert_time': current_time, 'trade_date': date, 'alert_type': alert_type, 'anomaly_score': float(score), 'threshold': detector.threshold, **info } minute_alerts.append(alert) cooldown[concept_id] = current_time # 按分数排序,限制数量 minute_alerts.sort(key=lambda x: x['anomaly_score'], reverse=True) alerts.extend(minute_alerts[:BACKTEST_CONFIG['max_alerts_per_minute']]) return alerts # ==================== 数据库写入 ==================== def save_alerts_to_mysql(alerts: List[Dict], dry_run: bool = False) -> int: """保存异动到 MySQL""" if not alerts: return 0 if dry_run: print(f" [Dry Run] 将写入 {len(alerts)} 条异动") return len(alerts) saved = 0 with MYSQL_ENGINE.begin() as conn: for alert in alerts: try: # 检查是否已存在 check_sql = text(""" SELECT id FROM concept_minute_alert WHERE concept_id = :concept_id AND alert_time = :alert_time AND trade_date = :trade_date """) exists = conn.execute(check_sql, { 'concept_id': alert['concept_id'], 'alert_time': alert['alert_time'], 'trade_date': alert['trade_date'], }).fetchone() if exists: continue # 插入新记录 insert_sql = text(""" INSERT INTO concept_minute_alert (concept_id, concept_name, alert_time, alert_type, trade_date, change_pct, zscore, importance_score, stock_count, extra_info) VALUES (:concept_id, :concept_name, :alert_time, :alert_type, :trade_date, :change_pct, :zscore, :importance_score, :stock_count, :extra_info) """) conn.execute(insert_sql, { 'concept_id': alert['concept_id'], 'concept_name': alert.get('concept_name', ''), 'alert_time': alert['alert_time'], 'alert_type': alert['alert_type'], 'trade_date': alert['trade_date'], 'change_pct': alert.get('alpha', 0), 'zscore': alert['anomaly_score'], 'importance_score': alert['anomaly_score'], 'stock_count': alert.get('stock_count', 0), 'extra_info': json.dumps({ 'detection_method': 'ml_autoencoder', 'threshold': alert['threshold'], 'alpha': alert.get('alpha', 0), 'amt_ratio': alert.get('amt_ratio', 1), }, ensure_ascii=False) }) saved += 1 except Exception as e: print(f" 保存失败: {alert['concept_id']} - {e}") return saved def export_alerts_to_csv(alerts: List[Dict], output_path: str): """导出异动到 CSV""" if not alerts: return df = pd.DataFrame(alerts) df.to_csv(output_path, index=False, encoding='utf-8-sig') print(f"已导出到: {output_path}") # ==================== 主函数 ==================== def main(): parser = argparse.ArgumentParser(description='历史异动回测') parser.add_argument('--data_dir', type=str, default='ml/data', help='特征数据目录') parser.add_argument('--checkpoint_dir', type=str, default='ml/checkpoints', help='模型检查点目录') parser.add_argument('--start', type=str, required=True, help='开始日期 (YYYY-MM-DD)') parser.add_argument('--end', type=str, required=True, help='结束日期 (YYYY-MM-DD)') parser.add_argument('--dry-run', action='store_true', help='只计算,不写入数据库') parser.add_argument('--export-csv', type=str, default=None, help='导出 CSV 文件路径') parser.add_argument('--device', type=str, default='auto', help='设备 (auto/cuda/cpu)') args = parser.parse_args() print("=" * 60) print("历史异动回测") print("=" * 60) print(f"日期范围: {args.start} ~ {args.end}") print(f"数据目录: {args.data_dir}") print(f"模型目录: {args.checkpoint_dir}") print(f"Dry Run: {args.dry_run}") print("=" * 60) # 初始化检测器 detector = AnomalyDetector(args.checkpoint_dir, args.device) # 获取可用日期 dates = get_available_dates(args.data_dir, args.start, args.end) if not dates: print(f"未找到 {args.start} ~ {args.end} 范围内的数据") return print(f"\n找到 {len(dates)} 天的数据") # 回测 all_alerts = [] total_saved = 0 for date in tqdm(dates, desc="回测进度"): # 加载数据 df = load_daily_features(args.data_dir, date) if df is None or df.empty: continue # 回测单天 alerts = backtest_single_day( detector, df, date, seq_len=BACKTEST_CONFIG['seq_len'] ) if alerts: all_alerts.extend(alerts) # 写入数据库 saved = save_alerts_to_mysql(alerts, dry_run=args.dry_run) total_saved += saved if not args.dry_run: tqdm.write(f" {date}: 检测到 {len(alerts)} 个异动,保存 {saved} 条") # 导出 CSV if args.export_csv and all_alerts: export_alerts_to_csv(all_alerts, args.export_csv) # 汇总 print("\n" + "=" * 60) print("回测完成!") print("=" * 60) print(f"总计检测到: {len(all_alerts)} 个异动") print(f"保存到数据库: {total_saved} 条") # 统计 if all_alerts: df_alerts = pd.DataFrame(all_alerts) print(f"\n异动类型分布:") print(df_alerts['alert_type'].value_counts()) print(f"\n异动分数统计:") print(f" Mean: {df_alerts['anomaly_score'].mean():.4f}") print(f" Max: {df_alerts['anomaly_score'].max():.4f}") print(f" Min: {df_alerts['anomaly_score'].min():.4f}") print("=" * 60) if __name__ == "__main__": main()