update pay ui
This commit is contained in:
526
ml/enhanced_detector.py
Normal file
526
ml/enhanced_detector.py
Normal file
@@ -0,0 +1,526 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
增强版概念异动检测器
|
||||
|
||||
融合两种检测方法:
|
||||
1. Alpha-based Z-Score(规则方法,实时性好)
|
||||
2. Transformer Autoencoder(ML方法,更准确)
|
||||
|
||||
使用策略:
|
||||
- 当 ML 模型可用且历史数据足够时,优先使用 ML 方法
|
||||
- 否则回退到 Alpha-based 方法
|
||||
- 可以配置两种方法的融合权重
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple, Optional
|
||||
from dataclasses import dataclass, field
|
||||
from collections import deque
|
||||
import numpy as np
|
||||
|
||||
# 添加父目录到路径
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ==================== 配置 ====================
|
||||
|
||||
ENHANCED_CONFIG = {
|
||||
# 融合策略
|
||||
'fusion_mode': 'adaptive', # 'ml_only', 'alpha_only', 'adaptive', 'ensemble'
|
||||
|
||||
# ML 权重(在 ensemble 模式下)
|
||||
'ml_weight': 0.6,
|
||||
'alpha_weight': 0.4,
|
||||
|
||||
# ML 模型配置
|
||||
'ml_checkpoint_dir': 'ml/checkpoints',
|
||||
'ml_threshold_key': 'p95', # p90, p95, p99
|
||||
|
||||
# Alpha 配置(与 concept_alert_alpha.py 一致)
|
||||
'alpha_zscore_threshold': 2.0,
|
||||
'alpha_absolute_threshold': 1.5,
|
||||
'alpha_history_window': 60,
|
||||
'alpha_min_history': 5,
|
||||
|
||||
# 共享配置
|
||||
'cooldown_minutes': 8,
|
||||
'max_alerts_per_minute': 15,
|
||||
'min_alpha_abs': 0.5,
|
||||
}
|
||||
|
||||
# 特征配置(与训练一致)
|
||||
FEATURE_NAMES = [
|
||||
'alpha',
|
||||
'alpha_delta',
|
||||
'amt_ratio',
|
||||
'amt_delta',
|
||||
'rank_pct',
|
||||
'limit_up_ratio',
|
||||
]
|
||||
|
||||
|
||||
# ==================== 数据结构 ====================
|
||||
|
||||
@dataclass
|
||||
class AlphaStats:
|
||||
"""概念的Alpha统计信息"""
|
||||
history: deque = field(default_factory=lambda: deque(maxlen=ENHANCED_CONFIG['alpha_history_window']))
|
||||
mean: float = 0.0
|
||||
std: float = 1.0
|
||||
|
||||
def update(self, alpha: float):
|
||||
self.history.append(alpha)
|
||||
if len(self.history) >= 2:
|
||||
self.mean = np.mean(self.history)
|
||||
self.std = max(np.std(self.history), 0.1)
|
||||
|
||||
def get_zscore(self, alpha: float) -> float:
|
||||
if len(self.history) < ENHANCED_CONFIG['alpha_min_history']:
|
||||
return 0.0
|
||||
return (alpha - self.mean) / self.std
|
||||
|
||||
def is_ready(self) -> bool:
|
||||
return len(self.history) >= ENHANCED_CONFIG['alpha_min_history']
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConceptFeatures:
|
||||
"""概念的实时特征"""
|
||||
alpha: float = 0.0
|
||||
alpha_delta: float = 0.0
|
||||
amt_ratio: float = 1.0
|
||||
amt_delta: float = 0.0
|
||||
rank_pct: float = 0.5
|
||||
limit_up_ratio: float = 0.0
|
||||
|
||||
def to_dict(self) -> Dict[str, float]:
|
||||
return {
|
||||
'alpha': self.alpha,
|
||||
'alpha_delta': self.alpha_delta,
|
||||
'amt_ratio': self.amt_ratio,
|
||||
'amt_delta': self.amt_delta,
|
||||
'rank_pct': self.rank_pct,
|
||||
'limit_up_ratio': self.limit_up_ratio,
|
||||
}
|
||||
|
||||
|
||||
# ==================== 增强检测器 ====================
|
||||
|
||||
class EnhancedAnomalyDetector:
|
||||
"""
|
||||
增强版异动检测器
|
||||
|
||||
融合 Alpha-based 和 ML 两种方法
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Dict = None,
|
||||
ml_enabled: bool = True
|
||||
):
|
||||
self.config = config or ENHANCED_CONFIG
|
||||
self.ml_enabled = ml_enabled
|
||||
self.ml_detector = None
|
||||
|
||||
# Alpha 统计
|
||||
self.alpha_stats: Dict[str, AlphaStats] = {}
|
||||
|
||||
# 特征历史(用于计算 delta)
|
||||
self.feature_history: Dict[str, deque] = {}
|
||||
|
||||
# 冷却记录
|
||||
self.cooldown_cache: Dict[str, datetime] = {}
|
||||
|
||||
# 尝试加载 ML 模型
|
||||
if ml_enabled:
|
||||
self._load_ml_model()
|
||||
|
||||
logger.info(f"EnhancedAnomalyDetector 初始化完成")
|
||||
logger.info(f" 融合模式: {self.config['fusion_mode']}")
|
||||
logger.info(f" ML 可用: {self.ml_detector is not None}")
|
||||
|
||||
def _load_ml_model(self):
|
||||
"""加载 ML 模型"""
|
||||
try:
|
||||
from inference import ConceptAnomalyDetector
|
||||
checkpoint_dir = Path(__file__).parent / 'checkpoints'
|
||||
|
||||
if (checkpoint_dir / 'best_model.pt').exists():
|
||||
self.ml_detector = ConceptAnomalyDetector(
|
||||
checkpoint_dir=str(checkpoint_dir),
|
||||
threshold_key=self.config['ml_threshold_key']
|
||||
)
|
||||
logger.info("ML 模型加载成功")
|
||||
else:
|
||||
logger.warning(f"ML 模型不存在: {checkpoint_dir / 'best_model.pt'}")
|
||||
except Exception as e:
|
||||
logger.warning(f"ML 模型加载失败: {e}")
|
||||
self.ml_detector = None
|
||||
|
||||
def _get_alpha_stats(self, concept_id: str) -> AlphaStats:
|
||||
"""获取或创建 Alpha 统计"""
|
||||
if concept_id not in self.alpha_stats:
|
||||
self.alpha_stats[concept_id] = AlphaStats()
|
||||
return self.alpha_stats[concept_id]
|
||||
|
||||
def _get_feature_history(self, concept_id: str) -> deque:
|
||||
"""获取特征历史"""
|
||||
if concept_id not in self.feature_history:
|
||||
self.feature_history[concept_id] = deque(maxlen=10)
|
||||
return self.feature_history[concept_id]
|
||||
|
||||
def _check_cooldown(self, concept_id: str, current_time: datetime) -> bool:
|
||||
"""检查冷却"""
|
||||
if concept_id not in self.cooldown_cache:
|
||||
return False
|
||||
|
||||
last_alert = self.cooldown_cache[concept_id]
|
||||
cooldown_td = (current_time - last_alert).total_seconds() / 60
|
||||
|
||||
return cooldown_td < self.config['cooldown_minutes']
|
||||
|
||||
def _set_cooldown(self, concept_id: str, current_time: datetime):
|
||||
"""设置冷却"""
|
||||
self.cooldown_cache[concept_id] = current_time
|
||||
|
||||
def compute_features(
|
||||
self,
|
||||
concept_id: str,
|
||||
alpha: float,
|
||||
amt_ratio: float,
|
||||
rank_pct: float,
|
||||
limit_up_ratio: float
|
||||
) -> ConceptFeatures:
|
||||
"""
|
||||
计算概念的完整特征
|
||||
|
||||
Args:
|
||||
concept_id: 概念ID
|
||||
alpha: 当前超额收益
|
||||
amt_ratio: 成交额比率
|
||||
rank_pct: 排名百分位
|
||||
limit_up_ratio: 涨停股占比
|
||||
|
||||
Returns:
|
||||
完整特征
|
||||
"""
|
||||
history = self._get_feature_history(concept_id)
|
||||
|
||||
# 计算变化率
|
||||
alpha_delta = 0.0
|
||||
amt_delta = 0.0
|
||||
|
||||
if len(history) > 0:
|
||||
last_features = history[-1]
|
||||
alpha_delta = alpha - last_features.alpha
|
||||
if last_features.amt_ratio > 0:
|
||||
amt_delta = (amt_ratio - last_features.amt_ratio) / last_features.amt_ratio
|
||||
|
||||
features = ConceptFeatures(
|
||||
alpha=alpha,
|
||||
alpha_delta=alpha_delta,
|
||||
amt_ratio=amt_ratio,
|
||||
amt_delta=amt_delta,
|
||||
rank_pct=rank_pct,
|
||||
limit_up_ratio=limit_up_ratio,
|
||||
)
|
||||
|
||||
# 更新历史
|
||||
history.append(features)
|
||||
|
||||
return features
|
||||
|
||||
def detect_alpha_anomaly(
|
||||
self,
|
||||
concept_id: str,
|
||||
alpha: float
|
||||
) -> Tuple[bool, float, str]:
|
||||
"""
|
||||
Alpha-based 异动检测
|
||||
|
||||
Returns:
|
||||
is_anomaly: 是否异动
|
||||
score: 异动分数(Z-Score 绝对值)
|
||||
reason: 触发原因
|
||||
"""
|
||||
stats = self._get_alpha_stats(concept_id)
|
||||
|
||||
# 计算 Z-Score(在更新前)
|
||||
zscore = stats.get_zscore(alpha)
|
||||
|
||||
# 更新统计
|
||||
stats.update(alpha)
|
||||
|
||||
# 判断
|
||||
if stats.is_ready():
|
||||
if abs(zscore) >= self.config['alpha_zscore_threshold']:
|
||||
return True, abs(zscore), f"Z={zscore:.2f}"
|
||||
else:
|
||||
if abs(alpha) >= self.config['alpha_absolute_threshold']:
|
||||
fake_zscore = alpha / 0.5
|
||||
return True, abs(fake_zscore), f"Alpha={alpha:+.2f}%"
|
||||
|
||||
return False, abs(zscore) if zscore else 0.0, ""
|
||||
|
||||
def detect_ml_anomaly(
|
||||
self,
|
||||
concept_id: str,
|
||||
features: ConceptFeatures
|
||||
) -> Tuple[bool, float]:
|
||||
"""
|
||||
ML-based 异动检测
|
||||
|
||||
Returns:
|
||||
is_anomaly: 是否异动
|
||||
score: 异动分数(重构误差)
|
||||
"""
|
||||
if self.ml_detector is None:
|
||||
return False, 0.0
|
||||
|
||||
try:
|
||||
is_anomaly, score = self.ml_detector.detect(
|
||||
concept_id,
|
||||
features.to_dict()
|
||||
)
|
||||
return is_anomaly, score or 0.0
|
||||
except Exception as e:
|
||||
logger.warning(f"ML 检测失败: {e}")
|
||||
return False, 0.0
|
||||
|
||||
def detect(
|
||||
self,
|
||||
concept_id: str,
|
||||
concept_name: str,
|
||||
alpha: float,
|
||||
amt_ratio: float,
|
||||
rank_pct: float,
|
||||
limit_up_ratio: float,
|
||||
change_pct: float,
|
||||
index_change: float,
|
||||
current_time: datetime,
|
||||
**extra_data
|
||||
) -> Optional[Dict]:
|
||||
"""
|
||||
融合检测
|
||||
|
||||
Args:
|
||||
concept_id: 概念ID
|
||||
concept_name: 概念名称
|
||||
alpha: 超额收益
|
||||
amt_ratio: 成交额比率
|
||||
rank_pct: 排名百分位
|
||||
limit_up_ratio: 涨停股占比
|
||||
change_pct: 概念涨跌幅
|
||||
index_change: 大盘涨跌幅
|
||||
current_time: 当前时间
|
||||
**extra_data: 其他数据(limit_up_count, stock_count 等)
|
||||
|
||||
Returns:
|
||||
异动信息(如果触发),否则 None
|
||||
"""
|
||||
# Alpha 太小,不关注
|
||||
if abs(alpha) < self.config['min_alpha_abs']:
|
||||
return None
|
||||
|
||||
# 检查冷却
|
||||
if self._check_cooldown(concept_id, current_time):
|
||||
return None
|
||||
|
||||
# 计算特征
|
||||
features = self.compute_features(
|
||||
concept_id, alpha, amt_ratio, rank_pct, limit_up_ratio
|
||||
)
|
||||
|
||||
# 执行检测
|
||||
fusion_mode = self.config['fusion_mode']
|
||||
|
||||
alpha_anomaly, alpha_score, alpha_reason = self.detect_alpha_anomaly(concept_id, alpha)
|
||||
ml_anomaly, ml_score = False, 0.0
|
||||
|
||||
if fusion_mode in ('ml_only', 'adaptive', 'ensemble'):
|
||||
ml_anomaly, ml_score = self.detect_ml_anomaly(concept_id, features)
|
||||
|
||||
# 根据融合模式判断
|
||||
is_anomaly = False
|
||||
final_score = 0.0
|
||||
detection_method = ''
|
||||
|
||||
if fusion_mode == 'alpha_only':
|
||||
is_anomaly = alpha_anomaly
|
||||
final_score = alpha_score
|
||||
detection_method = 'alpha'
|
||||
|
||||
elif fusion_mode == 'ml_only':
|
||||
is_anomaly = ml_anomaly
|
||||
final_score = ml_score
|
||||
detection_method = 'ml'
|
||||
|
||||
elif fusion_mode == 'adaptive':
|
||||
# 优先 ML,回退 Alpha
|
||||
if self.ml_detector and ml_score > 0:
|
||||
is_anomaly = ml_anomaly
|
||||
final_score = ml_score
|
||||
detection_method = 'ml'
|
||||
else:
|
||||
is_anomaly = alpha_anomaly
|
||||
final_score = alpha_score
|
||||
detection_method = 'alpha'
|
||||
|
||||
elif fusion_mode == 'ensemble':
|
||||
# 加权融合
|
||||
# 归一化分数
|
||||
norm_alpha = min(alpha_score / 5.0, 1.0) # Z > 5 视为 1.0
|
||||
norm_ml = min(ml_score / (self.ml_detector.threshold if self.ml_detector else 1.0), 1.0)
|
||||
|
||||
final_score = (
|
||||
self.config['alpha_weight'] * norm_alpha +
|
||||
self.config['ml_weight'] * norm_ml
|
||||
)
|
||||
is_anomaly = final_score > 0.5 or alpha_anomaly or ml_anomaly
|
||||
detection_method = 'ensemble'
|
||||
|
||||
if not is_anomaly:
|
||||
return None
|
||||
|
||||
# 构建异动记录
|
||||
self._set_cooldown(concept_id, current_time)
|
||||
|
||||
alert_type = 'surge_up' if alpha > 0 else 'surge_down'
|
||||
|
||||
alert = {
|
||||
'concept_id': concept_id,
|
||||
'concept_name': concept_name,
|
||||
'alert_type': alert_type,
|
||||
'alert_time': current_time,
|
||||
'change_pct': change_pct,
|
||||
'alpha': alpha,
|
||||
'alpha_zscore': alpha_score,
|
||||
'index_change_pct': index_change,
|
||||
'detection_method': detection_method,
|
||||
'alpha_score': alpha_score,
|
||||
'ml_score': ml_score,
|
||||
'final_score': final_score,
|
||||
**extra_data
|
||||
}
|
||||
|
||||
return alert
|
||||
|
||||
def batch_detect(
|
||||
self,
|
||||
concepts_data: List[Dict],
|
||||
current_time: datetime
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
批量检测
|
||||
|
||||
Args:
|
||||
concepts_data: 概念数据列表
|
||||
current_time: 当前时间
|
||||
|
||||
Returns:
|
||||
异动列表(按分数排序,限制数量)
|
||||
"""
|
||||
alerts = []
|
||||
|
||||
for data in concepts_data:
|
||||
alert = self.detect(
|
||||
concept_id=data['concept_id'],
|
||||
concept_name=data['concept_name'],
|
||||
alpha=data.get('alpha', 0),
|
||||
amt_ratio=data.get('amt_ratio', 1.0),
|
||||
rank_pct=data.get('rank_pct', 0.5),
|
||||
limit_up_ratio=data.get('limit_up_ratio', 0),
|
||||
change_pct=data.get('change_pct', 0),
|
||||
index_change=data.get('index_change', 0),
|
||||
current_time=current_time,
|
||||
limit_up_count=data.get('limit_up_count', 0),
|
||||
limit_down_count=data.get('limit_down_count', 0),
|
||||
stock_count=data.get('stock_count', 0),
|
||||
concept_type=data.get('concept_type', 'leaf'),
|
||||
)
|
||||
|
||||
if alert:
|
||||
alerts.append(alert)
|
||||
|
||||
# 排序并限制数量
|
||||
alerts.sort(key=lambda x: x['final_score'], reverse=True)
|
||||
return alerts[:self.config['max_alerts_per_minute']]
|
||||
|
||||
def reset(self):
|
||||
"""重置所有状态(新交易日)"""
|
||||
self.alpha_stats.clear()
|
||||
self.feature_history.clear()
|
||||
self.cooldown_cache.clear()
|
||||
|
||||
if self.ml_detector:
|
||||
self.ml_detector.clear_history()
|
||||
|
||||
logger.info("检测器状态已重置")
|
||||
|
||||
|
||||
# ==================== 测试 ====================
|
||||
|
||||
if __name__ == "__main__":
|
||||
import random
|
||||
|
||||
print("测试 EnhancedAnomalyDetector...")
|
||||
|
||||
# 初始化
|
||||
detector = EnhancedAnomalyDetector(ml_enabled=False) # 不加载 ML(可能不存在)
|
||||
|
||||
# 模拟数据
|
||||
concepts = [
|
||||
{'concept_id': 'ai_001', 'concept_name': '人工智能'},
|
||||
{'concept_id': 'chip_002', 'concept_name': '芯片半导体'},
|
||||
{'concept_id': 'car_003', 'concept_name': '新能源汽车'},
|
||||
]
|
||||
|
||||
print("\n模拟实时检测...")
|
||||
current_time = datetime.now()
|
||||
|
||||
for minute in range(50):
|
||||
concepts_data = []
|
||||
|
||||
for c in concepts:
|
||||
# 生成随机数据
|
||||
alpha = random.gauss(0, 0.8)
|
||||
amt_ratio = max(0.3, random.gauss(1, 0.3))
|
||||
rank_pct = random.random()
|
||||
limit_up_ratio = random.random() * 0.1
|
||||
|
||||
# 模拟异动(第30分钟人工智能暴涨)
|
||||
if minute == 30 and c['concept_id'] == 'ai_001':
|
||||
alpha = 4.5
|
||||
amt_ratio = 2.5
|
||||
limit_up_ratio = 0.3
|
||||
|
||||
concepts_data.append({
|
||||
**c,
|
||||
'alpha': alpha,
|
||||
'amt_ratio': amt_ratio,
|
||||
'rank_pct': rank_pct,
|
||||
'limit_up_ratio': limit_up_ratio,
|
||||
'change_pct': alpha + 0.5,
|
||||
'index_change': 0.5,
|
||||
})
|
||||
|
||||
# 检测
|
||||
alerts = detector.batch_detect(concepts_data, current_time)
|
||||
|
||||
if alerts:
|
||||
for alert in alerts:
|
||||
print(f" t={minute:02d} 🔥 {alert['concept_name']} "
|
||||
f"Alpha={alert['alpha']:+.2f}% "
|
||||
f"Score={alert['final_score']:.2f} "
|
||||
f"Method={alert['detection_method']}")
|
||||
|
||||
current_time = current_time.replace(minute=current_time.minute + 1 if current_time.minute < 59 else 0)
|
||||
|
||||
print("\n测试完成!")
|
||||
Reference in New Issue
Block a user