update pay ui
This commit is contained in:
622
ml/train_v2.py
Normal file
622
ml/train_v2.py
Normal file
@@ -0,0 +1,622 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
训练脚本 V2 - 基于 Z-Score 特征的 LSTM Autoencoder
|
||||
|
||||
改进点:
|
||||
1. 使用 Z-Score 特征(相对于同时间片历史的偏离)
|
||||
2. 短序列:10分钟(不需要30分钟预热)
|
||||
3. 开盘即可检测:9:30 直接有特征
|
||||
|
||||
模型输入:
|
||||
- 过去10分钟的 Z-Score 特征序列
|
||||
- 特征:alpha_zscore, amt_zscore, rank_zscore, momentum_3m, momentum_5m, limit_up_ratio
|
||||
|
||||
模型学习:
|
||||
- 学习 Z-Score 序列的"正常演化模式"
|
||||
- 异动 = Z-Score 序列的异常演化(重构误差大)
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple, Dict
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from torch.optim import AdamW
|
||||
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
|
||||
from tqdm import tqdm
|
||||
|
||||
from model import TransformerAutoencoder, AnomalyDetectionLoss, count_parameters
|
||||
|
||||
# 性能优化
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
try:
|
||||
import matplotlib
|
||||
matplotlib.use('Agg')
|
||||
import matplotlib.pyplot as plt
|
||||
HAS_MATPLOTLIB = True
|
||||
except ImportError:
|
||||
HAS_MATPLOTLIB = False
|
||||
|
||||
|
||||
# ==================== 配置 ====================
|
||||
|
||||
TRAIN_CONFIG = {
|
||||
# 数据配置(改进!)
|
||||
'seq_len': 10, # 10分钟序列(不是30分钟!)
|
||||
'stride': 2, # 步长2分钟
|
||||
|
||||
# 时间切分
|
||||
'train_end_date': '2024-06-30',
|
||||
'val_end_date': '2024-09-30',
|
||||
|
||||
# V2 特征(Z-Score 为主)
|
||||
'features': [
|
||||
'alpha_zscore', # Alpha 的 Z-Score
|
||||
'amt_zscore', # 成交额的 Z-Score
|
||||
'rank_zscore', # 排名的 Z-Score
|
||||
'momentum_3m', # 3分钟动量
|
||||
'momentum_5m', # 5分钟动量
|
||||
'limit_up_ratio', # 涨停占比
|
||||
],
|
||||
|
||||
# 训练配置
|
||||
'batch_size': 4096,
|
||||
'epochs': 100,
|
||||
'learning_rate': 3e-4,
|
||||
'weight_decay': 1e-5,
|
||||
'gradient_clip': 1.0,
|
||||
|
||||
# 早停配置
|
||||
'patience': 15,
|
||||
'min_delta': 1e-6,
|
||||
|
||||
# 模型配置(小型 LSTM)
|
||||
'model': {
|
||||
'n_features': 6,
|
||||
'hidden_dim': 32,
|
||||
'latent_dim': 4,
|
||||
'num_layers': 1,
|
||||
'dropout': 0.2,
|
||||
'bidirectional': True,
|
||||
},
|
||||
|
||||
# 标准化配置
|
||||
'clip_value': 5.0, # Z-Score 已经标准化,clip 5.0 足够
|
||||
|
||||
# 阈值配置
|
||||
'threshold_percentiles': [90, 95, 99],
|
||||
}
|
||||
|
||||
|
||||
# ==================== 数据加载 ====================
|
||||
|
||||
def load_data_by_date(data_dir: str, features: List[str]) -> Dict[str, pd.DataFrame]:
|
||||
"""按日期加载 V2 数据"""
|
||||
data_path = Path(data_dir)
|
||||
parquet_files = sorted(data_path.glob("features_v2_*.parquet"))
|
||||
|
||||
if not parquet_files:
|
||||
raise FileNotFoundError(f"未找到 V2 数据文件: {data_dir}")
|
||||
|
||||
print(f"找到 {len(parquet_files)} 个 V2 数据文件")
|
||||
|
||||
date_data = {}
|
||||
|
||||
for pf in tqdm(parquet_files, desc="加载数据"):
|
||||
date = pf.stem.replace('features_v2_', '')
|
||||
|
||||
df = pd.read_parquet(pf)
|
||||
|
||||
required_cols = features + ['concept_id', 'timestamp']
|
||||
missing_cols = [c for c in required_cols if c not in df.columns]
|
||||
if missing_cols:
|
||||
print(f"警告: {date} 缺少列: {missing_cols}, 跳过")
|
||||
continue
|
||||
|
||||
date_data[date] = df
|
||||
|
||||
print(f"成功加载 {len(date_data)} 天的数据")
|
||||
return date_data
|
||||
|
||||
|
||||
def split_data_by_date(
|
||||
date_data: Dict[str, pd.DataFrame],
|
||||
train_end: str,
|
||||
val_end: str
|
||||
) -> Tuple[Dict[str, pd.DataFrame], Dict[str, pd.DataFrame], Dict[str, pd.DataFrame]]:
|
||||
"""按日期划分数据集"""
|
||||
train_data = {}
|
||||
val_data = {}
|
||||
test_data = {}
|
||||
|
||||
for date, df in date_data.items():
|
||||
if date <= train_end:
|
||||
train_data[date] = df
|
||||
elif date <= val_end:
|
||||
val_data[date] = df
|
||||
else:
|
||||
test_data[date] = df
|
||||
|
||||
print(f"数据集划分:")
|
||||
print(f" 训练集: {len(train_data)} 天 (<= {train_end})")
|
||||
print(f" 验证集: {len(val_data)} 天 ({train_end} ~ {val_end})")
|
||||
print(f" 测试集: {len(test_data)} 天 (> {val_end})")
|
||||
|
||||
return train_data, val_data, test_data
|
||||
|
||||
|
||||
def build_sequences_by_concept(
|
||||
date_data: Dict[str, pd.DataFrame],
|
||||
features: List[str],
|
||||
seq_len: int,
|
||||
stride: int
|
||||
) -> np.ndarray:
|
||||
"""按概念分组构建序列"""
|
||||
all_dfs = []
|
||||
for date, df in sorted(date_data.items()):
|
||||
df = df.copy()
|
||||
df['date'] = date
|
||||
all_dfs.append(df)
|
||||
|
||||
if not all_dfs:
|
||||
return np.array([])
|
||||
|
||||
combined = pd.concat(all_dfs, ignore_index=True)
|
||||
combined = combined.sort_values(['concept_id', 'date', 'timestamp'])
|
||||
|
||||
all_sequences = []
|
||||
grouped = combined.groupby('concept_id', sort=False)
|
||||
n_concepts = len(grouped)
|
||||
|
||||
for concept_id, concept_df in tqdm(grouped, desc="构建序列", total=n_concepts, leave=False):
|
||||
feature_data = concept_df[features].values
|
||||
feature_data = np.nan_to_num(feature_data, nan=0.0, posinf=0.0, neginf=0.0)
|
||||
|
||||
n_points = len(feature_data)
|
||||
for start in range(0, n_points - seq_len + 1, stride):
|
||||
seq = feature_data[start:start + seq_len]
|
||||
all_sequences.append(seq)
|
||||
|
||||
if not all_sequences:
|
||||
return np.array([])
|
||||
|
||||
sequences = np.array(all_sequences)
|
||||
print(f" 构建序列: {len(sequences):,} 条 (来自 {n_concepts} 个概念)")
|
||||
|
||||
return sequences
|
||||
|
||||
|
||||
# ==================== 数据集 ====================
|
||||
|
||||
class SequenceDataset(Dataset):
|
||||
def __init__(self, sequences: np.ndarray):
|
||||
self.sequences = torch.FloatTensor(sequences)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.sequences)
|
||||
|
||||
def __getitem__(self, idx: int) -> torch.Tensor:
|
||||
return self.sequences[idx]
|
||||
|
||||
|
||||
# ==================== 训练器 ====================
|
||||
|
||||
class EarlyStopping:
|
||||
def __init__(self, patience: int = 10, min_delta: float = 1e-6):
|
||||
self.patience = patience
|
||||
self.min_delta = min_delta
|
||||
self.counter = 0
|
||||
self.best_loss = float('inf')
|
||||
self.early_stop = False
|
||||
|
||||
def __call__(self, val_loss: float) -> bool:
|
||||
if val_loss < self.best_loss - self.min_delta:
|
||||
self.best_loss = val_loss
|
||||
self.counter = 0
|
||||
else:
|
||||
self.counter += 1
|
||||
if self.counter >= self.patience:
|
||||
self.early_stop = True
|
||||
return self.early_stop
|
||||
|
||||
|
||||
class Trainer:
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
train_loader: DataLoader,
|
||||
val_loader: DataLoader,
|
||||
config: Dict,
|
||||
device: torch.device,
|
||||
save_dir: str = 'ml/checkpoints_v2'
|
||||
):
|
||||
self.model = model.to(device)
|
||||
self.train_loader = train_loader
|
||||
self.val_loader = val_loader
|
||||
self.config = config
|
||||
self.device = device
|
||||
self.save_dir = Path(save_dir)
|
||||
self.save_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.optimizer = AdamW(
|
||||
model.parameters(),
|
||||
lr=config['learning_rate'],
|
||||
weight_decay=config['weight_decay']
|
||||
)
|
||||
|
||||
self.scheduler = CosineAnnealingWarmRestarts(
|
||||
self.optimizer, T_0=10, T_mult=2, eta_min=1e-6
|
||||
)
|
||||
|
||||
self.criterion = AnomalyDetectionLoss()
|
||||
|
||||
self.early_stopping = EarlyStopping(
|
||||
patience=config['patience'],
|
||||
min_delta=config['min_delta']
|
||||
)
|
||||
|
||||
self.use_amp = torch.cuda.is_available()
|
||||
self.scaler = torch.cuda.amp.GradScaler() if self.use_amp else None
|
||||
if self.use_amp:
|
||||
print(" ✓ 启用 AMP 混合精度训练")
|
||||
|
||||
self.history = {'train_loss': [], 'val_loss': [], 'learning_rate': []}
|
||||
self.best_val_loss = float('inf')
|
||||
|
||||
def train_epoch(self) -> float:
|
||||
self.model.train()
|
||||
total_loss = 0.0
|
||||
n_batches = 0
|
||||
|
||||
pbar = tqdm(self.train_loader, desc="Training", leave=False)
|
||||
for batch in pbar:
|
||||
batch = batch.to(self.device, non_blocking=True)
|
||||
self.optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
if self.use_amp:
|
||||
with torch.cuda.amp.autocast():
|
||||
output, latent = self.model(batch)
|
||||
loss, _ = self.criterion(output, batch, latent)
|
||||
|
||||
self.scaler.scale(loss).backward()
|
||||
self.scaler.unscale_(self.optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config['gradient_clip'])
|
||||
self.scaler.step(self.optimizer)
|
||||
self.scaler.update()
|
||||
else:
|
||||
output, latent = self.model(batch)
|
||||
loss, _ = self.criterion(output, batch, latent)
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config['gradient_clip'])
|
||||
self.optimizer.step()
|
||||
|
||||
total_loss += loss.item()
|
||||
n_batches += 1
|
||||
pbar.set_postfix({'loss': f"{loss.item():.4f}"})
|
||||
|
||||
return total_loss / n_batches
|
||||
|
||||
@torch.no_grad()
|
||||
def validate(self) -> float:
|
||||
self.model.eval()
|
||||
total_loss = 0.0
|
||||
n_batches = 0
|
||||
|
||||
for batch in self.val_loader:
|
||||
batch = batch.to(self.device, non_blocking=True)
|
||||
|
||||
if self.use_amp:
|
||||
with torch.cuda.amp.autocast():
|
||||
output, latent = self.model(batch)
|
||||
loss, _ = self.criterion(output, batch, latent)
|
||||
else:
|
||||
output, latent = self.model(batch)
|
||||
loss, _ = self.criterion(output, batch, latent)
|
||||
|
||||
total_loss += loss.item()
|
||||
n_batches += 1
|
||||
|
||||
return total_loss / n_batches
|
||||
|
||||
def save_checkpoint(self, epoch: int, val_loss: float, is_best: bool = False):
|
||||
model_to_save = self.model.module if hasattr(self.model, 'module') else self.model
|
||||
|
||||
checkpoint = {
|
||||
'epoch': epoch,
|
||||
'model_state_dict': model_to_save.state_dict(),
|
||||
'optimizer_state_dict': self.optimizer.state_dict(),
|
||||
'scheduler_state_dict': self.scheduler.state_dict(),
|
||||
'val_loss': val_loss,
|
||||
'config': self.config,
|
||||
}
|
||||
|
||||
torch.save(checkpoint, self.save_dir / 'last_checkpoint.pt')
|
||||
|
||||
if is_best:
|
||||
torch.save(checkpoint, self.save_dir / 'best_model.pt')
|
||||
print(f" ✓ 保存最佳模型 (val_loss: {val_loss:.6f})")
|
||||
|
||||
def train(self, epochs: int):
|
||||
print(f"\n开始训练 ({epochs} epochs)...")
|
||||
print(f"设备: {self.device}")
|
||||
print(f"模型参数量: {count_parameters(self.model):,}")
|
||||
|
||||
for epoch in range(1, epochs + 1):
|
||||
print(f"\nEpoch {epoch}/{epochs}")
|
||||
|
||||
train_loss = self.train_epoch()
|
||||
val_loss = self.validate()
|
||||
|
||||
self.scheduler.step()
|
||||
current_lr = self.optimizer.param_groups[0]['lr']
|
||||
|
||||
self.history['train_loss'].append(train_loss)
|
||||
self.history['val_loss'].append(val_loss)
|
||||
self.history['learning_rate'].append(current_lr)
|
||||
|
||||
print(f" Train Loss: {train_loss:.6f}")
|
||||
print(f" Val Loss: {val_loss:.6f}")
|
||||
print(f" LR: {current_lr:.2e}")
|
||||
|
||||
is_best = val_loss < self.best_val_loss
|
||||
if is_best:
|
||||
self.best_val_loss = val_loss
|
||||
self.save_checkpoint(epoch, val_loss, is_best)
|
||||
|
||||
if self.early_stopping(val_loss):
|
||||
print(f"\n早停触发!")
|
||||
break
|
||||
|
||||
print(f"\n训练完成!最佳验证损失: {self.best_val_loss:.6f}")
|
||||
self.save_history()
|
||||
|
||||
return self.history
|
||||
|
||||
def save_history(self):
|
||||
history_path = self.save_dir / 'training_history.json'
|
||||
with open(history_path, 'w') as f:
|
||||
json.dump(self.history, f, indent=2)
|
||||
print(f"训练历史已保存: {history_path}")
|
||||
|
||||
if HAS_MATPLOTLIB:
|
||||
self.plot_training_curves()
|
||||
|
||||
def plot_training_curves(self):
|
||||
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
|
||||
epochs = range(1, len(self.history['train_loss']) + 1)
|
||||
|
||||
ax1 = axes[0]
|
||||
ax1.plot(epochs, self.history['train_loss'], 'b-', label='Train Loss', linewidth=2)
|
||||
ax1.plot(epochs, self.history['val_loss'], 'r-', label='Val Loss', linewidth=2)
|
||||
ax1.set_xlabel('Epoch')
|
||||
ax1.set_ylabel('Loss')
|
||||
ax1.set_title('Training & Validation Loss (V2)')
|
||||
ax1.legend()
|
||||
ax1.grid(True, alpha=0.3)
|
||||
|
||||
best_epoch = np.argmin(self.history['val_loss']) + 1
|
||||
best_val_loss = min(self.history['val_loss'])
|
||||
ax1.axvline(x=best_epoch, color='g', linestyle='--', alpha=0.7)
|
||||
ax1.scatter([best_epoch], [best_val_loss], color='g', s=100, zorder=5)
|
||||
|
||||
ax2 = axes[1]
|
||||
ax2.plot(epochs, self.history['learning_rate'], 'g-', linewidth=2)
|
||||
ax2.set_xlabel('Epoch')
|
||||
ax2.set_ylabel('Learning Rate')
|
||||
ax2.set_title('Learning Rate Schedule')
|
||||
ax2.set_yscale('log')
|
||||
ax2.grid(True, alpha=0.3)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(self.save_dir / 'training_curves.png', dpi=150, bbox_inches='tight')
|
||||
plt.close()
|
||||
print(f"训练曲线已保存")
|
||||
|
||||
|
||||
# ==================== 阈值计算 ====================
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_thresholds(
|
||||
model: nn.Module,
|
||||
data_loader: DataLoader,
|
||||
device: torch.device,
|
||||
percentiles: List[float] = [90, 95, 99]
|
||||
) -> Dict[str, float]:
|
||||
"""在验证集上计算阈值"""
|
||||
model.eval()
|
||||
all_errors = []
|
||||
|
||||
print("计算异动阈值...")
|
||||
for batch in tqdm(data_loader, desc="Computing thresholds"):
|
||||
batch = batch.to(device)
|
||||
errors = model.compute_reconstruction_error(batch, reduction='none')
|
||||
seq_errors = errors[:, -1] # 最后一个时刻
|
||||
all_errors.append(seq_errors.cpu().numpy())
|
||||
|
||||
all_errors = np.concatenate(all_errors)
|
||||
|
||||
thresholds = {}
|
||||
for p in percentiles:
|
||||
threshold = np.percentile(all_errors, p)
|
||||
thresholds[f'p{p}'] = float(threshold)
|
||||
print(f" P{p}: {threshold:.6f}")
|
||||
|
||||
thresholds['mean'] = float(np.mean(all_errors))
|
||||
thresholds['std'] = float(np.std(all_errors))
|
||||
thresholds['median'] = float(np.median(all_errors))
|
||||
|
||||
return thresholds
|
||||
|
||||
|
||||
# ==================== 主函数 ====================
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='训练 V2 模型')
|
||||
parser.add_argument('--data_dir', type=str, default='ml/data_v2', help='V2 数据目录')
|
||||
parser.add_argument('--epochs', type=int, default=100)
|
||||
parser.add_argument('--batch_size', type=int, default=4096)
|
||||
parser.add_argument('--lr', type=float, default=3e-4)
|
||||
parser.add_argument('--device', type=str, default='auto')
|
||||
parser.add_argument('--save_dir', type=str, default='ml/checkpoints_v2')
|
||||
parser.add_argument('--train_end', type=str, default='2024-06-30')
|
||||
parser.add_argument('--val_end', type=str, default='2024-09-30')
|
||||
parser.add_argument('--seq_len', type=int, default=10, help='序列长度(分钟)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
config = TRAIN_CONFIG.copy()
|
||||
config['batch_size'] = args.batch_size
|
||||
config['epochs'] = args.epochs
|
||||
config['learning_rate'] = args.lr
|
||||
config['train_end_date'] = args.train_end
|
||||
config['val_end_date'] = args.val_end
|
||||
config['seq_len'] = args.seq_len
|
||||
|
||||
if args.device == 'auto':
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
else:
|
||||
device = torch.device(args.device)
|
||||
|
||||
print("=" * 60)
|
||||
print("概念异动检测模型训练 V2(Z-Score 特征)")
|
||||
print("=" * 60)
|
||||
print(f"数据目录: {args.data_dir}")
|
||||
print(f"设备: {device}")
|
||||
print(f"序列长度: {config['seq_len']} 分钟")
|
||||
print(f"批次大小: {config['batch_size']}")
|
||||
print(f"特征: {config['features']}")
|
||||
print("=" * 60)
|
||||
|
||||
# 1. 加载数据
|
||||
print("\n[1/6] 加载 V2 数据...")
|
||||
date_data = load_data_by_date(args.data_dir, config['features'])
|
||||
|
||||
# 2. 划分数据集
|
||||
print("\n[2/6] 划分数据集...")
|
||||
train_data, val_data, test_data = split_data_by_date(
|
||||
date_data, config['train_end_date'], config['val_end_date']
|
||||
)
|
||||
|
||||
# 3. 构建序列
|
||||
print("\n[3/6] 构建序列...")
|
||||
print("训练集:")
|
||||
train_sequences = build_sequences_by_concept(
|
||||
train_data, config['features'], config['seq_len'], config['stride']
|
||||
)
|
||||
print("验证集:")
|
||||
val_sequences = build_sequences_by_concept(
|
||||
val_data, config['features'], config['seq_len'], config['stride']
|
||||
)
|
||||
|
||||
if len(train_sequences) == 0:
|
||||
print("错误: 训练集为空!")
|
||||
return
|
||||
|
||||
# 4. 预处理
|
||||
print("\n[4/6] 数据预处理...")
|
||||
clip_value = config['clip_value']
|
||||
print(f" Z-Score 特征已标准化,截断: ±{clip_value}")
|
||||
|
||||
train_sequences = np.clip(train_sequences, -clip_value, clip_value)
|
||||
if len(val_sequences) > 0:
|
||||
val_sequences = np.clip(val_sequences, -clip_value, clip_value)
|
||||
|
||||
# 保存配置
|
||||
save_dir = Path(args.save_dir)
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(save_dir / 'config.json', 'w') as f:
|
||||
json.dump(config, f, indent=2)
|
||||
|
||||
# 5. 创建数据加载器
|
||||
print("\n[5/6] 创建数据加载器...")
|
||||
train_dataset = SequenceDataset(train_sequences)
|
||||
val_dataset = SequenceDataset(val_sequences) if len(val_sequences) > 0 else None
|
||||
|
||||
print(f" 训练序列: {len(train_dataset):,}")
|
||||
print(f" 验证序列: {len(val_dataset) if val_dataset else 0:,}")
|
||||
|
||||
n_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 1
|
||||
num_workers = min(32, 8 * n_gpus) if sys.platform != 'win32' else 0
|
||||
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=config['batch_size'],
|
||||
shuffle=True,
|
||||
num_workers=num_workers,
|
||||
pin_memory=True,
|
||||
prefetch_factor=4 if num_workers > 0 else None,
|
||||
persistent_workers=True if num_workers > 0 else False,
|
||||
drop_last=True
|
||||
)
|
||||
|
||||
val_loader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=config['batch_size'] * 2,
|
||||
shuffle=False,
|
||||
num_workers=num_workers,
|
||||
pin_memory=True,
|
||||
) if val_dataset else None
|
||||
|
||||
# 6. 训练
|
||||
print("\n[6/6] 训练模型...")
|
||||
model = TransformerAutoencoder(**config['model'])
|
||||
|
||||
if torch.cuda.device_count() > 1:
|
||||
print(f" 使用 {torch.cuda.device_count()} 张 GPU 并行训练")
|
||||
model = nn.DataParallel(model)
|
||||
|
||||
if val_loader is None:
|
||||
print("警告: 验证集为空,使用训练集的 10% 作为验证")
|
||||
split_idx = int(len(train_dataset) * 0.9)
|
||||
train_subset = torch.utils.data.Subset(train_dataset, range(split_idx))
|
||||
val_subset = torch.utils.data.Subset(train_dataset, range(split_idx, len(train_dataset)))
|
||||
train_loader = DataLoader(train_subset, batch_size=config['batch_size'], shuffle=True, num_workers=num_workers, pin_memory=True)
|
||||
val_loader = DataLoader(val_subset, batch_size=config['batch_size'], shuffle=False, num_workers=num_workers, pin_memory=True)
|
||||
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
train_loader=train_loader,
|
||||
val_loader=val_loader,
|
||||
config=config,
|
||||
device=device,
|
||||
save_dir=args.save_dir
|
||||
)
|
||||
|
||||
trainer.train(config['epochs'])
|
||||
|
||||
# 计算阈值
|
||||
print("\n[额外] 计算异动阈值...")
|
||||
best_checkpoint = torch.load(save_dir / 'best_model.pt', map_location=device)
|
||||
|
||||
# 创建新的单 GPU 模型用于计算阈值(避免 DataParallel 问题)
|
||||
threshold_model = TransformerAutoencoder(**config['model'])
|
||||
threshold_model.load_state_dict(best_checkpoint['model_state_dict'])
|
||||
threshold_model.to(device)
|
||||
threshold_model.eval()
|
||||
|
||||
thresholds = compute_thresholds(threshold_model, val_loader, device, config['threshold_percentiles'])
|
||||
|
||||
with open(save_dir / 'thresholds.json', 'w') as f:
|
||||
json.dump(thresholds, f, indent=2)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("训练完成!")
|
||||
print(f"模型保存位置: {args.save_dir}")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user