#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 快速融合异动回测脚本 优化策略: 1. 预先构建所有序列(向量化),避免循环内重复切片 2. 批量 ML 推理(一次推理所有候选) 3. 使用 NumPy 向量化操作替代 Python 循环 性能对比: - 原版:5分钟/天 - 优化版:预计 10-30秒/天 """ import os import sys import argparse import json from datetime import datetime from pathlib import Path from typing import Dict, List, Optional, Tuple from collections import defaultdict import numpy as np import pandas as pd import torch from tqdm import tqdm from sqlalchemy import create_engine, text sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) # ==================== 配置 ==================== MYSQL_ENGINE = create_engine( "mysql+pymysql://root:Zzl5588161!@192.168.1.5:3306/stock", echo=False ) FEATURES = ['alpha', 'alpha_delta', 'amt_ratio', 'amt_delta', 'rank_pct', 'limit_up_ratio'] CONFIG = { 'seq_len': 15, # 序列长度(支持跨日后可从 9:30 检测) 'min_alpha_abs': 0.3, # 最小 alpha 过滤 'cooldown_minutes': 8, 'max_alerts_per_minute': 20, 'clip_value': 10.0, # === 融合权重:均衡 === 'rule_weight': 0.5, 'ml_weight': 0.5, # === 触发阈值 === 'rule_trigger': 65, # 60 -> 65,略提高规则门槛 'ml_trigger': 70, # 75 -> 70,略降低 ML 门槛 'fusion_trigger': 45, } # ==================== 规则评分(向量化版)==================== def get_size_adjusted_thresholds(stock_count: np.ndarray) -> dict: """ 根据概念股票数量计算动态阈值 设计思路: - 小概念(<10 只):波动大是正常的,需要更高阈值 - 中概念(10-50 只):标准阈值 - 大概念(>50 只):能有明显波动说明是真异动,降低阈值 返回各指标的调整系数(乘以基准阈值) """ n = len(stock_count) # 基于股票数量的调整系数 # 小概念:系数 > 1(提高阈值,更难触发) # 大概念:系数 < 1(降低阈值,更容易触发) size_factor = np.ones(n) # 微型概念(<5 只):阈值 × 1.8 tiny = stock_count < 5 size_factor[tiny] = 1.8 # 小概念(5-10 只):阈值 × 1.4 small = (stock_count >= 5) & (stock_count < 10) size_factor[small] = 1.4 # 中小概念(10-20 只):阈值 × 1.2 medium_small = (stock_count >= 10) & (stock_count < 20) size_factor[medium_small] = 1.2 # 中概念(20-50 只):标准阈值 × 1.0 medium = (stock_count >= 20) & (stock_count < 50) size_factor[medium] = 1.0 # 大概念(50-100 只):阈值 × 0.85 large = (stock_count >= 50) & (stock_count < 100) size_factor[large] = 0.85 # 超大概念(>100 只):阈值 × 0.7 xlarge = stock_count >= 100 size_factor[xlarge] = 0.7 return size_factor def score_rules_batch(df: pd.DataFrame) -> Tuple[np.ndarray, List[List[str]]]: """ 批量计算规则得分(向量化)- 考虑概念规模版 设计原则: - 规则作为辅助信号,不应单独主导决策 - 根据概念股票数量动态调整阈值 - 大概念异动更有价值,小概念需要更大波动才算异动 Args: df: DataFrame,包含所有特征列(必须包含 stock_count) Returns: scores: (n,) 规则得分数组 triggered_rules: 每行触发的规则列表 """ n = len(df) scores = np.zeros(n) triggered = [[] for _ in range(n)] alpha = df['alpha'].values alpha_delta = df['alpha_delta'].values amt_ratio = df['amt_ratio'].values amt_delta = df['amt_delta'].values rank_pct = df['rank_pct'].values limit_up_ratio = df['limit_up_ratio'].values stock_count = df['stock_count'].values if 'stock_count' in df.columns else np.full(n, 20) alpha_abs = np.abs(alpha) alpha_delta_abs = np.abs(alpha_delta) # 获取基于规模的调整系数 size_factor = get_size_adjusted_thresholds(stock_count) # ========== Alpha 规则(动态阈值)========== # 基准阈值:极强 5%,强 4%,中等 3% # 实际阈值 = 基准 × size_factor # 极强信号 alpha_extreme_thresh = 5.0 * size_factor mask = alpha_abs >= alpha_extreme_thresh scores[mask] += 20 for i in np.where(mask)[0]: triggered[i].append('alpha_extreme') # 强信号 alpha_strong_thresh = 4.0 * size_factor mask = (alpha_abs >= alpha_strong_thresh) & (alpha_abs < alpha_extreme_thresh) scores[mask] += 15 for i in np.where(mask)[0]: triggered[i].append('alpha_strong') # 中等信号 alpha_medium_thresh = 3.0 * size_factor mask = (alpha_abs >= alpha_medium_thresh) & (alpha_abs < alpha_strong_thresh) scores[mask] += 10 for i in np.where(mask)[0]: triggered[i].append('alpha_medium') # ========== Alpha 加速度规则(动态阈值)========== delta_strong_thresh = 2.0 * size_factor mask = alpha_delta_abs >= delta_strong_thresh scores[mask] += 15 for i in np.where(mask)[0]: triggered[i].append('alpha_delta_strong') delta_medium_thresh = 1.5 * size_factor mask = (alpha_delta_abs >= delta_medium_thresh) & (alpha_delta_abs < delta_strong_thresh) scores[mask] += 10 for i in np.where(mask)[0]: triggered[i].append('alpha_delta_medium') # ========== 成交额规则(不受规模影响,放量就是放量)========== mask = amt_ratio >= 10.0 scores[mask] += 20 for i in np.where(mask)[0]: triggered[i].append('volume_extreme') mask = (amt_ratio >= 6.0) & (amt_ratio < 10.0) scores[mask] += 12 for i in np.where(mask)[0]: triggered[i].append('volume_strong') # ========== 排名规则 ========== mask = rank_pct >= 0.98 scores[mask] += 15 for i in np.where(mask)[0]: triggered[i].append('rank_top') mask = rank_pct <= 0.02 scores[mask] += 15 for i in np.where(mask)[0]: triggered[i].append('rank_bottom') # ========== 涨停规则(动态阈值)========== # 大概念有涨停更有意义 limit_high_thresh = 0.30 * size_factor mask = limit_up_ratio >= limit_high_thresh scores[mask] += 20 for i in np.where(mask)[0]: triggered[i].append('limit_up_high') limit_medium_thresh = 0.20 * size_factor mask = (limit_up_ratio >= limit_medium_thresh) & (limit_up_ratio < limit_high_thresh) scores[mask] += 12 for i in np.where(mask)[0]: triggered[i].append('limit_up_medium') # ========== 概念规模加分(大概念异动更有价值)========== # 大概念(50+)额外加分 large_concept = stock_count >= 50 has_signal = scores > 0 # 至少触发了某个规则 mask = large_concept & has_signal scores[mask] += 10 for i in np.where(mask)[0]: triggered[i].append('large_concept_bonus') # 超大概念(100+)再加分 xlarge_concept = stock_count >= 100 mask = xlarge_concept & has_signal scores[mask] += 10 for i in np.where(mask)[0]: triggered[i].append('xlarge_concept_bonus') # ========== 组合规则(动态阈值)========== combo_alpha_thresh = 3.0 * size_factor # Alpha + 放量 + 排名(三重验证) mask = (alpha_abs >= combo_alpha_thresh) & (amt_ratio >= 5.0) & ((rank_pct >= 0.95) | (rank_pct <= 0.05)) scores[mask] += 20 for i in np.where(mask)[0]: triggered[i].append('triple_signal') # Alpha + 涨停(强组合) mask = (alpha_abs >= combo_alpha_thresh) & (limit_up_ratio >= 0.15 * size_factor) scores[mask] += 15 for i in np.where(mask)[0]: triggered[i].append('alpha_with_limit') # ========== 小概念惩罚(过滤噪音)========== # 微型概念(<5 只)如果只有单一信号,减分 tiny_concept = stock_count < 5 single_rule = np.array([len(t) <= 1 for t in triggered]) mask = tiny_concept & single_rule & (scores > 0) scores[mask] *= 0.5 # 减半 for i in np.where(mask)[0]: triggered[i].append('tiny_concept_penalty') scores = np.clip(scores, 0, 100) return scores, triggered # ==================== ML 评分器 ==================== class FastMLScorer: """快速 ML 评分器""" def __init__(self, checkpoint_dir: str = 'ml/checkpoints', device: str = 'auto'): self.checkpoint_dir = Path(checkpoint_dir) if device == 'auto': self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') elif device == 'cuda' and not torch.cuda.is_available(): print("警告: CUDA 不可用,使用 CPU") self.device = torch.device('cpu') else: self.device = torch.device(device) self.model = None self.thresholds = None self._load_model() def _load_model(self): model_path = self.checkpoint_dir / 'best_model.pt' thresholds_path = self.checkpoint_dir / 'thresholds.json' config_path = self.checkpoint_dir / 'config.json' if not model_path.exists(): print(f"警告: 模型不存在 {model_path}") return try: from model import LSTMAutoencoder config = {} if config_path.exists(): with open(config_path) as f: config = json.load(f).get('model', {}) # 处理旧配置键名 if 'd_model' in config: config['hidden_dim'] = config.pop('d_model') // 2 for key in ['num_encoder_layers', 'num_decoder_layers', 'nhead', 'dim_feedforward', 'max_seq_len', 'use_instance_norm']: config.pop(key, None) if 'num_layers' not in config: config['num_layers'] = 1 checkpoint = torch.load(model_path, map_location='cpu') self.model = LSTMAutoencoder(**config) self.model.load_state_dict(checkpoint['model_state_dict']) self.model.to(self.device) self.model.eval() if thresholds_path.exists(): with open(thresholds_path) as f: self.thresholds = json.load(f) print(f"ML模型加载成功 (设备: {self.device})") except Exception as e: print(f"ML模型加载失败: {e}") self.model = None def is_ready(self): return self.model is not None @torch.no_grad() def score_batch(self, sequences: np.ndarray) -> np.ndarray: """ 批量计算 ML 得分 Args: sequences: (batch, seq_len, n_features) Returns: scores: (batch,) 0-100 分数 """ if not self.is_ready() or len(sequences) == 0: return np.zeros(len(sequences)) x = torch.FloatTensor(sequences).to(self.device) output, _ = self.model(x) mse = ((output - x) ** 2).mean(dim=-1) errors = mse[:, -1].cpu().numpy() p95 = self.thresholds.get('p95', 0.1) if self.thresholds else 0.1 scores = np.clip(errors / p95 * 50, 0, 100) return scores # ==================== 快速回测 ==================== def build_sequences_fast( df: pd.DataFrame, seq_len: int = 30, prev_df: pd.DataFrame = None ) -> Tuple[np.ndarray, pd.DataFrame]: """ 快速构建所有有效序列 支持跨日序列:用前一天收盘数据 + 当天开盘数据拼接,实现 9:30 就能检测 Args: df: 当天数据 seq_len: 序列长度 prev_df: 前一天数据(可选,用于构建开盘时的序列) 返回: sequences: (n_valid, seq_len, n_features) 所有有效序列 info_df: 对应的元信息 DataFrame """ # 确保按概念和时间排序 df = df.sort_values(['concept_id', 'timestamp']).reset_index(drop=True) # 如果有前一天数据,按概念构建尾部缓存(取每个概念最后 seq_len-1 条) prev_cache = {} if prev_df is not None and len(prev_df) > 0: prev_df = prev_df.sort_values(['concept_id', 'timestamp']) for concept_id, gdf in prev_df.groupby('concept_id'): tail_data = gdf.tail(seq_len - 1) if len(tail_data) > 0: feat_matrix = tail_data[FEATURES].values feat_matrix = np.nan_to_num(feat_matrix, nan=0.0, posinf=0.0, neginf=0.0) feat_matrix = np.clip(feat_matrix, -CONFIG['clip_value'], CONFIG['clip_value']) prev_cache[concept_id] = feat_matrix # 按概念分组 groups = df.groupby('concept_id') sequences = [] infos = [] for concept_id, gdf in groups: gdf = gdf.reset_index(drop=True) # 获取特征矩阵 feat_matrix = gdf[FEATURES].values feat_matrix = np.nan_to_num(feat_matrix, nan=0.0, posinf=0.0, neginf=0.0) feat_matrix = np.clip(feat_matrix, -CONFIG['clip_value'], CONFIG['clip_value']) # 如果有前一天缓存,拼接到当天数据前面 if concept_id in prev_cache: prev_data = prev_cache[concept_id] combined_matrix = np.vstack([prev_data, feat_matrix]) # 计算偏移量:前一天数据的长度 offset = len(prev_data) else: combined_matrix = feat_matrix offset = 0 # 滑动窗口构建序列 n_total = len(combined_matrix) if n_total < seq_len: continue for i in range(n_total - seq_len + 1): seq = combined_matrix[i:i + seq_len] # 计算对应当天数据的索引 # 序列最后一个点的位置 = i + seq_len - 1 # 对应当天数据的索引 = (i + seq_len - 1) - offset today_idx = i + seq_len - 1 - offset # 只要序列的最后一个点是当天的数据,就记录 if today_idx < 0 or today_idx >= len(gdf): continue sequences.append(seq) # 记录最后一个时间步的信息(当天的) row = gdf.iloc[today_idx] infos.append({ 'concept_id': concept_id, 'timestamp': row['timestamp'], 'alpha': row['alpha'], 'alpha_delta': row.get('alpha_delta', 0), 'amt_ratio': row.get('amt_ratio', 1), 'amt_delta': row.get('amt_delta', 0), 'rank_pct': row.get('rank_pct', 0.5), 'limit_up_ratio': row.get('limit_up_ratio', 0), 'stock_count': row.get('stock_count', 0), 'total_amt': row.get('total_amt', 0), }) if not sequences: return np.array([]), pd.DataFrame() return np.array(sequences), pd.DataFrame(infos) def backtest_single_day_fast( ml_scorer: FastMLScorer, df: pd.DataFrame, date: str, config: Dict, prev_df: pd.DataFrame = None ) -> List[Dict]: """ 快速回测单天(向量化版本) Args: ml_scorer: ML 评分器 df: 当天数据 date: 日期 config: 配置 prev_df: 前一天数据(用于 9:30 开始检测) """ seq_len = config.get('seq_len', 30) # 1. 构建所有序列(支持跨日) sequences, info_df = build_sequences_fast(df, seq_len, prev_df) if len(sequences) == 0: return [] # 2. 过滤小波动 alpha_abs = np.abs(info_df['alpha'].values) valid_mask = alpha_abs >= config['min_alpha_abs'] sequences = sequences[valid_mask] info_df = info_df[valid_mask].reset_index(drop=True) if len(sequences) == 0: return [] # 3. 批量规则评分 rule_scores, triggered_rules = score_rules_batch(info_df) # 4. 批量 ML 评分(分批处理避免显存溢出) batch_size = 2048 ml_scores = [] for i in range(0, len(sequences), batch_size): batch_seq = sequences[i:i+batch_size] batch_scores = ml_scorer.score_batch(batch_seq) ml_scores.append(batch_scores) ml_scores = np.concatenate(ml_scores) if ml_scores else np.zeros(len(sequences)) # 5. 融合得分 w1, w2 = config['rule_weight'], config['ml_weight'] final_scores = w1 * rule_scores + w2 * ml_scores # 6. 判断异动 is_anomaly = ( (rule_scores >= config['rule_trigger']) | (ml_scores >= config['ml_trigger']) | (final_scores >= config['fusion_trigger']) ) # 7. 应用冷却期(按概念+时间排序后处理) info_df['rule_score'] = rule_scores info_df['ml_score'] = ml_scores info_df['final_score'] = final_scores info_df['is_anomaly'] = is_anomaly info_df['triggered_rules'] = triggered_rules # 只保留异动 anomaly_df = info_df[info_df['is_anomaly']].copy() if len(anomaly_df) == 0: return [] # 应用冷却期 anomaly_df = anomaly_df.sort_values(['concept_id', 'timestamp']) cooldown = {} keep_mask = [] for _, row in anomaly_df.iterrows(): cid = row['concept_id'] ts = row['timestamp'] if cid in cooldown: try: diff = (ts - cooldown[cid]).total_seconds() / 60 except: diff = config['cooldown_minutes'] + 1 if diff < config['cooldown_minutes']: keep_mask.append(False) continue cooldown[cid] = ts keep_mask.append(True) anomaly_df = anomaly_df[keep_mask] # 8. 按时间分组,每分钟最多 max_alerts_per_minute 个 alerts = [] for ts, group in anomaly_df.groupby('timestamp'): group = group.nlargest(config['max_alerts_per_minute'], 'final_score') for _, row in group.iterrows(): alpha = row['alpha'] if alpha >= 1.5: atype = 'surge_up' elif alpha <= -1.5: atype = 'surge_down' elif row['amt_ratio'] >= 3.0: atype = 'volume_spike' else: atype = 'unknown' rule_score = row['rule_score'] ml_score = row['ml_score'] final_score = row['final_score'] if rule_score >= config['rule_trigger']: trigger = f'规则强信号({rule_score:.0f}分)' elif ml_score >= config['ml_trigger']: trigger = f'ML强信号({ml_score:.0f}分)' else: trigger = f'融合触发({final_score:.0f}分)' alerts.append({ 'concept_id': row['concept_id'], 'alert_time': row['timestamp'], 'trade_date': date, 'alert_type': atype, 'final_score': final_score, 'rule_score': rule_score, 'ml_score': ml_score, 'trigger_reason': trigger, 'triggered_rules': row['triggered_rules'], 'alpha': alpha, 'alpha_delta': row['alpha_delta'], 'amt_ratio': row['amt_ratio'], 'amt_delta': row['amt_delta'], 'rank_pct': row['rank_pct'], 'limit_up_ratio': row['limit_up_ratio'], 'stock_count': row['stock_count'], 'total_amt': row['total_amt'], }) return alerts # ==================== 数据加载 ==================== def load_daily_features(data_dir: str, date: str) -> Optional[pd.DataFrame]: file_path = Path(data_dir) / f"features_{date}.parquet" if not file_path.exists(): return None return pd.read_parquet(file_path) def get_available_dates(data_dir: str, start: str, end: str) -> List[str]: data_path = Path(data_dir) dates = [] for f in sorted(data_path.glob("features_*.parquet")): d = f.stem.replace('features_', '') if start <= d <= end: dates.append(d) return dates def get_prev_trading_day(data_dir: str, date: str) -> Optional[str]: """获取给定日期之前最近的有数据的交易日""" data_path = Path(data_dir) all_dates = sorted([f.stem.replace('features_', '') for f in data_path.glob("features_*.parquet")]) for i, d in enumerate(all_dates): if d == date and i > 0: return all_dates[i - 1] return None def export_to_csv(alerts: List[Dict], path: str): if alerts: pd.DataFrame(alerts).to_csv(path, index=False, encoding='utf-8-sig') print(f"已导出: {path}") # ==================== 数据库写入 ==================== def init_db_table(): """ 初始化数据库表(如果不存在则创建) 表结构说明: - concept_id: 概念ID - alert_time: 异动时间(精确到分钟) - trade_date: 交易日期 - alert_type: 异动类型(surge_up/surge_down/volume_spike/unknown) - final_score: 最终得分(0-100) - rule_score: 规则得分(0-100) - ml_score: ML得分(0-100) - trigger_reason: 触发原因 - alpha: 超额收益率 - alpha_delta: alpha变化速度 - amt_ratio: 成交额放大倍数 - rank_pct: 排名百分位 - stock_count: 概念股票数量 - triggered_rules: 触发的规则列表(JSON) """ create_sql = text(""" CREATE TABLE IF NOT EXISTS concept_anomaly_hybrid ( id INT AUTO_INCREMENT PRIMARY KEY, concept_id VARCHAR(64) NOT NULL, alert_time DATETIME NOT NULL, trade_date DATE NOT NULL, alert_type VARCHAR(32) NOT NULL, final_score FLOAT NOT NULL, rule_score FLOAT NOT NULL, ml_score FLOAT NOT NULL, trigger_reason VARCHAR(64), alpha FLOAT, alpha_delta FLOAT, amt_ratio FLOAT, amt_delta FLOAT, rank_pct FLOAT, limit_up_ratio FLOAT, stock_count INT, total_amt FLOAT, triggered_rules JSON, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, UNIQUE KEY uk_concept_time (concept_id, alert_time, trade_date), INDEX idx_trade_date (trade_date), INDEX idx_concept_id (concept_id), INDEX idx_final_score (final_score), INDEX idx_alert_type (alert_type) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='概念异动检测结果(融合版)' """) with MYSQL_ENGINE.begin() as conn: conn.execute(create_sql) print("数据库表已就绪: concept_anomaly_hybrid") def save_alerts_to_mysql(alerts: List[Dict], dry_run: bool = False) -> int: """ 保存异动到 MySQL Args: alerts: 异动列表 dry_run: 是否只模拟,不实际写入 Returns: 实际保存的记录数 """ if not alerts: return 0 if dry_run: print(f" [Dry Run] 将写入 {len(alerts)} 条异动") return len(alerts) saved = 0 skipped = 0 with MYSQL_ENGINE.begin() as conn: for alert in alerts: try: # 检查是否已存在(使用 INSERT IGNORE 更高效) insert_sql = text(""" INSERT IGNORE INTO concept_anomaly_hybrid (concept_id, alert_time, trade_date, alert_type, final_score, rule_score, ml_score, trigger_reason, alpha, alpha_delta, amt_ratio, amt_delta, rank_pct, limit_up_ratio, stock_count, total_amt, triggered_rules) VALUES (:concept_id, :alert_time, :trade_date, :alert_type, :final_score, :rule_score, :ml_score, :trigger_reason, :alpha, :alpha_delta, :amt_ratio, :amt_delta, :rank_pct, :limit_up_ratio, :stock_count, :total_amt, :triggered_rules) """) result = conn.execute(insert_sql, { 'concept_id': alert['concept_id'], 'alert_time': alert['alert_time'], 'trade_date': alert['trade_date'], 'alert_type': alert['alert_type'], 'final_score': alert['final_score'], 'rule_score': alert['rule_score'], 'ml_score': alert['ml_score'], 'trigger_reason': alert['trigger_reason'], 'alpha': alert.get('alpha', 0), 'alpha_delta': alert.get('alpha_delta', 0), 'amt_ratio': alert.get('amt_ratio', 1), 'amt_delta': alert.get('amt_delta', 0), 'rank_pct': alert.get('rank_pct', 0.5), 'limit_up_ratio': alert.get('limit_up_ratio', 0), 'stock_count': alert.get('stock_count', 0), 'total_amt': alert.get('total_amt', 0), 'triggered_rules': json.dumps(alert.get('triggered_rules', []), ensure_ascii=False), }) if result.rowcount > 0: saved += 1 else: skipped += 1 except Exception as e: print(f" 保存失败: {alert['concept_id']} @ {alert['alert_time']} - {e}") if skipped > 0: print(f" 跳过 {skipped} 条重复记录") return saved def clear_alerts_by_date(trade_date: str) -> int: """清除指定日期的异动记录(用于重新回测)""" with MYSQL_ENGINE.begin() as conn: result = conn.execute( text("DELETE FROM concept_anomaly_hybrid WHERE trade_date = :trade_date"), {'trade_date': trade_date} ) return result.rowcount def analyze_alerts(alerts: List[Dict]): if not alerts: print("无异动") return df = pd.DataFrame(alerts) print(f"\n总异动: {len(alerts)}") print(f"\n类型分布:\n{df['alert_type'].value_counts()}") print(f"\n得分统计:") print(f" 最终: {df['final_score'].mean():.1f} (max: {df['final_score'].max():.1f})") print(f" 规则: {df['rule_score'].mean():.1f} (max: {df['rule_score'].max():.1f})") print(f" ML: {df['ml_score'].mean():.1f} (max: {df['ml_score'].max():.1f})") trigger_type = df['trigger_reason'].apply( lambda x: '规则' if '规则' in x else ('ML' if 'ML' in x else '融合') ) print(f"\n触发来源:\n{trigger_type.value_counts()}") # ==================== 主函数 ==================== def main(): parser = argparse.ArgumentParser(description='快速融合异动回测') parser.add_argument('--data_dir', default='ml/data') parser.add_argument('--checkpoint_dir', default='ml/checkpoints') parser.add_argument('--start', required=True) parser.add_argument('--end', default=None) parser.add_argument('--dry-run', action='store_true', help='模拟运行,不写入数据库') parser.add_argument('--export-csv', default=None, help='导出 CSV 文件路径') parser.add_argument('--save-db', action='store_true', help='保存结果到数据库') parser.add_argument('--clear-first', action='store_true', help='写入前先清除该日期的旧数据') parser.add_argument('--device', default='auto') args = parser.parse_args() if args.end is None: args.end = args.start print("=" * 60) print("快速融合异动回测") print("=" * 60) print(f"日期: {args.start} ~ {args.end}") print(f"设备: {args.device}") print(f"保存数据库: {args.save_db}") print("=" * 60) # 初始化数据库表(如果需要保存) if args.save_db and not args.dry_run: init_db_table() # 初始化 ML 评分器 ml_scorer = FastMLScorer(args.checkpoint_dir, args.device) # 获取日期 dates = get_available_dates(args.data_dir, args.start, args.end) if not dates: print("无数据") return print(f"找到 {len(dates)} 天数据\n") # 回测(支持跨日序列) all_alerts = [] total_saved = 0 prev_df = None # 缓存前一天数据 for i, date in enumerate(tqdm(dates, desc="回测")): df = load_daily_features(args.data_dir, date) if df is None or df.empty: prev_df = None # 当天无数据,清空缓存 continue # 第一天需要加载前一天数据(如果存在) if i == 0 and prev_df is None: prev_date = get_prev_trading_day(args.data_dir, date) if prev_date: prev_df = load_daily_features(args.data_dir, prev_date) if prev_df is not None: tqdm.write(f" 加载前一天数据: {prev_date}") alerts = backtest_single_day_fast(ml_scorer, df, date, CONFIG, prev_df) all_alerts.extend(alerts) # 保存到数据库 if args.save_db and alerts: if args.clear_first and not args.dry_run: cleared = clear_alerts_by_date(date) if cleared > 0: tqdm.write(f" 清除 {date} 旧数据: {cleared} 条") saved = save_alerts_to_mysql(alerts, dry_run=args.dry_run) total_saved += saved tqdm.write(f" {date}: {len(alerts)} 个异动, 保存 {saved} 条") elif alerts: tqdm.write(f" {date}: {len(alerts)} 个异动") # 当天数据成为下一天的 prev_df prev_df = df # 导出 CSV if args.export_csv: export_to_csv(all_alerts, args.export_csv) # 分析 analyze_alerts(all_alerts) print(f"\n总计: {len(all_alerts)} 个异动") if args.save_db: print(f"已保存到数据库: {total_saved} 条") if __name__ == "__main__": main()