Files
vf_react/ml/inference.py
2025-12-09 08:31:18 +08:00

456 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
概念异动检测推理服务
在实时场景中使用训练好的 Transformer Autoencoder 进行异动检测
使用方法:
from ml.inference import ConceptAnomalyDetector
detector = ConceptAnomalyDetector('ml/checkpoints')
# 检测异动
features = {...} # 实时特征数据
is_anomaly, score = detector.detect(features)
"""
import os
import json
from pathlib import Path
from typing import Dict, List, Tuple, Optional
from collections import deque
import numpy as np
import torch
from model import TransformerAutoencoder
class ConceptAnomalyDetector:
"""
概念异动检测器
使用训练好的 Transformer Autoencoder 进行实时异动检测
"""
def __init__(
self,
checkpoint_dir: str = 'ml/checkpoints',
device: str = 'auto',
threshold_key: str = 'p95'
):
"""
初始化检测器
Args:
checkpoint_dir: 模型检查点目录
device: 设备 (auto/cuda/cpu)
threshold_key: 使用的阈值键 (p90/p95/p99)
"""
self.checkpoint_dir = Path(checkpoint_dir)
self.threshold_key = threshold_key
# 设备选择
if device == 'auto':
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
else:
self.device = torch.device(device)
# 加载配置
self._load_config()
# 加载模型
self._load_model()
# 加载阈值
self._load_thresholds()
# 加载标准化统计量
self._load_normalization_stats()
# 概念历史数据缓存
# {concept_name: deque(maxlen=seq_len)}
self.history_cache: Dict[str, deque] = {}
print(f"ConceptAnomalyDetector 初始化完成")
print(f" 设备: {self.device}")
print(f" 阈值: {self.threshold_key} = {self.threshold:.6f}")
print(f" 序列长度: {self.seq_len}")
def _load_config(self):
"""加载配置"""
config_path = self.checkpoint_dir / 'config.json'
if not config_path.exists():
raise FileNotFoundError(f"配置文件不存在: {config_path}")
with open(config_path, 'r') as f:
self.config = json.load(f)
self.features = self.config['features']
self.seq_len = self.config['seq_len']
self.model_config = self.config['model']
def _load_model(self):
"""加载模型"""
model_path = self.checkpoint_dir / 'best_model.pt'
if not model_path.exists():
raise FileNotFoundError(f"模型文件不存在: {model_path}")
# 创建模型
self.model = TransformerAutoencoder(**self.model_config)
# 加载权重
checkpoint = torch.load(model_path, map_location=self.device)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.model.to(self.device)
self.model.eval()
print(f"模型已加载: {model_path}")
def _load_thresholds(self):
"""加载阈值"""
thresholds_path = self.checkpoint_dir / 'thresholds.json'
if not thresholds_path.exists():
raise FileNotFoundError(f"阈值文件不存在: {thresholds_path}")
with open(thresholds_path, 'r') as f:
self.thresholds = json.load(f)
if self.threshold_key not in self.thresholds:
available_keys = list(self.thresholds.keys())
raise KeyError(f"阈值键 '{self.threshold_key}' 不存在,可用: {available_keys}")
self.threshold = self.thresholds[self.threshold_key]
def _load_normalization_stats(self):
"""加载标准化统计量"""
stats_path = self.checkpoint_dir / 'normalization_stats.json'
if not stats_path.exists():
raise FileNotFoundError(f"标准化统计量文件不存在: {stats_path}")
with open(stats_path, 'r') as f:
stats = json.load(f)
self.norm_mean = np.array(stats['mean'])
self.norm_std = np.array(stats['std'])
def normalize(self, features: np.ndarray) -> np.ndarray:
"""标准化特征"""
return (features - self.norm_mean) / self.norm_std
def update_history(
self,
concept_name: str,
features: Dict[str, float]
):
"""
更新概念历史数据
Args:
concept_name: 概念名称
features: 当前时刻的特征字典
"""
# 初始化历史缓存
if concept_name not in self.history_cache:
self.history_cache[concept_name] = deque(maxlen=self.seq_len)
# 提取特征向量
feature_vector = np.array([
features.get(f, 0.0) for f in self.features
])
# 处理异常值
feature_vector = np.nan_to_num(feature_vector, nan=0.0, posinf=0.0, neginf=0.0)
# 添加到历史
self.history_cache[concept_name].append(feature_vector)
def get_history_length(self, concept_name: str) -> int:
"""获取概念的历史数据长度"""
if concept_name not in self.history_cache:
return 0
return len(self.history_cache[concept_name])
@torch.no_grad()
def detect(
self,
concept_name: str,
features: Dict[str, float] = None,
return_score: bool = True
) -> Tuple[bool, Optional[float]]:
"""
检测概念是否异动
Args:
concept_name: 概念名称
features: 当前时刻的特征(如果提供,会先更新历史)
return_score: 是否返回异动分数
Returns:
is_anomaly: 是否异动
score: 异动分数(如果 return_score=True
"""
# 更新历史
if features is not None:
self.update_history(concept_name, features)
# 检查历史数据是否足够
if concept_name not in self.history_cache:
return False, None
history = self.history_cache[concept_name]
if len(history) < self.seq_len:
return False, None
# 构建输入序列
sequence = np.array(list(history)) # (seq_len, n_features)
# 标准化
sequence = self.normalize(sequence)
# 转为 tensor
x = torch.FloatTensor(sequence).unsqueeze(0) # (1, seq_len, n_features)
x = x.to(self.device)
# 计算重构误差
error = self.model.compute_reconstruction_error(x, reduction='none')
# 取最后一个时刻的误差作为当前分数
score = error[0, -1].item()
# 判断是否异动
is_anomaly = score > self.threshold
if return_score:
return is_anomaly, score
else:
return is_anomaly, None
@torch.no_grad()
def batch_detect(
self,
concept_features: Dict[str, Dict[str, float]]
) -> Dict[str, Tuple[bool, float]]:
"""
批量检测多个概念
Args:
concept_features: {concept_name: {feature_name: value}}
Returns:
results: {concept_name: (is_anomaly, score)}
"""
results = {}
for concept_name, features in concept_features.items():
is_anomaly, score = self.detect(concept_name, features)
results[concept_name] = (is_anomaly, score)
return results
def get_anomaly_type(
self,
concept_name: str,
features: Dict[str, float]
) -> str:
"""
判断异动类型
Args:
concept_name: 概念名称
features: 当前特征
Returns:
anomaly_type: 'surge_up' / 'surge_down' / 'normal'
"""
is_anomaly, score = self.detect(concept_name, features)
if not is_anomaly:
return 'normal'
# 根据 alpha 判断涨跌
alpha = features.get('alpha', 0.0)
if alpha > 0:
return 'surge_up'
else:
return 'surge_down'
def get_top_anomalies(
self,
concept_features: Dict[str, Dict[str, float]],
top_k: int = 10
) -> List[Tuple[str, float, str]]:
"""
获取异动分数最高的 top_k 个概念
Args:
concept_features: {concept_name: {feature_name: value}}
top_k: 返回数量
Returns:
anomalies: [(concept_name, score, anomaly_type), ...]
"""
results = self.batch_detect(concept_features)
# 按分数排序
sorted_results = sorted(
[(name, is_anomaly, score) for name, (is_anomaly, score) in results.items() if score is not None],
key=lambda x: x[2],
reverse=True
)
# 取 top_k
top_anomalies = []
for name, is_anomaly, score in sorted_results[:top_k]:
if is_anomaly:
alpha = concept_features[name].get('alpha', 0.0)
anomaly_type = 'surge_up' if alpha > 0 else 'surge_down'
top_anomalies.append((name, score, anomaly_type))
return top_anomalies
def clear_history(self, concept_name: str = None):
"""
清除历史缓存
Args:
concept_name: 概念名称(如果为 None清除所有
"""
if concept_name is None:
self.history_cache.clear()
elif concept_name in self.history_cache:
del self.history_cache[concept_name]
# ==================== 集成到现有系统 ====================
class MLAnomalyService:
"""
ML 异动检测服务
用于替换或增强现有的 Alpha-based 检测
"""
def __init__(
self,
checkpoint_dir: str = 'ml/checkpoints',
fallback_to_alpha: bool = True
):
"""
Args:
checkpoint_dir: 模型检查点目录
fallback_to_alpha: 当 ML 模型不可用时是否回退到 Alpha 方法
"""
self.fallback_to_alpha = fallback_to_alpha
self.ml_detector = None
try:
self.ml_detector = ConceptAnomalyDetector(checkpoint_dir)
print("ML 异动检测服务初始化成功")
except Exception as e:
print(f"ML 模型加载失败: {e}")
if not fallback_to_alpha:
raise
print("将回退到 Alpha-based 检测")
def is_ml_available(self) -> bool:
"""检查 ML 模型是否可用"""
return self.ml_detector is not None
def detect_anomaly(
self,
concept_name: str,
features: Dict[str, float],
alpha_threshold: float = 2.0
) -> Tuple[bool, float, str]:
"""
检测异动
Args:
concept_name: 概念名称
features: 特征字典(需包含 alpha, amt_ratio 等)
alpha_threshold: Alpha Z-Score 阈值(用于回退)
Returns:
is_anomaly: 是否异动
score: 异动分数
method: 检测方法 ('ml' / 'alpha')
"""
# 优先使用 ML 检测
if self.ml_detector is not None:
history_len = self.ml_detector.get_history_length(concept_name)
# 历史数据足够时使用 ML
if history_len >= self.ml_detector.seq_len - 1:
is_anomaly, score = self.ml_detector.detect(concept_name, features)
if score is not None:
return is_anomaly, score, 'ml'
else:
# 更新历史但使用 Alpha 方法
self.ml_detector.update_history(concept_name, features)
# 回退到 Alpha 方法
if self.fallback_to_alpha:
alpha = features.get('alpha', 0.0)
alpha_zscore = features.get('alpha_zscore', 0.0)
is_anomaly = abs(alpha_zscore) > alpha_threshold
score = abs(alpha_zscore)
return is_anomaly, score, 'alpha'
return False, 0.0, 'none'
# ==================== 测试 ====================
if __name__ == "__main__":
import random
print("测试 ConceptAnomalyDetector...")
# 检查模型是否存在
checkpoint_dir = Path('ml/checkpoints')
if not (checkpoint_dir / 'best_model.pt').exists():
print("模型文件不存在,跳过测试")
print("请先运行 train.py 训练模型")
exit(0)
# 初始化检测器
detector = ConceptAnomalyDetector('ml/checkpoints')
# 模拟数据
print("\n模拟实时检测...")
concept_name = "人工智能"
for i in range(40):
# 生成随机特征
features = {
'alpha': random.gauss(0, 1),
'alpha_delta': random.gauss(0, 0.5),
'amt_ratio': random.gauss(1, 0.3),
'amt_delta': random.gauss(0, 0.2),
'rank_pct': random.random(),
'limit_up_ratio': random.random() * 0.1,
}
# 在第 35 分钟模拟异动
if i == 35:
features['alpha'] = 5.0
features['alpha_delta'] = 2.0
features['amt_ratio'] = 3.0
is_anomaly, score = detector.detect(concept_name, features)
history_len = detector.get_history_length(concept_name)
if score is not None:
status = "🔥 异动!" if is_anomaly else "正常"
print(f" t={i:02d} | 历史={history_len} | 分数={score:.4f} | {status}")
else:
print(f" t={i:02d} | 历史={history_len} | 数据不足")
print("\n测试完成!")