update pay ui
This commit is contained in:
859
ml/backtest_fast.py
Normal file
859
ml/backtest_fast.py
Normal file
@@ -0,0 +1,859 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
快速融合异动回测脚本
|
||||
|
||||
优化策略:
|
||||
1. 预先构建所有序列(向量化),避免循环内重复切片
|
||||
2. 批量 ML 推理(一次推理所有候选)
|
||||
3. 使用 NumPy 向量化操作替代 Python 循环
|
||||
|
||||
性能对比:
|
||||
- 原版:5分钟/天
|
||||
- 优化版:预计 10-30秒/天
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from sqlalchemy import create_engine, text
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
|
||||
# ==================== 配置 ====================
|
||||
|
||||
MYSQL_ENGINE = create_engine(
|
||||
"mysql+pymysql://root:Zzl5588161!@192.168.1.5:3306/stock",
|
||||
echo=False
|
||||
)
|
||||
|
||||
FEATURES = ['alpha', 'alpha_delta', 'amt_ratio', 'amt_delta', 'rank_pct', 'limit_up_ratio']
|
||||
|
||||
CONFIG = {
|
||||
'seq_len': 15, # 序列长度(支持跨日后可从 9:30 检测)
|
||||
'min_alpha_abs': 0.3, # 最小 alpha 过滤
|
||||
'cooldown_minutes': 8,
|
||||
'max_alerts_per_minute': 20,
|
||||
'clip_value': 10.0,
|
||||
# === 融合权重:均衡 ===
|
||||
'rule_weight': 0.5,
|
||||
'ml_weight': 0.5,
|
||||
# === 触发阈值 ===
|
||||
'rule_trigger': 65, # 60 -> 65,略提高规则门槛
|
||||
'ml_trigger': 70, # 75 -> 70,略降低 ML 门槛
|
||||
'fusion_trigger': 45,
|
||||
}
|
||||
|
||||
|
||||
# ==================== 规则评分(向量化版)====================
|
||||
|
||||
def get_size_adjusted_thresholds(stock_count: np.ndarray) -> dict:
|
||||
"""
|
||||
根据概念股票数量计算动态阈值
|
||||
|
||||
设计思路:
|
||||
- 小概念(<10 只):波动大是正常的,需要更高阈值
|
||||
- 中概念(10-50 只):标准阈值
|
||||
- 大概念(>50 只):能有明显波动说明是真异动,降低阈值
|
||||
|
||||
返回各指标的调整系数(乘以基准阈值)
|
||||
"""
|
||||
n = len(stock_count)
|
||||
|
||||
# 基于股票数量的调整系数
|
||||
# 小概念:系数 > 1(提高阈值,更难触发)
|
||||
# 大概念:系数 < 1(降低阈值,更容易触发)
|
||||
size_factor = np.ones(n)
|
||||
|
||||
# 微型概念(<5 只):阈值 × 1.8
|
||||
tiny = stock_count < 5
|
||||
size_factor[tiny] = 1.8
|
||||
|
||||
# 小概念(5-10 只):阈值 × 1.4
|
||||
small = (stock_count >= 5) & (stock_count < 10)
|
||||
size_factor[small] = 1.4
|
||||
|
||||
# 中小概念(10-20 只):阈值 × 1.2
|
||||
medium_small = (stock_count >= 10) & (stock_count < 20)
|
||||
size_factor[medium_small] = 1.2
|
||||
|
||||
# 中概念(20-50 只):标准阈值 × 1.0
|
||||
medium = (stock_count >= 20) & (stock_count < 50)
|
||||
size_factor[medium] = 1.0
|
||||
|
||||
# 大概念(50-100 只):阈值 × 0.85
|
||||
large = (stock_count >= 50) & (stock_count < 100)
|
||||
size_factor[large] = 0.85
|
||||
|
||||
# 超大概念(>100 只):阈值 × 0.7
|
||||
xlarge = stock_count >= 100
|
||||
size_factor[xlarge] = 0.7
|
||||
|
||||
return size_factor
|
||||
|
||||
|
||||
def score_rules_batch(df: pd.DataFrame) -> Tuple[np.ndarray, List[List[str]]]:
|
||||
"""
|
||||
批量计算规则得分(向量化)- 考虑概念规模版
|
||||
|
||||
设计原则:
|
||||
- 规则作为辅助信号,不应单独主导决策
|
||||
- 根据概念股票数量动态调整阈值
|
||||
- 大概念异动更有价值,小概念需要更大波动才算异动
|
||||
|
||||
Args:
|
||||
df: DataFrame,包含所有特征列(必须包含 stock_count)
|
||||
Returns:
|
||||
scores: (n,) 规则得分数组
|
||||
triggered_rules: 每行触发的规则列表
|
||||
"""
|
||||
n = len(df)
|
||||
scores = np.zeros(n)
|
||||
triggered = [[] for _ in range(n)]
|
||||
|
||||
alpha = df['alpha'].values
|
||||
alpha_delta = df['alpha_delta'].values
|
||||
amt_ratio = df['amt_ratio'].values
|
||||
amt_delta = df['amt_delta'].values
|
||||
rank_pct = df['rank_pct'].values
|
||||
limit_up_ratio = df['limit_up_ratio'].values
|
||||
stock_count = df['stock_count'].values if 'stock_count' in df.columns else np.full(n, 20)
|
||||
|
||||
alpha_abs = np.abs(alpha)
|
||||
alpha_delta_abs = np.abs(alpha_delta)
|
||||
|
||||
# 获取基于规模的调整系数
|
||||
size_factor = get_size_adjusted_thresholds(stock_count)
|
||||
|
||||
# ========== Alpha 规则(动态阈值)==========
|
||||
# 基准阈值:极强 5%,强 4%,中等 3%
|
||||
# 实际阈值 = 基准 × size_factor
|
||||
|
||||
# 极强信号
|
||||
alpha_extreme_thresh = 5.0 * size_factor
|
||||
mask = alpha_abs >= alpha_extreme_thresh
|
||||
scores[mask] += 20
|
||||
for i in np.where(mask)[0]: triggered[i].append('alpha_extreme')
|
||||
|
||||
# 强信号
|
||||
alpha_strong_thresh = 4.0 * size_factor
|
||||
mask = (alpha_abs >= alpha_strong_thresh) & (alpha_abs < alpha_extreme_thresh)
|
||||
scores[mask] += 15
|
||||
for i in np.where(mask)[0]: triggered[i].append('alpha_strong')
|
||||
|
||||
# 中等信号
|
||||
alpha_medium_thresh = 3.0 * size_factor
|
||||
mask = (alpha_abs >= alpha_medium_thresh) & (alpha_abs < alpha_strong_thresh)
|
||||
scores[mask] += 10
|
||||
for i in np.where(mask)[0]: triggered[i].append('alpha_medium')
|
||||
|
||||
# ========== Alpha 加速度规则(动态阈值)==========
|
||||
delta_strong_thresh = 2.0 * size_factor
|
||||
mask = alpha_delta_abs >= delta_strong_thresh
|
||||
scores[mask] += 15
|
||||
for i in np.where(mask)[0]: triggered[i].append('alpha_delta_strong')
|
||||
|
||||
delta_medium_thresh = 1.5 * size_factor
|
||||
mask = (alpha_delta_abs >= delta_medium_thresh) & (alpha_delta_abs < delta_strong_thresh)
|
||||
scores[mask] += 10
|
||||
for i in np.where(mask)[0]: triggered[i].append('alpha_delta_medium')
|
||||
|
||||
# ========== 成交额规则(不受规模影响,放量就是放量)==========
|
||||
mask = amt_ratio >= 10.0
|
||||
scores[mask] += 20
|
||||
for i in np.where(mask)[0]: triggered[i].append('volume_extreme')
|
||||
|
||||
mask = (amt_ratio >= 6.0) & (amt_ratio < 10.0)
|
||||
scores[mask] += 12
|
||||
for i in np.where(mask)[0]: triggered[i].append('volume_strong')
|
||||
|
||||
# ========== 排名规则 ==========
|
||||
mask = rank_pct >= 0.98
|
||||
scores[mask] += 15
|
||||
for i in np.where(mask)[0]: triggered[i].append('rank_top')
|
||||
|
||||
mask = rank_pct <= 0.02
|
||||
scores[mask] += 15
|
||||
for i in np.where(mask)[0]: triggered[i].append('rank_bottom')
|
||||
|
||||
# ========== 涨停规则(动态阈值)==========
|
||||
# 大概念有涨停更有意义
|
||||
limit_high_thresh = 0.30 * size_factor
|
||||
mask = limit_up_ratio >= limit_high_thresh
|
||||
scores[mask] += 20
|
||||
for i in np.where(mask)[0]: triggered[i].append('limit_up_high')
|
||||
|
||||
limit_medium_thresh = 0.20 * size_factor
|
||||
mask = (limit_up_ratio >= limit_medium_thresh) & (limit_up_ratio < limit_high_thresh)
|
||||
scores[mask] += 12
|
||||
for i in np.where(mask)[0]: triggered[i].append('limit_up_medium')
|
||||
|
||||
# ========== 概念规模加分(大概念异动更有价值)==========
|
||||
# 大概念(50+)额外加分
|
||||
large_concept = stock_count >= 50
|
||||
has_signal = scores > 0 # 至少触发了某个规则
|
||||
mask = large_concept & has_signal
|
||||
scores[mask] += 10
|
||||
for i in np.where(mask)[0]: triggered[i].append('large_concept_bonus')
|
||||
|
||||
# 超大概念(100+)再加分
|
||||
xlarge_concept = stock_count >= 100
|
||||
mask = xlarge_concept & has_signal
|
||||
scores[mask] += 10
|
||||
for i in np.where(mask)[0]: triggered[i].append('xlarge_concept_bonus')
|
||||
|
||||
# ========== 组合规则(动态阈值)==========
|
||||
combo_alpha_thresh = 3.0 * size_factor
|
||||
|
||||
# Alpha + 放量 + 排名(三重验证)
|
||||
mask = (alpha_abs >= combo_alpha_thresh) & (amt_ratio >= 5.0) & ((rank_pct >= 0.95) | (rank_pct <= 0.05))
|
||||
scores[mask] += 20
|
||||
for i in np.where(mask)[0]: triggered[i].append('triple_signal')
|
||||
|
||||
# Alpha + 涨停(强组合)
|
||||
mask = (alpha_abs >= combo_alpha_thresh) & (limit_up_ratio >= 0.15 * size_factor)
|
||||
scores[mask] += 15
|
||||
for i in np.where(mask)[0]: triggered[i].append('alpha_with_limit')
|
||||
|
||||
# ========== 小概念惩罚(过滤噪音)==========
|
||||
# 微型概念(<5 只)如果只有单一信号,减分
|
||||
tiny_concept = stock_count < 5
|
||||
single_rule = np.array([len(t) <= 1 for t in triggered])
|
||||
mask = tiny_concept & single_rule & (scores > 0)
|
||||
scores[mask] *= 0.5 # 减半
|
||||
for i in np.where(mask)[0]: triggered[i].append('tiny_concept_penalty')
|
||||
|
||||
scores = np.clip(scores, 0, 100)
|
||||
return scores, triggered
|
||||
|
||||
|
||||
# ==================== ML 评分器 ====================
|
||||
|
||||
class FastMLScorer:
|
||||
"""快速 ML 评分器"""
|
||||
|
||||
def __init__(self, checkpoint_dir: str = 'ml/checkpoints', device: str = 'auto'):
|
||||
self.checkpoint_dir = Path(checkpoint_dir)
|
||||
|
||||
if device == 'auto':
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
elif device == 'cuda' and not torch.cuda.is_available():
|
||||
print("警告: CUDA 不可用,使用 CPU")
|
||||
self.device = torch.device('cpu')
|
||||
else:
|
||||
self.device = torch.device(device)
|
||||
|
||||
self.model = None
|
||||
self.thresholds = None
|
||||
self._load_model()
|
||||
|
||||
def _load_model(self):
|
||||
model_path = self.checkpoint_dir / 'best_model.pt'
|
||||
thresholds_path = self.checkpoint_dir / 'thresholds.json'
|
||||
config_path = self.checkpoint_dir / 'config.json'
|
||||
|
||||
if not model_path.exists():
|
||||
print(f"警告: 模型不存在 {model_path}")
|
||||
return
|
||||
|
||||
try:
|
||||
from model import LSTMAutoencoder
|
||||
|
||||
config = {}
|
||||
if config_path.exists():
|
||||
with open(config_path) as f:
|
||||
config = json.load(f).get('model', {})
|
||||
|
||||
# 处理旧配置键名
|
||||
if 'd_model' in config:
|
||||
config['hidden_dim'] = config.pop('d_model') // 2
|
||||
for key in ['num_encoder_layers', 'num_decoder_layers', 'nhead', 'dim_feedforward', 'max_seq_len', 'use_instance_norm']:
|
||||
config.pop(key, None)
|
||||
if 'num_layers' not in config:
|
||||
config['num_layers'] = 1
|
||||
|
||||
checkpoint = torch.load(model_path, map_location='cpu')
|
||||
self.model = LSTMAutoencoder(**config)
|
||||
self.model.load_state_dict(checkpoint['model_state_dict'])
|
||||
self.model.to(self.device)
|
||||
self.model.eval()
|
||||
|
||||
if thresholds_path.exists():
|
||||
with open(thresholds_path) as f:
|
||||
self.thresholds = json.load(f)
|
||||
|
||||
print(f"ML模型加载成功 (设备: {self.device})")
|
||||
except Exception as e:
|
||||
print(f"ML模型加载失败: {e}")
|
||||
self.model = None
|
||||
|
||||
def is_ready(self):
|
||||
return self.model is not None
|
||||
|
||||
@torch.no_grad()
|
||||
def score_batch(self, sequences: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
批量计算 ML 得分
|
||||
|
||||
Args:
|
||||
sequences: (batch, seq_len, n_features)
|
||||
Returns:
|
||||
scores: (batch,) 0-100 分数
|
||||
"""
|
||||
if not self.is_ready() or len(sequences) == 0:
|
||||
return np.zeros(len(sequences))
|
||||
|
||||
x = torch.FloatTensor(sequences).to(self.device)
|
||||
output, _ = self.model(x)
|
||||
mse = ((output - x) ** 2).mean(dim=-1)
|
||||
errors = mse[:, -1].cpu().numpy()
|
||||
|
||||
p95 = self.thresholds.get('p95', 0.1) if self.thresholds else 0.1
|
||||
scores = np.clip(errors / p95 * 50, 0, 100)
|
||||
return scores
|
||||
|
||||
|
||||
# ==================== 快速回测 ====================
|
||||
|
||||
def build_sequences_fast(
|
||||
df: pd.DataFrame,
|
||||
seq_len: int = 30,
|
||||
prev_df: pd.DataFrame = None
|
||||
) -> Tuple[np.ndarray, pd.DataFrame]:
|
||||
"""
|
||||
快速构建所有有效序列
|
||||
|
||||
支持跨日序列:用前一天收盘数据 + 当天开盘数据拼接,实现 9:30 就能检测
|
||||
|
||||
Args:
|
||||
df: 当天数据
|
||||
seq_len: 序列长度
|
||||
prev_df: 前一天数据(可选,用于构建开盘时的序列)
|
||||
|
||||
返回:
|
||||
sequences: (n_valid, seq_len, n_features) 所有有效序列
|
||||
info_df: 对应的元信息 DataFrame
|
||||
"""
|
||||
# 确保按概念和时间排序
|
||||
df = df.sort_values(['concept_id', 'timestamp']).reset_index(drop=True)
|
||||
|
||||
# 如果有前一天数据,按概念构建尾部缓存(取每个概念最后 seq_len-1 条)
|
||||
prev_cache = {}
|
||||
if prev_df is not None and len(prev_df) > 0:
|
||||
prev_df = prev_df.sort_values(['concept_id', 'timestamp'])
|
||||
for concept_id, gdf in prev_df.groupby('concept_id'):
|
||||
tail_data = gdf.tail(seq_len - 1)
|
||||
if len(tail_data) > 0:
|
||||
feat_matrix = tail_data[FEATURES].values
|
||||
feat_matrix = np.nan_to_num(feat_matrix, nan=0.0, posinf=0.0, neginf=0.0)
|
||||
feat_matrix = np.clip(feat_matrix, -CONFIG['clip_value'], CONFIG['clip_value'])
|
||||
prev_cache[concept_id] = feat_matrix
|
||||
|
||||
# 按概念分组
|
||||
groups = df.groupby('concept_id')
|
||||
|
||||
sequences = []
|
||||
infos = []
|
||||
|
||||
for concept_id, gdf in groups:
|
||||
gdf = gdf.reset_index(drop=True)
|
||||
|
||||
# 获取特征矩阵
|
||||
feat_matrix = gdf[FEATURES].values
|
||||
feat_matrix = np.nan_to_num(feat_matrix, nan=0.0, posinf=0.0, neginf=0.0)
|
||||
feat_matrix = np.clip(feat_matrix, -CONFIG['clip_value'], CONFIG['clip_value'])
|
||||
|
||||
# 如果有前一天缓存,拼接到当天数据前面
|
||||
if concept_id in prev_cache:
|
||||
prev_data = prev_cache[concept_id]
|
||||
combined_matrix = np.vstack([prev_data, feat_matrix])
|
||||
# 计算偏移量:前一天数据的长度
|
||||
offset = len(prev_data)
|
||||
else:
|
||||
combined_matrix = feat_matrix
|
||||
offset = 0
|
||||
|
||||
# 滑动窗口构建序列
|
||||
n_total = len(combined_matrix)
|
||||
if n_total < seq_len:
|
||||
continue
|
||||
|
||||
for i in range(n_total - seq_len + 1):
|
||||
seq = combined_matrix[i:i + seq_len]
|
||||
|
||||
# 计算对应当天数据的索引
|
||||
# 序列最后一个点的位置 = i + seq_len - 1
|
||||
# 对应当天数据的索引 = (i + seq_len - 1) - offset
|
||||
today_idx = i + seq_len - 1 - offset
|
||||
|
||||
# 只要序列的最后一个点是当天的数据,就记录
|
||||
if today_idx < 0 or today_idx >= len(gdf):
|
||||
continue
|
||||
|
||||
sequences.append(seq)
|
||||
|
||||
# 记录最后一个时间步的信息(当天的)
|
||||
row = gdf.iloc[today_idx]
|
||||
infos.append({
|
||||
'concept_id': concept_id,
|
||||
'timestamp': row['timestamp'],
|
||||
'alpha': row['alpha'],
|
||||
'alpha_delta': row.get('alpha_delta', 0),
|
||||
'amt_ratio': row.get('amt_ratio', 1),
|
||||
'amt_delta': row.get('amt_delta', 0),
|
||||
'rank_pct': row.get('rank_pct', 0.5),
|
||||
'limit_up_ratio': row.get('limit_up_ratio', 0),
|
||||
'stock_count': row.get('stock_count', 0),
|
||||
'total_amt': row.get('total_amt', 0),
|
||||
})
|
||||
|
||||
if not sequences:
|
||||
return np.array([]), pd.DataFrame()
|
||||
|
||||
return np.array(sequences), pd.DataFrame(infos)
|
||||
|
||||
|
||||
def backtest_single_day_fast(
|
||||
ml_scorer: FastMLScorer,
|
||||
df: pd.DataFrame,
|
||||
date: str,
|
||||
config: Dict,
|
||||
prev_df: pd.DataFrame = None
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
快速回测单天(向量化版本)
|
||||
|
||||
Args:
|
||||
ml_scorer: ML 评分器
|
||||
df: 当天数据
|
||||
date: 日期
|
||||
config: 配置
|
||||
prev_df: 前一天数据(用于 9:30 开始检测)
|
||||
"""
|
||||
seq_len = config.get('seq_len', 30)
|
||||
|
||||
# 1. 构建所有序列(支持跨日)
|
||||
sequences, info_df = build_sequences_fast(df, seq_len, prev_df)
|
||||
|
||||
if len(sequences) == 0:
|
||||
return []
|
||||
|
||||
# 2. 过滤小波动
|
||||
alpha_abs = np.abs(info_df['alpha'].values)
|
||||
valid_mask = alpha_abs >= config['min_alpha_abs']
|
||||
|
||||
sequences = sequences[valid_mask]
|
||||
info_df = info_df[valid_mask].reset_index(drop=True)
|
||||
|
||||
if len(sequences) == 0:
|
||||
return []
|
||||
|
||||
# 3. 批量规则评分
|
||||
rule_scores, triggered_rules = score_rules_batch(info_df)
|
||||
|
||||
# 4. 批量 ML 评分(分批处理避免显存溢出)
|
||||
batch_size = 2048
|
||||
ml_scores = []
|
||||
for i in range(0, len(sequences), batch_size):
|
||||
batch_seq = sequences[i:i+batch_size]
|
||||
batch_scores = ml_scorer.score_batch(batch_seq)
|
||||
ml_scores.append(batch_scores)
|
||||
ml_scores = np.concatenate(ml_scores) if ml_scores else np.zeros(len(sequences))
|
||||
|
||||
# 5. 融合得分
|
||||
w1, w2 = config['rule_weight'], config['ml_weight']
|
||||
final_scores = w1 * rule_scores + w2 * ml_scores
|
||||
|
||||
# 6. 判断异动
|
||||
is_anomaly = (
|
||||
(rule_scores >= config['rule_trigger']) |
|
||||
(ml_scores >= config['ml_trigger']) |
|
||||
(final_scores >= config['fusion_trigger'])
|
||||
)
|
||||
|
||||
# 7. 应用冷却期(按概念+时间排序后处理)
|
||||
info_df['rule_score'] = rule_scores
|
||||
info_df['ml_score'] = ml_scores
|
||||
info_df['final_score'] = final_scores
|
||||
info_df['is_anomaly'] = is_anomaly
|
||||
info_df['triggered_rules'] = triggered_rules
|
||||
|
||||
# 只保留异动
|
||||
anomaly_df = info_df[info_df['is_anomaly']].copy()
|
||||
|
||||
if len(anomaly_df) == 0:
|
||||
return []
|
||||
|
||||
# 应用冷却期
|
||||
anomaly_df = anomaly_df.sort_values(['concept_id', 'timestamp'])
|
||||
cooldown = {}
|
||||
keep_mask = []
|
||||
|
||||
for _, row in anomaly_df.iterrows():
|
||||
cid = row['concept_id']
|
||||
ts = row['timestamp']
|
||||
|
||||
if cid in cooldown:
|
||||
try:
|
||||
diff = (ts - cooldown[cid]).total_seconds() / 60
|
||||
except:
|
||||
diff = config['cooldown_minutes'] + 1
|
||||
|
||||
if diff < config['cooldown_minutes']:
|
||||
keep_mask.append(False)
|
||||
continue
|
||||
|
||||
cooldown[cid] = ts
|
||||
keep_mask.append(True)
|
||||
|
||||
anomaly_df = anomaly_df[keep_mask]
|
||||
|
||||
# 8. 按时间分组,每分钟最多 max_alerts_per_minute 个
|
||||
alerts = []
|
||||
for ts, group in anomaly_df.groupby('timestamp'):
|
||||
group = group.nlargest(config['max_alerts_per_minute'], 'final_score')
|
||||
|
||||
for _, row in group.iterrows():
|
||||
alpha = row['alpha']
|
||||
if alpha >= 1.5:
|
||||
atype = 'surge_up'
|
||||
elif alpha <= -1.5:
|
||||
atype = 'surge_down'
|
||||
elif row['amt_ratio'] >= 3.0:
|
||||
atype = 'volume_spike'
|
||||
else:
|
||||
atype = 'unknown'
|
||||
|
||||
rule_score = row['rule_score']
|
||||
ml_score = row['ml_score']
|
||||
final_score = row['final_score']
|
||||
|
||||
if rule_score >= config['rule_trigger']:
|
||||
trigger = f'规则强信号({rule_score:.0f}分)'
|
||||
elif ml_score >= config['ml_trigger']:
|
||||
trigger = f'ML强信号({ml_score:.0f}分)'
|
||||
else:
|
||||
trigger = f'融合触发({final_score:.0f}分)'
|
||||
|
||||
alerts.append({
|
||||
'concept_id': row['concept_id'],
|
||||
'alert_time': row['timestamp'],
|
||||
'trade_date': date,
|
||||
'alert_type': atype,
|
||||
'final_score': final_score,
|
||||
'rule_score': rule_score,
|
||||
'ml_score': ml_score,
|
||||
'trigger_reason': trigger,
|
||||
'triggered_rules': row['triggered_rules'],
|
||||
'alpha': alpha,
|
||||
'alpha_delta': row['alpha_delta'],
|
||||
'amt_ratio': row['amt_ratio'],
|
||||
'amt_delta': row['amt_delta'],
|
||||
'rank_pct': row['rank_pct'],
|
||||
'limit_up_ratio': row['limit_up_ratio'],
|
||||
'stock_count': row['stock_count'],
|
||||
'total_amt': row['total_amt'],
|
||||
})
|
||||
|
||||
return alerts
|
||||
|
||||
|
||||
# ==================== 数据加载 ====================
|
||||
|
||||
def load_daily_features(data_dir: str, date: str) -> Optional[pd.DataFrame]:
|
||||
file_path = Path(data_dir) / f"features_{date}.parquet"
|
||||
if not file_path.exists():
|
||||
return None
|
||||
return pd.read_parquet(file_path)
|
||||
|
||||
|
||||
def get_available_dates(data_dir: str, start: str, end: str) -> List[str]:
|
||||
data_path = Path(data_dir)
|
||||
dates = []
|
||||
for f in sorted(data_path.glob("features_*.parquet")):
|
||||
d = f.stem.replace('features_', '')
|
||||
if start <= d <= end:
|
||||
dates.append(d)
|
||||
return dates
|
||||
|
||||
|
||||
def get_prev_trading_day(data_dir: str, date: str) -> Optional[str]:
|
||||
"""获取给定日期之前最近的有数据的交易日"""
|
||||
data_path = Path(data_dir)
|
||||
all_dates = sorted([f.stem.replace('features_', '') for f in data_path.glob("features_*.parquet")])
|
||||
|
||||
for i, d in enumerate(all_dates):
|
||||
if d == date and i > 0:
|
||||
return all_dates[i - 1]
|
||||
return None
|
||||
|
||||
|
||||
def export_to_csv(alerts: List[Dict], path: str):
|
||||
if alerts:
|
||||
pd.DataFrame(alerts).to_csv(path, index=False, encoding='utf-8-sig')
|
||||
print(f"已导出: {path}")
|
||||
|
||||
|
||||
# ==================== 数据库写入 ====================
|
||||
|
||||
def init_db_table():
|
||||
"""
|
||||
初始化数据库表(如果不存在则创建)
|
||||
|
||||
表结构说明:
|
||||
- concept_id: 概念ID
|
||||
- alert_time: 异动时间(精确到分钟)
|
||||
- trade_date: 交易日期
|
||||
- alert_type: 异动类型(surge_up/surge_down/volume_spike/unknown)
|
||||
- final_score: 最终得分(0-100)
|
||||
- rule_score: 规则得分(0-100)
|
||||
- ml_score: ML得分(0-100)
|
||||
- trigger_reason: 触发原因
|
||||
- alpha: 超额收益率
|
||||
- alpha_delta: alpha变化速度
|
||||
- amt_ratio: 成交额放大倍数
|
||||
- rank_pct: 排名百分位
|
||||
- stock_count: 概念股票数量
|
||||
- triggered_rules: 触发的规则列表(JSON)
|
||||
"""
|
||||
create_sql = text("""
|
||||
CREATE TABLE IF NOT EXISTS concept_anomaly_hybrid (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
concept_id VARCHAR(64) NOT NULL,
|
||||
alert_time DATETIME NOT NULL,
|
||||
trade_date DATE NOT NULL,
|
||||
alert_type VARCHAR(32) NOT NULL,
|
||||
final_score FLOAT NOT NULL,
|
||||
rule_score FLOAT NOT NULL,
|
||||
ml_score FLOAT NOT NULL,
|
||||
trigger_reason VARCHAR(64),
|
||||
alpha FLOAT,
|
||||
alpha_delta FLOAT,
|
||||
amt_ratio FLOAT,
|
||||
amt_delta FLOAT,
|
||||
rank_pct FLOAT,
|
||||
limit_up_ratio FLOAT,
|
||||
stock_count INT,
|
||||
total_amt FLOAT,
|
||||
triggered_rules JSON,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
UNIQUE KEY uk_concept_time (concept_id, alert_time, trade_date),
|
||||
INDEX idx_trade_date (trade_date),
|
||||
INDEX idx_concept_id (concept_id),
|
||||
INDEX idx_final_score (final_score),
|
||||
INDEX idx_alert_type (alert_type)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='概念异动检测结果(融合版)'
|
||||
""")
|
||||
|
||||
with MYSQL_ENGINE.begin() as conn:
|
||||
conn.execute(create_sql)
|
||||
print("数据库表已就绪: concept_anomaly_hybrid")
|
||||
|
||||
|
||||
def save_alerts_to_mysql(alerts: List[Dict], dry_run: bool = False) -> int:
|
||||
"""
|
||||
保存异动到 MySQL
|
||||
|
||||
Args:
|
||||
alerts: 异动列表
|
||||
dry_run: 是否只模拟,不实际写入
|
||||
|
||||
Returns:
|
||||
实际保存的记录数
|
||||
"""
|
||||
if not alerts:
|
||||
return 0
|
||||
|
||||
if dry_run:
|
||||
print(f" [Dry Run] 将写入 {len(alerts)} 条异动")
|
||||
return len(alerts)
|
||||
|
||||
saved = 0
|
||||
skipped = 0
|
||||
|
||||
with MYSQL_ENGINE.begin() as conn:
|
||||
for alert in alerts:
|
||||
try:
|
||||
# 检查是否已存在(使用 INSERT IGNORE 更高效)
|
||||
insert_sql = text("""
|
||||
INSERT IGNORE INTO concept_anomaly_hybrid
|
||||
(concept_id, alert_time, trade_date, alert_type,
|
||||
final_score, rule_score, ml_score, trigger_reason,
|
||||
alpha, alpha_delta, amt_ratio, amt_delta,
|
||||
rank_pct, limit_up_ratio, stock_count, total_amt,
|
||||
triggered_rules)
|
||||
VALUES
|
||||
(:concept_id, :alert_time, :trade_date, :alert_type,
|
||||
:final_score, :rule_score, :ml_score, :trigger_reason,
|
||||
:alpha, :alpha_delta, :amt_ratio, :amt_delta,
|
||||
:rank_pct, :limit_up_ratio, :stock_count, :total_amt,
|
||||
:triggered_rules)
|
||||
""")
|
||||
|
||||
result = conn.execute(insert_sql, {
|
||||
'concept_id': alert['concept_id'],
|
||||
'alert_time': alert['alert_time'],
|
||||
'trade_date': alert['trade_date'],
|
||||
'alert_type': alert['alert_type'],
|
||||
'final_score': alert['final_score'],
|
||||
'rule_score': alert['rule_score'],
|
||||
'ml_score': alert['ml_score'],
|
||||
'trigger_reason': alert['trigger_reason'],
|
||||
'alpha': alert.get('alpha', 0),
|
||||
'alpha_delta': alert.get('alpha_delta', 0),
|
||||
'amt_ratio': alert.get('amt_ratio', 1),
|
||||
'amt_delta': alert.get('amt_delta', 0),
|
||||
'rank_pct': alert.get('rank_pct', 0.5),
|
||||
'limit_up_ratio': alert.get('limit_up_ratio', 0),
|
||||
'stock_count': alert.get('stock_count', 0),
|
||||
'total_amt': alert.get('total_amt', 0),
|
||||
'triggered_rules': json.dumps(alert.get('triggered_rules', []), ensure_ascii=False),
|
||||
})
|
||||
|
||||
if result.rowcount > 0:
|
||||
saved += 1
|
||||
else:
|
||||
skipped += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f" 保存失败: {alert['concept_id']} @ {alert['alert_time']} - {e}")
|
||||
|
||||
if skipped > 0:
|
||||
print(f" 跳过 {skipped} 条重复记录")
|
||||
|
||||
return saved
|
||||
|
||||
|
||||
def clear_alerts_by_date(trade_date: str) -> int:
|
||||
"""清除指定日期的异动记录(用于重新回测)"""
|
||||
with MYSQL_ENGINE.begin() as conn:
|
||||
result = conn.execute(
|
||||
text("DELETE FROM concept_anomaly_hybrid WHERE trade_date = :trade_date"),
|
||||
{'trade_date': trade_date}
|
||||
)
|
||||
return result.rowcount
|
||||
|
||||
|
||||
def analyze_alerts(alerts: List[Dict]):
|
||||
if not alerts:
|
||||
print("无异动")
|
||||
return
|
||||
|
||||
df = pd.DataFrame(alerts)
|
||||
print(f"\n总异动: {len(alerts)}")
|
||||
print(f"\n类型分布:\n{df['alert_type'].value_counts()}")
|
||||
print(f"\n得分统计:")
|
||||
print(f" 最终: {df['final_score'].mean():.1f} (max: {df['final_score'].max():.1f})")
|
||||
print(f" 规则: {df['rule_score'].mean():.1f} (max: {df['rule_score'].max():.1f})")
|
||||
print(f" ML: {df['ml_score'].mean():.1f} (max: {df['ml_score'].max():.1f})")
|
||||
|
||||
trigger_type = df['trigger_reason'].apply(
|
||||
lambda x: '规则' if '规则' in x else ('ML' if 'ML' in x else '融合')
|
||||
)
|
||||
print(f"\n触发来源:\n{trigger_type.value_counts()}")
|
||||
|
||||
|
||||
# ==================== 主函数 ====================
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='快速融合异动回测')
|
||||
parser.add_argument('--data_dir', default='ml/data')
|
||||
parser.add_argument('--checkpoint_dir', default='ml/checkpoints')
|
||||
parser.add_argument('--start', required=True)
|
||||
parser.add_argument('--end', default=None)
|
||||
parser.add_argument('--dry-run', action='store_true', help='模拟运行,不写入数据库')
|
||||
parser.add_argument('--export-csv', default=None, help='导出 CSV 文件路径')
|
||||
parser.add_argument('--save-db', action='store_true', help='保存结果到数据库')
|
||||
parser.add_argument('--clear-first', action='store_true', help='写入前先清除该日期的旧数据')
|
||||
parser.add_argument('--device', default='auto')
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.end is None:
|
||||
args.end = args.start
|
||||
|
||||
print("=" * 60)
|
||||
print("快速融合异动回测")
|
||||
print("=" * 60)
|
||||
print(f"日期: {args.start} ~ {args.end}")
|
||||
print(f"设备: {args.device}")
|
||||
print(f"保存数据库: {args.save_db}")
|
||||
print("=" * 60)
|
||||
|
||||
# 初始化数据库表(如果需要保存)
|
||||
if args.save_db and not args.dry_run:
|
||||
init_db_table()
|
||||
|
||||
# 初始化 ML 评分器
|
||||
ml_scorer = FastMLScorer(args.checkpoint_dir, args.device)
|
||||
|
||||
# 获取日期
|
||||
dates = get_available_dates(args.data_dir, args.start, args.end)
|
||||
if not dates:
|
||||
print("无数据")
|
||||
return
|
||||
|
||||
print(f"找到 {len(dates)} 天数据\n")
|
||||
|
||||
# 回测(支持跨日序列)
|
||||
all_alerts = []
|
||||
total_saved = 0
|
||||
prev_df = None # 缓存前一天数据
|
||||
|
||||
for i, date in enumerate(tqdm(dates, desc="回测")):
|
||||
df = load_daily_features(args.data_dir, date)
|
||||
if df is None or df.empty:
|
||||
prev_df = None # 当天无数据,清空缓存
|
||||
continue
|
||||
|
||||
# 第一天需要加载前一天数据(如果存在)
|
||||
if i == 0 and prev_df is None:
|
||||
prev_date = get_prev_trading_day(args.data_dir, date)
|
||||
if prev_date:
|
||||
prev_df = load_daily_features(args.data_dir, prev_date)
|
||||
if prev_df is not None:
|
||||
tqdm.write(f" 加载前一天数据: {prev_date}")
|
||||
|
||||
alerts = backtest_single_day_fast(ml_scorer, df, date, CONFIG, prev_df)
|
||||
all_alerts.extend(alerts)
|
||||
|
||||
# 保存到数据库
|
||||
if args.save_db and alerts:
|
||||
if args.clear_first and not args.dry_run:
|
||||
cleared = clear_alerts_by_date(date)
|
||||
if cleared > 0:
|
||||
tqdm.write(f" 清除 {date} 旧数据: {cleared} 条")
|
||||
|
||||
saved = save_alerts_to_mysql(alerts, dry_run=args.dry_run)
|
||||
total_saved += saved
|
||||
tqdm.write(f" {date}: {len(alerts)} 个异动, 保存 {saved} 条")
|
||||
elif alerts:
|
||||
tqdm.write(f" {date}: {len(alerts)} 个异动")
|
||||
|
||||
# 当天数据成为下一天的 prev_df
|
||||
prev_df = df
|
||||
|
||||
# 导出 CSV
|
||||
if args.export_csv:
|
||||
export_to_csv(all_alerts, args.export_csv)
|
||||
|
||||
# 分析
|
||||
analyze_alerts(all_alerts)
|
||||
|
||||
print(f"\n总计: {len(all_alerts)} 个异动")
|
||||
if args.save_db:
|
||||
print(f"已保存到数据库: {total_saved} 条")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user