482 lines
15 KiB
Python
482 lines
15 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
"""
|
|
历史异动回测脚本
|
|
|
|
使用训练好的模型,对历史数据进行异动检测,生成异动记录
|
|
|
|
使用方法:
|
|
# 回测指定日期范围
|
|
python backtest.py --start 2024-01-01 --end 2024-12-01
|
|
|
|
# 回测单天
|
|
python backtest.py --start 2024-11-01 --end 2024-11-01
|
|
|
|
# 只生成结果,不写入数据库
|
|
python backtest.py --start 2024-01-01 --dry-run
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import argparse
|
|
import json
|
|
from datetime import datetime, timedelta
|
|
from pathlib import Path
|
|
from typing import Dict, List, Tuple, Optional
|
|
from collections import defaultdict
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
import torch
|
|
from tqdm import tqdm
|
|
from sqlalchemy import create_engine, text
|
|
|
|
# 添加父目录到路径
|
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
from model import TransformerAutoencoder
|
|
|
|
|
|
# ==================== 配置 ====================
|
|
|
|
MYSQL_ENGINE = create_engine(
|
|
"mysql+pymysql://root:Zzl5588161!@192.168.1.5:3306/stock",
|
|
echo=False
|
|
)
|
|
|
|
# 特征列表(与训练一致)
|
|
FEATURES = [
|
|
'alpha',
|
|
'alpha_delta',
|
|
'amt_ratio',
|
|
'amt_delta',
|
|
'rank_pct',
|
|
'limit_up_ratio',
|
|
]
|
|
|
|
# 回测配置
|
|
BACKTEST_CONFIG = {
|
|
'seq_len': 30, # 序列长度
|
|
'threshold_key': 'p95', # 使用的阈值
|
|
'min_alpha_abs': 0.5, # 最小 Alpha 绝对值(过滤微小波动)
|
|
'cooldown_minutes': 8, # 同一概念冷却时间
|
|
'max_alerts_per_minute': 15, # 每分钟最多异动数
|
|
'clip_value': 10.0, # 极端值截断
|
|
}
|
|
|
|
|
|
# ==================== 模型加载 ====================
|
|
|
|
class AnomalyDetector:
|
|
"""异动检测器"""
|
|
|
|
def __init__(self, checkpoint_dir: str = 'ml/checkpoints', device: str = 'auto'):
|
|
self.checkpoint_dir = Path(checkpoint_dir)
|
|
|
|
# 设备
|
|
if device == 'auto':
|
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
else:
|
|
self.device = torch.device(device)
|
|
|
|
# 加载配置
|
|
self._load_config()
|
|
|
|
# 加载模型
|
|
self._load_model()
|
|
|
|
# 加载阈值
|
|
self._load_thresholds()
|
|
|
|
print(f"AnomalyDetector 初始化完成")
|
|
print(f" 设备: {self.device}")
|
|
print(f" 阈值 ({BACKTEST_CONFIG['threshold_key']}): {self.threshold:.6f}")
|
|
|
|
def _load_config(self):
|
|
config_path = self.checkpoint_dir / 'config.json'
|
|
with open(config_path, 'r') as f:
|
|
self.config = json.load(f)
|
|
|
|
def _load_model(self):
|
|
model_path = self.checkpoint_dir / 'best_model.pt'
|
|
checkpoint = torch.load(model_path, map_location=self.device)
|
|
|
|
model_config = self.config['model'].copy()
|
|
model_config['use_instance_norm'] = self.config.get('use_instance_norm', True)
|
|
|
|
self.model = TransformerAutoencoder(**model_config)
|
|
self.model.load_state_dict(checkpoint['model_state_dict'])
|
|
self.model.to(self.device)
|
|
self.model.eval()
|
|
|
|
def _load_thresholds(self):
|
|
thresholds_path = self.checkpoint_dir / 'thresholds.json'
|
|
with open(thresholds_path, 'r') as f:
|
|
thresholds = json.load(f)
|
|
|
|
self.threshold = thresholds[BACKTEST_CONFIG['threshold_key']]
|
|
|
|
@torch.no_grad()
|
|
def compute_anomaly_scores(self, sequences: np.ndarray) -> np.ndarray:
|
|
"""
|
|
计算异动分数
|
|
|
|
Args:
|
|
sequences: (n_sequences, seq_len, n_features)
|
|
Returns:
|
|
scores: (n_sequences,) 每个序列最后时刻的异动分数
|
|
"""
|
|
# 截断极端值
|
|
sequences = np.clip(sequences, -BACKTEST_CONFIG['clip_value'], BACKTEST_CONFIG['clip_value'])
|
|
|
|
# 转为 tensor
|
|
x = torch.FloatTensor(sequences).to(self.device)
|
|
|
|
# 计算重构误差
|
|
errors = self.model.compute_reconstruction_error(x, reduction='none')
|
|
|
|
# 取最后一个时刻的误差
|
|
scores = errors[:, -1].cpu().numpy()
|
|
|
|
return scores
|
|
|
|
def is_anomaly(self, score: float) -> bool:
|
|
"""判断是否异动"""
|
|
return score > self.threshold
|
|
|
|
|
|
# ==================== 数据加载 ====================
|
|
|
|
def load_daily_features(data_dir: str, date: str) -> Optional[pd.DataFrame]:
|
|
"""加载单天的特征数据"""
|
|
file_path = Path(data_dir) / f"features_{date}.parquet"
|
|
|
|
if not file_path.exists():
|
|
return None
|
|
|
|
df = pd.read_parquet(file_path)
|
|
return df
|
|
|
|
|
|
def get_available_dates(data_dir: str, start_date: str, end_date: str) -> List[str]:
|
|
"""获取可用的日期列表"""
|
|
data_path = Path(data_dir)
|
|
all_files = sorted(data_path.glob("features_*.parquet"))
|
|
|
|
dates = []
|
|
for f in all_files:
|
|
date = f.stem.replace('features_', '')
|
|
if start_date <= date <= end_date:
|
|
dates.append(date)
|
|
|
|
return dates
|
|
|
|
|
|
# ==================== 回测逻辑 ====================
|
|
|
|
def backtest_single_day(
|
|
detector: AnomalyDetector,
|
|
df: pd.DataFrame,
|
|
date: str,
|
|
seq_len: int = 30
|
|
) -> List[Dict]:
|
|
"""
|
|
回测单天数据
|
|
|
|
Args:
|
|
detector: 异动检测器
|
|
df: 当天的特征数据
|
|
date: 日期
|
|
seq_len: 序列长度
|
|
|
|
Returns:
|
|
alerts: 异动列表
|
|
"""
|
|
alerts = []
|
|
|
|
# 按概念分组
|
|
grouped = df.groupby('concept_id', sort=False)
|
|
|
|
# 冷却记录 {concept_id: last_alert_timestamp}
|
|
cooldown = {}
|
|
|
|
# 获取所有时间点
|
|
all_timestamps = sorted(df['timestamp'].unique())
|
|
|
|
if len(all_timestamps) < seq_len:
|
|
return alerts
|
|
|
|
# 对每个时间点进行检测(从第 seq_len 个开始)
|
|
for t_idx in range(seq_len - 1, len(all_timestamps)):
|
|
current_time = all_timestamps[t_idx]
|
|
window_start_time = all_timestamps[t_idx - seq_len + 1]
|
|
|
|
minute_alerts = []
|
|
|
|
# 收集该时刻所有概念的序列
|
|
concept_sequences = []
|
|
concept_infos = []
|
|
|
|
for concept_id, concept_df in grouped:
|
|
# 获取该概念在时间窗口内的数据
|
|
mask = (concept_df['timestamp'] >= window_start_time) & (concept_df['timestamp'] <= current_time)
|
|
window_df = concept_df[mask].sort_values('timestamp')
|
|
|
|
if len(window_df) < seq_len:
|
|
continue
|
|
|
|
# 取最后 seq_len 个点
|
|
window_df = window_df.tail(seq_len)
|
|
|
|
# 提取特征
|
|
features = window_df[FEATURES].values
|
|
|
|
# 处理缺失值
|
|
features = np.nan_to_num(features, nan=0.0, posinf=0.0, neginf=0.0)
|
|
|
|
# 获取当前时刻的信息
|
|
current_row = window_df.iloc[-1]
|
|
|
|
concept_sequences.append(features)
|
|
concept_infos.append({
|
|
'concept_id': concept_id,
|
|
'timestamp': current_time,
|
|
'alpha': current_row.get('alpha', 0),
|
|
'alpha_delta': current_row.get('alpha_delta', 0),
|
|
'amt_ratio': current_row.get('amt_ratio', 1),
|
|
'limit_up_ratio': current_row.get('limit_up_ratio', 0),
|
|
'limit_down_ratio': current_row.get('limit_down_ratio', 0),
|
|
'rank_pct': current_row.get('rank_pct', 0.5),
|
|
'stock_count': current_row.get('stock_count', 0),
|
|
'total_amt': current_row.get('total_amt', 0),
|
|
})
|
|
|
|
if not concept_sequences:
|
|
continue
|
|
|
|
# 批量计算异动分数
|
|
sequences_array = np.array(concept_sequences)
|
|
scores = detector.compute_anomaly_scores(sequences_array)
|
|
|
|
# 检测异动
|
|
for i, (info, score) in enumerate(zip(concept_infos, scores)):
|
|
concept_id = info['concept_id']
|
|
alpha = info['alpha']
|
|
|
|
# 过滤小波动
|
|
if abs(alpha) < BACKTEST_CONFIG['min_alpha_abs']:
|
|
continue
|
|
|
|
# 检查冷却
|
|
if concept_id in cooldown:
|
|
last_alert = cooldown[concept_id]
|
|
if isinstance(current_time, datetime):
|
|
time_diff = (current_time - last_alert).total_seconds() / 60
|
|
else:
|
|
# timestamp 是字符串或其他格式
|
|
time_diff = BACKTEST_CONFIG['cooldown_minutes'] + 1 # 跳过冷却检查
|
|
|
|
if time_diff < BACKTEST_CONFIG['cooldown_minutes']:
|
|
continue
|
|
|
|
# 判断是否异动
|
|
if not detector.is_anomaly(score):
|
|
continue
|
|
|
|
# 记录异动
|
|
alert_type = 'surge_up' if alpha > 0 else 'surge_down'
|
|
|
|
alert = {
|
|
'concept_id': concept_id,
|
|
'alert_time': current_time,
|
|
'trade_date': date,
|
|
'alert_type': alert_type,
|
|
'anomaly_score': float(score),
|
|
'threshold': detector.threshold,
|
|
**info
|
|
}
|
|
|
|
minute_alerts.append(alert)
|
|
cooldown[concept_id] = current_time
|
|
|
|
# 按分数排序,限制数量
|
|
minute_alerts.sort(key=lambda x: x['anomaly_score'], reverse=True)
|
|
alerts.extend(minute_alerts[:BACKTEST_CONFIG['max_alerts_per_minute']])
|
|
|
|
return alerts
|
|
|
|
|
|
# ==================== 数据库写入 ====================
|
|
|
|
def save_alerts_to_mysql(alerts: List[Dict], dry_run: bool = False) -> int:
|
|
"""保存异动到 MySQL"""
|
|
if not alerts:
|
|
return 0
|
|
|
|
if dry_run:
|
|
print(f" [Dry Run] 将写入 {len(alerts)} 条异动")
|
|
return len(alerts)
|
|
|
|
saved = 0
|
|
with MYSQL_ENGINE.begin() as conn:
|
|
for alert in alerts:
|
|
try:
|
|
# 检查是否已存在
|
|
check_sql = text("""
|
|
SELECT id FROM concept_minute_alert
|
|
WHERE concept_id = :concept_id
|
|
AND alert_time = :alert_time
|
|
AND trade_date = :trade_date
|
|
""")
|
|
exists = conn.execute(check_sql, {
|
|
'concept_id': alert['concept_id'],
|
|
'alert_time': alert['alert_time'],
|
|
'trade_date': alert['trade_date'],
|
|
}).fetchone()
|
|
|
|
if exists:
|
|
continue
|
|
|
|
# 插入新记录
|
|
insert_sql = text("""
|
|
INSERT INTO concept_minute_alert
|
|
(concept_id, concept_name, alert_time, alert_type, trade_date,
|
|
change_pct, zscore, importance_score, stock_count, extra_info)
|
|
VALUES
|
|
(:concept_id, :concept_name, :alert_time, :alert_type, :trade_date,
|
|
:change_pct, :zscore, :importance_score, :stock_count, :extra_info)
|
|
""")
|
|
|
|
conn.execute(insert_sql, {
|
|
'concept_id': alert['concept_id'],
|
|
'concept_name': alert.get('concept_name', ''),
|
|
'alert_time': alert['alert_time'],
|
|
'alert_type': alert['alert_type'],
|
|
'trade_date': alert['trade_date'],
|
|
'change_pct': alert.get('alpha', 0),
|
|
'zscore': alert['anomaly_score'],
|
|
'importance_score': alert['anomaly_score'],
|
|
'stock_count': alert.get('stock_count', 0),
|
|
'extra_info': json.dumps({
|
|
'detection_method': 'ml_autoencoder',
|
|
'threshold': alert['threshold'],
|
|
'alpha': alert.get('alpha', 0),
|
|
'amt_ratio': alert.get('amt_ratio', 1),
|
|
}, ensure_ascii=False)
|
|
})
|
|
|
|
saved += 1
|
|
|
|
except Exception as e:
|
|
print(f" 保存失败: {alert['concept_id']} - {e}")
|
|
|
|
return saved
|
|
|
|
|
|
def export_alerts_to_csv(alerts: List[Dict], output_path: str):
|
|
"""导出异动到 CSV"""
|
|
if not alerts:
|
|
return
|
|
|
|
df = pd.DataFrame(alerts)
|
|
df.to_csv(output_path, index=False, encoding='utf-8-sig')
|
|
print(f"已导出到: {output_path}")
|
|
|
|
|
|
# ==================== 主函数 ====================
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description='历史异动回测')
|
|
parser.add_argument('--data_dir', type=str, default='ml/data',
|
|
help='特征数据目录')
|
|
parser.add_argument('--checkpoint_dir', type=str, default='ml/checkpoints',
|
|
help='模型检查点目录')
|
|
parser.add_argument('--start', type=str, required=True,
|
|
help='开始日期 (YYYY-MM-DD)')
|
|
parser.add_argument('--end', type=str, required=True,
|
|
help='结束日期 (YYYY-MM-DD)')
|
|
parser.add_argument('--dry-run', action='store_true',
|
|
help='只计算,不写入数据库')
|
|
parser.add_argument('--export-csv', type=str, default=None,
|
|
help='导出 CSV 文件路径')
|
|
parser.add_argument('--device', type=str, default='auto',
|
|
help='设备 (auto/cuda/cpu)')
|
|
|
|
args = parser.parse_args()
|
|
|
|
print("=" * 60)
|
|
print("历史异动回测")
|
|
print("=" * 60)
|
|
print(f"日期范围: {args.start} ~ {args.end}")
|
|
print(f"数据目录: {args.data_dir}")
|
|
print(f"模型目录: {args.checkpoint_dir}")
|
|
print(f"Dry Run: {args.dry_run}")
|
|
print("=" * 60)
|
|
|
|
# 初始化检测器
|
|
detector = AnomalyDetector(args.checkpoint_dir, args.device)
|
|
|
|
# 获取可用日期
|
|
dates = get_available_dates(args.data_dir, args.start, args.end)
|
|
|
|
if not dates:
|
|
print(f"未找到 {args.start} ~ {args.end} 范围内的数据")
|
|
return
|
|
|
|
print(f"\n找到 {len(dates)} 天的数据")
|
|
|
|
# 回测
|
|
all_alerts = []
|
|
total_saved = 0
|
|
|
|
for date in tqdm(dates, desc="回测进度"):
|
|
# 加载数据
|
|
df = load_daily_features(args.data_dir, date)
|
|
|
|
if df is None or df.empty:
|
|
continue
|
|
|
|
# 回测单天
|
|
alerts = backtest_single_day(
|
|
detector, df, date,
|
|
seq_len=BACKTEST_CONFIG['seq_len']
|
|
)
|
|
|
|
if alerts:
|
|
all_alerts.extend(alerts)
|
|
|
|
# 写入数据库
|
|
saved = save_alerts_to_mysql(alerts, dry_run=args.dry_run)
|
|
total_saved += saved
|
|
|
|
if not args.dry_run:
|
|
tqdm.write(f" {date}: 检测到 {len(alerts)} 个异动,保存 {saved} 条")
|
|
|
|
# 导出 CSV
|
|
if args.export_csv and all_alerts:
|
|
export_alerts_to_csv(all_alerts, args.export_csv)
|
|
|
|
# 汇总
|
|
print("\n" + "=" * 60)
|
|
print("回测完成!")
|
|
print("=" * 60)
|
|
print(f"总计检测到: {len(all_alerts)} 个异动")
|
|
print(f"保存到数据库: {total_saved} 条")
|
|
|
|
# 统计
|
|
if all_alerts:
|
|
df_alerts = pd.DataFrame(all_alerts)
|
|
print(f"\n异动类型分布:")
|
|
print(df_alerts['alert_type'].value_counts())
|
|
|
|
print(f"\n异动分数统计:")
|
|
print(f" Mean: {df_alerts['anomaly_score'].mean():.4f}")
|
|
print(f" Max: {df_alerts['anomaly_score'].max():.4f}")
|
|
print(f" Min: {df_alerts['anomaly_score'].min():.4f}")
|
|
|
|
print("=" * 60)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|