100 lines
2.3 KiB
Bash
100 lines
2.3 KiB
Bash
#!/bin/bash
|
|
# 概念异动检测模型训练脚本 (Linux)
|
|
#
|
|
# 使用方法:
|
|
# chmod +x run_training.sh
|
|
# ./run_training.sh
|
|
#
|
|
# 或指定参数:
|
|
# ./run_training.sh --start 2022-01-01 --epochs 100
|
|
|
|
set -e
|
|
|
|
echo "============================================================"
|
|
echo "概念异动检测模型训练流程"
|
|
echo "============================================================"
|
|
echo ""
|
|
|
|
# 获取脚本所在目录
|
|
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
|
cd "$SCRIPT_DIR/.."
|
|
|
|
echo "[1/4] 检查环境..."
|
|
python3 --version || { echo "Python3 未找到!"; exit 1; }
|
|
|
|
# 检查 GPU
|
|
if python3 -c "import torch; print(f'CUDA: {torch.cuda.is_available()}')" 2>/dev/null; then
|
|
echo "PyTorch GPU 检测完成"
|
|
else
|
|
echo "警告: PyTorch 未安装或无法检测 GPU"
|
|
fi
|
|
|
|
echo ""
|
|
echo "[2/4] 检查依赖..."
|
|
pip3 install -q torch pandas numpy pyarrow tqdm clickhouse-driver elasticsearch sqlalchemy pymysql
|
|
|
|
echo ""
|
|
echo "[3/4] 准备训练数据..."
|
|
echo "从 ClickHouse 提取历史数据,这可能需要较长时间..."
|
|
echo ""
|
|
|
|
# 解析参数
|
|
START_DATE="2022-01-01"
|
|
END_DATE=""
|
|
EPOCHS=100
|
|
BATCH_SIZE=256
|
|
TRAIN_END="2025-06-30"
|
|
VAL_END="2025-09-30"
|
|
|
|
while [[ $# -gt 0 ]]; do
|
|
case $1 in
|
|
--start)
|
|
START_DATE="$2"
|
|
shift 2
|
|
;;
|
|
--end)
|
|
END_DATE="$2"
|
|
shift 2
|
|
;;
|
|
--epochs)
|
|
EPOCHS="$2"
|
|
shift 2
|
|
;;
|
|
--batch_size)
|
|
BATCH_SIZE="$2"
|
|
shift 2
|
|
;;
|
|
--train_end)
|
|
TRAIN_END="$2"
|
|
shift 2
|
|
;;
|
|
--val_end)
|
|
VAL_END="$2"
|
|
shift 2
|
|
;;
|
|
*)
|
|
shift
|
|
;;
|
|
esac
|
|
done
|
|
|
|
# 数据准备
|
|
if [ -n "$END_DATE" ]; then
|
|
python3 ml/prepare_data.py --start "$START_DATE" --end "$END_DATE"
|
|
else
|
|
python3 ml/prepare_data.py --start "$START_DATE"
|
|
fi
|
|
|
|
echo ""
|
|
echo "[4/4] 训练模型..."
|
|
echo "使用 GPU 加速训练..."
|
|
echo ""
|
|
|
|
python3 ml/train.py --epochs "$EPOCHS" --batch_size "$BATCH_SIZE" --train_end "$TRAIN_END" --val_end "$VAL_END"
|
|
|
|
echo ""
|
|
echo "============================================================"
|
|
echo "训练完成!"
|
|
echo "模型保存在: ml/checkpoints/"
|
|
echo "============================================================"
|