860 lines
29 KiB
Python
860 lines
29 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
快速融合异动回测脚本
|
||
|
||
优化策略:
|
||
1. 预先构建所有序列(向量化),避免循环内重复切片
|
||
2. 批量 ML 推理(一次推理所有候选)
|
||
3. 使用 NumPy 向量化操作替代 Python 循环
|
||
|
||
性能对比:
|
||
- 原版:5分钟/天
|
||
- 优化版:预计 10-30秒/天
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import argparse
|
||
import json
|
||
from datetime import datetime
|
||
from pathlib import Path
|
||
from typing import Dict, List, Optional, Tuple
|
||
from collections import defaultdict
|
||
|
||
import numpy as np
|
||
import pandas as pd
|
||
import torch
|
||
from tqdm import tqdm
|
||
from sqlalchemy import create_engine, text
|
||
|
||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
|
||
|
||
# ==================== 配置 ====================
|
||
|
||
MYSQL_ENGINE = create_engine(
|
||
"mysql+pymysql://root:Zzl5588161!@192.168.1.5:3306/stock",
|
||
echo=False
|
||
)
|
||
|
||
FEATURES = ['alpha', 'alpha_delta', 'amt_ratio', 'amt_delta', 'rank_pct', 'limit_up_ratio']
|
||
|
||
CONFIG = {
|
||
'seq_len': 15, # 序列长度(支持跨日后可从 9:30 检测)
|
||
'min_alpha_abs': 0.3, # 最小 alpha 过滤
|
||
'cooldown_minutes': 8,
|
||
'max_alerts_per_minute': 20,
|
||
'clip_value': 10.0,
|
||
# === 融合权重:均衡 ===
|
||
'rule_weight': 0.5,
|
||
'ml_weight': 0.5,
|
||
# === 触发阈值 ===
|
||
'rule_trigger': 65, # 60 -> 65,略提高规则门槛
|
||
'ml_trigger': 70, # 75 -> 70,略降低 ML 门槛
|
||
'fusion_trigger': 45,
|
||
}
|
||
|
||
|
||
# ==================== 规则评分(向量化版)====================
|
||
|
||
def get_size_adjusted_thresholds(stock_count: np.ndarray) -> dict:
|
||
"""
|
||
根据概念股票数量计算动态阈值
|
||
|
||
设计思路:
|
||
- 小概念(<10 只):波动大是正常的,需要更高阈值
|
||
- 中概念(10-50 只):标准阈值
|
||
- 大概念(>50 只):能有明显波动说明是真异动,降低阈值
|
||
|
||
返回各指标的调整系数(乘以基准阈值)
|
||
"""
|
||
n = len(stock_count)
|
||
|
||
# 基于股票数量的调整系数
|
||
# 小概念:系数 > 1(提高阈值,更难触发)
|
||
# 大概念:系数 < 1(降低阈值,更容易触发)
|
||
size_factor = np.ones(n)
|
||
|
||
# 微型概念(<5 只):阈值 × 1.8
|
||
tiny = stock_count < 5
|
||
size_factor[tiny] = 1.8
|
||
|
||
# 小概念(5-10 只):阈值 × 1.4
|
||
small = (stock_count >= 5) & (stock_count < 10)
|
||
size_factor[small] = 1.4
|
||
|
||
# 中小概念(10-20 只):阈值 × 1.2
|
||
medium_small = (stock_count >= 10) & (stock_count < 20)
|
||
size_factor[medium_small] = 1.2
|
||
|
||
# 中概念(20-50 只):标准阈值 × 1.0
|
||
medium = (stock_count >= 20) & (stock_count < 50)
|
||
size_factor[medium] = 1.0
|
||
|
||
# 大概念(50-100 只):阈值 × 0.85
|
||
large = (stock_count >= 50) & (stock_count < 100)
|
||
size_factor[large] = 0.85
|
||
|
||
# 超大概念(>100 只):阈值 × 0.7
|
||
xlarge = stock_count >= 100
|
||
size_factor[xlarge] = 0.7
|
||
|
||
return size_factor
|
||
|
||
|
||
def score_rules_batch(df: pd.DataFrame) -> Tuple[np.ndarray, List[List[str]]]:
|
||
"""
|
||
批量计算规则得分(向量化)- 考虑概念规模版
|
||
|
||
设计原则:
|
||
- 规则作为辅助信号,不应单独主导决策
|
||
- 根据概念股票数量动态调整阈值
|
||
- 大概念异动更有价值,小概念需要更大波动才算异动
|
||
|
||
Args:
|
||
df: DataFrame,包含所有特征列(必须包含 stock_count)
|
||
Returns:
|
||
scores: (n,) 规则得分数组
|
||
triggered_rules: 每行触发的规则列表
|
||
"""
|
||
n = len(df)
|
||
scores = np.zeros(n)
|
||
triggered = [[] for _ in range(n)]
|
||
|
||
alpha = df['alpha'].values
|
||
alpha_delta = df['alpha_delta'].values
|
||
amt_ratio = df['amt_ratio'].values
|
||
amt_delta = df['amt_delta'].values
|
||
rank_pct = df['rank_pct'].values
|
||
limit_up_ratio = df['limit_up_ratio'].values
|
||
stock_count = df['stock_count'].values if 'stock_count' in df.columns else np.full(n, 20)
|
||
|
||
alpha_abs = np.abs(alpha)
|
||
alpha_delta_abs = np.abs(alpha_delta)
|
||
|
||
# 获取基于规模的调整系数
|
||
size_factor = get_size_adjusted_thresholds(stock_count)
|
||
|
||
# ========== Alpha 规则(动态阈值)==========
|
||
# 基准阈值:极强 5%,强 4%,中等 3%
|
||
# 实际阈值 = 基准 × size_factor
|
||
|
||
# 极强信号
|
||
alpha_extreme_thresh = 5.0 * size_factor
|
||
mask = alpha_abs >= alpha_extreme_thresh
|
||
scores[mask] += 20
|
||
for i in np.where(mask)[0]: triggered[i].append('alpha_extreme')
|
||
|
||
# 强信号
|
||
alpha_strong_thresh = 4.0 * size_factor
|
||
mask = (alpha_abs >= alpha_strong_thresh) & (alpha_abs < alpha_extreme_thresh)
|
||
scores[mask] += 15
|
||
for i in np.where(mask)[0]: triggered[i].append('alpha_strong')
|
||
|
||
# 中等信号
|
||
alpha_medium_thresh = 3.0 * size_factor
|
||
mask = (alpha_abs >= alpha_medium_thresh) & (alpha_abs < alpha_strong_thresh)
|
||
scores[mask] += 10
|
||
for i in np.where(mask)[0]: triggered[i].append('alpha_medium')
|
||
|
||
# ========== Alpha 加速度规则(动态阈值)==========
|
||
delta_strong_thresh = 2.0 * size_factor
|
||
mask = alpha_delta_abs >= delta_strong_thresh
|
||
scores[mask] += 15
|
||
for i in np.where(mask)[0]: triggered[i].append('alpha_delta_strong')
|
||
|
||
delta_medium_thresh = 1.5 * size_factor
|
||
mask = (alpha_delta_abs >= delta_medium_thresh) & (alpha_delta_abs < delta_strong_thresh)
|
||
scores[mask] += 10
|
||
for i in np.where(mask)[0]: triggered[i].append('alpha_delta_medium')
|
||
|
||
# ========== 成交额规则(不受规模影响,放量就是放量)==========
|
||
mask = amt_ratio >= 10.0
|
||
scores[mask] += 20
|
||
for i in np.where(mask)[0]: triggered[i].append('volume_extreme')
|
||
|
||
mask = (amt_ratio >= 6.0) & (amt_ratio < 10.0)
|
||
scores[mask] += 12
|
||
for i in np.where(mask)[0]: triggered[i].append('volume_strong')
|
||
|
||
# ========== 排名规则 ==========
|
||
mask = rank_pct >= 0.98
|
||
scores[mask] += 15
|
||
for i in np.where(mask)[0]: triggered[i].append('rank_top')
|
||
|
||
mask = rank_pct <= 0.02
|
||
scores[mask] += 15
|
||
for i in np.where(mask)[0]: triggered[i].append('rank_bottom')
|
||
|
||
# ========== 涨停规则(动态阈值)==========
|
||
# 大概念有涨停更有意义
|
||
limit_high_thresh = 0.30 * size_factor
|
||
mask = limit_up_ratio >= limit_high_thresh
|
||
scores[mask] += 20
|
||
for i in np.where(mask)[0]: triggered[i].append('limit_up_high')
|
||
|
||
limit_medium_thresh = 0.20 * size_factor
|
||
mask = (limit_up_ratio >= limit_medium_thresh) & (limit_up_ratio < limit_high_thresh)
|
||
scores[mask] += 12
|
||
for i in np.where(mask)[0]: triggered[i].append('limit_up_medium')
|
||
|
||
# ========== 概念规模加分(大概念异动更有价值)==========
|
||
# 大概念(50+)额外加分
|
||
large_concept = stock_count >= 50
|
||
has_signal = scores > 0 # 至少触发了某个规则
|
||
mask = large_concept & has_signal
|
||
scores[mask] += 10
|
||
for i in np.where(mask)[0]: triggered[i].append('large_concept_bonus')
|
||
|
||
# 超大概念(100+)再加分
|
||
xlarge_concept = stock_count >= 100
|
||
mask = xlarge_concept & has_signal
|
||
scores[mask] += 10
|
||
for i in np.where(mask)[0]: triggered[i].append('xlarge_concept_bonus')
|
||
|
||
# ========== 组合规则(动态阈值)==========
|
||
combo_alpha_thresh = 3.0 * size_factor
|
||
|
||
# Alpha + 放量 + 排名(三重验证)
|
||
mask = (alpha_abs >= combo_alpha_thresh) & (amt_ratio >= 5.0) & ((rank_pct >= 0.95) | (rank_pct <= 0.05))
|
||
scores[mask] += 20
|
||
for i in np.where(mask)[0]: triggered[i].append('triple_signal')
|
||
|
||
# Alpha + 涨停(强组合)
|
||
mask = (alpha_abs >= combo_alpha_thresh) & (limit_up_ratio >= 0.15 * size_factor)
|
||
scores[mask] += 15
|
||
for i in np.where(mask)[0]: triggered[i].append('alpha_with_limit')
|
||
|
||
# ========== 小概念惩罚(过滤噪音)==========
|
||
# 微型概念(<5 只)如果只有单一信号,减分
|
||
tiny_concept = stock_count < 5
|
||
single_rule = np.array([len(t) <= 1 for t in triggered])
|
||
mask = tiny_concept & single_rule & (scores > 0)
|
||
scores[mask] *= 0.5 # 减半
|
||
for i in np.where(mask)[0]: triggered[i].append('tiny_concept_penalty')
|
||
|
||
scores = np.clip(scores, 0, 100)
|
||
return scores, triggered
|
||
|
||
|
||
# ==================== ML 评分器 ====================
|
||
|
||
class FastMLScorer:
|
||
"""快速 ML 评分器"""
|
||
|
||
def __init__(self, checkpoint_dir: str = 'ml/checkpoints', device: str = 'auto'):
|
||
self.checkpoint_dir = Path(checkpoint_dir)
|
||
|
||
if device == 'auto':
|
||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||
elif device == 'cuda' and not torch.cuda.is_available():
|
||
print("警告: CUDA 不可用,使用 CPU")
|
||
self.device = torch.device('cpu')
|
||
else:
|
||
self.device = torch.device(device)
|
||
|
||
self.model = None
|
||
self.thresholds = None
|
||
self._load_model()
|
||
|
||
def _load_model(self):
|
||
model_path = self.checkpoint_dir / 'best_model.pt'
|
||
thresholds_path = self.checkpoint_dir / 'thresholds.json'
|
||
config_path = self.checkpoint_dir / 'config.json'
|
||
|
||
if not model_path.exists():
|
||
print(f"警告: 模型不存在 {model_path}")
|
||
return
|
||
|
||
try:
|
||
from model import LSTMAutoencoder
|
||
|
||
config = {}
|
||
if config_path.exists():
|
||
with open(config_path) as f:
|
||
config = json.load(f).get('model', {})
|
||
|
||
# 处理旧配置键名
|
||
if 'd_model' in config:
|
||
config['hidden_dim'] = config.pop('d_model') // 2
|
||
for key in ['num_encoder_layers', 'num_decoder_layers', 'nhead', 'dim_feedforward', 'max_seq_len', 'use_instance_norm']:
|
||
config.pop(key, None)
|
||
if 'num_layers' not in config:
|
||
config['num_layers'] = 1
|
||
|
||
checkpoint = torch.load(model_path, map_location='cpu')
|
||
self.model = LSTMAutoencoder(**config)
|
||
self.model.load_state_dict(checkpoint['model_state_dict'])
|
||
self.model.to(self.device)
|
||
self.model.eval()
|
||
|
||
if thresholds_path.exists():
|
||
with open(thresholds_path) as f:
|
||
self.thresholds = json.load(f)
|
||
|
||
print(f"ML模型加载成功 (设备: {self.device})")
|
||
except Exception as e:
|
||
print(f"ML模型加载失败: {e}")
|
||
self.model = None
|
||
|
||
def is_ready(self):
|
||
return self.model is not None
|
||
|
||
@torch.no_grad()
|
||
def score_batch(self, sequences: np.ndarray) -> np.ndarray:
|
||
"""
|
||
批量计算 ML 得分
|
||
|
||
Args:
|
||
sequences: (batch, seq_len, n_features)
|
||
Returns:
|
||
scores: (batch,) 0-100 分数
|
||
"""
|
||
if not self.is_ready() or len(sequences) == 0:
|
||
return np.zeros(len(sequences))
|
||
|
||
x = torch.FloatTensor(sequences).to(self.device)
|
||
output, _ = self.model(x)
|
||
mse = ((output - x) ** 2).mean(dim=-1)
|
||
errors = mse[:, -1].cpu().numpy()
|
||
|
||
p95 = self.thresholds.get('p95', 0.1) if self.thresholds else 0.1
|
||
scores = np.clip(errors / p95 * 50, 0, 100)
|
||
return scores
|
||
|
||
|
||
# ==================== 快速回测 ====================
|
||
|
||
def build_sequences_fast(
|
||
df: pd.DataFrame,
|
||
seq_len: int = 30,
|
||
prev_df: pd.DataFrame = None
|
||
) -> Tuple[np.ndarray, pd.DataFrame]:
|
||
"""
|
||
快速构建所有有效序列
|
||
|
||
支持跨日序列:用前一天收盘数据 + 当天开盘数据拼接,实现 9:30 就能检测
|
||
|
||
Args:
|
||
df: 当天数据
|
||
seq_len: 序列长度
|
||
prev_df: 前一天数据(可选,用于构建开盘时的序列)
|
||
|
||
返回:
|
||
sequences: (n_valid, seq_len, n_features) 所有有效序列
|
||
info_df: 对应的元信息 DataFrame
|
||
"""
|
||
# 确保按概念和时间排序
|
||
df = df.sort_values(['concept_id', 'timestamp']).reset_index(drop=True)
|
||
|
||
# 如果有前一天数据,按概念构建尾部缓存(取每个概念最后 seq_len-1 条)
|
||
prev_cache = {}
|
||
if prev_df is not None and len(prev_df) > 0:
|
||
prev_df = prev_df.sort_values(['concept_id', 'timestamp'])
|
||
for concept_id, gdf in prev_df.groupby('concept_id'):
|
||
tail_data = gdf.tail(seq_len - 1)
|
||
if len(tail_data) > 0:
|
||
feat_matrix = tail_data[FEATURES].values
|
||
feat_matrix = np.nan_to_num(feat_matrix, nan=0.0, posinf=0.0, neginf=0.0)
|
||
feat_matrix = np.clip(feat_matrix, -CONFIG['clip_value'], CONFIG['clip_value'])
|
||
prev_cache[concept_id] = feat_matrix
|
||
|
||
# 按概念分组
|
||
groups = df.groupby('concept_id')
|
||
|
||
sequences = []
|
||
infos = []
|
||
|
||
for concept_id, gdf in groups:
|
||
gdf = gdf.reset_index(drop=True)
|
||
|
||
# 获取特征矩阵
|
||
feat_matrix = gdf[FEATURES].values
|
||
feat_matrix = np.nan_to_num(feat_matrix, nan=0.0, posinf=0.0, neginf=0.0)
|
||
feat_matrix = np.clip(feat_matrix, -CONFIG['clip_value'], CONFIG['clip_value'])
|
||
|
||
# 如果有前一天缓存,拼接到当天数据前面
|
||
if concept_id in prev_cache:
|
||
prev_data = prev_cache[concept_id]
|
||
combined_matrix = np.vstack([prev_data, feat_matrix])
|
||
# 计算偏移量:前一天数据的长度
|
||
offset = len(prev_data)
|
||
else:
|
||
combined_matrix = feat_matrix
|
||
offset = 0
|
||
|
||
# 滑动窗口构建序列
|
||
n_total = len(combined_matrix)
|
||
if n_total < seq_len:
|
||
continue
|
||
|
||
for i in range(n_total - seq_len + 1):
|
||
seq = combined_matrix[i:i + seq_len]
|
||
|
||
# 计算对应当天数据的索引
|
||
# 序列最后一个点的位置 = i + seq_len - 1
|
||
# 对应当天数据的索引 = (i + seq_len - 1) - offset
|
||
today_idx = i + seq_len - 1 - offset
|
||
|
||
# 只要序列的最后一个点是当天的数据,就记录
|
||
if today_idx < 0 or today_idx >= len(gdf):
|
||
continue
|
||
|
||
sequences.append(seq)
|
||
|
||
# 记录最后一个时间步的信息(当天的)
|
||
row = gdf.iloc[today_idx]
|
||
infos.append({
|
||
'concept_id': concept_id,
|
||
'timestamp': row['timestamp'],
|
||
'alpha': row['alpha'],
|
||
'alpha_delta': row.get('alpha_delta', 0),
|
||
'amt_ratio': row.get('amt_ratio', 1),
|
||
'amt_delta': row.get('amt_delta', 0),
|
||
'rank_pct': row.get('rank_pct', 0.5),
|
||
'limit_up_ratio': row.get('limit_up_ratio', 0),
|
||
'stock_count': row.get('stock_count', 0),
|
||
'total_amt': row.get('total_amt', 0),
|
||
})
|
||
|
||
if not sequences:
|
||
return np.array([]), pd.DataFrame()
|
||
|
||
return np.array(sequences), pd.DataFrame(infos)
|
||
|
||
|
||
def backtest_single_day_fast(
|
||
ml_scorer: FastMLScorer,
|
||
df: pd.DataFrame,
|
||
date: str,
|
||
config: Dict,
|
||
prev_df: pd.DataFrame = None
|
||
) -> List[Dict]:
|
||
"""
|
||
快速回测单天(向量化版本)
|
||
|
||
Args:
|
||
ml_scorer: ML 评分器
|
||
df: 当天数据
|
||
date: 日期
|
||
config: 配置
|
||
prev_df: 前一天数据(用于 9:30 开始检测)
|
||
"""
|
||
seq_len = config.get('seq_len', 30)
|
||
|
||
# 1. 构建所有序列(支持跨日)
|
||
sequences, info_df = build_sequences_fast(df, seq_len, prev_df)
|
||
|
||
if len(sequences) == 0:
|
||
return []
|
||
|
||
# 2. 过滤小波动
|
||
alpha_abs = np.abs(info_df['alpha'].values)
|
||
valid_mask = alpha_abs >= config['min_alpha_abs']
|
||
|
||
sequences = sequences[valid_mask]
|
||
info_df = info_df[valid_mask].reset_index(drop=True)
|
||
|
||
if len(sequences) == 0:
|
||
return []
|
||
|
||
# 3. 批量规则评分
|
||
rule_scores, triggered_rules = score_rules_batch(info_df)
|
||
|
||
# 4. 批量 ML 评分(分批处理避免显存溢出)
|
||
batch_size = 2048
|
||
ml_scores = []
|
||
for i in range(0, len(sequences), batch_size):
|
||
batch_seq = sequences[i:i+batch_size]
|
||
batch_scores = ml_scorer.score_batch(batch_seq)
|
||
ml_scores.append(batch_scores)
|
||
ml_scores = np.concatenate(ml_scores) if ml_scores else np.zeros(len(sequences))
|
||
|
||
# 5. 融合得分
|
||
w1, w2 = config['rule_weight'], config['ml_weight']
|
||
final_scores = w1 * rule_scores + w2 * ml_scores
|
||
|
||
# 6. 判断异动
|
||
is_anomaly = (
|
||
(rule_scores >= config['rule_trigger']) |
|
||
(ml_scores >= config['ml_trigger']) |
|
||
(final_scores >= config['fusion_trigger'])
|
||
)
|
||
|
||
# 7. 应用冷却期(按概念+时间排序后处理)
|
||
info_df['rule_score'] = rule_scores
|
||
info_df['ml_score'] = ml_scores
|
||
info_df['final_score'] = final_scores
|
||
info_df['is_anomaly'] = is_anomaly
|
||
info_df['triggered_rules'] = triggered_rules
|
||
|
||
# 只保留异动
|
||
anomaly_df = info_df[info_df['is_anomaly']].copy()
|
||
|
||
if len(anomaly_df) == 0:
|
||
return []
|
||
|
||
# 应用冷却期
|
||
anomaly_df = anomaly_df.sort_values(['concept_id', 'timestamp'])
|
||
cooldown = {}
|
||
keep_mask = []
|
||
|
||
for _, row in anomaly_df.iterrows():
|
||
cid = row['concept_id']
|
||
ts = row['timestamp']
|
||
|
||
if cid in cooldown:
|
||
try:
|
||
diff = (ts - cooldown[cid]).total_seconds() / 60
|
||
except:
|
||
diff = config['cooldown_minutes'] + 1
|
||
|
||
if diff < config['cooldown_minutes']:
|
||
keep_mask.append(False)
|
||
continue
|
||
|
||
cooldown[cid] = ts
|
||
keep_mask.append(True)
|
||
|
||
anomaly_df = anomaly_df[keep_mask]
|
||
|
||
# 8. 按时间分组,每分钟最多 max_alerts_per_minute 个
|
||
alerts = []
|
||
for ts, group in anomaly_df.groupby('timestamp'):
|
||
group = group.nlargest(config['max_alerts_per_minute'], 'final_score')
|
||
|
||
for _, row in group.iterrows():
|
||
alpha = row['alpha']
|
||
if alpha >= 1.5:
|
||
atype = 'surge_up'
|
||
elif alpha <= -1.5:
|
||
atype = 'surge_down'
|
||
elif row['amt_ratio'] >= 3.0:
|
||
atype = 'volume_spike'
|
||
else:
|
||
atype = 'unknown'
|
||
|
||
rule_score = row['rule_score']
|
||
ml_score = row['ml_score']
|
||
final_score = row['final_score']
|
||
|
||
if rule_score >= config['rule_trigger']:
|
||
trigger = f'规则强信号({rule_score:.0f}分)'
|
||
elif ml_score >= config['ml_trigger']:
|
||
trigger = f'ML强信号({ml_score:.0f}分)'
|
||
else:
|
||
trigger = f'融合触发({final_score:.0f}分)'
|
||
|
||
alerts.append({
|
||
'concept_id': row['concept_id'],
|
||
'alert_time': row['timestamp'],
|
||
'trade_date': date,
|
||
'alert_type': atype,
|
||
'final_score': final_score,
|
||
'rule_score': rule_score,
|
||
'ml_score': ml_score,
|
||
'trigger_reason': trigger,
|
||
'triggered_rules': row['triggered_rules'],
|
||
'alpha': alpha,
|
||
'alpha_delta': row['alpha_delta'],
|
||
'amt_ratio': row['amt_ratio'],
|
||
'amt_delta': row['amt_delta'],
|
||
'rank_pct': row['rank_pct'],
|
||
'limit_up_ratio': row['limit_up_ratio'],
|
||
'stock_count': row['stock_count'],
|
||
'total_amt': row['total_amt'],
|
||
})
|
||
|
||
return alerts
|
||
|
||
|
||
# ==================== 数据加载 ====================
|
||
|
||
def load_daily_features(data_dir: str, date: str) -> Optional[pd.DataFrame]:
|
||
file_path = Path(data_dir) / f"features_{date}.parquet"
|
||
if not file_path.exists():
|
||
return None
|
||
return pd.read_parquet(file_path)
|
||
|
||
|
||
def get_available_dates(data_dir: str, start: str, end: str) -> List[str]:
|
||
data_path = Path(data_dir)
|
||
dates = []
|
||
for f in sorted(data_path.glob("features_*.parquet")):
|
||
d = f.stem.replace('features_', '')
|
||
if start <= d <= end:
|
||
dates.append(d)
|
||
return dates
|
||
|
||
|
||
def get_prev_trading_day(data_dir: str, date: str) -> Optional[str]:
|
||
"""获取给定日期之前最近的有数据的交易日"""
|
||
data_path = Path(data_dir)
|
||
all_dates = sorted([f.stem.replace('features_', '') for f in data_path.glob("features_*.parquet")])
|
||
|
||
for i, d in enumerate(all_dates):
|
||
if d == date and i > 0:
|
||
return all_dates[i - 1]
|
||
return None
|
||
|
||
|
||
def export_to_csv(alerts: List[Dict], path: str):
|
||
if alerts:
|
||
pd.DataFrame(alerts).to_csv(path, index=False, encoding='utf-8-sig')
|
||
print(f"已导出: {path}")
|
||
|
||
|
||
# ==================== 数据库写入 ====================
|
||
|
||
def init_db_table():
|
||
"""
|
||
初始化数据库表(如果不存在则创建)
|
||
|
||
表结构说明:
|
||
- concept_id: 概念ID
|
||
- alert_time: 异动时间(精确到分钟)
|
||
- trade_date: 交易日期
|
||
- alert_type: 异动类型(surge_up/surge_down/volume_spike/unknown)
|
||
- final_score: 最终得分(0-100)
|
||
- rule_score: 规则得分(0-100)
|
||
- ml_score: ML得分(0-100)
|
||
- trigger_reason: 触发原因
|
||
- alpha: 超额收益率
|
||
- alpha_delta: alpha变化速度
|
||
- amt_ratio: 成交额放大倍数
|
||
- rank_pct: 排名百分位
|
||
- stock_count: 概念股票数量
|
||
- triggered_rules: 触发的规则列表(JSON)
|
||
"""
|
||
create_sql = text("""
|
||
CREATE TABLE IF NOT EXISTS concept_anomaly_hybrid (
|
||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||
concept_id VARCHAR(64) NOT NULL,
|
||
alert_time DATETIME NOT NULL,
|
||
trade_date DATE NOT NULL,
|
||
alert_type VARCHAR(32) NOT NULL,
|
||
final_score FLOAT NOT NULL,
|
||
rule_score FLOAT NOT NULL,
|
||
ml_score FLOAT NOT NULL,
|
||
trigger_reason VARCHAR(64),
|
||
alpha FLOAT,
|
||
alpha_delta FLOAT,
|
||
amt_ratio FLOAT,
|
||
amt_delta FLOAT,
|
||
rank_pct FLOAT,
|
||
limit_up_ratio FLOAT,
|
||
stock_count INT,
|
||
total_amt FLOAT,
|
||
triggered_rules JSON,
|
||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||
UNIQUE KEY uk_concept_time (concept_id, alert_time, trade_date),
|
||
INDEX idx_trade_date (trade_date),
|
||
INDEX idx_concept_id (concept_id),
|
||
INDEX idx_final_score (final_score),
|
||
INDEX idx_alert_type (alert_type)
|
||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='概念异动检测结果(融合版)'
|
||
""")
|
||
|
||
with MYSQL_ENGINE.begin() as conn:
|
||
conn.execute(create_sql)
|
||
print("数据库表已就绪: concept_anomaly_hybrid")
|
||
|
||
|
||
def save_alerts_to_mysql(alerts: List[Dict], dry_run: bool = False) -> int:
|
||
"""
|
||
保存异动到 MySQL
|
||
|
||
Args:
|
||
alerts: 异动列表
|
||
dry_run: 是否只模拟,不实际写入
|
||
|
||
Returns:
|
||
实际保存的记录数
|
||
"""
|
||
if not alerts:
|
||
return 0
|
||
|
||
if dry_run:
|
||
print(f" [Dry Run] 将写入 {len(alerts)} 条异动")
|
||
return len(alerts)
|
||
|
||
saved = 0
|
||
skipped = 0
|
||
|
||
with MYSQL_ENGINE.begin() as conn:
|
||
for alert in alerts:
|
||
try:
|
||
# 检查是否已存在(使用 INSERT IGNORE 更高效)
|
||
insert_sql = text("""
|
||
INSERT IGNORE INTO concept_anomaly_hybrid
|
||
(concept_id, alert_time, trade_date, alert_type,
|
||
final_score, rule_score, ml_score, trigger_reason,
|
||
alpha, alpha_delta, amt_ratio, amt_delta,
|
||
rank_pct, limit_up_ratio, stock_count, total_amt,
|
||
triggered_rules)
|
||
VALUES
|
||
(:concept_id, :alert_time, :trade_date, :alert_type,
|
||
:final_score, :rule_score, :ml_score, :trigger_reason,
|
||
:alpha, :alpha_delta, :amt_ratio, :amt_delta,
|
||
:rank_pct, :limit_up_ratio, :stock_count, :total_amt,
|
||
:triggered_rules)
|
||
""")
|
||
|
||
result = conn.execute(insert_sql, {
|
||
'concept_id': alert['concept_id'],
|
||
'alert_time': alert['alert_time'],
|
||
'trade_date': alert['trade_date'],
|
||
'alert_type': alert['alert_type'],
|
||
'final_score': alert['final_score'],
|
||
'rule_score': alert['rule_score'],
|
||
'ml_score': alert['ml_score'],
|
||
'trigger_reason': alert['trigger_reason'],
|
||
'alpha': alert.get('alpha', 0),
|
||
'alpha_delta': alert.get('alpha_delta', 0),
|
||
'amt_ratio': alert.get('amt_ratio', 1),
|
||
'amt_delta': alert.get('amt_delta', 0),
|
||
'rank_pct': alert.get('rank_pct', 0.5),
|
||
'limit_up_ratio': alert.get('limit_up_ratio', 0),
|
||
'stock_count': alert.get('stock_count', 0),
|
||
'total_amt': alert.get('total_amt', 0),
|
||
'triggered_rules': json.dumps(alert.get('triggered_rules', []), ensure_ascii=False),
|
||
})
|
||
|
||
if result.rowcount > 0:
|
||
saved += 1
|
||
else:
|
||
skipped += 1
|
||
|
||
except Exception as e:
|
||
print(f" 保存失败: {alert['concept_id']} @ {alert['alert_time']} - {e}")
|
||
|
||
if skipped > 0:
|
||
print(f" 跳过 {skipped} 条重复记录")
|
||
|
||
return saved
|
||
|
||
|
||
def clear_alerts_by_date(trade_date: str) -> int:
|
||
"""清除指定日期的异动记录(用于重新回测)"""
|
||
with MYSQL_ENGINE.begin() as conn:
|
||
result = conn.execute(
|
||
text("DELETE FROM concept_anomaly_hybrid WHERE trade_date = :trade_date"),
|
||
{'trade_date': trade_date}
|
||
)
|
||
return result.rowcount
|
||
|
||
|
||
def analyze_alerts(alerts: List[Dict]):
|
||
if not alerts:
|
||
print("无异动")
|
||
return
|
||
|
||
df = pd.DataFrame(alerts)
|
||
print(f"\n总异动: {len(alerts)}")
|
||
print(f"\n类型分布:\n{df['alert_type'].value_counts()}")
|
||
print(f"\n得分统计:")
|
||
print(f" 最终: {df['final_score'].mean():.1f} (max: {df['final_score'].max():.1f})")
|
||
print(f" 规则: {df['rule_score'].mean():.1f} (max: {df['rule_score'].max():.1f})")
|
||
print(f" ML: {df['ml_score'].mean():.1f} (max: {df['ml_score'].max():.1f})")
|
||
|
||
trigger_type = df['trigger_reason'].apply(
|
||
lambda x: '规则' if '规则' in x else ('ML' if 'ML' in x else '融合')
|
||
)
|
||
print(f"\n触发来源:\n{trigger_type.value_counts()}")
|
||
|
||
|
||
# ==================== 主函数 ====================
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(description='快速融合异动回测')
|
||
parser.add_argument('--data_dir', default='ml/data')
|
||
parser.add_argument('--checkpoint_dir', default='ml/checkpoints')
|
||
parser.add_argument('--start', required=True)
|
||
parser.add_argument('--end', default=None)
|
||
parser.add_argument('--dry-run', action='store_true', help='模拟运行,不写入数据库')
|
||
parser.add_argument('--export-csv', default=None, help='导出 CSV 文件路径')
|
||
parser.add_argument('--save-db', action='store_true', help='保存结果到数据库')
|
||
parser.add_argument('--clear-first', action='store_true', help='写入前先清除该日期的旧数据')
|
||
parser.add_argument('--device', default='auto')
|
||
|
||
args = parser.parse_args()
|
||
if args.end is None:
|
||
args.end = args.start
|
||
|
||
print("=" * 60)
|
||
print("快速融合异动回测")
|
||
print("=" * 60)
|
||
print(f"日期: {args.start} ~ {args.end}")
|
||
print(f"设备: {args.device}")
|
||
print(f"保存数据库: {args.save_db}")
|
||
print("=" * 60)
|
||
|
||
# 初始化数据库表(如果需要保存)
|
||
if args.save_db and not args.dry_run:
|
||
init_db_table()
|
||
|
||
# 初始化 ML 评分器
|
||
ml_scorer = FastMLScorer(args.checkpoint_dir, args.device)
|
||
|
||
# 获取日期
|
||
dates = get_available_dates(args.data_dir, args.start, args.end)
|
||
if not dates:
|
||
print("无数据")
|
||
return
|
||
|
||
print(f"找到 {len(dates)} 天数据\n")
|
||
|
||
# 回测(支持跨日序列)
|
||
all_alerts = []
|
||
total_saved = 0
|
||
prev_df = None # 缓存前一天数据
|
||
|
||
for i, date in enumerate(tqdm(dates, desc="回测")):
|
||
df = load_daily_features(args.data_dir, date)
|
||
if df is None or df.empty:
|
||
prev_df = None # 当天无数据,清空缓存
|
||
continue
|
||
|
||
# 第一天需要加载前一天数据(如果存在)
|
||
if i == 0 and prev_df is None:
|
||
prev_date = get_prev_trading_day(args.data_dir, date)
|
||
if prev_date:
|
||
prev_df = load_daily_features(args.data_dir, prev_date)
|
||
if prev_df is not None:
|
||
tqdm.write(f" 加载前一天数据: {prev_date}")
|
||
|
||
alerts = backtest_single_day_fast(ml_scorer, df, date, CONFIG, prev_df)
|
||
all_alerts.extend(alerts)
|
||
|
||
# 保存到数据库
|
||
if args.save_db and alerts:
|
||
if args.clear_first and not args.dry_run:
|
||
cleared = clear_alerts_by_date(date)
|
||
if cleared > 0:
|
||
tqdm.write(f" 清除 {date} 旧数据: {cleared} 条")
|
||
|
||
saved = save_alerts_to_mysql(alerts, dry_run=args.dry_run)
|
||
total_saved += saved
|
||
tqdm.write(f" {date}: {len(alerts)} 个异动, 保存 {saved} 条")
|
||
elif alerts:
|
||
tqdm.write(f" {date}: {len(alerts)} 个异动")
|
||
|
||
# 当天数据成为下一天的 prev_df
|
||
prev_df = df
|
||
|
||
# 导出 CSV
|
||
if args.export_csv:
|
||
export_to_csv(all_alerts, args.export_csv)
|
||
|
||
# 分析
|
||
analyze_alerts(all_alerts)
|
||
|
||
print(f"\n总计: {len(all_alerts)} 个异动")
|
||
if args.save_db:
|
||
print(f"已保存到数据库: {total_saved} 条")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|