update pay ui
This commit is contained in:
481
ml/backtest.py
Normal file
481
ml/backtest.py
Normal file
@@ -0,0 +1,481 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
历史异动回测脚本
|
||||
|
||||
使用训练好的模型,对历史数据进行异动检测,生成异动记录
|
||||
|
||||
使用方法:
|
||||
# 回测指定日期范围
|
||||
python backtest.py --start 2024-01-01 --end 2024-12-01
|
||||
|
||||
# 回测单天
|
||||
python backtest.py --start 2024-11-01 --end 2024-11-01
|
||||
|
||||
# 只生成结果,不写入数据库
|
||||
python backtest.py --start 2024-01-01 --dry-run
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import json
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple, Optional
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from sqlalchemy import create_engine, text
|
||||
|
||||
# 添加父目录到路径
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from model import TransformerAutoencoder
|
||||
|
||||
|
||||
# ==================== 配置 ====================
|
||||
|
||||
MYSQL_ENGINE = create_engine(
|
||||
"mysql+pymysql://root:Zzl5588161!@192.168.1.5:3306/stock",
|
||||
echo=False
|
||||
)
|
||||
|
||||
# 特征列表(与训练一致)
|
||||
FEATURES = [
|
||||
'alpha',
|
||||
'alpha_delta',
|
||||
'amt_ratio',
|
||||
'amt_delta',
|
||||
'rank_pct',
|
||||
'limit_up_ratio',
|
||||
]
|
||||
|
||||
# 回测配置
|
||||
BACKTEST_CONFIG = {
|
||||
'seq_len': 30, # 序列长度
|
||||
'threshold_key': 'p95', # 使用的阈值
|
||||
'min_alpha_abs': 0.5, # 最小 Alpha 绝对值(过滤微小波动)
|
||||
'cooldown_minutes': 8, # 同一概念冷却时间
|
||||
'max_alerts_per_minute': 15, # 每分钟最多异动数
|
||||
'clip_value': 10.0, # 极端值截断
|
||||
}
|
||||
|
||||
|
||||
# ==================== 模型加载 ====================
|
||||
|
||||
class AnomalyDetector:
|
||||
"""异动检测器"""
|
||||
|
||||
def __init__(self, checkpoint_dir: str = 'ml/checkpoints', device: str = 'auto'):
|
||||
self.checkpoint_dir = Path(checkpoint_dir)
|
||||
|
||||
# 设备
|
||||
if device == 'auto':
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
else:
|
||||
self.device = torch.device(device)
|
||||
|
||||
# 加载配置
|
||||
self._load_config()
|
||||
|
||||
# 加载模型
|
||||
self._load_model()
|
||||
|
||||
# 加载阈值
|
||||
self._load_thresholds()
|
||||
|
||||
print(f"AnomalyDetector 初始化完成")
|
||||
print(f" 设备: {self.device}")
|
||||
print(f" 阈值 ({BACKTEST_CONFIG['threshold_key']}): {self.threshold:.6f}")
|
||||
|
||||
def _load_config(self):
|
||||
config_path = self.checkpoint_dir / 'config.json'
|
||||
with open(config_path, 'r') as f:
|
||||
self.config = json.load(f)
|
||||
|
||||
def _load_model(self):
|
||||
model_path = self.checkpoint_dir / 'best_model.pt'
|
||||
checkpoint = torch.load(model_path, map_location=self.device)
|
||||
|
||||
model_config = self.config['model'].copy()
|
||||
model_config['use_instance_norm'] = self.config.get('use_instance_norm', True)
|
||||
|
||||
self.model = TransformerAutoencoder(**model_config)
|
||||
self.model.load_state_dict(checkpoint['model_state_dict'])
|
||||
self.model.to(self.device)
|
||||
self.model.eval()
|
||||
|
||||
def _load_thresholds(self):
|
||||
thresholds_path = self.checkpoint_dir / 'thresholds.json'
|
||||
with open(thresholds_path, 'r') as f:
|
||||
thresholds = json.load(f)
|
||||
|
||||
self.threshold = thresholds[BACKTEST_CONFIG['threshold_key']]
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_anomaly_scores(self, sequences: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
计算异动分数
|
||||
|
||||
Args:
|
||||
sequences: (n_sequences, seq_len, n_features)
|
||||
Returns:
|
||||
scores: (n_sequences,) 每个序列最后时刻的异动分数
|
||||
"""
|
||||
# 截断极端值
|
||||
sequences = np.clip(sequences, -BACKTEST_CONFIG['clip_value'], BACKTEST_CONFIG['clip_value'])
|
||||
|
||||
# 转为 tensor
|
||||
x = torch.FloatTensor(sequences).to(self.device)
|
||||
|
||||
# 计算重构误差
|
||||
errors = self.model.compute_reconstruction_error(x, reduction='none')
|
||||
|
||||
# 取最后一个时刻的误差
|
||||
scores = errors[:, -1].cpu().numpy()
|
||||
|
||||
return scores
|
||||
|
||||
def is_anomaly(self, score: float) -> bool:
|
||||
"""判断是否异动"""
|
||||
return score > self.threshold
|
||||
|
||||
|
||||
# ==================== 数据加载 ====================
|
||||
|
||||
def load_daily_features(data_dir: str, date: str) -> Optional[pd.DataFrame]:
|
||||
"""加载单天的特征数据"""
|
||||
file_path = Path(data_dir) / f"features_{date}.parquet"
|
||||
|
||||
if not file_path.exists():
|
||||
return None
|
||||
|
||||
df = pd.read_parquet(file_path)
|
||||
return df
|
||||
|
||||
|
||||
def get_available_dates(data_dir: str, start_date: str, end_date: str) -> List[str]:
|
||||
"""获取可用的日期列表"""
|
||||
data_path = Path(data_dir)
|
||||
all_files = sorted(data_path.glob("features_*.parquet"))
|
||||
|
||||
dates = []
|
||||
for f in all_files:
|
||||
date = f.stem.replace('features_', '')
|
||||
if start_date <= date <= end_date:
|
||||
dates.append(date)
|
||||
|
||||
return dates
|
||||
|
||||
|
||||
# ==================== 回测逻辑 ====================
|
||||
|
||||
def backtest_single_day(
|
||||
detector: AnomalyDetector,
|
||||
df: pd.DataFrame,
|
||||
date: str,
|
||||
seq_len: int = 30
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
回测单天数据
|
||||
|
||||
Args:
|
||||
detector: 异动检测器
|
||||
df: 当天的特征数据
|
||||
date: 日期
|
||||
seq_len: 序列长度
|
||||
|
||||
Returns:
|
||||
alerts: 异动列表
|
||||
"""
|
||||
alerts = []
|
||||
|
||||
# 按概念分组
|
||||
grouped = df.groupby('concept_id', sort=False)
|
||||
|
||||
# 冷却记录 {concept_id: last_alert_timestamp}
|
||||
cooldown = {}
|
||||
|
||||
# 获取所有时间点
|
||||
all_timestamps = sorted(df['timestamp'].unique())
|
||||
|
||||
if len(all_timestamps) < seq_len:
|
||||
return alerts
|
||||
|
||||
# 对每个时间点进行检测(从第 seq_len 个开始)
|
||||
for t_idx in range(seq_len - 1, len(all_timestamps)):
|
||||
current_time = all_timestamps[t_idx]
|
||||
window_start_time = all_timestamps[t_idx - seq_len + 1]
|
||||
|
||||
minute_alerts = []
|
||||
|
||||
# 收集该时刻所有概念的序列
|
||||
concept_sequences = []
|
||||
concept_infos = []
|
||||
|
||||
for concept_id, concept_df in grouped:
|
||||
# 获取该概念在时间窗口内的数据
|
||||
mask = (concept_df['timestamp'] >= window_start_time) & (concept_df['timestamp'] <= current_time)
|
||||
window_df = concept_df[mask].sort_values('timestamp')
|
||||
|
||||
if len(window_df) < seq_len:
|
||||
continue
|
||||
|
||||
# 取最后 seq_len 个点
|
||||
window_df = window_df.tail(seq_len)
|
||||
|
||||
# 提取特征
|
||||
features = window_df[FEATURES].values
|
||||
|
||||
# 处理缺失值
|
||||
features = np.nan_to_num(features, nan=0.0, posinf=0.0, neginf=0.0)
|
||||
|
||||
# 获取当前时刻的信息
|
||||
current_row = window_df.iloc[-1]
|
||||
|
||||
concept_sequences.append(features)
|
||||
concept_infos.append({
|
||||
'concept_id': concept_id,
|
||||
'timestamp': current_time,
|
||||
'alpha': current_row.get('alpha', 0),
|
||||
'alpha_delta': current_row.get('alpha_delta', 0),
|
||||
'amt_ratio': current_row.get('amt_ratio', 1),
|
||||
'limit_up_ratio': current_row.get('limit_up_ratio', 0),
|
||||
'limit_down_ratio': current_row.get('limit_down_ratio', 0),
|
||||
'rank_pct': current_row.get('rank_pct', 0.5),
|
||||
'stock_count': current_row.get('stock_count', 0),
|
||||
'total_amt': current_row.get('total_amt', 0),
|
||||
})
|
||||
|
||||
if not concept_sequences:
|
||||
continue
|
||||
|
||||
# 批量计算异动分数
|
||||
sequences_array = np.array(concept_sequences)
|
||||
scores = detector.compute_anomaly_scores(sequences_array)
|
||||
|
||||
# 检测异动
|
||||
for i, (info, score) in enumerate(zip(concept_infos, scores)):
|
||||
concept_id = info['concept_id']
|
||||
alpha = info['alpha']
|
||||
|
||||
# 过滤小波动
|
||||
if abs(alpha) < BACKTEST_CONFIG['min_alpha_abs']:
|
||||
continue
|
||||
|
||||
# 检查冷却
|
||||
if concept_id in cooldown:
|
||||
last_alert = cooldown[concept_id]
|
||||
if isinstance(current_time, datetime):
|
||||
time_diff = (current_time - last_alert).total_seconds() / 60
|
||||
else:
|
||||
# timestamp 是字符串或其他格式
|
||||
time_diff = BACKTEST_CONFIG['cooldown_minutes'] + 1 # 跳过冷却检查
|
||||
|
||||
if time_diff < BACKTEST_CONFIG['cooldown_minutes']:
|
||||
continue
|
||||
|
||||
# 判断是否异动
|
||||
if not detector.is_anomaly(score):
|
||||
continue
|
||||
|
||||
# 记录异动
|
||||
alert_type = 'surge_up' if alpha > 0 else 'surge_down'
|
||||
|
||||
alert = {
|
||||
'concept_id': concept_id,
|
||||
'alert_time': current_time,
|
||||
'trade_date': date,
|
||||
'alert_type': alert_type,
|
||||
'anomaly_score': float(score),
|
||||
'threshold': detector.threshold,
|
||||
**info
|
||||
}
|
||||
|
||||
minute_alerts.append(alert)
|
||||
cooldown[concept_id] = current_time
|
||||
|
||||
# 按分数排序,限制数量
|
||||
minute_alerts.sort(key=lambda x: x['anomaly_score'], reverse=True)
|
||||
alerts.extend(minute_alerts[:BACKTEST_CONFIG['max_alerts_per_minute']])
|
||||
|
||||
return alerts
|
||||
|
||||
|
||||
# ==================== 数据库写入 ====================
|
||||
|
||||
def save_alerts_to_mysql(alerts: List[Dict], dry_run: bool = False) -> int:
|
||||
"""保存异动到 MySQL"""
|
||||
if not alerts:
|
||||
return 0
|
||||
|
||||
if dry_run:
|
||||
print(f" [Dry Run] 将写入 {len(alerts)} 条异动")
|
||||
return len(alerts)
|
||||
|
||||
saved = 0
|
||||
with MYSQL_ENGINE.begin() as conn:
|
||||
for alert in alerts:
|
||||
try:
|
||||
# 检查是否已存在
|
||||
check_sql = text("""
|
||||
SELECT id FROM concept_minute_alert
|
||||
WHERE concept_id = :concept_id
|
||||
AND alert_time = :alert_time
|
||||
AND trade_date = :trade_date
|
||||
""")
|
||||
exists = conn.execute(check_sql, {
|
||||
'concept_id': alert['concept_id'],
|
||||
'alert_time': alert['alert_time'],
|
||||
'trade_date': alert['trade_date'],
|
||||
}).fetchone()
|
||||
|
||||
if exists:
|
||||
continue
|
||||
|
||||
# 插入新记录
|
||||
insert_sql = text("""
|
||||
INSERT INTO concept_minute_alert
|
||||
(concept_id, concept_name, alert_time, alert_type, trade_date,
|
||||
change_pct, zscore, importance_score, stock_count, extra_info)
|
||||
VALUES
|
||||
(:concept_id, :concept_name, :alert_time, :alert_type, :trade_date,
|
||||
:change_pct, :zscore, :importance_score, :stock_count, :extra_info)
|
||||
""")
|
||||
|
||||
conn.execute(insert_sql, {
|
||||
'concept_id': alert['concept_id'],
|
||||
'concept_name': alert.get('concept_name', ''),
|
||||
'alert_time': alert['alert_time'],
|
||||
'alert_type': alert['alert_type'],
|
||||
'trade_date': alert['trade_date'],
|
||||
'change_pct': alert.get('alpha', 0),
|
||||
'zscore': alert['anomaly_score'],
|
||||
'importance_score': alert['anomaly_score'],
|
||||
'stock_count': alert.get('stock_count', 0),
|
||||
'extra_info': json.dumps({
|
||||
'detection_method': 'ml_autoencoder',
|
||||
'threshold': alert['threshold'],
|
||||
'alpha': alert.get('alpha', 0),
|
||||
'amt_ratio': alert.get('amt_ratio', 1),
|
||||
}, ensure_ascii=False)
|
||||
})
|
||||
|
||||
saved += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f" 保存失败: {alert['concept_id']} - {e}")
|
||||
|
||||
return saved
|
||||
|
||||
|
||||
def export_alerts_to_csv(alerts: List[Dict], output_path: str):
|
||||
"""导出异动到 CSV"""
|
||||
if not alerts:
|
||||
return
|
||||
|
||||
df = pd.DataFrame(alerts)
|
||||
df.to_csv(output_path, index=False, encoding='utf-8-sig')
|
||||
print(f"已导出到: {output_path}")
|
||||
|
||||
|
||||
# ==================== 主函数 ====================
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='历史异动回测')
|
||||
parser.add_argument('--data_dir', type=str, default='ml/data',
|
||||
help='特征数据目录')
|
||||
parser.add_argument('--checkpoint_dir', type=str, default='ml/checkpoints',
|
||||
help='模型检查点目录')
|
||||
parser.add_argument('--start', type=str, required=True,
|
||||
help='开始日期 (YYYY-MM-DD)')
|
||||
parser.add_argument('--end', type=str, required=True,
|
||||
help='结束日期 (YYYY-MM-DD)')
|
||||
parser.add_argument('--dry-run', action='store_true',
|
||||
help='只计算,不写入数据库')
|
||||
parser.add_argument('--export-csv', type=str, default=None,
|
||||
help='导出 CSV 文件路径')
|
||||
parser.add_argument('--device', type=str, default='auto',
|
||||
help='设备 (auto/cuda/cpu)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print("=" * 60)
|
||||
print("历史异动回测")
|
||||
print("=" * 60)
|
||||
print(f"日期范围: {args.start} ~ {args.end}")
|
||||
print(f"数据目录: {args.data_dir}")
|
||||
print(f"模型目录: {args.checkpoint_dir}")
|
||||
print(f"Dry Run: {args.dry_run}")
|
||||
print("=" * 60)
|
||||
|
||||
# 初始化检测器
|
||||
detector = AnomalyDetector(args.checkpoint_dir, args.device)
|
||||
|
||||
# 获取可用日期
|
||||
dates = get_available_dates(args.data_dir, args.start, args.end)
|
||||
|
||||
if not dates:
|
||||
print(f"未找到 {args.start} ~ {args.end} 范围内的数据")
|
||||
return
|
||||
|
||||
print(f"\n找到 {len(dates)} 天的数据")
|
||||
|
||||
# 回测
|
||||
all_alerts = []
|
||||
total_saved = 0
|
||||
|
||||
for date in tqdm(dates, desc="回测进度"):
|
||||
# 加载数据
|
||||
df = load_daily_features(args.data_dir, date)
|
||||
|
||||
if df is None or df.empty:
|
||||
continue
|
||||
|
||||
# 回测单天
|
||||
alerts = backtest_single_day(
|
||||
detector, df, date,
|
||||
seq_len=BACKTEST_CONFIG['seq_len']
|
||||
)
|
||||
|
||||
if alerts:
|
||||
all_alerts.extend(alerts)
|
||||
|
||||
# 写入数据库
|
||||
saved = save_alerts_to_mysql(alerts, dry_run=args.dry_run)
|
||||
total_saved += saved
|
||||
|
||||
if not args.dry_run:
|
||||
tqdm.write(f" {date}: 检测到 {len(alerts)} 个异动,保存 {saved} 条")
|
||||
|
||||
# 导出 CSV
|
||||
if args.export_csv and all_alerts:
|
||||
export_alerts_to_csv(all_alerts, args.export_csv)
|
||||
|
||||
# 汇总
|
||||
print("\n" + "=" * 60)
|
||||
print("回测完成!")
|
||||
print("=" * 60)
|
||||
print(f"总计检测到: {len(all_alerts)} 个异动")
|
||||
print(f"保存到数据库: {total_saved} 条")
|
||||
|
||||
# 统计
|
||||
if all_alerts:
|
||||
df_alerts = pd.DataFrame(all_alerts)
|
||||
print(f"\n异动类型分布:")
|
||||
print(df_alerts['alert_type'].value_counts())
|
||||
|
||||
print(f"\n异动分数统计:")
|
||||
print(f" Mean: {df_alerts['anomaly_score'].mean():.4f}")
|
||||
print(f" Max: {df_alerts['anomaly_score'].max():.4f}")
|
||||
print(f" Min: {df_alerts['anomaly_score'].min():.4f}")
|
||||
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user