update pay ui

This commit is contained in:
2025-12-09 08:31:18 +08:00
parent e4937c2719
commit 25492caf15
26 changed files with 15577 additions and 1061 deletions

418
ml/backtest_hybrid.py Normal file
View File

@@ -0,0 +1,418 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
融合异动回测脚本
使用 HybridAnomalyDetector 进行回测:
- 规则评分 + LSTM Autoencoder 融合判断
- 输出更丰富的异动信息
使用方法:
python backtest_hybrid.py --start 2024-01-01 --end 2024-12-01
python backtest_hybrid.py --start 2024-11-01 --dry-run
"""
import os
import sys
import argparse
import json
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional
from collections import defaultdict
import numpy as np
import pandas as pd
from tqdm import tqdm
from sqlalchemy import create_engine, text
# 添加父目录到路径
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from detector import HybridAnomalyDetector, create_detector
# ==================== 配置 ====================
MYSQL_ENGINE = create_engine(
"mysql+pymysql://root:Zzl5588161!@192.168.1.5:3306/stock",
echo=False
)
FEATURES = [
'alpha',
'alpha_delta',
'amt_ratio',
'amt_delta',
'rank_pct',
'limit_up_ratio',
]
BACKTEST_CONFIG = {
'seq_len': 30,
'min_alpha_abs': 0.3, # 降低阈值,让规则也能发挥作用
'cooldown_minutes': 8,
'max_alerts_per_minute': 20,
'clip_value': 10.0,
}
# ==================== 数据加载 ====================
def load_daily_features(data_dir: str, date: str) -> Optional[pd.DataFrame]:
"""加载单天的特征数据"""
file_path = Path(data_dir) / f"features_{date}.parquet"
if not file_path.exists():
return None
df = pd.read_parquet(file_path)
return df
def get_available_dates(data_dir: str, start_date: str, end_date: str) -> List[str]:
"""获取可用的日期列表"""
data_path = Path(data_dir)
all_files = sorted(data_path.glob("features_*.parquet"))
dates = []
for f in all_files:
date = f.stem.replace('features_', '')
if start_date <= date <= end_date:
dates.append(date)
return dates
# ==================== 融合回测 ====================
def backtest_single_day_hybrid(
detector: HybridAnomalyDetector,
df: pd.DataFrame,
date: str,
seq_len: int = 30
) -> List[Dict]:
"""
使用融合检测器回测单天数据
"""
alerts = []
# 按概念分组
grouped = df.groupby('concept_id', sort=False)
# 冷却记录
cooldown = {}
# 获取所有时间点
all_timestamps = sorted(df['timestamp'].unique())
if len(all_timestamps) < seq_len:
return alerts
# 对每个时间点进行检测
for t_idx in range(seq_len - 1, len(all_timestamps)):
current_time = all_timestamps[t_idx]
window_start_time = all_timestamps[t_idx - seq_len + 1]
minute_alerts = []
for concept_id, concept_df in grouped:
# 获取时间窗口内的数据
mask = (concept_df['timestamp'] >= window_start_time) & (concept_df['timestamp'] <= current_time)
window_df = concept_df[mask].sort_values('timestamp')
if len(window_df) < seq_len:
continue
window_df = window_df.tail(seq_len)
# 提取特征序列(给 ML 模型)
sequence = window_df[FEATURES].values
sequence = np.nan_to_num(sequence, nan=0.0, posinf=0.0, neginf=0.0)
sequence = np.clip(sequence, -BACKTEST_CONFIG['clip_value'], BACKTEST_CONFIG['clip_value'])
# 当前时刻特征(给规则系统)
current_row = window_df.iloc[-1]
current_features = {
'alpha': current_row.get('alpha', 0),
'alpha_delta': current_row.get('alpha_delta', 0),
'amt_ratio': current_row.get('amt_ratio', 1),
'amt_delta': current_row.get('amt_delta', 0),
'rank_pct': current_row.get('rank_pct', 0.5),
'limit_up_ratio': current_row.get('limit_up_ratio', 0),
}
# 过滤微小波动
if abs(current_features['alpha']) < BACKTEST_CONFIG['min_alpha_abs']:
continue
# 检查冷却
if concept_id in cooldown:
last_alert = cooldown[concept_id]
if isinstance(current_time, datetime):
time_diff = (current_time - last_alert).total_seconds() / 60
else:
time_diff = BACKTEST_CONFIG['cooldown_minutes'] + 1
if time_diff < BACKTEST_CONFIG['cooldown_minutes']:
continue
# 融合检测
result = detector.detect(current_features, sequence)
if not result.is_anomaly:
continue
# 记录异动
alert = {
'concept_id': concept_id,
'alert_time': current_time,
'trade_date': date,
'alert_type': result.anomaly_type,
'final_score': result.final_score,
'rule_score': result.rule_score,
'ml_score': result.ml_score,
'trigger_reason': result.trigger_reason,
'triggered_rules': list(result.rule_details.keys()),
**current_features,
'stock_count': current_row.get('stock_count', 0),
'total_amt': current_row.get('total_amt', 0),
}
minute_alerts.append(alert)
cooldown[concept_id] = current_time
# 按最终得分排序
minute_alerts.sort(key=lambda x: x['final_score'], reverse=True)
alerts.extend(minute_alerts[:BACKTEST_CONFIG['max_alerts_per_minute']])
return alerts
# ==================== 数据库写入 ====================
def save_alerts_to_mysql(alerts: List[Dict], dry_run: bool = False) -> int:
"""保存异动到 MySQL增强版字段"""
if not alerts:
return 0
if dry_run:
print(f" [Dry Run] 将写入 {len(alerts)} 条异动")
return len(alerts)
saved = 0
with MYSQL_ENGINE.begin() as conn:
for alert in alerts:
try:
# 检查是否已存在
check_sql = text("""
SELECT id FROM concept_minute_alert
WHERE concept_id = :concept_id
AND alert_time = :alert_time
AND trade_date = :trade_date
""")
exists = conn.execute(check_sql, {
'concept_id': alert['concept_id'],
'alert_time': alert['alert_time'],
'trade_date': alert['trade_date'],
}).fetchone()
if exists:
continue
# 插入新记录
insert_sql = text("""
INSERT INTO concept_minute_alert
(concept_id, concept_name, alert_time, alert_type, trade_date,
change_pct, zscore, importance_score, stock_count, extra_info)
VALUES
(:concept_id, :concept_name, :alert_time, :alert_type, :trade_date,
:change_pct, :zscore, :importance_score, :stock_count, :extra_info)
""")
extra_info = {
'detection_method': 'hybrid',
'final_score': alert['final_score'],
'rule_score': alert['rule_score'],
'ml_score': alert['ml_score'],
'trigger_reason': alert['trigger_reason'],
'triggered_rules': alert['triggered_rules'],
'alpha': alert.get('alpha', 0),
'alpha_delta': alert.get('alpha_delta', 0),
'amt_ratio': alert.get('amt_ratio', 1),
}
conn.execute(insert_sql, {
'concept_id': alert['concept_id'],
'concept_name': alert.get('concept_name', ''),
'alert_time': alert['alert_time'],
'alert_type': alert['alert_type'],
'trade_date': alert['trade_date'],
'change_pct': alert.get('alpha', 0),
'zscore': alert['final_score'], # 用最终得分作为 zscore
'importance_score': alert['final_score'],
'stock_count': alert.get('stock_count', 0),
'extra_info': json.dumps(extra_info, ensure_ascii=False)
})
saved += 1
except Exception as e:
print(f" 保存失败: {alert['concept_id']} - {e}")
return saved
def export_alerts_to_csv(alerts: List[Dict], output_path: str):
"""导出异动到 CSV"""
if not alerts:
return
df = pd.DataFrame(alerts)
df.to_csv(output_path, index=False, encoding='utf-8-sig')
print(f"已导出到: {output_path}")
# ==================== 统计分析 ====================
def analyze_alerts(alerts: List[Dict]):
"""分析异动结果"""
if not alerts:
print("无异动数据")
return
df = pd.DataFrame(alerts)
print("\n" + "=" * 60)
print("异动统计分析")
print("=" * 60)
# 1. 基本统计
print(f"\n总异动数: {len(alerts)}")
# 2. 按类型统计
print(f"\n异动类型分布:")
print(df['alert_type'].value_counts())
# 3. 得分统计
print(f"\n得分统计:")
print(f" 最终得分 - Mean: {df['final_score'].mean():.1f}, Max: {df['final_score'].max():.1f}")
print(f" 规则得分 - Mean: {df['rule_score'].mean():.1f}, Max: {df['rule_score'].max():.1f}")
print(f" ML得分 - Mean: {df['ml_score'].mean():.1f}, Max: {df['ml_score'].max():.1f}")
# 4. 触发来源分析
print(f"\n触发来源分析:")
trigger_counts = df['trigger_reason'].apply(
lambda x: '规则' if '规则' in x else ('ML' if 'ML' in x else '融合')
).value_counts()
print(trigger_counts)
# 5. 规则触发频率
all_rules = []
for rules in df['triggered_rules']:
if isinstance(rules, list):
all_rules.extend(rules)
if all_rules:
print(f"\n最常触发的规则 (Top 10):")
from collections import Counter
rule_counts = Counter(all_rules)
for rule, count in rule_counts.most_common(10):
print(f" {rule}: {count}")
# ==================== 主函数 ====================
def main():
parser = argparse.ArgumentParser(description='融合异动回测')
parser.add_argument('--data_dir', type=str, default='ml/data',
help='特征数据目录')
parser.add_argument('--checkpoint_dir', type=str, default='ml/checkpoints',
help='模型检查点目录')
parser.add_argument('--start', type=str, required=True,
help='开始日期 (YYYY-MM-DD)')
parser.add_argument('--end', type=str, default=None,
help='结束日期 (YYYY-MM-DD),默认=start')
parser.add_argument('--dry-run', action='store_true',
help='只计算,不写入数据库')
parser.add_argument('--export-csv', type=str, default=None,
help='导出 CSV 文件路径')
parser.add_argument('--rule-weight', type=float, default=0.6,
help='规则权重 (0-1)')
parser.add_argument('--ml-weight', type=float, default=0.4,
help='ML权重 (0-1)')
args = parser.parse_args()
if args.end is None:
args.end = args.start
print("=" * 60)
print("融合异动回测 (规则 + LSTM)")
print("=" * 60)
print(f"日期范围: {args.start} ~ {args.end}")
print(f"数据目录: {args.data_dir}")
print(f"模型目录: {args.checkpoint_dir}")
print(f"规则权重: {args.rule_weight}")
print(f"ML权重: {args.ml_weight}")
print(f"Dry Run: {args.dry_run}")
print("=" * 60)
# 初始化融合检测器
config = {
'rule_weight': args.rule_weight,
'ml_weight': args.ml_weight,
}
detector = create_detector(args.checkpoint_dir, config)
# 获取可用日期
dates = get_available_dates(args.data_dir, args.start, args.end)
if not dates:
print(f"未找到 {args.start} ~ {args.end} 范围内的数据")
return
print(f"\n找到 {len(dates)} 天的数据")
# 回测
all_alerts = []
total_saved = 0
for date in tqdm(dates, desc="回测进度"):
df = load_daily_features(args.data_dir, date)
if df is None or df.empty:
continue
alerts = backtest_single_day_hybrid(
detector, df, date,
seq_len=BACKTEST_CONFIG['seq_len']
)
if alerts:
all_alerts.extend(alerts)
saved = save_alerts_to_mysql(alerts, dry_run=args.dry_run)
total_saved += saved
if not args.dry_run:
tqdm.write(f" {date}: 检测到 {len(alerts)} 个异动,保存 {saved}")
# 导出 CSV
if args.export_csv and all_alerts:
export_alerts_to_csv(all_alerts, args.export_csv)
# 统计分析
analyze_alerts(all_alerts)
# 汇总
print("\n" + "=" * 60)
print("回测完成!")
print("=" * 60)
print(f"总计检测到: {len(all_alerts)} 个异动")
print(f"保存到数据库: {total_saved}")
print("=" * 60)
if __name__ == "__main__":
main()