623 lines
20 KiB
Python
623 lines
20 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
训练脚本 V2 - 基于 Z-Score 特征的 LSTM Autoencoder
|
||
|
||
改进点:
|
||
1. 使用 Z-Score 特征(相对于同时间片历史的偏离)
|
||
2. 短序列:10分钟(不需要30分钟预热)
|
||
3. 开盘即可检测:9:30 直接有特征
|
||
|
||
模型输入:
|
||
- 过去10分钟的 Z-Score 特征序列
|
||
- 特征:alpha_zscore, amt_zscore, rank_zscore, momentum_3m, momentum_5m, limit_up_ratio
|
||
|
||
模型学习:
|
||
- 学习 Z-Score 序列的"正常演化模式"
|
||
- 异动 = Z-Score 序列的异常演化(重构误差大)
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import argparse
|
||
import json
|
||
from datetime import datetime
|
||
from pathlib import Path
|
||
from typing import List, Tuple, Dict
|
||
|
||
import numpy as np
|
||
import pandas as pd
|
||
import torch
|
||
import torch.nn as nn
|
||
from torch.utils.data import Dataset, DataLoader
|
||
from torch.optim import AdamW
|
||
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
|
||
from tqdm import tqdm
|
||
|
||
from model import TransformerAutoencoder, AnomalyDetectionLoss, count_parameters
|
||
|
||
# 性能优化
|
||
torch.backends.cudnn.benchmark = True
|
||
torch.backends.cuda.matmul.allow_tf32 = True
|
||
torch.backends.cudnn.allow_tf32 = True
|
||
|
||
try:
|
||
import matplotlib
|
||
matplotlib.use('Agg')
|
||
import matplotlib.pyplot as plt
|
||
HAS_MATPLOTLIB = True
|
||
except ImportError:
|
||
HAS_MATPLOTLIB = False
|
||
|
||
|
||
# ==================== 配置 ====================
|
||
|
||
TRAIN_CONFIG = {
|
||
# 数据配置(改进!)
|
||
'seq_len': 10, # 10分钟序列(不是30分钟!)
|
||
'stride': 2, # 步长2分钟
|
||
|
||
# 时间切分
|
||
'train_end_date': '2024-06-30',
|
||
'val_end_date': '2024-09-30',
|
||
|
||
# V2 特征(Z-Score 为主)
|
||
'features': [
|
||
'alpha_zscore', # Alpha 的 Z-Score
|
||
'amt_zscore', # 成交额的 Z-Score
|
||
'rank_zscore', # 排名的 Z-Score
|
||
'momentum_3m', # 3分钟动量
|
||
'momentum_5m', # 5分钟动量
|
||
'limit_up_ratio', # 涨停占比
|
||
],
|
||
|
||
# 训练配置
|
||
'batch_size': 4096,
|
||
'epochs': 100,
|
||
'learning_rate': 3e-4,
|
||
'weight_decay': 1e-5,
|
||
'gradient_clip': 1.0,
|
||
|
||
# 早停配置
|
||
'patience': 15,
|
||
'min_delta': 1e-6,
|
||
|
||
# 模型配置(小型 LSTM)
|
||
'model': {
|
||
'n_features': 6,
|
||
'hidden_dim': 32,
|
||
'latent_dim': 4,
|
||
'num_layers': 1,
|
||
'dropout': 0.2,
|
||
'bidirectional': True,
|
||
},
|
||
|
||
# 标准化配置
|
||
'clip_value': 5.0, # Z-Score 已经标准化,clip 5.0 足够
|
||
|
||
# 阈值配置
|
||
'threshold_percentiles': [90, 95, 99],
|
||
}
|
||
|
||
|
||
# ==================== 数据加载 ====================
|
||
|
||
def load_data_by_date(data_dir: str, features: List[str]) -> Dict[str, pd.DataFrame]:
|
||
"""按日期加载 V2 数据"""
|
||
data_path = Path(data_dir)
|
||
parquet_files = sorted(data_path.glob("features_v2_*.parquet"))
|
||
|
||
if not parquet_files:
|
||
raise FileNotFoundError(f"未找到 V2 数据文件: {data_dir}")
|
||
|
||
print(f"找到 {len(parquet_files)} 个 V2 数据文件")
|
||
|
||
date_data = {}
|
||
|
||
for pf in tqdm(parquet_files, desc="加载数据"):
|
||
date = pf.stem.replace('features_v2_', '')
|
||
|
||
df = pd.read_parquet(pf)
|
||
|
||
required_cols = features + ['concept_id', 'timestamp']
|
||
missing_cols = [c for c in required_cols if c not in df.columns]
|
||
if missing_cols:
|
||
print(f"警告: {date} 缺少列: {missing_cols}, 跳过")
|
||
continue
|
||
|
||
date_data[date] = df
|
||
|
||
print(f"成功加载 {len(date_data)} 天的数据")
|
||
return date_data
|
||
|
||
|
||
def split_data_by_date(
|
||
date_data: Dict[str, pd.DataFrame],
|
||
train_end: str,
|
||
val_end: str
|
||
) -> Tuple[Dict[str, pd.DataFrame], Dict[str, pd.DataFrame], Dict[str, pd.DataFrame]]:
|
||
"""按日期划分数据集"""
|
||
train_data = {}
|
||
val_data = {}
|
||
test_data = {}
|
||
|
||
for date, df in date_data.items():
|
||
if date <= train_end:
|
||
train_data[date] = df
|
||
elif date <= val_end:
|
||
val_data[date] = df
|
||
else:
|
||
test_data[date] = df
|
||
|
||
print(f"数据集划分:")
|
||
print(f" 训练集: {len(train_data)} 天 (<= {train_end})")
|
||
print(f" 验证集: {len(val_data)} 天 ({train_end} ~ {val_end})")
|
||
print(f" 测试集: {len(test_data)} 天 (> {val_end})")
|
||
|
||
return train_data, val_data, test_data
|
||
|
||
|
||
def build_sequences_by_concept(
|
||
date_data: Dict[str, pd.DataFrame],
|
||
features: List[str],
|
||
seq_len: int,
|
||
stride: int
|
||
) -> np.ndarray:
|
||
"""按概念分组构建序列"""
|
||
all_dfs = []
|
||
for date, df in sorted(date_data.items()):
|
||
df = df.copy()
|
||
df['date'] = date
|
||
all_dfs.append(df)
|
||
|
||
if not all_dfs:
|
||
return np.array([])
|
||
|
||
combined = pd.concat(all_dfs, ignore_index=True)
|
||
combined = combined.sort_values(['concept_id', 'date', 'timestamp'])
|
||
|
||
all_sequences = []
|
||
grouped = combined.groupby('concept_id', sort=False)
|
||
n_concepts = len(grouped)
|
||
|
||
for concept_id, concept_df in tqdm(grouped, desc="构建序列", total=n_concepts, leave=False):
|
||
feature_data = concept_df[features].values
|
||
feature_data = np.nan_to_num(feature_data, nan=0.0, posinf=0.0, neginf=0.0)
|
||
|
||
n_points = len(feature_data)
|
||
for start in range(0, n_points - seq_len + 1, stride):
|
||
seq = feature_data[start:start + seq_len]
|
||
all_sequences.append(seq)
|
||
|
||
if not all_sequences:
|
||
return np.array([])
|
||
|
||
sequences = np.array(all_sequences)
|
||
print(f" 构建序列: {len(sequences):,} 条 (来自 {n_concepts} 个概念)")
|
||
|
||
return sequences
|
||
|
||
|
||
# ==================== 数据集 ====================
|
||
|
||
class SequenceDataset(Dataset):
|
||
def __init__(self, sequences: np.ndarray):
|
||
self.sequences = torch.FloatTensor(sequences)
|
||
|
||
def __len__(self) -> int:
|
||
return len(self.sequences)
|
||
|
||
def __getitem__(self, idx: int) -> torch.Tensor:
|
||
return self.sequences[idx]
|
||
|
||
|
||
# ==================== 训练器 ====================
|
||
|
||
class EarlyStopping:
|
||
def __init__(self, patience: int = 10, min_delta: float = 1e-6):
|
||
self.patience = patience
|
||
self.min_delta = min_delta
|
||
self.counter = 0
|
||
self.best_loss = float('inf')
|
||
self.early_stop = False
|
||
|
||
def __call__(self, val_loss: float) -> bool:
|
||
if val_loss < self.best_loss - self.min_delta:
|
||
self.best_loss = val_loss
|
||
self.counter = 0
|
||
else:
|
||
self.counter += 1
|
||
if self.counter >= self.patience:
|
||
self.early_stop = True
|
||
return self.early_stop
|
||
|
||
|
||
class Trainer:
|
||
def __init__(
|
||
self,
|
||
model: nn.Module,
|
||
train_loader: DataLoader,
|
||
val_loader: DataLoader,
|
||
config: Dict,
|
||
device: torch.device,
|
||
save_dir: str = 'ml/checkpoints_v2'
|
||
):
|
||
self.model = model.to(device)
|
||
self.train_loader = train_loader
|
||
self.val_loader = val_loader
|
||
self.config = config
|
||
self.device = device
|
||
self.save_dir = Path(save_dir)
|
||
self.save_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
self.optimizer = AdamW(
|
||
model.parameters(),
|
||
lr=config['learning_rate'],
|
||
weight_decay=config['weight_decay']
|
||
)
|
||
|
||
self.scheduler = CosineAnnealingWarmRestarts(
|
||
self.optimizer, T_0=10, T_mult=2, eta_min=1e-6
|
||
)
|
||
|
||
self.criterion = AnomalyDetectionLoss()
|
||
|
||
self.early_stopping = EarlyStopping(
|
||
patience=config['patience'],
|
||
min_delta=config['min_delta']
|
||
)
|
||
|
||
self.use_amp = torch.cuda.is_available()
|
||
self.scaler = torch.cuda.amp.GradScaler() if self.use_amp else None
|
||
if self.use_amp:
|
||
print(" ✓ 启用 AMP 混合精度训练")
|
||
|
||
self.history = {'train_loss': [], 'val_loss': [], 'learning_rate': []}
|
||
self.best_val_loss = float('inf')
|
||
|
||
def train_epoch(self) -> float:
|
||
self.model.train()
|
||
total_loss = 0.0
|
||
n_batches = 0
|
||
|
||
pbar = tqdm(self.train_loader, desc="Training", leave=False)
|
||
for batch in pbar:
|
||
batch = batch.to(self.device, non_blocking=True)
|
||
self.optimizer.zero_grad(set_to_none=True)
|
||
|
||
if self.use_amp:
|
||
with torch.cuda.amp.autocast():
|
||
output, latent = self.model(batch)
|
||
loss, _ = self.criterion(output, batch, latent)
|
||
|
||
self.scaler.scale(loss).backward()
|
||
self.scaler.unscale_(self.optimizer)
|
||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config['gradient_clip'])
|
||
self.scaler.step(self.optimizer)
|
||
self.scaler.update()
|
||
else:
|
||
output, latent = self.model(batch)
|
||
loss, _ = self.criterion(output, batch, latent)
|
||
loss.backward()
|
||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config['gradient_clip'])
|
||
self.optimizer.step()
|
||
|
||
total_loss += loss.item()
|
||
n_batches += 1
|
||
pbar.set_postfix({'loss': f"{loss.item():.4f}"})
|
||
|
||
return total_loss / n_batches
|
||
|
||
@torch.no_grad()
|
||
def validate(self) -> float:
|
||
self.model.eval()
|
||
total_loss = 0.0
|
||
n_batches = 0
|
||
|
||
for batch in self.val_loader:
|
||
batch = batch.to(self.device, non_blocking=True)
|
||
|
||
if self.use_amp:
|
||
with torch.cuda.amp.autocast():
|
||
output, latent = self.model(batch)
|
||
loss, _ = self.criterion(output, batch, latent)
|
||
else:
|
||
output, latent = self.model(batch)
|
||
loss, _ = self.criterion(output, batch, latent)
|
||
|
||
total_loss += loss.item()
|
||
n_batches += 1
|
||
|
||
return total_loss / n_batches
|
||
|
||
def save_checkpoint(self, epoch: int, val_loss: float, is_best: bool = False):
|
||
model_to_save = self.model.module if hasattr(self.model, 'module') else self.model
|
||
|
||
checkpoint = {
|
||
'epoch': epoch,
|
||
'model_state_dict': model_to_save.state_dict(),
|
||
'optimizer_state_dict': self.optimizer.state_dict(),
|
||
'scheduler_state_dict': self.scheduler.state_dict(),
|
||
'val_loss': val_loss,
|
||
'config': self.config,
|
||
}
|
||
|
||
torch.save(checkpoint, self.save_dir / 'last_checkpoint.pt')
|
||
|
||
if is_best:
|
||
torch.save(checkpoint, self.save_dir / 'best_model.pt')
|
||
print(f" ✓ 保存最佳模型 (val_loss: {val_loss:.6f})")
|
||
|
||
def train(self, epochs: int):
|
||
print(f"\n开始训练 ({epochs} epochs)...")
|
||
print(f"设备: {self.device}")
|
||
print(f"模型参数量: {count_parameters(self.model):,}")
|
||
|
||
for epoch in range(1, epochs + 1):
|
||
print(f"\nEpoch {epoch}/{epochs}")
|
||
|
||
train_loss = self.train_epoch()
|
||
val_loss = self.validate()
|
||
|
||
self.scheduler.step()
|
||
current_lr = self.optimizer.param_groups[0]['lr']
|
||
|
||
self.history['train_loss'].append(train_loss)
|
||
self.history['val_loss'].append(val_loss)
|
||
self.history['learning_rate'].append(current_lr)
|
||
|
||
print(f" Train Loss: {train_loss:.6f}")
|
||
print(f" Val Loss: {val_loss:.6f}")
|
||
print(f" LR: {current_lr:.2e}")
|
||
|
||
is_best = val_loss < self.best_val_loss
|
||
if is_best:
|
||
self.best_val_loss = val_loss
|
||
self.save_checkpoint(epoch, val_loss, is_best)
|
||
|
||
if self.early_stopping(val_loss):
|
||
print(f"\n早停触发!")
|
||
break
|
||
|
||
print(f"\n训练完成!最佳验证损失: {self.best_val_loss:.6f}")
|
||
self.save_history()
|
||
|
||
return self.history
|
||
|
||
def save_history(self):
|
||
history_path = self.save_dir / 'training_history.json'
|
||
with open(history_path, 'w') as f:
|
||
json.dump(self.history, f, indent=2)
|
||
print(f"训练历史已保存: {history_path}")
|
||
|
||
if HAS_MATPLOTLIB:
|
||
self.plot_training_curves()
|
||
|
||
def plot_training_curves(self):
|
||
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
|
||
epochs = range(1, len(self.history['train_loss']) + 1)
|
||
|
||
ax1 = axes[0]
|
||
ax1.plot(epochs, self.history['train_loss'], 'b-', label='Train Loss', linewidth=2)
|
||
ax1.plot(epochs, self.history['val_loss'], 'r-', label='Val Loss', linewidth=2)
|
||
ax1.set_xlabel('Epoch')
|
||
ax1.set_ylabel('Loss')
|
||
ax1.set_title('Training & Validation Loss (V2)')
|
||
ax1.legend()
|
||
ax1.grid(True, alpha=0.3)
|
||
|
||
best_epoch = np.argmin(self.history['val_loss']) + 1
|
||
best_val_loss = min(self.history['val_loss'])
|
||
ax1.axvline(x=best_epoch, color='g', linestyle='--', alpha=0.7)
|
||
ax1.scatter([best_epoch], [best_val_loss], color='g', s=100, zorder=5)
|
||
|
||
ax2 = axes[1]
|
||
ax2.plot(epochs, self.history['learning_rate'], 'g-', linewidth=2)
|
||
ax2.set_xlabel('Epoch')
|
||
ax2.set_ylabel('Learning Rate')
|
||
ax2.set_title('Learning Rate Schedule')
|
||
ax2.set_yscale('log')
|
||
ax2.grid(True, alpha=0.3)
|
||
|
||
plt.tight_layout()
|
||
plt.savefig(self.save_dir / 'training_curves.png', dpi=150, bbox_inches='tight')
|
||
plt.close()
|
||
print(f"训练曲线已保存")
|
||
|
||
|
||
# ==================== 阈值计算 ====================
|
||
|
||
@torch.no_grad()
|
||
def compute_thresholds(
|
||
model: nn.Module,
|
||
data_loader: DataLoader,
|
||
device: torch.device,
|
||
percentiles: List[float] = [90, 95, 99]
|
||
) -> Dict[str, float]:
|
||
"""在验证集上计算阈值"""
|
||
model.eval()
|
||
all_errors = []
|
||
|
||
print("计算异动阈值...")
|
||
for batch in tqdm(data_loader, desc="Computing thresholds"):
|
||
batch = batch.to(device)
|
||
errors = model.compute_reconstruction_error(batch, reduction='none')
|
||
seq_errors = errors[:, -1] # 最后一个时刻
|
||
all_errors.append(seq_errors.cpu().numpy())
|
||
|
||
all_errors = np.concatenate(all_errors)
|
||
|
||
thresholds = {}
|
||
for p in percentiles:
|
||
threshold = np.percentile(all_errors, p)
|
||
thresholds[f'p{p}'] = float(threshold)
|
||
print(f" P{p}: {threshold:.6f}")
|
||
|
||
thresholds['mean'] = float(np.mean(all_errors))
|
||
thresholds['std'] = float(np.std(all_errors))
|
||
thresholds['median'] = float(np.median(all_errors))
|
||
|
||
return thresholds
|
||
|
||
|
||
# ==================== 主函数 ====================
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(description='训练 V2 模型')
|
||
parser.add_argument('--data_dir', type=str, default='ml/data_v2', help='V2 数据目录')
|
||
parser.add_argument('--epochs', type=int, default=100)
|
||
parser.add_argument('--batch_size', type=int, default=4096)
|
||
parser.add_argument('--lr', type=float, default=3e-4)
|
||
parser.add_argument('--device', type=str, default='auto')
|
||
parser.add_argument('--save_dir', type=str, default='ml/checkpoints_v2')
|
||
parser.add_argument('--train_end', type=str, default='2024-06-30')
|
||
parser.add_argument('--val_end', type=str, default='2024-09-30')
|
||
parser.add_argument('--seq_len', type=int, default=10, help='序列长度(分钟)')
|
||
|
||
args = parser.parse_args()
|
||
|
||
config = TRAIN_CONFIG.copy()
|
||
config['batch_size'] = args.batch_size
|
||
config['epochs'] = args.epochs
|
||
config['learning_rate'] = args.lr
|
||
config['train_end_date'] = args.train_end
|
||
config['val_end_date'] = args.val_end
|
||
config['seq_len'] = args.seq_len
|
||
|
||
if args.device == 'auto':
|
||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||
else:
|
||
device = torch.device(args.device)
|
||
|
||
print("=" * 60)
|
||
print("概念异动检测模型训练 V2(Z-Score 特征)")
|
||
print("=" * 60)
|
||
print(f"数据目录: {args.data_dir}")
|
||
print(f"设备: {device}")
|
||
print(f"序列长度: {config['seq_len']} 分钟")
|
||
print(f"批次大小: {config['batch_size']}")
|
||
print(f"特征: {config['features']}")
|
||
print("=" * 60)
|
||
|
||
# 1. 加载数据
|
||
print("\n[1/6] 加载 V2 数据...")
|
||
date_data = load_data_by_date(args.data_dir, config['features'])
|
||
|
||
# 2. 划分数据集
|
||
print("\n[2/6] 划分数据集...")
|
||
train_data, val_data, test_data = split_data_by_date(
|
||
date_data, config['train_end_date'], config['val_end_date']
|
||
)
|
||
|
||
# 3. 构建序列
|
||
print("\n[3/6] 构建序列...")
|
||
print("训练集:")
|
||
train_sequences = build_sequences_by_concept(
|
||
train_data, config['features'], config['seq_len'], config['stride']
|
||
)
|
||
print("验证集:")
|
||
val_sequences = build_sequences_by_concept(
|
||
val_data, config['features'], config['seq_len'], config['stride']
|
||
)
|
||
|
||
if len(train_sequences) == 0:
|
||
print("错误: 训练集为空!")
|
||
return
|
||
|
||
# 4. 预处理
|
||
print("\n[4/6] 数据预处理...")
|
||
clip_value = config['clip_value']
|
||
print(f" Z-Score 特征已标准化,截断: ±{clip_value}")
|
||
|
||
train_sequences = np.clip(train_sequences, -clip_value, clip_value)
|
||
if len(val_sequences) > 0:
|
||
val_sequences = np.clip(val_sequences, -clip_value, clip_value)
|
||
|
||
# 保存配置
|
||
save_dir = Path(args.save_dir)
|
||
save_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
with open(save_dir / 'config.json', 'w') as f:
|
||
json.dump(config, f, indent=2)
|
||
|
||
# 5. 创建数据加载器
|
||
print("\n[5/6] 创建数据加载器...")
|
||
train_dataset = SequenceDataset(train_sequences)
|
||
val_dataset = SequenceDataset(val_sequences) if len(val_sequences) > 0 else None
|
||
|
||
print(f" 训练序列: {len(train_dataset):,}")
|
||
print(f" 验证序列: {len(val_dataset) if val_dataset else 0:,}")
|
||
|
||
n_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 1
|
||
num_workers = min(32, 8 * n_gpus) if sys.platform != 'win32' else 0
|
||
|
||
train_loader = DataLoader(
|
||
train_dataset,
|
||
batch_size=config['batch_size'],
|
||
shuffle=True,
|
||
num_workers=num_workers,
|
||
pin_memory=True,
|
||
prefetch_factor=4 if num_workers > 0 else None,
|
||
persistent_workers=True if num_workers > 0 else False,
|
||
drop_last=True
|
||
)
|
||
|
||
val_loader = DataLoader(
|
||
val_dataset,
|
||
batch_size=config['batch_size'] * 2,
|
||
shuffle=False,
|
||
num_workers=num_workers,
|
||
pin_memory=True,
|
||
) if val_dataset else None
|
||
|
||
# 6. 训练
|
||
print("\n[6/6] 训练模型...")
|
||
model = TransformerAutoencoder(**config['model'])
|
||
|
||
if torch.cuda.device_count() > 1:
|
||
print(f" 使用 {torch.cuda.device_count()} 张 GPU 并行训练")
|
||
model = nn.DataParallel(model)
|
||
|
||
if val_loader is None:
|
||
print("警告: 验证集为空,使用训练集的 10% 作为验证")
|
||
split_idx = int(len(train_dataset) * 0.9)
|
||
train_subset = torch.utils.data.Subset(train_dataset, range(split_idx))
|
||
val_subset = torch.utils.data.Subset(train_dataset, range(split_idx, len(train_dataset)))
|
||
train_loader = DataLoader(train_subset, batch_size=config['batch_size'], shuffle=True, num_workers=num_workers, pin_memory=True)
|
||
val_loader = DataLoader(val_subset, batch_size=config['batch_size'], shuffle=False, num_workers=num_workers, pin_memory=True)
|
||
|
||
trainer = Trainer(
|
||
model=model,
|
||
train_loader=train_loader,
|
||
val_loader=val_loader,
|
||
config=config,
|
||
device=device,
|
||
save_dir=args.save_dir
|
||
)
|
||
|
||
trainer.train(config['epochs'])
|
||
|
||
# 计算阈值
|
||
print("\n[额外] 计算异动阈值...")
|
||
best_checkpoint = torch.load(save_dir / 'best_model.pt', map_location=device)
|
||
|
||
# 创建新的单 GPU 模型用于计算阈值(避免 DataParallel 问题)
|
||
threshold_model = TransformerAutoencoder(**config['model'])
|
||
threshold_model.load_state_dict(best_checkpoint['model_state_dict'])
|
||
threshold_model.to(device)
|
||
threshold_model.eval()
|
||
|
||
thresholds = compute_thresholds(threshold_model, val_loader, device, config['threshold_percentiles'])
|
||
|
||
with open(save_dir / 'thresholds.json', 'w') as f:
|
||
json.dump(thresholds, f, indent=2)
|
||
|
||
print("\n" + "=" * 60)
|
||
print("训练完成!")
|
||
print(f"模型保存位置: {args.save_dir}")
|
||
print("=" * 60)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|