update pay ui
This commit is contained in:
808
ml/train.py
Normal file
808
ml/train.py
Normal file
@@ -0,0 +1,808 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Transformer Autoencoder 训练脚本 (修复版)
|
||||
|
||||
修复问题:
|
||||
1. 按概念分组构建序列,避免跨概念切片
|
||||
2. 按时间(日期)切分数据集,避免数据泄露
|
||||
3. 使用 RobustScaler + Clipping 处理非平稳性
|
||||
4. 使用验证集计算阈值
|
||||
|
||||
训练流程:
|
||||
1. 加载预处理好的特征数据(parquet 文件)
|
||||
2. 按概念分组,在每个概念内部构建序列
|
||||
3. 按日期划分训练/验证/测试集
|
||||
4. 训练 Autoencoder(最小化重构误差)
|
||||
5. 保存模型和阈值
|
||||
|
||||
使用方法:
|
||||
python train.py --data_dir ml/data --epochs 100 --batch_size 256
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple, Dict
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from torch.optim import AdamW
|
||||
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
|
||||
from tqdm import tqdm
|
||||
|
||||
from model import TransformerAutoencoder, AnomalyDetectionLoss, count_parameters
|
||||
|
||||
# 性能优化:启用 cuDNN benchmark(对固定输入尺寸自动选择最快算法)
|
||||
torch.backends.cudnn.benchmark = True
|
||||
# 启用 TF32(RTX 30/40 系列特有,提速约 3 倍)
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
# 可视化(可选)
|
||||
try:
|
||||
import matplotlib
|
||||
matplotlib.use('Agg') # 无头模式,不需要显示器
|
||||
import matplotlib.pyplot as plt
|
||||
HAS_MATPLOTLIB = True
|
||||
except ImportError:
|
||||
HAS_MATPLOTLIB = False
|
||||
|
||||
|
||||
# ==================== 配置 ====================
|
||||
|
||||
TRAIN_CONFIG = {
|
||||
# 数据配置
|
||||
'seq_len': 30, # 输入序列长度(30分钟)
|
||||
'stride': 5, # 滑动窗口步长
|
||||
|
||||
# 时间切分(按日期)
|
||||
'train_end_date': '2024-06-30', # 训练集截止日期
|
||||
'val_end_date': '2024-09-30', # 验证集截止日期(之后为测试集)
|
||||
|
||||
# 特征配置
|
||||
'features': [
|
||||
'alpha', # 超额收益
|
||||
'alpha_delta', # Alpha 变化率
|
||||
'amt_ratio', # 成交额比率
|
||||
'amt_delta', # 成交额变化率
|
||||
'rank_pct', # Alpha 排名百分位
|
||||
'limit_up_ratio', # 涨停比例
|
||||
],
|
||||
|
||||
# 训练配置(针对 4x RTX 4090 优化)
|
||||
'batch_size': 4096, # 256 -> 4096(大幅增加,充分利用显存)
|
||||
'epochs': 100,
|
||||
'learning_rate': 3e-4, # 1e-4 -> 3e-4(大 batch 需要更大学习率)
|
||||
'weight_decay': 1e-5,
|
||||
'gradient_clip': 1.0,
|
||||
|
||||
# 早停配置
|
||||
'patience': 10,
|
||||
'min_delta': 1e-6,
|
||||
|
||||
# 模型配置(LSTM Autoencoder,简洁有效)
|
||||
'model': {
|
||||
'n_features': 6,
|
||||
'hidden_dim': 32, # LSTM 隐藏维度(小)
|
||||
'latent_dim': 4, # 瓶颈维度(非常小!关键)
|
||||
'num_layers': 1, # LSTM 层数
|
||||
'dropout': 0.2,
|
||||
'bidirectional': True, # 双向编码器
|
||||
},
|
||||
|
||||
# 标准化配置
|
||||
'use_instance_norm': True, # 模型内部使用 Instance Norm(推荐)
|
||||
'clip_value': 10.0, # 简单截断极端值
|
||||
|
||||
# 阈值配置
|
||||
'threshold_percentiles': [90, 95, 99],
|
||||
}
|
||||
|
||||
|
||||
# ==================== 数据加载(修复版)====================
|
||||
|
||||
def load_data_by_date(data_dir: str, features: List[str]) -> Dict[str, pd.DataFrame]:
|
||||
"""
|
||||
按日期加载数据,返回 {date: DataFrame} 字典
|
||||
|
||||
每个 DataFrame 包含该日所有概念的所有时间点数据
|
||||
"""
|
||||
data_path = Path(data_dir)
|
||||
parquet_files = sorted(data_path.glob("features_*.parquet"))
|
||||
|
||||
if not parquet_files:
|
||||
raise FileNotFoundError(f"未找到 parquet 文件: {data_dir}")
|
||||
|
||||
print(f"找到 {len(parquet_files)} 个数据文件")
|
||||
|
||||
date_data = {}
|
||||
|
||||
for pf in tqdm(parquet_files, desc="加载数据"):
|
||||
# 提取日期
|
||||
date = pf.stem.replace('features_', '')
|
||||
|
||||
df = pd.read_parquet(pf)
|
||||
|
||||
# 检查必要列
|
||||
required_cols = features + ['concept_id', 'timestamp']
|
||||
missing_cols = [c for c in required_cols if c not in df.columns]
|
||||
if missing_cols:
|
||||
print(f"警告: {date} 缺少列: {missing_cols}, 跳过")
|
||||
continue
|
||||
|
||||
date_data[date] = df
|
||||
|
||||
print(f"成功加载 {len(date_data)} 天的数据")
|
||||
return date_data
|
||||
|
||||
|
||||
def split_data_by_date(
|
||||
date_data: Dict[str, pd.DataFrame],
|
||||
train_end: str,
|
||||
val_end: str
|
||||
) -> Tuple[Dict[str, pd.DataFrame], Dict[str, pd.DataFrame], Dict[str, pd.DataFrame]]:
|
||||
"""
|
||||
按日期严格划分数据集
|
||||
|
||||
- 训练集: <= train_end
|
||||
- 验证集: train_end < date <= val_end
|
||||
- 测试集: > val_end
|
||||
"""
|
||||
train_data = {}
|
||||
val_data = {}
|
||||
test_data = {}
|
||||
|
||||
for date, df in date_data.items():
|
||||
if date <= train_end:
|
||||
train_data[date] = df
|
||||
elif date <= val_end:
|
||||
val_data[date] = df
|
||||
else:
|
||||
test_data[date] = df
|
||||
|
||||
print(f"数据集划分(按日期):")
|
||||
print(f" 训练集: {len(train_data)} 天 (<= {train_end})")
|
||||
print(f" 验证集: {len(val_data)} 天 ({train_end} ~ {val_end})")
|
||||
print(f" 测试集: {len(test_data)} 天 (> {val_end})")
|
||||
|
||||
return train_data, val_data, test_data
|
||||
|
||||
|
||||
def build_sequences_by_concept(
|
||||
date_data: Dict[str, pd.DataFrame],
|
||||
features: List[str],
|
||||
seq_len: int,
|
||||
stride: int
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
按概念分组构建序列(性能优化版)
|
||||
|
||||
使用 groupby 一次性分组,避免重复扫描大数组
|
||||
|
||||
1. 将所有日期的数据合并
|
||||
2. 使用 groupby 按 concept_id 分组
|
||||
3. 在每个概念内部,按时间排序并滑动窗口
|
||||
4. 合并所有序列
|
||||
"""
|
||||
# 合并所有日期的数据
|
||||
all_dfs = []
|
||||
for date, df in sorted(date_data.items()):
|
||||
df = df.copy()
|
||||
df['date'] = date
|
||||
all_dfs.append(df)
|
||||
|
||||
if not all_dfs:
|
||||
return np.array([])
|
||||
|
||||
combined = pd.concat(all_dfs, ignore_index=True)
|
||||
|
||||
# 预先排序(按概念、日期、时间),这样 groupby 会更快
|
||||
combined = combined.sort_values(['concept_id', 'date', 'timestamp'])
|
||||
|
||||
# 使用 groupby 一次性分组(性能关键!)
|
||||
all_sequences = []
|
||||
grouped = combined.groupby('concept_id', sort=False)
|
||||
n_concepts = len(grouped)
|
||||
|
||||
for concept_id, concept_df in tqdm(grouped, desc="构建序列", total=n_concepts, leave=False):
|
||||
# 已经排序过了,直接提取特征
|
||||
feature_data = concept_df[features].values
|
||||
|
||||
# 处理缺失值
|
||||
feature_data = np.nan_to_num(feature_data, nan=0.0, posinf=0.0, neginf=0.0)
|
||||
|
||||
# 在该概念内部滑动窗口
|
||||
n_points = len(feature_data)
|
||||
for start in range(0, n_points - seq_len + 1, stride):
|
||||
seq = feature_data[start:start + seq_len]
|
||||
all_sequences.append(seq)
|
||||
|
||||
if not all_sequences:
|
||||
return np.array([])
|
||||
|
||||
sequences = np.array(all_sequences)
|
||||
print(f" 构建序列: {len(sequences):,} 条 (来自 {n_concepts} 个概念)")
|
||||
|
||||
return sequences
|
||||
|
||||
|
||||
# ==================== 数据集 ====================
|
||||
|
||||
class SequenceDataset(Dataset):
|
||||
"""序列数据集(已经构建好的序列)"""
|
||||
|
||||
def __init__(self, sequences: np.ndarray):
|
||||
self.sequences = torch.FloatTensor(sequences)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.sequences)
|
||||
|
||||
def __getitem__(self, idx: int) -> torch.Tensor:
|
||||
return self.sequences[idx]
|
||||
|
||||
|
||||
# ==================== 训练器 ====================
|
||||
|
||||
class EarlyStopping:
|
||||
"""早停机制"""
|
||||
|
||||
def __init__(self, patience: int = 10, min_delta: float = 1e-6):
|
||||
self.patience = patience
|
||||
self.min_delta = min_delta
|
||||
self.counter = 0
|
||||
self.best_loss = float('inf')
|
||||
self.early_stop = False
|
||||
|
||||
def __call__(self, val_loss: float) -> bool:
|
||||
if val_loss < self.best_loss - self.min_delta:
|
||||
self.best_loss = val_loss
|
||||
self.counter = 0
|
||||
else:
|
||||
self.counter += 1
|
||||
if self.counter >= self.patience:
|
||||
self.early_stop = True
|
||||
|
||||
return self.early_stop
|
||||
|
||||
|
||||
class Trainer:
|
||||
"""模型训练器(支持 AMP 混合精度加速)"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
train_loader: DataLoader,
|
||||
val_loader: DataLoader,
|
||||
config: Dict,
|
||||
device: torch.device,
|
||||
save_dir: str = 'ml/checkpoints'
|
||||
):
|
||||
self.model = model.to(device)
|
||||
self.train_loader = train_loader
|
||||
self.val_loader = val_loader
|
||||
self.config = config
|
||||
self.device = device
|
||||
self.save_dir = Path(save_dir)
|
||||
self.save_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 优化器
|
||||
self.optimizer = AdamW(
|
||||
model.parameters(),
|
||||
lr=config['learning_rate'],
|
||||
weight_decay=config['weight_decay']
|
||||
)
|
||||
|
||||
# 学习率调度器
|
||||
self.scheduler = CosineAnnealingWarmRestarts(
|
||||
self.optimizer,
|
||||
T_0=10,
|
||||
T_mult=2,
|
||||
eta_min=1e-6
|
||||
)
|
||||
|
||||
# 损失函数(简化版,只用 MSE)
|
||||
self.criterion = AnomalyDetectionLoss()
|
||||
|
||||
# 早停
|
||||
self.early_stopping = EarlyStopping(
|
||||
patience=config['patience'],
|
||||
min_delta=config['min_delta']
|
||||
)
|
||||
|
||||
# AMP 混合精度训练(大幅提速 + 省显存)
|
||||
self.use_amp = torch.cuda.is_available()
|
||||
self.scaler = torch.cuda.amp.GradScaler() if self.use_amp else None
|
||||
if self.use_amp:
|
||||
print(" ✓ 启用 AMP 混合精度训练")
|
||||
|
||||
# 训练历史
|
||||
self.history = {
|
||||
'train_loss': [],
|
||||
'val_loss': [],
|
||||
'learning_rate': [],
|
||||
}
|
||||
|
||||
self.best_val_loss = float('inf')
|
||||
|
||||
def train_epoch(self) -> float:
|
||||
"""训练一个 epoch(使用 AMP 混合精度)"""
|
||||
self.model.train()
|
||||
total_loss = 0.0
|
||||
n_batches = 0
|
||||
|
||||
pbar = tqdm(self.train_loader, desc="Training", leave=False)
|
||||
for batch in pbar:
|
||||
batch = batch.to(self.device, non_blocking=True) # 异步传输
|
||||
|
||||
self.optimizer.zero_grad(set_to_none=True) # 更快的梯度清零
|
||||
|
||||
# AMP 混合精度前向传播
|
||||
if self.use_amp:
|
||||
with torch.cuda.amp.autocast():
|
||||
output, latent = self.model(batch)
|
||||
loss, loss_dict = self.criterion(output, batch, latent)
|
||||
|
||||
# AMP 反向传播
|
||||
self.scaler.scale(loss).backward()
|
||||
|
||||
# 梯度裁剪(需要 unscale)
|
||||
self.scaler.unscale_(self.optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(
|
||||
self.model.parameters(),
|
||||
self.config['gradient_clip']
|
||||
)
|
||||
|
||||
self.scaler.step(self.optimizer)
|
||||
self.scaler.update()
|
||||
else:
|
||||
# 非 AMP 模式
|
||||
output, latent = self.model(batch)
|
||||
loss, loss_dict = self.criterion(output, batch, latent)
|
||||
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(
|
||||
self.model.parameters(),
|
||||
self.config['gradient_clip']
|
||||
)
|
||||
self.optimizer.step()
|
||||
|
||||
total_loss += loss.item()
|
||||
n_batches += 1
|
||||
|
||||
pbar.set_postfix({'loss': f"{loss.item():.4f}"})
|
||||
|
||||
return total_loss / n_batches
|
||||
|
||||
@torch.no_grad()
|
||||
def validate(self) -> float:
|
||||
"""验证(使用 AMP)"""
|
||||
self.model.eval()
|
||||
total_loss = 0.0
|
||||
n_batches = 0
|
||||
|
||||
for batch in self.val_loader:
|
||||
batch = batch.to(self.device, non_blocking=True)
|
||||
|
||||
if self.use_amp:
|
||||
with torch.cuda.amp.autocast():
|
||||
output, latent = self.model(batch)
|
||||
loss, _ = self.criterion(output, batch, latent)
|
||||
else:
|
||||
output, latent = self.model(batch)
|
||||
loss, _ = self.criterion(output, batch, latent)
|
||||
|
||||
total_loss += loss.item()
|
||||
n_batches += 1
|
||||
|
||||
return total_loss / n_batches
|
||||
|
||||
def save_checkpoint(self, epoch: int, val_loss: float, is_best: bool = False):
|
||||
"""保存检查点"""
|
||||
# 处理 DataParallel 包装
|
||||
model_to_save = self.model.module if hasattr(self.model, 'module') else self.model
|
||||
|
||||
checkpoint = {
|
||||
'epoch': epoch,
|
||||
'model_state_dict': model_to_save.state_dict(),
|
||||
'optimizer_state_dict': self.optimizer.state_dict(),
|
||||
'scheduler_state_dict': self.scheduler.state_dict(),
|
||||
'val_loss': val_loss,
|
||||
'config': self.config,
|
||||
}
|
||||
|
||||
# 保存最新检查点
|
||||
torch.save(checkpoint, self.save_dir / 'last_checkpoint.pt')
|
||||
|
||||
# 保存最佳模型
|
||||
if is_best:
|
||||
torch.save(checkpoint, self.save_dir / 'best_model.pt')
|
||||
print(f" ✓ 保存最佳模型 (val_loss: {val_loss:.6f})")
|
||||
|
||||
def train(self, epochs: int):
|
||||
"""完整训练流程"""
|
||||
print(f"\n开始训练 ({epochs} epochs)...")
|
||||
print(f"设备: {self.device}")
|
||||
print(f"模型参数量: {count_parameters(self.model):,}")
|
||||
|
||||
for epoch in range(1, epochs + 1):
|
||||
print(f"\nEpoch {epoch}/{epochs}")
|
||||
|
||||
# 训练
|
||||
train_loss = self.train_epoch()
|
||||
|
||||
# 验证
|
||||
val_loss = self.validate()
|
||||
|
||||
# 更新学习率
|
||||
self.scheduler.step()
|
||||
current_lr = self.optimizer.param_groups[0]['lr']
|
||||
|
||||
# 记录历史
|
||||
self.history['train_loss'].append(train_loss)
|
||||
self.history['val_loss'].append(val_loss)
|
||||
self.history['learning_rate'].append(current_lr)
|
||||
|
||||
# 打印进度
|
||||
print(f" Train Loss: {train_loss:.6f}")
|
||||
print(f" Val Loss: {val_loss:.6f}")
|
||||
print(f" LR: {current_lr:.2e}")
|
||||
|
||||
# 保存检查点
|
||||
is_best = val_loss < self.best_val_loss
|
||||
if is_best:
|
||||
self.best_val_loss = val_loss
|
||||
self.save_checkpoint(epoch, val_loss, is_best)
|
||||
|
||||
# 早停检查
|
||||
if self.early_stopping(val_loss):
|
||||
print(f"\n早停触发!验证损失已 {self.early_stopping.patience} 个 epoch 未改善")
|
||||
break
|
||||
|
||||
print(f"\n训练完成!最佳验证损失: {self.best_val_loss:.6f}")
|
||||
|
||||
# 保存训练历史
|
||||
self.save_history()
|
||||
|
||||
return self.history
|
||||
|
||||
def save_history(self):
|
||||
"""保存训练历史"""
|
||||
history_path = self.save_dir / 'training_history.json'
|
||||
with open(history_path, 'w') as f:
|
||||
json.dump(self.history, f, indent=2)
|
||||
print(f"训练历史已保存: {history_path}")
|
||||
|
||||
# 绘制训练曲线
|
||||
self.plot_training_curves()
|
||||
|
||||
def plot_training_curves(self):
|
||||
"""绘制训练曲线"""
|
||||
if not HAS_MATPLOTLIB:
|
||||
print("matplotlib 未安装,跳过绘图")
|
||||
return
|
||||
|
||||
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
|
||||
|
||||
epochs = range(1, len(self.history['train_loss']) + 1)
|
||||
|
||||
# 1. Loss 曲线
|
||||
ax1 = axes[0]
|
||||
ax1.plot(epochs, self.history['train_loss'], 'b-', label='Train Loss', linewidth=2)
|
||||
ax1.plot(epochs, self.history['val_loss'], 'r-', label='Val Loss', linewidth=2)
|
||||
ax1.set_xlabel('Epoch', fontsize=12)
|
||||
ax1.set_ylabel('Loss', fontsize=12)
|
||||
ax1.set_title('Training & Validation Loss', fontsize=14)
|
||||
ax1.legend(fontsize=11)
|
||||
ax1.grid(True, alpha=0.3)
|
||||
|
||||
# 标记最佳点
|
||||
best_epoch = np.argmin(self.history['val_loss']) + 1
|
||||
best_val_loss = min(self.history['val_loss'])
|
||||
ax1.axvline(x=best_epoch, color='g', linestyle='--', alpha=0.7, label=f'Best Epoch: {best_epoch}')
|
||||
ax1.scatter([best_epoch], [best_val_loss], color='g', s=100, zorder=5)
|
||||
ax1.annotate(f'Best: {best_val_loss:.6f}', xy=(best_epoch, best_val_loss),
|
||||
xytext=(best_epoch + 2, best_val_loss + 0.0005),
|
||||
fontsize=10, color='green')
|
||||
|
||||
# 2. 学习率曲线
|
||||
ax2 = axes[1]
|
||||
ax2.plot(epochs, self.history['learning_rate'], 'g-', linewidth=2)
|
||||
ax2.set_xlabel('Epoch', fontsize=12)
|
||||
ax2.set_ylabel('Learning Rate', fontsize=12)
|
||||
ax2.set_title('Learning Rate Schedule', fontsize=14)
|
||||
ax2.set_yscale('log')
|
||||
ax2.grid(True, alpha=0.3)
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
# 保存图片
|
||||
plot_path = self.save_dir / 'training_curves.png'
|
||||
plt.savefig(plot_path, dpi=150, bbox_inches='tight')
|
||||
plt.close()
|
||||
print(f"训练曲线已保存: {plot_path}")
|
||||
|
||||
|
||||
# ==================== 阈值计算(使用验证集)====================
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_thresholds(
|
||||
model: nn.Module,
|
||||
data_loader: DataLoader,
|
||||
device: torch.device,
|
||||
percentiles: List[float] = [90, 95, 99]
|
||||
) -> Dict[str, float]:
|
||||
"""
|
||||
在验证集上计算重构误差的百分位数阈值
|
||||
|
||||
注:使用验证集而非测试集,避免数据泄露
|
||||
"""
|
||||
model.eval()
|
||||
all_errors = []
|
||||
|
||||
print("计算异动阈值(使用验证集)...")
|
||||
for batch in tqdm(data_loader, desc="Computing thresholds"):
|
||||
batch = batch.to(device)
|
||||
errors = model.compute_reconstruction_error(batch, reduction='none')
|
||||
|
||||
# 取每个序列的最后一个时刻误差(预测当前时刻)
|
||||
seq_errors = errors[:, -1] # (batch,)
|
||||
all_errors.append(seq_errors.cpu().numpy())
|
||||
|
||||
all_errors = np.concatenate(all_errors)
|
||||
|
||||
thresholds = {}
|
||||
for p in percentiles:
|
||||
threshold = np.percentile(all_errors, p)
|
||||
thresholds[f'p{p}'] = float(threshold)
|
||||
print(f" P{p}: {threshold:.6f}")
|
||||
|
||||
# 额外统计
|
||||
thresholds['mean'] = float(np.mean(all_errors))
|
||||
thresholds['std'] = float(np.std(all_errors))
|
||||
thresholds['median'] = float(np.median(all_errors))
|
||||
|
||||
print(f" Mean: {thresholds['mean']:.6f}")
|
||||
print(f" Median: {thresholds['median']:.6f}")
|
||||
print(f" Std: {thresholds['std']:.6f}")
|
||||
|
||||
return thresholds
|
||||
|
||||
|
||||
# ==================== 主函数 ====================
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='训练概念异动检测模型')
|
||||
parser.add_argument('--data_dir', type=str, default='ml/data',
|
||||
help='数据目录路径')
|
||||
parser.add_argument('--epochs', type=int, default=100,
|
||||
help='训练轮数')
|
||||
parser.add_argument('--batch_size', type=int, default=4096,
|
||||
help='批次大小(4x RTX 4090 推荐 4096~8192)')
|
||||
parser.add_argument('--lr', type=float, default=3e-4,
|
||||
help='学习率(大 batch 推荐 3e-4)')
|
||||
parser.add_argument('--device', type=str, default='auto',
|
||||
help='设备 (auto/cuda/cpu)')
|
||||
parser.add_argument('--save_dir', type=str, default='ml/checkpoints',
|
||||
help='模型保存目录')
|
||||
parser.add_argument('--train_end', type=str, default='2024-06-30',
|
||||
help='训练集截止日期')
|
||||
parser.add_argument('--val_end', type=str, default='2024-09-30',
|
||||
help='验证集截止日期')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 更新配置
|
||||
config = TRAIN_CONFIG.copy()
|
||||
config['batch_size'] = args.batch_size
|
||||
config['epochs'] = args.epochs
|
||||
config['learning_rate'] = args.lr
|
||||
config['train_end_date'] = args.train_end
|
||||
config['val_end_date'] = args.val_end
|
||||
|
||||
# 设备选择
|
||||
if args.device == 'auto':
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
else:
|
||||
device = torch.device(args.device)
|
||||
|
||||
print("=" * 60)
|
||||
print("概念异动检测模型训练(修复版)")
|
||||
print("=" * 60)
|
||||
print(f"配置:")
|
||||
print(f" 数据目录: {args.data_dir}")
|
||||
print(f" 设备: {device}")
|
||||
print(f" 批次大小: {config['batch_size']}")
|
||||
print(f" 学习率: {config['learning_rate']}")
|
||||
print(f" 训练轮数: {config['epochs']}")
|
||||
print(f" 训练集截止: {config['train_end_date']}")
|
||||
print(f" 验证集截止: {config['val_end_date']}")
|
||||
print("=" * 60)
|
||||
|
||||
# 1. 按日期加载数据
|
||||
print("\n[1/6] 加载数据...")
|
||||
date_data = load_data_by_date(args.data_dir, config['features'])
|
||||
|
||||
# 2. 按日期划分
|
||||
print("\n[2/6] 按日期划分数据集...")
|
||||
train_data, val_data, test_data = split_data_by_date(
|
||||
date_data,
|
||||
config['train_end_date'],
|
||||
config['val_end_date']
|
||||
)
|
||||
|
||||
# 3. 按概念构建序列
|
||||
print("\n[3/6] 按概念构建序列...")
|
||||
print("训练集:")
|
||||
train_sequences = build_sequences_by_concept(
|
||||
train_data, config['features'], config['seq_len'], config['stride']
|
||||
)
|
||||
print("验证集:")
|
||||
val_sequences = build_sequences_by_concept(
|
||||
val_data, config['features'], config['seq_len'], config['stride']
|
||||
)
|
||||
print("测试集:")
|
||||
test_sequences = build_sequences_by_concept(
|
||||
test_data, config['features'], config['seq_len'], config['stride']
|
||||
)
|
||||
|
||||
if len(train_sequences) == 0:
|
||||
print("错误: 训练集为空!请检查数据和日期范围")
|
||||
return
|
||||
|
||||
# 4. 数据预处理(简单截断极端值,标准化在模型内部通过 Instance Norm 完成)
|
||||
print("\n[4/6] 数据预处理...")
|
||||
print(" 注意: 使用 Instance Norm,每个序列在模型内部单独标准化")
|
||||
print(" 这样可以处理不同概念波动率差异(银行 vs 半导体)")
|
||||
|
||||
clip_value = config['clip_value']
|
||||
print(f" 截断极端值: ±{clip_value}")
|
||||
|
||||
# 简单截断极端值(防止异常数据影响训练)
|
||||
train_sequences = np.clip(train_sequences, -clip_value, clip_value)
|
||||
if len(val_sequences) > 0:
|
||||
val_sequences = np.clip(val_sequences, -clip_value, clip_value)
|
||||
if len(test_sequences) > 0:
|
||||
test_sequences = np.clip(test_sequences, -clip_value, clip_value)
|
||||
|
||||
# 保存配置
|
||||
save_dir = Path(args.save_dir)
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
preprocess_params = {
|
||||
'features': config['features'],
|
||||
'normalization': 'instance_norm', # 在模型内部完成
|
||||
'clip_value': clip_value,
|
||||
'note': '标准化在模型内部通过 InstanceNorm1d 完成,无需外部 Scaler'
|
||||
}
|
||||
|
||||
with open(save_dir / 'normalization_stats.json', 'w') as f:
|
||||
json.dump(preprocess_params, f, indent=2)
|
||||
print(f" 预处理参数已保存")
|
||||
|
||||
# 5. 创建数据集和加载器
|
||||
print("\n[5/6] 创建数据加载器...")
|
||||
train_dataset = SequenceDataset(train_sequences)
|
||||
val_dataset = SequenceDataset(val_sequences) if len(val_sequences) > 0 else None
|
||||
test_dataset = SequenceDataset(test_sequences) if len(test_sequences) > 0 else None
|
||||
|
||||
print(f" 训练序列: {len(train_dataset):,}")
|
||||
print(f" 验证序列: {len(val_dataset) if val_dataset else 0:,}")
|
||||
print(f" 测试序列: {len(test_dataset) if test_dataset else 0:,}")
|
||||
|
||||
# 多卡时增加 num_workers(Linux 上可以用更多)
|
||||
n_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 1
|
||||
num_workers = min(32, 8 * n_gpus) if sys.platform != 'win32' else 0
|
||||
print(f" DataLoader workers: {num_workers}")
|
||||
print(f" Batch size: {config['batch_size']}")
|
||||
|
||||
# 大 batch + 多 worker + prefetch 提速
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=config['batch_size'],
|
||||
shuffle=True,
|
||||
num_workers=num_workers,
|
||||
pin_memory=True,
|
||||
prefetch_factor=4 if num_workers > 0 else None, # 预取更多 batch
|
||||
persistent_workers=True if num_workers > 0 else False, # 保持 worker 存活
|
||||
drop_last=True # 丢弃不完整的最后一批,避免 batch 大小不一致
|
||||
)
|
||||
|
||||
val_loader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=config['batch_size'] * 2, # 验证时可以用更大 batch(无梯度)
|
||||
shuffle=False,
|
||||
num_workers=num_workers,
|
||||
pin_memory=True,
|
||||
prefetch_factor=4 if num_workers > 0 else None,
|
||||
persistent_workers=True if num_workers > 0 else False,
|
||||
) if val_dataset else None
|
||||
|
||||
test_loader = DataLoader(
|
||||
test_dataset,
|
||||
batch_size=config['batch_size'] * 2,
|
||||
shuffle=False,
|
||||
num_workers=num_workers,
|
||||
pin_memory=True,
|
||||
prefetch_factor=4 if num_workers > 0 else None,
|
||||
persistent_workers=True if num_workers > 0 else False,
|
||||
) if test_dataset else None
|
||||
|
||||
# 6. 训练
|
||||
print("\n[6/6] 训练模型...")
|
||||
model_config = config['model'].copy()
|
||||
model = TransformerAutoencoder(**model_config)
|
||||
|
||||
# 多卡并行
|
||||
if torch.cuda.device_count() > 1:
|
||||
print(f" 使用 {torch.cuda.device_count()} 张 GPU 并行训练")
|
||||
model = nn.DataParallel(model)
|
||||
|
||||
if val_loader is None:
|
||||
print("警告: 验证集为空,将使用训练集的一部分作为验证")
|
||||
# 简单处理:用训练集的后 10% 作为验证
|
||||
split_idx = int(len(train_dataset) * 0.9)
|
||||
train_subset = torch.utils.data.Subset(train_dataset, range(split_idx))
|
||||
val_subset = torch.utils.data.Subset(train_dataset, range(split_idx, len(train_dataset)))
|
||||
|
||||
train_loader = DataLoader(train_subset, batch_size=config['batch_size'], shuffle=True, num_workers=num_workers, pin_memory=True)
|
||||
val_loader = DataLoader(val_subset, batch_size=config['batch_size'], shuffle=False, num_workers=num_workers, pin_memory=True)
|
||||
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
train_loader=train_loader,
|
||||
val_loader=val_loader,
|
||||
config=config,
|
||||
device=device,
|
||||
save_dir=args.save_dir
|
||||
)
|
||||
|
||||
history = trainer.train(config['epochs'])
|
||||
|
||||
# 7. 计算阈值(使用验证集)
|
||||
print("\n[额外] 计算异动阈值...")
|
||||
|
||||
# 加载最佳模型
|
||||
best_checkpoint = torch.load(
|
||||
save_dir / 'best_model.pt',
|
||||
map_location=device
|
||||
)
|
||||
model.load_state_dict(best_checkpoint['model_state_dict'])
|
||||
model.to(device)
|
||||
|
||||
# 使用验证集计算阈值(避免数据泄露)
|
||||
thresholds = compute_thresholds(
|
||||
model,
|
||||
val_loader,
|
||||
device,
|
||||
config['threshold_percentiles']
|
||||
)
|
||||
|
||||
# 保存阈值
|
||||
with open(save_dir / 'thresholds.json', 'w') as f:
|
||||
json.dump(thresholds, f, indent=2)
|
||||
print(f"阈值已保存")
|
||||
|
||||
# 保存完整配置
|
||||
with open(save_dir / 'config.json', 'w') as f:
|
||||
json.dump(config, f, indent=2)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("训练完成!")
|
||||
print("=" * 60)
|
||||
print(f"模型保存位置: {args.save_dir}")
|
||||
print(f" - best_model.pt: 最佳模型权重")
|
||||
print(f" - thresholds.json: 异动阈值")
|
||||
print(f" - normalization_stats.json: 标准化参数")
|
||||
print(f" - config.json: 训练配置")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user