295 lines
10 KiB
Python
295 lines
10 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
V2 回测脚本 - 验证时间片对齐 + 持续性确认的效果
|
||
|
||
回测指标:
|
||
1. 准确率:异动后 N 分钟内 alpha 是否继续上涨/下跌
|
||
2. 虚警率:多少异动是噪音
|
||
3. 持续性:平均异动持续时长
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import json
|
||
import argparse
|
||
from datetime import datetime
|
||
from pathlib import Path
|
||
from typing import Dict, List, Tuple
|
||
from collections import defaultdict
|
||
|
||
import numpy as np
|
||
import pandas as pd
|
||
from tqdm import tqdm
|
||
from sqlalchemy import create_engine, text
|
||
|
||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
|
||
from ml.detector_v2 import AnomalyDetectorV2, CONFIG
|
||
|
||
|
||
# ==================== 配置 ====================
|
||
|
||
MYSQL_ENGINE = create_engine(
|
||
"mysql+pymysql://root:Zzl5588161!@192.168.1.5:3306/stock",
|
||
echo=False
|
||
)
|
||
|
||
|
||
# ==================== 回测评估 ====================
|
||
|
||
def evaluate_alerts(
|
||
alerts: List[Dict],
|
||
raw_data: pd.DataFrame,
|
||
lookahead_minutes: int = 10
|
||
) -> Dict:
|
||
"""
|
||
评估异动质量
|
||
|
||
指标:
|
||
1. 方向正确率:异动后 N 分钟 alpha 方向是否一致
|
||
2. 持续率:异动后 N 分钟内有多少时刻 alpha 保持同向
|
||
3. 峰值收益:异动后 N 分钟内的最大 alpha
|
||
"""
|
||
if not alerts:
|
||
return {'accuracy': 0, 'sustained_rate': 0, 'avg_peak': 0, 'total_alerts': 0}
|
||
|
||
results = []
|
||
|
||
for alert in alerts:
|
||
concept_id = alert['concept_id']
|
||
alert_time = alert['alert_time']
|
||
alert_alpha = alert['alpha']
|
||
is_up = alert_alpha > 0
|
||
|
||
# 获取该概念在异动后的数据
|
||
concept_data = raw_data[
|
||
(raw_data['concept_id'] == concept_id) &
|
||
(raw_data['timestamp'] > alert_time)
|
||
].head(lookahead_minutes)
|
||
|
||
if len(concept_data) < 3:
|
||
continue
|
||
|
||
future_alphas = concept_data['alpha'].values
|
||
|
||
# 方向正确:未来 alpha 平均值与当前同向
|
||
avg_future_alpha = np.mean(future_alphas)
|
||
direction_correct = (is_up and avg_future_alpha > 0) or (not is_up and avg_future_alpha < 0)
|
||
|
||
# 持续率:有多少时刻保持同向
|
||
if is_up:
|
||
sustained_count = sum(1 for a in future_alphas if a > 0)
|
||
else:
|
||
sustained_count = sum(1 for a in future_alphas if a < 0)
|
||
sustained_rate = sustained_count / len(future_alphas)
|
||
|
||
# 峰值收益
|
||
if is_up:
|
||
peak = max(future_alphas)
|
||
else:
|
||
peak = min(future_alphas)
|
||
|
||
results.append({
|
||
'direction_correct': direction_correct,
|
||
'sustained_rate': sustained_rate,
|
||
'peak': peak,
|
||
'alert_alpha': alert_alpha,
|
||
})
|
||
|
||
if not results:
|
||
return {'accuracy': 0, 'sustained_rate': 0, 'avg_peak': 0, 'total_alerts': 0}
|
||
|
||
return {
|
||
'accuracy': np.mean([r['direction_correct'] for r in results]),
|
||
'sustained_rate': np.mean([r['sustained_rate'] for r in results]),
|
||
'avg_peak': np.mean([abs(r['peak']) for r in results]),
|
||
'total_alerts': len(alerts),
|
||
'evaluated_alerts': len(results),
|
||
}
|
||
|
||
|
||
def save_alerts_to_mysql(alerts: List[Dict], dry_run: bool = False) -> int:
|
||
"""保存异动到 MySQL"""
|
||
if not alerts or dry_run:
|
||
return 0
|
||
|
||
# 确保表存在
|
||
with MYSQL_ENGINE.begin() as conn:
|
||
conn.execute(text("""
|
||
CREATE TABLE IF NOT EXISTS concept_anomaly_v2 (
|
||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||
concept_id VARCHAR(64) NOT NULL,
|
||
alert_time DATETIME NOT NULL,
|
||
trade_date DATE NOT NULL,
|
||
alert_type VARCHAR(32) NOT NULL,
|
||
final_score FLOAT NOT NULL,
|
||
rule_score FLOAT NOT NULL,
|
||
ml_score FLOAT NOT NULL,
|
||
trigger_reason VARCHAR(128),
|
||
confirm_ratio FLOAT,
|
||
alpha FLOAT,
|
||
alpha_zscore FLOAT,
|
||
amt_zscore FLOAT,
|
||
rank_zscore FLOAT,
|
||
momentum_3m FLOAT,
|
||
momentum_5m FLOAT,
|
||
limit_up_ratio FLOAT,
|
||
triggered_rules JSON,
|
||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||
UNIQUE KEY uk_concept_time (concept_id, alert_time, trade_date),
|
||
INDEX idx_trade_date (trade_date),
|
||
INDEX idx_final_score (final_score)
|
||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='概念异动 V2(时间片对齐+持续确认)'
|
||
"""))
|
||
|
||
# 插入数据
|
||
saved = 0
|
||
with MYSQL_ENGINE.begin() as conn:
|
||
for alert in alerts:
|
||
try:
|
||
conn.execute(text("""
|
||
INSERT IGNORE INTO concept_anomaly_v2
|
||
(concept_id, alert_time, trade_date, alert_type,
|
||
final_score, rule_score, ml_score, trigger_reason, confirm_ratio,
|
||
alpha, alpha_zscore, amt_zscore, rank_zscore,
|
||
momentum_3m, momentum_5m, limit_up_ratio, triggered_rules)
|
||
VALUES
|
||
(:concept_id, :alert_time, :trade_date, :alert_type,
|
||
:final_score, :rule_score, :ml_score, :trigger_reason, :confirm_ratio,
|
||
:alpha, :alpha_zscore, :amt_zscore, :rank_zscore,
|
||
:momentum_3m, :momentum_5m, :limit_up_ratio, :triggered_rules)
|
||
"""), {
|
||
'concept_id': alert['concept_id'],
|
||
'alert_time': alert['alert_time'],
|
||
'trade_date': alert['trade_date'],
|
||
'alert_type': alert['alert_type'],
|
||
'final_score': alert['final_score'],
|
||
'rule_score': alert['rule_score'],
|
||
'ml_score': alert['ml_score'],
|
||
'trigger_reason': alert['trigger_reason'],
|
||
'confirm_ratio': alert.get('confirm_ratio', 0),
|
||
'alpha': alert['alpha'],
|
||
'alpha_zscore': alert.get('alpha_zscore', 0),
|
||
'amt_zscore': alert.get('amt_zscore', 0),
|
||
'rank_zscore': alert.get('rank_zscore', 0),
|
||
'momentum_3m': alert.get('momentum_3m', 0),
|
||
'momentum_5m': alert.get('momentum_5m', 0),
|
||
'limit_up_ratio': alert.get('limit_up_ratio', 0),
|
||
'triggered_rules': json.dumps(alert.get('triggered_rules', [])),
|
||
})
|
||
saved += 1
|
||
except Exception as e:
|
||
print(f"保存失败: {e}")
|
||
|
||
return saved
|
||
|
||
|
||
# ==================== 主函数 ====================
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(description='V2 回测')
|
||
parser.add_argument('--start', type=str, required=True, help='开始日期')
|
||
parser.add_argument('--end', type=str, default=None, help='结束日期')
|
||
parser.add_argument('--model_dir', type=str, default='ml/checkpoints_v2')
|
||
parser.add_argument('--baseline_dir', type=str, default='ml/data_v2/baselines')
|
||
parser.add_argument('--save', action='store_true', help='保存到数据库')
|
||
parser.add_argument('--lookahead', type=int, default=10, help='评估前瞻时间(分钟)')
|
||
|
||
args = parser.parse_args()
|
||
|
||
end_date = args.end or args.start
|
||
|
||
print("=" * 60)
|
||
print("V2 回测 - 时间片对齐 + 持续性确认")
|
||
print("=" * 60)
|
||
print(f"日期范围: {args.start} ~ {end_date}")
|
||
print(f"模型目录: {args.model_dir}")
|
||
print(f"评估前瞻: {args.lookahead} 分钟")
|
||
|
||
# 初始化检测器
|
||
detector = AnomalyDetectorV2(
|
||
model_dir=args.model_dir,
|
||
baseline_dir=args.baseline_dir
|
||
)
|
||
|
||
# 获取交易日
|
||
from prepare_data_v2 import get_trading_days
|
||
trading_days = get_trading_days(args.start, end_date)
|
||
|
||
if not trading_days:
|
||
print("无交易日")
|
||
return
|
||
|
||
print(f"交易日数: {len(trading_days)}")
|
||
|
||
# 回测统计
|
||
total_stats = {
|
||
'total_alerts': 0,
|
||
'accuracy_sum': 0,
|
||
'sustained_sum': 0,
|
||
'peak_sum': 0,
|
||
'day_count': 0,
|
||
}
|
||
|
||
all_alerts = []
|
||
|
||
for trade_date in tqdm(trading_days, desc="回测进度"):
|
||
# 检测异动
|
||
alerts = detector.detect(trade_date)
|
||
|
||
if not alerts:
|
||
continue
|
||
|
||
all_alerts.extend(alerts)
|
||
|
||
# 评估
|
||
raw_data = detector._compute_raw_features(trade_date)
|
||
if raw_data.empty:
|
||
continue
|
||
|
||
stats = evaluate_alerts(alerts, raw_data, args.lookahead)
|
||
|
||
if stats['evaluated_alerts'] > 0:
|
||
total_stats['total_alerts'] += stats['total_alerts']
|
||
total_stats['accuracy_sum'] += stats['accuracy'] * stats['evaluated_alerts']
|
||
total_stats['sustained_sum'] += stats['sustained_rate'] * stats['evaluated_alerts']
|
||
total_stats['peak_sum'] += stats['avg_peak'] * stats['evaluated_alerts']
|
||
total_stats['day_count'] += 1
|
||
|
||
print(f"\n[{trade_date}] 异动: {stats['total_alerts']}, "
|
||
f"准确率: {stats['accuracy']:.1%}, "
|
||
f"持续率: {stats['sustained_rate']:.1%}, "
|
||
f"峰值: {stats['avg_peak']:.2f}%")
|
||
|
||
# 汇总
|
||
print("\n" + "=" * 60)
|
||
print("回测汇总")
|
||
print("=" * 60)
|
||
|
||
if total_stats['total_alerts'] > 0:
|
||
avg_accuracy = total_stats['accuracy_sum'] / total_stats['total_alerts']
|
||
avg_sustained = total_stats['sustained_sum'] / total_stats['total_alerts']
|
||
avg_peak = total_stats['peak_sum'] / total_stats['total_alerts']
|
||
|
||
print(f"总异动数: {total_stats['total_alerts']}")
|
||
print(f"回测天数: {total_stats['day_count']}")
|
||
print(f"平均每天: {total_stats['total_alerts'] / max(1, total_stats['day_count']):.1f} 个")
|
||
print(f"方向准确率: {avg_accuracy:.1%}")
|
||
print(f"持续率: {avg_sustained:.1%}")
|
||
print(f"平均峰值: {avg_peak:.2f}%")
|
||
else:
|
||
print("无异动检测结果")
|
||
|
||
# 保存
|
||
if args.save and all_alerts:
|
||
print(f"\n保存 {len(all_alerts)} 条异动到数据库...")
|
||
saved = save_alerts_to_mysql(all_alerts)
|
||
print(f"保存完成: {saved} 条")
|
||
|
||
print("=" * 60)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|