717 lines
23 KiB
Python
717 lines
23 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
异动检测器 V2 - 基于时间片对齐 + 持续性确认
|
||
|
||
核心改进:
|
||
1. Z-Score 特征:相对于同时间片历史的偏离
|
||
2. 短序列 LSTM:10分钟序列,开盘即可用
|
||
3. 持续性确认:5分钟窗口内60%时刻超标才确认为异动
|
||
|
||
检测流程:
|
||
1. 计算当前时刻的 Z-Score(对比同时间片历史基线)
|
||
2. 构建最近10分钟的 Z-Score 序列
|
||
3. LSTM 计算重构误差(ML分数)
|
||
4. 规则评分(基于 Z-Score 的规则)
|
||
5. 滑动窗口确认:最近5分钟内是否有足够多的时刻超标
|
||
6. 只有通过持续性确认的才输出为异动
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import json
|
||
import pickle
|
||
from datetime import datetime, timedelta
|
||
from typing import Dict, List, Optional, Tuple
|
||
from collections import defaultdict, deque
|
||
|
||
import numpy as np
|
||
import pandas as pd
|
||
import torch
|
||
from sqlalchemy import create_engine, text
|
||
from elasticsearch import Elasticsearch
|
||
from clickhouse_driver import Client
|
||
|
||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
|
||
from ml.model import TransformerAutoencoder
|
||
|
||
# ==================== 配置 ====================
|
||
|
||
MYSQL_ENGINE = create_engine(
|
||
"mysql+pymysql://root:Zzl5588161!@192.168.1.5:3306/stock",
|
||
echo=False
|
||
)
|
||
|
||
ES_CLIENT = Elasticsearch(['http://127.0.0.1:9200'])
|
||
ES_INDEX = 'concept_library_v3'
|
||
|
||
CLICKHOUSE_CONFIG = {
|
||
'host': '127.0.0.1',
|
||
'port': 9000,
|
||
'user': 'default',
|
||
'password': 'Zzl33818!',
|
||
'database': 'stock'
|
||
}
|
||
|
||
REFERENCE_INDEX = '000001.SH'
|
||
|
||
# 检测配置
|
||
CONFIG = {
|
||
# 序列配置
|
||
'seq_len': 10, # LSTM 序列长度(分钟)
|
||
|
||
# 持续性确认配置(核心!)
|
||
'confirm_window': 5, # 确认窗口(分钟)
|
||
'confirm_ratio': 0.6, # 确认比例(60%时刻需要超标)
|
||
|
||
# Z-Score 阈值
|
||
'alpha_zscore_threshold': 2.0, # Alpha Z-Score 阈值
|
||
'amt_zscore_threshold': 2.5, # 成交额 Z-Score 阈值
|
||
|
||
# 融合权重
|
||
'rule_weight': 0.5,
|
||
'ml_weight': 0.5,
|
||
|
||
# 触发阈值
|
||
'rule_trigger': 60,
|
||
'ml_trigger': 70,
|
||
'fusion_trigger': 50,
|
||
|
||
# 冷却期
|
||
'cooldown_minutes': 10,
|
||
'max_alerts_per_minute': 15,
|
||
|
||
# Z-Score 截断
|
||
'zscore_clip': 5.0,
|
||
}
|
||
|
||
# V2 特征列表
|
||
FEATURES_V2 = [
|
||
'alpha_zscore', 'amt_zscore', 'rank_zscore',
|
||
'momentum_3m', 'momentum_5m', 'limit_up_ratio'
|
||
]
|
||
|
||
|
||
# ==================== 工具函数 ====================
|
||
|
||
def get_ch_client():
|
||
return Client(**CLICKHOUSE_CONFIG)
|
||
|
||
|
||
def code_to_ch_format(code: str) -> str:
|
||
if not code or len(code) != 6 or not code.isdigit():
|
||
return None
|
||
if code.startswith('6'):
|
||
return f"{code}.SH"
|
||
elif code.startswith('0') or code.startswith('3'):
|
||
return f"{code}.SZ"
|
||
else:
|
||
return f"{code}.BJ"
|
||
|
||
|
||
def time_to_slot(ts) -> str:
|
||
"""时间戳转时间片(HH:MM)"""
|
||
if isinstance(ts, str):
|
||
return ts
|
||
return ts.strftime('%H:%M')
|
||
|
||
|
||
# ==================== 基线加载 ====================
|
||
|
||
def load_baselines(baseline_dir: str = 'ml/data_v2/baselines') -> Dict[str, pd.DataFrame]:
|
||
"""加载时间片基线"""
|
||
baseline_file = os.path.join(baseline_dir, 'baselines.pkl')
|
||
if os.path.exists(baseline_file):
|
||
with open(baseline_file, 'rb') as f:
|
||
return pickle.load(f)
|
||
return {}
|
||
|
||
|
||
# ==================== 规则评分(基于 Z-Score)====================
|
||
|
||
def score_rules_zscore(row: Dict) -> Tuple[float, List[str]]:
|
||
"""
|
||
基于 Z-Score 的规则评分
|
||
|
||
设计思路:Z-Score 已经标准化,直接用阈值判断
|
||
"""
|
||
score = 0.0
|
||
triggered = []
|
||
|
||
alpha_zscore = row.get('alpha_zscore', 0)
|
||
amt_zscore = row.get('amt_zscore', 0)
|
||
rank_zscore = row.get('rank_zscore', 0)
|
||
momentum_3m = row.get('momentum_3m', 0)
|
||
momentum_5m = row.get('momentum_5m', 0)
|
||
limit_up_ratio = row.get('limit_up_ratio', 0)
|
||
|
||
alpha_zscore_abs = abs(alpha_zscore)
|
||
amt_zscore_abs = abs(amt_zscore)
|
||
|
||
# ========== Alpha Z-Score 规则 ==========
|
||
if alpha_zscore_abs >= 4.0:
|
||
score += 25
|
||
triggered.append('alpha_zscore_extreme')
|
||
elif alpha_zscore_abs >= 3.0:
|
||
score += 18
|
||
triggered.append('alpha_zscore_strong')
|
||
elif alpha_zscore_abs >= 2.0:
|
||
score += 10
|
||
triggered.append('alpha_zscore_moderate')
|
||
|
||
# ========== 成交额 Z-Score 规则 ==========
|
||
if amt_zscore >= 4.0:
|
||
score += 20
|
||
triggered.append('amt_zscore_extreme')
|
||
elif amt_zscore >= 3.0:
|
||
score += 12
|
||
triggered.append('amt_zscore_strong')
|
||
elif amt_zscore >= 2.0:
|
||
score += 6
|
||
triggered.append('amt_zscore_moderate')
|
||
|
||
# ========== 排名 Z-Score 规则 ==========
|
||
if abs(rank_zscore) >= 3.0:
|
||
score += 15
|
||
triggered.append('rank_zscore_extreme')
|
||
elif abs(rank_zscore) >= 2.0:
|
||
score += 8
|
||
triggered.append('rank_zscore_strong')
|
||
|
||
# ========== 动量规则 ==========
|
||
if momentum_3m >= 1.0:
|
||
score += 12
|
||
triggered.append('momentum_3m_strong')
|
||
elif momentum_3m >= 0.5:
|
||
score += 6
|
||
triggered.append('momentum_3m_moderate')
|
||
|
||
if momentum_5m >= 1.5:
|
||
score += 10
|
||
triggered.append('momentum_5m_strong')
|
||
|
||
# ========== 涨停比例规则 ==========
|
||
if limit_up_ratio >= 0.3:
|
||
score += 20
|
||
triggered.append('limit_up_extreme')
|
||
elif limit_up_ratio >= 0.15:
|
||
score += 12
|
||
triggered.append('limit_up_strong')
|
||
elif limit_up_ratio >= 0.08:
|
||
score += 5
|
||
triggered.append('limit_up_moderate')
|
||
|
||
# ========== 组合规则 ==========
|
||
# Alpha Z-Score + 成交额放大
|
||
if alpha_zscore_abs >= 2.0 and amt_zscore >= 2.0:
|
||
score += 15
|
||
triggered.append('combo_alpha_amt')
|
||
|
||
# Alpha Z-Score + 涨停
|
||
if alpha_zscore_abs >= 2.0 and limit_up_ratio >= 0.1:
|
||
score += 12
|
||
triggered.append('combo_alpha_limitup')
|
||
|
||
return min(score, 100), triggered
|
||
|
||
|
||
# ==================== ML 评分器 ====================
|
||
|
||
class MLScorerV2:
|
||
"""V2 ML 评分器"""
|
||
|
||
def __init__(self, model_dir: str = 'ml/checkpoints_v2'):
|
||
self.model_dir = model_dir
|
||
self.model = None
|
||
self.thresholds = None
|
||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||
self._load_model()
|
||
|
||
def _load_model(self):
|
||
"""加载模型和阈值"""
|
||
model_path = os.path.join(self.model_dir, 'best_model.pt')
|
||
threshold_path = os.path.join(self.model_dir, 'thresholds.json')
|
||
config_path = os.path.join(self.model_dir, 'config.json')
|
||
|
||
if not os.path.exists(model_path):
|
||
print(f"警告: 模型文件不存在: {model_path}")
|
||
return
|
||
|
||
# 加载配置
|
||
with open(config_path, 'r') as f:
|
||
config = json.load(f)
|
||
|
||
# 创建模型
|
||
model_config = config.get('model', {})
|
||
self.model = TransformerAutoencoder(**model_config)
|
||
|
||
# 加载权重
|
||
checkpoint = torch.load(model_path, map_location=self.device)
|
||
self.model.load_state_dict(checkpoint['model_state_dict'])
|
||
self.model.to(self.device)
|
||
self.model.eval()
|
||
|
||
# 加载阈值
|
||
if os.path.exists(threshold_path):
|
||
with open(threshold_path, 'r') as f:
|
||
self.thresholds = json.load(f)
|
||
|
||
print(f"V2 模型加载完成: {model_path}")
|
||
|
||
@torch.no_grad()
|
||
def score_batch(self, sequences: np.ndarray) -> np.ndarray:
|
||
"""
|
||
批量计算 ML 分数
|
||
|
||
返回 0-100 的分数,越高越异常
|
||
"""
|
||
if self.model is None:
|
||
return np.zeros(len(sequences))
|
||
|
||
# 转换为 tensor
|
||
x = torch.FloatTensor(sequences).to(self.device)
|
||
|
||
# 计算重构误差
|
||
errors = self.model.compute_reconstruction_error(x, reduction='none')
|
||
# 取最后一个时刻的误差
|
||
last_errors = errors[:, -1].cpu().numpy()
|
||
|
||
# 转换为 0-100 分数
|
||
if self.thresholds:
|
||
p50 = self.thresholds.get('median', 0.1)
|
||
p99 = self.thresholds.get('p99', 1.0)
|
||
|
||
# 线性映射:p50 -> 50分,p99 -> 99分
|
||
scores = 50 + (last_errors - p50) / (p99 - p50) * 49
|
||
scores = np.clip(scores, 0, 100)
|
||
else:
|
||
# 没有阈值时,简单归一化
|
||
scores = last_errors * 100
|
||
scores = np.clip(scores, 0, 100)
|
||
|
||
return scores
|
||
|
||
|
||
# ==================== 实时数据管理器 ====================
|
||
|
||
class RealtimeDataManagerV2:
|
||
"""
|
||
V2 实时数据管理器
|
||
|
||
维护:
|
||
1. 每个概念的历史 Z-Score 序列(用于 LSTM 输入)
|
||
2. 每个概念的异动候选队列(用于持续性确认)
|
||
"""
|
||
|
||
def __init__(self, concepts: List[dict], baselines: Dict[str, pd.DataFrame]):
|
||
self.concepts = {c['concept_id']: c for c in concepts}
|
||
self.baselines = baselines
|
||
|
||
# 概念到股票的映射
|
||
self.concept_stocks = {c['concept_id']: set(c['stocks']) for c in concepts}
|
||
|
||
# 历史 Z-Score 序列(每个概念)
|
||
# {concept_id: deque([(timestamp, features_dict), ...], maxlen=seq_len)}
|
||
self.zscore_history = defaultdict(lambda: deque(maxlen=CONFIG['seq_len']))
|
||
|
||
# 异动候选队列(用于持续性确认)
|
||
# {concept_id: deque([(timestamp, score), ...], maxlen=confirm_window)}
|
||
self.anomaly_candidates = defaultdict(lambda: deque(maxlen=CONFIG['confirm_window']))
|
||
|
||
# 冷却期记录
|
||
self.cooldown = {}
|
||
|
||
# 上一次更新的时间戳
|
||
self.last_timestamp = None
|
||
|
||
def compute_zscore_features(
|
||
self,
|
||
concept_id: str,
|
||
timestamp,
|
||
alpha: float,
|
||
total_amt: float,
|
||
rank_pct: float,
|
||
limit_up_ratio: float
|
||
) -> Optional[Dict]:
|
||
"""计算单个概念单个时刻的 Z-Score 特征"""
|
||
if concept_id not in self.baselines:
|
||
return None
|
||
|
||
baseline = self.baselines[concept_id]
|
||
time_slot = time_to_slot(timestamp)
|
||
|
||
# 查找对应时间片的基线
|
||
bl_row = baseline[baseline['time_slot'] == time_slot]
|
||
if bl_row.empty:
|
||
return None
|
||
|
||
bl = bl_row.iloc[0]
|
||
|
||
# 检查样本量
|
||
if bl.get('sample_count', 0) < 10:
|
||
return None
|
||
|
||
# 计算 Z-Score
|
||
alpha_zscore = (alpha - bl['alpha_mean']) / bl['alpha_std']
|
||
amt_zscore = (total_amt - bl['amt_mean']) / bl['amt_std']
|
||
rank_zscore = (rank_pct - bl['rank_mean']) / bl['rank_std']
|
||
|
||
# 截断
|
||
clip = CONFIG['zscore_clip']
|
||
alpha_zscore = np.clip(alpha_zscore, -clip, clip)
|
||
amt_zscore = np.clip(amt_zscore, -clip, clip)
|
||
rank_zscore = np.clip(rank_zscore, -clip, clip)
|
||
|
||
# 计算动量(需要历史)
|
||
history = self.zscore_history[concept_id]
|
||
momentum_3m = 0
|
||
momentum_5m = 0
|
||
|
||
if len(history) >= 3:
|
||
recent_alphas = [h[1]['alpha'] for h in list(history)[-3:]]
|
||
older_alphas = [h[1]['alpha'] for h in list(history)[-6:-3]] if len(history) >= 6 else [alpha]
|
||
momentum_3m = np.mean(recent_alphas) - np.mean(older_alphas)
|
||
|
||
if len(history) >= 5:
|
||
recent_alphas = [h[1]['alpha'] for h in list(history)[-5:]]
|
||
older_alphas = [h[1]['alpha'] for h in list(history)[-10:-5]] if len(history) >= 10 else [alpha]
|
||
momentum_5m = np.mean(recent_alphas) - np.mean(older_alphas)
|
||
|
||
return {
|
||
'alpha': alpha,
|
||
'alpha_zscore': alpha_zscore,
|
||
'amt_zscore': amt_zscore,
|
||
'rank_zscore': rank_zscore,
|
||
'momentum_3m': momentum_3m,
|
||
'momentum_5m': momentum_5m,
|
||
'limit_up_ratio': limit_up_ratio,
|
||
'total_amt': total_amt,
|
||
'rank_pct': rank_pct,
|
||
}
|
||
|
||
def update(self, concept_id: str, timestamp, features: Dict):
|
||
"""更新概念的历史数据"""
|
||
self.zscore_history[concept_id].append((timestamp, features))
|
||
|
||
def get_sequence(self, concept_id: str) -> Optional[np.ndarray]:
|
||
"""获取用于 LSTM 的序列"""
|
||
history = self.zscore_history[concept_id]
|
||
|
||
if len(history) < CONFIG['seq_len']:
|
||
return None
|
||
|
||
# 提取特征
|
||
feature_list = []
|
||
for _, features in history:
|
||
feature_list.append([
|
||
features['alpha_zscore'],
|
||
features['amt_zscore'],
|
||
features['rank_zscore'],
|
||
features['momentum_3m'],
|
||
features['momentum_5m'],
|
||
features['limit_up_ratio'],
|
||
])
|
||
|
||
return np.array(feature_list)
|
||
|
||
def add_anomaly_candidate(self, concept_id: str, timestamp, score: float):
|
||
"""添加异动候选"""
|
||
self.anomaly_candidates[concept_id].append((timestamp, score))
|
||
|
||
def check_sustained_anomaly(self, concept_id: str, threshold: float) -> Tuple[bool, float]:
|
||
"""
|
||
检查是否为持续性异动
|
||
|
||
返回:(是否确认, 确认比例)
|
||
"""
|
||
candidates = self.anomaly_candidates[concept_id]
|
||
|
||
if len(candidates) < CONFIG['confirm_window']:
|
||
return False, 0.0
|
||
|
||
# 统计超过阈值的时刻数量
|
||
exceed_count = sum(1 for _, score in candidates if score >= threshold)
|
||
ratio = exceed_count / len(candidates)
|
||
|
||
return ratio >= CONFIG['confirm_ratio'], ratio
|
||
|
||
def check_cooldown(self, concept_id: str, timestamp) -> bool:
|
||
"""检查是否在冷却期"""
|
||
if concept_id not in self.cooldown:
|
||
return False
|
||
|
||
last_alert = self.cooldown[concept_id]
|
||
try:
|
||
diff = (timestamp - last_alert).total_seconds() / 60
|
||
return diff < CONFIG['cooldown_minutes']
|
||
except:
|
||
return False
|
||
|
||
def set_cooldown(self, concept_id: str, timestamp):
|
||
"""设置冷却期"""
|
||
self.cooldown[concept_id] = timestamp
|
||
|
||
|
||
# ==================== 异动检测器 V2 ====================
|
||
|
||
class AnomalyDetectorV2:
|
||
"""
|
||
V2 异动检测器
|
||
|
||
核心流程:
|
||
1. 获取实时数据
|
||
2. 计算 Z-Score 特征
|
||
3. 规则评分 + ML 评分
|
||
4. 持续性确认
|
||
5. 输出异动
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
model_dir: str = 'ml/checkpoints_v2',
|
||
baseline_dir: str = 'ml/data_v2/baselines'
|
||
):
|
||
# 加载概念
|
||
self.concepts = self._load_concepts()
|
||
|
||
# 加载基线
|
||
self.baselines = load_baselines(baseline_dir)
|
||
print(f"加载了 {len(self.baselines)} 个概念的基线")
|
||
|
||
# 初始化 ML 评分器
|
||
self.ml_scorer = MLScorerV2(model_dir)
|
||
|
||
# 初始化数据管理器
|
||
self.data_manager = RealtimeDataManagerV2(self.concepts, self.baselines)
|
||
|
||
# 收集所有股票
|
||
self.all_stocks = list(set(s for c in self.concepts for s in c['stocks']))
|
||
|
||
def _load_concepts(self) -> List[dict]:
|
||
"""从 ES 加载概念"""
|
||
concepts = []
|
||
query = {"query": {"match_all": {}}, "size": 100, "_source": ["concept_id", "concept", "stocks"]}
|
||
|
||
resp = ES_CLIENT.search(index=ES_INDEX, body=query, scroll='2m')
|
||
scroll_id = resp['_scroll_id']
|
||
hits = resp['hits']['hits']
|
||
|
||
while len(hits) > 0:
|
||
for hit in hits:
|
||
source = hit['_source']
|
||
stocks = []
|
||
if 'stocks' in source and isinstance(source['stocks'], list):
|
||
for stock in source['stocks']:
|
||
if isinstance(stock, dict) and 'code' in stock and stock['code']:
|
||
stocks.append(stock['code'])
|
||
if stocks:
|
||
concepts.append({
|
||
'concept_id': source.get('concept_id'),
|
||
'concept_name': source.get('concept'),
|
||
'stocks': stocks
|
||
})
|
||
|
||
resp = ES_CLIENT.scroll(scroll_id=scroll_id, scroll='2m')
|
||
scroll_id = resp['_scroll_id']
|
||
hits = resp['hits']['hits']
|
||
|
||
ES_CLIENT.clear_scroll(scroll_id=scroll_id)
|
||
print(f"加载了 {len(concepts)} 个概念")
|
||
return concepts
|
||
|
||
def detect(self, trade_date: str) -> List[Dict]:
|
||
"""
|
||
检测指定日期的异动
|
||
|
||
返回异动列表
|
||
"""
|
||
print(f"\n检测 {trade_date} 的异动...")
|
||
|
||
# 获取原始数据
|
||
raw_features = self._compute_raw_features(trade_date)
|
||
if raw_features.empty:
|
||
print("无数据")
|
||
return []
|
||
|
||
# 按时间排序
|
||
timestamps = sorted(raw_features['timestamp'].unique())
|
||
print(f"时间点数: {len(timestamps)}")
|
||
|
||
all_alerts = []
|
||
|
||
for ts in timestamps:
|
||
ts_data = raw_features[raw_features['timestamp'] == ts]
|
||
ts_alerts = self._process_timestamp(ts, ts_data, trade_date)
|
||
all_alerts.extend(ts_alerts)
|
||
|
||
print(f"共检测到 {len(all_alerts)} 个异动")
|
||
return all_alerts
|
||
|
||
def _compute_raw_features(self, trade_date: str) -> pd.DataFrame:
|
||
"""计算原始特征(同 prepare_data_v2)"""
|
||
# 这里简化处理,直接调用数据准备逻辑
|
||
from prepare_data_v2 import compute_raw_concept_features
|
||
return compute_raw_concept_features(trade_date, self.concepts, self.all_stocks)
|
||
|
||
def _process_timestamp(self, timestamp, ts_data: pd.DataFrame, trade_date: str) -> List[Dict]:
|
||
"""处理单个时间戳"""
|
||
alerts = []
|
||
candidates = [] # (concept_id, features, rule_score, triggered_rules)
|
||
|
||
for _, row in ts_data.iterrows():
|
||
concept_id = row['concept_id']
|
||
|
||
# 计算 Z-Score 特征
|
||
features = self.data_manager.compute_zscore_features(
|
||
concept_id, timestamp,
|
||
row['alpha'], row['total_amt'], row['rank_pct'], row['limit_up_ratio']
|
||
)
|
||
|
||
if features is None:
|
||
continue
|
||
|
||
# 更新历史
|
||
self.data_manager.update(concept_id, timestamp, features)
|
||
|
||
# 规则评分
|
||
rule_score, triggered_rules = score_rules_zscore(features)
|
||
|
||
# 收集候选
|
||
candidates.append((concept_id, features, rule_score, triggered_rules))
|
||
|
||
if not candidates:
|
||
return []
|
||
|
||
# 批量 ML 评分
|
||
sequences = []
|
||
valid_candidates = []
|
||
|
||
for concept_id, features, rule_score, triggered_rules in candidates:
|
||
seq = self.data_manager.get_sequence(concept_id)
|
||
if seq is not None:
|
||
sequences.append(seq)
|
||
valid_candidates.append((concept_id, features, rule_score, triggered_rules))
|
||
|
||
if not sequences:
|
||
return []
|
||
|
||
sequences = np.array(sequences)
|
||
ml_scores = self.ml_scorer.score_batch(sequences)
|
||
|
||
# 融合评分 + 持续性确认
|
||
for i, (concept_id, features, rule_score, triggered_rules) in enumerate(valid_candidates):
|
||
ml_score = ml_scores[i]
|
||
final_score = CONFIG['rule_weight'] * rule_score + CONFIG['ml_weight'] * ml_score
|
||
|
||
# 判断是否触发
|
||
is_triggered = (
|
||
rule_score >= CONFIG['rule_trigger'] or
|
||
ml_score >= CONFIG['ml_trigger'] or
|
||
final_score >= CONFIG['fusion_trigger']
|
||
)
|
||
|
||
# 添加到候选队列
|
||
self.data_manager.add_anomaly_candidate(concept_id, timestamp, final_score)
|
||
|
||
if not is_triggered:
|
||
continue
|
||
|
||
# 检查冷却期
|
||
if self.data_manager.check_cooldown(concept_id, timestamp):
|
||
continue
|
||
|
||
# 持续性确认
|
||
is_sustained, confirm_ratio = self.data_manager.check_sustained_anomaly(
|
||
concept_id, CONFIG['fusion_trigger']
|
||
)
|
||
|
||
if not is_sustained:
|
||
continue
|
||
|
||
# 确认为异动!
|
||
self.data_manager.set_cooldown(concept_id, timestamp)
|
||
|
||
# 确定异动类型
|
||
alpha = features['alpha']
|
||
if alpha >= 1.5:
|
||
alert_type = 'surge_up'
|
||
elif alpha <= -1.5:
|
||
alert_type = 'surge_down'
|
||
elif features['amt_zscore'] >= 3.0:
|
||
alert_type = 'volume_spike'
|
||
else:
|
||
alert_type = 'surge'
|
||
|
||
# 确定触发原因
|
||
if rule_score >= CONFIG['rule_trigger']:
|
||
trigger_reason = f'规则({rule_score:.0f})+持续确认({confirm_ratio:.0%})'
|
||
elif ml_score >= CONFIG['ml_trigger']:
|
||
trigger_reason = f'ML({ml_score:.0f})+持续确认({confirm_ratio:.0%})'
|
||
else:
|
||
trigger_reason = f'融合({final_score:.0f})+持续确认({confirm_ratio:.0%})'
|
||
|
||
alerts.append({
|
||
'concept_id': concept_id,
|
||
'concept_name': self.data_manager.concepts.get(concept_id, {}).get('concept_name', concept_id),
|
||
'alert_time': timestamp,
|
||
'trade_date': trade_date,
|
||
'alert_type': alert_type,
|
||
'final_score': final_score,
|
||
'rule_score': rule_score,
|
||
'ml_score': ml_score,
|
||
'trigger_reason': trigger_reason,
|
||
'confirm_ratio': confirm_ratio,
|
||
'alpha': alpha,
|
||
'alpha_zscore': features['alpha_zscore'],
|
||
'amt_zscore': features['amt_zscore'],
|
||
'rank_zscore': features['rank_zscore'],
|
||
'momentum_3m': features['momentum_3m'],
|
||
'momentum_5m': features['momentum_5m'],
|
||
'limit_up_ratio': features['limit_up_ratio'],
|
||
'triggered_rules': triggered_rules,
|
||
})
|
||
|
||
# 每分钟最多 N 个
|
||
if len(alerts) > CONFIG['max_alerts_per_minute']:
|
||
alerts = sorted(alerts, key=lambda x: x['final_score'], reverse=True)
|
||
alerts = alerts[:CONFIG['max_alerts_per_minute']]
|
||
|
||
return alerts
|
||
|
||
|
||
# ==================== 主函数 ====================
|
||
|
||
def main():
|
||
import argparse
|
||
|
||
parser = argparse.ArgumentParser(description='V2 异动检测器')
|
||
parser.add_argument('--date', type=str, default=None, help='检测日期(默认今天)')
|
||
parser.add_argument('--model_dir', type=str, default='ml/checkpoints_v2')
|
||
parser.add_argument('--baseline_dir', type=str, default='ml/data_v2/baselines')
|
||
|
||
args = parser.parse_args()
|
||
|
||
trade_date = args.date or datetime.now().strftime('%Y-%m-%d')
|
||
|
||
detector = AnomalyDetectorV2(
|
||
model_dir=args.model_dir,
|
||
baseline_dir=args.baseline_dir
|
||
)
|
||
|
||
alerts = detector.detect(trade_date)
|
||
|
||
print(f"\n检测结果:")
|
||
for alert in alerts[:20]:
|
||
print(f" [{alert['alert_time'].strftime('%H:%M') if hasattr(alert['alert_time'], 'strftime') else alert['alert_time']}] "
|
||
f"{alert['concept_name']} ({alert['alert_type']}) "
|
||
f"分数={alert['final_score']:.0f} "
|
||
f"确认率={alert['confirm_ratio']:.0%}")
|
||
|
||
if len(alerts) > 20:
|
||
print(f" ... 共 {len(alerts)} 个异动")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|