update pay ui
This commit is contained in:
112
ml/README.md
112
ml/README.md
@@ -1,112 +0,0 @@
|
|||||||
# 概念异动检测 ML 模块
|
|
||||||
|
|
||||||
基于 Transformer Autoencoder 的概念异动检测系统。
|
|
||||||
|
|
||||||
## 环境要求
|
|
||||||
|
|
||||||
- Python 3.8+
|
|
||||||
- PyTorch 2.0+ (CUDA 12.x for 5090 GPU)
|
|
||||||
- ClickHouse, MySQL, Elasticsearch
|
|
||||||
|
|
||||||
## 数据库配置
|
|
||||||
|
|
||||||
当前配置(`prepare_data.py`):
|
|
||||||
- MySQL: `192.168.1.5:3306`
|
|
||||||
- Elasticsearch: `127.0.0.1:9200`
|
|
||||||
- ClickHouse: `127.0.0.1:9000`
|
|
||||||
|
|
||||||
## 快速开始
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# 1. 安装依赖
|
|
||||||
pip install -r ml/requirements.txt
|
|
||||||
|
|
||||||
# 2. 安装 PyTorch (5090 需要 CUDA 12.4)
|
|
||||||
pip install torch --index-url https://download.pytorch.org/whl/cu124
|
|
||||||
|
|
||||||
# 3. 运行训练
|
|
||||||
chmod +x ml/run_training.sh
|
|
||||||
./ml/run_training.sh
|
|
||||||
```
|
|
||||||
|
|
||||||
## 文件说明
|
|
||||||
|
|
||||||
| 文件 | 说明 |
|
|
||||||
|------|------|
|
|
||||||
| `model.py` | Transformer Autoencoder 模型定义 |
|
|
||||||
| `prepare_data.py` | 数据提取和特征计算 |
|
|
||||||
| `train.py` | 模型训练脚本 |
|
|
||||||
| `inference.py` | 推理服务 |
|
|
||||||
| `enhanced_detector.py` | 增强版检测器(融合 Alpha + ML) |
|
|
||||||
|
|
||||||
## 训练参数
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# 完整参数
|
|
||||||
./ml/run_training.sh --start 2022-01-01 --end 2024-12-01 --epochs 100 --batch_size 256
|
|
||||||
|
|
||||||
# 只准备数据
|
|
||||||
python ml/prepare_data.py --start 2022-01-01
|
|
||||||
|
|
||||||
# 只训练(数据已准备好)
|
|
||||||
python ml/train.py --epochs 100 --batch_size 256 --lr 1e-4
|
|
||||||
```
|
|
||||||
|
|
||||||
## 模型架构
|
|
||||||
|
|
||||||
```
|
|
||||||
输入: (batch, 30, 6) # 30分钟序列,6个特征
|
|
||||||
↓
|
|
||||||
Positional Encoding
|
|
||||||
↓
|
|
||||||
Transformer Encoder (4层, 8头, d=128)
|
|
||||||
↓
|
|
||||||
Bottleneck (压缩到 32 维)
|
|
||||||
↓
|
|
||||||
Transformer Decoder (4层)
|
|
||||||
↓
|
|
||||||
输出: (batch, 30, 6) # 重构序列
|
|
||||||
|
|
||||||
异动判断: reconstruction_error > threshold
|
|
||||||
```
|
|
||||||
|
|
||||||
## 6维特征
|
|
||||||
|
|
||||||
1. `alpha` - 超额收益(概念涨幅 - 大盘涨幅)
|
|
||||||
2. `alpha_delta` - Alpha 5分钟变化
|
|
||||||
3. `amt_ratio` - 成交额 / 20分钟均值
|
|
||||||
4. `amt_delta` - 成交额变化率
|
|
||||||
5. `rank_pct` - Alpha 排名百分位
|
|
||||||
6. `limit_up_ratio` - 涨停股占比
|
|
||||||
|
|
||||||
## 训练产出
|
|
||||||
|
|
||||||
训练完成后,`ml/checkpoints/` 包含:
|
|
||||||
- `best_model.pt` - 最佳模型权重
|
|
||||||
- `thresholds.json` - 异动阈值 (P90/P95/P99)
|
|
||||||
- `normalization_stats.json` - 数据标准化参数
|
|
||||||
- `config.json` - 训练配置
|
|
||||||
|
|
||||||
## 使用示例
|
|
||||||
|
|
||||||
```python
|
|
||||||
from ml.inference import ConceptAnomalyDetector
|
|
||||||
|
|
||||||
detector = ConceptAnomalyDetector('ml/checkpoints')
|
|
||||||
|
|
||||||
# 实时检测
|
|
||||||
is_anomaly, score = detector.detect(
|
|
||||||
concept_name="人工智能",
|
|
||||||
features={
|
|
||||||
'alpha': 2.5,
|
|
||||||
'alpha_delta': 0.8,
|
|
||||||
'amt_ratio': 1.5,
|
|
||||||
'amt_delta': 0.3,
|
|
||||||
'rank_pct': 0.95,
|
|
||||||
'limit_up_ratio': 0.15,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_anomaly:
|
|
||||||
print(f"检测到异动!分数: {score}")
|
|
||||||
```
|
|
||||||
@@ -1,10 +0,0 @@
|
|||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""
|
|
||||||
概念异动检测 ML 模块
|
|
||||||
|
|
||||||
提供基于 Transformer Autoencoder 的异动检测功能
|
|
||||||
"""
|
|
||||||
|
|
||||||
from .inference import ConceptAnomalyDetector, MLAnomalyService
|
|
||||||
|
|
||||||
__all__ = ['ConceptAnomalyDetector', 'MLAnomalyService']
|
|
||||||
Binary file not shown.
481
ml/backtest.py
481
ml/backtest.py
@@ -1,481 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""
|
|
||||||
历史异动回测脚本
|
|
||||||
|
|
||||||
使用训练好的模型,对历史数据进行异动检测,生成异动记录
|
|
||||||
|
|
||||||
使用方法:
|
|
||||||
# 回测指定日期范围
|
|
||||||
python backtest.py --start 2024-01-01 --end 2024-12-01
|
|
||||||
|
|
||||||
# 回测单天
|
|
||||||
python backtest.py --start 2024-11-01 --end 2024-11-01
|
|
||||||
|
|
||||||
# 只生成结果,不写入数据库
|
|
||||||
python backtest.py --start 2024-01-01 --dry-run
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import argparse
|
|
||||||
import json
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Dict, List, Tuple, Optional
|
|
||||||
from collections import defaultdict
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pandas as pd
|
|
||||||
import torch
|
|
||||||
from tqdm import tqdm
|
|
||||||
from sqlalchemy import create_engine, text
|
|
||||||
|
|
||||||
# 添加父目录到路径
|
|
||||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
||||||
|
|
||||||
from model import TransformerAutoencoder
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 配置 ====================
|
|
||||||
|
|
||||||
MYSQL_ENGINE = create_engine(
|
|
||||||
"mysql+pymysql://root:Zzl5588161!@192.168.1.5:3306/stock",
|
|
||||||
echo=False
|
|
||||||
)
|
|
||||||
|
|
||||||
# 特征列表(与训练一致)
|
|
||||||
FEATURES = [
|
|
||||||
'alpha',
|
|
||||||
'alpha_delta',
|
|
||||||
'amt_ratio',
|
|
||||||
'amt_delta',
|
|
||||||
'rank_pct',
|
|
||||||
'limit_up_ratio',
|
|
||||||
]
|
|
||||||
|
|
||||||
# 回测配置
|
|
||||||
BACKTEST_CONFIG = {
|
|
||||||
'seq_len': 30, # 序列长度
|
|
||||||
'threshold_key': 'p95', # 使用的阈值
|
|
||||||
'min_alpha_abs': 0.5, # 最小 Alpha 绝对值(过滤微小波动)
|
|
||||||
'cooldown_minutes': 8, # 同一概念冷却时间
|
|
||||||
'max_alerts_per_minute': 15, # 每分钟最多异动数
|
|
||||||
'clip_value': 10.0, # 极端值截断
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 模型加载 ====================
|
|
||||||
|
|
||||||
class AnomalyDetector:
|
|
||||||
"""异动检测器"""
|
|
||||||
|
|
||||||
def __init__(self, checkpoint_dir: str = 'ml/checkpoints', device: str = 'auto'):
|
|
||||||
self.checkpoint_dir = Path(checkpoint_dir)
|
|
||||||
|
|
||||||
# 设备
|
|
||||||
if device == 'auto':
|
|
||||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
||||||
else:
|
|
||||||
self.device = torch.device(device)
|
|
||||||
|
|
||||||
# 加载配置
|
|
||||||
self._load_config()
|
|
||||||
|
|
||||||
# 加载模型
|
|
||||||
self._load_model()
|
|
||||||
|
|
||||||
# 加载阈值
|
|
||||||
self._load_thresholds()
|
|
||||||
|
|
||||||
print(f"AnomalyDetector 初始化完成")
|
|
||||||
print(f" 设备: {self.device}")
|
|
||||||
print(f" 阈值 ({BACKTEST_CONFIG['threshold_key']}): {self.threshold:.6f}")
|
|
||||||
|
|
||||||
def _load_config(self):
|
|
||||||
config_path = self.checkpoint_dir / 'config.json'
|
|
||||||
with open(config_path, 'r') as f:
|
|
||||||
self.config = json.load(f)
|
|
||||||
|
|
||||||
def _load_model(self):
|
|
||||||
model_path = self.checkpoint_dir / 'best_model.pt'
|
|
||||||
checkpoint = torch.load(model_path, map_location=self.device)
|
|
||||||
|
|
||||||
model_config = self.config['model'].copy()
|
|
||||||
model_config['use_instance_norm'] = self.config.get('use_instance_norm', True)
|
|
||||||
|
|
||||||
self.model = TransformerAutoencoder(**model_config)
|
|
||||||
self.model.load_state_dict(checkpoint['model_state_dict'])
|
|
||||||
self.model.to(self.device)
|
|
||||||
self.model.eval()
|
|
||||||
|
|
||||||
def _load_thresholds(self):
|
|
||||||
thresholds_path = self.checkpoint_dir / 'thresholds.json'
|
|
||||||
with open(thresholds_path, 'r') as f:
|
|
||||||
thresholds = json.load(f)
|
|
||||||
|
|
||||||
self.threshold = thresholds[BACKTEST_CONFIG['threshold_key']]
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def compute_anomaly_scores(self, sequences: np.ndarray) -> np.ndarray:
|
|
||||||
"""
|
|
||||||
计算异动分数
|
|
||||||
|
|
||||||
Args:
|
|
||||||
sequences: (n_sequences, seq_len, n_features)
|
|
||||||
Returns:
|
|
||||||
scores: (n_sequences,) 每个序列最后时刻的异动分数
|
|
||||||
"""
|
|
||||||
# 截断极端值
|
|
||||||
sequences = np.clip(sequences, -BACKTEST_CONFIG['clip_value'], BACKTEST_CONFIG['clip_value'])
|
|
||||||
|
|
||||||
# 转为 tensor
|
|
||||||
x = torch.FloatTensor(sequences).to(self.device)
|
|
||||||
|
|
||||||
# 计算重构误差
|
|
||||||
errors = self.model.compute_reconstruction_error(x, reduction='none')
|
|
||||||
|
|
||||||
# 取最后一个时刻的误差
|
|
||||||
scores = errors[:, -1].cpu().numpy()
|
|
||||||
|
|
||||||
return scores
|
|
||||||
|
|
||||||
def is_anomaly(self, score: float) -> bool:
|
|
||||||
"""判断是否异动"""
|
|
||||||
return score > self.threshold
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 数据加载 ====================
|
|
||||||
|
|
||||||
def load_daily_features(data_dir: str, date: str) -> Optional[pd.DataFrame]:
|
|
||||||
"""加载单天的特征数据"""
|
|
||||||
file_path = Path(data_dir) / f"features_{date}.parquet"
|
|
||||||
|
|
||||||
if not file_path.exists():
|
|
||||||
return None
|
|
||||||
|
|
||||||
df = pd.read_parquet(file_path)
|
|
||||||
return df
|
|
||||||
|
|
||||||
|
|
||||||
def get_available_dates(data_dir: str, start_date: str, end_date: str) -> List[str]:
|
|
||||||
"""获取可用的日期列表"""
|
|
||||||
data_path = Path(data_dir)
|
|
||||||
all_files = sorted(data_path.glob("features_*.parquet"))
|
|
||||||
|
|
||||||
dates = []
|
|
||||||
for f in all_files:
|
|
||||||
date = f.stem.replace('features_', '')
|
|
||||||
if start_date <= date <= end_date:
|
|
||||||
dates.append(date)
|
|
||||||
|
|
||||||
return dates
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 回测逻辑 ====================
|
|
||||||
|
|
||||||
def backtest_single_day(
|
|
||||||
detector: AnomalyDetector,
|
|
||||||
df: pd.DataFrame,
|
|
||||||
date: str,
|
|
||||||
seq_len: int = 30
|
|
||||||
) -> List[Dict]:
|
|
||||||
"""
|
|
||||||
回测单天数据
|
|
||||||
|
|
||||||
Args:
|
|
||||||
detector: 异动检测器
|
|
||||||
df: 当天的特征数据
|
|
||||||
date: 日期
|
|
||||||
seq_len: 序列长度
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
alerts: 异动列表
|
|
||||||
"""
|
|
||||||
alerts = []
|
|
||||||
|
|
||||||
# 按概念分组
|
|
||||||
grouped = df.groupby('concept_id', sort=False)
|
|
||||||
|
|
||||||
# 冷却记录 {concept_id: last_alert_timestamp}
|
|
||||||
cooldown = {}
|
|
||||||
|
|
||||||
# 获取所有时间点
|
|
||||||
all_timestamps = sorted(df['timestamp'].unique())
|
|
||||||
|
|
||||||
if len(all_timestamps) < seq_len:
|
|
||||||
return alerts
|
|
||||||
|
|
||||||
# 对每个时间点进行检测(从第 seq_len 个开始)
|
|
||||||
for t_idx in range(seq_len - 1, len(all_timestamps)):
|
|
||||||
current_time = all_timestamps[t_idx]
|
|
||||||
window_start_time = all_timestamps[t_idx - seq_len + 1]
|
|
||||||
|
|
||||||
minute_alerts = []
|
|
||||||
|
|
||||||
# 收集该时刻所有概念的序列
|
|
||||||
concept_sequences = []
|
|
||||||
concept_infos = []
|
|
||||||
|
|
||||||
for concept_id, concept_df in grouped:
|
|
||||||
# 获取该概念在时间窗口内的数据
|
|
||||||
mask = (concept_df['timestamp'] >= window_start_time) & (concept_df['timestamp'] <= current_time)
|
|
||||||
window_df = concept_df[mask].sort_values('timestamp')
|
|
||||||
|
|
||||||
if len(window_df) < seq_len:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 取最后 seq_len 个点
|
|
||||||
window_df = window_df.tail(seq_len)
|
|
||||||
|
|
||||||
# 提取特征
|
|
||||||
features = window_df[FEATURES].values
|
|
||||||
|
|
||||||
# 处理缺失值
|
|
||||||
features = np.nan_to_num(features, nan=0.0, posinf=0.0, neginf=0.0)
|
|
||||||
|
|
||||||
# 获取当前时刻的信息
|
|
||||||
current_row = window_df.iloc[-1]
|
|
||||||
|
|
||||||
concept_sequences.append(features)
|
|
||||||
concept_infos.append({
|
|
||||||
'concept_id': concept_id,
|
|
||||||
'timestamp': current_time,
|
|
||||||
'alpha': current_row.get('alpha', 0),
|
|
||||||
'alpha_delta': current_row.get('alpha_delta', 0),
|
|
||||||
'amt_ratio': current_row.get('amt_ratio', 1),
|
|
||||||
'limit_up_ratio': current_row.get('limit_up_ratio', 0),
|
|
||||||
'limit_down_ratio': current_row.get('limit_down_ratio', 0),
|
|
||||||
'rank_pct': current_row.get('rank_pct', 0.5),
|
|
||||||
'stock_count': current_row.get('stock_count', 0),
|
|
||||||
'total_amt': current_row.get('total_amt', 0),
|
|
||||||
})
|
|
||||||
|
|
||||||
if not concept_sequences:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 批量计算异动分数
|
|
||||||
sequences_array = np.array(concept_sequences)
|
|
||||||
scores = detector.compute_anomaly_scores(sequences_array)
|
|
||||||
|
|
||||||
# 检测异动
|
|
||||||
for i, (info, score) in enumerate(zip(concept_infos, scores)):
|
|
||||||
concept_id = info['concept_id']
|
|
||||||
alpha = info['alpha']
|
|
||||||
|
|
||||||
# 过滤小波动
|
|
||||||
if abs(alpha) < BACKTEST_CONFIG['min_alpha_abs']:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 检查冷却
|
|
||||||
if concept_id in cooldown:
|
|
||||||
last_alert = cooldown[concept_id]
|
|
||||||
if isinstance(current_time, datetime):
|
|
||||||
time_diff = (current_time - last_alert).total_seconds() / 60
|
|
||||||
else:
|
|
||||||
# timestamp 是字符串或其他格式
|
|
||||||
time_diff = BACKTEST_CONFIG['cooldown_minutes'] + 1 # 跳过冷却检查
|
|
||||||
|
|
||||||
if time_diff < BACKTEST_CONFIG['cooldown_minutes']:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 判断是否异动
|
|
||||||
if not detector.is_anomaly(score):
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 记录异动
|
|
||||||
alert_type = 'surge_up' if alpha > 0 else 'surge_down'
|
|
||||||
|
|
||||||
alert = {
|
|
||||||
'concept_id': concept_id,
|
|
||||||
'alert_time': current_time,
|
|
||||||
'trade_date': date,
|
|
||||||
'alert_type': alert_type,
|
|
||||||
'anomaly_score': float(score),
|
|
||||||
'threshold': detector.threshold,
|
|
||||||
**info
|
|
||||||
}
|
|
||||||
|
|
||||||
minute_alerts.append(alert)
|
|
||||||
cooldown[concept_id] = current_time
|
|
||||||
|
|
||||||
# 按分数排序,限制数量
|
|
||||||
minute_alerts.sort(key=lambda x: x['anomaly_score'], reverse=True)
|
|
||||||
alerts.extend(minute_alerts[:BACKTEST_CONFIG['max_alerts_per_minute']])
|
|
||||||
|
|
||||||
return alerts
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 数据库写入 ====================
|
|
||||||
|
|
||||||
def save_alerts_to_mysql(alerts: List[Dict], dry_run: bool = False) -> int:
|
|
||||||
"""保存异动到 MySQL"""
|
|
||||||
if not alerts:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
if dry_run:
|
|
||||||
print(f" [Dry Run] 将写入 {len(alerts)} 条异动")
|
|
||||||
return len(alerts)
|
|
||||||
|
|
||||||
saved = 0
|
|
||||||
with MYSQL_ENGINE.begin() as conn:
|
|
||||||
for alert in alerts:
|
|
||||||
try:
|
|
||||||
# 检查是否已存在
|
|
||||||
check_sql = text("""
|
|
||||||
SELECT id FROM concept_minute_alert
|
|
||||||
WHERE concept_id = :concept_id
|
|
||||||
AND alert_time = :alert_time
|
|
||||||
AND trade_date = :trade_date
|
|
||||||
""")
|
|
||||||
exists = conn.execute(check_sql, {
|
|
||||||
'concept_id': alert['concept_id'],
|
|
||||||
'alert_time': alert['alert_time'],
|
|
||||||
'trade_date': alert['trade_date'],
|
|
||||||
}).fetchone()
|
|
||||||
|
|
||||||
if exists:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 插入新记录
|
|
||||||
insert_sql = text("""
|
|
||||||
INSERT INTO concept_minute_alert
|
|
||||||
(concept_id, concept_name, alert_time, alert_type, trade_date,
|
|
||||||
change_pct, zscore, importance_score, stock_count, extra_info)
|
|
||||||
VALUES
|
|
||||||
(:concept_id, :concept_name, :alert_time, :alert_type, :trade_date,
|
|
||||||
:change_pct, :zscore, :importance_score, :stock_count, :extra_info)
|
|
||||||
""")
|
|
||||||
|
|
||||||
conn.execute(insert_sql, {
|
|
||||||
'concept_id': alert['concept_id'],
|
|
||||||
'concept_name': alert.get('concept_name', ''),
|
|
||||||
'alert_time': alert['alert_time'],
|
|
||||||
'alert_type': alert['alert_type'],
|
|
||||||
'trade_date': alert['trade_date'],
|
|
||||||
'change_pct': alert.get('alpha', 0),
|
|
||||||
'zscore': alert['anomaly_score'],
|
|
||||||
'importance_score': alert['anomaly_score'],
|
|
||||||
'stock_count': alert.get('stock_count', 0),
|
|
||||||
'extra_info': json.dumps({
|
|
||||||
'detection_method': 'ml_autoencoder',
|
|
||||||
'threshold': alert['threshold'],
|
|
||||||
'alpha': alert.get('alpha', 0),
|
|
||||||
'amt_ratio': alert.get('amt_ratio', 1),
|
|
||||||
}, ensure_ascii=False)
|
|
||||||
})
|
|
||||||
|
|
||||||
saved += 1
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f" 保存失败: {alert['concept_id']} - {e}")
|
|
||||||
|
|
||||||
return saved
|
|
||||||
|
|
||||||
|
|
||||||
def export_alerts_to_csv(alerts: List[Dict], output_path: str):
|
|
||||||
"""导出异动到 CSV"""
|
|
||||||
if not alerts:
|
|
||||||
return
|
|
||||||
|
|
||||||
df = pd.DataFrame(alerts)
|
|
||||||
df.to_csv(output_path, index=False, encoding='utf-8-sig')
|
|
||||||
print(f"已导出到: {output_path}")
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 主函数 ====================
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser(description='历史异动回测')
|
|
||||||
parser.add_argument('--data_dir', type=str, default='ml/data',
|
|
||||||
help='特征数据目录')
|
|
||||||
parser.add_argument('--checkpoint_dir', type=str, default='ml/checkpoints',
|
|
||||||
help='模型检查点目录')
|
|
||||||
parser.add_argument('--start', type=str, required=True,
|
|
||||||
help='开始日期 (YYYY-MM-DD)')
|
|
||||||
parser.add_argument('--end', type=str, required=True,
|
|
||||||
help='结束日期 (YYYY-MM-DD)')
|
|
||||||
parser.add_argument('--dry-run', action='store_true',
|
|
||||||
help='只计算,不写入数据库')
|
|
||||||
parser.add_argument('--export-csv', type=str, default=None,
|
|
||||||
help='导出 CSV 文件路径')
|
|
||||||
parser.add_argument('--device', type=str, default='auto',
|
|
||||||
help='设备 (auto/cuda/cpu)')
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
print("=" * 60)
|
|
||||||
print("历史异动回测")
|
|
||||||
print("=" * 60)
|
|
||||||
print(f"日期范围: {args.start} ~ {args.end}")
|
|
||||||
print(f"数据目录: {args.data_dir}")
|
|
||||||
print(f"模型目录: {args.checkpoint_dir}")
|
|
||||||
print(f"Dry Run: {args.dry_run}")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
# 初始化检测器
|
|
||||||
detector = AnomalyDetector(args.checkpoint_dir, args.device)
|
|
||||||
|
|
||||||
# 获取可用日期
|
|
||||||
dates = get_available_dates(args.data_dir, args.start, args.end)
|
|
||||||
|
|
||||||
if not dates:
|
|
||||||
print(f"未找到 {args.start} ~ {args.end} 范围内的数据")
|
|
||||||
return
|
|
||||||
|
|
||||||
print(f"\n找到 {len(dates)} 天的数据")
|
|
||||||
|
|
||||||
# 回测
|
|
||||||
all_alerts = []
|
|
||||||
total_saved = 0
|
|
||||||
|
|
||||||
for date in tqdm(dates, desc="回测进度"):
|
|
||||||
# 加载数据
|
|
||||||
df = load_daily_features(args.data_dir, date)
|
|
||||||
|
|
||||||
if df is None or df.empty:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 回测单天
|
|
||||||
alerts = backtest_single_day(
|
|
||||||
detector, df, date,
|
|
||||||
seq_len=BACKTEST_CONFIG['seq_len']
|
|
||||||
)
|
|
||||||
|
|
||||||
if alerts:
|
|
||||||
all_alerts.extend(alerts)
|
|
||||||
|
|
||||||
# 写入数据库
|
|
||||||
saved = save_alerts_to_mysql(alerts, dry_run=args.dry_run)
|
|
||||||
total_saved += saved
|
|
||||||
|
|
||||||
if not args.dry_run:
|
|
||||||
tqdm.write(f" {date}: 检测到 {len(alerts)} 个异动,保存 {saved} 条")
|
|
||||||
|
|
||||||
# 导出 CSV
|
|
||||||
if args.export_csv and all_alerts:
|
|
||||||
export_alerts_to_csv(all_alerts, args.export_csv)
|
|
||||||
|
|
||||||
# 汇总
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
print("回测完成!")
|
|
||||||
print("=" * 60)
|
|
||||||
print(f"总计检测到: {len(all_alerts)} 个异动")
|
|
||||||
print(f"保存到数据库: {total_saved} 条")
|
|
||||||
|
|
||||||
# 统计
|
|
||||||
if all_alerts:
|
|
||||||
df_alerts = pd.DataFrame(all_alerts)
|
|
||||||
print(f"\n异动类型分布:")
|
|
||||||
print(df_alerts['alert_type'].value_counts())
|
|
||||||
|
|
||||||
print(f"\n异动分数统计:")
|
|
||||||
print(f" Mean: {df_alerts['anomaly_score'].mean():.4f}")
|
|
||||||
print(f" Max: {df_alerts['anomaly_score'].max():.4f}")
|
|
||||||
print(f" Min: {df_alerts['anomaly_score'].min():.4f}")
|
|
||||||
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -1,859 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""
|
|
||||||
快速融合异动回测脚本
|
|
||||||
|
|
||||||
优化策略:
|
|
||||||
1. 预先构建所有序列(向量化),避免循环内重复切片
|
|
||||||
2. 批量 ML 推理(一次推理所有候选)
|
|
||||||
3. 使用 NumPy 向量化操作替代 Python 循环
|
|
||||||
|
|
||||||
性能对比:
|
|
||||||
- 原版:5分钟/天
|
|
||||||
- 优化版:预计 10-30秒/天
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import argparse
|
|
||||||
import json
|
|
||||||
from datetime import datetime
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Dict, List, Optional, Tuple
|
|
||||||
from collections import defaultdict
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pandas as pd
|
|
||||||
import torch
|
|
||||||
from tqdm import tqdm
|
|
||||||
from sqlalchemy import create_engine, text
|
|
||||||
|
|
||||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 配置 ====================
|
|
||||||
|
|
||||||
MYSQL_ENGINE = create_engine(
|
|
||||||
"mysql+pymysql://root:Zzl5588161!@192.168.1.5:3306/stock",
|
|
||||||
echo=False
|
|
||||||
)
|
|
||||||
|
|
||||||
FEATURES = ['alpha', 'alpha_delta', 'amt_ratio', 'amt_delta', 'rank_pct', 'limit_up_ratio']
|
|
||||||
|
|
||||||
CONFIG = {
|
|
||||||
'seq_len': 15, # 序列长度(支持跨日后可从 9:30 检测)
|
|
||||||
'min_alpha_abs': 0.3, # 最小 alpha 过滤
|
|
||||||
'cooldown_minutes': 8,
|
|
||||||
'max_alerts_per_minute': 20,
|
|
||||||
'clip_value': 10.0,
|
|
||||||
# === 融合权重:均衡 ===
|
|
||||||
'rule_weight': 0.5,
|
|
||||||
'ml_weight': 0.5,
|
|
||||||
# === 触发阈值 ===
|
|
||||||
'rule_trigger': 65, # 60 -> 65,略提高规则门槛
|
|
||||||
'ml_trigger': 70, # 75 -> 70,略降低 ML 门槛
|
|
||||||
'fusion_trigger': 45,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 规则评分(向量化版)====================
|
|
||||||
|
|
||||||
def get_size_adjusted_thresholds(stock_count: np.ndarray) -> dict:
|
|
||||||
"""
|
|
||||||
根据概念股票数量计算动态阈值
|
|
||||||
|
|
||||||
设计思路:
|
|
||||||
- 小概念(<10 只):波动大是正常的,需要更高阈值
|
|
||||||
- 中概念(10-50 只):标准阈值
|
|
||||||
- 大概念(>50 只):能有明显波动说明是真异动,降低阈值
|
|
||||||
|
|
||||||
返回各指标的调整系数(乘以基准阈值)
|
|
||||||
"""
|
|
||||||
n = len(stock_count)
|
|
||||||
|
|
||||||
# 基于股票数量的调整系数
|
|
||||||
# 小概念:系数 > 1(提高阈值,更难触发)
|
|
||||||
# 大概念:系数 < 1(降低阈值,更容易触发)
|
|
||||||
size_factor = np.ones(n)
|
|
||||||
|
|
||||||
# 微型概念(<5 只):阈值 × 1.8
|
|
||||||
tiny = stock_count < 5
|
|
||||||
size_factor[tiny] = 1.8
|
|
||||||
|
|
||||||
# 小概念(5-10 只):阈值 × 1.4
|
|
||||||
small = (stock_count >= 5) & (stock_count < 10)
|
|
||||||
size_factor[small] = 1.4
|
|
||||||
|
|
||||||
# 中小概念(10-20 只):阈值 × 1.2
|
|
||||||
medium_small = (stock_count >= 10) & (stock_count < 20)
|
|
||||||
size_factor[medium_small] = 1.2
|
|
||||||
|
|
||||||
# 中概念(20-50 只):标准阈值 × 1.0
|
|
||||||
medium = (stock_count >= 20) & (stock_count < 50)
|
|
||||||
size_factor[medium] = 1.0
|
|
||||||
|
|
||||||
# 大概念(50-100 只):阈值 × 0.85
|
|
||||||
large = (stock_count >= 50) & (stock_count < 100)
|
|
||||||
size_factor[large] = 0.85
|
|
||||||
|
|
||||||
# 超大概念(>100 只):阈值 × 0.7
|
|
||||||
xlarge = stock_count >= 100
|
|
||||||
size_factor[xlarge] = 0.7
|
|
||||||
|
|
||||||
return size_factor
|
|
||||||
|
|
||||||
|
|
||||||
def score_rules_batch(df: pd.DataFrame) -> Tuple[np.ndarray, List[List[str]]]:
|
|
||||||
"""
|
|
||||||
批量计算规则得分(向量化)- 考虑概念规模版
|
|
||||||
|
|
||||||
设计原则:
|
|
||||||
- 规则作为辅助信号,不应单独主导决策
|
|
||||||
- 根据概念股票数量动态调整阈值
|
|
||||||
- 大概念异动更有价值,小概念需要更大波动才算异动
|
|
||||||
|
|
||||||
Args:
|
|
||||||
df: DataFrame,包含所有特征列(必须包含 stock_count)
|
|
||||||
Returns:
|
|
||||||
scores: (n,) 规则得分数组
|
|
||||||
triggered_rules: 每行触发的规则列表
|
|
||||||
"""
|
|
||||||
n = len(df)
|
|
||||||
scores = np.zeros(n)
|
|
||||||
triggered = [[] for _ in range(n)]
|
|
||||||
|
|
||||||
alpha = df['alpha'].values
|
|
||||||
alpha_delta = df['alpha_delta'].values
|
|
||||||
amt_ratio = df['amt_ratio'].values
|
|
||||||
amt_delta = df['amt_delta'].values
|
|
||||||
rank_pct = df['rank_pct'].values
|
|
||||||
limit_up_ratio = df['limit_up_ratio'].values
|
|
||||||
stock_count = df['stock_count'].values if 'stock_count' in df.columns else np.full(n, 20)
|
|
||||||
|
|
||||||
alpha_abs = np.abs(alpha)
|
|
||||||
alpha_delta_abs = np.abs(alpha_delta)
|
|
||||||
|
|
||||||
# 获取基于规模的调整系数
|
|
||||||
size_factor = get_size_adjusted_thresholds(stock_count)
|
|
||||||
|
|
||||||
# ========== Alpha 规则(动态阈值)==========
|
|
||||||
# 基准阈值:极强 5%,强 4%,中等 3%
|
|
||||||
# 实际阈值 = 基准 × size_factor
|
|
||||||
|
|
||||||
# 极强信号
|
|
||||||
alpha_extreme_thresh = 5.0 * size_factor
|
|
||||||
mask = alpha_abs >= alpha_extreme_thresh
|
|
||||||
scores[mask] += 20
|
|
||||||
for i in np.where(mask)[0]: triggered[i].append('alpha_extreme')
|
|
||||||
|
|
||||||
# 强信号
|
|
||||||
alpha_strong_thresh = 4.0 * size_factor
|
|
||||||
mask = (alpha_abs >= alpha_strong_thresh) & (alpha_abs < alpha_extreme_thresh)
|
|
||||||
scores[mask] += 15
|
|
||||||
for i in np.where(mask)[0]: triggered[i].append('alpha_strong')
|
|
||||||
|
|
||||||
# 中等信号
|
|
||||||
alpha_medium_thresh = 3.0 * size_factor
|
|
||||||
mask = (alpha_abs >= alpha_medium_thresh) & (alpha_abs < alpha_strong_thresh)
|
|
||||||
scores[mask] += 10
|
|
||||||
for i in np.where(mask)[0]: triggered[i].append('alpha_medium')
|
|
||||||
|
|
||||||
# ========== Alpha 加速度规则(动态阈值)==========
|
|
||||||
delta_strong_thresh = 2.0 * size_factor
|
|
||||||
mask = alpha_delta_abs >= delta_strong_thresh
|
|
||||||
scores[mask] += 15
|
|
||||||
for i in np.where(mask)[0]: triggered[i].append('alpha_delta_strong')
|
|
||||||
|
|
||||||
delta_medium_thresh = 1.5 * size_factor
|
|
||||||
mask = (alpha_delta_abs >= delta_medium_thresh) & (alpha_delta_abs < delta_strong_thresh)
|
|
||||||
scores[mask] += 10
|
|
||||||
for i in np.where(mask)[0]: triggered[i].append('alpha_delta_medium')
|
|
||||||
|
|
||||||
# ========== 成交额规则(不受规模影响,放量就是放量)==========
|
|
||||||
mask = amt_ratio >= 10.0
|
|
||||||
scores[mask] += 20
|
|
||||||
for i in np.where(mask)[0]: triggered[i].append('volume_extreme')
|
|
||||||
|
|
||||||
mask = (amt_ratio >= 6.0) & (amt_ratio < 10.0)
|
|
||||||
scores[mask] += 12
|
|
||||||
for i in np.where(mask)[0]: triggered[i].append('volume_strong')
|
|
||||||
|
|
||||||
# ========== 排名规则 ==========
|
|
||||||
mask = rank_pct >= 0.98
|
|
||||||
scores[mask] += 15
|
|
||||||
for i in np.where(mask)[0]: triggered[i].append('rank_top')
|
|
||||||
|
|
||||||
mask = rank_pct <= 0.02
|
|
||||||
scores[mask] += 15
|
|
||||||
for i in np.where(mask)[0]: triggered[i].append('rank_bottom')
|
|
||||||
|
|
||||||
# ========== 涨停规则(动态阈值)==========
|
|
||||||
# 大概念有涨停更有意义
|
|
||||||
limit_high_thresh = 0.30 * size_factor
|
|
||||||
mask = limit_up_ratio >= limit_high_thresh
|
|
||||||
scores[mask] += 20
|
|
||||||
for i in np.where(mask)[0]: triggered[i].append('limit_up_high')
|
|
||||||
|
|
||||||
limit_medium_thresh = 0.20 * size_factor
|
|
||||||
mask = (limit_up_ratio >= limit_medium_thresh) & (limit_up_ratio < limit_high_thresh)
|
|
||||||
scores[mask] += 12
|
|
||||||
for i in np.where(mask)[0]: triggered[i].append('limit_up_medium')
|
|
||||||
|
|
||||||
# ========== 概念规模加分(大概念异动更有价值)==========
|
|
||||||
# 大概念(50+)额外加分
|
|
||||||
large_concept = stock_count >= 50
|
|
||||||
has_signal = scores > 0 # 至少触发了某个规则
|
|
||||||
mask = large_concept & has_signal
|
|
||||||
scores[mask] += 10
|
|
||||||
for i in np.where(mask)[0]: triggered[i].append('large_concept_bonus')
|
|
||||||
|
|
||||||
# 超大概念(100+)再加分
|
|
||||||
xlarge_concept = stock_count >= 100
|
|
||||||
mask = xlarge_concept & has_signal
|
|
||||||
scores[mask] += 10
|
|
||||||
for i in np.where(mask)[0]: triggered[i].append('xlarge_concept_bonus')
|
|
||||||
|
|
||||||
# ========== 组合规则(动态阈值)==========
|
|
||||||
combo_alpha_thresh = 3.0 * size_factor
|
|
||||||
|
|
||||||
# Alpha + 放量 + 排名(三重验证)
|
|
||||||
mask = (alpha_abs >= combo_alpha_thresh) & (amt_ratio >= 5.0) & ((rank_pct >= 0.95) | (rank_pct <= 0.05))
|
|
||||||
scores[mask] += 20
|
|
||||||
for i in np.where(mask)[0]: triggered[i].append('triple_signal')
|
|
||||||
|
|
||||||
# Alpha + 涨停(强组合)
|
|
||||||
mask = (alpha_abs >= combo_alpha_thresh) & (limit_up_ratio >= 0.15 * size_factor)
|
|
||||||
scores[mask] += 15
|
|
||||||
for i in np.where(mask)[0]: triggered[i].append('alpha_with_limit')
|
|
||||||
|
|
||||||
# ========== 小概念惩罚(过滤噪音)==========
|
|
||||||
# 微型概念(<5 只)如果只有单一信号,减分
|
|
||||||
tiny_concept = stock_count < 5
|
|
||||||
single_rule = np.array([len(t) <= 1 for t in triggered])
|
|
||||||
mask = tiny_concept & single_rule & (scores > 0)
|
|
||||||
scores[mask] *= 0.5 # 减半
|
|
||||||
for i in np.where(mask)[0]: triggered[i].append('tiny_concept_penalty')
|
|
||||||
|
|
||||||
scores = np.clip(scores, 0, 100)
|
|
||||||
return scores, triggered
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== ML 评分器 ====================
|
|
||||||
|
|
||||||
class FastMLScorer:
|
|
||||||
"""快速 ML 评分器"""
|
|
||||||
|
|
||||||
def __init__(self, checkpoint_dir: str = 'ml/checkpoints', device: str = 'auto'):
|
|
||||||
self.checkpoint_dir = Path(checkpoint_dir)
|
|
||||||
|
|
||||||
if device == 'auto':
|
|
||||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
||||||
elif device == 'cuda' and not torch.cuda.is_available():
|
|
||||||
print("警告: CUDA 不可用,使用 CPU")
|
|
||||||
self.device = torch.device('cpu')
|
|
||||||
else:
|
|
||||||
self.device = torch.device(device)
|
|
||||||
|
|
||||||
self.model = None
|
|
||||||
self.thresholds = None
|
|
||||||
self._load_model()
|
|
||||||
|
|
||||||
def _load_model(self):
|
|
||||||
model_path = self.checkpoint_dir / 'best_model.pt'
|
|
||||||
thresholds_path = self.checkpoint_dir / 'thresholds.json'
|
|
||||||
config_path = self.checkpoint_dir / 'config.json'
|
|
||||||
|
|
||||||
if not model_path.exists():
|
|
||||||
print(f"警告: 模型不存在 {model_path}")
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
from model import LSTMAutoencoder
|
|
||||||
|
|
||||||
config = {}
|
|
||||||
if config_path.exists():
|
|
||||||
with open(config_path) as f:
|
|
||||||
config = json.load(f).get('model', {})
|
|
||||||
|
|
||||||
# 处理旧配置键名
|
|
||||||
if 'd_model' in config:
|
|
||||||
config['hidden_dim'] = config.pop('d_model') // 2
|
|
||||||
for key in ['num_encoder_layers', 'num_decoder_layers', 'nhead', 'dim_feedforward', 'max_seq_len', 'use_instance_norm']:
|
|
||||||
config.pop(key, None)
|
|
||||||
if 'num_layers' not in config:
|
|
||||||
config['num_layers'] = 1
|
|
||||||
|
|
||||||
checkpoint = torch.load(model_path, map_location='cpu')
|
|
||||||
self.model = LSTMAutoencoder(**config)
|
|
||||||
self.model.load_state_dict(checkpoint['model_state_dict'])
|
|
||||||
self.model.to(self.device)
|
|
||||||
self.model.eval()
|
|
||||||
|
|
||||||
if thresholds_path.exists():
|
|
||||||
with open(thresholds_path) as f:
|
|
||||||
self.thresholds = json.load(f)
|
|
||||||
|
|
||||||
print(f"ML模型加载成功 (设备: {self.device})")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"ML模型加载失败: {e}")
|
|
||||||
self.model = None
|
|
||||||
|
|
||||||
def is_ready(self):
|
|
||||||
return self.model is not None
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def score_batch(self, sequences: np.ndarray) -> np.ndarray:
|
|
||||||
"""
|
|
||||||
批量计算 ML 得分
|
|
||||||
|
|
||||||
Args:
|
|
||||||
sequences: (batch, seq_len, n_features)
|
|
||||||
Returns:
|
|
||||||
scores: (batch,) 0-100 分数
|
|
||||||
"""
|
|
||||||
if not self.is_ready() or len(sequences) == 0:
|
|
||||||
return np.zeros(len(sequences))
|
|
||||||
|
|
||||||
x = torch.FloatTensor(sequences).to(self.device)
|
|
||||||
output, _ = self.model(x)
|
|
||||||
mse = ((output - x) ** 2).mean(dim=-1)
|
|
||||||
errors = mse[:, -1].cpu().numpy()
|
|
||||||
|
|
||||||
p95 = self.thresholds.get('p95', 0.1) if self.thresholds else 0.1
|
|
||||||
scores = np.clip(errors / p95 * 50, 0, 100)
|
|
||||||
return scores
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 快速回测 ====================
|
|
||||||
|
|
||||||
def build_sequences_fast(
|
|
||||||
df: pd.DataFrame,
|
|
||||||
seq_len: int = 30,
|
|
||||||
prev_df: pd.DataFrame = None
|
|
||||||
) -> Tuple[np.ndarray, pd.DataFrame]:
|
|
||||||
"""
|
|
||||||
快速构建所有有效序列
|
|
||||||
|
|
||||||
支持跨日序列:用前一天收盘数据 + 当天开盘数据拼接,实现 9:30 就能检测
|
|
||||||
|
|
||||||
Args:
|
|
||||||
df: 当天数据
|
|
||||||
seq_len: 序列长度
|
|
||||||
prev_df: 前一天数据(可选,用于构建开盘时的序列)
|
|
||||||
|
|
||||||
返回:
|
|
||||||
sequences: (n_valid, seq_len, n_features) 所有有效序列
|
|
||||||
info_df: 对应的元信息 DataFrame
|
|
||||||
"""
|
|
||||||
# 确保按概念和时间排序
|
|
||||||
df = df.sort_values(['concept_id', 'timestamp']).reset_index(drop=True)
|
|
||||||
|
|
||||||
# 如果有前一天数据,按概念构建尾部缓存(取每个概念最后 seq_len-1 条)
|
|
||||||
prev_cache = {}
|
|
||||||
if prev_df is not None and len(prev_df) > 0:
|
|
||||||
prev_df = prev_df.sort_values(['concept_id', 'timestamp'])
|
|
||||||
for concept_id, gdf in prev_df.groupby('concept_id'):
|
|
||||||
tail_data = gdf.tail(seq_len - 1)
|
|
||||||
if len(tail_data) > 0:
|
|
||||||
feat_matrix = tail_data[FEATURES].values
|
|
||||||
feat_matrix = np.nan_to_num(feat_matrix, nan=0.0, posinf=0.0, neginf=0.0)
|
|
||||||
feat_matrix = np.clip(feat_matrix, -CONFIG['clip_value'], CONFIG['clip_value'])
|
|
||||||
prev_cache[concept_id] = feat_matrix
|
|
||||||
|
|
||||||
# 按概念分组
|
|
||||||
groups = df.groupby('concept_id')
|
|
||||||
|
|
||||||
sequences = []
|
|
||||||
infos = []
|
|
||||||
|
|
||||||
for concept_id, gdf in groups:
|
|
||||||
gdf = gdf.reset_index(drop=True)
|
|
||||||
|
|
||||||
# 获取特征矩阵
|
|
||||||
feat_matrix = gdf[FEATURES].values
|
|
||||||
feat_matrix = np.nan_to_num(feat_matrix, nan=0.0, posinf=0.0, neginf=0.0)
|
|
||||||
feat_matrix = np.clip(feat_matrix, -CONFIG['clip_value'], CONFIG['clip_value'])
|
|
||||||
|
|
||||||
# 如果有前一天缓存,拼接到当天数据前面
|
|
||||||
if concept_id in prev_cache:
|
|
||||||
prev_data = prev_cache[concept_id]
|
|
||||||
combined_matrix = np.vstack([prev_data, feat_matrix])
|
|
||||||
# 计算偏移量:前一天数据的长度
|
|
||||||
offset = len(prev_data)
|
|
||||||
else:
|
|
||||||
combined_matrix = feat_matrix
|
|
||||||
offset = 0
|
|
||||||
|
|
||||||
# 滑动窗口构建序列
|
|
||||||
n_total = len(combined_matrix)
|
|
||||||
if n_total < seq_len:
|
|
||||||
continue
|
|
||||||
|
|
||||||
for i in range(n_total - seq_len + 1):
|
|
||||||
seq = combined_matrix[i:i + seq_len]
|
|
||||||
|
|
||||||
# 计算对应当天数据的索引
|
|
||||||
# 序列最后一个点的位置 = i + seq_len - 1
|
|
||||||
# 对应当天数据的索引 = (i + seq_len - 1) - offset
|
|
||||||
today_idx = i + seq_len - 1 - offset
|
|
||||||
|
|
||||||
# 只要序列的最后一个点是当天的数据,就记录
|
|
||||||
if today_idx < 0 or today_idx >= len(gdf):
|
|
||||||
continue
|
|
||||||
|
|
||||||
sequences.append(seq)
|
|
||||||
|
|
||||||
# 记录最后一个时间步的信息(当天的)
|
|
||||||
row = gdf.iloc[today_idx]
|
|
||||||
infos.append({
|
|
||||||
'concept_id': concept_id,
|
|
||||||
'timestamp': row['timestamp'],
|
|
||||||
'alpha': row['alpha'],
|
|
||||||
'alpha_delta': row.get('alpha_delta', 0),
|
|
||||||
'amt_ratio': row.get('amt_ratio', 1),
|
|
||||||
'amt_delta': row.get('amt_delta', 0),
|
|
||||||
'rank_pct': row.get('rank_pct', 0.5),
|
|
||||||
'limit_up_ratio': row.get('limit_up_ratio', 0),
|
|
||||||
'stock_count': row.get('stock_count', 0),
|
|
||||||
'total_amt': row.get('total_amt', 0),
|
|
||||||
})
|
|
||||||
|
|
||||||
if not sequences:
|
|
||||||
return np.array([]), pd.DataFrame()
|
|
||||||
|
|
||||||
return np.array(sequences), pd.DataFrame(infos)
|
|
||||||
|
|
||||||
|
|
||||||
def backtest_single_day_fast(
|
|
||||||
ml_scorer: FastMLScorer,
|
|
||||||
df: pd.DataFrame,
|
|
||||||
date: str,
|
|
||||||
config: Dict,
|
|
||||||
prev_df: pd.DataFrame = None
|
|
||||||
) -> List[Dict]:
|
|
||||||
"""
|
|
||||||
快速回测单天(向量化版本)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
ml_scorer: ML 评分器
|
|
||||||
df: 当天数据
|
|
||||||
date: 日期
|
|
||||||
config: 配置
|
|
||||||
prev_df: 前一天数据(用于 9:30 开始检测)
|
|
||||||
"""
|
|
||||||
seq_len = config.get('seq_len', 30)
|
|
||||||
|
|
||||||
# 1. 构建所有序列(支持跨日)
|
|
||||||
sequences, info_df = build_sequences_fast(df, seq_len, prev_df)
|
|
||||||
|
|
||||||
if len(sequences) == 0:
|
|
||||||
return []
|
|
||||||
|
|
||||||
# 2. 过滤小波动
|
|
||||||
alpha_abs = np.abs(info_df['alpha'].values)
|
|
||||||
valid_mask = alpha_abs >= config['min_alpha_abs']
|
|
||||||
|
|
||||||
sequences = sequences[valid_mask]
|
|
||||||
info_df = info_df[valid_mask].reset_index(drop=True)
|
|
||||||
|
|
||||||
if len(sequences) == 0:
|
|
||||||
return []
|
|
||||||
|
|
||||||
# 3. 批量规则评分
|
|
||||||
rule_scores, triggered_rules = score_rules_batch(info_df)
|
|
||||||
|
|
||||||
# 4. 批量 ML 评分(分批处理避免显存溢出)
|
|
||||||
batch_size = 2048
|
|
||||||
ml_scores = []
|
|
||||||
for i in range(0, len(sequences), batch_size):
|
|
||||||
batch_seq = sequences[i:i+batch_size]
|
|
||||||
batch_scores = ml_scorer.score_batch(batch_seq)
|
|
||||||
ml_scores.append(batch_scores)
|
|
||||||
ml_scores = np.concatenate(ml_scores) if ml_scores else np.zeros(len(sequences))
|
|
||||||
|
|
||||||
# 5. 融合得分
|
|
||||||
w1, w2 = config['rule_weight'], config['ml_weight']
|
|
||||||
final_scores = w1 * rule_scores + w2 * ml_scores
|
|
||||||
|
|
||||||
# 6. 判断异动
|
|
||||||
is_anomaly = (
|
|
||||||
(rule_scores >= config['rule_trigger']) |
|
|
||||||
(ml_scores >= config['ml_trigger']) |
|
|
||||||
(final_scores >= config['fusion_trigger'])
|
|
||||||
)
|
|
||||||
|
|
||||||
# 7. 应用冷却期(按概念+时间排序后处理)
|
|
||||||
info_df['rule_score'] = rule_scores
|
|
||||||
info_df['ml_score'] = ml_scores
|
|
||||||
info_df['final_score'] = final_scores
|
|
||||||
info_df['is_anomaly'] = is_anomaly
|
|
||||||
info_df['triggered_rules'] = triggered_rules
|
|
||||||
|
|
||||||
# 只保留异动
|
|
||||||
anomaly_df = info_df[info_df['is_anomaly']].copy()
|
|
||||||
|
|
||||||
if len(anomaly_df) == 0:
|
|
||||||
return []
|
|
||||||
|
|
||||||
# 应用冷却期
|
|
||||||
anomaly_df = anomaly_df.sort_values(['concept_id', 'timestamp'])
|
|
||||||
cooldown = {}
|
|
||||||
keep_mask = []
|
|
||||||
|
|
||||||
for _, row in anomaly_df.iterrows():
|
|
||||||
cid = row['concept_id']
|
|
||||||
ts = row['timestamp']
|
|
||||||
|
|
||||||
if cid in cooldown:
|
|
||||||
try:
|
|
||||||
diff = (ts - cooldown[cid]).total_seconds() / 60
|
|
||||||
except:
|
|
||||||
diff = config['cooldown_minutes'] + 1
|
|
||||||
|
|
||||||
if diff < config['cooldown_minutes']:
|
|
||||||
keep_mask.append(False)
|
|
||||||
continue
|
|
||||||
|
|
||||||
cooldown[cid] = ts
|
|
||||||
keep_mask.append(True)
|
|
||||||
|
|
||||||
anomaly_df = anomaly_df[keep_mask]
|
|
||||||
|
|
||||||
# 8. 按时间分组,每分钟最多 max_alerts_per_minute 个
|
|
||||||
alerts = []
|
|
||||||
for ts, group in anomaly_df.groupby('timestamp'):
|
|
||||||
group = group.nlargest(config['max_alerts_per_minute'], 'final_score')
|
|
||||||
|
|
||||||
for _, row in group.iterrows():
|
|
||||||
alpha = row['alpha']
|
|
||||||
if alpha >= 1.5:
|
|
||||||
atype = 'surge_up'
|
|
||||||
elif alpha <= -1.5:
|
|
||||||
atype = 'surge_down'
|
|
||||||
elif row['amt_ratio'] >= 3.0:
|
|
||||||
atype = 'volume_spike'
|
|
||||||
else:
|
|
||||||
atype = 'unknown'
|
|
||||||
|
|
||||||
rule_score = row['rule_score']
|
|
||||||
ml_score = row['ml_score']
|
|
||||||
final_score = row['final_score']
|
|
||||||
|
|
||||||
if rule_score >= config['rule_trigger']:
|
|
||||||
trigger = f'规则强信号({rule_score:.0f}分)'
|
|
||||||
elif ml_score >= config['ml_trigger']:
|
|
||||||
trigger = f'ML强信号({ml_score:.0f}分)'
|
|
||||||
else:
|
|
||||||
trigger = f'融合触发({final_score:.0f}分)'
|
|
||||||
|
|
||||||
alerts.append({
|
|
||||||
'concept_id': row['concept_id'],
|
|
||||||
'alert_time': row['timestamp'],
|
|
||||||
'trade_date': date,
|
|
||||||
'alert_type': atype,
|
|
||||||
'final_score': final_score,
|
|
||||||
'rule_score': rule_score,
|
|
||||||
'ml_score': ml_score,
|
|
||||||
'trigger_reason': trigger,
|
|
||||||
'triggered_rules': row['triggered_rules'],
|
|
||||||
'alpha': alpha,
|
|
||||||
'alpha_delta': row['alpha_delta'],
|
|
||||||
'amt_ratio': row['amt_ratio'],
|
|
||||||
'amt_delta': row['amt_delta'],
|
|
||||||
'rank_pct': row['rank_pct'],
|
|
||||||
'limit_up_ratio': row['limit_up_ratio'],
|
|
||||||
'stock_count': row['stock_count'],
|
|
||||||
'total_amt': row['total_amt'],
|
|
||||||
})
|
|
||||||
|
|
||||||
return alerts
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 数据加载 ====================
|
|
||||||
|
|
||||||
def load_daily_features(data_dir: str, date: str) -> Optional[pd.DataFrame]:
|
|
||||||
file_path = Path(data_dir) / f"features_{date}.parquet"
|
|
||||||
if not file_path.exists():
|
|
||||||
return None
|
|
||||||
return pd.read_parquet(file_path)
|
|
||||||
|
|
||||||
|
|
||||||
def get_available_dates(data_dir: str, start: str, end: str) -> List[str]:
|
|
||||||
data_path = Path(data_dir)
|
|
||||||
dates = []
|
|
||||||
for f in sorted(data_path.glob("features_*.parquet")):
|
|
||||||
d = f.stem.replace('features_', '')
|
|
||||||
if start <= d <= end:
|
|
||||||
dates.append(d)
|
|
||||||
return dates
|
|
||||||
|
|
||||||
|
|
||||||
def get_prev_trading_day(data_dir: str, date: str) -> Optional[str]:
|
|
||||||
"""获取给定日期之前最近的有数据的交易日"""
|
|
||||||
data_path = Path(data_dir)
|
|
||||||
all_dates = sorted([f.stem.replace('features_', '') for f in data_path.glob("features_*.parquet")])
|
|
||||||
|
|
||||||
for i, d in enumerate(all_dates):
|
|
||||||
if d == date and i > 0:
|
|
||||||
return all_dates[i - 1]
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def export_to_csv(alerts: List[Dict], path: str):
|
|
||||||
if alerts:
|
|
||||||
pd.DataFrame(alerts).to_csv(path, index=False, encoding='utf-8-sig')
|
|
||||||
print(f"已导出: {path}")
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 数据库写入 ====================
|
|
||||||
|
|
||||||
def init_db_table():
|
|
||||||
"""
|
|
||||||
初始化数据库表(如果不存在则创建)
|
|
||||||
|
|
||||||
表结构说明:
|
|
||||||
- concept_id: 概念ID
|
|
||||||
- alert_time: 异动时间(精确到分钟)
|
|
||||||
- trade_date: 交易日期
|
|
||||||
- alert_type: 异动类型(surge_up/surge_down/volume_spike/unknown)
|
|
||||||
- final_score: 最终得分(0-100)
|
|
||||||
- rule_score: 规则得分(0-100)
|
|
||||||
- ml_score: ML得分(0-100)
|
|
||||||
- trigger_reason: 触发原因
|
|
||||||
- alpha: 超额收益率
|
|
||||||
- alpha_delta: alpha变化速度
|
|
||||||
- amt_ratio: 成交额放大倍数
|
|
||||||
- rank_pct: 排名百分位
|
|
||||||
- stock_count: 概念股票数量
|
|
||||||
- triggered_rules: 触发的规则列表(JSON)
|
|
||||||
"""
|
|
||||||
create_sql = text("""
|
|
||||||
CREATE TABLE IF NOT EXISTS concept_anomaly_hybrid (
|
|
||||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
|
||||||
concept_id VARCHAR(64) NOT NULL,
|
|
||||||
alert_time DATETIME NOT NULL,
|
|
||||||
trade_date DATE NOT NULL,
|
|
||||||
alert_type VARCHAR(32) NOT NULL,
|
|
||||||
final_score FLOAT NOT NULL,
|
|
||||||
rule_score FLOAT NOT NULL,
|
|
||||||
ml_score FLOAT NOT NULL,
|
|
||||||
trigger_reason VARCHAR(64),
|
|
||||||
alpha FLOAT,
|
|
||||||
alpha_delta FLOAT,
|
|
||||||
amt_ratio FLOAT,
|
|
||||||
amt_delta FLOAT,
|
|
||||||
rank_pct FLOAT,
|
|
||||||
limit_up_ratio FLOAT,
|
|
||||||
stock_count INT,
|
|
||||||
total_amt FLOAT,
|
|
||||||
triggered_rules JSON,
|
|
||||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
||||||
UNIQUE KEY uk_concept_time (concept_id, alert_time, trade_date),
|
|
||||||
INDEX idx_trade_date (trade_date),
|
|
||||||
INDEX idx_concept_id (concept_id),
|
|
||||||
INDEX idx_final_score (final_score),
|
|
||||||
INDEX idx_alert_type (alert_type)
|
|
||||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='概念异动检测结果(融合版)'
|
|
||||||
""")
|
|
||||||
|
|
||||||
with MYSQL_ENGINE.begin() as conn:
|
|
||||||
conn.execute(create_sql)
|
|
||||||
print("数据库表已就绪: concept_anomaly_hybrid")
|
|
||||||
|
|
||||||
|
|
||||||
def save_alerts_to_mysql(alerts: List[Dict], dry_run: bool = False) -> int:
|
|
||||||
"""
|
|
||||||
保存异动到 MySQL
|
|
||||||
|
|
||||||
Args:
|
|
||||||
alerts: 异动列表
|
|
||||||
dry_run: 是否只模拟,不实际写入
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
实际保存的记录数
|
|
||||||
"""
|
|
||||||
if not alerts:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
if dry_run:
|
|
||||||
print(f" [Dry Run] 将写入 {len(alerts)} 条异动")
|
|
||||||
return len(alerts)
|
|
||||||
|
|
||||||
saved = 0
|
|
||||||
skipped = 0
|
|
||||||
|
|
||||||
with MYSQL_ENGINE.begin() as conn:
|
|
||||||
for alert in alerts:
|
|
||||||
try:
|
|
||||||
# 检查是否已存在(使用 INSERT IGNORE 更高效)
|
|
||||||
insert_sql = text("""
|
|
||||||
INSERT IGNORE INTO concept_anomaly_hybrid
|
|
||||||
(concept_id, alert_time, trade_date, alert_type,
|
|
||||||
final_score, rule_score, ml_score, trigger_reason,
|
|
||||||
alpha, alpha_delta, amt_ratio, amt_delta,
|
|
||||||
rank_pct, limit_up_ratio, stock_count, total_amt,
|
|
||||||
triggered_rules)
|
|
||||||
VALUES
|
|
||||||
(:concept_id, :alert_time, :trade_date, :alert_type,
|
|
||||||
:final_score, :rule_score, :ml_score, :trigger_reason,
|
|
||||||
:alpha, :alpha_delta, :amt_ratio, :amt_delta,
|
|
||||||
:rank_pct, :limit_up_ratio, :stock_count, :total_amt,
|
|
||||||
:triggered_rules)
|
|
||||||
""")
|
|
||||||
|
|
||||||
result = conn.execute(insert_sql, {
|
|
||||||
'concept_id': alert['concept_id'],
|
|
||||||
'alert_time': alert['alert_time'],
|
|
||||||
'trade_date': alert['trade_date'],
|
|
||||||
'alert_type': alert['alert_type'],
|
|
||||||
'final_score': alert['final_score'],
|
|
||||||
'rule_score': alert['rule_score'],
|
|
||||||
'ml_score': alert['ml_score'],
|
|
||||||
'trigger_reason': alert['trigger_reason'],
|
|
||||||
'alpha': alert.get('alpha', 0),
|
|
||||||
'alpha_delta': alert.get('alpha_delta', 0),
|
|
||||||
'amt_ratio': alert.get('amt_ratio', 1),
|
|
||||||
'amt_delta': alert.get('amt_delta', 0),
|
|
||||||
'rank_pct': alert.get('rank_pct', 0.5),
|
|
||||||
'limit_up_ratio': alert.get('limit_up_ratio', 0),
|
|
||||||
'stock_count': alert.get('stock_count', 0),
|
|
||||||
'total_amt': alert.get('total_amt', 0),
|
|
||||||
'triggered_rules': json.dumps(alert.get('triggered_rules', []), ensure_ascii=False),
|
|
||||||
})
|
|
||||||
|
|
||||||
if result.rowcount > 0:
|
|
||||||
saved += 1
|
|
||||||
else:
|
|
||||||
skipped += 1
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f" 保存失败: {alert['concept_id']} @ {alert['alert_time']} - {e}")
|
|
||||||
|
|
||||||
if skipped > 0:
|
|
||||||
print(f" 跳过 {skipped} 条重复记录")
|
|
||||||
|
|
||||||
return saved
|
|
||||||
|
|
||||||
|
|
||||||
def clear_alerts_by_date(trade_date: str) -> int:
|
|
||||||
"""清除指定日期的异动记录(用于重新回测)"""
|
|
||||||
with MYSQL_ENGINE.begin() as conn:
|
|
||||||
result = conn.execute(
|
|
||||||
text("DELETE FROM concept_anomaly_hybrid WHERE trade_date = :trade_date"),
|
|
||||||
{'trade_date': trade_date}
|
|
||||||
)
|
|
||||||
return result.rowcount
|
|
||||||
|
|
||||||
|
|
||||||
def analyze_alerts(alerts: List[Dict]):
|
|
||||||
if not alerts:
|
|
||||||
print("无异动")
|
|
||||||
return
|
|
||||||
|
|
||||||
df = pd.DataFrame(alerts)
|
|
||||||
print(f"\n总异动: {len(alerts)}")
|
|
||||||
print(f"\n类型分布:\n{df['alert_type'].value_counts()}")
|
|
||||||
print(f"\n得分统计:")
|
|
||||||
print(f" 最终: {df['final_score'].mean():.1f} (max: {df['final_score'].max():.1f})")
|
|
||||||
print(f" 规则: {df['rule_score'].mean():.1f} (max: {df['rule_score'].max():.1f})")
|
|
||||||
print(f" ML: {df['ml_score'].mean():.1f} (max: {df['ml_score'].max():.1f})")
|
|
||||||
|
|
||||||
trigger_type = df['trigger_reason'].apply(
|
|
||||||
lambda x: '规则' if '规则' in x else ('ML' if 'ML' in x else '融合')
|
|
||||||
)
|
|
||||||
print(f"\n触发来源:\n{trigger_type.value_counts()}")
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 主函数 ====================
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser(description='快速融合异动回测')
|
|
||||||
parser.add_argument('--data_dir', default='ml/data')
|
|
||||||
parser.add_argument('--checkpoint_dir', default='ml/checkpoints')
|
|
||||||
parser.add_argument('--start', required=True)
|
|
||||||
parser.add_argument('--end', default=None)
|
|
||||||
parser.add_argument('--dry-run', action='store_true', help='模拟运行,不写入数据库')
|
|
||||||
parser.add_argument('--export-csv', default=None, help='导出 CSV 文件路径')
|
|
||||||
parser.add_argument('--save-db', action='store_true', help='保存结果到数据库')
|
|
||||||
parser.add_argument('--clear-first', action='store_true', help='写入前先清除该日期的旧数据')
|
|
||||||
parser.add_argument('--device', default='auto')
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
if args.end is None:
|
|
||||||
args.end = args.start
|
|
||||||
|
|
||||||
print("=" * 60)
|
|
||||||
print("快速融合异动回测")
|
|
||||||
print("=" * 60)
|
|
||||||
print(f"日期: {args.start} ~ {args.end}")
|
|
||||||
print(f"设备: {args.device}")
|
|
||||||
print(f"保存数据库: {args.save_db}")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
# 初始化数据库表(如果需要保存)
|
|
||||||
if args.save_db and not args.dry_run:
|
|
||||||
init_db_table()
|
|
||||||
|
|
||||||
# 初始化 ML 评分器
|
|
||||||
ml_scorer = FastMLScorer(args.checkpoint_dir, args.device)
|
|
||||||
|
|
||||||
# 获取日期
|
|
||||||
dates = get_available_dates(args.data_dir, args.start, args.end)
|
|
||||||
if not dates:
|
|
||||||
print("无数据")
|
|
||||||
return
|
|
||||||
|
|
||||||
print(f"找到 {len(dates)} 天数据\n")
|
|
||||||
|
|
||||||
# 回测(支持跨日序列)
|
|
||||||
all_alerts = []
|
|
||||||
total_saved = 0
|
|
||||||
prev_df = None # 缓存前一天数据
|
|
||||||
|
|
||||||
for i, date in enumerate(tqdm(dates, desc="回测")):
|
|
||||||
df = load_daily_features(args.data_dir, date)
|
|
||||||
if df is None or df.empty:
|
|
||||||
prev_df = None # 当天无数据,清空缓存
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 第一天需要加载前一天数据(如果存在)
|
|
||||||
if i == 0 and prev_df is None:
|
|
||||||
prev_date = get_prev_trading_day(args.data_dir, date)
|
|
||||||
if prev_date:
|
|
||||||
prev_df = load_daily_features(args.data_dir, prev_date)
|
|
||||||
if prev_df is not None:
|
|
||||||
tqdm.write(f" 加载前一天数据: {prev_date}")
|
|
||||||
|
|
||||||
alerts = backtest_single_day_fast(ml_scorer, df, date, CONFIG, prev_df)
|
|
||||||
all_alerts.extend(alerts)
|
|
||||||
|
|
||||||
# 保存到数据库
|
|
||||||
if args.save_db and alerts:
|
|
||||||
if args.clear_first and not args.dry_run:
|
|
||||||
cleared = clear_alerts_by_date(date)
|
|
||||||
if cleared > 0:
|
|
||||||
tqdm.write(f" 清除 {date} 旧数据: {cleared} 条")
|
|
||||||
|
|
||||||
saved = save_alerts_to_mysql(alerts, dry_run=args.dry_run)
|
|
||||||
total_saved += saved
|
|
||||||
tqdm.write(f" {date}: {len(alerts)} 个异动, 保存 {saved} 条")
|
|
||||||
elif alerts:
|
|
||||||
tqdm.write(f" {date}: {len(alerts)} 个异动")
|
|
||||||
|
|
||||||
# 当天数据成为下一天的 prev_df
|
|
||||||
prev_df = df
|
|
||||||
|
|
||||||
# 导出 CSV
|
|
||||||
if args.export_csv:
|
|
||||||
export_to_csv(all_alerts, args.export_csv)
|
|
||||||
|
|
||||||
# 分析
|
|
||||||
analyze_alerts(all_alerts)
|
|
||||||
|
|
||||||
print(f"\n总计: {len(all_alerts)} 个异动")
|
|
||||||
if args.save_db:
|
|
||||||
print(f"已保存到数据库: {total_saved} 条")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -1,481 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""
|
|
||||||
融合异动回测脚本
|
|
||||||
|
|
||||||
使用 HybridAnomalyDetector 进行回测:
|
|
||||||
- 规则评分 + LSTM Autoencoder 融合判断
|
|
||||||
- 输出更丰富的异动信息
|
|
||||||
|
|
||||||
使用方法:
|
|
||||||
python backtest_hybrid.py --start 2024-01-01 --end 2024-12-01
|
|
||||||
python backtest_hybrid.py --start 2024-11-01 --dry-run
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import argparse
|
|
||||||
import json
|
|
||||||
from datetime import datetime
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Dict, List, Optional
|
|
||||||
from collections import defaultdict
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pandas as pd
|
|
||||||
from tqdm import tqdm
|
|
||||||
from sqlalchemy import create_engine, text
|
|
||||||
|
|
||||||
# 添加父目录到路径
|
|
||||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
||||||
|
|
||||||
from detector import HybridAnomalyDetector, create_detector
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 配置 ====================
|
|
||||||
|
|
||||||
MYSQL_ENGINE = create_engine(
|
|
||||||
"mysql+pymysql://root:Zzl5588161!@192.168.1.5:3306/stock",
|
|
||||||
echo=False
|
|
||||||
)
|
|
||||||
|
|
||||||
FEATURES = [
|
|
||||||
'alpha',
|
|
||||||
'alpha_delta',
|
|
||||||
'amt_ratio',
|
|
||||||
'amt_delta',
|
|
||||||
'rank_pct',
|
|
||||||
'limit_up_ratio',
|
|
||||||
]
|
|
||||||
|
|
||||||
BACKTEST_CONFIG = {
|
|
||||||
'seq_len': 30,
|
|
||||||
'min_alpha_abs': 0.3, # 降低阈值,让规则也能发挥作用
|
|
||||||
'cooldown_minutes': 8,
|
|
||||||
'max_alerts_per_minute': 20,
|
|
||||||
'clip_value': 10.0,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 数据加载 ====================
|
|
||||||
|
|
||||||
def load_daily_features(data_dir: str, date: str) -> Optional[pd.DataFrame]:
|
|
||||||
"""加载单天的特征数据"""
|
|
||||||
file_path = Path(data_dir) / f"features_{date}.parquet"
|
|
||||||
|
|
||||||
if not file_path.exists():
|
|
||||||
return None
|
|
||||||
|
|
||||||
df = pd.read_parquet(file_path)
|
|
||||||
return df
|
|
||||||
|
|
||||||
|
|
||||||
def get_available_dates(data_dir: str, start_date: str, end_date: str) -> List[str]:
|
|
||||||
"""获取可用的日期列表"""
|
|
||||||
data_path = Path(data_dir)
|
|
||||||
all_files = sorted(data_path.glob("features_*.parquet"))
|
|
||||||
|
|
||||||
dates = []
|
|
||||||
for f in all_files:
|
|
||||||
date = f.stem.replace('features_', '')
|
|
||||||
if start_date <= date <= end_date:
|
|
||||||
dates.append(date)
|
|
||||||
|
|
||||||
return dates
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 融合回测 ====================
|
|
||||||
|
|
||||||
def backtest_single_day_hybrid(
|
|
||||||
detector: HybridAnomalyDetector,
|
|
||||||
df: pd.DataFrame,
|
|
||||||
date: str,
|
|
||||||
seq_len: int = 30
|
|
||||||
) -> List[Dict]:
|
|
||||||
"""
|
|
||||||
使用融合检测器回测单天数据(批量优化版)
|
|
||||||
"""
|
|
||||||
alerts = []
|
|
||||||
|
|
||||||
# 按概念分组,预先构建字典
|
|
||||||
grouped_dict = {cid: cdf for cid, cdf in df.groupby('concept_id', sort=False)}
|
|
||||||
|
|
||||||
# 冷却记录
|
|
||||||
cooldown = {}
|
|
||||||
|
|
||||||
# 获取所有时间点
|
|
||||||
all_timestamps = sorted(df['timestamp'].unique())
|
|
||||||
|
|
||||||
if len(all_timestamps) < seq_len:
|
|
||||||
return alerts
|
|
||||||
|
|
||||||
# 对每个时间点进行检测
|
|
||||||
for t_idx in range(seq_len - 1, len(all_timestamps)):
|
|
||||||
current_time = all_timestamps[t_idx]
|
|
||||||
window_start_time = all_timestamps[t_idx - seq_len + 1]
|
|
||||||
|
|
||||||
# 批量收集该时刻所有候选概念
|
|
||||||
batch_sequences = []
|
|
||||||
batch_features = []
|
|
||||||
batch_infos = []
|
|
||||||
|
|
||||||
for concept_id, concept_df in grouped_dict.items():
|
|
||||||
# 检查冷却(提前过滤)
|
|
||||||
if concept_id in cooldown:
|
|
||||||
last_alert = cooldown[concept_id]
|
|
||||||
if isinstance(current_time, datetime):
|
|
||||||
time_diff = (current_time - last_alert).total_seconds() / 60
|
|
||||||
else:
|
|
||||||
time_diff = BACKTEST_CONFIG['cooldown_minutes'] + 1
|
|
||||||
if time_diff < BACKTEST_CONFIG['cooldown_minutes']:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 获取时间窗口内的数据
|
|
||||||
mask = (concept_df['timestamp'] >= window_start_time) & (concept_df['timestamp'] <= current_time)
|
|
||||||
window_df = concept_df.loc[mask]
|
|
||||||
|
|
||||||
if len(window_df) < seq_len:
|
|
||||||
continue
|
|
||||||
|
|
||||||
window_df = window_df.sort_values('timestamp').tail(seq_len)
|
|
||||||
|
|
||||||
# 当前时刻特征
|
|
||||||
current_row = window_df.iloc[-1]
|
|
||||||
alpha = current_row.get('alpha', 0)
|
|
||||||
|
|
||||||
# 过滤微小波动(提前过滤)
|
|
||||||
if abs(alpha) < BACKTEST_CONFIG['min_alpha_abs']:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 提取特征序列
|
|
||||||
sequence = window_df[FEATURES].values
|
|
||||||
sequence = np.nan_to_num(sequence, nan=0.0, posinf=0.0, neginf=0.0)
|
|
||||||
sequence = np.clip(sequence, -BACKTEST_CONFIG['clip_value'], BACKTEST_CONFIG['clip_value'])
|
|
||||||
|
|
||||||
current_features = {
|
|
||||||
'alpha': alpha,
|
|
||||||
'alpha_delta': current_row.get('alpha_delta', 0),
|
|
||||||
'amt_ratio': current_row.get('amt_ratio', 1),
|
|
||||||
'amt_delta': current_row.get('amt_delta', 0),
|
|
||||||
'rank_pct': current_row.get('rank_pct', 0.5),
|
|
||||||
'limit_up_ratio': current_row.get('limit_up_ratio', 0),
|
|
||||||
}
|
|
||||||
|
|
||||||
batch_sequences.append(sequence)
|
|
||||||
batch_features.append(current_features)
|
|
||||||
batch_infos.append({
|
|
||||||
'concept_id': concept_id,
|
|
||||||
'stock_count': current_row.get('stock_count', 0),
|
|
||||||
'total_amt': current_row.get('total_amt', 0),
|
|
||||||
})
|
|
||||||
|
|
||||||
if not batch_sequences:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 批量 ML 推理
|
|
||||||
sequences_array = np.array(batch_sequences)
|
|
||||||
ml_scores = detector.ml_scorer.score(sequences_array) if detector.ml_scorer.is_ready() else [0.0] * len(batch_sequences)
|
|
||||||
if isinstance(ml_scores, float):
|
|
||||||
ml_scores = [ml_scores]
|
|
||||||
|
|
||||||
# 批量规则评分 + 融合
|
|
||||||
minute_alerts = []
|
|
||||||
for i, (features, info) in enumerate(zip(batch_features, batch_infos)):
|
|
||||||
concept_id = info['concept_id']
|
|
||||||
|
|
||||||
# 规则评分
|
|
||||||
rule_score, rule_details = detector.rule_scorer.score(features)
|
|
||||||
|
|
||||||
# ML 评分
|
|
||||||
ml_score = ml_scores[i] if i < len(ml_scores) else 0.0
|
|
||||||
|
|
||||||
# 融合
|
|
||||||
w1 = detector.config['rule_weight']
|
|
||||||
w2 = detector.config['ml_weight']
|
|
||||||
final_score = w1 * rule_score + w2 * ml_score
|
|
||||||
|
|
||||||
# 判断是否异动
|
|
||||||
is_anomaly = False
|
|
||||||
trigger_reason = ''
|
|
||||||
|
|
||||||
if rule_score >= detector.config['rule_trigger']:
|
|
||||||
is_anomaly = True
|
|
||||||
trigger_reason = f'规则强信号({rule_score:.0f}分)'
|
|
||||||
elif ml_score >= detector.config['ml_trigger']:
|
|
||||||
is_anomaly = True
|
|
||||||
trigger_reason = f'ML强信号({ml_score:.0f}分)'
|
|
||||||
elif final_score >= detector.config['fusion_trigger']:
|
|
||||||
is_anomaly = True
|
|
||||||
trigger_reason = f'融合触发({final_score:.0f}分)'
|
|
||||||
|
|
||||||
if not is_anomaly:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 异动类型
|
|
||||||
alpha = features.get('alpha', 0)
|
|
||||||
if alpha >= 1.5:
|
|
||||||
anomaly_type = 'surge_up'
|
|
||||||
elif alpha <= -1.5:
|
|
||||||
anomaly_type = 'surge_down'
|
|
||||||
elif features.get('amt_ratio', 1) >= 3.0:
|
|
||||||
anomaly_type = 'volume_spike'
|
|
||||||
else:
|
|
||||||
anomaly_type = 'unknown'
|
|
||||||
|
|
||||||
alert = {
|
|
||||||
'concept_id': concept_id,
|
|
||||||
'alert_time': current_time,
|
|
||||||
'trade_date': date,
|
|
||||||
'alert_type': anomaly_type,
|
|
||||||
'final_score': final_score,
|
|
||||||
'rule_score': rule_score,
|
|
||||||
'ml_score': ml_score,
|
|
||||||
'trigger_reason': trigger_reason,
|
|
||||||
'triggered_rules': list(rule_details.keys()),
|
|
||||||
**features,
|
|
||||||
**info,
|
|
||||||
}
|
|
||||||
|
|
||||||
minute_alerts.append(alert)
|
|
||||||
cooldown[concept_id] = current_time
|
|
||||||
|
|
||||||
# 按最终得分排序
|
|
||||||
minute_alerts.sort(key=lambda x: x['final_score'], reverse=True)
|
|
||||||
alerts.extend(minute_alerts[:BACKTEST_CONFIG['max_alerts_per_minute']])
|
|
||||||
|
|
||||||
return alerts
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 数据库写入 ====================
|
|
||||||
|
|
||||||
def save_alerts_to_mysql(alerts: List[Dict], dry_run: bool = False) -> int:
|
|
||||||
"""保存异动到 MySQL(增强版字段)"""
|
|
||||||
if not alerts:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
if dry_run:
|
|
||||||
print(f" [Dry Run] 将写入 {len(alerts)} 条异动")
|
|
||||||
return len(alerts)
|
|
||||||
|
|
||||||
saved = 0
|
|
||||||
with MYSQL_ENGINE.begin() as conn:
|
|
||||||
for alert in alerts:
|
|
||||||
try:
|
|
||||||
# 检查是否已存在
|
|
||||||
check_sql = text("""
|
|
||||||
SELECT id FROM concept_minute_alert
|
|
||||||
WHERE concept_id = :concept_id
|
|
||||||
AND alert_time = :alert_time
|
|
||||||
AND trade_date = :trade_date
|
|
||||||
""")
|
|
||||||
exists = conn.execute(check_sql, {
|
|
||||||
'concept_id': alert['concept_id'],
|
|
||||||
'alert_time': alert['alert_time'],
|
|
||||||
'trade_date': alert['trade_date'],
|
|
||||||
}).fetchone()
|
|
||||||
|
|
||||||
if exists:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 插入新记录
|
|
||||||
insert_sql = text("""
|
|
||||||
INSERT INTO concept_minute_alert
|
|
||||||
(concept_id, concept_name, alert_time, alert_type, trade_date,
|
|
||||||
change_pct, zscore, importance_score, stock_count, extra_info)
|
|
||||||
VALUES
|
|
||||||
(:concept_id, :concept_name, :alert_time, :alert_type, :trade_date,
|
|
||||||
:change_pct, :zscore, :importance_score, :stock_count, :extra_info)
|
|
||||||
""")
|
|
||||||
|
|
||||||
extra_info = {
|
|
||||||
'detection_method': 'hybrid',
|
|
||||||
'final_score': alert['final_score'],
|
|
||||||
'rule_score': alert['rule_score'],
|
|
||||||
'ml_score': alert['ml_score'],
|
|
||||||
'trigger_reason': alert['trigger_reason'],
|
|
||||||
'triggered_rules': alert['triggered_rules'],
|
|
||||||
'alpha': alert.get('alpha', 0),
|
|
||||||
'alpha_delta': alert.get('alpha_delta', 0),
|
|
||||||
'amt_ratio': alert.get('amt_ratio', 1),
|
|
||||||
}
|
|
||||||
|
|
||||||
conn.execute(insert_sql, {
|
|
||||||
'concept_id': alert['concept_id'],
|
|
||||||
'concept_name': alert.get('concept_name', ''),
|
|
||||||
'alert_time': alert['alert_time'],
|
|
||||||
'alert_type': alert['alert_type'],
|
|
||||||
'trade_date': alert['trade_date'],
|
|
||||||
'change_pct': alert.get('alpha', 0),
|
|
||||||
'zscore': alert['final_score'], # 用最终得分作为 zscore
|
|
||||||
'importance_score': alert['final_score'],
|
|
||||||
'stock_count': alert.get('stock_count', 0),
|
|
||||||
'extra_info': json.dumps(extra_info, ensure_ascii=False)
|
|
||||||
})
|
|
||||||
|
|
||||||
saved += 1
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f" 保存失败: {alert['concept_id']} - {e}")
|
|
||||||
|
|
||||||
return saved
|
|
||||||
|
|
||||||
|
|
||||||
def export_alerts_to_csv(alerts: List[Dict], output_path: str):
|
|
||||||
"""导出异动到 CSV"""
|
|
||||||
if not alerts:
|
|
||||||
return
|
|
||||||
|
|
||||||
df = pd.DataFrame(alerts)
|
|
||||||
df.to_csv(output_path, index=False, encoding='utf-8-sig')
|
|
||||||
print(f"已导出到: {output_path}")
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 统计分析 ====================
|
|
||||||
|
|
||||||
def analyze_alerts(alerts: List[Dict]):
|
|
||||||
"""分析异动结果"""
|
|
||||||
if not alerts:
|
|
||||||
print("无异动数据")
|
|
||||||
return
|
|
||||||
|
|
||||||
df = pd.DataFrame(alerts)
|
|
||||||
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
print("异动统计分析")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
# 1. 基本统计
|
|
||||||
print(f"\n总异动数: {len(alerts)}")
|
|
||||||
|
|
||||||
# 2. 按类型统计
|
|
||||||
print(f"\n异动类型分布:")
|
|
||||||
print(df['alert_type'].value_counts())
|
|
||||||
|
|
||||||
# 3. 得分统计
|
|
||||||
print(f"\n得分统计:")
|
|
||||||
print(f" 最终得分 - Mean: {df['final_score'].mean():.1f}, Max: {df['final_score'].max():.1f}")
|
|
||||||
print(f" 规则得分 - Mean: {df['rule_score'].mean():.1f}, Max: {df['rule_score'].max():.1f}")
|
|
||||||
print(f" ML得分 - Mean: {df['ml_score'].mean():.1f}, Max: {df['ml_score'].max():.1f}")
|
|
||||||
|
|
||||||
# 4. 触发来源分析
|
|
||||||
print(f"\n触发来源分析:")
|
|
||||||
trigger_counts = df['trigger_reason'].apply(
|
|
||||||
lambda x: '规则' if '规则' in x else ('ML' if 'ML' in x else '融合')
|
|
||||||
).value_counts()
|
|
||||||
print(trigger_counts)
|
|
||||||
|
|
||||||
# 5. 规则触发频率
|
|
||||||
all_rules = []
|
|
||||||
for rules in df['triggered_rules']:
|
|
||||||
if isinstance(rules, list):
|
|
||||||
all_rules.extend(rules)
|
|
||||||
|
|
||||||
if all_rules:
|
|
||||||
print(f"\n最常触发的规则 (Top 10):")
|
|
||||||
from collections import Counter
|
|
||||||
rule_counts = Counter(all_rules)
|
|
||||||
for rule, count in rule_counts.most_common(10):
|
|
||||||
print(f" {rule}: {count}")
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 主函数 ====================
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser(description='融合异动回测')
|
|
||||||
parser.add_argument('--data_dir', type=str, default='ml/data',
|
|
||||||
help='特征数据目录')
|
|
||||||
parser.add_argument('--checkpoint_dir', type=str, default='ml/checkpoints',
|
|
||||||
help='模型检查点目录')
|
|
||||||
parser.add_argument('--start', type=str, required=True,
|
|
||||||
help='开始日期 (YYYY-MM-DD)')
|
|
||||||
parser.add_argument('--end', type=str, default=None,
|
|
||||||
help='结束日期 (YYYY-MM-DD),默认=start')
|
|
||||||
parser.add_argument('--dry-run', action='store_true',
|
|
||||||
help='只计算,不写入数据库')
|
|
||||||
parser.add_argument('--export-csv', type=str, default=None,
|
|
||||||
help='导出 CSV 文件路径')
|
|
||||||
parser.add_argument('--rule-weight', type=float, default=0.6,
|
|
||||||
help='规则权重 (0-1)')
|
|
||||||
parser.add_argument('--ml-weight', type=float, default=0.4,
|
|
||||||
help='ML权重 (0-1)')
|
|
||||||
parser.add_argument('--device', type=str, default='cuda',
|
|
||||||
help='设备 (cuda/cpu),默认 cuda')
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
if args.end is None:
|
|
||||||
args.end = args.start
|
|
||||||
|
|
||||||
print("=" * 60)
|
|
||||||
print("融合异动回测 (规则 + LSTM)")
|
|
||||||
print("=" * 60)
|
|
||||||
print(f"日期范围: {args.start} ~ {args.end}")
|
|
||||||
print(f"数据目录: {args.data_dir}")
|
|
||||||
print(f"模型目录: {args.checkpoint_dir}")
|
|
||||||
print(f"规则权重: {args.rule_weight}")
|
|
||||||
print(f"ML权重: {args.ml_weight}")
|
|
||||||
print(f"设备: {args.device}")
|
|
||||||
print(f"Dry Run: {args.dry_run}")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
# 初始化融合检测器(使用 GPU)
|
|
||||||
config = {
|
|
||||||
'rule_weight': args.rule_weight,
|
|
||||||
'ml_weight': args.ml_weight,
|
|
||||||
}
|
|
||||||
|
|
||||||
# 修改 detector.py 中 MLScorer 的设备
|
|
||||||
from detector import HybridAnomalyDetector
|
|
||||||
detector = HybridAnomalyDetector(config, args.checkpoint_dir, device=args.device)
|
|
||||||
|
|
||||||
# 获取可用日期
|
|
||||||
dates = get_available_dates(args.data_dir, args.start, args.end)
|
|
||||||
|
|
||||||
if not dates:
|
|
||||||
print(f"未找到 {args.start} ~ {args.end} 范围内的数据")
|
|
||||||
return
|
|
||||||
|
|
||||||
print(f"\n找到 {len(dates)} 天的数据")
|
|
||||||
|
|
||||||
# 回测
|
|
||||||
all_alerts = []
|
|
||||||
total_saved = 0
|
|
||||||
|
|
||||||
for date in tqdm(dates, desc="回测进度"):
|
|
||||||
df = load_daily_features(args.data_dir, date)
|
|
||||||
|
|
||||||
if df is None or df.empty:
|
|
||||||
continue
|
|
||||||
|
|
||||||
alerts = backtest_single_day_hybrid(
|
|
||||||
detector, df, date,
|
|
||||||
seq_len=BACKTEST_CONFIG['seq_len']
|
|
||||||
)
|
|
||||||
|
|
||||||
if alerts:
|
|
||||||
all_alerts.extend(alerts)
|
|
||||||
|
|
||||||
saved = save_alerts_to_mysql(alerts, dry_run=args.dry_run)
|
|
||||||
total_saved += saved
|
|
||||||
|
|
||||||
if not args.dry_run:
|
|
||||||
tqdm.write(f" {date}: 检测到 {len(alerts)} 个异动,保存 {saved} 条")
|
|
||||||
|
|
||||||
# 导出 CSV
|
|
||||||
if args.export_csv and all_alerts:
|
|
||||||
export_alerts_to_csv(all_alerts, args.export_csv)
|
|
||||||
|
|
||||||
# 统计分析
|
|
||||||
analyze_alerts(all_alerts)
|
|
||||||
|
|
||||||
# 汇总
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
print("回测完成!")
|
|
||||||
print("=" * 60)
|
|
||||||
print(f"总计检测到: {len(all_alerts)} 个异动")
|
|
||||||
print(f"保存到数据库: {total_saved} 条")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -1,294 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""
|
|
||||||
V2 回测脚本 - 验证时间片对齐 + 持续性确认的效果
|
|
||||||
|
|
||||||
回测指标:
|
|
||||||
1. 准确率:异动后 N 分钟内 alpha 是否继续上涨/下跌
|
|
||||||
2. 虚警率:多少异动是噪音
|
|
||||||
3. 持续性:平均异动持续时长
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import json
|
|
||||||
import argparse
|
|
||||||
from datetime import datetime
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Dict, List, Tuple
|
|
||||||
from collections import defaultdict
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pandas as pd
|
|
||||||
from tqdm import tqdm
|
|
||||||
from sqlalchemy import create_engine, text
|
|
||||||
|
|
||||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
||||||
|
|
||||||
from ml.detector_v2 import AnomalyDetectorV2, CONFIG
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 配置 ====================
|
|
||||||
|
|
||||||
MYSQL_ENGINE = create_engine(
|
|
||||||
"mysql+pymysql://root:Zzl5588161!@192.168.1.5:3306/stock",
|
|
||||||
echo=False
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 回测评估 ====================
|
|
||||||
|
|
||||||
def evaluate_alerts(
|
|
||||||
alerts: List[Dict],
|
|
||||||
raw_data: pd.DataFrame,
|
|
||||||
lookahead_minutes: int = 10
|
|
||||||
) -> Dict:
|
|
||||||
"""
|
|
||||||
评估异动质量
|
|
||||||
|
|
||||||
指标:
|
|
||||||
1. 方向正确率:异动后 N 分钟 alpha 方向是否一致
|
|
||||||
2. 持续率:异动后 N 分钟内有多少时刻 alpha 保持同向
|
|
||||||
3. 峰值收益:异动后 N 分钟内的最大 alpha
|
|
||||||
"""
|
|
||||||
if not alerts:
|
|
||||||
return {'accuracy': 0, 'sustained_rate': 0, 'avg_peak': 0, 'total_alerts': 0}
|
|
||||||
|
|
||||||
results = []
|
|
||||||
|
|
||||||
for alert in alerts:
|
|
||||||
concept_id = alert['concept_id']
|
|
||||||
alert_time = alert['alert_time']
|
|
||||||
alert_alpha = alert['alpha']
|
|
||||||
is_up = alert_alpha > 0
|
|
||||||
|
|
||||||
# 获取该概念在异动后的数据
|
|
||||||
concept_data = raw_data[
|
|
||||||
(raw_data['concept_id'] == concept_id) &
|
|
||||||
(raw_data['timestamp'] > alert_time)
|
|
||||||
].head(lookahead_minutes)
|
|
||||||
|
|
||||||
if len(concept_data) < 3:
|
|
||||||
continue
|
|
||||||
|
|
||||||
future_alphas = concept_data['alpha'].values
|
|
||||||
|
|
||||||
# 方向正确:未来 alpha 平均值与当前同向
|
|
||||||
avg_future_alpha = np.mean(future_alphas)
|
|
||||||
direction_correct = (is_up and avg_future_alpha > 0) or (not is_up and avg_future_alpha < 0)
|
|
||||||
|
|
||||||
# 持续率:有多少时刻保持同向
|
|
||||||
if is_up:
|
|
||||||
sustained_count = sum(1 for a in future_alphas if a > 0)
|
|
||||||
else:
|
|
||||||
sustained_count = sum(1 for a in future_alphas if a < 0)
|
|
||||||
sustained_rate = sustained_count / len(future_alphas)
|
|
||||||
|
|
||||||
# 峰值收益
|
|
||||||
if is_up:
|
|
||||||
peak = max(future_alphas)
|
|
||||||
else:
|
|
||||||
peak = min(future_alphas)
|
|
||||||
|
|
||||||
results.append({
|
|
||||||
'direction_correct': direction_correct,
|
|
||||||
'sustained_rate': sustained_rate,
|
|
||||||
'peak': peak,
|
|
||||||
'alert_alpha': alert_alpha,
|
|
||||||
})
|
|
||||||
|
|
||||||
if not results:
|
|
||||||
return {'accuracy': 0, 'sustained_rate': 0, 'avg_peak': 0, 'total_alerts': 0}
|
|
||||||
|
|
||||||
return {
|
|
||||||
'accuracy': np.mean([r['direction_correct'] for r in results]),
|
|
||||||
'sustained_rate': np.mean([r['sustained_rate'] for r in results]),
|
|
||||||
'avg_peak': np.mean([abs(r['peak']) for r in results]),
|
|
||||||
'total_alerts': len(alerts),
|
|
||||||
'evaluated_alerts': len(results),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def save_alerts_to_mysql(alerts: List[Dict], dry_run: bool = False) -> int:
|
|
||||||
"""保存异动到 MySQL"""
|
|
||||||
if not alerts or dry_run:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
# 确保表存在
|
|
||||||
with MYSQL_ENGINE.begin() as conn:
|
|
||||||
conn.execute(text("""
|
|
||||||
CREATE TABLE IF NOT EXISTS concept_anomaly_v2 (
|
|
||||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
|
||||||
concept_id VARCHAR(64) NOT NULL,
|
|
||||||
alert_time DATETIME NOT NULL,
|
|
||||||
trade_date DATE NOT NULL,
|
|
||||||
alert_type VARCHAR(32) NOT NULL,
|
|
||||||
final_score FLOAT NOT NULL,
|
|
||||||
rule_score FLOAT NOT NULL,
|
|
||||||
ml_score FLOAT NOT NULL,
|
|
||||||
trigger_reason VARCHAR(128),
|
|
||||||
confirm_ratio FLOAT,
|
|
||||||
alpha FLOAT,
|
|
||||||
alpha_zscore FLOAT,
|
|
||||||
amt_zscore FLOAT,
|
|
||||||
rank_zscore FLOAT,
|
|
||||||
momentum_3m FLOAT,
|
|
||||||
momentum_5m FLOAT,
|
|
||||||
limit_up_ratio FLOAT,
|
|
||||||
triggered_rules JSON,
|
|
||||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
||||||
UNIQUE KEY uk_concept_time (concept_id, alert_time, trade_date),
|
|
||||||
INDEX idx_trade_date (trade_date),
|
|
||||||
INDEX idx_final_score (final_score)
|
|
||||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='概念异动 V2(时间片对齐+持续确认)'
|
|
||||||
"""))
|
|
||||||
|
|
||||||
# 插入数据
|
|
||||||
saved = 0
|
|
||||||
with MYSQL_ENGINE.begin() as conn:
|
|
||||||
for alert in alerts:
|
|
||||||
try:
|
|
||||||
conn.execute(text("""
|
|
||||||
INSERT IGNORE INTO concept_anomaly_v2
|
|
||||||
(concept_id, alert_time, trade_date, alert_type,
|
|
||||||
final_score, rule_score, ml_score, trigger_reason, confirm_ratio,
|
|
||||||
alpha, alpha_zscore, amt_zscore, rank_zscore,
|
|
||||||
momentum_3m, momentum_5m, limit_up_ratio, triggered_rules)
|
|
||||||
VALUES
|
|
||||||
(:concept_id, :alert_time, :trade_date, :alert_type,
|
|
||||||
:final_score, :rule_score, :ml_score, :trigger_reason, :confirm_ratio,
|
|
||||||
:alpha, :alpha_zscore, :amt_zscore, :rank_zscore,
|
|
||||||
:momentum_3m, :momentum_5m, :limit_up_ratio, :triggered_rules)
|
|
||||||
"""), {
|
|
||||||
'concept_id': alert['concept_id'],
|
|
||||||
'alert_time': alert['alert_time'],
|
|
||||||
'trade_date': alert['trade_date'],
|
|
||||||
'alert_type': alert['alert_type'],
|
|
||||||
'final_score': alert['final_score'],
|
|
||||||
'rule_score': alert['rule_score'],
|
|
||||||
'ml_score': alert['ml_score'],
|
|
||||||
'trigger_reason': alert['trigger_reason'],
|
|
||||||
'confirm_ratio': alert.get('confirm_ratio', 0),
|
|
||||||
'alpha': alert['alpha'],
|
|
||||||
'alpha_zscore': alert.get('alpha_zscore', 0),
|
|
||||||
'amt_zscore': alert.get('amt_zscore', 0),
|
|
||||||
'rank_zscore': alert.get('rank_zscore', 0),
|
|
||||||
'momentum_3m': alert.get('momentum_3m', 0),
|
|
||||||
'momentum_5m': alert.get('momentum_5m', 0),
|
|
||||||
'limit_up_ratio': alert.get('limit_up_ratio', 0),
|
|
||||||
'triggered_rules': json.dumps(alert.get('triggered_rules', [])),
|
|
||||||
})
|
|
||||||
saved += 1
|
|
||||||
except Exception as e:
|
|
||||||
print(f"保存失败: {e}")
|
|
||||||
|
|
||||||
return saved
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 主函数 ====================
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser(description='V2 回测')
|
|
||||||
parser.add_argument('--start', type=str, required=True, help='开始日期')
|
|
||||||
parser.add_argument('--end', type=str, default=None, help='结束日期')
|
|
||||||
parser.add_argument('--model_dir', type=str, default='ml/checkpoints_v2')
|
|
||||||
parser.add_argument('--baseline_dir', type=str, default='ml/data_v2/baselines')
|
|
||||||
parser.add_argument('--save', action='store_true', help='保存到数据库')
|
|
||||||
parser.add_argument('--lookahead', type=int, default=10, help='评估前瞻时间(分钟)')
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
end_date = args.end or args.start
|
|
||||||
|
|
||||||
print("=" * 60)
|
|
||||||
print("V2 回测 - 时间片对齐 + 持续性确认")
|
|
||||||
print("=" * 60)
|
|
||||||
print(f"日期范围: {args.start} ~ {end_date}")
|
|
||||||
print(f"模型目录: {args.model_dir}")
|
|
||||||
print(f"评估前瞻: {args.lookahead} 分钟")
|
|
||||||
|
|
||||||
# 初始化检测器
|
|
||||||
detector = AnomalyDetectorV2(
|
|
||||||
model_dir=args.model_dir,
|
|
||||||
baseline_dir=args.baseline_dir
|
|
||||||
)
|
|
||||||
|
|
||||||
# 获取交易日
|
|
||||||
from prepare_data_v2 import get_trading_days
|
|
||||||
trading_days = get_trading_days(args.start, end_date)
|
|
||||||
|
|
||||||
if not trading_days:
|
|
||||||
print("无交易日")
|
|
||||||
return
|
|
||||||
|
|
||||||
print(f"交易日数: {len(trading_days)}")
|
|
||||||
|
|
||||||
# 回测统计
|
|
||||||
total_stats = {
|
|
||||||
'total_alerts': 0,
|
|
||||||
'accuracy_sum': 0,
|
|
||||||
'sustained_sum': 0,
|
|
||||||
'peak_sum': 0,
|
|
||||||
'day_count': 0,
|
|
||||||
}
|
|
||||||
|
|
||||||
all_alerts = []
|
|
||||||
|
|
||||||
for trade_date in tqdm(trading_days, desc="回测进度"):
|
|
||||||
# 检测异动
|
|
||||||
alerts = detector.detect(trade_date)
|
|
||||||
|
|
||||||
if not alerts:
|
|
||||||
continue
|
|
||||||
|
|
||||||
all_alerts.extend(alerts)
|
|
||||||
|
|
||||||
# 评估
|
|
||||||
raw_data = detector._compute_raw_features(trade_date)
|
|
||||||
if raw_data.empty:
|
|
||||||
continue
|
|
||||||
|
|
||||||
stats = evaluate_alerts(alerts, raw_data, args.lookahead)
|
|
||||||
|
|
||||||
if stats['evaluated_alerts'] > 0:
|
|
||||||
total_stats['total_alerts'] += stats['total_alerts']
|
|
||||||
total_stats['accuracy_sum'] += stats['accuracy'] * stats['evaluated_alerts']
|
|
||||||
total_stats['sustained_sum'] += stats['sustained_rate'] * stats['evaluated_alerts']
|
|
||||||
total_stats['peak_sum'] += stats['avg_peak'] * stats['evaluated_alerts']
|
|
||||||
total_stats['day_count'] += 1
|
|
||||||
|
|
||||||
print(f"\n[{trade_date}] 异动: {stats['total_alerts']}, "
|
|
||||||
f"准确率: {stats['accuracy']:.1%}, "
|
|
||||||
f"持续率: {stats['sustained_rate']:.1%}, "
|
|
||||||
f"峰值: {stats['avg_peak']:.2f}%")
|
|
||||||
|
|
||||||
# 汇总
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
print("回测汇总")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
if total_stats['total_alerts'] > 0:
|
|
||||||
avg_accuracy = total_stats['accuracy_sum'] / total_stats['total_alerts']
|
|
||||||
avg_sustained = total_stats['sustained_sum'] / total_stats['total_alerts']
|
|
||||||
avg_peak = total_stats['peak_sum'] / total_stats['total_alerts']
|
|
||||||
|
|
||||||
print(f"总异动数: {total_stats['total_alerts']}")
|
|
||||||
print(f"回测天数: {total_stats['day_count']}")
|
|
||||||
print(f"平均每天: {total_stats['total_alerts'] / max(1, total_stats['day_count']):.1f} 个")
|
|
||||||
print(f"方向准确率: {avg_accuracy:.1%}")
|
|
||||||
print(f"持续率: {avg_sustained:.1%}")
|
|
||||||
print(f"平均峰值: {avg_peak:.2f}%")
|
|
||||||
else:
|
|
||||||
print("无异动检测结果")
|
|
||||||
|
|
||||||
# 保存
|
|
||||||
if args.save and all_alerts:
|
|
||||||
print(f"\n保存 {len(all_alerts)} 条异动到数据库...")
|
|
||||||
saved = save_alerts_to_mysql(all_alerts)
|
|
||||||
print(f"保存完成: {saved} 条")
|
|
||||||
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -1,31 +0,0 @@
|
|||||||
{
|
|
||||||
"seq_len": 10,
|
|
||||||
"stride": 2,
|
|
||||||
"train_end_date": "2025-06-30",
|
|
||||||
"val_end_date": "2025-09-30",
|
|
||||||
"features": [
|
|
||||||
"alpha_zscore",
|
|
||||||
"amt_zscore",
|
|
||||||
"rank_zscore",
|
|
||||||
"momentum_3m",
|
|
||||||
"momentum_5m",
|
|
||||||
"limit_up_ratio"
|
|
||||||
],
|
|
||||||
"batch_size": 32768,
|
|
||||||
"epochs": 150,
|
|
||||||
"learning_rate": 0.0006,
|
|
||||||
"weight_decay": 1e-05,
|
|
||||||
"gradient_clip": 1.0,
|
|
||||||
"patience": 15,
|
|
||||||
"min_delta": 1e-06,
|
|
||||||
"model": {
|
|
||||||
"n_features": 6,
|
|
||||||
"hidden_dim": 32,
|
|
||||||
"latent_dim": 4,
|
|
||||||
"num_layers": 1,
|
|
||||||
"dropout": 0.2,
|
|
||||||
"bidirectional": true
|
|
||||||
},
|
|
||||||
"clip_value": 5.0,
|
|
||||||
"threshold_percentiles": [90, 95, 99]
|
|
||||||
}
|
|
||||||
@@ -1,8 +0,0 @@
|
|||||||
{
|
|
||||||
"p90": 0.15,
|
|
||||||
"p95": 0.25,
|
|
||||||
"p99": 0.50,
|
|
||||||
"mean": 0.08,
|
|
||||||
"std": 0.12,
|
|
||||||
"median": 0.06
|
|
||||||
}
|
|
||||||
635
ml/detector.py
635
ml/detector.py
@@ -1,635 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""
|
|
||||||
概念异动检测器 - 融合版
|
|
||||||
|
|
||||||
结合两种方法的优势:
|
|
||||||
1. 规则评分系统:可解释、稳定、覆盖已知模式
|
|
||||||
2. LSTM Autoencoder:发现未知的异常模式
|
|
||||||
|
|
||||||
融合策略:
|
|
||||||
┌─────────────────────────────────────────────────────────┐
|
|
||||||
│ 输入特征 │
|
|
||||||
│ (alpha, alpha_delta, amt_ratio, amt_delta, rank_pct, │
|
|
||||||
│ limit_up_ratio) │
|
|
||||||
├─────────────────────────────────────────────────────────┤
|
|
||||||
│ │
|
|
||||||
│ ┌──────────────┐ ┌──────────────┐ │
|
|
||||||
│ │ 规则评分系统 │ │ LSTM Autoencoder │ │
|
|
||||||
│ │ (0-100分) │ │ (重构误差) │ │
|
|
||||||
│ └──────┬───────┘ └──────┬───────┘ │
|
|
||||||
│ │ │ │
|
|
||||||
│ ▼ ▼ │
|
|
||||||
│ rule_score (0-100) ml_score (标准化后 0-100) │
|
|
||||||
│ │
|
|
||||||
├─────────────────────────────────────────────────────────┤
|
|
||||||
│ 融合策略 │
|
|
||||||
│ │
|
|
||||||
│ final_score = w1 * rule_score + w2 * ml_score │
|
|
||||||
│ │
|
|
||||||
│ 异动判定: │
|
|
||||||
│ - rule_score >= 60 → 直接触发(规则强信号) │
|
|
||||||
│ - ml_score >= 80 → 直接触发(ML强信号) │
|
|
||||||
│ - final_score >= 50 → 融合触发 │
|
|
||||||
│ │
|
|
||||||
└─────────────────────────────────────────────────────────┘
|
|
||||||
|
|
||||||
优势:
|
|
||||||
- 规则系统保证已知模式的检出率
|
|
||||||
- ML模型捕捉规则未覆盖的异常
|
|
||||||
- 两者互相验证,减少误报
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Dict, List, Optional, Tuple
|
|
||||||
from dataclasses import dataclass
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
|
|
||||||
# 尝试导入模型(可能不存在)
|
|
||||||
try:
|
|
||||||
from model import LSTMAutoencoder, create_model
|
|
||||||
HAS_MODEL = True
|
|
||||||
except ImportError:
|
|
||||||
HAS_MODEL = False
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class AnomalyResult:
|
|
||||||
"""异动检测结果"""
|
|
||||||
is_anomaly: bool
|
|
||||||
final_score: float # 最终得分 (0-100)
|
|
||||||
rule_score: float # 规则得分 (0-100)
|
|
||||||
ml_score: float # ML得分 (0-100)
|
|
||||||
trigger_reason: str # 触发原因
|
|
||||||
rule_details: Dict # 规则明细
|
|
||||||
anomaly_type: str # 异动类型: surge_up / surge_down / volume_spike / unknown
|
|
||||||
|
|
||||||
|
|
||||||
class RuleBasedScorer:
|
|
||||||
"""
|
|
||||||
基于规则的评分系统
|
|
||||||
|
|
||||||
设计原则:
|
|
||||||
- 每个规则独立打分
|
|
||||||
- 分数可叠加
|
|
||||||
- 阈值可配置
|
|
||||||
"""
|
|
||||||
|
|
||||||
# 默认规则配置
|
|
||||||
DEFAULT_RULES = {
|
|
||||||
# Alpha 相关(超额收益)
|
|
||||||
'alpha_strong': {
|
|
||||||
'condition': lambda r: abs(r.get('alpha', 0)) >= 3.0,
|
|
||||||
'score': 35,
|
|
||||||
'description': 'Alpha强信号(|α|≥3%)'
|
|
||||||
},
|
|
||||||
'alpha_medium': {
|
|
||||||
'condition': lambda r: 2.0 <= abs(r.get('alpha', 0)) < 3.0,
|
|
||||||
'score': 25,
|
|
||||||
'description': 'Alpha中等(2%≤|α|<3%)'
|
|
||||||
},
|
|
||||||
'alpha_weak': {
|
|
||||||
'condition': lambda r: 1.5 <= abs(r.get('alpha', 0)) < 2.0,
|
|
||||||
'score': 15,
|
|
||||||
'description': 'Alpha轻微(1.5%≤|α|<2%)'
|
|
||||||
},
|
|
||||||
|
|
||||||
# Alpha 变化率(加速度)
|
|
||||||
'alpha_delta_strong': {
|
|
||||||
'condition': lambda r: abs(r.get('alpha_delta', 0)) >= 1.0,
|
|
||||||
'score': 30,
|
|
||||||
'description': 'Alpha加速强(|Δα|≥1%)'
|
|
||||||
},
|
|
||||||
'alpha_delta_medium': {
|
|
||||||
'condition': lambda r: 0.5 <= abs(r.get('alpha_delta', 0)) < 1.0,
|
|
||||||
'score': 20,
|
|
||||||
'description': 'Alpha加速中(0.5%≤|Δα|<1%)'
|
|
||||||
},
|
|
||||||
|
|
||||||
# 成交额比率(放量)
|
|
||||||
'volume_spike_strong': {
|
|
||||||
'condition': lambda r: r.get('amt_ratio', 1) >= 5.0,
|
|
||||||
'score': 30,
|
|
||||||
'description': '极度放量(≥5倍)'
|
|
||||||
},
|
|
||||||
'volume_spike_medium': {
|
|
||||||
'condition': lambda r: 3.0 <= r.get('amt_ratio', 1) < 5.0,
|
|
||||||
'score': 20,
|
|
||||||
'description': '显著放量(3-5倍)'
|
|
||||||
},
|
|
||||||
'volume_spike_weak': {
|
|
||||||
'condition': lambda r: 2.0 <= r.get('amt_ratio', 1) < 3.0,
|
|
||||||
'score': 10,
|
|
||||||
'description': '轻微放量(2-3倍)'
|
|
||||||
},
|
|
||||||
|
|
||||||
# 成交额变化率
|
|
||||||
'amt_delta_strong': {
|
|
||||||
'condition': lambda r: abs(r.get('amt_delta', 0)) >= 1.0,
|
|
||||||
'score': 15,
|
|
||||||
'description': '成交额急变(|Δamt|≥100%)'
|
|
||||||
},
|
|
||||||
|
|
||||||
# 排名跳变
|
|
||||||
'rank_top': {
|
|
||||||
'condition': lambda r: r.get('rank_pct', 0.5) >= 0.95,
|
|
||||||
'score': 25,
|
|
||||||
'description': '排名前5%'
|
|
||||||
},
|
|
||||||
'rank_bottom': {
|
|
||||||
'condition': lambda r: r.get('rank_pct', 0.5) <= 0.05,
|
|
||||||
'score': 25,
|
|
||||||
'description': '排名后5%'
|
|
||||||
},
|
|
||||||
'rank_high': {
|
|
||||||
'condition': lambda r: 0.9 <= r.get('rank_pct', 0.5) < 0.95,
|
|
||||||
'score': 15,
|
|
||||||
'description': '排名前10%'
|
|
||||||
},
|
|
||||||
|
|
||||||
# 涨停比例
|
|
||||||
'limit_up_high': {
|
|
||||||
'condition': lambda r: r.get('limit_up_ratio', 0) >= 0.2,
|
|
||||||
'score': 25,
|
|
||||||
'description': '涨停比例≥20%'
|
|
||||||
},
|
|
||||||
'limit_up_medium': {
|
|
||||||
'condition': lambda r: 0.1 <= r.get('limit_up_ratio', 0) < 0.2,
|
|
||||||
'score': 15,
|
|
||||||
'description': '涨停比例10-20%'
|
|
||||||
},
|
|
||||||
|
|
||||||
# 组合条件(更可靠的信号)
|
|
||||||
'alpha_with_volume': {
|
|
||||||
'condition': lambda r: abs(r.get('alpha', 0)) >= 1.5 and r.get('amt_ratio', 1) >= 2.0,
|
|
||||||
'score': 20, # 额外加分
|
|
||||||
'description': 'Alpha+放量组合'
|
|
||||||
},
|
|
||||||
'acceleration_with_rank': {
|
|
||||||
'condition': lambda r: abs(r.get('alpha_delta', 0)) >= 0.5 and (r.get('rank_pct', 0.5) >= 0.9 or r.get('rank_pct', 0.5) <= 0.1),
|
|
||||||
'score': 15, # 额外加分
|
|
||||||
'description': '加速+排名异常组合'
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
def __init__(self, rules: Dict = None):
|
|
||||||
"""
|
|
||||||
初始化规则评分器
|
|
||||||
|
|
||||||
Args:
|
|
||||||
rules: 自定义规则,格式同 DEFAULT_RULES
|
|
||||||
"""
|
|
||||||
self.rules = rules or self.DEFAULT_RULES
|
|
||||||
|
|
||||||
def score(self, features: Dict) -> Tuple[float, Dict]:
|
|
||||||
"""
|
|
||||||
计算规则得分
|
|
||||||
|
|
||||||
Args:
|
|
||||||
features: 特征字典,包含 alpha, alpha_delta, amt_ratio 等
|
|
||||||
Returns:
|
|
||||||
score: 总分 (0-100)
|
|
||||||
details: 触发的规则明细
|
|
||||||
"""
|
|
||||||
total_score = 0
|
|
||||||
triggered_rules = {}
|
|
||||||
|
|
||||||
for rule_name, rule_config in self.rules.items():
|
|
||||||
try:
|
|
||||||
if rule_config['condition'](features):
|
|
||||||
total_score += rule_config['score']
|
|
||||||
triggered_rules[rule_name] = {
|
|
||||||
'score': rule_config['score'],
|
|
||||||
'description': rule_config['description']
|
|
||||||
}
|
|
||||||
except Exception:
|
|
||||||
# 忽略规则计算错误
|
|
||||||
pass
|
|
||||||
|
|
||||||
# 限制在 0-100
|
|
||||||
total_score = min(100, max(0, total_score))
|
|
||||||
|
|
||||||
return total_score, triggered_rules
|
|
||||||
|
|
||||||
def get_anomaly_type(self, features: Dict) -> str:
|
|
||||||
"""判断异动类型"""
|
|
||||||
alpha = features.get('alpha', 0)
|
|
||||||
amt_ratio = features.get('amt_ratio', 1)
|
|
||||||
|
|
||||||
if alpha >= 1.5:
|
|
||||||
return 'surge_up'
|
|
||||||
elif alpha <= -1.5:
|
|
||||||
return 'surge_down'
|
|
||||||
elif amt_ratio >= 3.0:
|
|
||||||
return 'volume_spike'
|
|
||||||
else:
|
|
||||||
return 'unknown'
|
|
||||||
|
|
||||||
|
|
||||||
class MLScorer:
|
|
||||||
"""
|
|
||||||
基于 LSTM Autoencoder 的评分器
|
|
||||||
|
|
||||||
将重构误差转换为 0-100 的分数
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
checkpoint_dir: str = 'ml/checkpoints',
|
|
||||||
device: str = 'auto'
|
|
||||||
):
|
|
||||||
self.checkpoint_dir = Path(checkpoint_dir)
|
|
||||||
|
|
||||||
# 设备检测
|
|
||||||
if device == 'auto':
|
|
||||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
||||||
elif device == 'cuda' and not torch.cuda.is_available():
|
|
||||||
print("警告: CUDA 不可用,使用 CPU")
|
|
||||||
self.device = torch.device('cpu')
|
|
||||||
else:
|
|
||||||
self.device = torch.device(device)
|
|
||||||
|
|
||||||
self.model = None
|
|
||||||
self.thresholds = None
|
|
||||||
self.config = None
|
|
||||||
|
|
||||||
# 尝试加载模型
|
|
||||||
self._load_model()
|
|
||||||
|
|
||||||
def _load_model(self):
|
|
||||||
"""加载模型和阈值"""
|
|
||||||
if not HAS_MODEL:
|
|
||||||
print("警告: 无法导入模型模块")
|
|
||||||
return
|
|
||||||
|
|
||||||
model_path = self.checkpoint_dir / 'best_model.pt'
|
|
||||||
thresholds_path = self.checkpoint_dir / 'thresholds.json'
|
|
||||||
config_path = self.checkpoint_dir / 'config.json'
|
|
||||||
|
|
||||||
if not model_path.exists():
|
|
||||||
print(f"警告: 模型文件不存在 {model_path}")
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 加载配置
|
|
||||||
if config_path.exists():
|
|
||||||
with open(config_path, 'r') as f:
|
|
||||||
self.config = json.load(f)
|
|
||||||
|
|
||||||
# 先用 CPU 加载模型(避免 CUDA 不可用问题),再移动到目标设备
|
|
||||||
checkpoint = torch.load(model_path, map_location='cpu')
|
|
||||||
|
|
||||||
model_config = self.config.get('model', {}) if self.config else {}
|
|
||||||
self.model = create_model(model_config)
|
|
||||||
self.model.load_state_dict(checkpoint['model_state_dict'])
|
|
||||||
self.model.to(self.device)
|
|
||||||
self.model.eval()
|
|
||||||
|
|
||||||
# 加载阈值
|
|
||||||
if thresholds_path.exists():
|
|
||||||
with open(thresholds_path, 'r') as f:
|
|
||||||
self.thresholds = json.load(f)
|
|
||||||
|
|
||||||
print(f"MLScorer 加载成功 (设备: {self.device})")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"警告: 模型加载失败 - {e}")
|
|
||||||
import traceback
|
|
||||||
traceback.print_exc()
|
|
||||||
self.model = None
|
|
||||||
|
|
||||||
def is_ready(self) -> bool:
|
|
||||||
"""检查模型是否就绪"""
|
|
||||||
return self.model is not None
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def score(self, sequence: np.ndarray) -> float:
|
|
||||||
"""
|
|
||||||
计算 ML 得分
|
|
||||||
|
|
||||||
Args:
|
|
||||||
sequence: (seq_len, n_features) 或 (batch, seq_len, n_features)
|
|
||||||
Returns:
|
|
||||||
score: 0-100 的分数,越高越异常
|
|
||||||
"""
|
|
||||||
if not self.is_ready():
|
|
||||||
return 0.0
|
|
||||||
|
|
||||||
# 确保是 3D
|
|
||||||
if sequence.ndim == 2:
|
|
||||||
sequence = sequence[np.newaxis, ...]
|
|
||||||
|
|
||||||
# 转为 tensor
|
|
||||||
x = torch.FloatTensor(sequence).to(self.device)
|
|
||||||
|
|
||||||
# 计算重构误差
|
|
||||||
output, _ = self.model(x)
|
|
||||||
mse = ((output - x) ** 2).mean(dim=-1) # (batch, seq_len)
|
|
||||||
|
|
||||||
# 取最后时刻的误差
|
|
||||||
error = mse[:, -1].cpu().numpy()
|
|
||||||
|
|
||||||
# 转换为 0-100 分数
|
|
||||||
# 使用 p95 阈值作为参考
|
|
||||||
if self.thresholds:
|
|
||||||
p95 = self.thresholds.get('p95', 0.1)
|
|
||||||
p99 = self.thresholds.get('p99', 0.2)
|
|
||||||
else:
|
|
||||||
p95, p99 = 0.1, 0.2
|
|
||||||
|
|
||||||
# 线性映射:p95 -> 50分, p99 -> 80分
|
|
||||||
# error=0 -> 0分, error>=p99*1.5 -> 100分
|
|
||||||
score = np.clip(error / p95 * 50, 0, 100)
|
|
||||||
|
|
||||||
return float(score[0]) if len(score) == 1 else score.tolist()
|
|
||||||
|
|
||||||
|
|
||||||
class HybridAnomalyDetector:
|
|
||||||
"""
|
|
||||||
融合异动检测器
|
|
||||||
|
|
||||||
结合规则系统和 ML 模型
|
|
||||||
"""
|
|
||||||
|
|
||||||
# 默认配置
|
|
||||||
DEFAULT_CONFIG = {
|
|
||||||
# 权重配置
|
|
||||||
'rule_weight': 0.6, # 规则权重
|
|
||||||
'ml_weight': 0.4, # ML权重
|
|
||||||
|
|
||||||
# 触发阈值
|
|
||||||
'rule_trigger': 60, # 规则直接触发阈值
|
|
||||||
'ml_trigger': 80, # ML直接触发阈值
|
|
||||||
'fusion_trigger': 50, # 融合触发阈值
|
|
||||||
|
|
||||||
# 特征列表
|
|
||||||
'features': [
|
|
||||||
'alpha', 'alpha_delta', 'amt_ratio',
|
|
||||||
'amt_delta', 'rank_pct', 'limit_up_ratio'
|
|
||||||
],
|
|
||||||
|
|
||||||
# 序列长度(ML模型需要)
|
|
||||||
'seq_len': 30,
|
|
||||||
}
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
config: Dict = None,
|
|
||||||
checkpoint_dir: str = 'ml/checkpoints',
|
|
||||||
device: str = 'auto'
|
|
||||||
):
|
|
||||||
self.config = {**self.DEFAULT_CONFIG, **(config or {})}
|
|
||||||
|
|
||||||
# 初始化评分器
|
|
||||||
self.rule_scorer = RuleBasedScorer()
|
|
||||||
self.ml_scorer = MLScorer(checkpoint_dir, device)
|
|
||||||
|
|
||||||
print(f"HybridAnomalyDetector 初始化完成")
|
|
||||||
print(f" 规则权重: {self.config['rule_weight']}")
|
|
||||||
print(f" ML权重: {self.config['ml_weight']}")
|
|
||||||
print(f" ML模型: {'就绪' if self.ml_scorer.is_ready() else '未加载'}")
|
|
||||||
|
|
||||||
def detect(
|
|
||||||
self,
|
|
||||||
features: Dict,
|
|
||||||
sequence: np.ndarray = None
|
|
||||||
) -> AnomalyResult:
|
|
||||||
"""
|
|
||||||
检测异动
|
|
||||||
|
|
||||||
Args:
|
|
||||||
features: 当前时刻的特征字典
|
|
||||||
sequence: 历史序列 (seq_len, n_features),ML模型需要
|
|
||||||
Returns:
|
|
||||||
AnomalyResult: 检测结果
|
|
||||||
"""
|
|
||||||
# 1. 规则评分
|
|
||||||
rule_score, rule_details = self.rule_scorer.score(features)
|
|
||||||
|
|
||||||
# 2. ML评分
|
|
||||||
ml_score = 0.0
|
|
||||||
if sequence is not None and self.ml_scorer.is_ready():
|
|
||||||
ml_score = self.ml_scorer.score(sequence)
|
|
||||||
|
|
||||||
# 3. 融合得分
|
|
||||||
w1 = self.config['rule_weight']
|
|
||||||
w2 = self.config['ml_weight']
|
|
||||||
|
|
||||||
# 如果ML不可用,全部权重给规则
|
|
||||||
if not self.ml_scorer.is_ready():
|
|
||||||
w1, w2 = 1.0, 0.0
|
|
||||||
|
|
||||||
final_score = w1 * rule_score + w2 * ml_score
|
|
||||||
|
|
||||||
# 4. 判断是否异动
|
|
||||||
is_anomaly = False
|
|
||||||
trigger_reason = ''
|
|
||||||
|
|
||||||
if rule_score >= self.config['rule_trigger']:
|
|
||||||
is_anomaly = True
|
|
||||||
trigger_reason = f'规则强信号({rule_score:.0f}分)'
|
|
||||||
elif ml_score >= self.config['ml_trigger']:
|
|
||||||
is_anomaly = True
|
|
||||||
trigger_reason = f'ML强信号({ml_score:.0f}分)'
|
|
||||||
elif final_score >= self.config['fusion_trigger']:
|
|
||||||
is_anomaly = True
|
|
||||||
trigger_reason = f'融合触发({final_score:.0f}分)'
|
|
||||||
|
|
||||||
# 5. 判断异动类型
|
|
||||||
anomaly_type = self.rule_scorer.get_anomaly_type(features) if is_anomaly else ''
|
|
||||||
|
|
||||||
return AnomalyResult(
|
|
||||||
is_anomaly=is_anomaly,
|
|
||||||
final_score=final_score,
|
|
||||||
rule_score=rule_score,
|
|
||||||
ml_score=ml_score,
|
|
||||||
trigger_reason=trigger_reason,
|
|
||||||
rule_details=rule_details,
|
|
||||||
anomaly_type=anomaly_type
|
|
||||||
)
|
|
||||||
|
|
||||||
def detect_batch(
|
|
||||||
self,
|
|
||||||
features_list: List[Dict],
|
|
||||||
sequences: np.ndarray = None
|
|
||||||
) -> List[AnomalyResult]:
|
|
||||||
"""
|
|
||||||
批量检测
|
|
||||||
|
|
||||||
Args:
|
|
||||||
features_list: 特征字典列表
|
|
||||||
sequences: (batch, seq_len, n_features)
|
|
||||||
Returns:
|
|
||||||
List[AnomalyResult]
|
|
||||||
"""
|
|
||||||
results = []
|
|
||||||
|
|
||||||
for i, features in enumerate(features_list):
|
|
||||||
seq = sequences[i] if sequences is not None else None
|
|
||||||
result = self.detect(features, seq)
|
|
||||||
results.append(result)
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 便捷函数 ====================
|
|
||||||
|
|
||||||
def create_detector(
|
|
||||||
checkpoint_dir: str = 'ml/checkpoints',
|
|
||||||
config: Dict = None
|
|
||||||
) -> HybridAnomalyDetector:
|
|
||||||
"""创建融合检测器"""
|
|
||||||
return HybridAnomalyDetector(config, checkpoint_dir)
|
|
||||||
|
|
||||||
|
|
||||||
def quick_detect(features: Dict) -> bool:
|
|
||||||
"""
|
|
||||||
快速检测(只用规则,不需要ML模型)
|
|
||||||
|
|
||||||
适用于:
|
|
||||||
- 实时检测
|
|
||||||
- ML模型未训练完成时
|
|
||||||
"""
|
|
||||||
scorer = RuleBasedScorer()
|
|
||||||
score, _ = scorer.score(features)
|
|
||||||
return score >= 50
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 测试 ====================
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
print("=" * 60)
|
|
||||||
print("融合异动检测器测试")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
# 创建检测器
|
|
||||||
detector = create_detector()
|
|
||||||
|
|
||||||
# 测试用例
|
|
||||||
test_cases = [
|
|
||||||
{
|
|
||||||
'name': '正常情况',
|
|
||||||
'features': {
|
|
||||||
'alpha': 0.5,
|
|
||||||
'alpha_delta': 0.1,
|
|
||||||
'amt_ratio': 1.2,
|
|
||||||
'amt_delta': 0.1,
|
|
||||||
'rank_pct': 0.5,
|
|
||||||
'limit_up_ratio': 0.02
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
'name': 'Alpha异动',
|
|
||||||
'features': {
|
|
||||||
'alpha': 3.5,
|
|
||||||
'alpha_delta': 0.8,
|
|
||||||
'amt_ratio': 2.5,
|
|
||||||
'amt_delta': 0.5,
|
|
||||||
'rank_pct': 0.92,
|
|
||||||
'limit_up_ratio': 0.05
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
'name': '放量异动',
|
|
||||||
'features': {
|
|
||||||
'alpha': 1.2,
|
|
||||||
'alpha_delta': 0.3,
|
|
||||||
'amt_ratio': 6.0,
|
|
||||||
'amt_delta': 1.5,
|
|
||||||
'rank_pct': 0.85,
|
|
||||||
'limit_up_ratio': 0.08
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
'name': '涨停潮',
|
|
||||||
'features': {
|
|
||||||
'alpha': 2.5,
|
|
||||||
'alpha_delta': 0.6,
|
|
||||||
'amt_ratio': 3.5,
|
|
||||||
'amt_delta': 0.8,
|
|
||||||
'rank_pct': 0.98,
|
|
||||||
'limit_up_ratio': 0.25
|
|
||||||
}
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
print("\n" + "-" * 60)
|
|
||||||
print("测试1: 只用规则(无序列数据)")
|
|
||||||
print("-" * 60)
|
|
||||||
|
|
||||||
for case in test_cases:
|
|
||||||
result = detector.detect(case['features'])
|
|
||||||
|
|
||||||
print(f"\n{case['name']}:")
|
|
||||||
print(f" 异动: {'是' if result.is_anomaly else '否'}")
|
|
||||||
print(f" 最终得分: {result.final_score:.1f}")
|
|
||||||
print(f" 规则得分: {result.rule_score:.1f}")
|
|
||||||
print(f" ML得分: {result.ml_score:.1f}")
|
|
||||||
if result.is_anomaly:
|
|
||||||
print(f" 触发原因: {result.trigger_reason}")
|
|
||||||
print(f" 异动类型: {result.anomaly_type}")
|
|
||||||
print(f" 触发规则: {list(result.rule_details.keys())}")
|
|
||||||
|
|
||||||
# 测试2: 带序列数据的融合检测
|
|
||||||
print("\n" + "-" * 60)
|
|
||||||
print("测试2: 融合检测(规则 + ML)")
|
|
||||||
print("-" * 60)
|
|
||||||
|
|
||||||
# 生成模拟序列数据
|
|
||||||
seq_len = 30
|
|
||||||
n_features = 6
|
|
||||||
|
|
||||||
# 正常序列:小幅波动
|
|
||||||
normal_sequence = np.random.randn(seq_len, n_features) * 0.3
|
|
||||||
normal_sequence[:, 0] = np.linspace(0, 0.5, seq_len) # alpha 缓慢上升
|
|
||||||
normal_sequence[:, 2] = np.abs(normal_sequence[:, 2]) + 1 # amt_ratio > 0
|
|
||||||
|
|
||||||
# 异常序列:最后几个时间步突然变化
|
|
||||||
anomaly_sequence = np.random.randn(seq_len, n_features) * 0.3
|
|
||||||
anomaly_sequence[-5:, 0] = np.linspace(1, 4, 5) # alpha 突然飙升
|
|
||||||
anomaly_sequence[-5:, 1] = np.linspace(0.2, 1.5, 5) # alpha_delta 加速
|
|
||||||
anomaly_sequence[-5:, 2] = np.linspace(2, 6, 5) # amt_ratio 放量
|
|
||||||
anomaly_sequence[:, 2] = np.abs(anomaly_sequence[:, 2]) + 1
|
|
||||||
|
|
||||||
# 测试正常序列
|
|
||||||
normal_features = {
|
|
||||||
'alpha': float(normal_sequence[-1, 0]),
|
|
||||||
'alpha_delta': float(normal_sequence[-1, 1]),
|
|
||||||
'amt_ratio': float(normal_sequence[-1, 2]),
|
|
||||||
'amt_delta': float(normal_sequence[-1, 3]),
|
|
||||||
'rank_pct': 0.5,
|
|
||||||
'limit_up_ratio': 0.02
|
|
||||||
}
|
|
||||||
|
|
||||||
result = detector.detect(normal_features, normal_sequence)
|
|
||||||
print(f"\n正常序列:")
|
|
||||||
print(f" 异动: {'是' if result.is_anomaly else '否'}")
|
|
||||||
print(f" 最终得分: {result.final_score:.1f}")
|
|
||||||
print(f" 规则得分: {result.rule_score:.1f}")
|
|
||||||
print(f" ML得分: {result.ml_score:.1f}")
|
|
||||||
|
|
||||||
# 测试异常序列
|
|
||||||
anomaly_features = {
|
|
||||||
'alpha': float(anomaly_sequence[-1, 0]),
|
|
||||||
'alpha_delta': float(anomaly_sequence[-1, 1]),
|
|
||||||
'amt_ratio': float(anomaly_sequence[-1, 2]),
|
|
||||||
'amt_delta': float(anomaly_sequence[-1, 3]),
|
|
||||||
'rank_pct': 0.95,
|
|
||||||
'limit_up_ratio': 0.15
|
|
||||||
}
|
|
||||||
|
|
||||||
result = detector.detect(anomaly_features, anomaly_sequence)
|
|
||||||
print(f"\n异常序列:")
|
|
||||||
print(f" 异动: {'是' if result.is_anomaly else '否'}")
|
|
||||||
print(f" 最终得分: {result.final_score:.1f}")
|
|
||||||
print(f" 规则得分: {result.rule_score:.1f}")
|
|
||||||
print(f" ML得分: {result.ml_score:.1f}")
|
|
||||||
if result.is_anomaly:
|
|
||||||
print(f" 触发原因: {result.trigger_reason}")
|
|
||||||
print(f" 异动类型: {result.anomaly_type}")
|
|
||||||
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
print("测试完成!")
|
|
||||||
@@ -1,716 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""
|
|
||||||
异动检测器 V2 - 基于时间片对齐 + 持续性确认
|
|
||||||
|
|
||||||
核心改进:
|
|
||||||
1. Z-Score 特征:相对于同时间片历史的偏离
|
|
||||||
2. 短序列 LSTM:10分钟序列,开盘即可用
|
|
||||||
3. 持续性确认:5分钟窗口内60%时刻超标才确认为异动
|
|
||||||
|
|
||||||
检测流程:
|
|
||||||
1. 计算当前时刻的 Z-Score(对比同时间片历史基线)
|
|
||||||
2. 构建最近10分钟的 Z-Score 序列
|
|
||||||
3. LSTM 计算重构误差(ML分数)
|
|
||||||
4. 规则评分(基于 Z-Score 的规则)
|
|
||||||
5. 滑动窗口确认:最近5分钟内是否有足够多的时刻超标
|
|
||||||
6. 只有通过持续性确认的才输出为异动
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import json
|
|
||||||
import pickle
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from typing import Dict, List, Optional, Tuple
|
|
||||||
from collections import defaultdict, deque
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pandas as pd
|
|
||||||
import torch
|
|
||||||
from sqlalchemy import create_engine, text
|
|
||||||
from elasticsearch import Elasticsearch
|
|
||||||
from clickhouse_driver import Client
|
|
||||||
|
|
||||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
||||||
|
|
||||||
from ml.model import TransformerAutoencoder
|
|
||||||
|
|
||||||
# ==================== 配置 ====================
|
|
||||||
|
|
||||||
MYSQL_ENGINE = create_engine(
|
|
||||||
"mysql+pymysql://root:Zzl5588161!@192.168.1.5:3306/stock",
|
|
||||||
echo=False
|
|
||||||
)
|
|
||||||
|
|
||||||
ES_CLIENT = Elasticsearch(['http://127.0.0.1:9200'])
|
|
||||||
ES_INDEX = 'concept_library_v3'
|
|
||||||
|
|
||||||
CLICKHOUSE_CONFIG = {
|
|
||||||
'host': '127.0.0.1',
|
|
||||||
'port': 9000,
|
|
||||||
'user': 'default',
|
|
||||||
'password': 'Zzl33818!',
|
|
||||||
'database': 'stock'
|
|
||||||
}
|
|
||||||
|
|
||||||
REFERENCE_INDEX = '000001.SH'
|
|
||||||
|
|
||||||
# 检测配置
|
|
||||||
CONFIG = {
|
|
||||||
# 序列配置
|
|
||||||
'seq_len': 10, # LSTM 序列长度(分钟)
|
|
||||||
|
|
||||||
# 持续性确认配置(核心!)
|
|
||||||
'confirm_window': 5, # 确认窗口(分钟)
|
|
||||||
'confirm_ratio': 0.6, # 确认比例(60%时刻需要超标)
|
|
||||||
|
|
||||||
# Z-Score 阈值
|
|
||||||
'alpha_zscore_threshold': 2.0, # Alpha Z-Score 阈值
|
|
||||||
'amt_zscore_threshold': 2.5, # 成交额 Z-Score 阈值
|
|
||||||
|
|
||||||
# 融合权重
|
|
||||||
'rule_weight': 0.5,
|
|
||||||
'ml_weight': 0.5,
|
|
||||||
|
|
||||||
# 触发阈值
|
|
||||||
'rule_trigger': 60,
|
|
||||||
'ml_trigger': 70,
|
|
||||||
'fusion_trigger': 50,
|
|
||||||
|
|
||||||
# 冷却期
|
|
||||||
'cooldown_minutes': 10,
|
|
||||||
'max_alerts_per_minute': 15,
|
|
||||||
|
|
||||||
# Z-Score 截断
|
|
||||||
'zscore_clip': 5.0,
|
|
||||||
}
|
|
||||||
|
|
||||||
# V2 特征列表
|
|
||||||
FEATURES_V2 = [
|
|
||||||
'alpha_zscore', 'amt_zscore', 'rank_zscore',
|
|
||||||
'momentum_3m', 'momentum_5m', 'limit_up_ratio'
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 工具函数 ====================
|
|
||||||
|
|
||||||
def get_ch_client():
|
|
||||||
return Client(**CLICKHOUSE_CONFIG)
|
|
||||||
|
|
||||||
|
|
||||||
def code_to_ch_format(code: str) -> str:
|
|
||||||
if not code or len(code) != 6 or not code.isdigit():
|
|
||||||
return None
|
|
||||||
if code.startswith('6'):
|
|
||||||
return f"{code}.SH"
|
|
||||||
elif code.startswith('0') or code.startswith('3'):
|
|
||||||
return f"{code}.SZ"
|
|
||||||
else:
|
|
||||||
return f"{code}.BJ"
|
|
||||||
|
|
||||||
|
|
||||||
def time_to_slot(ts) -> str:
|
|
||||||
"""时间戳转时间片(HH:MM)"""
|
|
||||||
if isinstance(ts, str):
|
|
||||||
return ts
|
|
||||||
return ts.strftime('%H:%M')
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 基线加载 ====================
|
|
||||||
|
|
||||||
def load_baselines(baseline_dir: str = 'ml/data_v2/baselines') -> Dict[str, pd.DataFrame]:
|
|
||||||
"""加载时间片基线"""
|
|
||||||
baseline_file = os.path.join(baseline_dir, 'baselines.pkl')
|
|
||||||
if os.path.exists(baseline_file):
|
|
||||||
with open(baseline_file, 'rb') as f:
|
|
||||||
return pickle.load(f)
|
|
||||||
return {}
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 规则评分(基于 Z-Score)====================
|
|
||||||
|
|
||||||
def score_rules_zscore(row: Dict) -> Tuple[float, List[str]]:
|
|
||||||
"""
|
|
||||||
基于 Z-Score 的规则评分
|
|
||||||
|
|
||||||
设计思路:Z-Score 已经标准化,直接用阈值判断
|
|
||||||
"""
|
|
||||||
score = 0.0
|
|
||||||
triggered = []
|
|
||||||
|
|
||||||
alpha_zscore = row.get('alpha_zscore', 0)
|
|
||||||
amt_zscore = row.get('amt_zscore', 0)
|
|
||||||
rank_zscore = row.get('rank_zscore', 0)
|
|
||||||
momentum_3m = row.get('momentum_3m', 0)
|
|
||||||
momentum_5m = row.get('momentum_5m', 0)
|
|
||||||
limit_up_ratio = row.get('limit_up_ratio', 0)
|
|
||||||
|
|
||||||
alpha_zscore_abs = abs(alpha_zscore)
|
|
||||||
amt_zscore_abs = abs(amt_zscore)
|
|
||||||
|
|
||||||
# ========== Alpha Z-Score 规则 ==========
|
|
||||||
if alpha_zscore_abs >= 4.0:
|
|
||||||
score += 25
|
|
||||||
triggered.append('alpha_zscore_extreme')
|
|
||||||
elif alpha_zscore_abs >= 3.0:
|
|
||||||
score += 18
|
|
||||||
triggered.append('alpha_zscore_strong')
|
|
||||||
elif alpha_zscore_abs >= 2.0:
|
|
||||||
score += 10
|
|
||||||
triggered.append('alpha_zscore_moderate')
|
|
||||||
|
|
||||||
# ========== 成交额 Z-Score 规则 ==========
|
|
||||||
if amt_zscore >= 4.0:
|
|
||||||
score += 20
|
|
||||||
triggered.append('amt_zscore_extreme')
|
|
||||||
elif amt_zscore >= 3.0:
|
|
||||||
score += 12
|
|
||||||
triggered.append('amt_zscore_strong')
|
|
||||||
elif amt_zscore >= 2.0:
|
|
||||||
score += 6
|
|
||||||
triggered.append('amt_zscore_moderate')
|
|
||||||
|
|
||||||
# ========== 排名 Z-Score 规则 ==========
|
|
||||||
if abs(rank_zscore) >= 3.0:
|
|
||||||
score += 15
|
|
||||||
triggered.append('rank_zscore_extreme')
|
|
||||||
elif abs(rank_zscore) >= 2.0:
|
|
||||||
score += 8
|
|
||||||
triggered.append('rank_zscore_strong')
|
|
||||||
|
|
||||||
# ========== 动量规则 ==========
|
|
||||||
if momentum_3m >= 1.0:
|
|
||||||
score += 12
|
|
||||||
triggered.append('momentum_3m_strong')
|
|
||||||
elif momentum_3m >= 0.5:
|
|
||||||
score += 6
|
|
||||||
triggered.append('momentum_3m_moderate')
|
|
||||||
|
|
||||||
if momentum_5m >= 1.5:
|
|
||||||
score += 10
|
|
||||||
triggered.append('momentum_5m_strong')
|
|
||||||
|
|
||||||
# ========== 涨停比例规则 ==========
|
|
||||||
if limit_up_ratio >= 0.3:
|
|
||||||
score += 20
|
|
||||||
triggered.append('limit_up_extreme')
|
|
||||||
elif limit_up_ratio >= 0.15:
|
|
||||||
score += 12
|
|
||||||
triggered.append('limit_up_strong')
|
|
||||||
elif limit_up_ratio >= 0.08:
|
|
||||||
score += 5
|
|
||||||
triggered.append('limit_up_moderate')
|
|
||||||
|
|
||||||
# ========== 组合规则 ==========
|
|
||||||
# Alpha Z-Score + 成交额放大
|
|
||||||
if alpha_zscore_abs >= 2.0 and amt_zscore >= 2.0:
|
|
||||||
score += 15
|
|
||||||
triggered.append('combo_alpha_amt')
|
|
||||||
|
|
||||||
# Alpha Z-Score + 涨停
|
|
||||||
if alpha_zscore_abs >= 2.0 and limit_up_ratio >= 0.1:
|
|
||||||
score += 12
|
|
||||||
triggered.append('combo_alpha_limitup')
|
|
||||||
|
|
||||||
return min(score, 100), triggered
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== ML 评分器 ====================
|
|
||||||
|
|
||||||
class MLScorerV2:
|
|
||||||
"""V2 ML 评分器"""
|
|
||||||
|
|
||||||
def __init__(self, model_dir: str = 'ml/checkpoints_v2'):
|
|
||||||
self.model_dir = model_dir
|
|
||||||
self.model = None
|
|
||||||
self.thresholds = None
|
|
||||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
||||||
self._load_model()
|
|
||||||
|
|
||||||
def _load_model(self):
|
|
||||||
"""加载模型和阈值"""
|
|
||||||
model_path = os.path.join(self.model_dir, 'best_model.pt')
|
|
||||||
threshold_path = os.path.join(self.model_dir, 'thresholds.json')
|
|
||||||
config_path = os.path.join(self.model_dir, 'config.json')
|
|
||||||
|
|
||||||
if not os.path.exists(model_path):
|
|
||||||
print(f"警告: 模型文件不存在: {model_path}")
|
|
||||||
return
|
|
||||||
|
|
||||||
# 加载配置
|
|
||||||
with open(config_path, 'r') as f:
|
|
||||||
config = json.load(f)
|
|
||||||
|
|
||||||
# 创建模型
|
|
||||||
model_config = config.get('model', {})
|
|
||||||
self.model = TransformerAutoencoder(**model_config)
|
|
||||||
|
|
||||||
# 加载权重
|
|
||||||
checkpoint = torch.load(model_path, map_location=self.device)
|
|
||||||
self.model.load_state_dict(checkpoint['model_state_dict'])
|
|
||||||
self.model.to(self.device)
|
|
||||||
self.model.eval()
|
|
||||||
|
|
||||||
# 加载阈值
|
|
||||||
if os.path.exists(threshold_path):
|
|
||||||
with open(threshold_path, 'r') as f:
|
|
||||||
self.thresholds = json.load(f)
|
|
||||||
|
|
||||||
print(f"V2 模型加载完成: {model_path}")
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def score_batch(self, sequences: np.ndarray) -> np.ndarray:
|
|
||||||
"""
|
|
||||||
批量计算 ML 分数
|
|
||||||
|
|
||||||
返回 0-100 的分数,越高越异常
|
|
||||||
"""
|
|
||||||
if self.model is None:
|
|
||||||
return np.zeros(len(sequences))
|
|
||||||
|
|
||||||
# 转换为 tensor
|
|
||||||
x = torch.FloatTensor(sequences).to(self.device)
|
|
||||||
|
|
||||||
# 计算重构误差
|
|
||||||
errors = self.model.compute_reconstruction_error(x, reduction='none')
|
|
||||||
# 取最后一个时刻的误差
|
|
||||||
last_errors = errors[:, -1].cpu().numpy()
|
|
||||||
|
|
||||||
# 转换为 0-100 分数
|
|
||||||
if self.thresholds:
|
|
||||||
p50 = self.thresholds.get('median', 0.1)
|
|
||||||
p99 = self.thresholds.get('p99', 1.0)
|
|
||||||
|
|
||||||
# 线性映射:p50 -> 50分,p99 -> 99分
|
|
||||||
scores = 50 + (last_errors - p50) / (p99 - p50) * 49
|
|
||||||
scores = np.clip(scores, 0, 100)
|
|
||||||
else:
|
|
||||||
# 没有阈值时,简单归一化
|
|
||||||
scores = last_errors * 100
|
|
||||||
scores = np.clip(scores, 0, 100)
|
|
||||||
|
|
||||||
return scores
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 实时数据管理器 ====================
|
|
||||||
|
|
||||||
class RealtimeDataManagerV2:
|
|
||||||
"""
|
|
||||||
V2 实时数据管理器
|
|
||||||
|
|
||||||
维护:
|
|
||||||
1. 每个概念的历史 Z-Score 序列(用于 LSTM 输入)
|
|
||||||
2. 每个概念的异动候选队列(用于持续性确认)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, concepts: List[dict], baselines: Dict[str, pd.DataFrame]):
|
|
||||||
self.concepts = {c['concept_id']: c for c in concepts}
|
|
||||||
self.baselines = baselines
|
|
||||||
|
|
||||||
# 概念到股票的映射
|
|
||||||
self.concept_stocks = {c['concept_id']: set(c['stocks']) for c in concepts}
|
|
||||||
|
|
||||||
# 历史 Z-Score 序列(每个概念)
|
|
||||||
# {concept_id: deque([(timestamp, features_dict), ...], maxlen=seq_len)}
|
|
||||||
self.zscore_history = defaultdict(lambda: deque(maxlen=CONFIG['seq_len']))
|
|
||||||
|
|
||||||
# 异动候选队列(用于持续性确认)
|
|
||||||
# {concept_id: deque([(timestamp, score), ...], maxlen=confirm_window)}
|
|
||||||
self.anomaly_candidates = defaultdict(lambda: deque(maxlen=CONFIG['confirm_window']))
|
|
||||||
|
|
||||||
# 冷却期记录
|
|
||||||
self.cooldown = {}
|
|
||||||
|
|
||||||
# 上一次更新的时间戳
|
|
||||||
self.last_timestamp = None
|
|
||||||
|
|
||||||
def compute_zscore_features(
|
|
||||||
self,
|
|
||||||
concept_id: str,
|
|
||||||
timestamp,
|
|
||||||
alpha: float,
|
|
||||||
total_amt: float,
|
|
||||||
rank_pct: float,
|
|
||||||
limit_up_ratio: float
|
|
||||||
) -> Optional[Dict]:
|
|
||||||
"""计算单个概念单个时刻的 Z-Score 特征"""
|
|
||||||
if concept_id not in self.baselines:
|
|
||||||
return None
|
|
||||||
|
|
||||||
baseline = self.baselines[concept_id]
|
|
||||||
time_slot = time_to_slot(timestamp)
|
|
||||||
|
|
||||||
# 查找对应时间片的基线
|
|
||||||
bl_row = baseline[baseline['time_slot'] == time_slot]
|
|
||||||
if bl_row.empty:
|
|
||||||
return None
|
|
||||||
|
|
||||||
bl = bl_row.iloc[0]
|
|
||||||
|
|
||||||
# 检查样本量
|
|
||||||
if bl.get('sample_count', 0) < 10:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# 计算 Z-Score
|
|
||||||
alpha_zscore = (alpha - bl['alpha_mean']) / bl['alpha_std']
|
|
||||||
amt_zscore = (total_amt - bl['amt_mean']) / bl['amt_std']
|
|
||||||
rank_zscore = (rank_pct - bl['rank_mean']) / bl['rank_std']
|
|
||||||
|
|
||||||
# 截断
|
|
||||||
clip = CONFIG['zscore_clip']
|
|
||||||
alpha_zscore = np.clip(alpha_zscore, -clip, clip)
|
|
||||||
amt_zscore = np.clip(amt_zscore, -clip, clip)
|
|
||||||
rank_zscore = np.clip(rank_zscore, -clip, clip)
|
|
||||||
|
|
||||||
# 计算动量(需要历史)
|
|
||||||
history = self.zscore_history[concept_id]
|
|
||||||
momentum_3m = 0
|
|
||||||
momentum_5m = 0
|
|
||||||
|
|
||||||
if len(history) >= 3:
|
|
||||||
recent_alphas = [h[1]['alpha'] for h in list(history)[-3:]]
|
|
||||||
older_alphas = [h[1]['alpha'] for h in list(history)[-6:-3]] if len(history) >= 6 else [alpha]
|
|
||||||
momentum_3m = np.mean(recent_alphas) - np.mean(older_alphas)
|
|
||||||
|
|
||||||
if len(history) >= 5:
|
|
||||||
recent_alphas = [h[1]['alpha'] for h in list(history)[-5:]]
|
|
||||||
older_alphas = [h[1]['alpha'] for h in list(history)[-10:-5]] if len(history) >= 10 else [alpha]
|
|
||||||
momentum_5m = np.mean(recent_alphas) - np.mean(older_alphas)
|
|
||||||
|
|
||||||
return {
|
|
||||||
'alpha': alpha,
|
|
||||||
'alpha_zscore': alpha_zscore,
|
|
||||||
'amt_zscore': amt_zscore,
|
|
||||||
'rank_zscore': rank_zscore,
|
|
||||||
'momentum_3m': momentum_3m,
|
|
||||||
'momentum_5m': momentum_5m,
|
|
||||||
'limit_up_ratio': limit_up_ratio,
|
|
||||||
'total_amt': total_amt,
|
|
||||||
'rank_pct': rank_pct,
|
|
||||||
}
|
|
||||||
|
|
||||||
def update(self, concept_id: str, timestamp, features: Dict):
|
|
||||||
"""更新概念的历史数据"""
|
|
||||||
self.zscore_history[concept_id].append((timestamp, features))
|
|
||||||
|
|
||||||
def get_sequence(self, concept_id: str) -> Optional[np.ndarray]:
|
|
||||||
"""获取用于 LSTM 的序列"""
|
|
||||||
history = self.zscore_history[concept_id]
|
|
||||||
|
|
||||||
if len(history) < CONFIG['seq_len']:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# 提取特征
|
|
||||||
feature_list = []
|
|
||||||
for _, features in history:
|
|
||||||
feature_list.append([
|
|
||||||
features['alpha_zscore'],
|
|
||||||
features['amt_zscore'],
|
|
||||||
features['rank_zscore'],
|
|
||||||
features['momentum_3m'],
|
|
||||||
features['momentum_5m'],
|
|
||||||
features['limit_up_ratio'],
|
|
||||||
])
|
|
||||||
|
|
||||||
return np.array(feature_list)
|
|
||||||
|
|
||||||
def add_anomaly_candidate(self, concept_id: str, timestamp, score: float):
|
|
||||||
"""添加异动候选"""
|
|
||||||
self.anomaly_candidates[concept_id].append((timestamp, score))
|
|
||||||
|
|
||||||
def check_sustained_anomaly(self, concept_id: str, threshold: float) -> Tuple[bool, float]:
|
|
||||||
"""
|
|
||||||
检查是否为持续性异动
|
|
||||||
|
|
||||||
返回:(是否确认, 确认比例)
|
|
||||||
"""
|
|
||||||
candidates = self.anomaly_candidates[concept_id]
|
|
||||||
|
|
||||||
if len(candidates) < CONFIG['confirm_window']:
|
|
||||||
return False, 0.0
|
|
||||||
|
|
||||||
# 统计超过阈值的时刻数量
|
|
||||||
exceed_count = sum(1 for _, score in candidates if score >= threshold)
|
|
||||||
ratio = exceed_count / len(candidates)
|
|
||||||
|
|
||||||
return ratio >= CONFIG['confirm_ratio'], ratio
|
|
||||||
|
|
||||||
def check_cooldown(self, concept_id: str, timestamp) -> bool:
|
|
||||||
"""检查是否在冷却期"""
|
|
||||||
if concept_id not in self.cooldown:
|
|
||||||
return False
|
|
||||||
|
|
||||||
last_alert = self.cooldown[concept_id]
|
|
||||||
try:
|
|
||||||
diff = (timestamp - last_alert).total_seconds() / 60
|
|
||||||
return diff < CONFIG['cooldown_minutes']
|
|
||||||
except:
|
|
||||||
return False
|
|
||||||
|
|
||||||
def set_cooldown(self, concept_id: str, timestamp):
|
|
||||||
"""设置冷却期"""
|
|
||||||
self.cooldown[concept_id] = timestamp
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 异动检测器 V2 ====================
|
|
||||||
|
|
||||||
class AnomalyDetectorV2:
|
|
||||||
"""
|
|
||||||
V2 异动检测器
|
|
||||||
|
|
||||||
核心流程:
|
|
||||||
1. 获取实时数据
|
|
||||||
2. 计算 Z-Score 特征
|
|
||||||
3. 规则评分 + ML 评分
|
|
||||||
4. 持续性确认
|
|
||||||
5. 输出异动
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_dir: str = 'ml/checkpoints_v2',
|
|
||||||
baseline_dir: str = 'ml/data_v2/baselines'
|
|
||||||
):
|
|
||||||
# 加载概念
|
|
||||||
self.concepts = self._load_concepts()
|
|
||||||
|
|
||||||
# 加载基线
|
|
||||||
self.baselines = load_baselines(baseline_dir)
|
|
||||||
print(f"加载了 {len(self.baselines)} 个概念的基线")
|
|
||||||
|
|
||||||
# 初始化 ML 评分器
|
|
||||||
self.ml_scorer = MLScorerV2(model_dir)
|
|
||||||
|
|
||||||
# 初始化数据管理器
|
|
||||||
self.data_manager = RealtimeDataManagerV2(self.concepts, self.baselines)
|
|
||||||
|
|
||||||
# 收集所有股票
|
|
||||||
self.all_stocks = list(set(s for c in self.concepts for s in c['stocks']))
|
|
||||||
|
|
||||||
def _load_concepts(self) -> List[dict]:
|
|
||||||
"""从 ES 加载概念"""
|
|
||||||
concepts = []
|
|
||||||
query = {"query": {"match_all": {}}, "size": 100, "_source": ["concept_id", "concept", "stocks"]}
|
|
||||||
|
|
||||||
resp = ES_CLIENT.search(index=ES_INDEX, body=query, scroll='2m')
|
|
||||||
scroll_id = resp['_scroll_id']
|
|
||||||
hits = resp['hits']['hits']
|
|
||||||
|
|
||||||
while len(hits) > 0:
|
|
||||||
for hit in hits:
|
|
||||||
source = hit['_source']
|
|
||||||
stocks = []
|
|
||||||
if 'stocks' in source and isinstance(source['stocks'], list):
|
|
||||||
for stock in source['stocks']:
|
|
||||||
if isinstance(stock, dict) and 'code' in stock and stock['code']:
|
|
||||||
stocks.append(stock['code'])
|
|
||||||
if stocks:
|
|
||||||
concepts.append({
|
|
||||||
'concept_id': source.get('concept_id'),
|
|
||||||
'concept_name': source.get('concept'),
|
|
||||||
'stocks': stocks
|
|
||||||
})
|
|
||||||
|
|
||||||
resp = ES_CLIENT.scroll(scroll_id=scroll_id, scroll='2m')
|
|
||||||
scroll_id = resp['_scroll_id']
|
|
||||||
hits = resp['hits']['hits']
|
|
||||||
|
|
||||||
ES_CLIENT.clear_scroll(scroll_id=scroll_id)
|
|
||||||
print(f"加载了 {len(concepts)} 个概念")
|
|
||||||
return concepts
|
|
||||||
|
|
||||||
def detect(self, trade_date: str) -> List[Dict]:
|
|
||||||
"""
|
|
||||||
检测指定日期的异动
|
|
||||||
|
|
||||||
返回异动列表
|
|
||||||
"""
|
|
||||||
print(f"\n检测 {trade_date} 的异动...")
|
|
||||||
|
|
||||||
# 获取原始数据
|
|
||||||
raw_features = self._compute_raw_features(trade_date)
|
|
||||||
if raw_features.empty:
|
|
||||||
print("无数据")
|
|
||||||
return []
|
|
||||||
|
|
||||||
# 按时间排序
|
|
||||||
timestamps = sorted(raw_features['timestamp'].unique())
|
|
||||||
print(f"时间点数: {len(timestamps)}")
|
|
||||||
|
|
||||||
all_alerts = []
|
|
||||||
|
|
||||||
for ts in timestamps:
|
|
||||||
ts_data = raw_features[raw_features['timestamp'] == ts]
|
|
||||||
ts_alerts = self._process_timestamp(ts, ts_data, trade_date)
|
|
||||||
all_alerts.extend(ts_alerts)
|
|
||||||
|
|
||||||
print(f"共检测到 {len(all_alerts)} 个异动")
|
|
||||||
return all_alerts
|
|
||||||
|
|
||||||
def _compute_raw_features(self, trade_date: str) -> pd.DataFrame:
|
|
||||||
"""计算原始特征(同 prepare_data_v2)"""
|
|
||||||
# 这里简化处理,直接调用数据准备逻辑
|
|
||||||
from prepare_data_v2 import compute_raw_concept_features
|
|
||||||
return compute_raw_concept_features(trade_date, self.concepts, self.all_stocks)
|
|
||||||
|
|
||||||
def _process_timestamp(self, timestamp, ts_data: pd.DataFrame, trade_date: str) -> List[Dict]:
|
|
||||||
"""处理单个时间戳"""
|
|
||||||
alerts = []
|
|
||||||
candidates = [] # (concept_id, features, rule_score, triggered_rules)
|
|
||||||
|
|
||||||
for _, row in ts_data.iterrows():
|
|
||||||
concept_id = row['concept_id']
|
|
||||||
|
|
||||||
# 计算 Z-Score 特征
|
|
||||||
features = self.data_manager.compute_zscore_features(
|
|
||||||
concept_id, timestamp,
|
|
||||||
row['alpha'], row['total_amt'], row['rank_pct'], row['limit_up_ratio']
|
|
||||||
)
|
|
||||||
|
|
||||||
if features is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 更新历史
|
|
||||||
self.data_manager.update(concept_id, timestamp, features)
|
|
||||||
|
|
||||||
# 规则评分
|
|
||||||
rule_score, triggered_rules = score_rules_zscore(features)
|
|
||||||
|
|
||||||
# 收集候选
|
|
||||||
candidates.append((concept_id, features, rule_score, triggered_rules))
|
|
||||||
|
|
||||||
if not candidates:
|
|
||||||
return []
|
|
||||||
|
|
||||||
# 批量 ML 评分
|
|
||||||
sequences = []
|
|
||||||
valid_candidates = []
|
|
||||||
|
|
||||||
for concept_id, features, rule_score, triggered_rules in candidates:
|
|
||||||
seq = self.data_manager.get_sequence(concept_id)
|
|
||||||
if seq is not None:
|
|
||||||
sequences.append(seq)
|
|
||||||
valid_candidates.append((concept_id, features, rule_score, triggered_rules))
|
|
||||||
|
|
||||||
if not sequences:
|
|
||||||
return []
|
|
||||||
|
|
||||||
sequences = np.array(sequences)
|
|
||||||
ml_scores = self.ml_scorer.score_batch(sequences)
|
|
||||||
|
|
||||||
# 融合评分 + 持续性确认
|
|
||||||
for i, (concept_id, features, rule_score, triggered_rules) in enumerate(valid_candidates):
|
|
||||||
ml_score = ml_scores[i]
|
|
||||||
final_score = CONFIG['rule_weight'] * rule_score + CONFIG['ml_weight'] * ml_score
|
|
||||||
|
|
||||||
# 判断是否触发
|
|
||||||
is_triggered = (
|
|
||||||
rule_score >= CONFIG['rule_trigger'] or
|
|
||||||
ml_score >= CONFIG['ml_trigger'] or
|
|
||||||
final_score >= CONFIG['fusion_trigger']
|
|
||||||
)
|
|
||||||
|
|
||||||
# 添加到候选队列
|
|
||||||
self.data_manager.add_anomaly_candidate(concept_id, timestamp, final_score)
|
|
||||||
|
|
||||||
if not is_triggered:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 检查冷却期
|
|
||||||
if self.data_manager.check_cooldown(concept_id, timestamp):
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 持续性确认
|
|
||||||
is_sustained, confirm_ratio = self.data_manager.check_sustained_anomaly(
|
|
||||||
concept_id, CONFIG['fusion_trigger']
|
|
||||||
)
|
|
||||||
|
|
||||||
if not is_sustained:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 确认为异动!
|
|
||||||
self.data_manager.set_cooldown(concept_id, timestamp)
|
|
||||||
|
|
||||||
# 确定异动类型
|
|
||||||
alpha = features['alpha']
|
|
||||||
if alpha >= 1.5:
|
|
||||||
alert_type = 'surge_up'
|
|
||||||
elif alpha <= -1.5:
|
|
||||||
alert_type = 'surge_down'
|
|
||||||
elif features['amt_zscore'] >= 3.0:
|
|
||||||
alert_type = 'volume_spike'
|
|
||||||
else:
|
|
||||||
alert_type = 'surge'
|
|
||||||
|
|
||||||
# 确定触发原因
|
|
||||||
if rule_score >= CONFIG['rule_trigger']:
|
|
||||||
trigger_reason = f'规则({rule_score:.0f})+持续确认({confirm_ratio:.0%})'
|
|
||||||
elif ml_score >= CONFIG['ml_trigger']:
|
|
||||||
trigger_reason = f'ML({ml_score:.0f})+持续确认({confirm_ratio:.0%})'
|
|
||||||
else:
|
|
||||||
trigger_reason = f'融合({final_score:.0f})+持续确认({confirm_ratio:.0%})'
|
|
||||||
|
|
||||||
alerts.append({
|
|
||||||
'concept_id': concept_id,
|
|
||||||
'concept_name': self.data_manager.concepts.get(concept_id, {}).get('concept_name', concept_id),
|
|
||||||
'alert_time': timestamp,
|
|
||||||
'trade_date': trade_date,
|
|
||||||
'alert_type': alert_type,
|
|
||||||
'final_score': final_score,
|
|
||||||
'rule_score': rule_score,
|
|
||||||
'ml_score': ml_score,
|
|
||||||
'trigger_reason': trigger_reason,
|
|
||||||
'confirm_ratio': confirm_ratio,
|
|
||||||
'alpha': alpha,
|
|
||||||
'alpha_zscore': features['alpha_zscore'],
|
|
||||||
'amt_zscore': features['amt_zscore'],
|
|
||||||
'rank_zscore': features['rank_zscore'],
|
|
||||||
'momentum_3m': features['momentum_3m'],
|
|
||||||
'momentum_5m': features['momentum_5m'],
|
|
||||||
'limit_up_ratio': features['limit_up_ratio'],
|
|
||||||
'triggered_rules': triggered_rules,
|
|
||||||
})
|
|
||||||
|
|
||||||
# 每分钟最多 N 个
|
|
||||||
if len(alerts) > CONFIG['max_alerts_per_minute']:
|
|
||||||
alerts = sorted(alerts, key=lambda x: x['final_score'], reverse=True)
|
|
||||||
alerts = alerts[:CONFIG['max_alerts_per_minute']]
|
|
||||||
|
|
||||||
return alerts
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 主函数 ====================
|
|
||||||
|
|
||||||
def main():
|
|
||||||
import argparse
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='V2 异动检测器')
|
|
||||||
parser.add_argument('--date', type=str, default=None, help='检测日期(默认今天)')
|
|
||||||
parser.add_argument('--model_dir', type=str, default='ml/checkpoints_v2')
|
|
||||||
parser.add_argument('--baseline_dir', type=str, default='ml/data_v2/baselines')
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
trade_date = args.date or datetime.now().strftime('%Y-%m-%d')
|
|
||||||
|
|
||||||
detector = AnomalyDetectorV2(
|
|
||||||
model_dir=args.model_dir,
|
|
||||||
baseline_dir=args.baseline_dir
|
|
||||||
)
|
|
||||||
|
|
||||||
alerts = detector.detect(trade_date)
|
|
||||||
|
|
||||||
print(f"\n检测结果:")
|
|
||||||
for alert in alerts[:20]:
|
|
||||||
print(f" [{alert['alert_time'].strftime('%H:%M') if hasattr(alert['alert_time'], 'strftime') else alert['alert_time']}] "
|
|
||||||
f"{alert['concept_name']} ({alert['alert_type']}) "
|
|
||||||
f"分数={alert['final_score']:.0f} "
|
|
||||||
f"确认率={alert['confirm_ratio']:.0%}")
|
|
||||||
|
|
||||||
if len(alerts) > 20:
|
|
||||||
print(f" ... 共 {len(alerts)} 个异动")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -1,526 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""
|
|
||||||
增强版概念异动检测器
|
|
||||||
|
|
||||||
融合两种检测方法:
|
|
||||||
1. Alpha-based Z-Score(规则方法,实时性好)
|
|
||||||
2. Transformer Autoencoder(ML方法,更准确)
|
|
||||||
|
|
||||||
使用策略:
|
|
||||||
- 当 ML 模型可用且历史数据足够时,优先使用 ML 方法
|
|
||||||
- 否则回退到 Alpha-based 方法
|
|
||||||
- 可以配置两种方法的融合权重
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import logging
|
|
||||||
from datetime import datetime
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Dict, List, Tuple, Optional
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from collections import deque
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
# 添加父目录到路径
|
|
||||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 配置 ====================
|
|
||||||
|
|
||||||
ENHANCED_CONFIG = {
|
|
||||||
# 融合策略
|
|
||||||
'fusion_mode': 'adaptive', # 'ml_only', 'alpha_only', 'adaptive', 'ensemble'
|
|
||||||
|
|
||||||
# ML 权重(在 ensemble 模式下)
|
|
||||||
'ml_weight': 0.6,
|
|
||||||
'alpha_weight': 0.4,
|
|
||||||
|
|
||||||
# ML 模型配置
|
|
||||||
'ml_checkpoint_dir': 'ml/checkpoints',
|
|
||||||
'ml_threshold_key': 'p95', # p90, p95, p99
|
|
||||||
|
|
||||||
# Alpha 配置(与 concept_alert_alpha.py 一致)
|
|
||||||
'alpha_zscore_threshold': 2.0,
|
|
||||||
'alpha_absolute_threshold': 1.5,
|
|
||||||
'alpha_history_window': 60,
|
|
||||||
'alpha_min_history': 5,
|
|
||||||
|
|
||||||
# 共享配置
|
|
||||||
'cooldown_minutes': 8,
|
|
||||||
'max_alerts_per_minute': 15,
|
|
||||||
'min_alpha_abs': 0.5,
|
|
||||||
}
|
|
||||||
|
|
||||||
# 特征配置(与训练一致)
|
|
||||||
FEATURE_NAMES = [
|
|
||||||
'alpha',
|
|
||||||
'alpha_delta',
|
|
||||||
'amt_ratio',
|
|
||||||
'amt_delta',
|
|
||||||
'rank_pct',
|
|
||||||
'limit_up_ratio',
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 数据结构 ====================
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class AlphaStats:
|
|
||||||
"""概念的Alpha统计信息"""
|
|
||||||
history: deque = field(default_factory=lambda: deque(maxlen=ENHANCED_CONFIG['alpha_history_window']))
|
|
||||||
mean: float = 0.0
|
|
||||||
std: float = 1.0
|
|
||||||
|
|
||||||
def update(self, alpha: float):
|
|
||||||
self.history.append(alpha)
|
|
||||||
if len(self.history) >= 2:
|
|
||||||
self.mean = np.mean(self.history)
|
|
||||||
self.std = max(np.std(self.history), 0.1)
|
|
||||||
|
|
||||||
def get_zscore(self, alpha: float) -> float:
|
|
||||||
if len(self.history) < ENHANCED_CONFIG['alpha_min_history']:
|
|
||||||
return 0.0
|
|
||||||
return (alpha - self.mean) / self.std
|
|
||||||
|
|
||||||
def is_ready(self) -> bool:
|
|
||||||
return len(self.history) >= ENHANCED_CONFIG['alpha_min_history']
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ConceptFeatures:
|
|
||||||
"""概念的实时特征"""
|
|
||||||
alpha: float = 0.0
|
|
||||||
alpha_delta: float = 0.0
|
|
||||||
amt_ratio: float = 1.0
|
|
||||||
amt_delta: float = 0.0
|
|
||||||
rank_pct: float = 0.5
|
|
||||||
limit_up_ratio: float = 0.0
|
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, float]:
|
|
||||||
return {
|
|
||||||
'alpha': self.alpha,
|
|
||||||
'alpha_delta': self.alpha_delta,
|
|
||||||
'amt_ratio': self.amt_ratio,
|
|
||||||
'amt_delta': self.amt_delta,
|
|
||||||
'rank_pct': self.rank_pct,
|
|
||||||
'limit_up_ratio': self.limit_up_ratio,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 增强检测器 ====================
|
|
||||||
|
|
||||||
class EnhancedAnomalyDetector:
|
|
||||||
"""
|
|
||||||
增强版异动检测器
|
|
||||||
|
|
||||||
融合 Alpha-based 和 ML 两种方法
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
config: Dict = None,
|
|
||||||
ml_enabled: bool = True
|
|
||||||
):
|
|
||||||
self.config = config or ENHANCED_CONFIG
|
|
||||||
self.ml_enabled = ml_enabled
|
|
||||||
self.ml_detector = None
|
|
||||||
|
|
||||||
# Alpha 统计
|
|
||||||
self.alpha_stats: Dict[str, AlphaStats] = {}
|
|
||||||
|
|
||||||
# 特征历史(用于计算 delta)
|
|
||||||
self.feature_history: Dict[str, deque] = {}
|
|
||||||
|
|
||||||
# 冷却记录
|
|
||||||
self.cooldown_cache: Dict[str, datetime] = {}
|
|
||||||
|
|
||||||
# 尝试加载 ML 模型
|
|
||||||
if ml_enabled:
|
|
||||||
self._load_ml_model()
|
|
||||||
|
|
||||||
logger.info(f"EnhancedAnomalyDetector 初始化完成")
|
|
||||||
logger.info(f" 融合模式: {self.config['fusion_mode']}")
|
|
||||||
logger.info(f" ML 可用: {self.ml_detector is not None}")
|
|
||||||
|
|
||||||
def _load_ml_model(self):
|
|
||||||
"""加载 ML 模型"""
|
|
||||||
try:
|
|
||||||
from inference import ConceptAnomalyDetector
|
|
||||||
checkpoint_dir = Path(__file__).parent / 'checkpoints'
|
|
||||||
|
|
||||||
if (checkpoint_dir / 'best_model.pt').exists():
|
|
||||||
self.ml_detector = ConceptAnomalyDetector(
|
|
||||||
checkpoint_dir=str(checkpoint_dir),
|
|
||||||
threshold_key=self.config['ml_threshold_key']
|
|
||||||
)
|
|
||||||
logger.info("ML 模型加载成功")
|
|
||||||
else:
|
|
||||||
logger.warning(f"ML 模型不存在: {checkpoint_dir / 'best_model.pt'}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"ML 模型加载失败: {e}")
|
|
||||||
self.ml_detector = None
|
|
||||||
|
|
||||||
def _get_alpha_stats(self, concept_id: str) -> AlphaStats:
|
|
||||||
"""获取或创建 Alpha 统计"""
|
|
||||||
if concept_id not in self.alpha_stats:
|
|
||||||
self.alpha_stats[concept_id] = AlphaStats()
|
|
||||||
return self.alpha_stats[concept_id]
|
|
||||||
|
|
||||||
def _get_feature_history(self, concept_id: str) -> deque:
|
|
||||||
"""获取特征历史"""
|
|
||||||
if concept_id not in self.feature_history:
|
|
||||||
self.feature_history[concept_id] = deque(maxlen=10)
|
|
||||||
return self.feature_history[concept_id]
|
|
||||||
|
|
||||||
def _check_cooldown(self, concept_id: str, current_time: datetime) -> bool:
|
|
||||||
"""检查冷却"""
|
|
||||||
if concept_id not in self.cooldown_cache:
|
|
||||||
return False
|
|
||||||
|
|
||||||
last_alert = self.cooldown_cache[concept_id]
|
|
||||||
cooldown_td = (current_time - last_alert).total_seconds() / 60
|
|
||||||
|
|
||||||
return cooldown_td < self.config['cooldown_minutes']
|
|
||||||
|
|
||||||
def _set_cooldown(self, concept_id: str, current_time: datetime):
|
|
||||||
"""设置冷却"""
|
|
||||||
self.cooldown_cache[concept_id] = current_time
|
|
||||||
|
|
||||||
def compute_features(
|
|
||||||
self,
|
|
||||||
concept_id: str,
|
|
||||||
alpha: float,
|
|
||||||
amt_ratio: float,
|
|
||||||
rank_pct: float,
|
|
||||||
limit_up_ratio: float
|
|
||||||
) -> ConceptFeatures:
|
|
||||||
"""
|
|
||||||
计算概念的完整特征
|
|
||||||
|
|
||||||
Args:
|
|
||||||
concept_id: 概念ID
|
|
||||||
alpha: 当前超额收益
|
|
||||||
amt_ratio: 成交额比率
|
|
||||||
rank_pct: 排名百分位
|
|
||||||
limit_up_ratio: 涨停股占比
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
完整特征
|
|
||||||
"""
|
|
||||||
history = self._get_feature_history(concept_id)
|
|
||||||
|
|
||||||
# 计算变化率
|
|
||||||
alpha_delta = 0.0
|
|
||||||
amt_delta = 0.0
|
|
||||||
|
|
||||||
if len(history) > 0:
|
|
||||||
last_features = history[-1]
|
|
||||||
alpha_delta = alpha - last_features.alpha
|
|
||||||
if last_features.amt_ratio > 0:
|
|
||||||
amt_delta = (amt_ratio - last_features.amt_ratio) / last_features.amt_ratio
|
|
||||||
|
|
||||||
features = ConceptFeatures(
|
|
||||||
alpha=alpha,
|
|
||||||
alpha_delta=alpha_delta,
|
|
||||||
amt_ratio=amt_ratio,
|
|
||||||
amt_delta=amt_delta,
|
|
||||||
rank_pct=rank_pct,
|
|
||||||
limit_up_ratio=limit_up_ratio,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 更新历史
|
|
||||||
history.append(features)
|
|
||||||
|
|
||||||
return features
|
|
||||||
|
|
||||||
def detect_alpha_anomaly(
|
|
||||||
self,
|
|
||||||
concept_id: str,
|
|
||||||
alpha: float
|
|
||||||
) -> Tuple[bool, float, str]:
|
|
||||||
"""
|
|
||||||
Alpha-based 异动检测
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
is_anomaly: 是否异动
|
|
||||||
score: 异动分数(Z-Score 绝对值)
|
|
||||||
reason: 触发原因
|
|
||||||
"""
|
|
||||||
stats = self._get_alpha_stats(concept_id)
|
|
||||||
|
|
||||||
# 计算 Z-Score(在更新前)
|
|
||||||
zscore = stats.get_zscore(alpha)
|
|
||||||
|
|
||||||
# 更新统计
|
|
||||||
stats.update(alpha)
|
|
||||||
|
|
||||||
# 判断
|
|
||||||
if stats.is_ready():
|
|
||||||
if abs(zscore) >= self.config['alpha_zscore_threshold']:
|
|
||||||
return True, abs(zscore), f"Z={zscore:.2f}"
|
|
||||||
else:
|
|
||||||
if abs(alpha) >= self.config['alpha_absolute_threshold']:
|
|
||||||
fake_zscore = alpha / 0.5
|
|
||||||
return True, abs(fake_zscore), f"Alpha={alpha:+.2f}%"
|
|
||||||
|
|
||||||
return False, abs(zscore) if zscore else 0.0, ""
|
|
||||||
|
|
||||||
def detect_ml_anomaly(
|
|
||||||
self,
|
|
||||||
concept_id: str,
|
|
||||||
features: ConceptFeatures
|
|
||||||
) -> Tuple[bool, float]:
|
|
||||||
"""
|
|
||||||
ML-based 异动检测
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
is_anomaly: 是否异动
|
|
||||||
score: 异动分数(重构误差)
|
|
||||||
"""
|
|
||||||
if self.ml_detector is None:
|
|
||||||
return False, 0.0
|
|
||||||
|
|
||||||
try:
|
|
||||||
is_anomaly, score = self.ml_detector.detect(
|
|
||||||
concept_id,
|
|
||||||
features.to_dict()
|
|
||||||
)
|
|
||||||
return is_anomaly, score or 0.0
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"ML 检测失败: {e}")
|
|
||||||
return False, 0.0
|
|
||||||
|
|
||||||
def detect(
|
|
||||||
self,
|
|
||||||
concept_id: str,
|
|
||||||
concept_name: str,
|
|
||||||
alpha: float,
|
|
||||||
amt_ratio: float,
|
|
||||||
rank_pct: float,
|
|
||||||
limit_up_ratio: float,
|
|
||||||
change_pct: float,
|
|
||||||
index_change: float,
|
|
||||||
current_time: datetime,
|
|
||||||
**extra_data
|
|
||||||
) -> Optional[Dict]:
|
|
||||||
"""
|
|
||||||
融合检测
|
|
||||||
|
|
||||||
Args:
|
|
||||||
concept_id: 概念ID
|
|
||||||
concept_name: 概念名称
|
|
||||||
alpha: 超额收益
|
|
||||||
amt_ratio: 成交额比率
|
|
||||||
rank_pct: 排名百分位
|
|
||||||
limit_up_ratio: 涨停股占比
|
|
||||||
change_pct: 概念涨跌幅
|
|
||||||
index_change: 大盘涨跌幅
|
|
||||||
current_time: 当前时间
|
|
||||||
**extra_data: 其他数据(limit_up_count, stock_count 等)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
异动信息(如果触发),否则 None
|
|
||||||
"""
|
|
||||||
# Alpha 太小,不关注
|
|
||||||
if abs(alpha) < self.config['min_alpha_abs']:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# 检查冷却
|
|
||||||
if self._check_cooldown(concept_id, current_time):
|
|
||||||
return None
|
|
||||||
|
|
||||||
# 计算特征
|
|
||||||
features = self.compute_features(
|
|
||||||
concept_id, alpha, amt_ratio, rank_pct, limit_up_ratio
|
|
||||||
)
|
|
||||||
|
|
||||||
# 执行检测
|
|
||||||
fusion_mode = self.config['fusion_mode']
|
|
||||||
|
|
||||||
alpha_anomaly, alpha_score, alpha_reason = self.detect_alpha_anomaly(concept_id, alpha)
|
|
||||||
ml_anomaly, ml_score = False, 0.0
|
|
||||||
|
|
||||||
if fusion_mode in ('ml_only', 'adaptive', 'ensemble'):
|
|
||||||
ml_anomaly, ml_score = self.detect_ml_anomaly(concept_id, features)
|
|
||||||
|
|
||||||
# 根据融合模式判断
|
|
||||||
is_anomaly = False
|
|
||||||
final_score = 0.0
|
|
||||||
detection_method = ''
|
|
||||||
|
|
||||||
if fusion_mode == 'alpha_only':
|
|
||||||
is_anomaly = alpha_anomaly
|
|
||||||
final_score = alpha_score
|
|
||||||
detection_method = 'alpha'
|
|
||||||
|
|
||||||
elif fusion_mode == 'ml_only':
|
|
||||||
is_anomaly = ml_anomaly
|
|
||||||
final_score = ml_score
|
|
||||||
detection_method = 'ml'
|
|
||||||
|
|
||||||
elif fusion_mode == 'adaptive':
|
|
||||||
# 优先 ML,回退 Alpha
|
|
||||||
if self.ml_detector and ml_score > 0:
|
|
||||||
is_anomaly = ml_anomaly
|
|
||||||
final_score = ml_score
|
|
||||||
detection_method = 'ml'
|
|
||||||
else:
|
|
||||||
is_anomaly = alpha_anomaly
|
|
||||||
final_score = alpha_score
|
|
||||||
detection_method = 'alpha'
|
|
||||||
|
|
||||||
elif fusion_mode == 'ensemble':
|
|
||||||
# 加权融合
|
|
||||||
# 归一化分数
|
|
||||||
norm_alpha = min(alpha_score / 5.0, 1.0) # Z > 5 视为 1.0
|
|
||||||
norm_ml = min(ml_score / (self.ml_detector.threshold if self.ml_detector else 1.0), 1.0)
|
|
||||||
|
|
||||||
final_score = (
|
|
||||||
self.config['alpha_weight'] * norm_alpha +
|
|
||||||
self.config['ml_weight'] * norm_ml
|
|
||||||
)
|
|
||||||
is_anomaly = final_score > 0.5 or alpha_anomaly or ml_anomaly
|
|
||||||
detection_method = 'ensemble'
|
|
||||||
|
|
||||||
if not is_anomaly:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# 构建异动记录
|
|
||||||
self._set_cooldown(concept_id, current_time)
|
|
||||||
|
|
||||||
alert_type = 'surge_up' if alpha > 0 else 'surge_down'
|
|
||||||
|
|
||||||
alert = {
|
|
||||||
'concept_id': concept_id,
|
|
||||||
'concept_name': concept_name,
|
|
||||||
'alert_type': alert_type,
|
|
||||||
'alert_time': current_time,
|
|
||||||
'change_pct': change_pct,
|
|
||||||
'alpha': alpha,
|
|
||||||
'alpha_zscore': alpha_score,
|
|
||||||
'index_change_pct': index_change,
|
|
||||||
'detection_method': detection_method,
|
|
||||||
'alpha_score': alpha_score,
|
|
||||||
'ml_score': ml_score,
|
|
||||||
'final_score': final_score,
|
|
||||||
**extra_data
|
|
||||||
}
|
|
||||||
|
|
||||||
return alert
|
|
||||||
|
|
||||||
def batch_detect(
|
|
||||||
self,
|
|
||||||
concepts_data: List[Dict],
|
|
||||||
current_time: datetime
|
|
||||||
) -> List[Dict]:
|
|
||||||
"""
|
|
||||||
批量检测
|
|
||||||
|
|
||||||
Args:
|
|
||||||
concepts_data: 概念数据列表
|
|
||||||
current_time: 当前时间
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
异动列表(按分数排序,限制数量)
|
|
||||||
"""
|
|
||||||
alerts = []
|
|
||||||
|
|
||||||
for data in concepts_data:
|
|
||||||
alert = self.detect(
|
|
||||||
concept_id=data['concept_id'],
|
|
||||||
concept_name=data['concept_name'],
|
|
||||||
alpha=data.get('alpha', 0),
|
|
||||||
amt_ratio=data.get('amt_ratio', 1.0),
|
|
||||||
rank_pct=data.get('rank_pct', 0.5),
|
|
||||||
limit_up_ratio=data.get('limit_up_ratio', 0),
|
|
||||||
change_pct=data.get('change_pct', 0),
|
|
||||||
index_change=data.get('index_change', 0),
|
|
||||||
current_time=current_time,
|
|
||||||
limit_up_count=data.get('limit_up_count', 0),
|
|
||||||
limit_down_count=data.get('limit_down_count', 0),
|
|
||||||
stock_count=data.get('stock_count', 0),
|
|
||||||
concept_type=data.get('concept_type', 'leaf'),
|
|
||||||
)
|
|
||||||
|
|
||||||
if alert:
|
|
||||||
alerts.append(alert)
|
|
||||||
|
|
||||||
# 排序并限制数量
|
|
||||||
alerts.sort(key=lambda x: x['final_score'], reverse=True)
|
|
||||||
return alerts[:self.config['max_alerts_per_minute']]
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
"""重置所有状态(新交易日)"""
|
|
||||||
self.alpha_stats.clear()
|
|
||||||
self.feature_history.clear()
|
|
||||||
self.cooldown_cache.clear()
|
|
||||||
|
|
||||||
if self.ml_detector:
|
|
||||||
self.ml_detector.clear_history()
|
|
||||||
|
|
||||||
logger.info("检测器状态已重置")
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 测试 ====================
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import random
|
|
||||||
|
|
||||||
print("测试 EnhancedAnomalyDetector...")
|
|
||||||
|
|
||||||
# 初始化
|
|
||||||
detector = EnhancedAnomalyDetector(ml_enabled=False) # 不加载 ML(可能不存在)
|
|
||||||
|
|
||||||
# 模拟数据
|
|
||||||
concepts = [
|
|
||||||
{'concept_id': 'ai_001', 'concept_name': '人工智能'},
|
|
||||||
{'concept_id': 'chip_002', 'concept_name': '芯片半导体'},
|
|
||||||
{'concept_id': 'car_003', 'concept_name': '新能源汽车'},
|
|
||||||
]
|
|
||||||
|
|
||||||
print("\n模拟实时检测...")
|
|
||||||
current_time = datetime.now()
|
|
||||||
|
|
||||||
for minute in range(50):
|
|
||||||
concepts_data = []
|
|
||||||
|
|
||||||
for c in concepts:
|
|
||||||
# 生成随机数据
|
|
||||||
alpha = random.gauss(0, 0.8)
|
|
||||||
amt_ratio = max(0.3, random.gauss(1, 0.3))
|
|
||||||
rank_pct = random.random()
|
|
||||||
limit_up_ratio = random.random() * 0.1
|
|
||||||
|
|
||||||
# 模拟异动(第30分钟人工智能暴涨)
|
|
||||||
if minute == 30 and c['concept_id'] == 'ai_001':
|
|
||||||
alpha = 4.5
|
|
||||||
amt_ratio = 2.5
|
|
||||||
limit_up_ratio = 0.3
|
|
||||||
|
|
||||||
concepts_data.append({
|
|
||||||
**c,
|
|
||||||
'alpha': alpha,
|
|
||||||
'amt_ratio': amt_ratio,
|
|
||||||
'rank_pct': rank_pct,
|
|
||||||
'limit_up_ratio': limit_up_ratio,
|
|
||||||
'change_pct': alpha + 0.5,
|
|
||||||
'index_change': 0.5,
|
|
||||||
})
|
|
||||||
|
|
||||||
# 检测
|
|
||||||
alerts = detector.batch_detect(concepts_data, current_time)
|
|
||||||
|
|
||||||
if alerts:
|
|
||||||
for alert in alerts:
|
|
||||||
print(f" t={minute:02d} 🔥 {alert['concept_name']} "
|
|
||||||
f"Alpha={alert['alpha']:+.2f}% "
|
|
||||||
f"Score={alert['final_score']:.2f} "
|
|
||||||
f"Method={alert['detection_method']}")
|
|
||||||
|
|
||||||
current_time = current_time.replace(minute=current_time.minute + 1 if current_time.minute < 59 else 0)
|
|
||||||
|
|
||||||
print("\n测试完成!")
|
|
||||||
455
ml/inference.py
455
ml/inference.py
@@ -1,455 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""
|
|
||||||
概念异动检测推理服务
|
|
||||||
|
|
||||||
在实时场景中使用训练好的 Transformer Autoencoder 进行异动检测
|
|
||||||
|
|
||||||
使用方法:
|
|
||||||
from ml.inference import ConceptAnomalyDetector
|
|
||||||
|
|
||||||
detector = ConceptAnomalyDetector('ml/checkpoints')
|
|
||||||
|
|
||||||
# 检测异动
|
|
||||||
features = {...} # 实时特征数据
|
|
||||||
is_anomaly, score = detector.detect(features)
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import json
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Dict, List, Tuple, Optional
|
|
||||||
from collections import deque
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from model import TransformerAutoencoder
|
|
||||||
|
|
||||||
|
|
||||||
class ConceptAnomalyDetector:
|
|
||||||
"""
|
|
||||||
概念异动检测器
|
|
||||||
|
|
||||||
使用训练好的 Transformer Autoencoder 进行实时异动检测
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
checkpoint_dir: str = 'ml/checkpoints',
|
|
||||||
device: str = 'auto',
|
|
||||||
threshold_key: str = 'p95'
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
初始化检测器
|
|
||||||
|
|
||||||
Args:
|
|
||||||
checkpoint_dir: 模型检查点目录
|
|
||||||
device: 设备 (auto/cuda/cpu)
|
|
||||||
threshold_key: 使用的阈值键 (p90/p95/p99)
|
|
||||||
"""
|
|
||||||
self.checkpoint_dir = Path(checkpoint_dir)
|
|
||||||
self.threshold_key = threshold_key
|
|
||||||
|
|
||||||
# 设备选择
|
|
||||||
if device == 'auto':
|
|
||||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
||||||
else:
|
|
||||||
self.device = torch.device(device)
|
|
||||||
|
|
||||||
# 加载配置
|
|
||||||
self._load_config()
|
|
||||||
|
|
||||||
# 加载模型
|
|
||||||
self._load_model()
|
|
||||||
|
|
||||||
# 加载阈值
|
|
||||||
self._load_thresholds()
|
|
||||||
|
|
||||||
# 加载标准化统计量
|
|
||||||
self._load_normalization_stats()
|
|
||||||
|
|
||||||
# 概念历史数据缓存
|
|
||||||
# {concept_name: deque(maxlen=seq_len)}
|
|
||||||
self.history_cache: Dict[str, deque] = {}
|
|
||||||
|
|
||||||
print(f"ConceptAnomalyDetector 初始化完成")
|
|
||||||
print(f" 设备: {self.device}")
|
|
||||||
print(f" 阈值: {self.threshold_key} = {self.threshold:.6f}")
|
|
||||||
print(f" 序列长度: {self.seq_len}")
|
|
||||||
|
|
||||||
def _load_config(self):
|
|
||||||
"""加载配置"""
|
|
||||||
config_path = self.checkpoint_dir / 'config.json'
|
|
||||||
if not config_path.exists():
|
|
||||||
raise FileNotFoundError(f"配置文件不存在: {config_path}")
|
|
||||||
|
|
||||||
with open(config_path, 'r') as f:
|
|
||||||
self.config = json.load(f)
|
|
||||||
|
|
||||||
self.features = self.config['features']
|
|
||||||
self.seq_len = self.config['seq_len']
|
|
||||||
self.model_config = self.config['model']
|
|
||||||
|
|
||||||
def _load_model(self):
|
|
||||||
"""加载模型"""
|
|
||||||
model_path = self.checkpoint_dir / 'best_model.pt'
|
|
||||||
if not model_path.exists():
|
|
||||||
raise FileNotFoundError(f"模型文件不存在: {model_path}")
|
|
||||||
|
|
||||||
# 创建模型
|
|
||||||
self.model = TransformerAutoencoder(**self.model_config)
|
|
||||||
|
|
||||||
# 加载权重
|
|
||||||
checkpoint = torch.load(model_path, map_location=self.device)
|
|
||||||
self.model.load_state_dict(checkpoint['model_state_dict'])
|
|
||||||
self.model.to(self.device)
|
|
||||||
self.model.eval()
|
|
||||||
|
|
||||||
print(f"模型已加载: {model_path}")
|
|
||||||
|
|
||||||
def _load_thresholds(self):
|
|
||||||
"""加载阈值"""
|
|
||||||
thresholds_path = self.checkpoint_dir / 'thresholds.json'
|
|
||||||
if not thresholds_path.exists():
|
|
||||||
raise FileNotFoundError(f"阈值文件不存在: {thresholds_path}")
|
|
||||||
|
|
||||||
with open(thresholds_path, 'r') as f:
|
|
||||||
self.thresholds = json.load(f)
|
|
||||||
|
|
||||||
if self.threshold_key not in self.thresholds:
|
|
||||||
available_keys = list(self.thresholds.keys())
|
|
||||||
raise KeyError(f"阈值键 '{self.threshold_key}' 不存在,可用: {available_keys}")
|
|
||||||
|
|
||||||
self.threshold = self.thresholds[self.threshold_key]
|
|
||||||
|
|
||||||
def _load_normalization_stats(self):
|
|
||||||
"""加载标准化统计量"""
|
|
||||||
stats_path = self.checkpoint_dir / 'normalization_stats.json'
|
|
||||||
if not stats_path.exists():
|
|
||||||
raise FileNotFoundError(f"标准化统计量文件不存在: {stats_path}")
|
|
||||||
|
|
||||||
with open(stats_path, 'r') as f:
|
|
||||||
stats = json.load(f)
|
|
||||||
|
|
||||||
self.norm_mean = np.array(stats['mean'])
|
|
||||||
self.norm_std = np.array(stats['std'])
|
|
||||||
|
|
||||||
def normalize(self, features: np.ndarray) -> np.ndarray:
|
|
||||||
"""标准化特征"""
|
|
||||||
return (features - self.norm_mean) / self.norm_std
|
|
||||||
|
|
||||||
def update_history(
|
|
||||||
self,
|
|
||||||
concept_name: str,
|
|
||||||
features: Dict[str, float]
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
更新概念历史数据
|
|
||||||
|
|
||||||
Args:
|
|
||||||
concept_name: 概念名称
|
|
||||||
features: 当前时刻的特征字典
|
|
||||||
"""
|
|
||||||
# 初始化历史缓存
|
|
||||||
if concept_name not in self.history_cache:
|
|
||||||
self.history_cache[concept_name] = deque(maxlen=self.seq_len)
|
|
||||||
|
|
||||||
# 提取特征向量
|
|
||||||
feature_vector = np.array([
|
|
||||||
features.get(f, 0.0) for f in self.features
|
|
||||||
])
|
|
||||||
|
|
||||||
# 处理异常值
|
|
||||||
feature_vector = np.nan_to_num(feature_vector, nan=0.0, posinf=0.0, neginf=0.0)
|
|
||||||
|
|
||||||
# 添加到历史
|
|
||||||
self.history_cache[concept_name].append(feature_vector)
|
|
||||||
|
|
||||||
def get_history_length(self, concept_name: str) -> int:
|
|
||||||
"""获取概念的历史数据长度"""
|
|
||||||
if concept_name not in self.history_cache:
|
|
||||||
return 0
|
|
||||||
return len(self.history_cache[concept_name])
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def detect(
|
|
||||||
self,
|
|
||||||
concept_name: str,
|
|
||||||
features: Dict[str, float] = None,
|
|
||||||
return_score: bool = True
|
|
||||||
) -> Tuple[bool, Optional[float]]:
|
|
||||||
"""
|
|
||||||
检测概念是否异动
|
|
||||||
|
|
||||||
Args:
|
|
||||||
concept_name: 概念名称
|
|
||||||
features: 当前时刻的特征(如果提供,会先更新历史)
|
|
||||||
return_score: 是否返回异动分数
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
is_anomaly: 是否异动
|
|
||||||
score: 异动分数(如果 return_score=True)
|
|
||||||
"""
|
|
||||||
# 更新历史
|
|
||||||
if features is not None:
|
|
||||||
self.update_history(concept_name, features)
|
|
||||||
|
|
||||||
# 检查历史数据是否足够
|
|
||||||
if concept_name not in self.history_cache:
|
|
||||||
return False, None
|
|
||||||
|
|
||||||
history = self.history_cache[concept_name]
|
|
||||||
if len(history) < self.seq_len:
|
|
||||||
return False, None
|
|
||||||
|
|
||||||
# 构建输入序列
|
|
||||||
sequence = np.array(list(history)) # (seq_len, n_features)
|
|
||||||
|
|
||||||
# 标准化
|
|
||||||
sequence = self.normalize(sequence)
|
|
||||||
|
|
||||||
# 转为 tensor
|
|
||||||
x = torch.FloatTensor(sequence).unsqueeze(0) # (1, seq_len, n_features)
|
|
||||||
x = x.to(self.device)
|
|
||||||
|
|
||||||
# 计算重构误差
|
|
||||||
error = self.model.compute_reconstruction_error(x, reduction='none')
|
|
||||||
|
|
||||||
# 取最后一个时刻的误差作为当前分数
|
|
||||||
score = error[0, -1].item()
|
|
||||||
|
|
||||||
# 判断是否异动
|
|
||||||
is_anomaly = score > self.threshold
|
|
||||||
|
|
||||||
if return_score:
|
|
||||||
return is_anomaly, score
|
|
||||||
else:
|
|
||||||
return is_anomaly, None
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def batch_detect(
|
|
||||||
self,
|
|
||||||
concept_features: Dict[str, Dict[str, float]]
|
|
||||||
) -> Dict[str, Tuple[bool, float]]:
|
|
||||||
"""
|
|
||||||
批量检测多个概念
|
|
||||||
|
|
||||||
Args:
|
|
||||||
concept_features: {concept_name: {feature_name: value}}
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
results: {concept_name: (is_anomaly, score)}
|
|
||||||
"""
|
|
||||||
results = {}
|
|
||||||
|
|
||||||
for concept_name, features in concept_features.items():
|
|
||||||
is_anomaly, score = self.detect(concept_name, features)
|
|
||||||
results[concept_name] = (is_anomaly, score)
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
def get_anomaly_type(
|
|
||||||
self,
|
|
||||||
concept_name: str,
|
|
||||||
features: Dict[str, float]
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
判断异动类型
|
|
||||||
|
|
||||||
Args:
|
|
||||||
concept_name: 概念名称
|
|
||||||
features: 当前特征
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
anomaly_type: 'surge_up' / 'surge_down' / 'normal'
|
|
||||||
"""
|
|
||||||
is_anomaly, score = self.detect(concept_name, features)
|
|
||||||
|
|
||||||
if not is_anomaly:
|
|
||||||
return 'normal'
|
|
||||||
|
|
||||||
# 根据 alpha 判断涨跌
|
|
||||||
alpha = features.get('alpha', 0.0)
|
|
||||||
|
|
||||||
if alpha > 0:
|
|
||||||
return 'surge_up'
|
|
||||||
else:
|
|
||||||
return 'surge_down'
|
|
||||||
|
|
||||||
def get_top_anomalies(
|
|
||||||
self,
|
|
||||||
concept_features: Dict[str, Dict[str, float]],
|
|
||||||
top_k: int = 10
|
|
||||||
) -> List[Tuple[str, float, str]]:
|
|
||||||
"""
|
|
||||||
获取异动分数最高的 top_k 个概念
|
|
||||||
|
|
||||||
Args:
|
|
||||||
concept_features: {concept_name: {feature_name: value}}
|
|
||||||
top_k: 返回数量
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
anomalies: [(concept_name, score, anomaly_type), ...]
|
|
||||||
"""
|
|
||||||
results = self.batch_detect(concept_features)
|
|
||||||
|
|
||||||
# 按分数排序
|
|
||||||
sorted_results = sorted(
|
|
||||||
[(name, is_anomaly, score) for name, (is_anomaly, score) in results.items() if score is not None],
|
|
||||||
key=lambda x: x[2],
|
|
||||||
reverse=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# 取 top_k
|
|
||||||
top_anomalies = []
|
|
||||||
for name, is_anomaly, score in sorted_results[:top_k]:
|
|
||||||
if is_anomaly:
|
|
||||||
alpha = concept_features[name].get('alpha', 0.0)
|
|
||||||
anomaly_type = 'surge_up' if alpha > 0 else 'surge_down'
|
|
||||||
top_anomalies.append((name, score, anomaly_type))
|
|
||||||
|
|
||||||
return top_anomalies
|
|
||||||
|
|
||||||
def clear_history(self, concept_name: str = None):
|
|
||||||
"""
|
|
||||||
清除历史缓存
|
|
||||||
|
|
||||||
Args:
|
|
||||||
concept_name: 概念名称(如果为 None,清除所有)
|
|
||||||
"""
|
|
||||||
if concept_name is None:
|
|
||||||
self.history_cache.clear()
|
|
||||||
elif concept_name in self.history_cache:
|
|
||||||
del self.history_cache[concept_name]
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 集成到现有系统 ====================
|
|
||||||
|
|
||||||
class MLAnomalyService:
|
|
||||||
"""
|
|
||||||
ML 异动检测服务
|
|
||||||
|
|
||||||
用于替换或增强现有的 Alpha-based 检测
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
checkpoint_dir: str = 'ml/checkpoints',
|
|
||||||
fallback_to_alpha: bool = True
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
checkpoint_dir: 模型检查点目录
|
|
||||||
fallback_to_alpha: 当 ML 模型不可用时是否回退到 Alpha 方法
|
|
||||||
"""
|
|
||||||
self.fallback_to_alpha = fallback_to_alpha
|
|
||||||
self.ml_detector = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
self.ml_detector = ConceptAnomalyDetector(checkpoint_dir)
|
|
||||||
print("ML 异动检测服务初始化成功")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"ML 模型加载失败: {e}")
|
|
||||||
if not fallback_to_alpha:
|
|
||||||
raise
|
|
||||||
print("将回退到 Alpha-based 检测")
|
|
||||||
|
|
||||||
def is_ml_available(self) -> bool:
|
|
||||||
"""检查 ML 模型是否可用"""
|
|
||||||
return self.ml_detector is not None
|
|
||||||
|
|
||||||
def detect_anomaly(
|
|
||||||
self,
|
|
||||||
concept_name: str,
|
|
||||||
features: Dict[str, float],
|
|
||||||
alpha_threshold: float = 2.0
|
|
||||||
) -> Tuple[bool, float, str]:
|
|
||||||
"""
|
|
||||||
检测异动
|
|
||||||
|
|
||||||
Args:
|
|
||||||
concept_name: 概念名称
|
|
||||||
features: 特征字典(需包含 alpha, amt_ratio 等)
|
|
||||||
alpha_threshold: Alpha Z-Score 阈值(用于回退)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
is_anomaly: 是否异动
|
|
||||||
score: 异动分数
|
|
||||||
method: 检测方法 ('ml' / 'alpha')
|
|
||||||
"""
|
|
||||||
# 优先使用 ML 检测
|
|
||||||
if self.ml_detector is not None:
|
|
||||||
history_len = self.ml_detector.get_history_length(concept_name)
|
|
||||||
|
|
||||||
# 历史数据足够时使用 ML
|
|
||||||
if history_len >= self.ml_detector.seq_len - 1:
|
|
||||||
is_anomaly, score = self.ml_detector.detect(concept_name, features)
|
|
||||||
if score is not None:
|
|
||||||
return is_anomaly, score, 'ml'
|
|
||||||
else:
|
|
||||||
# 更新历史但使用 Alpha 方法
|
|
||||||
self.ml_detector.update_history(concept_name, features)
|
|
||||||
|
|
||||||
# 回退到 Alpha 方法
|
|
||||||
if self.fallback_to_alpha:
|
|
||||||
alpha = features.get('alpha', 0.0)
|
|
||||||
alpha_zscore = features.get('alpha_zscore', 0.0)
|
|
||||||
|
|
||||||
is_anomaly = abs(alpha_zscore) > alpha_threshold
|
|
||||||
score = abs(alpha_zscore)
|
|
||||||
|
|
||||||
return is_anomaly, score, 'alpha'
|
|
||||||
|
|
||||||
return False, 0.0, 'none'
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 测试 ====================
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import random
|
|
||||||
|
|
||||||
print("测试 ConceptAnomalyDetector...")
|
|
||||||
|
|
||||||
# 检查模型是否存在
|
|
||||||
checkpoint_dir = Path('ml/checkpoints')
|
|
||||||
if not (checkpoint_dir / 'best_model.pt').exists():
|
|
||||||
print("模型文件不存在,跳过测试")
|
|
||||||
print("请先运行 train.py 训练模型")
|
|
||||||
exit(0)
|
|
||||||
|
|
||||||
# 初始化检测器
|
|
||||||
detector = ConceptAnomalyDetector('ml/checkpoints')
|
|
||||||
|
|
||||||
# 模拟数据
|
|
||||||
print("\n模拟实时检测...")
|
|
||||||
concept_name = "人工智能"
|
|
||||||
|
|
||||||
for i in range(40):
|
|
||||||
# 生成随机特征
|
|
||||||
features = {
|
|
||||||
'alpha': random.gauss(0, 1),
|
|
||||||
'alpha_delta': random.gauss(0, 0.5),
|
|
||||||
'amt_ratio': random.gauss(1, 0.3),
|
|
||||||
'amt_delta': random.gauss(0, 0.2),
|
|
||||||
'rank_pct': random.random(),
|
|
||||||
'limit_up_ratio': random.random() * 0.1,
|
|
||||||
}
|
|
||||||
|
|
||||||
# 在第 35 分钟模拟异动
|
|
||||||
if i == 35:
|
|
||||||
features['alpha'] = 5.0
|
|
||||||
features['alpha_delta'] = 2.0
|
|
||||||
features['amt_ratio'] = 3.0
|
|
||||||
|
|
||||||
is_anomaly, score = detector.detect(concept_name, features)
|
|
||||||
|
|
||||||
history_len = detector.get_history_length(concept_name)
|
|
||||||
|
|
||||||
if score is not None:
|
|
||||||
status = "🔥 异动!" if is_anomaly else "正常"
|
|
||||||
print(f" t={i:02d} | 历史={history_len} | 分数={score:.4f} | {status}")
|
|
||||||
else:
|
|
||||||
print(f" t={i:02d} | 历史={history_len} | 数据不足")
|
|
||||||
|
|
||||||
print("\n测试完成!")
|
|
||||||
393
ml/model.py
393
ml/model.py
@@ -1,393 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""
|
|
||||||
LSTM Autoencoder 模型定义
|
|
||||||
|
|
||||||
用于概念异动检测:
|
|
||||||
- 学习"正常"市场模式
|
|
||||||
- 重构误差大的时刻 = 异动
|
|
||||||
|
|
||||||
模型结构(简洁有效):
|
|
||||||
┌─────────────────────────────────────┐
|
|
||||||
│ 输入: (batch, seq_len, n_features) │
|
|
||||||
│ 过去30分钟的特征序列 │
|
|
||||||
├─────────────────────────────────────┤
|
|
||||||
│ LSTM Encoder │
|
|
||||||
│ - 双向 LSTM │
|
|
||||||
│ - 输出最后隐藏状态 │
|
|
||||||
├─────────────────────────────────────┤
|
|
||||||
│ Bottleneck (压缩层) │
|
|
||||||
│ 降维到 latent_dim(关键!) │
|
|
||||||
├─────────────────────────────────────┤
|
|
||||||
│ LSTM Decoder │
|
|
||||||
│ - 单向 LSTM │
|
|
||||||
│ - 重构序列 │
|
|
||||||
├─────────────────────────────────────┤
|
|
||||||
│ 输出: (batch, seq_len, n_features) │
|
|
||||||
│ 重构的特征序列 │
|
|
||||||
└─────────────────────────────────────┘
|
|
||||||
|
|
||||||
为什么用 LSTM 而不是 Transformer:
|
|
||||||
1. 参数更少,不容易过拟合
|
|
||||||
2. 对于 6 维特征足够用
|
|
||||||
3. 训练更稳定
|
|
||||||
4. 瓶颈约束更容易控制
|
|
||||||
"""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from typing import Optional, Tuple
|
|
||||||
|
|
||||||
|
|
||||||
class LSTMAutoencoder(nn.Module):
|
|
||||||
"""
|
|
||||||
LSTM Autoencoder for Anomaly Detection
|
|
||||||
|
|
||||||
设计原则:
|
|
||||||
- 足够简单,避免过拟合
|
|
||||||
- 瓶颈层严格限制,迫使模型只学习主要模式
|
|
||||||
- 异常难以通过狭窄瓶颈,重构误差大
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
n_features: int = 6,
|
|
||||||
hidden_dim: int = 32, # LSTM 隐藏维度(小!)
|
|
||||||
latent_dim: int = 4, # 瓶颈维度(非常小!关键参数)
|
|
||||||
num_layers: int = 1, # LSTM 层数
|
|
||||||
dropout: float = 0.2,
|
|
||||||
bidirectional: bool = True, # 双向编码器
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.n_features = n_features
|
|
||||||
self.hidden_dim = hidden_dim
|
|
||||||
self.latent_dim = latent_dim
|
|
||||||
self.num_layers = num_layers
|
|
||||||
self.bidirectional = bidirectional
|
|
||||||
self.num_directions = 2 if bidirectional else 1
|
|
||||||
|
|
||||||
# Encoder: 双向 LSTM
|
|
||||||
self.encoder = nn.LSTM(
|
|
||||||
input_size=n_features,
|
|
||||||
hidden_size=hidden_dim,
|
|
||||||
num_layers=num_layers,
|
|
||||||
batch_first=True,
|
|
||||||
dropout=dropout if num_layers > 1 else 0,
|
|
||||||
bidirectional=bidirectional
|
|
||||||
)
|
|
||||||
|
|
||||||
# Bottleneck: 压缩到极小的 latent space
|
|
||||||
encoder_output_dim = hidden_dim * self.num_directions
|
|
||||||
self.bottleneck_down = nn.Sequential(
|
|
||||||
nn.Linear(encoder_output_dim, latent_dim),
|
|
||||||
nn.Tanh(), # 限制范围,增加约束
|
|
||||||
)
|
|
||||||
|
|
||||||
# 使用 LeakyReLU 替代 ReLU
|
|
||||||
# 原因:Z-Score 数据范围是 [-5, +5],ReLU 会截断负值,丢失跌幅信息
|
|
||||||
# LeakyReLU 保留负值信号(乘以 0.1)
|
|
||||||
self.bottleneck_up = nn.Sequential(
|
|
||||||
nn.Linear(latent_dim, hidden_dim),
|
|
||||||
nn.LeakyReLU(negative_slope=0.1),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Decoder: 单向 LSTM
|
|
||||||
self.decoder = nn.LSTM(
|
|
||||||
input_size=hidden_dim,
|
|
||||||
hidden_size=hidden_dim,
|
|
||||||
num_layers=num_layers,
|
|
||||||
batch_first=True,
|
|
||||||
dropout=dropout if num_layers > 1 else 0,
|
|
||||||
bidirectional=False # 解码器用单向
|
|
||||||
)
|
|
||||||
|
|
||||||
# 输出层
|
|
||||||
self.output_layer = nn.Linear(hidden_dim, n_features)
|
|
||||||
|
|
||||||
# Dropout
|
|
||||||
self.dropout = nn.Dropout(dropout)
|
|
||||||
|
|
||||||
# 初始化
|
|
||||||
self._init_weights()
|
|
||||||
|
|
||||||
def _init_weights(self):
|
|
||||||
"""初始化权重"""
|
|
||||||
for name, param in self.named_parameters():
|
|
||||||
if 'weight_ih' in name:
|
|
||||||
nn.init.xavier_uniform_(param)
|
|
||||||
elif 'weight_hh' in name:
|
|
||||||
nn.init.orthogonal_(param)
|
|
||||||
elif 'bias' in name:
|
|
||||||
nn.init.zeros_(param)
|
|
||||||
|
|
||||||
def encode(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
"""
|
|
||||||
编码器
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x: (batch, seq_len, n_features)
|
|
||||||
Returns:
|
|
||||||
latent: (batch, seq_len, latent_dim) 每个时间步的压缩表示
|
|
||||||
encoder_outputs: (batch, seq_len, hidden_dim * num_directions)
|
|
||||||
"""
|
|
||||||
# LSTM 编码
|
|
||||||
encoder_outputs, (h_n, c_n) = self.encoder(x)
|
|
||||||
# encoder_outputs: (batch, seq_len, hidden_dim * num_directions)
|
|
||||||
|
|
||||||
encoder_outputs = self.dropout(encoder_outputs)
|
|
||||||
|
|
||||||
# 压缩到 latent space(对每个时间步)
|
|
||||||
latent = self.bottleneck_down(encoder_outputs)
|
|
||||||
# latent: (batch, seq_len, latent_dim)
|
|
||||||
|
|
||||||
return latent, encoder_outputs
|
|
||||||
|
|
||||||
def decode(self, latent: torch.Tensor, seq_len: int) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
解码器
|
|
||||||
|
|
||||||
Args:
|
|
||||||
latent: (batch, seq_len, latent_dim)
|
|
||||||
seq_len: 序列长度
|
|
||||||
Returns:
|
|
||||||
output: (batch, seq_len, n_features)
|
|
||||||
"""
|
|
||||||
# 从 latent space 恢复
|
|
||||||
decoder_input = self.bottleneck_up(latent)
|
|
||||||
# decoder_input: (batch, seq_len, hidden_dim)
|
|
||||||
|
|
||||||
# LSTM 解码
|
|
||||||
decoder_outputs, _ = self.decoder(decoder_input)
|
|
||||||
# decoder_outputs: (batch, seq_len, hidden_dim)
|
|
||||||
|
|
||||||
decoder_outputs = self.dropout(decoder_outputs)
|
|
||||||
|
|
||||||
# 投影到原始特征空间
|
|
||||||
output = self.output_layer(decoder_outputs)
|
|
||||||
# output: (batch, seq_len, n_features)
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
"""
|
|
||||||
前向传播
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x: (batch, seq_len, n_features)
|
|
||||||
Returns:
|
|
||||||
output: (batch, seq_len, n_features) 重构结果
|
|
||||||
latent: (batch, seq_len, latent_dim) 隐向量
|
|
||||||
"""
|
|
||||||
batch_size, seq_len, _ = x.shape
|
|
||||||
|
|
||||||
# 编码
|
|
||||||
latent, _ = self.encode(x)
|
|
||||||
|
|
||||||
# 解码
|
|
||||||
output = self.decode(latent, seq_len)
|
|
||||||
|
|
||||||
return output, latent
|
|
||||||
|
|
||||||
def compute_reconstruction_error(
|
|
||||||
self,
|
|
||||||
x: torch.Tensor,
|
|
||||||
reduction: str = 'none'
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
计算重构误差
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x: (batch, seq_len, n_features)
|
|
||||||
reduction: 'none' | 'mean' | 'sum'
|
|
||||||
Returns:
|
|
||||||
error: 重构误差
|
|
||||||
"""
|
|
||||||
output, _ = self.forward(x)
|
|
||||||
|
|
||||||
# MSE per feature per timestep
|
|
||||||
error = F.mse_loss(output, x, reduction='none')
|
|
||||||
|
|
||||||
if reduction == 'none':
|
|
||||||
# (batch, seq_len, n_features) -> (batch, seq_len)
|
|
||||||
return error.mean(dim=-1)
|
|
||||||
elif reduction == 'mean':
|
|
||||||
return error.mean()
|
|
||||||
elif reduction == 'sum':
|
|
||||||
return error.sum()
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown reduction: {reduction}")
|
|
||||||
|
|
||||||
def detect_anomaly(
|
|
||||||
self,
|
|
||||||
x: torch.Tensor,
|
|
||||||
threshold: float = None,
|
|
||||||
return_scores: bool = True
|
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
||||||
"""
|
|
||||||
检测异动
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x: (batch, seq_len, n_features)
|
|
||||||
threshold: 异动阈值(如果为 None,只返回分数)
|
|
||||||
return_scores: 是否返回异动分数
|
|
||||||
Returns:
|
|
||||||
is_anomaly: (batch, seq_len) bool tensor (if threshold is not None)
|
|
||||||
scores: (batch, seq_len) 异动分数 (if return_scores)
|
|
||||||
"""
|
|
||||||
scores = self.compute_reconstruction_error(x, reduction='none')
|
|
||||||
|
|
||||||
is_anomaly = None
|
|
||||||
if threshold is not None:
|
|
||||||
is_anomaly = scores > threshold
|
|
||||||
|
|
||||||
if return_scores:
|
|
||||||
return is_anomaly, scores
|
|
||||||
else:
|
|
||||||
return is_anomaly, None
|
|
||||||
|
|
||||||
|
|
||||||
# 为了兼容性,创建别名
|
|
||||||
TransformerAutoencoder = LSTMAutoencoder
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 损失函数 ====================
|
|
||||||
|
|
||||||
class AnomalyDetectionLoss(nn.Module):
|
|
||||||
"""
|
|
||||||
异动检测损失函数
|
|
||||||
|
|
||||||
简单的 MSE 重构损失
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
feature_weights: torch.Tensor = None,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.feature_weights = feature_weights
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
output: torch.Tensor,
|
|
||||||
target: torch.Tensor,
|
|
||||||
latent: torch.Tensor = None
|
|
||||||
) -> Tuple[torch.Tensor, dict]:
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
output: (batch, seq_len, n_features) 重构结果
|
|
||||||
target: (batch, seq_len, n_features) 原始输入
|
|
||||||
latent: (batch, seq_len, latent_dim) 隐向量(未使用)
|
|
||||||
Returns:
|
|
||||||
loss: 总损失
|
|
||||||
loss_dict: 各项损失详情
|
|
||||||
"""
|
|
||||||
# 重构损失 (MSE)
|
|
||||||
mse = F.mse_loss(output, target, reduction='none')
|
|
||||||
|
|
||||||
# 特征加权(可选)
|
|
||||||
if self.feature_weights is not None:
|
|
||||||
weights = self.feature_weights.to(mse.device)
|
|
||||||
mse = mse * weights
|
|
||||||
|
|
||||||
reconstruction_loss = mse.mean()
|
|
||||||
|
|
||||||
loss_dict = {
|
|
||||||
'total': reconstruction_loss.item(),
|
|
||||||
'reconstruction': reconstruction_loss.item(),
|
|
||||||
}
|
|
||||||
|
|
||||||
return reconstruction_loss, loss_dict
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 工具函数 ====================
|
|
||||||
|
|
||||||
def count_parameters(model: nn.Module) -> int:
|
|
||||||
"""统计模型参数量"""
|
|
||||||
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
||||||
|
|
||||||
|
|
||||||
def create_model(config: dict = None) -> LSTMAutoencoder:
|
|
||||||
"""
|
|
||||||
创建模型
|
|
||||||
|
|
||||||
默认使用小型 LSTM 配置,适合异动检测
|
|
||||||
"""
|
|
||||||
default_config = {
|
|
||||||
'n_features': 6,
|
|
||||||
'hidden_dim': 32, # 小!
|
|
||||||
'latent_dim': 4, # 非常小!关键
|
|
||||||
'num_layers': 1,
|
|
||||||
'dropout': 0.2,
|
|
||||||
'bidirectional': True,
|
|
||||||
}
|
|
||||||
|
|
||||||
if config:
|
|
||||||
# 兼容旧的 Transformer 配置键名
|
|
||||||
if 'd_model' in config:
|
|
||||||
config['hidden_dim'] = config.pop('d_model') // 2
|
|
||||||
if 'num_encoder_layers' in config:
|
|
||||||
config['num_layers'] = config.pop('num_encoder_layers')
|
|
||||||
if 'num_decoder_layers' in config:
|
|
||||||
config.pop('num_decoder_layers')
|
|
||||||
if 'nhead' in config:
|
|
||||||
config.pop('nhead')
|
|
||||||
if 'dim_feedforward' in config:
|
|
||||||
config.pop('dim_feedforward')
|
|
||||||
if 'max_seq_len' in config:
|
|
||||||
config.pop('max_seq_len')
|
|
||||||
if 'use_instance_norm' in config:
|
|
||||||
config.pop('use_instance_norm')
|
|
||||||
|
|
||||||
default_config.update(config)
|
|
||||||
|
|
||||||
model = LSTMAutoencoder(**default_config)
|
|
||||||
param_count = count_parameters(model)
|
|
||||||
print(f"模型参数量: {param_count:,}")
|
|
||||||
|
|
||||||
if param_count > 100000:
|
|
||||||
print(f"⚠️ 警告: 参数量较大({param_count:,}),可能过拟合")
|
|
||||||
else:
|
|
||||||
print(f"✓ 参数量适中(LSTM Autoencoder)")
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# 测试模型
|
|
||||||
print("测试 LSTM Autoencoder...")
|
|
||||||
|
|
||||||
# 创建模型
|
|
||||||
model = create_model()
|
|
||||||
|
|
||||||
# 测试输入
|
|
||||||
batch_size = 32
|
|
||||||
seq_len = 30
|
|
||||||
n_features = 6
|
|
||||||
|
|
||||||
x = torch.randn(batch_size, seq_len, n_features)
|
|
||||||
|
|
||||||
# 前向传播
|
|
||||||
output, latent = model(x)
|
|
||||||
|
|
||||||
print(f"输入形状: {x.shape}")
|
|
||||||
print(f"输出形状: {output.shape}")
|
|
||||||
print(f"隐向量形状: {latent.shape}")
|
|
||||||
|
|
||||||
# 计算重构误差
|
|
||||||
error = model.compute_reconstruction_error(x)
|
|
||||||
print(f"重构误差形状: {error.shape}")
|
|
||||||
print(f"平均重构误差: {error.mean().item():.4f}")
|
|
||||||
|
|
||||||
# 测试异动检测
|
|
||||||
is_anomaly, scores = model.detect_anomaly(x, threshold=0.5)
|
|
||||||
print(f"异动检测结果形状: {is_anomaly.shape if is_anomaly is not None else 'None'}")
|
|
||||||
print(f"异动分数形状: {scores.shape}")
|
|
||||||
|
|
||||||
# 测试损失函数
|
|
||||||
criterion = AnomalyDetectionLoss()
|
|
||||||
loss, loss_dict = criterion(output, x, latent)
|
|
||||||
print(f"损失: {loss.item():.4f}")
|
|
||||||
|
|
||||||
print("\n测试通过!")
|
|
||||||
@@ -1,537 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""
|
|
||||||
数据准备脚本 - 为 Transformer Autoencoder 准备训练数据
|
|
||||||
|
|
||||||
从 ClickHouse 提取历史分钟数据,计算以下特征:
|
|
||||||
1. alpha - 超额收益(概念涨幅 - 大盘涨幅)
|
|
||||||
2. alpha_delta - Alpha 变化率(5分钟)
|
|
||||||
3. amt_ratio - 成交额相对均值(当前/过去20分钟均值)
|
|
||||||
4. amt_delta - 成交额变化率
|
|
||||||
5. rank_pct - Alpha 排名百分位
|
|
||||||
6. limit_up_ratio - 涨停股占比
|
|
||||||
|
|
||||||
输出:按交易日存储的特征文件(parquet格式)
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import numpy as np
|
|
||||||
import pandas as pd
|
|
||||||
from datetime import datetime, timedelta, date
|
|
||||||
from sqlalchemy import create_engine, text
|
|
||||||
from elasticsearch import Elasticsearch
|
|
||||||
from clickhouse_driver import Client
|
|
||||||
import hashlib
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
from typing import Dict, List, Set, Tuple
|
|
||||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
|
||||||
from multiprocessing import Manager
|
|
||||||
import multiprocessing
|
|
||||||
import warnings
|
|
||||||
warnings.filterwarnings('ignore')
|
|
||||||
|
|
||||||
# ==================== 配置 ====================
|
|
||||||
|
|
||||||
MYSQL_ENGINE = create_engine(
|
|
||||||
"mysql+pymysql://root:Zzl5588161!@192.168.1.5:3306/stock",
|
|
||||||
echo=False
|
|
||||||
)
|
|
||||||
|
|
||||||
ES_CLIENT = Elasticsearch(['http://127.0.0.1:9200'])
|
|
||||||
ES_INDEX = 'concept_library_v3'
|
|
||||||
|
|
||||||
CLICKHOUSE_CONFIG = {
|
|
||||||
'host': '127.0.0.1',
|
|
||||||
'port': 9000,
|
|
||||||
'user': 'default',
|
|
||||||
'password': 'Zzl33818!',
|
|
||||||
'database': 'stock'
|
|
||||||
}
|
|
||||||
|
|
||||||
# 输出目录
|
|
||||||
OUTPUT_DIR = os.path.join(os.path.dirname(__file__), 'data')
|
|
||||||
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
|
||||||
|
|
||||||
# 特征计算参数
|
|
||||||
FEATURE_CONFIG = {
|
|
||||||
'alpha_delta_window': 5, # Alpha变化窗口(分钟)
|
|
||||||
'amt_ma_window': 20, # 成交额均值窗口(分钟)
|
|
||||||
'limit_up_threshold': 9.8, # 涨停阈值(%)
|
|
||||||
'limit_down_threshold': -9.8, # 跌停阈值(%)
|
|
||||||
}
|
|
||||||
|
|
||||||
REFERENCE_INDEX = '000001.SH'
|
|
||||||
|
|
||||||
# ==================== 日志 ====================
|
|
||||||
|
|
||||||
logging.basicConfig(
|
|
||||||
level=logging.INFO,
|
|
||||||
format='%(asctime)s - %(levelname)s - %(message)s'
|
|
||||||
)
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# ==================== 工具函数 ====================
|
|
||||||
|
|
||||||
def get_ch_client():
|
|
||||||
return Client(**CLICKHOUSE_CONFIG)
|
|
||||||
|
|
||||||
|
|
||||||
def generate_id(name: str) -> str:
|
|
||||||
return hashlib.md5(name.encode('utf-8')).hexdigest()[:16]
|
|
||||||
|
|
||||||
|
|
||||||
def code_to_ch_format(code: str) -> str:
|
|
||||||
if not code or len(code) != 6 or not code.isdigit():
|
|
||||||
return None
|
|
||||||
if code.startswith('6'):
|
|
||||||
return f"{code}.SH"
|
|
||||||
elif code.startswith('0') or code.startswith('3'):
|
|
||||||
return f"{code}.SZ"
|
|
||||||
else:
|
|
||||||
return f"{code}.BJ"
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 获取概念列表 ====================
|
|
||||||
|
|
||||||
def get_all_concepts() -> List[dict]:
|
|
||||||
"""从ES获取所有叶子概念"""
|
|
||||||
concepts = []
|
|
||||||
|
|
||||||
query = {
|
|
||||||
"query": {"match_all": {}},
|
|
||||||
"size": 100,
|
|
||||||
"_source": ["concept_id", "concept", "stocks"]
|
|
||||||
}
|
|
||||||
|
|
||||||
resp = ES_CLIENT.search(index=ES_INDEX, body=query, scroll='2m')
|
|
||||||
scroll_id = resp['_scroll_id']
|
|
||||||
hits = resp['hits']['hits']
|
|
||||||
|
|
||||||
while len(hits) > 0:
|
|
||||||
for hit in hits:
|
|
||||||
source = hit['_source']
|
|
||||||
stocks = []
|
|
||||||
if 'stocks' in source and isinstance(source['stocks'], list):
|
|
||||||
for stock in source['stocks']:
|
|
||||||
if isinstance(stock, dict) and 'code' in stock and stock['code']:
|
|
||||||
stocks.append(stock['code'])
|
|
||||||
|
|
||||||
if stocks:
|
|
||||||
concepts.append({
|
|
||||||
'concept_id': source.get('concept_id'),
|
|
||||||
'concept_name': source.get('concept'),
|
|
||||||
'stocks': stocks
|
|
||||||
})
|
|
||||||
|
|
||||||
resp = ES_CLIENT.scroll(scroll_id=scroll_id, scroll='2m')
|
|
||||||
scroll_id = resp['_scroll_id']
|
|
||||||
hits = resp['hits']['hits']
|
|
||||||
|
|
||||||
ES_CLIENT.clear_scroll(scroll_id=scroll_id)
|
|
||||||
print(f"获取到 {len(concepts)} 个概念")
|
|
||||||
return concepts
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 获取交易日列表 ====================
|
|
||||||
|
|
||||||
def get_trading_days(start_date: str, end_date: str) -> List[str]:
|
|
||||||
"""获取交易日列表"""
|
|
||||||
client = get_ch_client()
|
|
||||||
|
|
||||||
query = f"""
|
|
||||||
SELECT DISTINCT toDate(timestamp) as trade_date
|
|
||||||
FROM stock_minute
|
|
||||||
WHERE toDate(timestamp) >= '{start_date}'
|
|
||||||
AND toDate(timestamp) <= '{end_date}'
|
|
||||||
ORDER BY trade_date
|
|
||||||
"""
|
|
||||||
|
|
||||||
result = client.execute(query)
|
|
||||||
days = [row[0].strftime('%Y-%m-%d') for row in result]
|
|
||||||
print(f"找到 {len(days)} 个交易日: {days[0]} ~ {days[-1]}")
|
|
||||||
return days
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 获取单日数据 ====================
|
|
||||||
|
|
||||||
def get_daily_stock_data(trade_date: str, stock_codes: List[str]) -> pd.DataFrame:
|
|
||||||
"""获取单日所有股票的分钟数据"""
|
|
||||||
client = get_ch_client()
|
|
||||||
|
|
||||||
# 转换代码格式
|
|
||||||
ch_codes = []
|
|
||||||
code_map = {}
|
|
||||||
for code in stock_codes:
|
|
||||||
ch_code = code_to_ch_format(code)
|
|
||||||
if ch_code:
|
|
||||||
ch_codes.append(ch_code)
|
|
||||||
code_map[ch_code] = code
|
|
||||||
|
|
||||||
if not ch_codes:
|
|
||||||
return pd.DataFrame()
|
|
||||||
|
|
||||||
ch_codes_str = "','".join(ch_codes)
|
|
||||||
|
|
||||||
query = f"""
|
|
||||||
SELECT
|
|
||||||
code,
|
|
||||||
timestamp,
|
|
||||||
close,
|
|
||||||
volume,
|
|
||||||
amt
|
|
||||||
FROM stock_minute
|
|
||||||
WHERE toDate(timestamp) = '{trade_date}'
|
|
||||||
AND code IN ('{ch_codes_str}')
|
|
||||||
ORDER BY code, timestamp
|
|
||||||
"""
|
|
||||||
|
|
||||||
result = client.execute(query)
|
|
||||||
|
|
||||||
if not result:
|
|
||||||
return pd.DataFrame()
|
|
||||||
|
|
||||||
df = pd.DataFrame(result, columns=['ch_code', 'timestamp', 'close', 'volume', 'amt'])
|
|
||||||
df['code'] = df['ch_code'].map(code_map)
|
|
||||||
df = df.dropna(subset=['code'])
|
|
||||||
|
|
||||||
return df[['code', 'timestamp', 'close', 'volume', 'amt']]
|
|
||||||
|
|
||||||
|
|
||||||
def get_daily_index_data(trade_date: str, index_code: str = REFERENCE_INDEX) -> pd.DataFrame:
|
|
||||||
"""获取单日指数分钟数据"""
|
|
||||||
client = get_ch_client()
|
|
||||||
|
|
||||||
query = f"""
|
|
||||||
SELECT
|
|
||||||
timestamp,
|
|
||||||
close,
|
|
||||||
volume,
|
|
||||||
amt
|
|
||||||
FROM index_minute
|
|
||||||
WHERE toDate(timestamp) = '{trade_date}'
|
|
||||||
AND code = '{index_code}'
|
|
||||||
ORDER BY timestamp
|
|
||||||
"""
|
|
||||||
|
|
||||||
result = client.execute(query)
|
|
||||||
|
|
||||||
if not result:
|
|
||||||
return pd.DataFrame()
|
|
||||||
|
|
||||||
df = pd.DataFrame(result, columns=['timestamp', 'close', 'volume', 'amt'])
|
|
||||||
return df
|
|
||||||
|
|
||||||
|
|
||||||
def get_prev_close(stock_codes: List[str], trade_date: str) -> Dict[str, float]:
|
|
||||||
"""获取昨收价(上一交易日的收盘价 F007N)"""
|
|
||||||
valid_codes = [c for c in stock_codes if c and len(c) == 6 and c.isdigit()]
|
|
||||||
if not valid_codes:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
codes_str = "','".join(valid_codes)
|
|
||||||
|
|
||||||
# 注意:F007N 是"最近成交价"即当日收盘价,F002N 是"昨日收盘价"
|
|
||||||
# 我们需要查上一交易日的 F007N(那天的收盘价)作为今天的昨收
|
|
||||||
query = f"""
|
|
||||||
SELECT SECCODE, F007N
|
|
||||||
FROM ea_trade
|
|
||||||
WHERE SECCODE IN ('{codes_str}')
|
|
||||||
AND TRADEDATE = (
|
|
||||||
SELECT MAX(TRADEDATE) FROM ea_trade WHERE TRADEDATE < '{trade_date}'
|
|
||||||
)
|
|
||||||
AND F007N IS NOT NULL AND F007N > 0
|
|
||||||
"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
with MYSQL_ENGINE.connect() as conn:
|
|
||||||
result = conn.execute(text(query))
|
|
||||||
return {row[0]: float(row[1]) for row in result if row[1]}
|
|
||||||
except Exception as e:
|
|
||||||
print(f"获取昨收价失败: {e}")
|
|
||||||
return {}
|
|
||||||
|
|
||||||
|
|
||||||
def get_index_prev_close(trade_date: str, index_code: str = REFERENCE_INDEX) -> float:
|
|
||||||
"""获取指数昨收价"""
|
|
||||||
code_no_suffix = index_code.split('.')[0]
|
|
||||||
|
|
||||||
try:
|
|
||||||
with MYSQL_ENGINE.connect() as conn:
|
|
||||||
result = conn.execute(text("""
|
|
||||||
SELECT F006N FROM ea_exchangetrade
|
|
||||||
WHERE INDEXCODE = :code AND TRADEDATE < :today
|
|
||||||
ORDER BY TRADEDATE DESC LIMIT 1
|
|
||||||
"""), {'code': code_no_suffix, 'today': trade_date}).fetchone()
|
|
||||||
|
|
||||||
if result and result[0]:
|
|
||||||
return float(result[0])
|
|
||||||
except Exception as e:
|
|
||||||
print(f"获取指数昨收失败: {e}")
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 计算特征 ====================
|
|
||||||
|
|
||||||
def compute_daily_features(
|
|
||||||
trade_date: str,
|
|
||||||
concepts: List[dict],
|
|
||||||
all_stocks: List[str]
|
|
||||||
) -> pd.DataFrame:
|
|
||||||
"""
|
|
||||||
计算单日所有概念的特征
|
|
||||||
|
|
||||||
返回 DataFrame:
|
|
||||||
- index: (timestamp, concept_id)
|
|
||||||
- columns: alpha, alpha_delta, amt_ratio, amt_delta, rank_pct, limit_up_ratio
|
|
||||||
"""
|
|
||||||
|
|
||||||
# 1. 获取数据
|
|
||||||
stock_df = get_daily_stock_data(trade_date, all_stocks)
|
|
||||||
if stock_df.empty:
|
|
||||||
return pd.DataFrame()
|
|
||||||
|
|
||||||
index_df = get_daily_index_data(trade_date)
|
|
||||||
if index_df.empty:
|
|
||||||
return pd.DataFrame()
|
|
||||||
|
|
||||||
# 2. 获取昨收价
|
|
||||||
prev_close = get_prev_close(all_stocks, trade_date)
|
|
||||||
index_prev_close = get_index_prev_close(trade_date)
|
|
||||||
|
|
||||||
if not prev_close or not index_prev_close:
|
|
||||||
return pd.DataFrame()
|
|
||||||
|
|
||||||
# 3. 计算股票涨跌幅和成交额
|
|
||||||
stock_df['prev_close'] = stock_df['code'].map(prev_close)
|
|
||||||
stock_df = stock_df.dropna(subset=['prev_close'])
|
|
||||||
stock_df['change_pct'] = (stock_df['close'] - stock_df['prev_close']) / stock_df['prev_close'] * 100
|
|
||||||
|
|
||||||
# 4. 计算指数涨跌幅
|
|
||||||
index_df['change_pct'] = (index_df['close'] - index_prev_close) / index_prev_close * 100
|
|
||||||
index_change_map = dict(zip(index_df['timestamp'], index_df['change_pct']))
|
|
||||||
|
|
||||||
# 5. 获取所有时间点
|
|
||||||
timestamps = sorted(stock_df['timestamp'].unique())
|
|
||||||
|
|
||||||
# 6. 按时间点计算概念特征
|
|
||||||
results = []
|
|
||||||
|
|
||||||
# 概念到股票的映射
|
|
||||||
concept_stocks = {c['concept_id']: set(c['stocks']) for c in concepts}
|
|
||||||
concept_names = {c['concept_id']: c['concept_name'] for c in concepts}
|
|
||||||
|
|
||||||
# 历史数据缓存(用于计算变化率)
|
|
||||||
concept_history = {cid: {'alpha': [], 'amt': []} for cid in concept_stocks}
|
|
||||||
|
|
||||||
for ts in timestamps:
|
|
||||||
ts_stock_data = stock_df[stock_df['timestamp'] == ts]
|
|
||||||
index_change = index_change_map.get(ts, 0)
|
|
||||||
|
|
||||||
# 股票涨跌幅和成交额字典
|
|
||||||
stock_change = dict(zip(ts_stock_data['code'], ts_stock_data['change_pct']))
|
|
||||||
stock_amt = dict(zip(ts_stock_data['code'], ts_stock_data['amt']))
|
|
||||||
|
|
||||||
concept_features = []
|
|
||||||
|
|
||||||
for concept_id, stocks in concept_stocks.items():
|
|
||||||
# 该概念的股票数据
|
|
||||||
concept_changes = [stock_change[s] for s in stocks if s in stock_change]
|
|
||||||
concept_amts = [stock_amt.get(s, 0) for s in stocks if s in stock_change]
|
|
||||||
|
|
||||||
if not concept_changes:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 基础统计
|
|
||||||
avg_change = np.mean(concept_changes)
|
|
||||||
total_amt = sum(concept_amts)
|
|
||||||
|
|
||||||
# Alpha = 概念涨幅 - 指数涨幅
|
|
||||||
alpha = avg_change - index_change
|
|
||||||
|
|
||||||
# 涨停/跌停股占比
|
|
||||||
limit_up_count = sum(1 for c in concept_changes if c >= FEATURE_CONFIG['limit_up_threshold'])
|
|
||||||
limit_down_count = sum(1 for c in concept_changes if c <= FEATURE_CONFIG['limit_down_threshold'])
|
|
||||||
limit_up_ratio = limit_up_count / len(concept_changes)
|
|
||||||
limit_down_ratio = limit_down_count / len(concept_changes)
|
|
||||||
|
|
||||||
# 更新历史
|
|
||||||
history = concept_history[concept_id]
|
|
||||||
history['alpha'].append(alpha)
|
|
||||||
history['amt'].append(total_amt)
|
|
||||||
|
|
||||||
# 计算变化率
|
|
||||||
alpha_delta = 0
|
|
||||||
if len(history['alpha']) > FEATURE_CONFIG['alpha_delta_window']:
|
|
||||||
alpha_delta = alpha - history['alpha'][-FEATURE_CONFIG['alpha_delta_window']-1]
|
|
||||||
|
|
||||||
# 成交额相对均值
|
|
||||||
amt_ratio = 1.0
|
|
||||||
amt_delta = 0
|
|
||||||
if len(history['amt']) > FEATURE_CONFIG['amt_ma_window']:
|
|
||||||
amt_ma = np.mean(history['amt'][-FEATURE_CONFIG['amt_ma_window']-1:-1])
|
|
||||||
if amt_ma > 0:
|
|
||||||
amt_ratio = total_amt / amt_ma
|
|
||||||
amt_delta = total_amt - history['amt'][-2] if len(history['amt']) > 1 else 0
|
|
||||||
|
|
||||||
concept_features.append({
|
|
||||||
'concept_id': concept_id,
|
|
||||||
'alpha': alpha,
|
|
||||||
'alpha_delta': alpha_delta,
|
|
||||||
'amt_ratio': amt_ratio,
|
|
||||||
'amt_delta': amt_delta,
|
|
||||||
'limit_up_ratio': limit_up_ratio,
|
|
||||||
'limit_down_ratio': limit_down_ratio,
|
|
||||||
'total_amt': total_amt,
|
|
||||||
'stock_count': len(concept_changes),
|
|
||||||
})
|
|
||||||
|
|
||||||
if not concept_features:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 计算排名百分位
|
|
||||||
concept_df = pd.DataFrame(concept_features)
|
|
||||||
concept_df['rank_pct'] = concept_df['alpha'].rank(pct=True)
|
|
||||||
|
|
||||||
# 添加时间戳
|
|
||||||
concept_df['timestamp'] = ts
|
|
||||||
results.append(concept_df)
|
|
||||||
|
|
||||||
if not results:
|
|
||||||
return pd.DataFrame()
|
|
||||||
|
|
||||||
# 合并所有时间点
|
|
||||||
final_df = pd.concat(results, ignore_index=True)
|
|
||||||
|
|
||||||
# 标准化成交额变化率
|
|
||||||
if 'amt_delta' in final_df.columns:
|
|
||||||
amt_delta_std = final_df['amt_delta'].std()
|
|
||||||
if amt_delta_std > 0:
|
|
||||||
final_df['amt_delta'] = final_df['amt_delta'] / amt_delta_std
|
|
||||||
|
|
||||||
return final_df
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 主流程 ====================
|
|
||||||
|
|
||||||
def process_single_day(args) -> Tuple[str, bool]:
|
|
||||||
"""
|
|
||||||
处理单个交易日(多进程版本)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
args: (trade_date, concepts, all_stocks) 元组
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
(trade_date, success) 元组
|
|
||||||
"""
|
|
||||||
trade_date, concepts, all_stocks = args
|
|
||||||
output_file = os.path.join(OUTPUT_DIR, f'features_{trade_date}.parquet')
|
|
||||||
|
|
||||||
# 检查是否已处理
|
|
||||||
if os.path.exists(output_file):
|
|
||||||
print(f"[{trade_date}] 已存在,跳过")
|
|
||||||
return (trade_date, True)
|
|
||||||
|
|
||||||
print(f"[{trade_date}] 开始处理...")
|
|
||||||
|
|
||||||
try:
|
|
||||||
df = compute_daily_features(trade_date, concepts, all_stocks)
|
|
||||||
|
|
||||||
if df.empty:
|
|
||||||
print(f"[{trade_date}] 无数据")
|
|
||||||
return (trade_date, False)
|
|
||||||
|
|
||||||
# 保存
|
|
||||||
df.to_parquet(output_file, index=False)
|
|
||||||
print(f"[{trade_date}] 保存完成")
|
|
||||||
return (trade_date, True)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[{trade_date}] 处理失败: {e}")
|
|
||||||
import traceback
|
|
||||||
traceback.print_exc()
|
|
||||||
return (trade_date, False)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
import argparse
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='准备训练数据')
|
|
||||||
parser.add_argument('--start', type=str, default='2022-01-01', help='开始日期')
|
|
||||||
parser.add_argument('--end', type=str, default=None, help='结束日期(默认今天)')
|
|
||||||
parser.add_argument('--workers', type=int, default=18, help='并行进程数(默认18)')
|
|
||||||
parser.add_argument('--force', action='store_true', help='强制重新处理已存在的文件')
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
end_date = args.end or datetime.now().strftime('%Y-%m-%d')
|
|
||||||
|
|
||||||
print("=" * 60)
|
|
||||||
print("数据准备 - Transformer Autoencoder 训练数据")
|
|
||||||
print("=" * 60)
|
|
||||||
print(f"日期范围: {args.start} ~ {end_date}")
|
|
||||||
print(f"并行进程数: {args.workers}")
|
|
||||||
|
|
||||||
# 1. 获取概念列表
|
|
||||||
concepts = get_all_concepts()
|
|
||||||
|
|
||||||
# 收集所有股票
|
|
||||||
all_stocks = list(set(s for c in concepts for s in c['stocks']))
|
|
||||||
print(f"股票总数: {len(all_stocks)}")
|
|
||||||
|
|
||||||
# 2. 获取交易日列表
|
|
||||||
trading_days = get_trading_days(args.start, end_date)
|
|
||||||
|
|
||||||
if not trading_days:
|
|
||||||
print("无交易日数据")
|
|
||||||
return
|
|
||||||
|
|
||||||
# 如果强制模式,删除已有文件
|
|
||||||
if args.force:
|
|
||||||
for trade_date in trading_days:
|
|
||||||
output_file = os.path.join(OUTPUT_DIR, f'features_{trade_date}.parquet')
|
|
||||||
if os.path.exists(output_file):
|
|
||||||
os.remove(output_file)
|
|
||||||
print(f"删除已有文件: {output_file}")
|
|
||||||
|
|
||||||
# 3. 准备任务参数
|
|
||||||
tasks = [(trade_date, concepts, all_stocks) for trade_date in trading_days]
|
|
||||||
|
|
||||||
print(f"\n开始处理 {len(trading_days)} 个交易日({args.workers} 进程并行)...")
|
|
||||||
|
|
||||||
# 4. 多进程处理
|
|
||||||
success_count = 0
|
|
||||||
failed_dates = []
|
|
||||||
|
|
||||||
with ProcessPoolExecutor(max_workers=args.workers) as executor:
|
|
||||||
# 提交所有任务
|
|
||||||
futures = {executor.submit(process_single_day, task): task[0] for task in tasks}
|
|
||||||
|
|
||||||
# 使用 tqdm 显示进度
|
|
||||||
with tqdm(total=len(futures), desc="处理进度", unit="天") as pbar:
|
|
||||||
for future in as_completed(futures):
|
|
||||||
trade_date = futures[future]
|
|
||||||
try:
|
|
||||||
result_date, success = future.result()
|
|
||||||
if success:
|
|
||||||
success_count += 1
|
|
||||||
else:
|
|
||||||
failed_dates.append(result_date)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"\n[{trade_date}] 进程异常: {e}")
|
|
||||||
failed_dates.append(trade_date)
|
|
||||||
pbar.update(1)
|
|
||||||
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
print(f"处理完成: {success_count}/{len(trading_days)} 个交易日")
|
|
||||||
if failed_dates:
|
|
||||||
print(f"失败日期: {failed_dates[:10]}{'...' if len(failed_dates) > 10 else ''}")
|
|
||||||
print(f"数据保存在: {OUTPUT_DIR}")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -1,715 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""
|
|
||||||
数据准备 V2 - 基于时间片对齐的特征计算(修复版)
|
|
||||||
|
|
||||||
核心改进:
|
|
||||||
1. 时间片对齐:9:35 和历史的 9:35 比,而不是和前30分钟比
|
|
||||||
2. Z-Score 特征:相对于同时间片历史分布的偏离程度
|
|
||||||
3. 滚动窗口基线:每个日期使用它之前 N 天的数据作为基线(不是固定的最后 N 天!)
|
|
||||||
4. 基于 Z-Score 的动量:消除一天内波动率异构性
|
|
||||||
|
|
||||||
修复:
|
|
||||||
- 滚动窗口基线:避免未来数据泄露
|
|
||||||
- Z-Score 动量:消除早盘/尾盘波动率差异
|
|
||||||
- 进程级数据库单例:避免连接池爆炸
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import numpy as np
|
|
||||||
import pandas as pd
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from sqlalchemy import create_engine, text
|
|
||||||
from elasticsearch import Elasticsearch
|
|
||||||
from clickhouse_driver import Client
|
|
||||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
|
||||||
from typing import Dict, List, Tuple, Optional
|
|
||||||
from tqdm import tqdm
|
|
||||||
from collections import defaultdict
|
|
||||||
import warnings
|
|
||||||
import pickle
|
|
||||||
|
|
||||||
warnings.filterwarnings('ignore')
|
|
||||||
|
|
||||||
# ==================== 配置 ====================
|
|
||||||
|
|
||||||
MYSQL_URL = "mysql+pymysql://root:Zzl5588161!@192.168.1.5:3306/stock"
|
|
||||||
ES_HOST = 'http://127.0.0.1:9200'
|
|
||||||
ES_INDEX = 'concept_library_v3'
|
|
||||||
|
|
||||||
CLICKHOUSE_CONFIG = {
|
|
||||||
'host': '127.0.0.1',
|
|
||||||
'port': 9000,
|
|
||||||
'user': 'default',
|
|
||||||
'password': 'Zzl33818!',
|
|
||||||
'database': 'stock'
|
|
||||||
}
|
|
||||||
|
|
||||||
REFERENCE_INDEX = '000001.SH'
|
|
||||||
|
|
||||||
# 输出目录
|
|
||||||
OUTPUT_DIR = os.path.join(os.path.dirname(__file__), 'data_v2')
|
|
||||||
BASELINE_DIR = os.path.join(OUTPUT_DIR, 'baselines')
|
|
||||||
RAW_CACHE_DIR = os.path.join(OUTPUT_DIR, 'raw_cache')
|
|
||||||
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
|
||||||
os.makedirs(BASELINE_DIR, exist_ok=True)
|
|
||||||
os.makedirs(RAW_CACHE_DIR, exist_ok=True)
|
|
||||||
|
|
||||||
# 特征配置
|
|
||||||
CONFIG = {
|
|
||||||
'baseline_days': 20, # 滚动窗口大小
|
|
||||||
'min_baseline_samples': 10, # 最少需要10个样本才算有效基线
|
|
||||||
'limit_up_threshold': 9.8,
|
|
||||||
'limit_down_threshold': -9.8,
|
|
||||||
'zscore_clip': 5.0,
|
|
||||||
}
|
|
||||||
|
|
||||||
# 特征列表
|
|
||||||
FEATURES_V2 = [
|
|
||||||
'alpha', 'alpha_zscore', 'amt_zscore', 'rank_zscore',
|
|
||||||
'momentum_3m', 'momentum_5m', 'limit_up_ratio',
|
|
||||||
]
|
|
||||||
|
|
||||||
# ==================== 进程级单例(避免连接池爆炸)====================
|
|
||||||
|
|
||||||
# 进程级全局变量
|
|
||||||
_process_mysql_engine = None
|
|
||||||
_process_es_client = None
|
|
||||||
_process_ch_client = None
|
|
||||||
|
|
||||||
|
|
||||||
def init_process_connections():
|
|
||||||
"""进程初始化时调用,创建连接(单例)"""
|
|
||||||
global _process_mysql_engine, _process_es_client, _process_ch_client
|
|
||||||
_process_mysql_engine = create_engine(MYSQL_URL, echo=False, pool_pre_ping=True, pool_size=5)
|
|
||||||
_process_es_client = Elasticsearch([ES_HOST])
|
|
||||||
_process_ch_client = Client(**CLICKHOUSE_CONFIG)
|
|
||||||
|
|
||||||
|
|
||||||
def get_mysql_engine():
|
|
||||||
"""获取进程级 MySQL Engine(单例)"""
|
|
||||||
global _process_mysql_engine
|
|
||||||
if _process_mysql_engine is None:
|
|
||||||
_process_mysql_engine = create_engine(MYSQL_URL, echo=False, pool_pre_ping=True, pool_size=5)
|
|
||||||
return _process_mysql_engine
|
|
||||||
|
|
||||||
|
|
||||||
def get_es_client():
|
|
||||||
"""获取进程级 ES 客户端(单例)"""
|
|
||||||
global _process_es_client
|
|
||||||
if _process_es_client is None:
|
|
||||||
_process_es_client = Elasticsearch([ES_HOST])
|
|
||||||
return _process_es_client
|
|
||||||
|
|
||||||
|
|
||||||
def get_ch_client():
|
|
||||||
"""获取进程级 ClickHouse 客户端(单例)"""
|
|
||||||
global _process_ch_client
|
|
||||||
if _process_ch_client is None:
|
|
||||||
_process_ch_client = Client(**CLICKHOUSE_CONFIG)
|
|
||||||
return _process_ch_client
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 工具函数 ====================
|
|
||||||
|
|
||||||
def code_to_ch_format(code: str) -> str:
|
|
||||||
if not code or len(code) != 6 or not code.isdigit():
|
|
||||||
return None
|
|
||||||
if code.startswith('6'):
|
|
||||||
return f"{code}.SH"
|
|
||||||
elif code.startswith('0') or code.startswith('3'):
|
|
||||||
return f"{code}.SZ"
|
|
||||||
else:
|
|
||||||
return f"{code}.BJ"
|
|
||||||
|
|
||||||
|
|
||||||
def time_to_slot(ts) -> str:
|
|
||||||
"""将时间戳转换为时间片(HH:MM格式)"""
|
|
||||||
if isinstance(ts, str):
|
|
||||||
return ts
|
|
||||||
return ts.strftime('%H:%M')
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 获取概念列表 ====================
|
|
||||||
|
|
||||||
def get_all_concepts() -> List[dict]:
|
|
||||||
"""从ES获取所有叶子概念"""
|
|
||||||
es_client = get_es_client()
|
|
||||||
concepts = []
|
|
||||||
|
|
||||||
query = {
|
|
||||||
"query": {"match_all": {}},
|
|
||||||
"size": 100,
|
|
||||||
"_source": ["concept_id", "concept", "stocks"]
|
|
||||||
}
|
|
||||||
|
|
||||||
resp = es_client.search(index=ES_INDEX, body=query, scroll='2m')
|
|
||||||
scroll_id = resp['_scroll_id']
|
|
||||||
hits = resp['hits']['hits']
|
|
||||||
|
|
||||||
while len(hits) > 0:
|
|
||||||
for hit in hits:
|
|
||||||
source = hit['_source']
|
|
||||||
stocks = []
|
|
||||||
if 'stocks' in source and isinstance(source['stocks'], list):
|
|
||||||
for stock in source['stocks']:
|
|
||||||
if isinstance(stock, dict) and 'code' in stock and stock['code']:
|
|
||||||
stocks.append(stock['code'])
|
|
||||||
|
|
||||||
if stocks:
|
|
||||||
concepts.append({
|
|
||||||
'concept_id': source.get('concept_id'),
|
|
||||||
'concept_name': source.get('concept'),
|
|
||||||
'stocks': stocks
|
|
||||||
})
|
|
||||||
|
|
||||||
resp = es_client.scroll(scroll_id=scroll_id, scroll='2m')
|
|
||||||
scroll_id = resp['_scroll_id']
|
|
||||||
hits = resp['hits']['hits']
|
|
||||||
|
|
||||||
es_client.clear_scroll(scroll_id=scroll_id)
|
|
||||||
print(f"获取到 {len(concepts)} 个概念")
|
|
||||||
return concepts
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 获取交易日列表 ====================
|
|
||||||
|
|
||||||
def get_trading_days(start_date: str, end_date: str) -> List[str]:
|
|
||||||
"""获取交易日列表"""
|
|
||||||
client = get_ch_client()
|
|
||||||
|
|
||||||
query = f"""
|
|
||||||
SELECT DISTINCT toDate(timestamp) as trade_date
|
|
||||||
FROM stock_minute
|
|
||||||
WHERE toDate(timestamp) >= '{start_date}'
|
|
||||||
AND toDate(timestamp) <= '{end_date}'
|
|
||||||
ORDER BY trade_date
|
|
||||||
"""
|
|
||||||
|
|
||||||
result = client.execute(query)
|
|
||||||
days = [row[0].strftime('%Y-%m-%d') for row in result]
|
|
||||||
if days:
|
|
||||||
print(f"找到 {len(days)} 个交易日: {days[0]} ~ {days[-1]}")
|
|
||||||
return days
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 获取昨收价 ====================
|
|
||||||
|
|
||||||
def get_prev_close(stock_codes: List[str], trade_date: str) -> Dict[str, float]:
|
|
||||||
"""获取昨收价(上一交易日的收盘价 F007N)"""
|
|
||||||
valid_codes = [c for c in stock_codes if c and len(c) == 6 and c.isdigit()]
|
|
||||||
if not valid_codes:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
codes_str = "','".join(valid_codes)
|
|
||||||
query = f"""
|
|
||||||
SELECT SECCODE, F007N
|
|
||||||
FROM ea_trade
|
|
||||||
WHERE SECCODE IN ('{codes_str}')
|
|
||||||
AND TRADEDATE = (
|
|
||||||
SELECT MAX(TRADEDATE) FROM ea_trade WHERE TRADEDATE < '{trade_date}'
|
|
||||||
)
|
|
||||||
AND F007N IS NOT NULL AND F007N > 0
|
|
||||||
"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
engine = get_mysql_engine()
|
|
||||||
with engine.connect() as conn:
|
|
||||||
result = conn.execute(text(query))
|
|
||||||
return {row[0]: float(row[1]) for row in result if row[1]}
|
|
||||||
except Exception as e:
|
|
||||||
print(f"获取昨收价失败: {e}")
|
|
||||||
return {}
|
|
||||||
|
|
||||||
|
|
||||||
def get_index_prev_close(trade_date: str, index_code: str = REFERENCE_INDEX) -> float:
|
|
||||||
"""获取指数昨收价"""
|
|
||||||
code_no_suffix = index_code.split('.')[0]
|
|
||||||
|
|
||||||
try:
|
|
||||||
engine = get_mysql_engine()
|
|
||||||
with engine.connect() as conn:
|
|
||||||
result = conn.execute(text("""
|
|
||||||
SELECT F006N FROM ea_exchangetrade
|
|
||||||
WHERE INDEXCODE = :code AND TRADEDATE < :today
|
|
||||||
ORDER BY TRADEDATE DESC LIMIT 1
|
|
||||||
"""), {'code': code_no_suffix, 'today': trade_date}).fetchone()
|
|
||||||
|
|
||||||
if result and result[0]:
|
|
||||||
return float(result[0])
|
|
||||||
except Exception as e:
|
|
||||||
print(f"获取指数昨收失败: {e}")
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 获取分钟数据 ====================
|
|
||||||
|
|
||||||
def get_daily_stock_data(trade_date: str, stock_codes: List[str]) -> pd.DataFrame:
|
|
||||||
"""获取单日所有股票的分钟数据"""
|
|
||||||
client = get_ch_client()
|
|
||||||
|
|
||||||
ch_codes = []
|
|
||||||
code_map = {}
|
|
||||||
for code in stock_codes:
|
|
||||||
ch_code = code_to_ch_format(code)
|
|
||||||
if ch_code:
|
|
||||||
ch_codes.append(ch_code)
|
|
||||||
code_map[ch_code] = code
|
|
||||||
|
|
||||||
if not ch_codes:
|
|
||||||
return pd.DataFrame()
|
|
||||||
|
|
||||||
ch_codes_str = "','".join(ch_codes)
|
|
||||||
|
|
||||||
query = f"""
|
|
||||||
SELECT code, timestamp, close, volume, amt
|
|
||||||
FROM stock_minute
|
|
||||||
WHERE toDate(timestamp) = '{trade_date}'
|
|
||||||
AND code IN ('{ch_codes_str}')
|
|
||||||
ORDER BY code, timestamp
|
|
||||||
"""
|
|
||||||
|
|
||||||
result = client.execute(query)
|
|
||||||
if not result:
|
|
||||||
return pd.DataFrame()
|
|
||||||
|
|
||||||
df = pd.DataFrame(result, columns=['ch_code', 'timestamp', 'close', 'volume', 'amt'])
|
|
||||||
df['code'] = df['ch_code'].map(code_map)
|
|
||||||
df = df.dropna(subset=['code'])
|
|
||||||
|
|
||||||
return df[['code', 'timestamp', 'close', 'volume', 'amt']]
|
|
||||||
|
|
||||||
|
|
||||||
def get_daily_index_data(trade_date: str, index_code: str = REFERENCE_INDEX) -> pd.DataFrame:
|
|
||||||
"""获取单日指数分钟数据"""
|
|
||||||
client = get_ch_client()
|
|
||||||
|
|
||||||
query = f"""
|
|
||||||
SELECT timestamp, close, volume, amt
|
|
||||||
FROM index_minute
|
|
||||||
WHERE toDate(timestamp) = '{trade_date}'
|
|
||||||
AND code = '{index_code}'
|
|
||||||
ORDER BY timestamp
|
|
||||||
"""
|
|
||||||
|
|
||||||
result = client.execute(query)
|
|
||||||
if not result:
|
|
||||||
return pd.DataFrame()
|
|
||||||
|
|
||||||
df = pd.DataFrame(result, columns=['timestamp', 'close', 'volume', 'amt'])
|
|
||||||
return df
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 计算原始概念特征(单日)====================
|
|
||||||
|
|
||||||
def compute_raw_concept_features(
|
|
||||||
trade_date: str,
|
|
||||||
concepts: List[dict],
|
|
||||||
all_stocks: List[str]
|
|
||||||
) -> pd.DataFrame:
|
|
||||||
"""计算单日概念的原始特征(alpha, amt, rank_pct, limit_up_ratio)"""
|
|
||||||
# 检查缓存
|
|
||||||
cache_file = os.path.join(RAW_CACHE_DIR, f'raw_{trade_date}.parquet')
|
|
||||||
if os.path.exists(cache_file):
|
|
||||||
return pd.read_parquet(cache_file)
|
|
||||||
|
|
||||||
# 获取数据
|
|
||||||
stock_df = get_daily_stock_data(trade_date, all_stocks)
|
|
||||||
if stock_df.empty:
|
|
||||||
return pd.DataFrame()
|
|
||||||
|
|
||||||
index_df = get_daily_index_data(trade_date)
|
|
||||||
if index_df.empty:
|
|
||||||
return pd.DataFrame()
|
|
||||||
|
|
||||||
# 获取昨收价
|
|
||||||
prev_close = get_prev_close(all_stocks, trade_date)
|
|
||||||
index_prev_close = get_index_prev_close(trade_date)
|
|
||||||
|
|
||||||
if not prev_close or not index_prev_close:
|
|
||||||
return pd.DataFrame()
|
|
||||||
|
|
||||||
# 计算涨跌幅
|
|
||||||
stock_df['prev_close'] = stock_df['code'].map(prev_close)
|
|
||||||
stock_df = stock_df.dropna(subset=['prev_close'])
|
|
||||||
stock_df['change_pct'] = (stock_df['close'] - stock_df['prev_close']) / stock_df['prev_close'] * 100
|
|
||||||
|
|
||||||
index_df['change_pct'] = (index_df['close'] - index_prev_close) / index_prev_close * 100
|
|
||||||
index_change_map = dict(zip(index_df['timestamp'], index_df['change_pct']))
|
|
||||||
|
|
||||||
# 获取所有时间点
|
|
||||||
timestamps = sorted(stock_df['timestamp'].unique())
|
|
||||||
|
|
||||||
# 概念到股票的映射
|
|
||||||
concept_stocks = {c['concept_id']: set(c['stocks']) for c in concepts}
|
|
||||||
|
|
||||||
results = []
|
|
||||||
|
|
||||||
for ts in timestamps:
|
|
||||||
ts_stock_data = stock_df[stock_df['timestamp'] == ts]
|
|
||||||
index_change = index_change_map.get(ts, 0)
|
|
||||||
|
|
||||||
stock_change = dict(zip(ts_stock_data['code'], ts_stock_data['change_pct']))
|
|
||||||
stock_amt = dict(zip(ts_stock_data['code'], ts_stock_data['amt']))
|
|
||||||
|
|
||||||
concept_features = []
|
|
||||||
|
|
||||||
for concept_id, stocks in concept_stocks.items():
|
|
||||||
concept_changes = [stock_change[s] for s in stocks if s in stock_change]
|
|
||||||
concept_amts = [stock_amt.get(s, 0) for s in stocks if s in stock_change]
|
|
||||||
|
|
||||||
if not concept_changes:
|
|
||||||
continue
|
|
||||||
|
|
||||||
avg_change = np.mean(concept_changes)
|
|
||||||
total_amt = sum(concept_amts)
|
|
||||||
alpha = avg_change - index_change
|
|
||||||
|
|
||||||
limit_up_count = sum(1 for c in concept_changes if c >= CONFIG['limit_up_threshold'])
|
|
||||||
limit_up_ratio = limit_up_count / len(concept_changes)
|
|
||||||
|
|
||||||
concept_features.append({
|
|
||||||
'concept_id': concept_id,
|
|
||||||
'alpha': alpha,
|
|
||||||
'total_amt': total_amt,
|
|
||||||
'limit_up_ratio': limit_up_ratio,
|
|
||||||
'stock_count': len(concept_changes),
|
|
||||||
})
|
|
||||||
|
|
||||||
if not concept_features:
|
|
||||||
continue
|
|
||||||
|
|
||||||
concept_df = pd.DataFrame(concept_features)
|
|
||||||
concept_df['rank_pct'] = concept_df['alpha'].rank(pct=True)
|
|
||||||
concept_df['timestamp'] = ts
|
|
||||||
concept_df['time_slot'] = time_to_slot(ts)
|
|
||||||
concept_df['trade_date'] = trade_date
|
|
||||||
|
|
||||||
results.append(concept_df)
|
|
||||||
|
|
||||||
if not results:
|
|
||||||
return pd.DataFrame()
|
|
||||||
|
|
||||||
result_df = pd.concat(results, ignore_index=True)
|
|
||||||
|
|
||||||
# 保存缓存
|
|
||||||
result_df.to_parquet(cache_file, index=False)
|
|
||||||
|
|
||||||
return result_df
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 滚动窗口基线计算 ====================
|
|
||||||
|
|
||||||
def compute_rolling_baseline(
|
|
||||||
historical_data: pd.DataFrame,
|
|
||||||
concept_id: str
|
|
||||||
) -> Dict[str, Dict]:
|
|
||||||
"""
|
|
||||||
计算单个概念的滚动基线
|
|
||||||
|
|
||||||
返回: {time_slot: {alpha_mean, alpha_std, amt_mean, amt_std, rank_mean, rank_std, sample_count}}
|
|
||||||
"""
|
|
||||||
if historical_data.empty:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
concept_data = historical_data[historical_data['concept_id'] == concept_id]
|
|
||||||
if concept_data.empty:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
baseline_dict = {}
|
|
||||||
|
|
||||||
for time_slot, group in concept_data.groupby('time_slot'):
|
|
||||||
if len(group) < CONFIG['min_baseline_samples']:
|
|
||||||
continue
|
|
||||||
|
|
||||||
alpha_std = group['alpha'].std()
|
|
||||||
amt_std = group['total_amt'].std()
|
|
||||||
rank_std = group['rank_pct'].std()
|
|
||||||
|
|
||||||
baseline_dict[time_slot] = {
|
|
||||||
'alpha_mean': group['alpha'].mean(),
|
|
||||||
'alpha_std': max(alpha_std if pd.notna(alpha_std) else 1.0, 0.1),
|
|
||||||
'amt_mean': group['total_amt'].mean(),
|
|
||||||
'amt_std': max(amt_std if pd.notna(amt_std) else group['total_amt'].mean() * 0.5, 1.0),
|
|
||||||
'rank_mean': group['rank_pct'].mean(),
|
|
||||||
'rank_std': max(rank_std if pd.notna(rank_std) else 0.2, 0.05),
|
|
||||||
'sample_count': len(group),
|
|
||||||
}
|
|
||||||
|
|
||||||
return baseline_dict
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 计算单日 Z-Score 特征(带滚动基线)====================
|
|
||||||
|
|
||||||
def compute_zscore_features_rolling(
|
|
||||||
trade_date: str,
|
|
||||||
concepts: List[dict],
|
|
||||||
all_stocks: List[str],
|
|
||||||
historical_raw_data: pd.DataFrame # 该日期之前 N 天的原始数据
|
|
||||||
) -> pd.DataFrame:
|
|
||||||
"""
|
|
||||||
计算单日的 Z-Score 特征(使用滚动窗口基线)
|
|
||||||
|
|
||||||
关键改进:
|
|
||||||
1. 基线只使用 trade_date 之前的数据(无未来泄露)
|
|
||||||
2. 动量基于 Z-Score 计算(消除波动率异构性)
|
|
||||||
"""
|
|
||||||
# 计算当日原始特征
|
|
||||||
raw_df = compute_raw_concept_features(trade_date, concepts, all_stocks)
|
|
||||||
|
|
||||||
if raw_df.empty:
|
|
||||||
return pd.DataFrame()
|
|
||||||
|
|
||||||
zscore_records = []
|
|
||||||
|
|
||||||
for concept_id, group in raw_df.groupby('concept_id'):
|
|
||||||
# 计算该概念的滚动基线(只用历史数据)
|
|
||||||
baseline_dict = compute_rolling_baseline(historical_raw_data, concept_id)
|
|
||||||
|
|
||||||
if not baseline_dict:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 按时间排序
|
|
||||||
group = group.sort_values('timestamp').reset_index(drop=True)
|
|
||||||
|
|
||||||
# Z-Score 历史(用于计算基于 Z-Score 的动量)
|
|
||||||
zscore_history = []
|
|
||||||
|
|
||||||
for idx, row in group.iterrows():
|
|
||||||
time_slot = row['time_slot']
|
|
||||||
|
|
||||||
if time_slot not in baseline_dict:
|
|
||||||
continue
|
|
||||||
|
|
||||||
bl = baseline_dict[time_slot]
|
|
||||||
|
|
||||||
# 计算 Z-Score
|
|
||||||
alpha_zscore = (row['alpha'] - bl['alpha_mean']) / bl['alpha_std']
|
|
||||||
amt_zscore = (row['total_amt'] - bl['amt_mean']) / bl['amt_std']
|
|
||||||
rank_zscore = (row['rank_pct'] - bl['rank_mean']) / bl['rank_std']
|
|
||||||
|
|
||||||
# 截断极端值
|
|
||||||
clip = CONFIG['zscore_clip']
|
|
||||||
alpha_zscore = np.clip(alpha_zscore, -clip, clip)
|
|
||||||
amt_zscore = np.clip(amt_zscore, -clip, clip)
|
|
||||||
rank_zscore = np.clip(rank_zscore, -clip, clip)
|
|
||||||
|
|
||||||
# 记录 Z-Score 历史
|
|
||||||
zscore_history.append(alpha_zscore)
|
|
||||||
|
|
||||||
# 基于 Z-Score 计算动量(消除波动率异构性)
|
|
||||||
momentum_3m = 0.0
|
|
||||||
momentum_5m = 0.0
|
|
||||||
|
|
||||||
if len(zscore_history) >= 3:
|
|
||||||
recent_3 = zscore_history[-3:]
|
|
||||||
older_3 = zscore_history[-6:-3] if len(zscore_history) >= 6 else [zscore_history[0]]
|
|
||||||
momentum_3m = np.mean(recent_3) - np.mean(older_3)
|
|
||||||
|
|
||||||
if len(zscore_history) >= 5:
|
|
||||||
recent_5 = zscore_history[-5:]
|
|
||||||
older_5 = zscore_history[-10:-5] if len(zscore_history) >= 10 else [zscore_history[0]]
|
|
||||||
momentum_5m = np.mean(recent_5) - np.mean(older_5)
|
|
||||||
|
|
||||||
zscore_records.append({
|
|
||||||
'concept_id': concept_id,
|
|
||||||
'timestamp': row['timestamp'],
|
|
||||||
'time_slot': time_slot,
|
|
||||||
'trade_date': trade_date,
|
|
||||||
# 原始特征
|
|
||||||
'alpha': row['alpha'],
|
|
||||||
'total_amt': row['total_amt'],
|
|
||||||
'limit_up_ratio': row['limit_up_ratio'],
|
|
||||||
'stock_count': row['stock_count'],
|
|
||||||
'rank_pct': row['rank_pct'],
|
|
||||||
# Z-Score 特征
|
|
||||||
'alpha_zscore': alpha_zscore,
|
|
||||||
'amt_zscore': amt_zscore,
|
|
||||||
'rank_zscore': rank_zscore,
|
|
||||||
# 基于 Z-Score 的动量
|
|
||||||
'momentum_3m': momentum_3m,
|
|
||||||
'momentum_5m': momentum_5m,
|
|
||||||
})
|
|
||||||
|
|
||||||
if not zscore_records:
|
|
||||||
return pd.DataFrame()
|
|
||||||
|
|
||||||
return pd.DataFrame(zscore_records)
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 多进程处理 ====================
|
|
||||||
|
|
||||||
def process_single_day_v2(args) -> Tuple[str, bool]:
|
|
||||||
"""处理单个交易日(多进程版本)"""
|
|
||||||
trade_date, day_index, concepts, all_stocks, all_trading_days = args
|
|
||||||
output_file = os.path.join(OUTPUT_DIR, f'features_v2_{trade_date}.parquet')
|
|
||||||
|
|
||||||
if os.path.exists(output_file):
|
|
||||||
return (trade_date, True)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 计算滚动窗口范围(该日期之前的 N 天)
|
|
||||||
baseline_days = CONFIG['baseline_days']
|
|
||||||
|
|
||||||
# 找出 trade_date 之前的交易日
|
|
||||||
start_idx = max(0, day_index - baseline_days)
|
|
||||||
end_idx = day_index # 不包含当天
|
|
||||||
|
|
||||||
if end_idx <= start_idx:
|
|
||||||
# 没有足够的历史数据
|
|
||||||
return (trade_date, False)
|
|
||||||
|
|
||||||
historical_days = all_trading_days[start_idx:end_idx]
|
|
||||||
|
|
||||||
# 加载历史原始数据
|
|
||||||
historical_dfs = []
|
|
||||||
for hist_date in historical_days:
|
|
||||||
cache_file = os.path.join(RAW_CACHE_DIR, f'raw_{hist_date}.parquet')
|
|
||||||
if os.path.exists(cache_file):
|
|
||||||
historical_dfs.append(pd.read_parquet(cache_file))
|
|
||||||
else:
|
|
||||||
# 需要计算
|
|
||||||
hist_df = compute_raw_concept_features(hist_date, concepts, all_stocks)
|
|
||||||
if not hist_df.empty:
|
|
||||||
historical_dfs.append(hist_df)
|
|
||||||
|
|
||||||
if not historical_dfs:
|
|
||||||
return (trade_date, False)
|
|
||||||
|
|
||||||
historical_raw_data = pd.concat(historical_dfs, ignore_index=True)
|
|
||||||
|
|
||||||
# 计算当日 Z-Score 特征(使用滚动基线)
|
|
||||||
df = compute_zscore_features_rolling(trade_date, concepts, all_stocks, historical_raw_data)
|
|
||||||
|
|
||||||
if df.empty:
|
|
||||||
return (trade_date, False)
|
|
||||||
|
|
||||||
df.to_parquet(output_file, index=False)
|
|
||||||
return (trade_date, True)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[{trade_date}] 处理失败: {e}")
|
|
||||||
import traceback
|
|
||||||
traceback.print_exc()
|
|
||||||
return (trade_date, False)
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 主流程 ====================
|
|
||||||
|
|
||||||
def main():
|
|
||||||
import argparse
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='准备训练数据 V2(滚动窗口基线 + Z-Score 动量)')
|
|
||||||
parser.add_argument('--start', type=str, default='2022-01-01', help='开始日期')
|
|
||||||
parser.add_argument('--end', type=str, default=None, help='结束日期(默认今天)')
|
|
||||||
parser.add_argument('--workers', type=int, default=18, help='并行进程数')
|
|
||||||
parser.add_argument('--baseline-days', type=int, default=20, help='滚动基线窗口大小')
|
|
||||||
parser.add_argument('--force', action='store_true', help='强制重新计算(忽略缓存)')
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
end_date = args.end or datetime.now().strftime('%Y-%m-%d')
|
|
||||||
CONFIG['baseline_days'] = args.baseline_days
|
|
||||||
|
|
||||||
print("=" * 60)
|
|
||||||
print("数据准备 V2 - 滚动窗口基线 + Z-Score 动量")
|
|
||||||
print("=" * 60)
|
|
||||||
print(f"日期范围: {args.start} ~ {end_date}")
|
|
||||||
print(f"并行进程数: {args.workers}")
|
|
||||||
print(f"滚动基线窗口: {args.baseline_days} 天")
|
|
||||||
|
|
||||||
# 初始化主进程连接
|
|
||||||
init_process_connections()
|
|
||||||
|
|
||||||
# 1. 获取概念列表
|
|
||||||
concepts = get_all_concepts()
|
|
||||||
all_stocks = list(set(s for c in concepts for s in c['stocks']))
|
|
||||||
print(f"股票总数: {len(all_stocks)}")
|
|
||||||
|
|
||||||
# 2. 获取交易日列表
|
|
||||||
trading_days = get_trading_days(args.start, end_date)
|
|
||||||
|
|
||||||
if not trading_days:
|
|
||||||
print("无交易日数据")
|
|
||||||
return
|
|
||||||
|
|
||||||
# 3. 第一阶段:预计算所有原始特征(用于缓存)
|
|
||||||
print(f"\n{'='*60}")
|
|
||||||
print("第一阶段:预计算原始特征(用于滚动基线)")
|
|
||||||
print(f"{'='*60}")
|
|
||||||
|
|
||||||
# 如果强制重新计算,删除缓存
|
|
||||||
if args.force:
|
|
||||||
import shutil
|
|
||||||
if os.path.exists(RAW_CACHE_DIR):
|
|
||||||
shutil.rmtree(RAW_CACHE_DIR)
|
|
||||||
os.makedirs(RAW_CACHE_DIR, exist_ok=True)
|
|
||||||
if os.path.exists(OUTPUT_DIR):
|
|
||||||
for f in os.listdir(OUTPUT_DIR):
|
|
||||||
if f.startswith('features_v2_'):
|
|
||||||
os.remove(os.path.join(OUTPUT_DIR, f))
|
|
||||||
|
|
||||||
# 单线程预计算原始特征(因为需要顺序缓存)
|
|
||||||
print(f"预计算 {len(trading_days)} 天的原始特征...")
|
|
||||||
for trade_date in tqdm(trading_days, desc="预计算原始特征"):
|
|
||||||
cache_file = os.path.join(RAW_CACHE_DIR, f'raw_{trade_date}.parquet')
|
|
||||||
if not os.path.exists(cache_file):
|
|
||||||
compute_raw_concept_features(trade_date, concepts, all_stocks)
|
|
||||||
|
|
||||||
# 4. 第二阶段:计算 Z-Score 特征(多进程)
|
|
||||||
print(f"\n{'='*60}")
|
|
||||||
print("第二阶段:计算 Z-Score 特征(滚动基线)")
|
|
||||||
print(f"{'='*60}")
|
|
||||||
|
|
||||||
# 从第 baseline_days 天开始(前面的没有足够历史)
|
|
||||||
start_idx = args.baseline_days
|
|
||||||
processable_days = trading_days[start_idx:]
|
|
||||||
|
|
||||||
if not processable_days:
|
|
||||||
print(f"错误:需要至少 {args.baseline_days + 1} 天的数据")
|
|
||||||
return
|
|
||||||
|
|
||||||
print(f"可处理日期: {processable_days[0]} ~ {processable_days[-1]} ({len(processable_days)} 天)")
|
|
||||||
print(f"跳过前 {start_idx} 天(基线预热期)")
|
|
||||||
|
|
||||||
# 构建任务
|
|
||||||
tasks = []
|
|
||||||
for i, trade_date in enumerate(trading_days):
|
|
||||||
if i >= start_idx:
|
|
||||||
tasks.append((trade_date, i, concepts, all_stocks, trading_days))
|
|
||||||
|
|
||||||
print(f"开始处理 {len(tasks)} 个交易日({args.workers} 进程并行)...")
|
|
||||||
|
|
||||||
success_count = 0
|
|
||||||
failed_dates = []
|
|
||||||
|
|
||||||
# 使用进程池初始化器
|
|
||||||
with ProcessPoolExecutor(max_workers=args.workers, initializer=init_process_connections) as executor:
|
|
||||||
futures = {executor.submit(process_single_day_v2, task): task[0] for task in tasks}
|
|
||||||
|
|
||||||
with tqdm(total=len(futures), desc="处理进度", unit="天") as pbar:
|
|
||||||
for future in as_completed(futures):
|
|
||||||
trade_date = futures[future]
|
|
||||||
try:
|
|
||||||
result_date, success = future.result()
|
|
||||||
if success:
|
|
||||||
success_count += 1
|
|
||||||
else:
|
|
||||||
failed_dates.append(result_date)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"\n[{trade_date}] 进程异常: {e}")
|
|
||||||
failed_dates.append(trade_date)
|
|
||||||
pbar.update(1)
|
|
||||||
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
print(f"处理完成: {success_count}/{len(tasks)} 个交易日")
|
|
||||||
if failed_dates:
|
|
||||||
print(f"失败日期: {failed_dates[:10]}{'...' if len(failed_dates) > 10 else ''}")
|
|
||||||
print(f"数据保存在: {OUTPUT_DIR}")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,729 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""
|
|
||||||
V2 实时异动检测器
|
|
||||||
|
|
||||||
使用方法:
|
|
||||||
# 作为模块导入
|
|
||||||
from ml.realtime_detector_v2 import RealtimeDetectorV2
|
|
||||||
|
|
||||||
detector = RealtimeDetectorV2()
|
|
||||||
alerts = detector.detect_realtime() # 检测当前时刻
|
|
||||||
|
|
||||||
# 或命令行测试
|
|
||||||
python ml/realtime_detector_v2.py --date 2025-12-09
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import json
|
|
||||||
import pickle
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from typing import Dict, List, Optional, Tuple
|
|
||||||
from collections import defaultdict, deque
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pandas as pd
|
|
||||||
import torch
|
|
||||||
from sqlalchemy import create_engine, text
|
|
||||||
from elasticsearch import Elasticsearch
|
|
||||||
from clickhouse_driver import Client
|
|
||||||
|
|
||||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
||||||
|
|
||||||
from ml.model import TransformerAutoencoder
|
|
||||||
|
|
||||||
# ==================== 配置 ====================
|
|
||||||
|
|
||||||
MYSQL_URL = "mysql+pymysql://root:Zzl5588161!@192.168.1.5:3306/stock"
|
|
||||||
ES_HOST = 'http://127.0.0.1:9200'
|
|
||||||
ES_INDEX = 'concept_library_v3'
|
|
||||||
|
|
||||||
CLICKHOUSE_CONFIG = {
|
|
||||||
'host': '127.0.0.1',
|
|
||||||
'port': 9000,
|
|
||||||
'user': 'default',
|
|
||||||
'password': 'Zzl33818!',
|
|
||||||
'database': 'stock'
|
|
||||||
}
|
|
||||||
|
|
||||||
REFERENCE_INDEX = '000001.SH'
|
|
||||||
BASELINE_FILE = 'ml/data_v2/baselines/realtime_baseline.pkl'
|
|
||||||
MODEL_DIR = 'ml/checkpoints_v2'
|
|
||||||
|
|
||||||
# 检测配置
|
|
||||||
CONFIG = {
|
|
||||||
'seq_len': 10, # LSTM 序列长度
|
|
||||||
'confirm_window': 5, # 持续确认窗口
|
|
||||||
'confirm_ratio': 0.6, # 确认比例
|
|
||||||
'rule_weight': 0.5,
|
|
||||||
'ml_weight': 0.5,
|
|
||||||
'rule_trigger': 60,
|
|
||||||
'ml_trigger': 70,
|
|
||||||
'fusion_trigger': 50,
|
|
||||||
'cooldown_minutes': 10,
|
|
||||||
'max_alerts_per_minute': 15,
|
|
||||||
'zscore_clip': 5.0,
|
|
||||||
'limit_up_threshold': 9.8,
|
|
||||||
}
|
|
||||||
|
|
||||||
FEATURES = ['alpha_zscore', 'amt_zscore', 'rank_zscore', 'momentum_3m', 'momentum_5m', 'limit_up_ratio']
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 数据库连接 ====================
|
|
||||||
|
|
||||||
_mysql_engine = None
|
|
||||||
_es_client = None
|
|
||||||
_ch_client = None
|
|
||||||
|
|
||||||
|
|
||||||
def get_mysql_engine():
|
|
||||||
global _mysql_engine
|
|
||||||
if _mysql_engine is None:
|
|
||||||
_mysql_engine = create_engine(MYSQL_URL, echo=False, pool_pre_ping=True)
|
|
||||||
return _mysql_engine
|
|
||||||
|
|
||||||
|
|
||||||
def get_es_client():
|
|
||||||
global _es_client
|
|
||||||
if _es_client is None:
|
|
||||||
_es_client = Elasticsearch([ES_HOST])
|
|
||||||
return _es_client
|
|
||||||
|
|
||||||
|
|
||||||
def get_ch_client():
|
|
||||||
global _ch_client
|
|
||||||
if _ch_client is None:
|
|
||||||
_ch_client = Client(**CLICKHOUSE_CONFIG)
|
|
||||||
return _ch_client
|
|
||||||
|
|
||||||
|
|
||||||
def code_to_ch_format(code: str) -> str:
|
|
||||||
if not code or len(code) != 6 or not code.isdigit():
|
|
||||||
return None
|
|
||||||
if code.startswith('6'):
|
|
||||||
return f"{code}.SH"
|
|
||||||
elif code.startswith('0') or code.startswith('3'):
|
|
||||||
return f"{code}.SZ"
|
|
||||||
return f"{code}.BJ"
|
|
||||||
|
|
||||||
|
|
||||||
def time_to_slot(ts) -> str:
|
|
||||||
if isinstance(ts, str):
|
|
||||||
return ts
|
|
||||||
return ts.strftime('%H:%M')
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 规则评分 ====================
|
|
||||||
|
|
||||||
def score_rules_zscore(features: Dict) -> Tuple[float, List[str]]:
|
|
||||||
"""基于 Z-Score 的规则评分"""
|
|
||||||
score = 0.0
|
|
||||||
triggered = []
|
|
||||||
|
|
||||||
alpha_z = abs(features.get('alpha_zscore', 0))
|
|
||||||
amt_z = features.get('amt_zscore', 0)
|
|
||||||
rank_z = abs(features.get('rank_zscore', 0))
|
|
||||||
mom_3m = features.get('momentum_3m', 0)
|
|
||||||
mom_5m = features.get('momentum_5m', 0)
|
|
||||||
limit_up = features.get('limit_up_ratio', 0)
|
|
||||||
|
|
||||||
# Alpha Z-Score
|
|
||||||
if alpha_z >= 4.0:
|
|
||||||
score += 25
|
|
||||||
triggered.append('alpha_extreme')
|
|
||||||
elif alpha_z >= 3.0:
|
|
||||||
score += 18
|
|
||||||
triggered.append('alpha_strong')
|
|
||||||
elif alpha_z >= 2.0:
|
|
||||||
score += 10
|
|
||||||
triggered.append('alpha_moderate')
|
|
||||||
|
|
||||||
# 成交额 Z-Score
|
|
||||||
if amt_z >= 4.0:
|
|
||||||
score += 20
|
|
||||||
triggered.append('amt_extreme')
|
|
||||||
elif amt_z >= 3.0:
|
|
||||||
score += 12
|
|
||||||
triggered.append('amt_strong')
|
|
||||||
elif amt_z >= 2.0:
|
|
||||||
score += 6
|
|
||||||
triggered.append('amt_moderate')
|
|
||||||
|
|
||||||
# 排名 Z-Score
|
|
||||||
if rank_z >= 3.0:
|
|
||||||
score += 15
|
|
||||||
triggered.append('rank_extreme')
|
|
||||||
elif rank_z >= 2.0:
|
|
||||||
score += 8
|
|
||||||
triggered.append('rank_strong')
|
|
||||||
|
|
||||||
# 动量(基于 Z-Score 的)
|
|
||||||
if mom_3m >= 1.0:
|
|
||||||
score += 12
|
|
||||||
triggered.append('momentum_3m_strong')
|
|
||||||
elif mom_3m >= 0.5:
|
|
||||||
score += 6
|
|
||||||
triggered.append('momentum_3m_moderate')
|
|
||||||
|
|
||||||
if mom_5m >= 1.5:
|
|
||||||
score += 10
|
|
||||||
triggered.append('momentum_5m_strong')
|
|
||||||
|
|
||||||
# 涨停比例
|
|
||||||
if limit_up >= 0.3:
|
|
||||||
score += 20
|
|
||||||
triggered.append('limit_up_extreme')
|
|
||||||
elif limit_up >= 0.15:
|
|
||||||
score += 12
|
|
||||||
triggered.append('limit_up_strong')
|
|
||||||
elif limit_up >= 0.08:
|
|
||||||
score += 5
|
|
||||||
triggered.append('limit_up_moderate')
|
|
||||||
|
|
||||||
# 组合规则
|
|
||||||
if alpha_z >= 2.0 and amt_z >= 2.0:
|
|
||||||
score += 15
|
|
||||||
triggered.append('combo_alpha_amt')
|
|
||||||
|
|
||||||
if alpha_z >= 2.0 and limit_up >= 0.1:
|
|
||||||
score += 12
|
|
||||||
triggered.append('combo_alpha_limitup')
|
|
||||||
|
|
||||||
return min(score, 100), triggered
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 实时检测器 ====================
|
|
||||||
|
|
||||||
class RealtimeDetectorV2:
|
|
||||||
"""V2 实时异动检测器"""
|
|
||||||
|
|
||||||
def __init__(self, model_dir: str = MODEL_DIR, baseline_file: str = BASELINE_FILE):
|
|
||||||
print("初始化 V2 实时检测器...")
|
|
||||||
|
|
||||||
# 加载概念
|
|
||||||
self.concepts = self._load_concepts()
|
|
||||||
self.concept_stocks = {c['concept_id']: set(c['stocks']) for c in self.concepts}
|
|
||||||
self.all_stocks = list(set(s for c in self.concepts for s in c['stocks']))
|
|
||||||
|
|
||||||
# 加载基线
|
|
||||||
self.baselines = self._load_baselines(baseline_file)
|
|
||||||
|
|
||||||
# 加载模型
|
|
||||||
self.model, self.thresholds, self.device = self._load_model(model_dir)
|
|
||||||
|
|
||||||
# 状态管理
|
|
||||||
self.zscore_history = defaultdict(lambda: deque(maxlen=CONFIG['seq_len']))
|
|
||||||
self.anomaly_candidates = defaultdict(lambda: deque(maxlen=CONFIG['confirm_window']))
|
|
||||||
self.cooldown = {}
|
|
||||||
|
|
||||||
print(f"初始化完成: {len(self.concepts)} 概念, {len(self.baselines)} 基线")
|
|
||||||
|
|
||||||
def _load_concepts(self) -> List[dict]:
|
|
||||||
"""从 ES 加载概念"""
|
|
||||||
es = get_es_client()
|
|
||||||
concepts = []
|
|
||||||
|
|
||||||
query = {"query": {"match_all": {}}, "size": 100, "_source": ["concept_id", "concept", "stocks"]}
|
|
||||||
resp = es.search(index=ES_INDEX, body=query, scroll='2m')
|
|
||||||
scroll_id = resp['_scroll_id']
|
|
||||||
hits = resp['hits']['hits']
|
|
||||||
|
|
||||||
while hits:
|
|
||||||
for hit in hits:
|
|
||||||
src = hit['_source']
|
|
||||||
stocks = [s['code'] for s in src.get('stocks', []) if isinstance(s, dict) and s.get('code')]
|
|
||||||
if stocks:
|
|
||||||
concepts.append({
|
|
||||||
'concept_id': src.get('concept_id'),
|
|
||||||
'concept_name': src.get('concept'),
|
|
||||||
'stocks': stocks
|
|
||||||
})
|
|
||||||
resp = es.scroll(scroll_id=scroll_id, scroll='2m')
|
|
||||||
scroll_id = resp['_scroll_id']
|
|
||||||
hits = resp['hits']['hits']
|
|
||||||
|
|
||||||
es.clear_scroll(scroll_id=scroll_id)
|
|
||||||
return concepts
|
|
||||||
|
|
||||||
def _load_baselines(self, baseline_file: str) -> Dict:
|
|
||||||
"""加载基线"""
|
|
||||||
if not os.path.exists(baseline_file):
|
|
||||||
print(f"警告: 基线文件不存在: {baseline_file}")
|
|
||||||
print("请先运行: python ml/update_baseline.py")
|
|
||||||
return {}
|
|
||||||
|
|
||||||
with open(baseline_file, 'rb') as f:
|
|
||||||
data = pickle.load(f)
|
|
||||||
|
|
||||||
print(f"基线日期范围: {data.get('date_range', 'unknown')}")
|
|
||||||
print(f"更新时间: {data.get('update_time', 'unknown')}")
|
|
||||||
|
|
||||||
return data.get('baselines', {})
|
|
||||||
|
|
||||||
def _load_model(self, model_dir: str):
|
|
||||||
"""加载模型"""
|
|
||||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
||||||
|
|
||||||
config_path = os.path.join(model_dir, 'config.json')
|
|
||||||
model_path = os.path.join(model_dir, 'best_model.pt')
|
|
||||||
threshold_path = os.path.join(model_dir, 'thresholds.json')
|
|
||||||
|
|
||||||
if not os.path.exists(model_path):
|
|
||||||
print(f"警告: 模型不存在: {model_path}")
|
|
||||||
return None, {}, device
|
|
||||||
|
|
||||||
with open(config_path) as f:
|
|
||||||
config = json.load(f)
|
|
||||||
|
|
||||||
model = TransformerAutoencoder(**config['model'])
|
|
||||||
checkpoint = torch.load(model_path, map_location=device)
|
|
||||||
model.load_state_dict(checkpoint['model_state_dict'])
|
|
||||||
model.to(device)
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
thresholds = {}
|
|
||||||
if os.path.exists(threshold_path):
|
|
||||||
with open(threshold_path) as f:
|
|
||||||
thresholds = json.load(f)
|
|
||||||
|
|
||||||
print(f"模型已加载: {model_path}")
|
|
||||||
return model, thresholds, device
|
|
||||||
|
|
||||||
def _get_realtime_data(self, trade_date: str) -> pd.DataFrame:
|
|
||||||
"""获取实时数据并计算原始特征"""
|
|
||||||
ch = get_ch_client()
|
|
||||||
|
|
||||||
# 获取股票数据
|
|
||||||
ch_codes = [code_to_ch_format(c) for c in self.all_stocks if code_to_ch_format(c)]
|
|
||||||
ch_codes_str = "','".join(ch_codes)
|
|
||||||
|
|
||||||
stock_query = f"""
|
|
||||||
SELECT code, timestamp, close, amt
|
|
||||||
FROM stock_minute
|
|
||||||
WHERE toDate(timestamp) = '{trade_date}'
|
|
||||||
AND code IN ('{ch_codes_str}')
|
|
||||||
ORDER BY timestamp
|
|
||||||
"""
|
|
||||||
stock_result = ch.execute(stock_query)
|
|
||||||
if not stock_result:
|
|
||||||
return pd.DataFrame()
|
|
||||||
|
|
||||||
stock_df = pd.DataFrame(stock_result, columns=['ch_code', 'timestamp', 'close', 'amt'])
|
|
||||||
|
|
||||||
# 映射回原始代码
|
|
||||||
ch_to_code = {code_to_ch_format(c): c for c in self.all_stocks if code_to_ch_format(c)}
|
|
||||||
stock_df['code'] = stock_df['ch_code'].map(ch_to_code)
|
|
||||||
stock_df = stock_df.dropna(subset=['code'])
|
|
||||||
|
|
||||||
# 获取指数数据
|
|
||||||
index_query = f"""
|
|
||||||
SELECT timestamp, close
|
|
||||||
FROM index_minute
|
|
||||||
WHERE toDate(timestamp) = '{trade_date}'
|
|
||||||
AND code = '{REFERENCE_INDEX}'
|
|
||||||
ORDER BY timestamp
|
|
||||||
"""
|
|
||||||
index_result = ch.execute(index_query)
|
|
||||||
if not index_result:
|
|
||||||
return pd.DataFrame()
|
|
||||||
|
|
||||||
index_df = pd.DataFrame(index_result, columns=['timestamp', 'close'])
|
|
||||||
|
|
||||||
# 获取昨收价
|
|
||||||
engine = get_mysql_engine()
|
|
||||||
codes_str = "','".join([c for c in self.all_stocks if c and len(c) == 6])
|
|
||||||
|
|
||||||
with engine.connect() as conn:
|
|
||||||
prev_result = conn.execute(text(f"""
|
|
||||||
SELECT SECCODE, F007N FROM ea_trade
|
|
||||||
WHERE SECCODE IN ('{codes_str}')
|
|
||||||
AND TRADEDATE = (SELECT MAX(TRADEDATE) FROM ea_trade WHERE TRADEDATE < '{trade_date}')
|
|
||||||
AND F007N > 0
|
|
||||||
"""))
|
|
||||||
prev_close = {row[0]: float(row[1]) for row in prev_result if row[1]}
|
|
||||||
|
|
||||||
idx_result = conn.execute(text("""
|
|
||||||
SELECT F006N FROM ea_exchangetrade
|
|
||||||
WHERE INDEXCODE = '000001' AND TRADEDATE < :today
|
|
||||||
ORDER BY TRADEDATE DESC LIMIT 1
|
|
||||||
"""), {'today': trade_date}).fetchone()
|
|
||||||
index_prev_close = float(idx_result[0]) if idx_result else None
|
|
||||||
|
|
||||||
if not prev_close or not index_prev_close:
|
|
||||||
return pd.DataFrame()
|
|
||||||
|
|
||||||
# 计算涨跌幅
|
|
||||||
stock_df['prev_close'] = stock_df['code'].map(prev_close)
|
|
||||||
stock_df = stock_df.dropna(subset=['prev_close'])
|
|
||||||
stock_df['change_pct'] = (stock_df['close'] - stock_df['prev_close']) / stock_df['prev_close'] * 100
|
|
||||||
|
|
||||||
index_df['change_pct'] = (index_df['close'] - index_prev_close) / index_prev_close * 100
|
|
||||||
index_map = dict(zip(index_df['timestamp'], index_df['change_pct']))
|
|
||||||
|
|
||||||
# 按时间聚合概念特征
|
|
||||||
results = []
|
|
||||||
for ts in sorted(stock_df['timestamp'].unique()):
|
|
||||||
ts_data = stock_df[stock_df['timestamp'] == ts]
|
|
||||||
idx_chg = index_map.get(ts, 0)
|
|
||||||
|
|
||||||
stock_chg = dict(zip(ts_data['code'], ts_data['change_pct']))
|
|
||||||
stock_amt = dict(zip(ts_data['code'], ts_data['amt']))
|
|
||||||
|
|
||||||
for cid, stocks in self.concept_stocks.items():
|
|
||||||
changes = [stock_chg[s] for s in stocks if s in stock_chg]
|
|
||||||
amts = [stock_amt.get(s, 0) for s in stocks if s in stock_chg]
|
|
||||||
|
|
||||||
if not changes:
|
|
||||||
continue
|
|
||||||
|
|
||||||
alpha = np.mean(changes) - idx_chg
|
|
||||||
total_amt = sum(amts)
|
|
||||||
limit_up_ratio = sum(1 for c in changes if c >= CONFIG['limit_up_threshold']) / len(changes)
|
|
||||||
|
|
||||||
results.append({
|
|
||||||
'concept_id': cid,
|
|
||||||
'timestamp': ts,
|
|
||||||
'time_slot': time_to_slot(ts),
|
|
||||||
'alpha': alpha,
|
|
||||||
'total_amt': total_amt,
|
|
||||||
'limit_up_ratio': limit_up_ratio,
|
|
||||||
'stock_count': len(changes),
|
|
||||||
})
|
|
||||||
|
|
||||||
if not results:
|
|
||||||
return pd.DataFrame()
|
|
||||||
|
|
||||||
df = pd.DataFrame(results)
|
|
||||||
|
|
||||||
# 计算排名
|
|
||||||
for ts in df['timestamp'].unique():
|
|
||||||
mask = df['timestamp'] == ts
|
|
||||||
df.loc[mask, 'rank_pct'] = df.loc[mask, 'alpha'].rank(pct=True)
|
|
||||||
|
|
||||||
return df
|
|
||||||
|
|
||||||
def _compute_zscore(self, concept_id: str, time_slot: str, alpha: float, total_amt: float, rank_pct: float) -> Optional[Dict]:
|
|
||||||
"""计算 Z-Score"""
|
|
||||||
if concept_id not in self.baselines:
|
|
||||||
return None
|
|
||||||
|
|
||||||
baseline = self.baselines[concept_id]
|
|
||||||
if time_slot not in baseline:
|
|
||||||
return None
|
|
||||||
|
|
||||||
bl = baseline[time_slot]
|
|
||||||
|
|
||||||
alpha_z = np.clip((alpha - bl['alpha_mean']) / bl['alpha_std'], -5, 5)
|
|
||||||
amt_z = np.clip((total_amt - bl['amt_mean']) / bl['amt_std'], -5, 5)
|
|
||||||
rank_z = np.clip((rank_pct - bl['rank_mean']) / bl['rank_std'], -5, 5)
|
|
||||||
|
|
||||||
# 动量(基于 Z-Score 历史)
|
|
||||||
history = list(self.zscore_history[concept_id])
|
|
||||||
mom_3m = 0.0
|
|
||||||
mom_5m = 0.0
|
|
||||||
|
|
||||||
if len(history) >= 3:
|
|
||||||
recent = [h['alpha_zscore'] for h in history[-3:]]
|
|
||||||
older = [h['alpha_zscore'] for h in history[-6:-3]] if len(history) >= 6 else [history[0]['alpha_zscore']]
|
|
||||||
mom_3m = np.mean(recent) - np.mean(older)
|
|
||||||
|
|
||||||
if len(history) >= 5:
|
|
||||||
recent = [h['alpha_zscore'] for h in history[-5:]]
|
|
||||||
older = [h['alpha_zscore'] for h in history[-10:-5]] if len(history) >= 10 else [history[0]['alpha_zscore']]
|
|
||||||
mom_5m = np.mean(recent) - np.mean(older)
|
|
||||||
|
|
||||||
return {
|
|
||||||
'alpha_zscore': float(alpha_z),
|
|
||||||
'amt_zscore': float(amt_z),
|
|
||||||
'rank_zscore': float(rank_z),
|
|
||||||
'momentum_3m': float(mom_3m),
|
|
||||||
'momentum_5m': float(mom_5m),
|
|
||||||
}
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def _ml_score(self, sequences: np.ndarray) -> np.ndarray:
|
|
||||||
"""批量 ML 评分"""
|
|
||||||
if self.model is None or len(sequences) == 0:
|
|
||||||
return np.zeros(len(sequences))
|
|
||||||
|
|
||||||
x = torch.FloatTensor(sequences).to(self.device)
|
|
||||||
errors = self.model.compute_reconstruction_error(x, reduction='none')
|
|
||||||
last_errors = errors[:, -1].cpu().numpy()
|
|
||||||
|
|
||||||
# 转换为 0-100 分数
|
|
||||||
if self.thresholds:
|
|
||||||
p50 = self.thresholds.get('median', 0.001)
|
|
||||||
p99 = self.thresholds.get('p99', 0.05)
|
|
||||||
scores = 50 + (last_errors - p50) / (p99 - p50 + 1e-6) * 49
|
|
||||||
else:
|
|
||||||
scores = last_errors * 1000
|
|
||||||
|
|
||||||
return np.clip(scores, 0, 100)
|
|
||||||
|
|
||||||
def detect(self, trade_date: str = None) -> List[Dict]:
|
|
||||||
"""检测指定日期的异动"""
|
|
||||||
trade_date = trade_date or datetime.now().strftime('%Y-%m-%d')
|
|
||||||
print(f"\n检测 {trade_date} 的异动...")
|
|
||||||
|
|
||||||
# 重置状态
|
|
||||||
self.zscore_history.clear()
|
|
||||||
self.anomaly_candidates.clear()
|
|
||||||
self.cooldown.clear()
|
|
||||||
|
|
||||||
# 获取数据
|
|
||||||
raw_df = self._get_realtime_data(trade_date)
|
|
||||||
if raw_df.empty:
|
|
||||||
print("无数据")
|
|
||||||
return []
|
|
||||||
|
|
||||||
timestamps = sorted(raw_df['timestamp'].unique())
|
|
||||||
print(f"时间点数: {len(timestamps)}")
|
|
||||||
|
|
||||||
all_alerts = []
|
|
||||||
|
|
||||||
for ts in timestamps:
|
|
||||||
ts_data = raw_df[raw_df['timestamp'] == ts]
|
|
||||||
time_slot = time_to_slot(ts)
|
|
||||||
|
|
||||||
candidates = []
|
|
||||||
|
|
||||||
# 计算每个概念的 Z-Score
|
|
||||||
for _, row in ts_data.iterrows():
|
|
||||||
cid = row['concept_id']
|
|
||||||
|
|
||||||
zscore = self._compute_zscore(
|
|
||||||
cid, time_slot,
|
|
||||||
row['alpha'], row['total_amt'], row['rank_pct']
|
|
||||||
)
|
|
||||||
|
|
||||||
if zscore is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 完整特征
|
|
||||||
features = {
|
|
||||||
**zscore,
|
|
||||||
'alpha': row['alpha'],
|
|
||||||
'limit_up_ratio': row['limit_up_ratio'],
|
|
||||||
'total_amt': row['total_amt'],
|
|
||||||
}
|
|
||||||
|
|
||||||
# 更新历史
|
|
||||||
self.zscore_history[cid].append(zscore)
|
|
||||||
|
|
||||||
# 规则评分
|
|
||||||
rule_score, triggered = score_rules_zscore(features)
|
|
||||||
|
|
||||||
candidates.append((cid, features, rule_score, triggered))
|
|
||||||
|
|
||||||
if not candidates:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 批量 ML 评分
|
|
||||||
sequences = []
|
|
||||||
valid_candidates = []
|
|
||||||
|
|
||||||
for cid, features, rule_score, triggered in candidates:
|
|
||||||
history = list(self.zscore_history[cid])
|
|
||||||
if len(history) >= CONFIG['seq_len']:
|
|
||||||
seq = np.array([[h['alpha_zscore'], h['amt_zscore'], h['rank_zscore'],
|
|
||||||
h['momentum_3m'], h['momentum_5m'], features['limit_up_ratio']]
|
|
||||||
for h in history])
|
|
||||||
sequences.append(seq)
|
|
||||||
valid_candidates.append((cid, features, rule_score, triggered))
|
|
||||||
|
|
||||||
if not sequences:
|
|
||||||
continue
|
|
||||||
|
|
||||||
ml_scores = self._ml_score(np.array(sequences))
|
|
||||||
|
|
||||||
# 融合 + 确认
|
|
||||||
for i, (cid, features, rule_score, triggered) in enumerate(valid_candidates):
|
|
||||||
ml_score = ml_scores[i]
|
|
||||||
final_score = CONFIG['rule_weight'] * rule_score + CONFIG['ml_weight'] * ml_score
|
|
||||||
|
|
||||||
# 判断触发
|
|
||||||
is_triggered = (
|
|
||||||
rule_score >= CONFIG['rule_trigger'] or
|
|
||||||
ml_score >= CONFIG['ml_trigger'] or
|
|
||||||
final_score >= CONFIG['fusion_trigger']
|
|
||||||
)
|
|
||||||
|
|
||||||
self.anomaly_candidates[cid].append((ts, final_score))
|
|
||||||
|
|
||||||
if not is_triggered:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 冷却期
|
|
||||||
if cid in self.cooldown:
|
|
||||||
if (ts - self.cooldown[cid]).total_seconds() < CONFIG['cooldown_minutes'] * 60:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 持续确认
|
|
||||||
recent = list(self.anomaly_candidates[cid])
|
|
||||||
if len(recent) < CONFIG['confirm_window']:
|
|
||||||
continue
|
|
||||||
|
|
||||||
exceed = sum(1 for _, s in recent if s >= CONFIG['fusion_trigger'])
|
|
||||||
ratio = exceed / len(recent)
|
|
||||||
|
|
||||||
if ratio < CONFIG['confirm_ratio']:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 确认异动!
|
|
||||||
self.cooldown[cid] = ts
|
|
||||||
|
|
||||||
alpha = features['alpha']
|
|
||||||
alert_type = 'surge_up' if alpha >= 1.5 else 'surge_down' if alpha <= -1.5 else 'surge'
|
|
||||||
|
|
||||||
concept_name = next((c['concept_name'] for c in self.concepts if c['concept_id'] == cid), cid)
|
|
||||||
|
|
||||||
all_alerts.append({
|
|
||||||
'concept_id': cid,
|
|
||||||
'concept_name': concept_name,
|
|
||||||
'alert_time': ts,
|
|
||||||
'trade_date': trade_date,
|
|
||||||
'alert_type': alert_type,
|
|
||||||
'final_score': float(final_score),
|
|
||||||
'rule_score': float(rule_score),
|
|
||||||
'ml_score': float(ml_score),
|
|
||||||
'confirm_ratio': float(ratio),
|
|
||||||
'alpha': float(alpha),
|
|
||||||
'alpha_zscore': float(features['alpha_zscore']),
|
|
||||||
'amt_zscore': float(features['amt_zscore']),
|
|
||||||
'rank_zscore': float(features['rank_zscore']),
|
|
||||||
'momentum_3m': float(features['momentum_3m']),
|
|
||||||
'momentum_5m': float(features['momentum_5m']),
|
|
||||||
'limit_up_ratio': float(features['limit_up_ratio']),
|
|
||||||
'triggered_rules': triggered,
|
|
||||||
'trigger_reason': f"融合({final_score:.0f})+确认({ratio:.0%})",
|
|
||||||
})
|
|
||||||
|
|
||||||
print(f"检测到 {len(all_alerts)} 个异动")
|
|
||||||
return all_alerts
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 数据库存储 ====================
|
|
||||||
|
|
||||||
def create_v2_table():
|
|
||||||
"""创建 V2 异动表(如果不存在)"""
|
|
||||||
engine = get_mysql_engine()
|
|
||||||
with engine.begin() as conn:
|
|
||||||
conn.execute(text("""
|
|
||||||
CREATE TABLE IF NOT EXISTS concept_anomaly_v2 (
|
|
||||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
|
||||||
concept_id VARCHAR(50) NOT NULL,
|
|
||||||
alert_time DATETIME NOT NULL,
|
|
||||||
trade_date DATE NOT NULL,
|
|
||||||
alert_type VARCHAR(20) NOT NULL,
|
|
||||||
final_score FLOAT,
|
|
||||||
rule_score FLOAT,
|
|
||||||
ml_score FLOAT,
|
|
||||||
trigger_reason VARCHAR(200),
|
|
||||||
confirm_ratio FLOAT,
|
|
||||||
alpha FLOAT,
|
|
||||||
alpha_zscore FLOAT,
|
|
||||||
amt_zscore FLOAT,
|
|
||||||
rank_zscore FLOAT,
|
|
||||||
momentum_3m FLOAT,
|
|
||||||
momentum_5m FLOAT,
|
|
||||||
limit_up_ratio FLOAT,
|
|
||||||
triggered_rules TEXT,
|
|
||||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
||||||
UNIQUE KEY uk_concept_time (concept_id, alert_time),
|
|
||||||
INDEX idx_trade_date (trade_date),
|
|
||||||
INDEX idx_alert_type (alert_type)
|
|
||||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
|
|
||||||
"""))
|
|
||||||
print("concept_anomaly_v2 表已就绪")
|
|
||||||
|
|
||||||
|
|
||||||
def save_alerts_to_db(alerts: List[Dict]) -> int:
|
|
||||||
"""保存异动到数据库"""
|
|
||||||
if not alerts:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
engine = get_mysql_engine()
|
|
||||||
saved = 0
|
|
||||||
|
|
||||||
with engine.begin() as conn:
|
|
||||||
for alert in alerts:
|
|
||||||
try:
|
|
||||||
insert_sql = text("""
|
|
||||||
INSERT IGNORE INTO concept_anomaly_v2
|
|
||||||
(concept_id, alert_time, trade_date, alert_type,
|
|
||||||
final_score, rule_score, ml_score, trigger_reason, confirm_ratio,
|
|
||||||
alpha, alpha_zscore, amt_zscore, rank_zscore,
|
|
||||||
momentum_3m, momentum_5m, limit_up_ratio, triggered_rules)
|
|
||||||
VALUES
|
|
||||||
(:concept_id, :alert_time, :trade_date, :alert_type,
|
|
||||||
:final_score, :rule_score, :ml_score, :trigger_reason, :confirm_ratio,
|
|
||||||
:alpha, :alpha_zscore, :amt_zscore, :rank_zscore,
|
|
||||||
:momentum_3m, :momentum_5m, :limit_up_ratio, :triggered_rules)
|
|
||||||
""")
|
|
||||||
|
|
||||||
result = conn.execute(insert_sql, {
|
|
||||||
'concept_id': alert['concept_id'],
|
|
||||||
'alert_time': alert['alert_time'],
|
|
||||||
'trade_date': alert['trade_date'],
|
|
||||||
'alert_type': alert['alert_type'],
|
|
||||||
'final_score': alert['final_score'],
|
|
||||||
'rule_score': alert['rule_score'],
|
|
||||||
'ml_score': alert['ml_score'],
|
|
||||||
'trigger_reason': alert['trigger_reason'],
|
|
||||||
'confirm_ratio': alert['confirm_ratio'],
|
|
||||||
'alpha': alert['alpha'],
|
|
||||||
'alpha_zscore': alert['alpha_zscore'],
|
|
||||||
'amt_zscore': alert['amt_zscore'],
|
|
||||||
'rank_zscore': alert['rank_zscore'],
|
|
||||||
'momentum_3m': alert['momentum_3m'],
|
|
||||||
'momentum_5m': alert['momentum_5m'],
|
|
||||||
'limit_up_ratio': alert['limit_up_ratio'],
|
|
||||||
'triggered_rules': json.dumps(alert.get('triggered_rules', []), ensure_ascii=False),
|
|
||||||
})
|
|
||||||
|
|
||||||
if result.rowcount > 0:
|
|
||||||
saved += 1
|
|
||||||
except Exception as e:
|
|
||||||
print(f"保存失败: {alert['concept_id']} - {e}")
|
|
||||||
|
|
||||||
return saved
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
import argparse
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument('--date', type=str, default=None)
|
|
||||||
parser.add_argument('--no-save', action='store_true', help='不保存到数据库,只打印')
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# 确保表存在
|
|
||||||
if not args.no_save:
|
|
||||||
create_v2_table()
|
|
||||||
|
|
||||||
detector = RealtimeDetectorV2()
|
|
||||||
alerts = detector.detect(args.date)
|
|
||||||
|
|
||||||
print(f"\n{'='*60}")
|
|
||||||
print(f"检测结果 ({len(alerts)} 个异动)")
|
|
||||||
print('='*60)
|
|
||||||
|
|
||||||
for a in alerts[:20]:
|
|
||||||
print(f"[{a['alert_time'].strftime('%H:%M') if hasattr(a['alert_time'], 'strftime') else a['alert_time']}] "
|
|
||||||
f"{a['concept_name']} | {a['alert_type']} | "
|
|
||||||
f"分数={a['final_score']:.0f} 确认={a['confirm_ratio']:.0%} "
|
|
||||||
f"α={a['alpha']:.2f}% αZ={a['alpha_zscore']:.1f}")
|
|
||||||
|
|
||||||
if len(alerts) > 20:
|
|
||||||
print(f"... 共 {len(alerts)} 个")
|
|
||||||
|
|
||||||
# 保存到数据库
|
|
||||||
if not args.no_save and alerts:
|
|
||||||
saved = save_alerts_to_db(alerts)
|
|
||||||
print(f"\n✅ 已保存 {saved}/{len(alerts)} 条到 concept_anomaly_v2 表")
|
|
||||||
elif args.no_save:
|
|
||||||
print(f"\n⚠️ --no-save 模式,未保存到数据库")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -1,25 +0,0 @@
|
|||||||
# 概念异动检测 ML 模块依赖
|
|
||||||
# 安装: pip install -r ml/requirements.txt
|
|
||||||
|
|
||||||
# PyTorch (根据 CUDA 版本选择)
|
|
||||||
# 5090 显卡需要 CUDA 12.x
|
|
||||||
# pip install torch --index-url https://download.pytorch.org/whl/cu124
|
|
||||||
torch>=2.0.0
|
|
||||||
|
|
||||||
# 数据处理
|
|
||||||
numpy>=1.24.0
|
|
||||||
pandas>=2.0.0
|
|
||||||
pyarrow>=14.0.0
|
|
||||||
|
|
||||||
# 数据库
|
|
||||||
clickhouse-driver>=0.2.6
|
|
||||||
elasticsearch>=7.0.0,<8.0.0
|
|
||||||
sqlalchemy>=2.0.0
|
|
||||||
pymysql>=1.1.0
|
|
||||||
|
|
||||||
# 训练工具
|
|
||||||
tqdm>=4.65.0
|
|
||||||
|
|
||||||
# 可选: 可视化
|
|
||||||
# matplotlib>=3.7.0
|
|
||||||
# tensorboard>=2.14.0
|
|
||||||
@@ -1,99 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
# 概念异动检测模型训练脚本 (Linux)
|
|
||||||
#
|
|
||||||
# 使用方法:
|
|
||||||
# chmod +x run_training.sh
|
|
||||||
# ./run_training.sh
|
|
||||||
#
|
|
||||||
# 或指定参数:
|
|
||||||
# ./run_training.sh --start 2022-01-01 --epochs 100
|
|
||||||
|
|
||||||
set -e
|
|
||||||
|
|
||||||
echo "============================================================"
|
|
||||||
echo "概念异动检测模型训练流程"
|
|
||||||
echo "============================================================"
|
|
||||||
echo ""
|
|
||||||
|
|
||||||
# 获取脚本所在目录
|
|
||||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
|
||||||
cd "$SCRIPT_DIR/.."
|
|
||||||
|
|
||||||
echo "[1/4] 检查环境..."
|
|
||||||
python3 --version || { echo "Python3 未找到!"; exit 1; }
|
|
||||||
|
|
||||||
# 检查 GPU
|
|
||||||
if python3 -c "import torch; print(f'CUDA: {torch.cuda.is_available()}')" 2>/dev/null; then
|
|
||||||
echo "PyTorch GPU 检测完成"
|
|
||||||
else
|
|
||||||
echo "警告: PyTorch 未安装或无法检测 GPU"
|
|
||||||
fi
|
|
||||||
|
|
||||||
echo ""
|
|
||||||
echo "[2/4] 检查依赖..."
|
|
||||||
pip3 install -q torch pandas numpy pyarrow tqdm clickhouse-driver elasticsearch sqlalchemy pymysql
|
|
||||||
|
|
||||||
echo ""
|
|
||||||
echo "[3/4] 准备训练数据..."
|
|
||||||
echo "从 ClickHouse 提取历史数据,这可能需要较长时间..."
|
|
||||||
echo ""
|
|
||||||
|
|
||||||
# 解析参数
|
|
||||||
START_DATE="2022-01-01"
|
|
||||||
END_DATE=""
|
|
||||||
EPOCHS=100
|
|
||||||
BATCH_SIZE=256
|
|
||||||
TRAIN_END="2025-06-30"
|
|
||||||
VAL_END="2025-09-30"
|
|
||||||
|
|
||||||
while [[ $# -gt 0 ]]; do
|
|
||||||
case $1 in
|
|
||||||
--start)
|
|
||||||
START_DATE="$2"
|
|
||||||
shift 2
|
|
||||||
;;
|
|
||||||
--end)
|
|
||||||
END_DATE="$2"
|
|
||||||
shift 2
|
|
||||||
;;
|
|
||||||
--epochs)
|
|
||||||
EPOCHS="$2"
|
|
||||||
shift 2
|
|
||||||
;;
|
|
||||||
--batch_size)
|
|
||||||
BATCH_SIZE="$2"
|
|
||||||
shift 2
|
|
||||||
;;
|
|
||||||
--train_end)
|
|
||||||
TRAIN_END="$2"
|
|
||||||
shift 2
|
|
||||||
;;
|
|
||||||
--val_end)
|
|
||||||
VAL_END="$2"
|
|
||||||
shift 2
|
|
||||||
;;
|
|
||||||
*)
|
|
||||||
shift
|
|
||||||
;;
|
|
||||||
esac
|
|
||||||
done
|
|
||||||
|
|
||||||
# 数据准备
|
|
||||||
if [ -n "$END_DATE" ]; then
|
|
||||||
python3 ml/prepare_data.py --start "$START_DATE" --end "$END_DATE"
|
|
||||||
else
|
|
||||||
python3 ml/prepare_data.py --start "$START_DATE"
|
|
||||||
fi
|
|
||||||
|
|
||||||
echo ""
|
|
||||||
echo "[4/4] 训练模型..."
|
|
||||||
echo "使用 GPU 加速训练..."
|
|
||||||
echo ""
|
|
||||||
|
|
||||||
python3 ml/train.py --epochs "$EPOCHS" --batch_size "$BATCH_SIZE" --train_end "$TRAIN_END" --val_end "$VAL_END"
|
|
||||||
|
|
||||||
echo ""
|
|
||||||
echo "============================================================"
|
|
||||||
echo "训练完成!"
|
|
||||||
echo "模型保存在: ml/checkpoints/"
|
|
||||||
echo "============================================================"
|
|
||||||
808
ml/train.py
808
ml/train.py
@@ -1,808 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""
|
|
||||||
Transformer Autoencoder 训练脚本 (修复版)
|
|
||||||
|
|
||||||
修复问题:
|
|
||||||
1. 按概念分组构建序列,避免跨概念切片
|
|
||||||
2. 按时间(日期)切分数据集,避免数据泄露
|
|
||||||
3. 使用 RobustScaler + Clipping 处理非平稳性
|
|
||||||
4. 使用验证集计算阈值
|
|
||||||
|
|
||||||
训练流程:
|
|
||||||
1. 加载预处理好的特征数据(parquet 文件)
|
|
||||||
2. 按概念分组,在每个概念内部构建序列
|
|
||||||
3. 按日期划分训练/验证/测试集
|
|
||||||
4. 训练 Autoencoder(最小化重构误差)
|
|
||||||
5. 保存模型和阈值
|
|
||||||
|
|
||||||
使用方法:
|
|
||||||
python train.py --data_dir ml/data --epochs 100 --batch_size 256
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import argparse
|
|
||||||
import json
|
|
||||||
from datetime import datetime
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import List, Tuple, Dict
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pandas as pd
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from torch.utils.data import Dataset, DataLoader
|
|
||||||
from torch.optim import AdamW
|
|
||||||
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from model import TransformerAutoencoder, AnomalyDetectionLoss, count_parameters
|
|
||||||
|
|
||||||
# 性能优化:启用 cuDNN benchmark(对固定输入尺寸自动选择最快算法)
|
|
||||||
torch.backends.cudnn.benchmark = True
|
|
||||||
# 启用 TF32(RTX 30/40 系列特有,提速约 3 倍)
|
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
|
||||||
torch.backends.cudnn.allow_tf32 = True
|
|
||||||
|
|
||||||
# 可视化(可选)
|
|
||||||
try:
|
|
||||||
import matplotlib
|
|
||||||
matplotlib.use('Agg') # 无头模式,不需要显示器
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
HAS_MATPLOTLIB = True
|
|
||||||
except ImportError:
|
|
||||||
HAS_MATPLOTLIB = False
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 配置 ====================
|
|
||||||
|
|
||||||
TRAIN_CONFIG = {
|
|
||||||
# 数据配置
|
|
||||||
'seq_len': 30, # 输入序列长度(30分钟)
|
|
||||||
'stride': 5, # 滑动窗口步长
|
|
||||||
|
|
||||||
# 时间切分(按日期)
|
|
||||||
'train_end_date': '2024-06-30', # 训练集截止日期
|
|
||||||
'val_end_date': '2024-09-30', # 验证集截止日期(之后为测试集)
|
|
||||||
|
|
||||||
# 特征配置
|
|
||||||
'features': [
|
|
||||||
'alpha', # 超额收益
|
|
||||||
'alpha_delta', # Alpha 变化率
|
|
||||||
'amt_ratio', # 成交额比率
|
|
||||||
'amt_delta', # 成交额变化率
|
|
||||||
'rank_pct', # Alpha 排名百分位
|
|
||||||
'limit_up_ratio', # 涨停比例
|
|
||||||
],
|
|
||||||
|
|
||||||
# 训练配置(针对 4x RTX 4090 优化)
|
|
||||||
'batch_size': 4096, # 256 -> 4096(大幅增加,充分利用显存)
|
|
||||||
'epochs': 100,
|
|
||||||
'learning_rate': 3e-4, # 1e-4 -> 3e-4(大 batch 需要更大学习率)
|
|
||||||
'weight_decay': 1e-5,
|
|
||||||
'gradient_clip': 1.0,
|
|
||||||
|
|
||||||
# 早停配置
|
|
||||||
'patience': 10,
|
|
||||||
'min_delta': 1e-6,
|
|
||||||
|
|
||||||
# 模型配置(LSTM Autoencoder,简洁有效)
|
|
||||||
'model': {
|
|
||||||
'n_features': 6,
|
|
||||||
'hidden_dim': 32, # LSTM 隐藏维度(小)
|
|
||||||
'latent_dim': 4, # 瓶颈维度(非常小!关键)
|
|
||||||
'num_layers': 1, # LSTM 层数
|
|
||||||
'dropout': 0.2,
|
|
||||||
'bidirectional': True, # 双向编码器
|
|
||||||
},
|
|
||||||
|
|
||||||
# 标准化配置
|
|
||||||
'use_instance_norm': True, # 模型内部使用 Instance Norm(推荐)
|
|
||||||
'clip_value': 10.0, # 简单截断极端值
|
|
||||||
|
|
||||||
# 阈值配置
|
|
||||||
'threshold_percentiles': [90, 95, 99],
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 数据加载(修复版)====================
|
|
||||||
|
|
||||||
def load_data_by_date(data_dir: str, features: List[str]) -> Dict[str, pd.DataFrame]:
|
|
||||||
"""
|
|
||||||
按日期加载数据,返回 {date: DataFrame} 字典
|
|
||||||
|
|
||||||
每个 DataFrame 包含该日所有概念的所有时间点数据
|
|
||||||
"""
|
|
||||||
data_path = Path(data_dir)
|
|
||||||
parquet_files = sorted(data_path.glob("features_*.parquet"))
|
|
||||||
|
|
||||||
if not parquet_files:
|
|
||||||
raise FileNotFoundError(f"未找到 parquet 文件: {data_dir}")
|
|
||||||
|
|
||||||
print(f"找到 {len(parquet_files)} 个数据文件")
|
|
||||||
|
|
||||||
date_data = {}
|
|
||||||
|
|
||||||
for pf in tqdm(parquet_files, desc="加载数据"):
|
|
||||||
# 提取日期
|
|
||||||
date = pf.stem.replace('features_', '')
|
|
||||||
|
|
||||||
df = pd.read_parquet(pf)
|
|
||||||
|
|
||||||
# 检查必要列
|
|
||||||
required_cols = features + ['concept_id', 'timestamp']
|
|
||||||
missing_cols = [c for c in required_cols if c not in df.columns]
|
|
||||||
if missing_cols:
|
|
||||||
print(f"警告: {date} 缺少列: {missing_cols}, 跳过")
|
|
||||||
continue
|
|
||||||
|
|
||||||
date_data[date] = df
|
|
||||||
|
|
||||||
print(f"成功加载 {len(date_data)} 天的数据")
|
|
||||||
return date_data
|
|
||||||
|
|
||||||
|
|
||||||
def split_data_by_date(
|
|
||||||
date_data: Dict[str, pd.DataFrame],
|
|
||||||
train_end: str,
|
|
||||||
val_end: str
|
|
||||||
) -> Tuple[Dict[str, pd.DataFrame], Dict[str, pd.DataFrame], Dict[str, pd.DataFrame]]:
|
|
||||||
"""
|
|
||||||
按日期严格划分数据集
|
|
||||||
|
|
||||||
- 训练集: <= train_end
|
|
||||||
- 验证集: train_end < date <= val_end
|
|
||||||
- 测试集: > val_end
|
|
||||||
"""
|
|
||||||
train_data = {}
|
|
||||||
val_data = {}
|
|
||||||
test_data = {}
|
|
||||||
|
|
||||||
for date, df in date_data.items():
|
|
||||||
if date <= train_end:
|
|
||||||
train_data[date] = df
|
|
||||||
elif date <= val_end:
|
|
||||||
val_data[date] = df
|
|
||||||
else:
|
|
||||||
test_data[date] = df
|
|
||||||
|
|
||||||
print(f"数据集划分(按日期):")
|
|
||||||
print(f" 训练集: {len(train_data)} 天 (<= {train_end})")
|
|
||||||
print(f" 验证集: {len(val_data)} 天 ({train_end} ~ {val_end})")
|
|
||||||
print(f" 测试集: {len(test_data)} 天 (> {val_end})")
|
|
||||||
|
|
||||||
return train_data, val_data, test_data
|
|
||||||
|
|
||||||
|
|
||||||
def build_sequences_by_concept(
|
|
||||||
date_data: Dict[str, pd.DataFrame],
|
|
||||||
features: List[str],
|
|
||||||
seq_len: int,
|
|
||||||
stride: int
|
|
||||||
) -> np.ndarray:
|
|
||||||
"""
|
|
||||||
按概念分组构建序列(性能优化版)
|
|
||||||
|
|
||||||
使用 groupby 一次性分组,避免重复扫描大数组
|
|
||||||
|
|
||||||
1. 将所有日期的数据合并
|
|
||||||
2. 使用 groupby 按 concept_id 分组
|
|
||||||
3. 在每个概念内部,按时间排序并滑动窗口
|
|
||||||
4. 合并所有序列
|
|
||||||
"""
|
|
||||||
# 合并所有日期的数据
|
|
||||||
all_dfs = []
|
|
||||||
for date, df in sorted(date_data.items()):
|
|
||||||
df = df.copy()
|
|
||||||
df['date'] = date
|
|
||||||
all_dfs.append(df)
|
|
||||||
|
|
||||||
if not all_dfs:
|
|
||||||
return np.array([])
|
|
||||||
|
|
||||||
combined = pd.concat(all_dfs, ignore_index=True)
|
|
||||||
|
|
||||||
# 预先排序(按概念、日期、时间),这样 groupby 会更快
|
|
||||||
combined = combined.sort_values(['concept_id', 'date', 'timestamp'])
|
|
||||||
|
|
||||||
# 使用 groupby 一次性分组(性能关键!)
|
|
||||||
all_sequences = []
|
|
||||||
grouped = combined.groupby('concept_id', sort=False)
|
|
||||||
n_concepts = len(grouped)
|
|
||||||
|
|
||||||
for concept_id, concept_df in tqdm(grouped, desc="构建序列", total=n_concepts, leave=False):
|
|
||||||
# 已经排序过了,直接提取特征
|
|
||||||
feature_data = concept_df[features].values
|
|
||||||
|
|
||||||
# 处理缺失值
|
|
||||||
feature_data = np.nan_to_num(feature_data, nan=0.0, posinf=0.0, neginf=0.0)
|
|
||||||
|
|
||||||
# 在该概念内部滑动窗口
|
|
||||||
n_points = len(feature_data)
|
|
||||||
for start in range(0, n_points - seq_len + 1, stride):
|
|
||||||
seq = feature_data[start:start + seq_len]
|
|
||||||
all_sequences.append(seq)
|
|
||||||
|
|
||||||
if not all_sequences:
|
|
||||||
return np.array([])
|
|
||||||
|
|
||||||
sequences = np.array(all_sequences)
|
|
||||||
print(f" 构建序列: {len(sequences):,} 条 (来自 {n_concepts} 个概念)")
|
|
||||||
|
|
||||||
return sequences
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 数据集 ====================
|
|
||||||
|
|
||||||
class SequenceDataset(Dataset):
|
|
||||||
"""序列数据集(已经构建好的序列)"""
|
|
||||||
|
|
||||||
def __init__(self, sequences: np.ndarray):
|
|
||||||
self.sequences = torch.FloatTensor(sequences)
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
|
||||||
return len(self.sequences)
|
|
||||||
|
|
||||||
def __getitem__(self, idx: int) -> torch.Tensor:
|
|
||||||
return self.sequences[idx]
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 训练器 ====================
|
|
||||||
|
|
||||||
class EarlyStopping:
|
|
||||||
"""早停机制"""
|
|
||||||
|
|
||||||
def __init__(self, patience: int = 10, min_delta: float = 1e-6):
|
|
||||||
self.patience = patience
|
|
||||||
self.min_delta = min_delta
|
|
||||||
self.counter = 0
|
|
||||||
self.best_loss = float('inf')
|
|
||||||
self.early_stop = False
|
|
||||||
|
|
||||||
def __call__(self, val_loss: float) -> bool:
|
|
||||||
if val_loss < self.best_loss - self.min_delta:
|
|
||||||
self.best_loss = val_loss
|
|
||||||
self.counter = 0
|
|
||||||
else:
|
|
||||||
self.counter += 1
|
|
||||||
if self.counter >= self.patience:
|
|
||||||
self.early_stop = True
|
|
||||||
|
|
||||||
return self.early_stop
|
|
||||||
|
|
||||||
|
|
||||||
class Trainer:
|
|
||||||
"""模型训练器(支持 AMP 混合精度加速)"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model: nn.Module,
|
|
||||||
train_loader: DataLoader,
|
|
||||||
val_loader: DataLoader,
|
|
||||||
config: Dict,
|
|
||||||
device: torch.device,
|
|
||||||
save_dir: str = 'ml/checkpoints'
|
|
||||||
):
|
|
||||||
self.model = model.to(device)
|
|
||||||
self.train_loader = train_loader
|
|
||||||
self.val_loader = val_loader
|
|
||||||
self.config = config
|
|
||||||
self.device = device
|
|
||||||
self.save_dir = Path(save_dir)
|
|
||||||
self.save_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
# 优化器
|
|
||||||
self.optimizer = AdamW(
|
|
||||||
model.parameters(),
|
|
||||||
lr=config['learning_rate'],
|
|
||||||
weight_decay=config['weight_decay']
|
|
||||||
)
|
|
||||||
|
|
||||||
# 学习率调度器
|
|
||||||
self.scheduler = CosineAnnealingWarmRestarts(
|
|
||||||
self.optimizer,
|
|
||||||
T_0=10,
|
|
||||||
T_mult=2,
|
|
||||||
eta_min=1e-6
|
|
||||||
)
|
|
||||||
|
|
||||||
# 损失函数(简化版,只用 MSE)
|
|
||||||
self.criterion = AnomalyDetectionLoss()
|
|
||||||
|
|
||||||
# 早停
|
|
||||||
self.early_stopping = EarlyStopping(
|
|
||||||
patience=config['patience'],
|
|
||||||
min_delta=config['min_delta']
|
|
||||||
)
|
|
||||||
|
|
||||||
# AMP 混合精度训练(大幅提速 + 省显存)
|
|
||||||
self.use_amp = torch.cuda.is_available()
|
|
||||||
self.scaler = torch.cuda.amp.GradScaler() if self.use_amp else None
|
|
||||||
if self.use_amp:
|
|
||||||
print(" ✓ 启用 AMP 混合精度训练")
|
|
||||||
|
|
||||||
# 训练历史
|
|
||||||
self.history = {
|
|
||||||
'train_loss': [],
|
|
||||||
'val_loss': [],
|
|
||||||
'learning_rate': [],
|
|
||||||
}
|
|
||||||
|
|
||||||
self.best_val_loss = float('inf')
|
|
||||||
|
|
||||||
def train_epoch(self) -> float:
|
|
||||||
"""训练一个 epoch(使用 AMP 混合精度)"""
|
|
||||||
self.model.train()
|
|
||||||
total_loss = 0.0
|
|
||||||
n_batches = 0
|
|
||||||
|
|
||||||
pbar = tqdm(self.train_loader, desc="Training", leave=False)
|
|
||||||
for batch in pbar:
|
|
||||||
batch = batch.to(self.device, non_blocking=True) # 异步传输
|
|
||||||
|
|
||||||
self.optimizer.zero_grad(set_to_none=True) # 更快的梯度清零
|
|
||||||
|
|
||||||
# AMP 混合精度前向传播
|
|
||||||
if self.use_amp:
|
|
||||||
with torch.cuda.amp.autocast():
|
|
||||||
output, latent = self.model(batch)
|
|
||||||
loss, loss_dict = self.criterion(output, batch, latent)
|
|
||||||
|
|
||||||
# AMP 反向传播
|
|
||||||
self.scaler.scale(loss).backward()
|
|
||||||
|
|
||||||
# 梯度裁剪(需要 unscale)
|
|
||||||
self.scaler.unscale_(self.optimizer)
|
|
||||||
torch.nn.utils.clip_grad_norm_(
|
|
||||||
self.model.parameters(),
|
|
||||||
self.config['gradient_clip']
|
|
||||||
)
|
|
||||||
|
|
||||||
self.scaler.step(self.optimizer)
|
|
||||||
self.scaler.update()
|
|
||||||
else:
|
|
||||||
# 非 AMP 模式
|
|
||||||
output, latent = self.model(batch)
|
|
||||||
loss, loss_dict = self.criterion(output, batch, latent)
|
|
||||||
|
|
||||||
loss.backward()
|
|
||||||
torch.nn.utils.clip_grad_norm_(
|
|
||||||
self.model.parameters(),
|
|
||||||
self.config['gradient_clip']
|
|
||||||
)
|
|
||||||
self.optimizer.step()
|
|
||||||
|
|
||||||
total_loss += loss.item()
|
|
||||||
n_batches += 1
|
|
||||||
|
|
||||||
pbar.set_postfix({'loss': f"{loss.item():.4f}"})
|
|
||||||
|
|
||||||
return total_loss / n_batches
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def validate(self) -> float:
|
|
||||||
"""验证(使用 AMP)"""
|
|
||||||
self.model.eval()
|
|
||||||
total_loss = 0.0
|
|
||||||
n_batches = 0
|
|
||||||
|
|
||||||
for batch in self.val_loader:
|
|
||||||
batch = batch.to(self.device, non_blocking=True)
|
|
||||||
|
|
||||||
if self.use_amp:
|
|
||||||
with torch.cuda.amp.autocast():
|
|
||||||
output, latent = self.model(batch)
|
|
||||||
loss, _ = self.criterion(output, batch, latent)
|
|
||||||
else:
|
|
||||||
output, latent = self.model(batch)
|
|
||||||
loss, _ = self.criterion(output, batch, latent)
|
|
||||||
|
|
||||||
total_loss += loss.item()
|
|
||||||
n_batches += 1
|
|
||||||
|
|
||||||
return total_loss / n_batches
|
|
||||||
|
|
||||||
def save_checkpoint(self, epoch: int, val_loss: float, is_best: bool = False):
|
|
||||||
"""保存检查点"""
|
|
||||||
# 处理 DataParallel 包装
|
|
||||||
model_to_save = self.model.module if hasattr(self.model, 'module') else self.model
|
|
||||||
|
|
||||||
checkpoint = {
|
|
||||||
'epoch': epoch,
|
|
||||||
'model_state_dict': model_to_save.state_dict(),
|
|
||||||
'optimizer_state_dict': self.optimizer.state_dict(),
|
|
||||||
'scheduler_state_dict': self.scheduler.state_dict(),
|
|
||||||
'val_loss': val_loss,
|
|
||||||
'config': self.config,
|
|
||||||
}
|
|
||||||
|
|
||||||
# 保存最新检查点
|
|
||||||
torch.save(checkpoint, self.save_dir / 'last_checkpoint.pt')
|
|
||||||
|
|
||||||
# 保存最佳模型
|
|
||||||
if is_best:
|
|
||||||
torch.save(checkpoint, self.save_dir / 'best_model.pt')
|
|
||||||
print(f" ✓ 保存最佳模型 (val_loss: {val_loss:.6f})")
|
|
||||||
|
|
||||||
def train(self, epochs: int):
|
|
||||||
"""完整训练流程"""
|
|
||||||
print(f"\n开始训练 ({epochs} epochs)...")
|
|
||||||
print(f"设备: {self.device}")
|
|
||||||
print(f"模型参数量: {count_parameters(self.model):,}")
|
|
||||||
|
|
||||||
for epoch in range(1, epochs + 1):
|
|
||||||
print(f"\nEpoch {epoch}/{epochs}")
|
|
||||||
|
|
||||||
# 训练
|
|
||||||
train_loss = self.train_epoch()
|
|
||||||
|
|
||||||
# 验证
|
|
||||||
val_loss = self.validate()
|
|
||||||
|
|
||||||
# 更新学习率
|
|
||||||
self.scheduler.step()
|
|
||||||
current_lr = self.optimizer.param_groups[0]['lr']
|
|
||||||
|
|
||||||
# 记录历史
|
|
||||||
self.history['train_loss'].append(train_loss)
|
|
||||||
self.history['val_loss'].append(val_loss)
|
|
||||||
self.history['learning_rate'].append(current_lr)
|
|
||||||
|
|
||||||
# 打印进度
|
|
||||||
print(f" Train Loss: {train_loss:.6f}")
|
|
||||||
print(f" Val Loss: {val_loss:.6f}")
|
|
||||||
print(f" LR: {current_lr:.2e}")
|
|
||||||
|
|
||||||
# 保存检查点
|
|
||||||
is_best = val_loss < self.best_val_loss
|
|
||||||
if is_best:
|
|
||||||
self.best_val_loss = val_loss
|
|
||||||
self.save_checkpoint(epoch, val_loss, is_best)
|
|
||||||
|
|
||||||
# 早停检查
|
|
||||||
if self.early_stopping(val_loss):
|
|
||||||
print(f"\n早停触发!验证损失已 {self.early_stopping.patience} 个 epoch 未改善")
|
|
||||||
break
|
|
||||||
|
|
||||||
print(f"\n训练完成!最佳验证损失: {self.best_val_loss:.6f}")
|
|
||||||
|
|
||||||
# 保存训练历史
|
|
||||||
self.save_history()
|
|
||||||
|
|
||||||
return self.history
|
|
||||||
|
|
||||||
def save_history(self):
|
|
||||||
"""保存训练历史"""
|
|
||||||
history_path = self.save_dir / 'training_history.json'
|
|
||||||
with open(history_path, 'w') as f:
|
|
||||||
json.dump(self.history, f, indent=2)
|
|
||||||
print(f"训练历史已保存: {history_path}")
|
|
||||||
|
|
||||||
# 绘制训练曲线
|
|
||||||
self.plot_training_curves()
|
|
||||||
|
|
||||||
def plot_training_curves(self):
|
|
||||||
"""绘制训练曲线"""
|
|
||||||
if not HAS_MATPLOTLIB:
|
|
||||||
print("matplotlib 未安装,跳过绘图")
|
|
||||||
return
|
|
||||||
|
|
||||||
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
|
|
||||||
|
|
||||||
epochs = range(1, len(self.history['train_loss']) + 1)
|
|
||||||
|
|
||||||
# 1. Loss 曲线
|
|
||||||
ax1 = axes[0]
|
|
||||||
ax1.plot(epochs, self.history['train_loss'], 'b-', label='Train Loss', linewidth=2)
|
|
||||||
ax1.plot(epochs, self.history['val_loss'], 'r-', label='Val Loss', linewidth=2)
|
|
||||||
ax1.set_xlabel('Epoch', fontsize=12)
|
|
||||||
ax1.set_ylabel('Loss', fontsize=12)
|
|
||||||
ax1.set_title('Training & Validation Loss', fontsize=14)
|
|
||||||
ax1.legend(fontsize=11)
|
|
||||||
ax1.grid(True, alpha=0.3)
|
|
||||||
|
|
||||||
# 标记最佳点
|
|
||||||
best_epoch = np.argmin(self.history['val_loss']) + 1
|
|
||||||
best_val_loss = min(self.history['val_loss'])
|
|
||||||
ax1.axvline(x=best_epoch, color='g', linestyle='--', alpha=0.7, label=f'Best Epoch: {best_epoch}')
|
|
||||||
ax1.scatter([best_epoch], [best_val_loss], color='g', s=100, zorder=5)
|
|
||||||
ax1.annotate(f'Best: {best_val_loss:.6f}', xy=(best_epoch, best_val_loss),
|
|
||||||
xytext=(best_epoch + 2, best_val_loss + 0.0005),
|
|
||||||
fontsize=10, color='green')
|
|
||||||
|
|
||||||
# 2. 学习率曲线
|
|
||||||
ax2 = axes[1]
|
|
||||||
ax2.plot(epochs, self.history['learning_rate'], 'g-', linewidth=2)
|
|
||||||
ax2.set_xlabel('Epoch', fontsize=12)
|
|
||||||
ax2.set_ylabel('Learning Rate', fontsize=12)
|
|
||||||
ax2.set_title('Learning Rate Schedule', fontsize=14)
|
|
||||||
ax2.set_yscale('log')
|
|
||||||
ax2.grid(True, alpha=0.3)
|
|
||||||
|
|
||||||
plt.tight_layout()
|
|
||||||
|
|
||||||
# 保存图片
|
|
||||||
plot_path = self.save_dir / 'training_curves.png'
|
|
||||||
plt.savefig(plot_path, dpi=150, bbox_inches='tight')
|
|
||||||
plt.close()
|
|
||||||
print(f"训练曲线已保存: {plot_path}")
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 阈值计算(使用验证集)====================
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def compute_thresholds(
|
|
||||||
model: nn.Module,
|
|
||||||
data_loader: DataLoader,
|
|
||||||
device: torch.device,
|
|
||||||
percentiles: List[float] = [90, 95, 99]
|
|
||||||
) -> Dict[str, float]:
|
|
||||||
"""
|
|
||||||
在验证集上计算重构误差的百分位数阈值
|
|
||||||
|
|
||||||
注:使用验证集而非测试集,避免数据泄露
|
|
||||||
"""
|
|
||||||
model.eval()
|
|
||||||
all_errors = []
|
|
||||||
|
|
||||||
print("计算异动阈值(使用验证集)...")
|
|
||||||
for batch in tqdm(data_loader, desc="Computing thresholds"):
|
|
||||||
batch = batch.to(device)
|
|
||||||
errors = model.compute_reconstruction_error(batch, reduction='none')
|
|
||||||
|
|
||||||
# 取每个序列的最后一个时刻误差(预测当前时刻)
|
|
||||||
seq_errors = errors[:, -1] # (batch,)
|
|
||||||
all_errors.append(seq_errors.cpu().numpy())
|
|
||||||
|
|
||||||
all_errors = np.concatenate(all_errors)
|
|
||||||
|
|
||||||
thresholds = {}
|
|
||||||
for p in percentiles:
|
|
||||||
threshold = np.percentile(all_errors, p)
|
|
||||||
thresholds[f'p{p}'] = float(threshold)
|
|
||||||
print(f" P{p}: {threshold:.6f}")
|
|
||||||
|
|
||||||
# 额外统计
|
|
||||||
thresholds['mean'] = float(np.mean(all_errors))
|
|
||||||
thresholds['std'] = float(np.std(all_errors))
|
|
||||||
thresholds['median'] = float(np.median(all_errors))
|
|
||||||
|
|
||||||
print(f" Mean: {thresholds['mean']:.6f}")
|
|
||||||
print(f" Median: {thresholds['median']:.6f}")
|
|
||||||
print(f" Std: {thresholds['std']:.6f}")
|
|
||||||
|
|
||||||
return thresholds
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 主函数 ====================
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser(description='训练概念异动检测模型')
|
|
||||||
parser.add_argument('--data_dir', type=str, default='ml/data',
|
|
||||||
help='数据目录路径')
|
|
||||||
parser.add_argument('--epochs', type=int, default=100,
|
|
||||||
help='训练轮数')
|
|
||||||
parser.add_argument('--batch_size', type=int, default=4096,
|
|
||||||
help='批次大小(4x RTX 4090 推荐 4096~8192)')
|
|
||||||
parser.add_argument('--lr', type=float, default=3e-4,
|
|
||||||
help='学习率(大 batch 推荐 3e-4)')
|
|
||||||
parser.add_argument('--device', type=str, default='auto',
|
|
||||||
help='设备 (auto/cuda/cpu)')
|
|
||||||
parser.add_argument('--save_dir', type=str, default='ml/checkpoints',
|
|
||||||
help='模型保存目录')
|
|
||||||
parser.add_argument('--train_end', type=str, default='2024-06-30',
|
|
||||||
help='训练集截止日期')
|
|
||||||
parser.add_argument('--val_end', type=str, default='2024-09-30',
|
|
||||||
help='验证集截止日期')
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# 更新配置
|
|
||||||
config = TRAIN_CONFIG.copy()
|
|
||||||
config['batch_size'] = args.batch_size
|
|
||||||
config['epochs'] = args.epochs
|
|
||||||
config['learning_rate'] = args.lr
|
|
||||||
config['train_end_date'] = args.train_end
|
|
||||||
config['val_end_date'] = args.val_end
|
|
||||||
|
|
||||||
# 设备选择
|
|
||||||
if args.device == 'auto':
|
|
||||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
||||||
else:
|
|
||||||
device = torch.device(args.device)
|
|
||||||
|
|
||||||
print("=" * 60)
|
|
||||||
print("概念异动检测模型训练(修复版)")
|
|
||||||
print("=" * 60)
|
|
||||||
print(f"配置:")
|
|
||||||
print(f" 数据目录: {args.data_dir}")
|
|
||||||
print(f" 设备: {device}")
|
|
||||||
print(f" 批次大小: {config['batch_size']}")
|
|
||||||
print(f" 学习率: {config['learning_rate']}")
|
|
||||||
print(f" 训练轮数: {config['epochs']}")
|
|
||||||
print(f" 训练集截止: {config['train_end_date']}")
|
|
||||||
print(f" 验证集截止: {config['val_end_date']}")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
# 1. 按日期加载数据
|
|
||||||
print("\n[1/6] 加载数据...")
|
|
||||||
date_data = load_data_by_date(args.data_dir, config['features'])
|
|
||||||
|
|
||||||
# 2. 按日期划分
|
|
||||||
print("\n[2/6] 按日期划分数据集...")
|
|
||||||
train_data, val_data, test_data = split_data_by_date(
|
|
||||||
date_data,
|
|
||||||
config['train_end_date'],
|
|
||||||
config['val_end_date']
|
|
||||||
)
|
|
||||||
|
|
||||||
# 3. 按概念构建序列
|
|
||||||
print("\n[3/6] 按概念构建序列...")
|
|
||||||
print("训练集:")
|
|
||||||
train_sequences = build_sequences_by_concept(
|
|
||||||
train_data, config['features'], config['seq_len'], config['stride']
|
|
||||||
)
|
|
||||||
print("验证集:")
|
|
||||||
val_sequences = build_sequences_by_concept(
|
|
||||||
val_data, config['features'], config['seq_len'], config['stride']
|
|
||||||
)
|
|
||||||
print("测试集:")
|
|
||||||
test_sequences = build_sequences_by_concept(
|
|
||||||
test_data, config['features'], config['seq_len'], config['stride']
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(train_sequences) == 0:
|
|
||||||
print("错误: 训练集为空!请检查数据和日期范围")
|
|
||||||
return
|
|
||||||
|
|
||||||
# 4. 数据预处理(简单截断极端值,标准化在模型内部通过 Instance Norm 完成)
|
|
||||||
print("\n[4/6] 数据预处理...")
|
|
||||||
print(" 注意: 使用 Instance Norm,每个序列在模型内部单独标准化")
|
|
||||||
print(" 这样可以处理不同概念波动率差异(银行 vs 半导体)")
|
|
||||||
|
|
||||||
clip_value = config['clip_value']
|
|
||||||
print(f" 截断极端值: ±{clip_value}")
|
|
||||||
|
|
||||||
# 简单截断极端值(防止异常数据影响训练)
|
|
||||||
train_sequences = np.clip(train_sequences, -clip_value, clip_value)
|
|
||||||
if len(val_sequences) > 0:
|
|
||||||
val_sequences = np.clip(val_sequences, -clip_value, clip_value)
|
|
||||||
if len(test_sequences) > 0:
|
|
||||||
test_sequences = np.clip(test_sequences, -clip_value, clip_value)
|
|
||||||
|
|
||||||
# 保存配置
|
|
||||||
save_dir = Path(args.save_dir)
|
|
||||||
save_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
preprocess_params = {
|
|
||||||
'features': config['features'],
|
|
||||||
'normalization': 'instance_norm', # 在模型内部完成
|
|
||||||
'clip_value': clip_value,
|
|
||||||
'note': '标准化在模型内部通过 InstanceNorm1d 完成,无需外部 Scaler'
|
|
||||||
}
|
|
||||||
|
|
||||||
with open(save_dir / 'normalization_stats.json', 'w') as f:
|
|
||||||
json.dump(preprocess_params, f, indent=2)
|
|
||||||
print(f" 预处理参数已保存")
|
|
||||||
|
|
||||||
# 5. 创建数据集和加载器
|
|
||||||
print("\n[5/6] 创建数据加载器...")
|
|
||||||
train_dataset = SequenceDataset(train_sequences)
|
|
||||||
val_dataset = SequenceDataset(val_sequences) if len(val_sequences) > 0 else None
|
|
||||||
test_dataset = SequenceDataset(test_sequences) if len(test_sequences) > 0 else None
|
|
||||||
|
|
||||||
print(f" 训练序列: {len(train_dataset):,}")
|
|
||||||
print(f" 验证序列: {len(val_dataset) if val_dataset else 0:,}")
|
|
||||||
print(f" 测试序列: {len(test_dataset) if test_dataset else 0:,}")
|
|
||||||
|
|
||||||
# 多卡时增加 num_workers(Linux 上可以用更多)
|
|
||||||
n_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 1
|
|
||||||
num_workers = min(32, 8 * n_gpus) if sys.platform != 'win32' else 0
|
|
||||||
print(f" DataLoader workers: {num_workers}")
|
|
||||||
print(f" Batch size: {config['batch_size']}")
|
|
||||||
|
|
||||||
# 大 batch + 多 worker + prefetch 提速
|
|
||||||
train_loader = DataLoader(
|
|
||||||
train_dataset,
|
|
||||||
batch_size=config['batch_size'],
|
|
||||||
shuffle=True,
|
|
||||||
num_workers=num_workers,
|
|
||||||
pin_memory=True,
|
|
||||||
prefetch_factor=4 if num_workers > 0 else None, # 预取更多 batch
|
|
||||||
persistent_workers=True if num_workers > 0 else False, # 保持 worker 存活
|
|
||||||
drop_last=True # 丢弃不完整的最后一批,避免 batch 大小不一致
|
|
||||||
)
|
|
||||||
|
|
||||||
val_loader = DataLoader(
|
|
||||||
val_dataset,
|
|
||||||
batch_size=config['batch_size'] * 2, # 验证时可以用更大 batch(无梯度)
|
|
||||||
shuffle=False,
|
|
||||||
num_workers=num_workers,
|
|
||||||
pin_memory=True,
|
|
||||||
prefetch_factor=4 if num_workers > 0 else None,
|
|
||||||
persistent_workers=True if num_workers > 0 else False,
|
|
||||||
) if val_dataset else None
|
|
||||||
|
|
||||||
test_loader = DataLoader(
|
|
||||||
test_dataset,
|
|
||||||
batch_size=config['batch_size'] * 2,
|
|
||||||
shuffle=False,
|
|
||||||
num_workers=num_workers,
|
|
||||||
pin_memory=True,
|
|
||||||
prefetch_factor=4 if num_workers > 0 else None,
|
|
||||||
persistent_workers=True if num_workers > 0 else False,
|
|
||||||
) if test_dataset else None
|
|
||||||
|
|
||||||
# 6. 训练
|
|
||||||
print("\n[6/6] 训练模型...")
|
|
||||||
model_config = config['model'].copy()
|
|
||||||
model = TransformerAutoencoder(**model_config)
|
|
||||||
|
|
||||||
# 多卡并行
|
|
||||||
if torch.cuda.device_count() > 1:
|
|
||||||
print(f" 使用 {torch.cuda.device_count()} 张 GPU 并行训练")
|
|
||||||
model = nn.DataParallel(model)
|
|
||||||
|
|
||||||
if val_loader is None:
|
|
||||||
print("警告: 验证集为空,将使用训练集的一部分作为验证")
|
|
||||||
# 简单处理:用训练集的后 10% 作为验证
|
|
||||||
split_idx = int(len(train_dataset) * 0.9)
|
|
||||||
train_subset = torch.utils.data.Subset(train_dataset, range(split_idx))
|
|
||||||
val_subset = torch.utils.data.Subset(train_dataset, range(split_idx, len(train_dataset)))
|
|
||||||
|
|
||||||
train_loader = DataLoader(train_subset, batch_size=config['batch_size'], shuffle=True, num_workers=num_workers, pin_memory=True)
|
|
||||||
val_loader = DataLoader(val_subset, batch_size=config['batch_size'], shuffle=False, num_workers=num_workers, pin_memory=True)
|
|
||||||
|
|
||||||
trainer = Trainer(
|
|
||||||
model=model,
|
|
||||||
train_loader=train_loader,
|
|
||||||
val_loader=val_loader,
|
|
||||||
config=config,
|
|
||||||
device=device,
|
|
||||||
save_dir=args.save_dir
|
|
||||||
)
|
|
||||||
|
|
||||||
history = trainer.train(config['epochs'])
|
|
||||||
|
|
||||||
# 7. 计算阈值(使用验证集)
|
|
||||||
print("\n[额外] 计算异动阈值...")
|
|
||||||
|
|
||||||
# 加载最佳模型
|
|
||||||
best_checkpoint = torch.load(
|
|
||||||
save_dir / 'best_model.pt',
|
|
||||||
map_location=device
|
|
||||||
)
|
|
||||||
model.load_state_dict(best_checkpoint['model_state_dict'])
|
|
||||||
model.to(device)
|
|
||||||
|
|
||||||
# 使用验证集计算阈值(避免数据泄露)
|
|
||||||
thresholds = compute_thresholds(
|
|
||||||
model,
|
|
||||||
val_loader,
|
|
||||||
device,
|
|
||||||
config['threshold_percentiles']
|
|
||||||
)
|
|
||||||
|
|
||||||
# 保存阈值
|
|
||||||
with open(save_dir / 'thresholds.json', 'w') as f:
|
|
||||||
json.dump(thresholds, f, indent=2)
|
|
||||||
print(f"阈值已保存")
|
|
||||||
|
|
||||||
# 保存完整配置
|
|
||||||
with open(save_dir / 'config.json', 'w') as f:
|
|
||||||
json.dump(config, f, indent=2)
|
|
||||||
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
print("训练完成!")
|
|
||||||
print("=" * 60)
|
|
||||||
print(f"模型保存位置: {args.save_dir}")
|
|
||||||
print(f" - best_model.pt: 最佳模型权重")
|
|
||||||
print(f" - thresholds.json: 异动阈值")
|
|
||||||
print(f" - normalization_stats.json: 标准化参数")
|
|
||||||
print(f" - config.json: 训练配置")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
622
ml/train_v2.py
622
ml/train_v2.py
@@ -1,622 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""
|
|
||||||
训练脚本 V2 - 基于 Z-Score 特征的 LSTM Autoencoder
|
|
||||||
|
|
||||||
改进点:
|
|
||||||
1. 使用 Z-Score 特征(相对于同时间片历史的偏离)
|
|
||||||
2. 短序列:10分钟(不需要30分钟预热)
|
|
||||||
3. 开盘即可检测:9:30 直接有特征
|
|
||||||
|
|
||||||
模型输入:
|
|
||||||
- 过去10分钟的 Z-Score 特征序列
|
|
||||||
- 特征:alpha_zscore, amt_zscore, rank_zscore, momentum_3m, momentum_5m, limit_up_ratio
|
|
||||||
|
|
||||||
模型学习:
|
|
||||||
- 学习 Z-Score 序列的"正常演化模式"
|
|
||||||
- 异动 = Z-Score 序列的异常演化(重构误差大)
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import argparse
|
|
||||||
import json
|
|
||||||
from datetime import datetime
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import List, Tuple, Dict
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pandas as pd
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from torch.utils.data import Dataset, DataLoader
|
|
||||||
from torch.optim import AdamW
|
|
||||||
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from model import TransformerAutoencoder, AnomalyDetectionLoss, count_parameters
|
|
||||||
|
|
||||||
# 性能优化
|
|
||||||
torch.backends.cudnn.benchmark = True
|
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
|
||||||
torch.backends.cudnn.allow_tf32 = True
|
|
||||||
|
|
||||||
try:
|
|
||||||
import matplotlib
|
|
||||||
matplotlib.use('Agg')
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
HAS_MATPLOTLIB = True
|
|
||||||
except ImportError:
|
|
||||||
HAS_MATPLOTLIB = False
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 配置 ====================
|
|
||||||
|
|
||||||
TRAIN_CONFIG = {
|
|
||||||
# 数据配置(改进!)
|
|
||||||
'seq_len': 10, # 10分钟序列(不是30分钟!)
|
|
||||||
'stride': 2, # 步长2分钟
|
|
||||||
|
|
||||||
# 时间切分
|
|
||||||
'train_end_date': '2024-06-30',
|
|
||||||
'val_end_date': '2024-09-30',
|
|
||||||
|
|
||||||
# V2 特征(Z-Score 为主)
|
|
||||||
'features': [
|
|
||||||
'alpha_zscore', # Alpha 的 Z-Score
|
|
||||||
'amt_zscore', # 成交额的 Z-Score
|
|
||||||
'rank_zscore', # 排名的 Z-Score
|
|
||||||
'momentum_3m', # 3分钟动量
|
|
||||||
'momentum_5m', # 5分钟动量
|
|
||||||
'limit_up_ratio', # 涨停占比
|
|
||||||
],
|
|
||||||
|
|
||||||
# 训练配置
|
|
||||||
'batch_size': 4096,
|
|
||||||
'epochs': 100,
|
|
||||||
'learning_rate': 3e-4,
|
|
||||||
'weight_decay': 1e-5,
|
|
||||||
'gradient_clip': 1.0,
|
|
||||||
|
|
||||||
# 早停配置
|
|
||||||
'patience': 15,
|
|
||||||
'min_delta': 1e-6,
|
|
||||||
|
|
||||||
# 模型配置(小型 LSTM)
|
|
||||||
'model': {
|
|
||||||
'n_features': 6,
|
|
||||||
'hidden_dim': 32,
|
|
||||||
'latent_dim': 4,
|
|
||||||
'num_layers': 1,
|
|
||||||
'dropout': 0.2,
|
|
||||||
'bidirectional': True,
|
|
||||||
},
|
|
||||||
|
|
||||||
# 标准化配置
|
|
||||||
'clip_value': 5.0, # Z-Score 已经标准化,clip 5.0 足够
|
|
||||||
|
|
||||||
# 阈值配置
|
|
||||||
'threshold_percentiles': [90, 95, 99],
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 数据加载 ====================
|
|
||||||
|
|
||||||
def load_data_by_date(data_dir: str, features: List[str]) -> Dict[str, pd.DataFrame]:
|
|
||||||
"""按日期加载 V2 数据"""
|
|
||||||
data_path = Path(data_dir)
|
|
||||||
parquet_files = sorted(data_path.glob("features_v2_*.parquet"))
|
|
||||||
|
|
||||||
if not parquet_files:
|
|
||||||
raise FileNotFoundError(f"未找到 V2 数据文件: {data_dir}")
|
|
||||||
|
|
||||||
print(f"找到 {len(parquet_files)} 个 V2 数据文件")
|
|
||||||
|
|
||||||
date_data = {}
|
|
||||||
|
|
||||||
for pf in tqdm(parquet_files, desc="加载数据"):
|
|
||||||
date = pf.stem.replace('features_v2_', '')
|
|
||||||
|
|
||||||
df = pd.read_parquet(pf)
|
|
||||||
|
|
||||||
required_cols = features + ['concept_id', 'timestamp']
|
|
||||||
missing_cols = [c for c in required_cols if c not in df.columns]
|
|
||||||
if missing_cols:
|
|
||||||
print(f"警告: {date} 缺少列: {missing_cols}, 跳过")
|
|
||||||
continue
|
|
||||||
|
|
||||||
date_data[date] = df
|
|
||||||
|
|
||||||
print(f"成功加载 {len(date_data)} 天的数据")
|
|
||||||
return date_data
|
|
||||||
|
|
||||||
|
|
||||||
def split_data_by_date(
|
|
||||||
date_data: Dict[str, pd.DataFrame],
|
|
||||||
train_end: str,
|
|
||||||
val_end: str
|
|
||||||
) -> Tuple[Dict[str, pd.DataFrame], Dict[str, pd.DataFrame], Dict[str, pd.DataFrame]]:
|
|
||||||
"""按日期划分数据集"""
|
|
||||||
train_data = {}
|
|
||||||
val_data = {}
|
|
||||||
test_data = {}
|
|
||||||
|
|
||||||
for date, df in date_data.items():
|
|
||||||
if date <= train_end:
|
|
||||||
train_data[date] = df
|
|
||||||
elif date <= val_end:
|
|
||||||
val_data[date] = df
|
|
||||||
else:
|
|
||||||
test_data[date] = df
|
|
||||||
|
|
||||||
print(f"数据集划分:")
|
|
||||||
print(f" 训练集: {len(train_data)} 天 (<= {train_end})")
|
|
||||||
print(f" 验证集: {len(val_data)} 天 ({train_end} ~ {val_end})")
|
|
||||||
print(f" 测试集: {len(test_data)} 天 (> {val_end})")
|
|
||||||
|
|
||||||
return train_data, val_data, test_data
|
|
||||||
|
|
||||||
|
|
||||||
def build_sequences_by_concept(
|
|
||||||
date_data: Dict[str, pd.DataFrame],
|
|
||||||
features: List[str],
|
|
||||||
seq_len: int,
|
|
||||||
stride: int
|
|
||||||
) -> np.ndarray:
|
|
||||||
"""按概念分组构建序列"""
|
|
||||||
all_dfs = []
|
|
||||||
for date, df in sorted(date_data.items()):
|
|
||||||
df = df.copy()
|
|
||||||
df['date'] = date
|
|
||||||
all_dfs.append(df)
|
|
||||||
|
|
||||||
if not all_dfs:
|
|
||||||
return np.array([])
|
|
||||||
|
|
||||||
combined = pd.concat(all_dfs, ignore_index=True)
|
|
||||||
combined = combined.sort_values(['concept_id', 'date', 'timestamp'])
|
|
||||||
|
|
||||||
all_sequences = []
|
|
||||||
grouped = combined.groupby('concept_id', sort=False)
|
|
||||||
n_concepts = len(grouped)
|
|
||||||
|
|
||||||
for concept_id, concept_df in tqdm(grouped, desc="构建序列", total=n_concepts, leave=False):
|
|
||||||
feature_data = concept_df[features].values
|
|
||||||
feature_data = np.nan_to_num(feature_data, nan=0.0, posinf=0.0, neginf=0.0)
|
|
||||||
|
|
||||||
n_points = len(feature_data)
|
|
||||||
for start in range(0, n_points - seq_len + 1, stride):
|
|
||||||
seq = feature_data[start:start + seq_len]
|
|
||||||
all_sequences.append(seq)
|
|
||||||
|
|
||||||
if not all_sequences:
|
|
||||||
return np.array([])
|
|
||||||
|
|
||||||
sequences = np.array(all_sequences)
|
|
||||||
print(f" 构建序列: {len(sequences):,} 条 (来自 {n_concepts} 个概念)")
|
|
||||||
|
|
||||||
return sequences
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 数据集 ====================
|
|
||||||
|
|
||||||
class SequenceDataset(Dataset):
|
|
||||||
def __init__(self, sequences: np.ndarray):
|
|
||||||
self.sequences = torch.FloatTensor(sequences)
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
|
||||||
return len(self.sequences)
|
|
||||||
|
|
||||||
def __getitem__(self, idx: int) -> torch.Tensor:
|
|
||||||
return self.sequences[idx]
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 训练器 ====================
|
|
||||||
|
|
||||||
class EarlyStopping:
|
|
||||||
def __init__(self, patience: int = 10, min_delta: float = 1e-6):
|
|
||||||
self.patience = patience
|
|
||||||
self.min_delta = min_delta
|
|
||||||
self.counter = 0
|
|
||||||
self.best_loss = float('inf')
|
|
||||||
self.early_stop = False
|
|
||||||
|
|
||||||
def __call__(self, val_loss: float) -> bool:
|
|
||||||
if val_loss < self.best_loss - self.min_delta:
|
|
||||||
self.best_loss = val_loss
|
|
||||||
self.counter = 0
|
|
||||||
else:
|
|
||||||
self.counter += 1
|
|
||||||
if self.counter >= self.patience:
|
|
||||||
self.early_stop = True
|
|
||||||
return self.early_stop
|
|
||||||
|
|
||||||
|
|
||||||
class Trainer:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model: nn.Module,
|
|
||||||
train_loader: DataLoader,
|
|
||||||
val_loader: DataLoader,
|
|
||||||
config: Dict,
|
|
||||||
device: torch.device,
|
|
||||||
save_dir: str = 'ml/checkpoints_v2'
|
|
||||||
):
|
|
||||||
self.model = model.to(device)
|
|
||||||
self.train_loader = train_loader
|
|
||||||
self.val_loader = val_loader
|
|
||||||
self.config = config
|
|
||||||
self.device = device
|
|
||||||
self.save_dir = Path(save_dir)
|
|
||||||
self.save_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
self.optimizer = AdamW(
|
|
||||||
model.parameters(),
|
|
||||||
lr=config['learning_rate'],
|
|
||||||
weight_decay=config['weight_decay']
|
|
||||||
)
|
|
||||||
|
|
||||||
self.scheduler = CosineAnnealingWarmRestarts(
|
|
||||||
self.optimizer, T_0=10, T_mult=2, eta_min=1e-6
|
|
||||||
)
|
|
||||||
|
|
||||||
self.criterion = AnomalyDetectionLoss()
|
|
||||||
|
|
||||||
self.early_stopping = EarlyStopping(
|
|
||||||
patience=config['patience'],
|
|
||||||
min_delta=config['min_delta']
|
|
||||||
)
|
|
||||||
|
|
||||||
self.use_amp = torch.cuda.is_available()
|
|
||||||
self.scaler = torch.cuda.amp.GradScaler() if self.use_amp else None
|
|
||||||
if self.use_amp:
|
|
||||||
print(" ✓ 启用 AMP 混合精度训练")
|
|
||||||
|
|
||||||
self.history = {'train_loss': [], 'val_loss': [], 'learning_rate': []}
|
|
||||||
self.best_val_loss = float('inf')
|
|
||||||
|
|
||||||
def train_epoch(self) -> float:
|
|
||||||
self.model.train()
|
|
||||||
total_loss = 0.0
|
|
||||||
n_batches = 0
|
|
||||||
|
|
||||||
pbar = tqdm(self.train_loader, desc="Training", leave=False)
|
|
||||||
for batch in pbar:
|
|
||||||
batch = batch.to(self.device, non_blocking=True)
|
|
||||||
self.optimizer.zero_grad(set_to_none=True)
|
|
||||||
|
|
||||||
if self.use_amp:
|
|
||||||
with torch.cuda.amp.autocast():
|
|
||||||
output, latent = self.model(batch)
|
|
||||||
loss, _ = self.criterion(output, batch, latent)
|
|
||||||
|
|
||||||
self.scaler.scale(loss).backward()
|
|
||||||
self.scaler.unscale_(self.optimizer)
|
|
||||||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config['gradient_clip'])
|
|
||||||
self.scaler.step(self.optimizer)
|
|
||||||
self.scaler.update()
|
|
||||||
else:
|
|
||||||
output, latent = self.model(batch)
|
|
||||||
loss, _ = self.criterion(output, batch, latent)
|
|
||||||
loss.backward()
|
|
||||||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config['gradient_clip'])
|
|
||||||
self.optimizer.step()
|
|
||||||
|
|
||||||
total_loss += loss.item()
|
|
||||||
n_batches += 1
|
|
||||||
pbar.set_postfix({'loss': f"{loss.item():.4f}"})
|
|
||||||
|
|
||||||
return total_loss / n_batches
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def validate(self) -> float:
|
|
||||||
self.model.eval()
|
|
||||||
total_loss = 0.0
|
|
||||||
n_batches = 0
|
|
||||||
|
|
||||||
for batch in self.val_loader:
|
|
||||||
batch = batch.to(self.device, non_blocking=True)
|
|
||||||
|
|
||||||
if self.use_amp:
|
|
||||||
with torch.cuda.amp.autocast():
|
|
||||||
output, latent = self.model(batch)
|
|
||||||
loss, _ = self.criterion(output, batch, latent)
|
|
||||||
else:
|
|
||||||
output, latent = self.model(batch)
|
|
||||||
loss, _ = self.criterion(output, batch, latent)
|
|
||||||
|
|
||||||
total_loss += loss.item()
|
|
||||||
n_batches += 1
|
|
||||||
|
|
||||||
return total_loss / n_batches
|
|
||||||
|
|
||||||
def save_checkpoint(self, epoch: int, val_loss: float, is_best: bool = False):
|
|
||||||
model_to_save = self.model.module if hasattr(self.model, 'module') else self.model
|
|
||||||
|
|
||||||
checkpoint = {
|
|
||||||
'epoch': epoch,
|
|
||||||
'model_state_dict': model_to_save.state_dict(),
|
|
||||||
'optimizer_state_dict': self.optimizer.state_dict(),
|
|
||||||
'scheduler_state_dict': self.scheduler.state_dict(),
|
|
||||||
'val_loss': val_loss,
|
|
||||||
'config': self.config,
|
|
||||||
}
|
|
||||||
|
|
||||||
torch.save(checkpoint, self.save_dir / 'last_checkpoint.pt')
|
|
||||||
|
|
||||||
if is_best:
|
|
||||||
torch.save(checkpoint, self.save_dir / 'best_model.pt')
|
|
||||||
print(f" ✓ 保存最佳模型 (val_loss: {val_loss:.6f})")
|
|
||||||
|
|
||||||
def train(self, epochs: int):
|
|
||||||
print(f"\n开始训练 ({epochs} epochs)...")
|
|
||||||
print(f"设备: {self.device}")
|
|
||||||
print(f"模型参数量: {count_parameters(self.model):,}")
|
|
||||||
|
|
||||||
for epoch in range(1, epochs + 1):
|
|
||||||
print(f"\nEpoch {epoch}/{epochs}")
|
|
||||||
|
|
||||||
train_loss = self.train_epoch()
|
|
||||||
val_loss = self.validate()
|
|
||||||
|
|
||||||
self.scheduler.step()
|
|
||||||
current_lr = self.optimizer.param_groups[0]['lr']
|
|
||||||
|
|
||||||
self.history['train_loss'].append(train_loss)
|
|
||||||
self.history['val_loss'].append(val_loss)
|
|
||||||
self.history['learning_rate'].append(current_lr)
|
|
||||||
|
|
||||||
print(f" Train Loss: {train_loss:.6f}")
|
|
||||||
print(f" Val Loss: {val_loss:.6f}")
|
|
||||||
print(f" LR: {current_lr:.2e}")
|
|
||||||
|
|
||||||
is_best = val_loss < self.best_val_loss
|
|
||||||
if is_best:
|
|
||||||
self.best_val_loss = val_loss
|
|
||||||
self.save_checkpoint(epoch, val_loss, is_best)
|
|
||||||
|
|
||||||
if self.early_stopping(val_loss):
|
|
||||||
print(f"\n早停触发!")
|
|
||||||
break
|
|
||||||
|
|
||||||
print(f"\n训练完成!最佳验证损失: {self.best_val_loss:.6f}")
|
|
||||||
self.save_history()
|
|
||||||
|
|
||||||
return self.history
|
|
||||||
|
|
||||||
def save_history(self):
|
|
||||||
history_path = self.save_dir / 'training_history.json'
|
|
||||||
with open(history_path, 'w') as f:
|
|
||||||
json.dump(self.history, f, indent=2)
|
|
||||||
print(f"训练历史已保存: {history_path}")
|
|
||||||
|
|
||||||
if HAS_MATPLOTLIB:
|
|
||||||
self.plot_training_curves()
|
|
||||||
|
|
||||||
def plot_training_curves(self):
|
|
||||||
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
|
|
||||||
epochs = range(1, len(self.history['train_loss']) + 1)
|
|
||||||
|
|
||||||
ax1 = axes[0]
|
|
||||||
ax1.plot(epochs, self.history['train_loss'], 'b-', label='Train Loss', linewidth=2)
|
|
||||||
ax1.plot(epochs, self.history['val_loss'], 'r-', label='Val Loss', linewidth=2)
|
|
||||||
ax1.set_xlabel('Epoch')
|
|
||||||
ax1.set_ylabel('Loss')
|
|
||||||
ax1.set_title('Training & Validation Loss (V2)')
|
|
||||||
ax1.legend()
|
|
||||||
ax1.grid(True, alpha=0.3)
|
|
||||||
|
|
||||||
best_epoch = np.argmin(self.history['val_loss']) + 1
|
|
||||||
best_val_loss = min(self.history['val_loss'])
|
|
||||||
ax1.axvline(x=best_epoch, color='g', linestyle='--', alpha=0.7)
|
|
||||||
ax1.scatter([best_epoch], [best_val_loss], color='g', s=100, zorder=5)
|
|
||||||
|
|
||||||
ax2 = axes[1]
|
|
||||||
ax2.plot(epochs, self.history['learning_rate'], 'g-', linewidth=2)
|
|
||||||
ax2.set_xlabel('Epoch')
|
|
||||||
ax2.set_ylabel('Learning Rate')
|
|
||||||
ax2.set_title('Learning Rate Schedule')
|
|
||||||
ax2.set_yscale('log')
|
|
||||||
ax2.grid(True, alpha=0.3)
|
|
||||||
|
|
||||||
plt.tight_layout()
|
|
||||||
plt.savefig(self.save_dir / 'training_curves.png', dpi=150, bbox_inches='tight')
|
|
||||||
plt.close()
|
|
||||||
print(f"训练曲线已保存")
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 阈值计算 ====================
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def compute_thresholds(
|
|
||||||
model: nn.Module,
|
|
||||||
data_loader: DataLoader,
|
|
||||||
device: torch.device,
|
|
||||||
percentiles: List[float] = [90, 95, 99]
|
|
||||||
) -> Dict[str, float]:
|
|
||||||
"""在验证集上计算阈值"""
|
|
||||||
model.eval()
|
|
||||||
all_errors = []
|
|
||||||
|
|
||||||
print("计算异动阈值...")
|
|
||||||
for batch in tqdm(data_loader, desc="Computing thresholds"):
|
|
||||||
batch = batch.to(device)
|
|
||||||
errors = model.compute_reconstruction_error(batch, reduction='none')
|
|
||||||
seq_errors = errors[:, -1] # 最后一个时刻
|
|
||||||
all_errors.append(seq_errors.cpu().numpy())
|
|
||||||
|
|
||||||
all_errors = np.concatenate(all_errors)
|
|
||||||
|
|
||||||
thresholds = {}
|
|
||||||
for p in percentiles:
|
|
||||||
threshold = np.percentile(all_errors, p)
|
|
||||||
thresholds[f'p{p}'] = float(threshold)
|
|
||||||
print(f" P{p}: {threshold:.6f}")
|
|
||||||
|
|
||||||
thresholds['mean'] = float(np.mean(all_errors))
|
|
||||||
thresholds['std'] = float(np.std(all_errors))
|
|
||||||
thresholds['median'] = float(np.median(all_errors))
|
|
||||||
|
|
||||||
return thresholds
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 主函数 ====================
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser(description='训练 V2 模型')
|
|
||||||
parser.add_argument('--data_dir', type=str, default='ml/data_v2', help='V2 数据目录')
|
|
||||||
parser.add_argument('--epochs', type=int, default=100)
|
|
||||||
parser.add_argument('--batch_size', type=int, default=4096)
|
|
||||||
parser.add_argument('--lr', type=float, default=3e-4)
|
|
||||||
parser.add_argument('--device', type=str, default='auto')
|
|
||||||
parser.add_argument('--save_dir', type=str, default='ml/checkpoints_v2')
|
|
||||||
parser.add_argument('--train_end', type=str, default='2024-06-30')
|
|
||||||
parser.add_argument('--val_end', type=str, default='2024-09-30')
|
|
||||||
parser.add_argument('--seq_len', type=int, default=10, help='序列长度(分钟)')
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
config = TRAIN_CONFIG.copy()
|
|
||||||
config['batch_size'] = args.batch_size
|
|
||||||
config['epochs'] = args.epochs
|
|
||||||
config['learning_rate'] = args.lr
|
|
||||||
config['train_end_date'] = args.train_end
|
|
||||||
config['val_end_date'] = args.val_end
|
|
||||||
config['seq_len'] = args.seq_len
|
|
||||||
|
|
||||||
if args.device == 'auto':
|
|
||||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
||||||
else:
|
|
||||||
device = torch.device(args.device)
|
|
||||||
|
|
||||||
print("=" * 60)
|
|
||||||
print("概念异动检测模型训练 V2(Z-Score 特征)")
|
|
||||||
print("=" * 60)
|
|
||||||
print(f"数据目录: {args.data_dir}")
|
|
||||||
print(f"设备: {device}")
|
|
||||||
print(f"序列长度: {config['seq_len']} 分钟")
|
|
||||||
print(f"批次大小: {config['batch_size']}")
|
|
||||||
print(f"特征: {config['features']}")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
# 1. 加载数据
|
|
||||||
print("\n[1/6] 加载 V2 数据...")
|
|
||||||
date_data = load_data_by_date(args.data_dir, config['features'])
|
|
||||||
|
|
||||||
# 2. 划分数据集
|
|
||||||
print("\n[2/6] 划分数据集...")
|
|
||||||
train_data, val_data, test_data = split_data_by_date(
|
|
||||||
date_data, config['train_end_date'], config['val_end_date']
|
|
||||||
)
|
|
||||||
|
|
||||||
# 3. 构建序列
|
|
||||||
print("\n[3/6] 构建序列...")
|
|
||||||
print("训练集:")
|
|
||||||
train_sequences = build_sequences_by_concept(
|
|
||||||
train_data, config['features'], config['seq_len'], config['stride']
|
|
||||||
)
|
|
||||||
print("验证集:")
|
|
||||||
val_sequences = build_sequences_by_concept(
|
|
||||||
val_data, config['features'], config['seq_len'], config['stride']
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(train_sequences) == 0:
|
|
||||||
print("错误: 训练集为空!")
|
|
||||||
return
|
|
||||||
|
|
||||||
# 4. 预处理
|
|
||||||
print("\n[4/6] 数据预处理...")
|
|
||||||
clip_value = config['clip_value']
|
|
||||||
print(f" Z-Score 特征已标准化,截断: ±{clip_value}")
|
|
||||||
|
|
||||||
train_sequences = np.clip(train_sequences, -clip_value, clip_value)
|
|
||||||
if len(val_sequences) > 0:
|
|
||||||
val_sequences = np.clip(val_sequences, -clip_value, clip_value)
|
|
||||||
|
|
||||||
# 保存配置
|
|
||||||
save_dir = Path(args.save_dir)
|
|
||||||
save_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
with open(save_dir / 'config.json', 'w') as f:
|
|
||||||
json.dump(config, f, indent=2)
|
|
||||||
|
|
||||||
# 5. 创建数据加载器
|
|
||||||
print("\n[5/6] 创建数据加载器...")
|
|
||||||
train_dataset = SequenceDataset(train_sequences)
|
|
||||||
val_dataset = SequenceDataset(val_sequences) if len(val_sequences) > 0 else None
|
|
||||||
|
|
||||||
print(f" 训练序列: {len(train_dataset):,}")
|
|
||||||
print(f" 验证序列: {len(val_dataset) if val_dataset else 0:,}")
|
|
||||||
|
|
||||||
n_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 1
|
|
||||||
num_workers = min(32, 8 * n_gpus) if sys.platform != 'win32' else 0
|
|
||||||
|
|
||||||
train_loader = DataLoader(
|
|
||||||
train_dataset,
|
|
||||||
batch_size=config['batch_size'],
|
|
||||||
shuffle=True,
|
|
||||||
num_workers=num_workers,
|
|
||||||
pin_memory=True,
|
|
||||||
prefetch_factor=4 if num_workers > 0 else None,
|
|
||||||
persistent_workers=True if num_workers > 0 else False,
|
|
||||||
drop_last=True
|
|
||||||
)
|
|
||||||
|
|
||||||
val_loader = DataLoader(
|
|
||||||
val_dataset,
|
|
||||||
batch_size=config['batch_size'] * 2,
|
|
||||||
shuffle=False,
|
|
||||||
num_workers=num_workers,
|
|
||||||
pin_memory=True,
|
|
||||||
) if val_dataset else None
|
|
||||||
|
|
||||||
# 6. 训练
|
|
||||||
print("\n[6/6] 训练模型...")
|
|
||||||
model = TransformerAutoencoder(**config['model'])
|
|
||||||
|
|
||||||
if torch.cuda.device_count() > 1:
|
|
||||||
print(f" 使用 {torch.cuda.device_count()} 张 GPU 并行训练")
|
|
||||||
model = nn.DataParallel(model)
|
|
||||||
|
|
||||||
if val_loader is None:
|
|
||||||
print("警告: 验证集为空,使用训练集的 10% 作为验证")
|
|
||||||
split_idx = int(len(train_dataset) * 0.9)
|
|
||||||
train_subset = torch.utils.data.Subset(train_dataset, range(split_idx))
|
|
||||||
val_subset = torch.utils.data.Subset(train_dataset, range(split_idx, len(train_dataset)))
|
|
||||||
train_loader = DataLoader(train_subset, batch_size=config['batch_size'], shuffle=True, num_workers=num_workers, pin_memory=True)
|
|
||||||
val_loader = DataLoader(val_subset, batch_size=config['batch_size'], shuffle=False, num_workers=num_workers, pin_memory=True)
|
|
||||||
|
|
||||||
trainer = Trainer(
|
|
||||||
model=model,
|
|
||||||
train_loader=train_loader,
|
|
||||||
val_loader=val_loader,
|
|
||||||
config=config,
|
|
||||||
device=device,
|
|
||||||
save_dir=args.save_dir
|
|
||||||
)
|
|
||||||
|
|
||||||
trainer.train(config['epochs'])
|
|
||||||
|
|
||||||
# 计算阈值
|
|
||||||
print("\n[额外] 计算异动阈值...")
|
|
||||||
best_checkpoint = torch.load(save_dir / 'best_model.pt', map_location=device)
|
|
||||||
|
|
||||||
# 创建新的单 GPU 模型用于计算阈值(避免 DataParallel 问题)
|
|
||||||
threshold_model = TransformerAutoencoder(**config['model'])
|
|
||||||
threshold_model.load_state_dict(best_checkpoint['model_state_dict'])
|
|
||||||
threshold_model.to(device)
|
|
||||||
threshold_model.eval()
|
|
||||||
|
|
||||||
thresholds = compute_thresholds(threshold_model, val_loader, device, config['threshold_percentiles'])
|
|
||||||
|
|
||||||
with open(save_dir / 'thresholds.json', 'w') as f:
|
|
||||||
json.dump(thresholds, f, indent=2)
|
|
||||||
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
print("训练完成!")
|
|
||||||
print(f"模型保存位置: {args.save_dir}")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -1,132 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""
|
|
||||||
每日盘后运行:更新滚动基线
|
|
||||||
|
|
||||||
使用方法:
|
|
||||||
python ml/update_baseline.py
|
|
||||||
|
|
||||||
建议加入 crontab,每天 15:30 后运行:
|
|
||||||
30 15 * * 1-5 cd /path/to/project && python ml/update_baseline.py
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import pickle
|
|
||||||
import pandas as pd
|
|
||||||
import numpy as np
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from pathlib import Path
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
||||||
|
|
||||||
from ml.prepare_data_v2 import (
|
|
||||||
get_all_concepts, get_trading_days, compute_raw_concept_features,
|
|
||||||
init_process_connections, CONFIG, RAW_CACHE_DIR, BASELINE_DIR
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def update_rolling_baseline(baseline_days: int = 20):
|
|
||||||
"""
|
|
||||||
更新滚动基线(用于实盘检测)
|
|
||||||
|
|
||||||
基线 = 最近 N 个交易日每个时间片的统计量
|
|
||||||
"""
|
|
||||||
print("=" * 60)
|
|
||||||
print("更新滚动基线(用于实盘)")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
# 初始化连接
|
|
||||||
init_process_connections()
|
|
||||||
|
|
||||||
# 获取概念列表
|
|
||||||
concepts = get_all_concepts()
|
|
||||||
all_stocks = list(set(s for c in concepts for s in c['stocks']))
|
|
||||||
|
|
||||||
# 获取最近的交易日
|
|
||||||
today = datetime.now().strftime('%Y-%m-%d')
|
|
||||||
start_date = (datetime.now() - timedelta(days=60)).strftime('%Y-%m-%d') # 多取一些
|
|
||||||
|
|
||||||
trading_days = get_trading_days(start_date, today)
|
|
||||||
|
|
||||||
if len(trading_days) < baseline_days:
|
|
||||||
print(f"错误:交易日不足 {baseline_days} 天")
|
|
||||||
return
|
|
||||||
|
|
||||||
# 只取最近 N 天
|
|
||||||
recent_days = trading_days[-baseline_days:]
|
|
||||||
print(f"使用 {len(recent_days)} 天数据: {recent_days[0]} ~ {recent_days[-1]}")
|
|
||||||
|
|
||||||
# 加载原始数据
|
|
||||||
all_data = []
|
|
||||||
for trade_date in tqdm(recent_days, desc="加载数据"):
|
|
||||||
cache_file = os.path.join(RAW_CACHE_DIR, f'raw_{trade_date}.parquet')
|
|
||||||
|
|
||||||
if os.path.exists(cache_file):
|
|
||||||
df = pd.read_parquet(cache_file)
|
|
||||||
else:
|
|
||||||
df = compute_raw_concept_features(trade_date, concepts, all_stocks)
|
|
||||||
|
|
||||||
if not df.empty:
|
|
||||||
all_data.append(df)
|
|
||||||
|
|
||||||
if not all_data:
|
|
||||||
print("错误:无数据")
|
|
||||||
return
|
|
||||||
|
|
||||||
combined = pd.concat(all_data, ignore_index=True)
|
|
||||||
print(f"总数据量: {len(combined):,} 条")
|
|
||||||
|
|
||||||
# 按概念计算基线
|
|
||||||
baselines = {}
|
|
||||||
|
|
||||||
for concept_id, group in tqdm(combined.groupby('concept_id'), desc="计算基线"):
|
|
||||||
baseline_dict = {}
|
|
||||||
|
|
||||||
for time_slot, slot_group in group.groupby('time_slot'):
|
|
||||||
if len(slot_group) < CONFIG['min_baseline_samples']:
|
|
||||||
continue
|
|
||||||
|
|
||||||
alpha_std = slot_group['alpha'].std()
|
|
||||||
amt_std = slot_group['total_amt'].std()
|
|
||||||
rank_std = slot_group['rank_pct'].std()
|
|
||||||
|
|
||||||
baseline_dict[time_slot] = {
|
|
||||||
'alpha_mean': float(slot_group['alpha'].mean()),
|
|
||||||
'alpha_std': float(max(alpha_std if pd.notna(alpha_std) else 1.0, 0.1)),
|
|
||||||
'amt_mean': float(slot_group['total_amt'].mean()),
|
|
||||||
'amt_std': float(max(amt_std if pd.notna(amt_std) else slot_group['total_amt'].mean() * 0.5, 1.0)),
|
|
||||||
'rank_mean': float(slot_group['rank_pct'].mean()),
|
|
||||||
'rank_std': float(max(rank_std if pd.notna(rank_std) else 0.2, 0.05)),
|
|
||||||
'sample_count': len(slot_group),
|
|
||||||
}
|
|
||||||
|
|
||||||
if baseline_dict:
|
|
||||||
baselines[concept_id] = baseline_dict
|
|
||||||
|
|
||||||
print(f"计算了 {len(baselines)} 个概念的基线")
|
|
||||||
|
|
||||||
# 保存
|
|
||||||
os.makedirs(BASELINE_DIR, exist_ok=True)
|
|
||||||
baseline_file = os.path.join(BASELINE_DIR, 'realtime_baseline.pkl')
|
|
||||||
|
|
||||||
with open(baseline_file, 'wb') as f:
|
|
||||||
pickle.dump({
|
|
||||||
'baselines': baselines,
|
|
||||||
'update_time': datetime.now().isoformat(),
|
|
||||||
'date_range': [recent_days[0], recent_days[-1]],
|
|
||||||
'baseline_days': baseline_days,
|
|
||||||
}, f)
|
|
||||||
|
|
||||||
print(f"基线已保存: {baseline_file}")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import argparse
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument('--days', type=int, default=20, help='基线天数')
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
update_rolling_baseline(args.days)
|
|
||||||
@@ -23,11 +23,17 @@ import {
|
|||||||
Th,
|
Th,
|
||||||
Td,
|
Td,
|
||||||
TableContainer,
|
TableContainer,
|
||||||
|
Popover,
|
||||||
|
PopoverTrigger,
|
||||||
|
PopoverContent,
|
||||||
|
PopoverBody,
|
||||||
|
Portal,
|
||||||
} from '@chakra-ui/react';
|
} from '@chakra-ui/react';
|
||||||
import { FaArrowUp, FaArrowDown, FaFire, FaChevronDown, FaChevronRight } from 'react-icons/fa';
|
import { FaArrowUp, FaArrowDown, FaFire, FaChevronDown, FaChevronRight } from 'react-icons/fa';
|
||||||
import { useNavigate } from 'react-router-dom';
|
import { useNavigate } from 'react-router-dom';
|
||||||
import axios from 'axios';
|
import axios from 'axios';
|
||||||
import { getAlertTypeLabel, formatScore, getScoreColor } from '../utils/chartHelpers';
|
import { getAlertTypeLabel, formatScore, getScoreColor } from '../utils/chartHelpers';
|
||||||
|
import MiniTimelineChart from '@components/Charts/Stock/MiniTimelineChart';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 紧凑型异动卡片
|
* 紧凑型异动卡片
|
||||||
@@ -38,6 +44,8 @@ const AlertCard = ({ alert, isExpanded, onToggle, stocks, loadingStocks }) => {
|
|||||||
const hoverBg = useColorModeValue('gray.50', '#252525');
|
const hoverBg = useColorModeValue('gray.50', '#252525');
|
||||||
const borderColor = useColorModeValue('gray.200', '#333');
|
const borderColor = useColorModeValue('gray.200', '#333');
|
||||||
const expandedBg = useColorModeValue('purple.50', '#1e1e2e');
|
const expandedBg = useColorModeValue('purple.50', '#1e1e2e');
|
||||||
|
const tableBg = useColorModeValue('gray.50', '#151520');
|
||||||
|
const popoverBg = useColorModeValue('white', '#1a1a1a');
|
||||||
|
|
||||||
const isUp = alert.alert_type !== 'surge_down';
|
const isUp = alert.alert_type !== 'surge_down';
|
||||||
const typeColor = isUp ? 'red' : 'green';
|
const typeColor = isUp ? 'red' : 'green';
|
||||||
@@ -175,7 +183,7 @@ const AlertCard = ({ alert, isExpanded, onToggle, stocks, loadingStocks }) => {
|
|||||||
borderTopWidth="1px"
|
borderTopWidth="1px"
|
||||||
borderColor={borderColor}
|
borderColor={borderColor}
|
||||||
p={3}
|
p={3}
|
||||||
bg={useColorModeValue('gray.50', '#151520')}
|
bg={tableBg}
|
||||||
>
|
>
|
||||||
{loadingStocks ? (
|
{loadingStocks ? (
|
||||||
<HStack justify="center" py={4}>
|
<HStack justify="center" py={4}>
|
||||||
@@ -183,9 +191,33 @@ const AlertCard = ({ alert, isExpanded, onToggle, stocks, loadingStocks }) => {
|
|||||||
<Text fontSize="sm" color="gray.500">加载相关股票...</Text>
|
<Text fontSize="sm" color="gray.500">加载相关股票...</Text>
|
||||||
</HStack>
|
</HStack>
|
||||||
) : stocks && stocks.length > 0 ? (
|
) : stocks && stocks.length > 0 ? (
|
||||||
|
<>
|
||||||
|
{/* 概念涨跌幅统计 */}
|
||||||
|
{(() => {
|
||||||
|
const validStocks = stocks.filter(s => s.change_pct != null && !isNaN(s.change_pct));
|
||||||
|
if (validStocks.length === 0) return null;
|
||||||
|
const avgChange = validStocks.reduce((sum, s) => sum + s.change_pct, 0) / validStocks.length;
|
||||||
|
const upCount = validStocks.filter(s => s.change_pct > 0).length;
|
||||||
|
const downCount = validStocks.filter(s => s.change_pct < 0).length;
|
||||||
|
return (
|
||||||
|
<HStack spacing={4} mb={2} fontSize="xs" color="gray.500">
|
||||||
|
<HStack>
|
||||||
|
<Text>概念均涨:</Text>
|
||||||
|
<Text fontWeight="bold" color={avgChange >= 0 ? 'red.400' : 'green.400'}>
|
||||||
|
{avgChange >= 0 ? '+' : ''}{avgChange.toFixed(2)}%
|
||||||
|
</Text>
|
||||||
|
</HStack>
|
||||||
|
<HStack spacing={1}>
|
||||||
|
<Text color="red.400">{upCount}涨</Text>
|
||||||
|
<Text>/</Text>
|
||||||
|
<Text color="green.400">{downCount}跌</Text>
|
||||||
|
</HStack>
|
||||||
|
</HStack>
|
||||||
|
);
|
||||||
|
})()}
|
||||||
<TableContainer maxH="200px" overflowY="auto">
|
<TableContainer maxH="200px" overflowY="auto">
|
||||||
<Table size="sm" variant="simple">
|
<Table size="sm" variant="simple">
|
||||||
<Thead position="sticky" top={0} bg={useColorModeValue('gray.50', '#151520')} zIndex={1}>
|
<Thead position="sticky" top={0} bg={tableBg} zIndex={1}>
|
||||||
<Tr>
|
<Tr>
|
||||||
<Th px={2} py={1} fontSize="xs" color="gray.500">股票</Th>
|
<Th px={2} py={1} fontSize="xs" color="gray.500">股票</Th>
|
||||||
<Th px={2} py={1} fontSize="xs" color="gray.500" isNumeric>涨跌</Th>
|
<Th px={2} py={1} fontSize="xs" color="gray.500" isNumeric>涨跌</Th>
|
||||||
@@ -196,17 +228,45 @@ const AlertCard = ({ alert, isExpanded, onToggle, stocks, loadingStocks }) => {
|
|||||||
{stocks.slice(0, 10).map((stock, idx) => {
|
{stocks.slice(0, 10).map((stock, idx) => {
|
||||||
const changePct = stock.change_pct;
|
const changePct = stock.change_pct;
|
||||||
const hasChange = changePct != null && !isNaN(changePct);
|
const hasChange = changePct != null && !isNaN(changePct);
|
||||||
|
const stockCode = stock.code || stock.stock_code;
|
||||||
|
const stockName = stock.name || stock.stock_name || '-';
|
||||||
return (
|
return (
|
||||||
<Tr
|
<Tr
|
||||||
key={idx}
|
key={idx}
|
||||||
cursor="pointer"
|
cursor="pointer"
|
||||||
_hover={{ bg: hoverBg }}
|
_hover={{ bg: hoverBg }}
|
||||||
onClick={(e) => handleStockClick(e, stock.code || stock.stock_code)}
|
onClick={(e) => handleStockClick(e, stockCode)}
|
||||||
>
|
>
|
||||||
<Td px={2} py={1.5}>
|
<Td px={2} py={1.5}>
|
||||||
<Text fontSize="xs" color="cyan.400" fontWeight="medium">
|
<Popover trigger="hover" placement="right" isLazy>
|
||||||
{stock.name || stock.stock_name || '-'}
|
<PopoverTrigger>
|
||||||
|
<Text
|
||||||
|
fontSize="xs"
|
||||||
|
color="cyan.400"
|
||||||
|
fontWeight="medium"
|
||||||
|
_hover={{ textDecoration: 'underline' }}
|
||||||
|
>
|
||||||
|
{stockName}
|
||||||
</Text>
|
</Text>
|
||||||
|
</PopoverTrigger>
|
||||||
|
<Portal>
|
||||||
|
<PopoverContent
|
||||||
|
w="180px"
|
||||||
|
h="80px"
|
||||||
|
bg={popoverBg}
|
||||||
|
borderColor={borderColor}
|
||||||
|
boxShadow="lg"
|
||||||
|
onClick={(e) => e.stopPropagation()}
|
||||||
|
>
|
||||||
|
<PopoverBody p={2}>
|
||||||
|
<Text fontSize="xs" color="gray.500" mb={1}>{stockName} 分时</Text>
|
||||||
|
<Box h="50px">
|
||||||
|
<MiniTimelineChart stockCode={stockCode} />
|
||||||
|
</Box>
|
||||||
|
</PopoverBody>
|
||||||
|
</PopoverContent>
|
||||||
|
</Portal>
|
||||||
|
</Popover>
|
||||||
</Td>
|
</Td>
|
||||||
<Td px={2} py={1.5} isNumeric>
|
<Td px={2} py={1.5} isNumeric>
|
||||||
<Text
|
<Text
|
||||||
@@ -239,6 +299,7 @@ const AlertCard = ({ alert, isExpanded, onToggle, stocks, loadingStocks }) => {
|
|||||||
</Text>
|
</Text>
|
||||||
)}
|
)}
|
||||||
</TableContainer>
|
</TableContainer>
|
||||||
|
</>
|
||||||
) : (
|
) : (
|
||||||
<Text fontSize="sm" color="gray.500" textAlign="center" py={2}>
|
<Text fontSize="sm" color="gray.500" textAlign="center" py={2}>
|
||||||
暂无相关股票数据
|
暂无相关股票数据
|
||||||
|
|||||||
@@ -843,8 +843,22 @@ const StockOverview = () => {
|
|||||||
</Box>
|
</Box>
|
||||||
|
|
||||||
{/* 热点概览 - 大盘走势 + 概念异动 */}
|
{/* 热点概览 - 大盘走势 + 概念异动 */}
|
||||||
|
{/* 只在 selectedDate 确定后渲染,避免 null → 日期 的双重请求 */}
|
||||||
<Box mb={10}>
|
<Box mb={10}>
|
||||||
|
{selectedDate ? (
|
||||||
<HotspotOverview selectedDate={selectedDate} />
|
<HotspotOverview selectedDate={selectedDate} />
|
||||||
|
) : (
|
||||||
|
<Card bg={cardBg} borderWidth="1px" borderColor={borderColor}>
|
||||||
|
<CardBody>
|
||||||
|
<Center h="400px">
|
||||||
|
<VStack spacing={4}>
|
||||||
|
<Spinner size="xl" color="purple.500" thickness="4px" />
|
||||||
|
<Text color={subTextColor}>加载热点概览数据...</Text>
|
||||||
|
</VStack>
|
||||||
|
</Center>
|
||||||
|
</CardBody>
|
||||||
|
</Card>
|
||||||
|
)}
|
||||||
</Box>
|
</Box>
|
||||||
|
|
||||||
{/* 灵活屏 - 实时行情监控 */}
|
{/* 灵活屏 - 实时行情监控 */}
|
||||||
|
|||||||
Reference in New Issue
Block a user