113 lines
2.4 KiB
Markdown
113 lines
2.4 KiB
Markdown
# 概念异动检测 ML 模块
|
||
|
||
基于 Transformer Autoencoder 的概念异动检测系统。
|
||
|
||
## 环境要求
|
||
|
||
- Python 3.8+
|
||
- PyTorch 2.0+ (CUDA 12.x for 5090 GPU)
|
||
- ClickHouse, MySQL, Elasticsearch
|
||
|
||
## 数据库配置
|
||
|
||
当前配置(`prepare_data.py`):
|
||
- MySQL: `192.168.1.5:3306`
|
||
- Elasticsearch: `127.0.0.1:9200`
|
||
- ClickHouse: `127.0.0.1:9000`
|
||
|
||
## 快速开始
|
||
|
||
```bash
|
||
# 1. 安装依赖
|
||
pip install -r ml/requirements.txt
|
||
|
||
# 2. 安装 PyTorch (5090 需要 CUDA 12.4)
|
||
pip install torch --index-url https://download.pytorch.org/whl/cu124
|
||
|
||
# 3. 运行训练
|
||
chmod +x ml/run_training.sh
|
||
./ml/run_training.sh
|
||
```
|
||
|
||
## 文件说明
|
||
|
||
| 文件 | 说明 |
|
||
|------|------|
|
||
| `model.py` | Transformer Autoencoder 模型定义 |
|
||
| `prepare_data.py` | 数据提取和特征计算 |
|
||
| `train.py` | 模型训练脚本 |
|
||
| `inference.py` | 推理服务 |
|
||
| `enhanced_detector.py` | 增强版检测器(融合 Alpha + ML) |
|
||
|
||
## 训练参数
|
||
|
||
```bash
|
||
# 完整参数
|
||
./ml/run_training.sh --start 2022-01-01 --end 2024-12-01 --epochs 100 --batch_size 256
|
||
|
||
# 只准备数据
|
||
python ml/prepare_data.py --start 2022-01-01
|
||
|
||
# 只训练(数据已准备好)
|
||
python ml/train.py --epochs 100 --batch_size 256 --lr 1e-4
|
||
```
|
||
|
||
## 模型架构
|
||
|
||
```
|
||
输入: (batch, 30, 6) # 30分钟序列,6个特征
|
||
↓
|
||
Positional Encoding
|
||
↓
|
||
Transformer Encoder (4层, 8头, d=128)
|
||
↓
|
||
Bottleneck (压缩到 32 维)
|
||
↓
|
||
Transformer Decoder (4层)
|
||
↓
|
||
输出: (batch, 30, 6) # 重构序列
|
||
|
||
异动判断: reconstruction_error > threshold
|
||
```
|
||
|
||
## 6维特征
|
||
|
||
1. `alpha` - 超额收益(概念涨幅 - 大盘涨幅)
|
||
2. `alpha_delta` - Alpha 5分钟变化
|
||
3. `amt_ratio` - 成交额 / 20分钟均值
|
||
4. `amt_delta` - 成交额变化率
|
||
5. `rank_pct` - Alpha 排名百分位
|
||
6. `limit_up_ratio` - 涨停股占比
|
||
|
||
## 训练产出
|
||
|
||
训练完成后,`ml/checkpoints/` 包含:
|
||
- `best_model.pt` - 最佳模型权重
|
||
- `thresholds.json` - 异动阈值 (P90/P95/P99)
|
||
- `normalization_stats.json` - 数据标准化参数
|
||
- `config.json` - 训练配置
|
||
|
||
## 使用示例
|
||
|
||
```python
|
||
from ml.inference import ConceptAnomalyDetector
|
||
|
||
detector = ConceptAnomalyDetector('ml/checkpoints')
|
||
|
||
# 实时检测
|
||
is_anomaly, score = detector.detect(
|
||
concept_name="人工智能",
|
||
features={
|
||
'alpha': 2.5,
|
||
'alpha_delta': 0.8,
|
||
'amt_ratio': 1.5,
|
||
'amt_delta': 0.3,
|
||
'rank_pct': 0.95,
|
||
'limit_up_ratio': 0.15,
|
||
}
|
||
)
|
||
|
||
if is_anomaly:
|
||
print(f"检测到异动!分数: {score}")
|
||
```
|