update pay ui
This commit is contained in:
455
ml/inference.py
Normal file
455
ml/inference.py
Normal file
@@ -0,0 +1,455 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
概念异动检测推理服务
|
||||
|
||||
在实时场景中使用训练好的 Transformer Autoencoder 进行异动检测
|
||||
|
||||
使用方法:
|
||||
from ml.inference import ConceptAnomalyDetector
|
||||
|
||||
detector = ConceptAnomalyDetector('ml/checkpoints')
|
||||
|
||||
# 检测异动
|
||||
features = {...} # 实时特征数据
|
||||
is_anomaly, score = detector.detect(features)
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple, Optional
|
||||
from collections import deque
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from model import TransformerAutoencoder
|
||||
|
||||
|
||||
class ConceptAnomalyDetector:
|
||||
"""
|
||||
概念异动检测器
|
||||
|
||||
使用训练好的 Transformer Autoencoder 进行实时异动检测
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
checkpoint_dir: str = 'ml/checkpoints',
|
||||
device: str = 'auto',
|
||||
threshold_key: str = 'p95'
|
||||
):
|
||||
"""
|
||||
初始化检测器
|
||||
|
||||
Args:
|
||||
checkpoint_dir: 模型检查点目录
|
||||
device: 设备 (auto/cuda/cpu)
|
||||
threshold_key: 使用的阈值键 (p90/p95/p99)
|
||||
"""
|
||||
self.checkpoint_dir = Path(checkpoint_dir)
|
||||
self.threshold_key = threshold_key
|
||||
|
||||
# 设备选择
|
||||
if device == 'auto':
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
else:
|
||||
self.device = torch.device(device)
|
||||
|
||||
# 加载配置
|
||||
self._load_config()
|
||||
|
||||
# 加载模型
|
||||
self._load_model()
|
||||
|
||||
# 加载阈值
|
||||
self._load_thresholds()
|
||||
|
||||
# 加载标准化统计量
|
||||
self._load_normalization_stats()
|
||||
|
||||
# 概念历史数据缓存
|
||||
# {concept_name: deque(maxlen=seq_len)}
|
||||
self.history_cache: Dict[str, deque] = {}
|
||||
|
||||
print(f"ConceptAnomalyDetector 初始化完成")
|
||||
print(f" 设备: {self.device}")
|
||||
print(f" 阈值: {self.threshold_key} = {self.threshold:.6f}")
|
||||
print(f" 序列长度: {self.seq_len}")
|
||||
|
||||
def _load_config(self):
|
||||
"""加载配置"""
|
||||
config_path = self.checkpoint_dir / 'config.json'
|
||||
if not config_path.exists():
|
||||
raise FileNotFoundError(f"配置文件不存在: {config_path}")
|
||||
|
||||
with open(config_path, 'r') as f:
|
||||
self.config = json.load(f)
|
||||
|
||||
self.features = self.config['features']
|
||||
self.seq_len = self.config['seq_len']
|
||||
self.model_config = self.config['model']
|
||||
|
||||
def _load_model(self):
|
||||
"""加载模型"""
|
||||
model_path = self.checkpoint_dir / 'best_model.pt'
|
||||
if not model_path.exists():
|
||||
raise FileNotFoundError(f"模型文件不存在: {model_path}")
|
||||
|
||||
# 创建模型
|
||||
self.model = TransformerAutoencoder(**self.model_config)
|
||||
|
||||
# 加载权重
|
||||
checkpoint = torch.load(model_path, map_location=self.device)
|
||||
self.model.load_state_dict(checkpoint['model_state_dict'])
|
||||
self.model.to(self.device)
|
||||
self.model.eval()
|
||||
|
||||
print(f"模型已加载: {model_path}")
|
||||
|
||||
def _load_thresholds(self):
|
||||
"""加载阈值"""
|
||||
thresholds_path = self.checkpoint_dir / 'thresholds.json'
|
||||
if not thresholds_path.exists():
|
||||
raise FileNotFoundError(f"阈值文件不存在: {thresholds_path}")
|
||||
|
||||
with open(thresholds_path, 'r') as f:
|
||||
self.thresholds = json.load(f)
|
||||
|
||||
if self.threshold_key not in self.thresholds:
|
||||
available_keys = list(self.thresholds.keys())
|
||||
raise KeyError(f"阈值键 '{self.threshold_key}' 不存在,可用: {available_keys}")
|
||||
|
||||
self.threshold = self.thresholds[self.threshold_key]
|
||||
|
||||
def _load_normalization_stats(self):
|
||||
"""加载标准化统计量"""
|
||||
stats_path = self.checkpoint_dir / 'normalization_stats.json'
|
||||
if not stats_path.exists():
|
||||
raise FileNotFoundError(f"标准化统计量文件不存在: {stats_path}")
|
||||
|
||||
with open(stats_path, 'r') as f:
|
||||
stats = json.load(f)
|
||||
|
||||
self.norm_mean = np.array(stats['mean'])
|
||||
self.norm_std = np.array(stats['std'])
|
||||
|
||||
def normalize(self, features: np.ndarray) -> np.ndarray:
|
||||
"""标准化特征"""
|
||||
return (features - self.norm_mean) / self.norm_std
|
||||
|
||||
def update_history(
|
||||
self,
|
||||
concept_name: str,
|
||||
features: Dict[str, float]
|
||||
):
|
||||
"""
|
||||
更新概念历史数据
|
||||
|
||||
Args:
|
||||
concept_name: 概念名称
|
||||
features: 当前时刻的特征字典
|
||||
"""
|
||||
# 初始化历史缓存
|
||||
if concept_name not in self.history_cache:
|
||||
self.history_cache[concept_name] = deque(maxlen=self.seq_len)
|
||||
|
||||
# 提取特征向量
|
||||
feature_vector = np.array([
|
||||
features.get(f, 0.0) for f in self.features
|
||||
])
|
||||
|
||||
# 处理异常值
|
||||
feature_vector = np.nan_to_num(feature_vector, nan=0.0, posinf=0.0, neginf=0.0)
|
||||
|
||||
# 添加到历史
|
||||
self.history_cache[concept_name].append(feature_vector)
|
||||
|
||||
def get_history_length(self, concept_name: str) -> int:
|
||||
"""获取概念的历史数据长度"""
|
||||
if concept_name not in self.history_cache:
|
||||
return 0
|
||||
return len(self.history_cache[concept_name])
|
||||
|
||||
@torch.no_grad()
|
||||
def detect(
|
||||
self,
|
||||
concept_name: str,
|
||||
features: Dict[str, float] = None,
|
||||
return_score: bool = True
|
||||
) -> Tuple[bool, Optional[float]]:
|
||||
"""
|
||||
检测概念是否异动
|
||||
|
||||
Args:
|
||||
concept_name: 概念名称
|
||||
features: 当前时刻的特征(如果提供,会先更新历史)
|
||||
return_score: 是否返回异动分数
|
||||
|
||||
Returns:
|
||||
is_anomaly: 是否异动
|
||||
score: 异动分数(如果 return_score=True)
|
||||
"""
|
||||
# 更新历史
|
||||
if features is not None:
|
||||
self.update_history(concept_name, features)
|
||||
|
||||
# 检查历史数据是否足够
|
||||
if concept_name not in self.history_cache:
|
||||
return False, None
|
||||
|
||||
history = self.history_cache[concept_name]
|
||||
if len(history) < self.seq_len:
|
||||
return False, None
|
||||
|
||||
# 构建输入序列
|
||||
sequence = np.array(list(history)) # (seq_len, n_features)
|
||||
|
||||
# 标准化
|
||||
sequence = self.normalize(sequence)
|
||||
|
||||
# 转为 tensor
|
||||
x = torch.FloatTensor(sequence).unsqueeze(0) # (1, seq_len, n_features)
|
||||
x = x.to(self.device)
|
||||
|
||||
# 计算重构误差
|
||||
error = self.model.compute_reconstruction_error(x, reduction='none')
|
||||
|
||||
# 取最后一个时刻的误差作为当前分数
|
||||
score = error[0, -1].item()
|
||||
|
||||
# 判断是否异动
|
||||
is_anomaly = score > self.threshold
|
||||
|
||||
if return_score:
|
||||
return is_anomaly, score
|
||||
else:
|
||||
return is_anomaly, None
|
||||
|
||||
@torch.no_grad()
|
||||
def batch_detect(
|
||||
self,
|
||||
concept_features: Dict[str, Dict[str, float]]
|
||||
) -> Dict[str, Tuple[bool, float]]:
|
||||
"""
|
||||
批量检测多个概念
|
||||
|
||||
Args:
|
||||
concept_features: {concept_name: {feature_name: value}}
|
||||
|
||||
Returns:
|
||||
results: {concept_name: (is_anomaly, score)}
|
||||
"""
|
||||
results = {}
|
||||
|
||||
for concept_name, features in concept_features.items():
|
||||
is_anomaly, score = self.detect(concept_name, features)
|
||||
results[concept_name] = (is_anomaly, score)
|
||||
|
||||
return results
|
||||
|
||||
def get_anomaly_type(
|
||||
self,
|
||||
concept_name: str,
|
||||
features: Dict[str, float]
|
||||
) -> str:
|
||||
"""
|
||||
判断异动类型
|
||||
|
||||
Args:
|
||||
concept_name: 概念名称
|
||||
features: 当前特征
|
||||
|
||||
Returns:
|
||||
anomaly_type: 'surge_up' / 'surge_down' / 'normal'
|
||||
"""
|
||||
is_anomaly, score = self.detect(concept_name, features)
|
||||
|
||||
if not is_anomaly:
|
||||
return 'normal'
|
||||
|
||||
# 根据 alpha 判断涨跌
|
||||
alpha = features.get('alpha', 0.0)
|
||||
|
||||
if alpha > 0:
|
||||
return 'surge_up'
|
||||
else:
|
||||
return 'surge_down'
|
||||
|
||||
def get_top_anomalies(
|
||||
self,
|
||||
concept_features: Dict[str, Dict[str, float]],
|
||||
top_k: int = 10
|
||||
) -> List[Tuple[str, float, str]]:
|
||||
"""
|
||||
获取异动分数最高的 top_k 个概念
|
||||
|
||||
Args:
|
||||
concept_features: {concept_name: {feature_name: value}}
|
||||
top_k: 返回数量
|
||||
|
||||
Returns:
|
||||
anomalies: [(concept_name, score, anomaly_type), ...]
|
||||
"""
|
||||
results = self.batch_detect(concept_features)
|
||||
|
||||
# 按分数排序
|
||||
sorted_results = sorted(
|
||||
[(name, is_anomaly, score) for name, (is_anomaly, score) in results.items() if score is not None],
|
||||
key=lambda x: x[2],
|
||||
reverse=True
|
||||
)
|
||||
|
||||
# 取 top_k
|
||||
top_anomalies = []
|
||||
for name, is_anomaly, score in sorted_results[:top_k]:
|
||||
if is_anomaly:
|
||||
alpha = concept_features[name].get('alpha', 0.0)
|
||||
anomaly_type = 'surge_up' if alpha > 0 else 'surge_down'
|
||||
top_anomalies.append((name, score, anomaly_type))
|
||||
|
||||
return top_anomalies
|
||||
|
||||
def clear_history(self, concept_name: str = None):
|
||||
"""
|
||||
清除历史缓存
|
||||
|
||||
Args:
|
||||
concept_name: 概念名称(如果为 None,清除所有)
|
||||
"""
|
||||
if concept_name is None:
|
||||
self.history_cache.clear()
|
||||
elif concept_name in self.history_cache:
|
||||
del self.history_cache[concept_name]
|
||||
|
||||
|
||||
# ==================== 集成到现有系统 ====================
|
||||
|
||||
class MLAnomalyService:
|
||||
"""
|
||||
ML 异动检测服务
|
||||
|
||||
用于替换或增强现有的 Alpha-based 检测
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
checkpoint_dir: str = 'ml/checkpoints',
|
||||
fallback_to_alpha: bool = True
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
checkpoint_dir: 模型检查点目录
|
||||
fallback_to_alpha: 当 ML 模型不可用时是否回退到 Alpha 方法
|
||||
"""
|
||||
self.fallback_to_alpha = fallback_to_alpha
|
||||
self.ml_detector = None
|
||||
|
||||
try:
|
||||
self.ml_detector = ConceptAnomalyDetector(checkpoint_dir)
|
||||
print("ML 异动检测服务初始化成功")
|
||||
except Exception as e:
|
||||
print(f"ML 模型加载失败: {e}")
|
||||
if not fallback_to_alpha:
|
||||
raise
|
||||
print("将回退到 Alpha-based 检测")
|
||||
|
||||
def is_ml_available(self) -> bool:
|
||||
"""检查 ML 模型是否可用"""
|
||||
return self.ml_detector is not None
|
||||
|
||||
def detect_anomaly(
|
||||
self,
|
||||
concept_name: str,
|
||||
features: Dict[str, float],
|
||||
alpha_threshold: float = 2.0
|
||||
) -> Tuple[bool, float, str]:
|
||||
"""
|
||||
检测异动
|
||||
|
||||
Args:
|
||||
concept_name: 概念名称
|
||||
features: 特征字典(需包含 alpha, amt_ratio 等)
|
||||
alpha_threshold: Alpha Z-Score 阈值(用于回退)
|
||||
|
||||
Returns:
|
||||
is_anomaly: 是否异动
|
||||
score: 异动分数
|
||||
method: 检测方法 ('ml' / 'alpha')
|
||||
"""
|
||||
# 优先使用 ML 检测
|
||||
if self.ml_detector is not None:
|
||||
history_len = self.ml_detector.get_history_length(concept_name)
|
||||
|
||||
# 历史数据足够时使用 ML
|
||||
if history_len >= self.ml_detector.seq_len - 1:
|
||||
is_anomaly, score = self.ml_detector.detect(concept_name, features)
|
||||
if score is not None:
|
||||
return is_anomaly, score, 'ml'
|
||||
else:
|
||||
# 更新历史但使用 Alpha 方法
|
||||
self.ml_detector.update_history(concept_name, features)
|
||||
|
||||
# 回退到 Alpha 方法
|
||||
if self.fallback_to_alpha:
|
||||
alpha = features.get('alpha', 0.0)
|
||||
alpha_zscore = features.get('alpha_zscore', 0.0)
|
||||
|
||||
is_anomaly = abs(alpha_zscore) > alpha_threshold
|
||||
score = abs(alpha_zscore)
|
||||
|
||||
return is_anomaly, score, 'alpha'
|
||||
|
||||
return False, 0.0, 'none'
|
||||
|
||||
|
||||
# ==================== 测试 ====================
|
||||
|
||||
if __name__ == "__main__":
|
||||
import random
|
||||
|
||||
print("测试 ConceptAnomalyDetector...")
|
||||
|
||||
# 检查模型是否存在
|
||||
checkpoint_dir = Path('ml/checkpoints')
|
||||
if not (checkpoint_dir / 'best_model.pt').exists():
|
||||
print("模型文件不存在,跳过测试")
|
||||
print("请先运行 train.py 训练模型")
|
||||
exit(0)
|
||||
|
||||
# 初始化检测器
|
||||
detector = ConceptAnomalyDetector('ml/checkpoints')
|
||||
|
||||
# 模拟数据
|
||||
print("\n模拟实时检测...")
|
||||
concept_name = "人工智能"
|
||||
|
||||
for i in range(40):
|
||||
# 生成随机特征
|
||||
features = {
|
||||
'alpha': random.gauss(0, 1),
|
||||
'alpha_delta': random.gauss(0, 0.5),
|
||||
'amt_ratio': random.gauss(1, 0.3),
|
||||
'amt_delta': random.gauss(0, 0.2),
|
||||
'rank_pct': random.random(),
|
||||
'limit_up_ratio': random.random() * 0.1,
|
||||
}
|
||||
|
||||
# 在第 35 分钟模拟异动
|
||||
if i == 35:
|
||||
features['alpha'] = 5.0
|
||||
features['alpha_delta'] = 2.0
|
||||
features['amt_ratio'] = 3.0
|
||||
|
||||
is_anomaly, score = detector.detect(concept_name, features)
|
||||
|
||||
history_len = detector.get_history_length(concept_name)
|
||||
|
||||
if score is not None:
|
||||
status = "🔥 异动!" if is_anomaly else "正常"
|
||||
print(f" t={i:02d} | 历史={history_len} | 分数={score:.4f} | {status}")
|
||||
else:
|
||||
print(f" t={i:02d} | 历史={history_len} | 数据不足")
|
||||
|
||||
print("\n测试完成!")
|
||||
Reference in New Issue
Block a user