Files
vf_react/ml/run_training.sh
2025-12-09 08:31:18 +08:00

100 lines
2.3 KiB
Bash

#!/bin/bash
# 概念异动检测模型训练脚本 (Linux)
#
# 使用方法:
# chmod +x run_training.sh
# ./run_training.sh
#
# 或指定参数:
# ./run_training.sh --start 2022-01-01 --epochs 100
set -e
echo "============================================================"
echo "概念异动检测模型训练流程"
echo "============================================================"
echo ""
# 获取脚本所在目录
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
cd "$SCRIPT_DIR/.."
echo "[1/4] 检查环境..."
python3 --version || { echo "Python3 未找到!"; exit 1; }
# 检查 GPU
if python3 -c "import torch; print(f'CUDA: {torch.cuda.is_available()}')" 2>/dev/null; then
echo "PyTorch GPU 检测完成"
else
echo "警告: PyTorch 未安装或无法检测 GPU"
fi
echo ""
echo "[2/4] 检查依赖..."
pip3 install -q torch pandas numpy pyarrow tqdm clickhouse-driver elasticsearch sqlalchemy pymysql
echo ""
echo "[3/4] 准备训练数据..."
echo "从 ClickHouse 提取历史数据,这可能需要较长时间..."
echo ""
# 解析参数
START_DATE="2022-01-01"
END_DATE=""
EPOCHS=100
BATCH_SIZE=256
TRAIN_END="2025-06-30"
VAL_END="2025-09-30"
while [[ $# -gt 0 ]]; do
case $1 in
--start)
START_DATE="$2"
shift 2
;;
--end)
END_DATE="$2"
shift 2
;;
--epochs)
EPOCHS="$2"
shift 2
;;
--batch_size)
BATCH_SIZE="$2"
shift 2
;;
--train_end)
TRAIN_END="$2"
shift 2
;;
--val_end)
VAL_END="$2"
shift 2
;;
*)
shift
;;
esac
done
# 数据准备
if [ -n "$END_DATE" ]; then
python3 ml/prepare_data.py --start "$START_DATE" --end "$END_DATE"
else
python3 ml/prepare_data.py --start "$START_DATE"
fi
echo ""
echo "[4/4] 训练模型..."
echo "使用 GPU 加速训练..."
echo ""
python3 ml/train.py --epochs "$EPOCHS" --batch_size "$BATCH_SIZE" --train_end "$TRAIN_END" --val_end "$VAL_END"
echo ""
echo "============================================================"
echo "训练完成!"
echo "模型保存在: ml/checkpoints/"
echo "============================================================"