#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Transformer Autoencoder 训练脚本 (修复版) 修复问题: 1. 按概念分组构建序列,避免跨概念切片 2. 按时间(日期)切分数据集,避免数据泄露 3. 使用 RobustScaler + Clipping 处理非平稳性 4. 使用验证集计算阈值 训练流程: 1. 加载预处理好的特征数据(parquet 文件) 2. 按概念分组,在每个概念内部构建序列 3. 按日期划分训练/验证/测试集 4. 训练 Autoencoder(最小化重构误差) 5. 保存模型和阈值 使用方法: python train.py --data_dir ml/data --epochs 100 --batch_size 256 """ import os import sys import argparse import json from datetime import datetime from pathlib import Path from typing import List, Tuple, Dict import numpy as np import pandas as pd import torch import torch.nn as nn from torch.utils.data import Dataset, DataLoader from torch.optim import AdamW from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts from tqdm import tqdm from model import TransformerAutoencoder, AnomalyDetectionLoss, count_parameters # 性能优化:启用 cuDNN benchmark(对固定输入尺寸自动选择最快算法) torch.backends.cudnn.benchmark = True # 启用 TF32(RTX 30/40 系列特有,提速约 3 倍) torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True # 可视化(可选) try: import matplotlib matplotlib.use('Agg') # 无头模式,不需要显示器 import matplotlib.pyplot as plt HAS_MATPLOTLIB = True except ImportError: HAS_MATPLOTLIB = False # ==================== 配置 ==================== TRAIN_CONFIG = { # 数据配置 'seq_len': 30, # 输入序列长度(30分钟) 'stride': 5, # 滑动窗口步长 # 时间切分(按日期) 'train_end_date': '2024-06-30', # 训练集截止日期 'val_end_date': '2024-09-30', # 验证集截止日期(之后为测试集) # 特征配置 'features': [ 'alpha', # 超额收益 'alpha_delta', # Alpha 变化率 'amt_ratio', # 成交额比率 'amt_delta', # 成交额变化率 'rank_pct', # Alpha 排名百分位 'limit_up_ratio', # 涨停比例 ], # 训练配置(针对 4x RTX 4090 优化) 'batch_size': 4096, # 256 -> 4096(大幅增加,充分利用显存) 'epochs': 100, 'learning_rate': 3e-4, # 1e-4 -> 3e-4(大 batch 需要更大学习率) 'weight_decay': 1e-5, 'gradient_clip': 1.0, # 早停配置 'patience': 10, 'min_delta': 1e-6, # 模型配置(LSTM Autoencoder,简洁有效) 'model': { 'n_features': 6, 'hidden_dim': 32, # LSTM 隐藏维度(小) 'latent_dim': 4, # 瓶颈维度(非常小!关键) 'num_layers': 1, # LSTM 层数 'dropout': 0.2, 'bidirectional': True, # 双向编码器 }, # 标准化配置 'use_instance_norm': True, # 模型内部使用 Instance Norm(推荐) 'clip_value': 10.0, # 简单截断极端值 # 阈值配置 'threshold_percentiles': [90, 95, 99], } # ==================== 数据加载(修复版)==================== def load_data_by_date(data_dir: str, features: List[str]) -> Dict[str, pd.DataFrame]: """ 按日期加载数据,返回 {date: DataFrame} 字典 每个 DataFrame 包含该日所有概念的所有时间点数据 """ data_path = Path(data_dir) parquet_files = sorted(data_path.glob("features_*.parquet")) if not parquet_files: raise FileNotFoundError(f"未找到 parquet 文件: {data_dir}") print(f"找到 {len(parquet_files)} 个数据文件") date_data = {} for pf in tqdm(parquet_files, desc="加载数据"): # 提取日期 date = pf.stem.replace('features_', '') df = pd.read_parquet(pf) # 检查必要列 required_cols = features + ['concept_id', 'timestamp'] missing_cols = [c for c in required_cols if c not in df.columns] if missing_cols: print(f"警告: {date} 缺少列: {missing_cols}, 跳过") continue date_data[date] = df print(f"成功加载 {len(date_data)} 天的数据") return date_data def split_data_by_date( date_data: Dict[str, pd.DataFrame], train_end: str, val_end: str ) -> Tuple[Dict[str, pd.DataFrame], Dict[str, pd.DataFrame], Dict[str, pd.DataFrame]]: """ 按日期严格划分数据集 - 训练集: <= train_end - 验证集: train_end < date <= val_end - 测试集: > val_end """ train_data = {} val_data = {} test_data = {} for date, df in date_data.items(): if date <= train_end: train_data[date] = df elif date <= val_end: val_data[date] = df else: test_data[date] = df print(f"数据集划分(按日期):") print(f" 训练集: {len(train_data)} 天 (<= {train_end})") print(f" 验证集: {len(val_data)} 天 ({train_end} ~ {val_end})") print(f" 测试集: {len(test_data)} 天 (> {val_end})") return train_data, val_data, test_data def build_sequences_by_concept( date_data: Dict[str, pd.DataFrame], features: List[str], seq_len: int, stride: int ) -> np.ndarray: """ 按概念分组构建序列(性能优化版) 使用 groupby 一次性分组,避免重复扫描大数组 1. 将所有日期的数据合并 2. 使用 groupby 按 concept_id 分组 3. 在每个概念内部,按时间排序并滑动窗口 4. 合并所有序列 """ # 合并所有日期的数据 all_dfs = [] for date, df in sorted(date_data.items()): df = df.copy() df['date'] = date all_dfs.append(df) if not all_dfs: return np.array([]) combined = pd.concat(all_dfs, ignore_index=True) # 预先排序(按概念、日期、时间),这样 groupby 会更快 combined = combined.sort_values(['concept_id', 'date', 'timestamp']) # 使用 groupby 一次性分组(性能关键!) all_sequences = [] grouped = combined.groupby('concept_id', sort=False) n_concepts = len(grouped) for concept_id, concept_df in tqdm(grouped, desc="构建序列", total=n_concepts, leave=False): # 已经排序过了,直接提取特征 feature_data = concept_df[features].values # 处理缺失值 feature_data = np.nan_to_num(feature_data, nan=0.0, posinf=0.0, neginf=0.0) # 在该概念内部滑动窗口 n_points = len(feature_data) for start in range(0, n_points - seq_len + 1, stride): seq = feature_data[start:start + seq_len] all_sequences.append(seq) if not all_sequences: return np.array([]) sequences = np.array(all_sequences) print(f" 构建序列: {len(sequences):,} 条 (来自 {n_concepts} 个概念)") return sequences # ==================== 数据集 ==================== class SequenceDataset(Dataset): """序列数据集(已经构建好的序列)""" def __init__(self, sequences: np.ndarray): self.sequences = torch.FloatTensor(sequences) def __len__(self) -> int: return len(self.sequences) def __getitem__(self, idx: int) -> torch.Tensor: return self.sequences[idx] # ==================== 训练器 ==================== class EarlyStopping: """早停机制""" def __init__(self, patience: int = 10, min_delta: float = 1e-6): self.patience = patience self.min_delta = min_delta self.counter = 0 self.best_loss = float('inf') self.early_stop = False def __call__(self, val_loss: float) -> bool: if val_loss < self.best_loss - self.min_delta: self.best_loss = val_loss self.counter = 0 else: self.counter += 1 if self.counter >= self.patience: self.early_stop = True return self.early_stop class Trainer: """模型训练器(支持 AMP 混合精度加速)""" def __init__( self, model: nn.Module, train_loader: DataLoader, val_loader: DataLoader, config: Dict, device: torch.device, save_dir: str = 'ml/checkpoints' ): self.model = model.to(device) self.train_loader = train_loader self.val_loader = val_loader self.config = config self.device = device self.save_dir = Path(save_dir) self.save_dir.mkdir(parents=True, exist_ok=True) # 优化器 self.optimizer = AdamW( model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'] ) # 学习率调度器 self.scheduler = CosineAnnealingWarmRestarts( self.optimizer, T_0=10, T_mult=2, eta_min=1e-6 ) # 损失函数(简化版,只用 MSE) self.criterion = AnomalyDetectionLoss() # 早停 self.early_stopping = EarlyStopping( patience=config['patience'], min_delta=config['min_delta'] ) # AMP 混合精度训练(大幅提速 + 省显存) self.use_amp = torch.cuda.is_available() self.scaler = torch.cuda.amp.GradScaler() if self.use_amp else None if self.use_amp: print(" ✓ 启用 AMP 混合精度训练") # 训练历史 self.history = { 'train_loss': [], 'val_loss': [], 'learning_rate': [], } self.best_val_loss = float('inf') def train_epoch(self) -> float: """训练一个 epoch(使用 AMP 混合精度)""" self.model.train() total_loss = 0.0 n_batches = 0 pbar = tqdm(self.train_loader, desc="Training", leave=False) for batch in pbar: batch = batch.to(self.device, non_blocking=True) # 异步传输 self.optimizer.zero_grad(set_to_none=True) # 更快的梯度清零 # AMP 混合精度前向传播 if self.use_amp: with torch.cuda.amp.autocast(): output, latent = self.model(batch) loss, loss_dict = self.criterion(output, batch, latent) # AMP 反向传播 self.scaler.scale(loss).backward() # 梯度裁剪(需要 unscale) self.scaler.unscale_(self.optimizer) torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.config['gradient_clip'] ) self.scaler.step(self.optimizer) self.scaler.update() else: # 非 AMP 模式 output, latent = self.model(batch) loss, loss_dict = self.criterion(output, batch, latent) loss.backward() torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.config['gradient_clip'] ) self.optimizer.step() total_loss += loss.item() n_batches += 1 pbar.set_postfix({'loss': f"{loss.item():.4f}"}) return total_loss / n_batches @torch.no_grad() def validate(self) -> float: """验证(使用 AMP)""" self.model.eval() total_loss = 0.0 n_batches = 0 for batch in self.val_loader: batch = batch.to(self.device, non_blocking=True) if self.use_amp: with torch.cuda.amp.autocast(): output, latent = self.model(batch) loss, _ = self.criterion(output, batch, latent) else: output, latent = self.model(batch) loss, _ = self.criterion(output, batch, latent) total_loss += loss.item() n_batches += 1 return total_loss / n_batches def save_checkpoint(self, epoch: int, val_loss: float, is_best: bool = False): """保存检查点""" # 处理 DataParallel 包装 model_to_save = self.model.module if hasattr(self.model, 'module') else self.model checkpoint = { 'epoch': epoch, 'model_state_dict': model_to_save.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'scheduler_state_dict': self.scheduler.state_dict(), 'val_loss': val_loss, 'config': self.config, } # 保存最新检查点 torch.save(checkpoint, self.save_dir / 'last_checkpoint.pt') # 保存最佳模型 if is_best: torch.save(checkpoint, self.save_dir / 'best_model.pt') print(f" ✓ 保存最佳模型 (val_loss: {val_loss:.6f})") def train(self, epochs: int): """完整训练流程""" print(f"\n开始训练 ({epochs} epochs)...") print(f"设备: {self.device}") print(f"模型参数量: {count_parameters(self.model):,}") for epoch in range(1, epochs + 1): print(f"\nEpoch {epoch}/{epochs}") # 训练 train_loss = self.train_epoch() # 验证 val_loss = self.validate() # 更新学习率 self.scheduler.step() current_lr = self.optimizer.param_groups[0]['lr'] # 记录历史 self.history['train_loss'].append(train_loss) self.history['val_loss'].append(val_loss) self.history['learning_rate'].append(current_lr) # 打印进度 print(f" Train Loss: {train_loss:.6f}") print(f" Val Loss: {val_loss:.6f}") print(f" LR: {current_lr:.2e}") # 保存检查点 is_best = val_loss < self.best_val_loss if is_best: self.best_val_loss = val_loss self.save_checkpoint(epoch, val_loss, is_best) # 早停检查 if self.early_stopping(val_loss): print(f"\n早停触发!验证损失已 {self.early_stopping.patience} 个 epoch 未改善") break print(f"\n训练完成!最佳验证损失: {self.best_val_loss:.6f}") # 保存训练历史 self.save_history() return self.history def save_history(self): """保存训练历史""" history_path = self.save_dir / 'training_history.json' with open(history_path, 'w') as f: json.dump(self.history, f, indent=2) print(f"训练历史已保存: {history_path}") # 绘制训练曲线 self.plot_training_curves() def plot_training_curves(self): """绘制训练曲线""" if not HAS_MATPLOTLIB: print("matplotlib 未安装,跳过绘图") return fig, axes = plt.subplots(1, 2, figsize=(14, 5)) epochs = range(1, len(self.history['train_loss']) + 1) # 1. Loss 曲线 ax1 = axes[0] ax1.plot(epochs, self.history['train_loss'], 'b-', label='Train Loss', linewidth=2) ax1.plot(epochs, self.history['val_loss'], 'r-', label='Val Loss', linewidth=2) ax1.set_xlabel('Epoch', fontsize=12) ax1.set_ylabel('Loss', fontsize=12) ax1.set_title('Training & Validation Loss', fontsize=14) ax1.legend(fontsize=11) ax1.grid(True, alpha=0.3) # 标记最佳点 best_epoch = np.argmin(self.history['val_loss']) + 1 best_val_loss = min(self.history['val_loss']) ax1.axvline(x=best_epoch, color='g', linestyle='--', alpha=0.7, label=f'Best Epoch: {best_epoch}') ax1.scatter([best_epoch], [best_val_loss], color='g', s=100, zorder=5) ax1.annotate(f'Best: {best_val_loss:.6f}', xy=(best_epoch, best_val_loss), xytext=(best_epoch + 2, best_val_loss + 0.0005), fontsize=10, color='green') # 2. 学习率曲线 ax2 = axes[1] ax2.plot(epochs, self.history['learning_rate'], 'g-', linewidth=2) ax2.set_xlabel('Epoch', fontsize=12) ax2.set_ylabel('Learning Rate', fontsize=12) ax2.set_title('Learning Rate Schedule', fontsize=14) ax2.set_yscale('log') ax2.grid(True, alpha=0.3) plt.tight_layout() # 保存图片 plot_path = self.save_dir / 'training_curves.png' plt.savefig(plot_path, dpi=150, bbox_inches='tight') plt.close() print(f"训练曲线已保存: {plot_path}") # ==================== 阈值计算(使用验证集)==================== @torch.no_grad() def compute_thresholds( model: nn.Module, data_loader: DataLoader, device: torch.device, percentiles: List[float] = [90, 95, 99] ) -> Dict[str, float]: """ 在验证集上计算重构误差的百分位数阈值 注:使用验证集而非测试集,避免数据泄露 """ model.eval() all_errors = [] print("计算异动阈值(使用验证集)...") for batch in tqdm(data_loader, desc="Computing thresholds"): batch = batch.to(device) errors = model.compute_reconstruction_error(batch, reduction='none') # 取每个序列的最后一个时刻误差(预测当前时刻) seq_errors = errors[:, -1] # (batch,) all_errors.append(seq_errors.cpu().numpy()) all_errors = np.concatenate(all_errors) thresholds = {} for p in percentiles: threshold = np.percentile(all_errors, p) thresholds[f'p{p}'] = float(threshold) print(f" P{p}: {threshold:.6f}") # 额外统计 thresholds['mean'] = float(np.mean(all_errors)) thresholds['std'] = float(np.std(all_errors)) thresholds['median'] = float(np.median(all_errors)) print(f" Mean: {thresholds['mean']:.6f}") print(f" Median: {thresholds['median']:.6f}") print(f" Std: {thresholds['std']:.6f}") return thresholds # ==================== 主函数 ==================== def main(): parser = argparse.ArgumentParser(description='训练概念异动检测模型') parser.add_argument('--data_dir', type=str, default='ml/data', help='数据目录路径') parser.add_argument('--epochs', type=int, default=100, help='训练轮数') parser.add_argument('--batch_size', type=int, default=4096, help='批次大小(4x RTX 4090 推荐 4096~8192)') parser.add_argument('--lr', type=float, default=3e-4, help='学习率(大 batch 推荐 3e-4)') parser.add_argument('--device', type=str, default='auto', help='设备 (auto/cuda/cpu)') parser.add_argument('--save_dir', type=str, default='ml/checkpoints', help='模型保存目录') parser.add_argument('--train_end', type=str, default='2024-06-30', help='训练集截止日期') parser.add_argument('--val_end', type=str, default='2024-09-30', help='验证集截止日期') args = parser.parse_args() # 更新配置 config = TRAIN_CONFIG.copy() config['batch_size'] = args.batch_size config['epochs'] = args.epochs config['learning_rate'] = args.lr config['train_end_date'] = args.train_end config['val_end_date'] = args.val_end # 设备选择 if args.device == 'auto': device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') else: device = torch.device(args.device) print("=" * 60) print("概念异动检测模型训练(修复版)") print("=" * 60) print(f"配置:") print(f" 数据目录: {args.data_dir}") print(f" 设备: {device}") print(f" 批次大小: {config['batch_size']}") print(f" 学习率: {config['learning_rate']}") print(f" 训练轮数: {config['epochs']}") print(f" 训练集截止: {config['train_end_date']}") print(f" 验证集截止: {config['val_end_date']}") print("=" * 60) # 1. 按日期加载数据 print("\n[1/6] 加载数据...") date_data = load_data_by_date(args.data_dir, config['features']) # 2. 按日期划分 print("\n[2/6] 按日期划分数据集...") train_data, val_data, test_data = split_data_by_date( date_data, config['train_end_date'], config['val_end_date'] ) # 3. 按概念构建序列 print("\n[3/6] 按概念构建序列...") print("训练集:") train_sequences = build_sequences_by_concept( train_data, config['features'], config['seq_len'], config['stride'] ) print("验证集:") val_sequences = build_sequences_by_concept( val_data, config['features'], config['seq_len'], config['stride'] ) print("测试集:") test_sequences = build_sequences_by_concept( test_data, config['features'], config['seq_len'], config['stride'] ) if len(train_sequences) == 0: print("错误: 训练集为空!请检查数据和日期范围") return # 4. 数据预处理(简单截断极端值,标准化在模型内部通过 Instance Norm 完成) print("\n[4/6] 数据预处理...") print(" 注意: 使用 Instance Norm,每个序列在模型内部单独标准化") print(" 这样可以处理不同概念波动率差异(银行 vs 半导体)") clip_value = config['clip_value'] print(f" 截断极端值: ±{clip_value}") # 简单截断极端值(防止异常数据影响训练) train_sequences = np.clip(train_sequences, -clip_value, clip_value) if len(val_sequences) > 0: val_sequences = np.clip(val_sequences, -clip_value, clip_value) if len(test_sequences) > 0: test_sequences = np.clip(test_sequences, -clip_value, clip_value) # 保存配置 save_dir = Path(args.save_dir) save_dir.mkdir(parents=True, exist_ok=True) preprocess_params = { 'features': config['features'], 'normalization': 'instance_norm', # 在模型内部完成 'clip_value': clip_value, 'note': '标准化在模型内部通过 InstanceNorm1d 完成,无需外部 Scaler' } with open(save_dir / 'normalization_stats.json', 'w') as f: json.dump(preprocess_params, f, indent=2) print(f" 预处理参数已保存") # 5. 创建数据集和加载器 print("\n[5/6] 创建数据加载器...") train_dataset = SequenceDataset(train_sequences) val_dataset = SequenceDataset(val_sequences) if len(val_sequences) > 0 else None test_dataset = SequenceDataset(test_sequences) if len(test_sequences) > 0 else None print(f" 训练序列: {len(train_dataset):,}") print(f" 验证序列: {len(val_dataset) if val_dataset else 0:,}") print(f" 测试序列: {len(test_dataset) if test_dataset else 0:,}") # 多卡时增加 num_workers(Linux 上可以用更多) n_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 1 num_workers = min(32, 8 * n_gpus) if sys.platform != 'win32' else 0 print(f" DataLoader workers: {num_workers}") print(f" Batch size: {config['batch_size']}") # 大 batch + 多 worker + prefetch 提速 train_loader = DataLoader( train_dataset, batch_size=config['batch_size'], shuffle=True, num_workers=num_workers, pin_memory=True, prefetch_factor=4 if num_workers > 0 else None, # 预取更多 batch persistent_workers=True if num_workers > 0 else False, # 保持 worker 存活 drop_last=True # 丢弃不完整的最后一批,避免 batch 大小不一致 ) val_loader = DataLoader( val_dataset, batch_size=config['batch_size'] * 2, # 验证时可以用更大 batch(无梯度) shuffle=False, num_workers=num_workers, pin_memory=True, prefetch_factor=4 if num_workers > 0 else None, persistent_workers=True if num_workers > 0 else False, ) if val_dataset else None test_loader = DataLoader( test_dataset, batch_size=config['batch_size'] * 2, shuffle=False, num_workers=num_workers, pin_memory=True, prefetch_factor=4 if num_workers > 0 else None, persistent_workers=True if num_workers > 0 else False, ) if test_dataset else None # 6. 训练 print("\n[6/6] 训练模型...") model_config = config['model'].copy() model = TransformerAutoencoder(**model_config) # 多卡并行 if torch.cuda.device_count() > 1: print(f" 使用 {torch.cuda.device_count()} 张 GPU 并行训练") model = nn.DataParallel(model) if val_loader is None: print("警告: 验证集为空,将使用训练集的一部分作为验证") # 简单处理:用训练集的后 10% 作为验证 split_idx = int(len(train_dataset) * 0.9) train_subset = torch.utils.data.Subset(train_dataset, range(split_idx)) val_subset = torch.utils.data.Subset(train_dataset, range(split_idx, len(train_dataset))) train_loader = DataLoader(train_subset, batch_size=config['batch_size'], shuffle=True, num_workers=num_workers, pin_memory=True) val_loader = DataLoader(val_subset, batch_size=config['batch_size'], shuffle=False, num_workers=num_workers, pin_memory=True) trainer = Trainer( model=model, train_loader=train_loader, val_loader=val_loader, config=config, device=device, save_dir=args.save_dir ) history = trainer.train(config['epochs']) # 7. 计算阈值(使用验证集) print("\n[额外] 计算异动阈值...") # 加载最佳模型 best_checkpoint = torch.load( save_dir / 'best_model.pt', map_location=device ) model.load_state_dict(best_checkpoint['model_state_dict']) model.to(device) # 使用验证集计算阈值(避免数据泄露) thresholds = compute_thresholds( model, val_loader, device, config['threshold_percentiles'] ) # 保存阈值 with open(save_dir / 'thresholds.json', 'w') as f: json.dump(thresholds, f, indent=2) print(f"阈值已保存") # 保存完整配置 with open(save_dir / 'config.json', 'w') as f: json.dump(config, f, indent=2) print("\n" + "=" * 60) print("训练完成!") print("=" * 60) print(f"模型保存位置: {args.save_dir}") print(f" - best_model.pt: 最佳模型权重") print(f" - thresholds.json: 异动阈值") print(f" - normalization_stats.json: 标准化参数") print(f" - config.json: 训练配置") print("=" * 60) if __name__ == "__main__": main()