809 lines
26 KiB
Python
809 lines
26 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
Transformer Autoencoder 训练脚本 (修复版)
|
||
|
||
修复问题:
|
||
1. 按概念分组构建序列,避免跨概念切片
|
||
2. 按时间(日期)切分数据集,避免数据泄露
|
||
3. 使用 RobustScaler + Clipping 处理非平稳性
|
||
4. 使用验证集计算阈值
|
||
|
||
训练流程:
|
||
1. 加载预处理好的特征数据(parquet 文件)
|
||
2. 按概念分组,在每个概念内部构建序列
|
||
3. 按日期划分训练/验证/测试集
|
||
4. 训练 Autoencoder(最小化重构误差)
|
||
5. 保存模型和阈值
|
||
|
||
使用方法:
|
||
python train.py --data_dir ml/data --epochs 100 --batch_size 256
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import argparse
|
||
import json
|
||
from datetime import datetime
|
||
from pathlib import Path
|
||
from typing import List, Tuple, Dict
|
||
|
||
import numpy as np
|
||
import pandas as pd
|
||
import torch
|
||
import torch.nn as nn
|
||
from torch.utils.data import Dataset, DataLoader
|
||
from torch.optim import AdamW
|
||
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
|
||
from tqdm import tqdm
|
||
|
||
from model import TransformerAutoencoder, AnomalyDetectionLoss, count_parameters
|
||
|
||
# 性能优化:启用 cuDNN benchmark(对固定输入尺寸自动选择最快算法)
|
||
torch.backends.cudnn.benchmark = True
|
||
# 启用 TF32(RTX 30/40 系列特有,提速约 3 倍)
|
||
torch.backends.cuda.matmul.allow_tf32 = True
|
||
torch.backends.cudnn.allow_tf32 = True
|
||
|
||
# 可视化(可选)
|
||
try:
|
||
import matplotlib
|
||
matplotlib.use('Agg') # 无头模式,不需要显示器
|
||
import matplotlib.pyplot as plt
|
||
HAS_MATPLOTLIB = True
|
||
except ImportError:
|
||
HAS_MATPLOTLIB = False
|
||
|
||
|
||
# ==================== 配置 ====================
|
||
|
||
TRAIN_CONFIG = {
|
||
# 数据配置
|
||
'seq_len': 30, # 输入序列长度(30分钟)
|
||
'stride': 5, # 滑动窗口步长
|
||
|
||
# 时间切分(按日期)
|
||
'train_end_date': '2024-06-30', # 训练集截止日期
|
||
'val_end_date': '2024-09-30', # 验证集截止日期(之后为测试集)
|
||
|
||
# 特征配置
|
||
'features': [
|
||
'alpha', # 超额收益
|
||
'alpha_delta', # Alpha 变化率
|
||
'amt_ratio', # 成交额比率
|
||
'amt_delta', # 成交额变化率
|
||
'rank_pct', # Alpha 排名百分位
|
||
'limit_up_ratio', # 涨停比例
|
||
],
|
||
|
||
# 训练配置(针对 4x RTX 4090 优化)
|
||
'batch_size': 4096, # 256 -> 4096(大幅增加,充分利用显存)
|
||
'epochs': 100,
|
||
'learning_rate': 3e-4, # 1e-4 -> 3e-4(大 batch 需要更大学习率)
|
||
'weight_decay': 1e-5,
|
||
'gradient_clip': 1.0,
|
||
|
||
# 早停配置
|
||
'patience': 10,
|
||
'min_delta': 1e-6,
|
||
|
||
# 模型配置(LSTM Autoencoder,简洁有效)
|
||
'model': {
|
||
'n_features': 6,
|
||
'hidden_dim': 32, # LSTM 隐藏维度(小)
|
||
'latent_dim': 4, # 瓶颈维度(非常小!关键)
|
||
'num_layers': 1, # LSTM 层数
|
||
'dropout': 0.2,
|
||
'bidirectional': True, # 双向编码器
|
||
},
|
||
|
||
# 标准化配置
|
||
'use_instance_norm': True, # 模型内部使用 Instance Norm(推荐)
|
||
'clip_value': 10.0, # 简单截断极端值
|
||
|
||
# 阈值配置
|
||
'threshold_percentiles': [90, 95, 99],
|
||
}
|
||
|
||
|
||
# ==================== 数据加载(修复版)====================
|
||
|
||
def load_data_by_date(data_dir: str, features: List[str]) -> Dict[str, pd.DataFrame]:
|
||
"""
|
||
按日期加载数据,返回 {date: DataFrame} 字典
|
||
|
||
每个 DataFrame 包含该日所有概念的所有时间点数据
|
||
"""
|
||
data_path = Path(data_dir)
|
||
parquet_files = sorted(data_path.glob("features_*.parquet"))
|
||
|
||
if not parquet_files:
|
||
raise FileNotFoundError(f"未找到 parquet 文件: {data_dir}")
|
||
|
||
print(f"找到 {len(parquet_files)} 个数据文件")
|
||
|
||
date_data = {}
|
||
|
||
for pf in tqdm(parquet_files, desc="加载数据"):
|
||
# 提取日期
|
||
date = pf.stem.replace('features_', '')
|
||
|
||
df = pd.read_parquet(pf)
|
||
|
||
# 检查必要列
|
||
required_cols = features + ['concept_id', 'timestamp']
|
||
missing_cols = [c for c in required_cols if c not in df.columns]
|
||
if missing_cols:
|
||
print(f"警告: {date} 缺少列: {missing_cols}, 跳过")
|
||
continue
|
||
|
||
date_data[date] = df
|
||
|
||
print(f"成功加载 {len(date_data)} 天的数据")
|
||
return date_data
|
||
|
||
|
||
def split_data_by_date(
|
||
date_data: Dict[str, pd.DataFrame],
|
||
train_end: str,
|
||
val_end: str
|
||
) -> Tuple[Dict[str, pd.DataFrame], Dict[str, pd.DataFrame], Dict[str, pd.DataFrame]]:
|
||
"""
|
||
按日期严格划分数据集
|
||
|
||
- 训练集: <= train_end
|
||
- 验证集: train_end < date <= val_end
|
||
- 测试集: > val_end
|
||
"""
|
||
train_data = {}
|
||
val_data = {}
|
||
test_data = {}
|
||
|
||
for date, df in date_data.items():
|
||
if date <= train_end:
|
||
train_data[date] = df
|
||
elif date <= val_end:
|
||
val_data[date] = df
|
||
else:
|
||
test_data[date] = df
|
||
|
||
print(f"数据集划分(按日期):")
|
||
print(f" 训练集: {len(train_data)} 天 (<= {train_end})")
|
||
print(f" 验证集: {len(val_data)} 天 ({train_end} ~ {val_end})")
|
||
print(f" 测试集: {len(test_data)} 天 (> {val_end})")
|
||
|
||
return train_data, val_data, test_data
|
||
|
||
|
||
def build_sequences_by_concept(
|
||
date_data: Dict[str, pd.DataFrame],
|
||
features: List[str],
|
||
seq_len: int,
|
||
stride: int
|
||
) -> np.ndarray:
|
||
"""
|
||
按概念分组构建序列(性能优化版)
|
||
|
||
使用 groupby 一次性分组,避免重复扫描大数组
|
||
|
||
1. 将所有日期的数据合并
|
||
2. 使用 groupby 按 concept_id 分组
|
||
3. 在每个概念内部,按时间排序并滑动窗口
|
||
4. 合并所有序列
|
||
"""
|
||
# 合并所有日期的数据
|
||
all_dfs = []
|
||
for date, df in sorted(date_data.items()):
|
||
df = df.copy()
|
||
df['date'] = date
|
||
all_dfs.append(df)
|
||
|
||
if not all_dfs:
|
||
return np.array([])
|
||
|
||
combined = pd.concat(all_dfs, ignore_index=True)
|
||
|
||
# 预先排序(按概念、日期、时间),这样 groupby 会更快
|
||
combined = combined.sort_values(['concept_id', 'date', 'timestamp'])
|
||
|
||
# 使用 groupby 一次性分组(性能关键!)
|
||
all_sequences = []
|
||
grouped = combined.groupby('concept_id', sort=False)
|
||
n_concepts = len(grouped)
|
||
|
||
for concept_id, concept_df in tqdm(grouped, desc="构建序列", total=n_concepts, leave=False):
|
||
# 已经排序过了,直接提取特征
|
||
feature_data = concept_df[features].values
|
||
|
||
# 处理缺失值
|
||
feature_data = np.nan_to_num(feature_data, nan=0.0, posinf=0.0, neginf=0.0)
|
||
|
||
# 在该概念内部滑动窗口
|
||
n_points = len(feature_data)
|
||
for start in range(0, n_points - seq_len + 1, stride):
|
||
seq = feature_data[start:start + seq_len]
|
||
all_sequences.append(seq)
|
||
|
||
if not all_sequences:
|
||
return np.array([])
|
||
|
||
sequences = np.array(all_sequences)
|
||
print(f" 构建序列: {len(sequences):,} 条 (来自 {n_concepts} 个概念)")
|
||
|
||
return sequences
|
||
|
||
|
||
# ==================== 数据集 ====================
|
||
|
||
class SequenceDataset(Dataset):
|
||
"""序列数据集(已经构建好的序列)"""
|
||
|
||
def __init__(self, sequences: np.ndarray):
|
||
self.sequences = torch.FloatTensor(sequences)
|
||
|
||
def __len__(self) -> int:
|
||
return len(self.sequences)
|
||
|
||
def __getitem__(self, idx: int) -> torch.Tensor:
|
||
return self.sequences[idx]
|
||
|
||
|
||
# ==================== 训练器 ====================
|
||
|
||
class EarlyStopping:
|
||
"""早停机制"""
|
||
|
||
def __init__(self, patience: int = 10, min_delta: float = 1e-6):
|
||
self.patience = patience
|
||
self.min_delta = min_delta
|
||
self.counter = 0
|
||
self.best_loss = float('inf')
|
||
self.early_stop = False
|
||
|
||
def __call__(self, val_loss: float) -> bool:
|
||
if val_loss < self.best_loss - self.min_delta:
|
||
self.best_loss = val_loss
|
||
self.counter = 0
|
||
else:
|
||
self.counter += 1
|
||
if self.counter >= self.patience:
|
||
self.early_stop = True
|
||
|
||
return self.early_stop
|
||
|
||
|
||
class Trainer:
|
||
"""模型训练器(支持 AMP 混合精度加速)"""
|
||
|
||
def __init__(
|
||
self,
|
||
model: nn.Module,
|
||
train_loader: DataLoader,
|
||
val_loader: DataLoader,
|
||
config: Dict,
|
||
device: torch.device,
|
||
save_dir: str = 'ml/checkpoints'
|
||
):
|
||
self.model = model.to(device)
|
||
self.train_loader = train_loader
|
||
self.val_loader = val_loader
|
||
self.config = config
|
||
self.device = device
|
||
self.save_dir = Path(save_dir)
|
||
self.save_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
# 优化器
|
||
self.optimizer = AdamW(
|
||
model.parameters(),
|
||
lr=config['learning_rate'],
|
||
weight_decay=config['weight_decay']
|
||
)
|
||
|
||
# 学习率调度器
|
||
self.scheduler = CosineAnnealingWarmRestarts(
|
||
self.optimizer,
|
||
T_0=10,
|
||
T_mult=2,
|
||
eta_min=1e-6
|
||
)
|
||
|
||
# 损失函数(简化版,只用 MSE)
|
||
self.criterion = AnomalyDetectionLoss()
|
||
|
||
# 早停
|
||
self.early_stopping = EarlyStopping(
|
||
patience=config['patience'],
|
||
min_delta=config['min_delta']
|
||
)
|
||
|
||
# AMP 混合精度训练(大幅提速 + 省显存)
|
||
self.use_amp = torch.cuda.is_available()
|
||
self.scaler = torch.cuda.amp.GradScaler() if self.use_amp else None
|
||
if self.use_amp:
|
||
print(" ✓ 启用 AMP 混合精度训练")
|
||
|
||
# 训练历史
|
||
self.history = {
|
||
'train_loss': [],
|
||
'val_loss': [],
|
||
'learning_rate': [],
|
||
}
|
||
|
||
self.best_val_loss = float('inf')
|
||
|
||
def train_epoch(self) -> float:
|
||
"""训练一个 epoch(使用 AMP 混合精度)"""
|
||
self.model.train()
|
||
total_loss = 0.0
|
||
n_batches = 0
|
||
|
||
pbar = tqdm(self.train_loader, desc="Training", leave=False)
|
||
for batch in pbar:
|
||
batch = batch.to(self.device, non_blocking=True) # 异步传输
|
||
|
||
self.optimizer.zero_grad(set_to_none=True) # 更快的梯度清零
|
||
|
||
# AMP 混合精度前向传播
|
||
if self.use_amp:
|
||
with torch.cuda.amp.autocast():
|
||
output, latent = self.model(batch)
|
||
loss, loss_dict = self.criterion(output, batch, latent)
|
||
|
||
# AMP 反向传播
|
||
self.scaler.scale(loss).backward()
|
||
|
||
# 梯度裁剪(需要 unscale)
|
||
self.scaler.unscale_(self.optimizer)
|
||
torch.nn.utils.clip_grad_norm_(
|
||
self.model.parameters(),
|
||
self.config['gradient_clip']
|
||
)
|
||
|
||
self.scaler.step(self.optimizer)
|
||
self.scaler.update()
|
||
else:
|
||
# 非 AMP 模式
|
||
output, latent = self.model(batch)
|
||
loss, loss_dict = self.criterion(output, batch, latent)
|
||
|
||
loss.backward()
|
||
torch.nn.utils.clip_grad_norm_(
|
||
self.model.parameters(),
|
||
self.config['gradient_clip']
|
||
)
|
||
self.optimizer.step()
|
||
|
||
total_loss += loss.item()
|
||
n_batches += 1
|
||
|
||
pbar.set_postfix({'loss': f"{loss.item():.4f}"})
|
||
|
||
return total_loss / n_batches
|
||
|
||
@torch.no_grad()
|
||
def validate(self) -> float:
|
||
"""验证(使用 AMP)"""
|
||
self.model.eval()
|
||
total_loss = 0.0
|
||
n_batches = 0
|
||
|
||
for batch in self.val_loader:
|
||
batch = batch.to(self.device, non_blocking=True)
|
||
|
||
if self.use_amp:
|
||
with torch.cuda.amp.autocast():
|
||
output, latent = self.model(batch)
|
||
loss, _ = self.criterion(output, batch, latent)
|
||
else:
|
||
output, latent = self.model(batch)
|
||
loss, _ = self.criterion(output, batch, latent)
|
||
|
||
total_loss += loss.item()
|
||
n_batches += 1
|
||
|
||
return total_loss / n_batches
|
||
|
||
def save_checkpoint(self, epoch: int, val_loss: float, is_best: bool = False):
|
||
"""保存检查点"""
|
||
# 处理 DataParallel 包装
|
||
model_to_save = self.model.module if hasattr(self.model, 'module') else self.model
|
||
|
||
checkpoint = {
|
||
'epoch': epoch,
|
||
'model_state_dict': model_to_save.state_dict(),
|
||
'optimizer_state_dict': self.optimizer.state_dict(),
|
||
'scheduler_state_dict': self.scheduler.state_dict(),
|
||
'val_loss': val_loss,
|
||
'config': self.config,
|
||
}
|
||
|
||
# 保存最新检查点
|
||
torch.save(checkpoint, self.save_dir / 'last_checkpoint.pt')
|
||
|
||
# 保存最佳模型
|
||
if is_best:
|
||
torch.save(checkpoint, self.save_dir / 'best_model.pt')
|
||
print(f" ✓ 保存最佳模型 (val_loss: {val_loss:.6f})")
|
||
|
||
def train(self, epochs: int):
|
||
"""完整训练流程"""
|
||
print(f"\n开始训练 ({epochs} epochs)...")
|
||
print(f"设备: {self.device}")
|
||
print(f"模型参数量: {count_parameters(self.model):,}")
|
||
|
||
for epoch in range(1, epochs + 1):
|
||
print(f"\nEpoch {epoch}/{epochs}")
|
||
|
||
# 训练
|
||
train_loss = self.train_epoch()
|
||
|
||
# 验证
|
||
val_loss = self.validate()
|
||
|
||
# 更新学习率
|
||
self.scheduler.step()
|
||
current_lr = self.optimizer.param_groups[0]['lr']
|
||
|
||
# 记录历史
|
||
self.history['train_loss'].append(train_loss)
|
||
self.history['val_loss'].append(val_loss)
|
||
self.history['learning_rate'].append(current_lr)
|
||
|
||
# 打印进度
|
||
print(f" Train Loss: {train_loss:.6f}")
|
||
print(f" Val Loss: {val_loss:.6f}")
|
||
print(f" LR: {current_lr:.2e}")
|
||
|
||
# 保存检查点
|
||
is_best = val_loss < self.best_val_loss
|
||
if is_best:
|
||
self.best_val_loss = val_loss
|
||
self.save_checkpoint(epoch, val_loss, is_best)
|
||
|
||
# 早停检查
|
||
if self.early_stopping(val_loss):
|
||
print(f"\n早停触发!验证损失已 {self.early_stopping.patience} 个 epoch 未改善")
|
||
break
|
||
|
||
print(f"\n训练完成!最佳验证损失: {self.best_val_loss:.6f}")
|
||
|
||
# 保存训练历史
|
||
self.save_history()
|
||
|
||
return self.history
|
||
|
||
def save_history(self):
|
||
"""保存训练历史"""
|
||
history_path = self.save_dir / 'training_history.json'
|
||
with open(history_path, 'w') as f:
|
||
json.dump(self.history, f, indent=2)
|
||
print(f"训练历史已保存: {history_path}")
|
||
|
||
# 绘制训练曲线
|
||
self.plot_training_curves()
|
||
|
||
def plot_training_curves(self):
|
||
"""绘制训练曲线"""
|
||
if not HAS_MATPLOTLIB:
|
||
print("matplotlib 未安装,跳过绘图")
|
||
return
|
||
|
||
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
|
||
|
||
epochs = range(1, len(self.history['train_loss']) + 1)
|
||
|
||
# 1. Loss 曲线
|
||
ax1 = axes[0]
|
||
ax1.plot(epochs, self.history['train_loss'], 'b-', label='Train Loss', linewidth=2)
|
||
ax1.plot(epochs, self.history['val_loss'], 'r-', label='Val Loss', linewidth=2)
|
||
ax1.set_xlabel('Epoch', fontsize=12)
|
||
ax1.set_ylabel('Loss', fontsize=12)
|
||
ax1.set_title('Training & Validation Loss', fontsize=14)
|
||
ax1.legend(fontsize=11)
|
||
ax1.grid(True, alpha=0.3)
|
||
|
||
# 标记最佳点
|
||
best_epoch = np.argmin(self.history['val_loss']) + 1
|
||
best_val_loss = min(self.history['val_loss'])
|
||
ax1.axvline(x=best_epoch, color='g', linestyle='--', alpha=0.7, label=f'Best Epoch: {best_epoch}')
|
||
ax1.scatter([best_epoch], [best_val_loss], color='g', s=100, zorder=5)
|
||
ax1.annotate(f'Best: {best_val_loss:.6f}', xy=(best_epoch, best_val_loss),
|
||
xytext=(best_epoch + 2, best_val_loss + 0.0005),
|
||
fontsize=10, color='green')
|
||
|
||
# 2. 学习率曲线
|
||
ax2 = axes[1]
|
||
ax2.plot(epochs, self.history['learning_rate'], 'g-', linewidth=2)
|
||
ax2.set_xlabel('Epoch', fontsize=12)
|
||
ax2.set_ylabel('Learning Rate', fontsize=12)
|
||
ax2.set_title('Learning Rate Schedule', fontsize=14)
|
||
ax2.set_yscale('log')
|
||
ax2.grid(True, alpha=0.3)
|
||
|
||
plt.tight_layout()
|
||
|
||
# 保存图片
|
||
plot_path = self.save_dir / 'training_curves.png'
|
||
plt.savefig(plot_path, dpi=150, bbox_inches='tight')
|
||
plt.close()
|
||
print(f"训练曲线已保存: {plot_path}")
|
||
|
||
|
||
# ==================== 阈值计算(使用验证集)====================
|
||
|
||
@torch.no_grad()
|
||
def compute_thresholds(
|
||
model: nn.Module,
|
||
data_loader: DataLoader,
|
||
device: torch.device,
|
||
percentiles: List[float] = [90, 95, 99]
|
||
) -> Dict[str, float]:
|
||
"""
|
||
在验证集上计算重构误差的百分位数阈值
|
||
|
||
注:使用验证集而非测试集,避免数据泄露
|
||
"""
|
||
model.eval()
|
||
all_errors = []
|
||
|
||
print("计算异动阈值(使用验证集)...")
|
||
for batch in tqdm(data_loader, desc="Computing thresholds"):
|
||
batch = batch.to(device)
|
||
errors = model.compute_reconstruction_error(batch, reduction='none')
|
||
|
||
# 取每个序列的最后一个时刻误差(预测当前时刻)
|
||
seq_errors = errors[:, -1] # (batch,)
|
||
all_errors.append(seq_errors.cpu().numpy())
|
||
|
||
all_errors = np.concatenate(all_errors)
|
||
|
||
thresholds = {}
|
||
for p in percentiles:
|
||
threshold = np.percentile(all_errors, p)
|
||
thresholds[f'p{p}'] = float(threshold)
|
||
print(f" P{p}: {threshold:.6f}")
|
||
|
||
# 额外统计
|
||
thresholds['mean'] = float(np.mean(all_errors))
|
||
thresholds['std'] = float(np.std(all_errors))
|
||
thresholds['median'] = float(np.median(all_errors))
|
||
|
||
print(f" Mean: {thresholds['mean']:.6f}")
|
||
print(f" Median: {thresholds['median']:.6f}")
|
||
print(f" Std: {thresholds['std']:.6f}")
|
||
|
||
return thresholds
|
||
|
||
|
||
# ==================== 主函数 ====================
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(description='训练概念异动检测模型')
|
||
parser.add_argument('--data_dir', type=str, default='ml/data',
|
||
help='数据目录路径')
|
||
parser.add_argument('--epochs', type=int, default=100,
|
||
help='训练轮数')
|
||
parser.add_argument('--batch_size', type=int, default=4096,
|
||
help='批次大小(4x RTX 4090 推荐 4096~8192)')
|
||
parser.add_argument('--lr', type=float, default=3e-4,
|
||
help='学习率(大 batch 推荐 3e-4)')
|
||
parser.add_argument('--device', type=str, default='auto',
|
||
help='设备 (auto/cuda/cpu)')
|
||
parser.add_argument('--save_dir', type=str, default='ml/checkpoints',
|
||
help='模型保存目录')
|
||
parser.add_argument('--train_end', type=str, default='2024-06-30',
|
||
help='训练集截止日期')
|
||
parser.add_argument('--val_end', type=str, default='2024-09-30',
|
||
help='验证集截止日期')
|
||
|
||
args = parser.parse_args()
|
||
|
||
# 更新配置
|
||
config = TRAIN_CONFIG.copy()
|
||
config['batch_size'] = args.batch_size
|
||
config['epochs'] = args.epochs
|
||
config['learning_rate'] = args.lr
|
||
config['train_end_date'] = args.train_end
|
||
config['val_end_date'] = args.val_end
|
||
|
||
# 设备选择
|
||
if args.device == 'auto':
|
||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||
else:
|
||
device = torch.device(args.device)
|
||
|
||
print("=" * 60)
|
||
print("概念异动检测模型训练(修复版)")
|
||
print("=" * 60)
|
||
print(f"配置:")
|
||
print(f" 数据目录: {args.data_dir}")
|
||
print(f" 设备: {device}")
|
||
print(f" 批次大小: {config['batch_size']}")
|
||
print(f" 学习率: {config['learning_rate']}")
|
||
print(f" 训练轮数: {config['epochs']}")
|
||
print(f" 训练集截止: {config['train_end_date']}")
|
||
print(f" 验证集截止: {config['val_end_date']}")
|
||
print("=" * 60)
|
||
|
||
# 1. 按日期加载数据
|
||
print("\n[1/6] 加载数据...")
|
||
date_data = load_data_by_date(args.data_dir, config['features'])
|
||
|
||
# 2. 按日期划分
|
||
print("\n[2/6] 按日期划分数据集...")
|
||
train_data, val_data, test_data = split_data_by_date(
|
||
date_data,
|
||
config['train_end_date'],
|
||
config['val_end_date']
|
||
)
|
||
|
||
# 3. 按概念构建序列
|
||
print("\n[3/6] 按概念构建序列...")
|
||
print("训练集:")
|
||
train_sequences = build_sequences_by_concept(
|
||
train_data, config['features'], config['seq_len'], config['stride']
|
||
)
|
||
print("验证集:")
|
||
val_sequences = build_sequences_by_concept(
|
||
val_data, config['features'], config['seq_len'], config['stride']
|
||
)
|
||
print("测试集:")
|
||
test_sequences = build_sequences_by_concept(
|
||
test_data, config['features'], config['seq_len'], config['stride']
|
||
)
|
||
|
||
if len(train_sequences) == 0:
|
||
print("错误: 训练集为空!请检查数据和日期范围")
|
||
return
|
||
|
||
# 4. 数据预处理(简单截断极端值,标准化在模型内部通过 Instance Norm 完成)
|
||
print("\n[4/6] 数据预处理...")
|
||
print(" 注意: 使用 Instance Norm,每个序列在模型内部单独标准化")
|
||
print(" 这样可以处理不同概念波动率差异(银行 vs 半导体)")
|
||
|
||
clip_value = config['clip_value']
|
||
print(f" 截断极端值: ±{clip_value}")
|
||
|
||
# 简单截断极端值(防止异常数据影响训练)
|
||
train_sequences = np.clip(train_sequences, -clip_value, clip_value)
|
||
if len(val_sequences) > 0:
|
||
val_sequences = np.clip(val_sequences, -clip_value, clip_value)
|
||
if len(test_sequences) > 0:
|
||
test_sequences = np.clip(test_sequences, -clip_value, clip_value)
|
||
|
||
# 保存配置
|
||
save_dir = Path(args.save_dir)
|
||
save_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
preprocess_params = {
|
||
'features': config['features'],
|
||
'normalization': 'instance_norm', # 在模型内部完成
|
||
'clip_value': clip_value,
|
||
'note': '标准化在模型内部通过 InstanceNorm1d 完成,无需外部 Scaler'
|
||
}
|
||
|
||
with open(save_dir / 'normalization_stats.json', 'w') as f:
|
||
json.dump(preprocess_params, f, indent=2)
|
||
print(f" 预处理参数已保存")
|
||
|
||
# 5. 创建数据集和加载器
|
||
print("\n[5/6] 创建数据加载器...")
|
||
train_dataset = SequenceDataset(train_sequences)
|
||
val_dataset = SequenceDataset(val_sequences) if len(val_sequences) > 0 else None
|
||
test_dataset = SequenceDataset(test_sequences) if len(test_sequences) > 0 else None
|
||
|
||
print(f" 训练序列: {len(train_dataset):,}")
|
||
print(f" 验证序列: {len(val_dataset) if val_dataset else 0:,}")
|
||
print(f" 测试序列: {len(test_dataset) if test_dataset else 0:,}")
|
||
|
||
# 多卡时增加 num_workers(Linux 上可以用更多)
|
||
n_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 1
|
||
num_workers = min(32, 8 * n_gpus) if sys.platform != 'win32' else 0
|
||
print(f" DataLoader workers: {num_workers}")
|
||
print(f" Batch size: {config['batch_size']}")
|
||
|
||
# 大 batch + 多 worker + prefetch 提速
|
||
train_loader = DataLoader(
|
||
train_dataset,
|
||
batch_size=config['batch_size'],
|
||
shuffle=True,
|
||
num_workers=num_workers,
|
||
pin_memory=True,
|
||
prefetch_factor=4 if num_workers > 0 else None, # 预取更多 batch
|
||
persistent_workers=True if num_workers > 0 else False, # 保持 worker 存活
|
||
drop_last=True # 丢弃不完整的最后一批,避免 batch 大小不一致
|
||
)
|
||
|
||
val_loader = DataLoader(
|
||
val_dataset,
|
||
batch_size=config['batch_size'] * 2, # 验证时可以用更大 batch(无梯度)
|
||
shuffle=False,
|
||
num_workers=num_workers,
|
||
pin_memory=True,
|
||
prefetch_factor=4 if num_workers > 0 else None,
|
||
persistent_workers=True if num_workers > 0 else False,
|
||
) if val_dataset else None
|
||
|
||
test_loader = DataLoader(
|
||
test_dataset,
|
||
batch_size=config['batch_size'] * 2,
|
||
shuffle=False,
|
||
num_workers=num_workers,
|
||
pin_memory=True,
|
||
prefetch_factor=4 if num_workers > 0 else None,
|
||
persistent_workers=True if num_workers > 0 else False,
|
||
) if test_dataset else None
|
||
|
||
# 6. 训练
|
||
print("\n[6/6] 训练模型...")
|
||
model_config = config['model'].copy()
|
||
model = TransformerAutoencoder(**model_config)
|
||
|
||
# 多卡并行
|
||
if torch.cuda.device_count() > 1:
|
||
print(f" 使用 {torch.cuda.device_count()} 张 GPU 并行训练")
|
||
model = nn.DataParallel(model)
|
||
|
||
if val_loader is None:
|
||
print("警告: 验证集为空,将使用训练集的一部分作为验证")
|
||
# 简单处理:用训练集的后 10% 作为验证
|
||
split_idx = int(len(train_dataset) * 0.9)
|
||
train_subset = torch.utils.data.Subset(train_dataset, range(split_idx))
|
||
val_subset = torch.utils.data.Subset(train_dataset, range(split_idx, len(train_dataset)))
|
||
|
||
train_loader = DataLoader(train_subset, batch_size=config['batch_size'], shuffle=True, num_workers=num_workers, pin_memory=True)
|
||
val_loader = DataLoader(val_subset, batch_size=config['batch_size'], shuffle=False, num_workers=num_workers, pin_memory=True)
|
||
|
||
trainer = Trainer(
|
||
model=model,
|
||
train_loader=train_loader,
|
||
val_loader=val_loader,
|
||
config=config,
|
||
device=device,
|
||
save_dir=args.save_dir
|
||
)
|
||
|
||
history = trainer.train(config['epochs'])
|
||
|
||
# 7. 计算阈值(使用验证集)
|
||
print("\n[额外] 计算异动阈值...")
|
||
|
||
# 加载最佳模型
|
||
best_checkpoint = torch.load(
|
||
save_dir / 'best_model.pt',
|
||
map_location=device
|
||
)
|
||
model.load_state_dict(best_checkpoint['model_state_dict'])
|
||
model.to(device)
|
||
|
||
# 使用验证集计算阈值(避免数据泄露)
|
||
thresholds = compute_thresholds(
|
||
model,
|
||
val_loader,
|
||
device,
|
||
config['threshold_percentiles']
|
||
)
|
||
|
||
# 保存阈值
|
||
with open(save_dir / 'thresholds.json', 'w') as f:
|
||
json.dump(thresholds, f, indent=2)
|
||
print(f"阈值已保存")
|
||
|
||
# 保存完整配置
|
||
with open(save_dir / 'config.json', 'w') as f:
|
||
json.dump(config, f, indent=2)
|
||
|
||
print("\n" + "=" * 60)
|
||
print("训练完成!")
|
||
print("=" * 60)
|
||
print(f"模型保存位置: {args.save_dir}")
|
||
print(f" - best_model.pt: 最佳模型权重")
|
||
print(f" - thresholds.json: 异动阈值")
|
||
print(f" - normalization_stats.json: 标准化参数")
|
||
print(f" - config.json: 训练配置")
|
||
print("=" * 60)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|