482 lines
16 KiB
Python
482 lines
16 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
融合异动回测脚本
|
||
|
||
使用 HybridAnomalyDetector 进行回测:
|
||
- 规则评分 + LSTM Autoencoder 融合判断
|
||
- 输出更丰富的异动信息
|
||
|
||
使用方法:
|
||
python backtest_hybrid.py --start 2024-01-01 --end 2024-12-01
|
||
python backtest_hybrid.py --start 2024-11-01 --dry-run
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import argparse
|
||
import json
|
||
from datetime import datetime
|
||
from pathlib import Path
|
||
from typing import Dict, List, Optional
|
||
from collections import defaultdict
|
||
|
||
import numpy as np
|
||
import pandas as pd
|
||
from tqdm import tqdm
|
||
from sqlalchemy import create_engine, text
|
||
|
||
# 添加父目录到路径
|
||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
|
||
from detector import HybridAnomalyDetector, create_detector
|
||
|
||
|
||
# ==================== 配置 ====================
|
||
|
||
MYSQL_ENGINE = create_engine(
|
||
"mysql+pymysql://root:Zzl5588161!@192.168.1.5:3306/stock",
|
||
echo=False
|
||
)
|
||
|
||
FEATURES = [
|
||
'alpha',
|
||
'alpha_delta',
|
||
'amt_ratio',
|
||
'amt_delta',
|
||
'rank_pct',
|
||
'limit_up_ratio',
|
||
]
|
||
|
||
BACKTEST_CONFIG = {
|
||
'seq_len': 30,
|
||
'min_alpha_abs': 0.3, # 降低阈值,让规则也能发挥作用
|
||
'cooldown_minutes': 8,
|
||
'max_alerts_per_minute': 20,
|
||
'clip_value': 10.0,
|
||
}
|
||
|
||
|
||
# ==================== 数据加载 ====================
|
||
|
||
def load_daily_features(data_dir: str, date: str) -> Optional[pd.DataFrame]:
|
||
"""加载单天的特征数据"""
|
||
file_path = Path(data_dir) / f"features_{date}.parquet"
|
||
|
||
if not file_path.exists():
|
||
return None
|
||
|
||
df = pd.read_parquet(file_path)
|
||
return df
|
||
|
||
|
||
def get_available_dates(data_dir: str, start_date: str, end_date: str) -> List[str]:
|
||
"""获取可用的日期列表"""
|
||
data_path = Path(data_dir)
|
||
all_files = sorted(data_path.glob("features_*.parquet"))
|
||
|
||
dates = []
|
||
for f in all_files:
|
||
date = f.stem.replace('features_', '')
|
||
if start_date <= date <= end_date:
|
||
dates.append(date)
|
||
|
||
return dates
|
||
|
||
|
||
# ==================== 融合回测 ====================
|
||
|
||
def backtest_single_day_hybrid(
|
||
detector: HybridAnomalyDetector,
|
||
df: pd.DataFrame,
|
||
date: str,
|
||
seq_len: int = 30
|
||
) -> List[Dict]:
|
||
"""
|
||
使用融合检测器回测单天数据(批量优化版)
|
||
"""
|
||
alerts = []
|
||
|
||
# 按概念分组,预先构建字典
|
||
grouped_dict = {cid: cdf for cid, cdf in df.groupby('concept_id', sort=False)}
|
||
|
||
# 冷却记录
|
||
cooldown = {}
|
||
|
||
# 获取所有时间点
|
||
all_timestamps = sorted(df['timestamp'].unique())
|
||
|
||
if len(all_timestamps) < seq_len:
|
||
return alerts
|
||
|
||
# 对每个时间点进行检测
|
||
for t_idx in range(seq_len - 1, len(all_timestamps)):
|
||
current_time = all_timestamps[t_idx]
|
||
window_start_time = all_timestamps[t_idx - seq_len + 1]
|
||
|
||
# 批量收集该时刻所有候选概念
|
||
batch_sequences = []
|
||
batch_features = []
|
||
batch_infos = []
|
||
|
||
for concept_id, concept_df in grouped_dict.items():
|
||
# 检查冷却(提前过滤)
|
||
if concept_id in cooldown:
|
||
last_alert = cooldown[concept_id]
|
||
if isinstance(current_time, datetime):
|
||
time_diff = (current_time - last_alert).total_seconds() / 60
|
||
else:
|
||
time_diff = BACKTEST_CONFIG['cooldown_minutes'] + 1
|
||
if time_diff < BACKTEST_CONFIG['cooldown_minutes']:
|
||
continue
|
||
|
||
# 获取时间窗口内的数据
|
||
mask = (concept_df['timestamp'] >= window_start_time) & (concept_df['timestamp'] <= current_time)
|
||
window_df = concept_df.loc[mask]
|
||
|
||
if len(window_df) < seq_len:
|
||
continue
|
||
|
||
window_df = window_df.sort_values('timestamp').tail(seq_len)
|
||
|
||
# 当前时刻特征
|
||
current_row = window_df.iloc[-1]
|
||
alpha = current_row.get('alpha', 0)
|
||
|
||
# 过滤微小波动(提前过滤)
|
||
if abs(alpha) < BACKTEST_CONFIG['min_alpha_abs']:
|
||
continue
|
||
|
||
# 提取特征序列
|
||
sequence = window_df[FEATURES].values
|
||
sequence = np.nan_to_num(sequence, nan=0.0, posinf=0.0, neginf=0.0)
|
||
sequence = np.clip(sequence, -BACKTEST_CONFIG['clip_value'], BACKTEST_CONFIG['clip_value'])
|
||
|
||
current_features = {
|
||
'alpha': alpha,
|
||
'alpha_delta': current_row.get('alpha_delta', 0),
|
||
'amt_ratio': current_row.get('amt_ratio', 1),
|
||
'amt_delta': current_row.get('amt_delta', 0),
|
||
'rank_pct': current_row.get('rank_pct', 0.5),
|
||
'limit_up_ratio': current_row.get('limit_up_ratio', 0),
|
||
}
|
||
|
||
batch_sequences.append(sequence)
|
||
batch_features.append(current_features)
|
||
batch_infos.append({
|
||
'concept_id': concept_id,
|
||
'stock_count': current_row.get('stock_count', 0),
|
||
'total_amt': current_row.get('total_amt', 0),
|
||
})
|
||
|
||
if not batch_sequences:
|
||
continue
|
||
|
||
# 批量 ML 推理
|
||
sequences_array = np.array(batch_sequences)
|
||
ml_scores = detector.ml_scorer.score(sequences_array) if detector.ml_scorer.is_ready() else [0.0] * len(batch_sequences)
|
||
if isinstance(ml_scores, float):
|
||
ml_scores = [ml_scores]
|
||
|
||
# 批量规则评分 + 融合
|
||
minute_alerts = []
|
||
for i, (features, info) in enumerate(zip(batch_features, batch_infos)):
|
||
concept_id = info['concept_id']
|
||
|
||
# 规则评分
|
||
rule_score, rule_details = detector.rule_scorer.score(features)
|
||
|
||
# ML 评分
|
||
ml_score = ml_scores[i] if i < len(ml_scores) else 0.0
|
||
|
||
# 融合
|
||
w1 = detector.config['rule_weight']
|
||
w2 = detector.config['ml_weight']
|
||
final_score = w1 * rule_score + w2 * ml_score
|
||
|
||
# 判断是否异动
|
||
is_anomaly = False
|
||
trigger_reason = ''
|
||
|
||
if rule_score >= detector.config['rule_trigger']:
|
||
is_anomaly = True
|
||
trigger_reason = f'规则强信号({rule_score:.0f}分)'
|
||
elif ml_score >= detector.config['ml_trigger']:
|
||
is_anomaly = True
|
||
trigger_reason = f'ML强信号({ml_score:.0f}分)'
|
||
elif final_score >= detector.config['fusion_trigger']:
|
||
is_anomaly = True
|
||
trigger_reason = f'融合触发({final_score:.0f}分)'
|
||
|
||
if not is_anomaly:
|
||
continue
|
||
|
||
# 异动类型
|
||
alpha = features.get('alpha', 0)
|
||
if alpha >= 1.5:
|
||
anomaly_type = 'surge_up'
|
||
elif alpha <= -1.5:
|
||
anomaly_type = 'surge_down'
|
||
elif features.get('amt_ratio', 1) >= 3.0:
|
||
anomaly_type = 'volume_spike'
|
||
else:
|
||
anomaly_type = 'unknown'
|
||
|
||
alert = {
|
||
'concept_id': concept_id,
|
||
'alert_time': current_time,
|
||
'trade_date': date,
|
||
'alert_type': anomaly_type,
|
||
'final_score': final_score,
|
||
'rule_score': rule_score,
|
||
'ml_score': ml_score,
|
||
'trigger_reason': trigger_reason,
|
||
'triggered_rules': list(rule_details.keys()),
|
||
**features,
|
||
**info,
|
||
}
|
||
|
||
minute_alerts.append(alert)
|
||
cooldown[concept_id] = current_time
|
||
|
||
# 按最终得分排序
|
||
minute_alerts.sort(key=lambda x: x['final_score'], reverse=True)
|
||
alerts.extend(minute_alerts[:BACKTEST_CONFIG['max_alerts_per_minute']])
|
||
|
||
return alerts
|
||
|
||
|
||
# ==================== 数据库写入 ====================
|
||
|
||
def save_alerts_to_mysql(alerts: List[Dict], dry_run: bool = False) -> int:
|
||
"""保存异动到 MySQL(增强版字段)"""
|
||
if not alerts:
|
||
return 0
|
||
|
||
if dry_run:
|
||
print(f" [Dry Run] 将写入 {len(alerts)} 条异动")
|
||
return len(alerts)
|
||
|
||
saved = 0
|
||
with MYSQL_ENGINE.begin() as conn:
|
||
for alert in alerts:
|
||
try:
|
||
# 检查是否已存在
|
||
check_sql = text("""
|
||
SELECT id FROM concept_minute_alert
|
||
WHERE concept_id = :concept_id
|
||
AND alert_time = :alert_time
|
||
AND trade_date = :trade_date
|
||
""")
|
||
exists = conn.execute(check_sql, {
|
||
'concept_id': alert['concept_id'],
|
||
'alert_time': alert['alert_time'],
|
||
'trade_date': alert['trade_date'],
|
||
}).fetchone()
|
||
|
||
if exists:
|
||
continue
|
||
|
||
# 插入新记录
|
||
insert_sql = text("""
|
||
INSERT INTO concept_minute_alert
|
||
(concept_id, concept_name, alert_time, alert_type, trade_date,
|
||
change_pct, zscore, importance_score, stock_count, extra_info)
|
||
VALUES
|
||
(:concept_id, :concept_name, :alert_time, :alert_type, :trade_date,
|
||
:change_pct, :zscore, :importance_score, :stock_count, :extra_info)
|
||
""")
|
||
|
||
extra_info = {
|
||
'detection_method': 'hybrid',
|
||
'final_score': alert['final_score'],
|
||
'rule_score': alert['rule_score'],
|
||
'ml_score': alert['ml_score'],
|
||
'trigger_reason': alert['trigger_reason'],
|
||
'triggered_rules': alert['triggered_rules'],
|
||
'alpha': alert.get('alpha', 0),
|
||
'alpha_delta': alert.get('alpha_delta', 0),
|
||
'amt_ratio': alert.get('amt_ratio', 1),
|
||
}
|
||
|
||
conn.execute(insert_sql, {
|
||
'concept_id': alert['concept_id'],
|
||
'concept_name': alert.get('concept_name', ''),
|
||
'alert_time': alert['alert_time'],
|
||
'alert_type': alert['alert_type'],
|
||
'trade_date': alert['trade_date'],
|
||
'change_pct': alert.get('alpha', 0),
|
||
'zscore': alert['final_score'], # 用最终得分作为 zscore
|
||
'importance_score': alert['final_score'],
|
||
'stock_count': alert.get('stock_count', 0),
|
||
'extra_info': json.dumps(extra_info, ensure_ascii=False)
|
||
})
|
||
|
||
saved += 1
|
||
|
||
except Exception as e:
|
||
print(f" 保存失败: {alert['concept_id']} - {e}")
|
||
|
||
return saved
|
||
|
||
|
||
def export_alerts_to_csv(alerts: List[Dict], output_path: str):
|
||
"""导出异动到 CSV"""
|
||
if not alerts:
|
||
return
|
||
|
||
df = pd.DataFrame(alerts)
|
||
df.to_csv(output_path, index=False, encoding='utf-8-sig')
|
||
print(f"已导出到: {output_path}")
|
||
|
||
|
||
# ==================== 统计分析 ====================
|
||
|
||
def analyze_alerts(alerts: List[Dict]):
|
||
"""分析异动结果"""
|
||
if not alerts:
|
||
print("无异动数据")
|
||
return
|
||
|
||
df = pd.DataFrame(alerts)
|
||
|
||
print("\n" + "=" * 60)
|
||
print("异动统计分析")
|
||
print("=" * 60)
|
||
|
||
# 1. 基本统计
|
||
print(f"\n总异动数: {len(alerts)}")
|
||
|
||
# 2. 按类型统计
|
||
print(f"\n异动类型分布:")
|
||
print(df['alert_type'].value_counts())
|
||
|
||
# 3. 得分统计
|
||
print(f"\n得分统计:")
|
||
print(f" 最终得分 - Mean: {df['final_score'].mean():.1f}, Max: {df['final_score'].max():.1f}")
|
||
print(f" 规则得分 - Mean: {df['rule_score'].mean():.1f}, Max: {df['rule_score'].max():.1f}")
|
||
print(f" ML得分 - Mean: {df['ml_score'].mean():.1f}, Max: {df['ml_score'].max():.1f}")
|
||
|
||
# 4. 触发来源分析
|
||
print(f"\n触发来源分析:")
|
||
trigger_counts = df['trigger_reason'].apply(
|
||
lambda x: '规则' if '规则' in x else ('ML' if 'ML' in x else '融合')
|
||
).value_counts()
|
||
print(trigger_counts)
|
||
|
||
# 5. 规则触发频率
|
||
all_rules = []
|
||
for rules in df['triggered_rules']:
|
||
if isinstance(rules, list):
|
||
all_rules.extend(rules)
|
||
|
||
if all_rules:
|
||
print(f"\n最常触发的规则 (Top 10):")
|
||
from collections import Counter
|
||
rule_counts = Counter(all_rules)
|
||
for rule, count in rule_counts.most_common(10):
|
||
print(f" {rule}: {count}")
|
||
|
||
|
||
# ==================== 主函数 ====================
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(description='融合异动回测')
|
||
parser.add_argument('--data_dir', type=str, default='ml/data',
|
||
help='特征数据目录')
|
||
parser.add_argument('--checkpoint_dir', type=str, default='ml/checkpoints',
|
||
help='模型检查点目录')
|
||
parser.add_argument('--start', type=str, required=True,
|
||
help='开始日期 (YYYY-MM-DD)')
|
||
parser.add_argument('--end', type=str, default=None,
|
||
help='结束日期 (YYYY-MM-DD),默认=start')
|
||
parser.add_argument('--dry-run', action='store_true',
|
||
help='只计算,不写入数据库')
|
||
parser.add_argument('--export-csv', type=str, default=None,
|
||
help='导出 CSV 文件路径')
|
||
parser.add_argument('--rule-weight', type=float, default=0.6,
|
||
help='规则权重 (0-1)')
|
||
parser.add_argument('--ml-weight', type=float, default=0.4,
|
||
help='ML权重 (0-1)')
|
||
parser.add_argument('--device', type=str, default='cuda',
|
||
help='设备 (cuda/cpu),默认 cuda')
|
||
|
||
args = parser.parse_args()
|
||
|
||
if args.end is None:
|
||
args.end = args.start
|
||
|
||
print("=" * 60)
|
||
print("融合异动回测 (规则 + LSTM)")
|
||
print("=" * 60)
|
||
print(f"日期范围: {args.start} ~ {args.end}")
|
||
print(f"数据目录: {args.data_dir}")
|
||
print(f"模型目录: {args.checkpoint_dir}")
|
||
print(f"规则权重: {args.rule_weight}")
|
||
print(f"ML权重: {args.ml_weight}")
|
||
print(f"设备: {args.device}")
|
||
print(f"Dry Run: {args.dry_run}")
|
||
print("=" * 60)
|
||
|
||
# 初始化融合检测器(使用 GPU)
|
||
config = {
|
||
'rule_weight': args.rule_weight,
|
||
'ml_weight': args.ml_weight,
|
||
}
|
||
|
||
# 修改 detector.py 中 MLScorer 的设备
|
||
from detector import HybridAnomalyDetector
|
||
detector = HybridAnomalyDetector(config, args.checkpoint_dir, device=args.device)
|
||
|
||
# 获取可用日期
|
||
dates = get_available_dates(args.data_dir, args.start, args.end)
|
||
|
||
if not dates:
|
||
print(f"未找到 {args.start} ~ {args.end} 范围内的数据")
|
||
return
|
||
|
||
print(f"\n找到 {len(dates)} 天的数据")
|
||
|
||
# 回测
|
||
all_alerts = []
|
||
total_saved = 0
|
||
|
||
for date in tqdm(dates, desc="回测进度"):
|
||
df = load_daily_features(args.data_dir, date)
|
||
|
||
if df is None or df.empty:
|
||
continue
|
||
|
||
alerts = backtest_single_day_hybrid(
|
||
detector, df, date,
|
||
seq_len=BACKTEST_CONFIG['seq_len']
|
||
)
|
||
|
||
if alerts:
|
||
all_alerts.extend(alerts)
|
||
|
||
saved = save_alerts_to_mysql(alerts, dry_run=args.dry_run)
|
||
total_saved += saved
|
||
|
||
if not args.dry_run:
|
||
tqdm.write(f" {date}: 检测到 {len(alerts)} 个异动,保存 {saved} 条")
|
||
|
||
# 导出 CSV
|
||
if args.export_csv and all_alerts:
|
||
export_alerts_to_csv(all_alerts, args.export_csv)
|
||
|
||
# 统计分析
|
||
analyze_alerts(all_alerts)
|
||
|
||
# 汇总
|
||
print("\n" + "=" * 60)
|
||
print("回测完成!")
|
||
print("=" * 60)
|
||
print(f"总计检测到: {len(all_alerts)} 个异动")
|
||
print(f"保存到数据库: {total_saved} 条")
|
||
print("=" * 60)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|