update pay ui
This commit is contained in:
418
ml/backtest_hybrid.py
Normal file
418
ml/backtest_hybrid.py
Normal file
@@ -0,0 +1,418 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
融合异动回测脚本
|
||||
|
||||
使用 HybridAnomalyDetector 进行回测:
|
||||
- 规则评分 + LSTM Autoencoder 融合判断
|
||||
- 输出更丰富的异动信息
|
||||
|
||||
使用方法:
|
||||
python backtest_hybrid.py --start 2024-01-01 --end 2024-12-01
|
||||
python backtest_hybrid.py --start 2024-11-01 --dry-run
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
from sqlalchemy import create_engine, text
|
||||
|
||||
# 添加父目录到路径
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from detector import HybridAnomalyDetector, create_detector
|
||||
|
||||
|
||||
# ==================== 配置 ====================
|
||||
|
||||
MYSQL_ENGINE = create_engine(
|
||||
"mysql+pymysql://root:Zzl5588161!@192.168.1.5:3306/stock",
|
||||
echo=False
|
||||
)
|
||||
|
||||
FEATURES = [
|
||||
'alpha',
|
||||
'alpha_delta',
|
||||
'amt_ratio',
|
||||
'amt_delta',
|
||||
'rank_pct',
|
||||
'limit_up_ratio',
|
||||
]
|
||||
|
||||
BACKTEST_CONFIG = {
|
||||
'seq_len': 30,
|
||||
'min_alpha_abs': 0.3, # 降低阈值,让规则也能发挥作用
|
||||
'cooldown_minutes': 8,
|
||||
'max_alerts_per_minute': 20,
|
||||
'clip_value': 10.0,
|
||||
}
|
||||
|
||||
|
||||
# ==================== 数据加载 ====================
|
||||
|
||||
def load_daily_features(data_dir: str, date: str) -> Optional[pd.DataFrame]:
|
||||
"""加载单天的特征数据"""
|
||||
file_path = Path(data_dir) / f"features_{date}.parquet"
|
||||
|
||||
if not file_path.exists():
|
||||
return None
|
||||
|
||||
df = pd.read_parquet(file_path)
|
||||
return df
|
||||
|
||||
|
||||
def get_available_dates(data_dir: str, start_date: str, end_date: str) -> List[str]:
|
||||
"""获取可用的日期列表"""
|
||||
data_path = Path(data_dir)
|
||||
all_files = sorted(data_path.glob("features_*.parquet"))
|
||||
|
||||
dates = []
|
||||
for f in all_files:
|
||||
date = f.stem.replace('features_', '')
|
||||
if start_date <= date <= end_date:
|
||||
dates.append(date)
|
||||
|
||||
return dates
|
||||
|
||||
|
||||
# ==================== 融合回测 ====================
|
||||
|
||||
def backtest_single_day_hybrid(
|
||||
detector: HybridAnomalyDetector,
|
||||
df: pd.DataFrame,
|
||||
date: str,
|
||||
seq_len: int = 30
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
使用融合检测器回测单天数据
|
||||
"""
|
||||
alerts = []
|
||||
|
||||
# 按概念分组
|
||||
grouped = df.groupby('concept_id', sort=False)
|
||||
|
||||
# 冷却记录
|
||||
cooldown = {}
|
||||
|
||||
# 获取所有时间点
|
||||
all_timestamps = sorted(df['timestamp'].unique())
|
||||
|
||||
if len(all_timestamps) < seq_len:
|
||||
return alerts
|
||||
|
||||
# 对每个时间点进行检测
|
||||
for t_idx in range(seq_len - 1, len(all_timestamps)):
|
||||
current_time = all_timestamps[t_idx]
|
||||
window_start_time = all_timestamps[t_idx - seq_len + 1]
|
||||
|
||||
minute_alerts = []
|
||||
|
||||
for concept_id, concept_df in grouped:
|
||||
# 获取时间窗口内的数据
|
||||
mask = (concept_df['timestamp'] >= window_start_time) & (concept_df['timestamp'] <= current_time)
|
||||
window_df = concept_df[mask].sort_values('timestamp')
|
||||
|
||||
if len(window_df) < seq_len:
|
||||
continue
|
||||
|
||||
window_df = window_df.tail(seq_len)
|
||||
|
||||
# 提取特征序列(给 ML 模型)
|
||||
sequence = window_df[FEATURES].values
|
||||
sequence = np.nan_to_num(sequence, nan=0.0, posinf=0.0, neginf=0.0)
|
||||
sequence = np.clip(sequence, -BACKTEST_CONFIG['clip_value'], BACKTEST_CONFIG['clip_value'])
|
||||
|
||||
# 当前时刻特征(给规则系统)
|
||||
current_row = window_df.iloc[-1]
|
||||
current_features = {
|
||||
'alpha': current_row.get('alpha', 0),
|
||||
'alpha_delta': current_row.get('alpha_delta', 0),
|
||||
'amt_ratio': current_row.get('amt_ratio', 1),
|
||||
'amt_delta': current_row.get('amt_delta', 0),
|
||||
'rank_pct': current_row.get('rank_pct', 0.5),
|
||||
'limit_up_ratio': current_row.get('limit_up_ratio', 0),
|
||||
}
|
||||
|
||||
# 过滤微小波动
|
||||
if abs(current_features['alpha']) < BACKTEST_CONFIG['min_alpha_abs']:
|
||||
continue
|
||||
|
||||
# 检查冷却
|
||||
if concept_id in cooldown:
|
||||
last_alert = cooldown[concept_id]
|
||||
if isinstance(current_time, datetime):
|
||||
time_diff = (current_time - last_alert).total_seconds() / 60
|
||||
else:
|
||||
time_diff = BACKTEST_CONFIG['cooldown_minutes'] + 1
|
||||
|
||||
if time_diff < BACKTEST_CONFIG['cooldown_minutes']:
|
||||
continue
|
||||
|
||||
# 融合检测
|
||||
result = detector.detect(current_features, sequence)
|
||||
|
||||
if not result.is_anomaly:
|
||||
continue
|
||||
|
||||
# 记录异动
|
||||
alert = {
|
||||
'concept_id': concept_id,
|
||||
'alert_time': current_time,
|
||||
'trade_date': date,
|
||||
'alert_type': result.anomaly_type,
|
||||
'final_score': result.final_score,
|
||||
'rule_score': result.rule_score,
|
||||
'ml_score': result.ml_score,
|
||||
'trigger_reason': result.trigger_reason,
|
||||
'triggered_rules': list(result.rule_details.keys()),
|
||||
**current_features,
|
||||
'stock_count': current_row.get('stock_count', 0),
|
||||
'total_amt': current_row.get('total_amt', 0),
|
||||
}
|
||||
|
||||
minute_alerts.append(alert)
|
||||
cooldown[concept_id] = current_time
|
||||
|
||||
# 按最终得分排序
|
||||
minute_alerts.sort(key=lambda x: x['final_score'], reverse=True)
|
||||
alerts.extend(minute_alerts[:BACKTEST_CONFIG['max_alerts_per_minute']])
|
||||
|
||||
return alerts
|
||||
|
||||
|
||||
# ==================== 数据库写入 ====================
|
||||
|
||||
def save_alerts_to_mysql(alerts: List[Dict], dry_run: bool = False) -> int:
|
||||
"""保存异动到 MySQL(增强版字段)"""
|
||||
if not alerts:
|
||||
return 0
|
||||
|
||||
if dry_run:
|
||||
print(f" [Dry Run] 将写入 {len(alerts)} 条异动")
|
||||
return len(alerts)
|
||||
|
||||
saved = 0
|
||||
with MYSQL_ENGINE.begin() as conn:
|
||||
for alert in alerts:
|
||||
try:
|
||||
# 检查是否已存在
|
||||
check_sql = text("""
|
||||
SELECT id FROM concept_minute_alert
|
||||
WHERE concept_id = :concept_id
|
||||
AND alert_time = :alert_time
|
||||
AND trade_date = :trade_date
|
||||
""")
|
||||
exists = conn.execute(check_sql, {
|
||||
'concept_id': alert['concept_id'],
|
||||
'alert_time': alert['alert_time'],
|
||||
'trade_date': alert['trade_date'],
|
||||
}).fetchone()
|
||||
|
||||
if exists:
|
||||
continue
|
||||
|
||||
# 插入新记录
|
||||
insert_sql = text("""
|
||||
INSERT INTO concept_minute_alert
|
||||
(concept_id, concept_name, alert_time, alert_type, trade_date,
|
||||
change_pct, zscore, importance_score, stock_count, extra_info)
|
||||
VALUES
|
||||
(:concept_id, :concept_name, :alert_time, :alert_type, :trade_date,
|
||||
:change_pct, :zscore, :importance_score, :stock_count, :extra_info)
|
||||
""")
|
||||
|
||||
extra_info = {
|
||||
'detection_method': 'hybrid',
|
||||
'final_score': alert['final_score'],
|
||||
'rule_score': alert['rule_score'],
|
||||
'ml_score': alert['ml_score'],
|
||||
'trigger_reason': alert['trigger_reason'],
|
||||
'triggered_rules': alert['triggered_rules'],
|
||||
'alpha': alert.get('alpha', 0),
|
||||
'alpha_delta': alert.get('alpha_delta', 0),
|
||||
'amt_ratio': alert.get('amt_ratio', 1),
|
||||
}
|
||||
|
||||
conn.execute(insert_sql, {
|
||||
'concept_id': alert['concept_id'],
|
||||
'concept_name': alert.get('concept_name', ''),
|
||||
'alert_time': alert['alert_time'],
|
||||
'alert_type': alert['alert_type'],
|
||||
'trade_date': alert['trade_date'],
|
||||
'change_pct': alert.get('alpha', 0),
|
||||
'zscore': alert['final_score'], # 用最终得分作为 zscore
|
||||
'importance_score': alert['final_score'],
|
||||
'stock_count': alert.get('stock_count', 0),
|
||||
'extra_info': json.dumps(extra_info, ensure_ascii=False)
|
||||
})
|
||||
|
||||
saved += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f" 保存失败: {alert['concept_id']} - {e}")
|
||||
|
||||
return saved
|
||||
|
||||
|
||||
def export_alerts_to_csv(alerts: List[Dict], output_path: str):
|
||||
"""导出异动到 CSV"""
|
||||
if not alerts:
|
||||
return
|
||||
|
||||
df = pd.DataFrame(alerts)
|
||||
df.to_csv(output_path, index=False, encoding='utf-8-sig')
|
||||
print(f"已导出到: {output_path}")
|
||||
|
||||
|
||||
# ==================== 统计分析 ====================
|
||||
|
||||
def analyze_alerts(alerts: List[Dict]):
|
||||
"""分析异动结果"""
|
||||
if not alerts:
|
||||
print("无异动数据")
|
||||
return
|
||||
|
||||
df = pd.DataFrame(alerts)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("异动统计分析")
|
||||
print("=" * 60)
|
||||
|
||||
# 1. 基本统计
|
||||
print(f"\n总异动数: {len(alerts)}")
|
||||
|
||||
# 2. 按类型统计
|
||||
print(f"\n异动类型分布:")
|
||||
print(df['alert_type'].value_counts())
|
||||
|
||||
# 3. 得分统计
|
||||
print(f"\n得分统计:")
|
||||
print(f" 最终得分 - Mean: {df['final_score'].mean():.1f}, Max: {df['final_score'].max():.1f}")
|
||||
print(f" 规则得分 - Mean: {df['rule_score'].mean():.1f}, Max: {df['rule_score'].max():.1f}")
|
||||
print(f" ML得分 - Mean: {df['ml_score'].mean():.1f}, Max: {df['ml_score'].max():.1f}")
|
||||
|
||||
# 4. 触发来源分析
|
||||
print(f"\n触发来源分析:")
|
||||
trigger_counts = df['trigger_reason'].apply(
|
||||
lambda x: '规则' if '规则' in x else ('ML' if 'ML' in x else '融合')
|
||||
).value_counts()
|
||||
print(trigger_counts)
|
||||
|
||||
# 5. 规则触发频率
|
||||
all_rules = []
|
||||
for rules in df['triggered_rules']:
|
||||
if isinstance(rules, list):
|
||||
all_rules.extend(rules)
|
||||
|
||||
if all_rules:
|
||||
print(f"\n最常触发的规则 (Top 10):")
|
||||
from collections import Counter
|
||||
rule_counts = Counter(all_rules)
|
||||
for rule, count in rule_counts.most_common(10):
|
||||
print(f" {rule}: {count}")
|
||||
|
||||
|
||||
# ==================== 主函数 ====================
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='融合异动回测')
|
||||
parser.add_argument('--data_dir', type=str, default='ml/data',
|
||||
help='特征数据目录')
|
||||
parser.add_argument('--checkpoint_dir', type=str, default='ml/checkpoints',
|
||||
help='模型检查点目录')
|
||||
parser.add_argument('--start', type=str, required=True,
|
||||
help='开始日期 (YYYY-MM-DD)')
|
||||
parser.add_argument('--end', type=str, default=None,
|
||||
help='结束日期 (YYYY-MM-DD),默认=start')
|
||||
parser.add_argument('--dry-run', action='store_true',
|
||||
help='只计算,不写入数据库')
|
||||
parser.add_argument('--export-csv', type=str, default=None,
|
||||
help='导出 CSV 文件路径')
|
||||
parser.add_argument('--rule-weight', type=float, default=0.6,
|
||||
help='规则权重 (0-1)')
|
||||
parser.add_argument('--ml-weight', type=float, default=0.4,
|
||||
help='ML权重 (0-1)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.end is None:
|
||||
args.end = args.start
|
||||
|
||||
print("=" * 60)
|
||||
print("融合异动回测 (规则 + LSTM)")
|
||||
print("=" * 60)
|
||||
print(f"日期范围: {args.start} ~ {args.end}")
|
||||
print(f"数据目录: {args.data_dir}")
|
||||
print(f"模型目录: {args.checkpoint_dir}")
|
||||
print(f"规则权重: {args.rule_weight}")
|
||||
print(f"ML权重: {args.ml_weight}")
|
||||
print(f"Dry Run: {args.dry_run}")
|
||||
print("=" * 60)
|
||||
|
||||
# 初始化融合检测器
|
||||
config = {
|
||||
'rule_weight': args.rule_weight,
|
||||
'ml_weight': args.ml_weight,
|
||||
}
|
||||
detector = create_detector(args.checkpoint_dir, config)
|
||||
|
||||
# 获取可用日期
|
||||
dates = get_available_dates(args.data_dir, args.start, args.end)
|
||||
|
||||
if not dates:
|
||||
print(f"未找到 {args.start} ~ {args.end} 范围内的数据")
|
||||
return
|
||||
|
||||
print(f"\n找到 {len(dates)} 天的数据")
|
||||
|
||||
# 回测
|
||||
all_alerts = []
|
||||
total_saved = 0
|
||||
|
||||
for date in tqdm(dates, desc="回测进度"):
|
||||
df = load_daily_features(args.data_dir, date)
|
||||
|
||||
if df is None or df.empty:
|
||||
continue
|
||||
|
||||
alerts = backtest_single_day_hybrid(
|
||||
detector, df, date,
|
||||
seq_len=BACKTEST_CONFIG['seq_len']
|
||||
)
|
||||
|
||||
if alerts:
|
||||
all_alerts.extend(alerts)
|
||||
|
||||
saved = save_alerts_to_mysql(alerts, dry_run=args.dry_run)
|
||||
total_saved += saved
|
||||
|
||||
if not args.dry_run:
|
||||
tqdm.write(f" {date}: 检测到 {len(alerts)} 个异动,保存 {saved} 条")
|
||||
|
||||
# 导出 CSV
|
||||
if args.export_csv and all_alerts:
|
||||
export_alerts_to_csv(all_alerts, args.export_csv)
|
||||
|
||||
# 统计分析
|
||||
analyze_alerts(all_alerts)
|
||||
|
||||
# 汇总
|
||||
print("\n" + "=" * 60)
|
||||
print("回测完成!")
|
||||
print("=" * 60)
|
||||
print(f"总计检测到: {len(all_alerts)} 个异动")
|
||||
print(f"保存到数据库: {total_saved} 条")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user