update pay ui
This commit is contained in:
294
ml/backtest_v2.py
Normal file
294
ml/backtest_v2.py
Normal file
@@ -0,0 +1,294 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
V2 回测脚本 - 验证时间片对齐 + 持续性确认的效果
|
||||
|
||||
回测指标:
|
||||
1. 准确率:异动后 N 分钟内 alpha 是否继续上涨/下跌
|
||||
2. 虚警率:多少异动是噪音
|
||||
3. 持续性:平均异动持续时长
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import argparse
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
from sqlalchemy import create_engine, text
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from ml.detector_v2 import AnomalyDetectorV2, CONFIG
|
||||
|
||||
|
||||
# ==================== 配置 ====================
|
||||
|
||||
MYSQL_ENGINE = create_engine(
|
||||
"mysql+pymysql://root:Zzl5588161!@192.168.1.5:3306/stock",
|
||||
echo=False
|
||||
)
|
||||
|
||||
|
||||
# ==================== 回测评估 ====================
|
||||
|
||||
def evaluate_alerts(
|
||||
alerts: List[Dict],
|
||||
raw_data: pd.DataFrame,
|
||||
lookahead_minutes: int = 10
|
||||
) -> Dict:
|
||||
"""
|
||||
评估异动质量
|
||||
|
||||
指标:
|
||||
1. 方向正确率:异动后 N 分钟 alpha 方向是否一致
|
||||
2. 持续率:异动后 N 分钟内有多少时刻 alpha 保持同向
|
||||
3. 峰值收益:异动后 N 分钟内的最大 alpha
|
||||
"""
|
||||
if not alerts:
|
||||
return {'accuracy': 0, 'sustained_rate': 0, 'avg_peak': 0, 'total_alerts': 0}
|
||||
|
||||
results = []
|
||||
|
||||
for alert in alerts:
|
||||
concept_id = alert['concept_id']
|
||||
alert_time = alert['alert_time']
|
||||
alert_alpha = alert['alpha']
|
||||
is_up = alert_alpha > 0
|
||||
|
||||
# 获取该概念在异动后的数据
|
||||
concept_data = raw_data[
|
||||
(raw_data['concept_id'] == concept_id) &
|
||||
(raw_data['timestamp'] > alert_time)
|
||||
].head(lookahead_minutes)
|
||||
|
||||
if len(concept_data) < 3:
|
||||
continue
|
||||
|
||||
future_alphas = concept_data['alpha'].values
|
||||
|
||||
# 方向正确:未来 alpha 平均值与当前同向
|
||||
avg_future_alpha = np.mean(future_alphas)
|
||||
direction_correct = (is_up and avg_future_alpha > 0) or (not is_up and avg_future_alpha < 0)
|
||||
|
||||
# 持续率:有多少时刻保持同向
|
||||
if is_up:
|
||||
sustained_count = sum(1 for a in future_alphas if a > 0)
|
||||
else:
|
||||
sustained_count = sum(1 for a in future_alphas if a < 0)
|
||||
sustained_rate = sustained_count / len(future_alphas)
|
||||
|
||||
# 峰值收益
|
||||
if is_up:
|
||||
peak = max(future_alphas)
|
||||
else:
|
||||
peak = min(future_alphas)
|
||||
|
||||
results.append({
|
||||
'direction_correct': direction_correct,
|
||||
'sustained_rate': sustained_rate,
|
||||
'peak': peak,
|
||||
'alert_alpha': alert_alpha,
|
||||
})
|
||||
|
||||
if not results:
|
||||
return {'accuracy': 0, 'sustained_rate': 0, 'avg_peak': 0, 'total_alerts': 0}
|
||||
|
||||
return {
|
||||
'accuracy': np.mean([r['direction_correct'] for r in results]),
|
||||
'sustained_rate': np.mean([r['sustained_rate'] for r in results]),
|
||||
'avg_peak': np.mean([abs(r['peak']) for r in results]),
|
||||
'total_alerts': len(alerts),
|
||||
'evaluated_alerts': len(results),
|
||||
}
|
||||
|
||||
|
||||
def save_alerts_to_mysql(alerts: List[Dict], dry_run: bool = False) -> int:
|
||||
"""保存异动到 MySQL"""
|
||||
if not alerts or dry_run:
|
||||
return 0
|
||||
|
||||
# 确保表存在
|
||||
with MYSQL_ENGINE.begin() as conn:
|
||||
conn.execute(text("""
|
||||
CREATE TABLE IF NOT EXISTS concept_anomaly_v2 (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
concept_id VARCHAR(64) NOT NULL,
|
||||
alert_time DATETIME NOT NULL,
|
||||
trade_date DATE NOT NULL,
|
||||
alert_type VARCHAR(32) NOT NULL,
|
||||
final_score FLOAT NOT NULL,
|
||||
rule_score FLOAT NOT NULL,
|
||||
ml_score FLOAT NOT NULL,
|
||||
trigger_reason VARCHAR(128),
|
||||
confirm_ratio FLOAT,
|
||||
alpha FLOAT,
|
||||
alpha_zscore FLOAT,
|
||||
amt_zscore FLOAT,
|
||||
rank_zscore FLOAT,
|
||||
momentum_3m FLOAT,
|
||||
momentum_5m FLOAT,
|
||||
limit_up_ratio FLOAT,
|
||||
triggered_rules JSON,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
UNIQUE KEY uk_concept_time (concept_id, alert_time, trade_date),
|
||||
INDEX idx_trade_date (trade_date),
|
||||
INDEX idx_final_score (final_score)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='概念异动 V2(时间片对齐+持续确认)'
|
||||
"""))
|
||||
|
||||
# 插入数据
|
||||
saved = 0
|
||||
with MYSQL_ENGINE.begin() as conn:
|
||||
for alert in alerts:
|
||||
try:
|
||||
conn.execute(text("""
|
||||
INSERT IGNORE INTO concept_anomaly_v2
|
||||
(concept_id, alert_time, trade_date, alert_type,
|
||||
final_score, rule_score, ml_score, trigger_reason, confirm_ratio,
|
||||
alpha, alpha_zscore, amt_zscore, rank_zscore,
|
||||
momentum_3m, momentum_5m, limit_up_ratio, triggered_rules)
|
||||
VALUES
|
||||
(:concept_id, :alert_time, :trade_date, :alert_type,
|
||||
:final_score, :rule_score, :ml_score, :trigger_reason, :confirm_ratio,
|
||||
:alpha, :alpha_zscore, :amt_zscore, :rank_zscore,
|
||||
:momentum_3m, :momentum_5m, :limit_up_ratio, :triggered_rules)
|
||||
"""), {
|
||||
'concept_id': alert['concept_id'],
|
||||
'alert_time': alert['alert_time'],
|
||||
'trade_date': alert['trade_date'],
|
||||
'alert_type': alert['alert_type'],
|
||||
'final_score': alert['final_score'],
|
||||
'rule_score': alert['rule_score'],
|
||||
'ml_score': alert['ml_score'],
|
||||
'trigger_reason': alert['trigger_reason'],
|
||||
'confirm_ratio': alert.get('confirm_ratio', 0),
|
||||
'alpha': alert['alpha'],
|
||||
'alpha_zscore': alert.get('alpha_zscore', 0),
|
||||
'amt_zscore': alert.get('amt_zscore', 0),
|
||||
'rank_zscore': alert.get('rank_zscore', 0),
|
||||
'momentum_3m': alert.get('momentum_3m', 0),
|
||||
'momentum_5m': alert.get('momentum_5m', 0),
|
||||
'limit_up_ratio': alert.get('limit_up_ratio', 0),
|
||||
'triggered_rules': json.dumps(alert.get('triggered_rules', [])),
|
||||
})
|
||||
saved += 1
|
||||
except Exception as e:
|
||||
print(f"保存失败: {e}")
|
||||
|
||||
return saved
|
||||
|
||||
|
||||
# ==================== 主函数 ====================
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='V2 回测')
|
||||
parser.add_argument('--start', type=str, required=True, help='开始日期')
|
||||
parser.add_argument('--end', type=str, default=None, help='结束日期')
|
||||
parser.add_argument('--model_dir', type=str, default='ml/checkpoints_v2')
|
||||
parser.add_argument('--baseline_dir', type=str, default='ml/data_v2/baselines')
|
||||
parser.add_argument('--save', action='store_true', help='保存到数据库')
|
||||
parser.add_argument('--lookahead', type=int, default=10, help='评估前瞻时间(分钟)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
end_date = args.end or args.start
|
||||
|
||||
print("=" * 60)
|
||||
print("V2 回测 - 时间片对齐 + 持续性确认")
|
||||
print("=" * 60)
|
||||
print(f"日期范围: {args.start} ~ {end_date}")
|
||||
print(f"模型目录: {args.model_dir}")
|
||||
print(f"评估前瞻: {args.lookahead} 分钟")
|
||||
|
||||
# 初始化检测器
|
||||
detector = AnomalyDetectorV2(
|
||||
model_dir=args.model_dir,
|
||||
baseline_dir=args.baseline_dir
|
||||
)
|
||||
|
||||
# 获取交易日
|
||||
from prepare_data_v2 import get_trading_days
|
||||
trading_days = get_trading_days(args.start, end_date)
|
||||
|
||||
if not trading_days:
|
||||
print("无交易日")
|
||||
return
|
||||
|
||||
print(f"交易日数: {len(trading_days)}")
|
||||
|
||||
# 回测统计
|
||||
total_stats = {
|
||||
'total_alerts': 0,
|
||||
'accuracy_sum': 0,
|
||||
'sustained_sum': 0,
|
||||
'peak_sum': 0,
|
||||
'day_count': 0,
|
||||
}
|
||||
|
||||
all_alerts = []
|
||||
|
||||
for trade_date in tqdm(trading_days, desc="回测进度"):
|
||||
# 检测异动
|
||||
alerts = detector.detect(trade_date)
|
||||
|
||||
if not alerts:
|
||||
continue
|
||||
|
||||
all_alerts.extend(alerts)
|
||||
|
||||
# 评估
|
||||
raw_data = detector._compute_raw_features(trade_date)
|
||||
if raw_data.empty:
|
||||
continue
|
||||
|
||||
stats = evaluate_alerts(alerts, raw_data, args.lookahead)
|
||||
|
||||
if stats['evaluated_alerts'] > 0:
|
||||
total_stats['total_alerts'] += stats['total_alerts']
|
||||
total_stats['accuracy_sum'] += stats['accuracy'] * stats['evaluated_alerts']
|
||||
total_stats['sustained_sum'] += stats['sustained_rate'] * stats['evaluated_alerts']
|
||||
total_stats['peak_sum'] += stats['avg_peak'] * stats['evaluated_alerts']
|
||||
total_stats['day_count'] += 1
|
||||
|
||||
print(f"\n[{trade_date}] 异动: {stats['total_alerts']}, "
|
||||
f"准确率: {stats['accuracy']:.1%}, "
|
||||
f"持续率: {stats['sustained_rate']:.1%}, "
|
||||
f"峰值: {stats['avg_peak']:.2f}%")
|
||||
|
||||
# 汇总
|
||||
print("\n" + "=" * 60)
|
||||
print("回测汇总")
|
||||
print("=" * 60)
|
||||
|
||||
if total_stats['total_alerts'] > 0:
|
||||
avg_accuracy = total_stats['accuracy_sum'] / total_stats['total_alerts']
|
||||
avg_sustained = total_stats['sustained_sum'] / total_stats['total_alerts']
|
||||
avg_peak = total_stats['peak_sum'] / total_stats['total_alerts']
|
||||
|
||||
print(f"总异动数: {total_stats['total_alerts']}")
|
||||
print(f"回测天数: {total_stats['day_count']}")
|
||||
print(f"平均每天: {total_stats['total_alerts'] / max(1, total_stats['day_count']):.1f} 个")
|
||||
print(f"方向准确率: {avg_accuracy:.1%}")
|
||||
print(f"持续率: {avg_sustained:.1%}")
|
||||
print(f"平均峰值: {avg_peak:.2f}%")
|
||||
else:
|
||||
print("无异动检测结果")
|
||||
|
||||
# 保存
|
||||
if args.save and all_alerts:
|
||||
print(f"\n保存 {len(all_alerts)} 条异动到数据库...")
|
||||
saved = save_alerts_to_mysql(all_alerts)
|
||||
print(f"保存完成: {saved} 条")
|
||||
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user