update pay ui
This commit is contained in:
729
ml/realtime_detector_v2.py
Normal file
729
ml/realtime_detector_v2.py
Normal file
@@ -0,0 +1,729 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
V2 实时异动检测器
|
||||
|
||||
使用方法:
|
||||
# 作为模块导入
|
||||
from ml.realtime_detector_v2 import RealtimeDetectorV2
|
||||
|
||||
detector = RealtimeDetectorV2()
|
||||
alerts = detector.detect_realtime() # 检测当前时刻
|
||||
|
||||
# 或命令行测试
|
||||
python ml/realtime_detector_v2.py --date 2025-12-09
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import pickle
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from collections import defaultdict, deque
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
from sqlalchemy import create_engine, text
|
||||
from elasticsearch import Elasticsearch
|
||||
from clickhouse_driver import Client
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from ml.model import TransformerAutoencoder
|
||||
|
||||
# ==================== 配置 ====================
|
||||
|
||||
MYSQL_URL = "mysql+pymysql://root:Zzl5588161!@192.168.1.5:3306/stock"
|
||||
ES_HOST = 'http://127.0.0.1:9200'
|
||||
ES_INDEX = 'concept_library_v3'
|
||||
|
||||
CLICKHOUSE_CONFIG = {
|
||||
'host': '127.0.0.1',
|
||||
'port': 9000,
|
||||
'user': 'default',
|
||||
'password': 'Zzl33818!',
|
||||
'database': 'stock'
|
||||
}
|
||||
|
||||
REFERENCE_INDEX = '000001.SH'
|
||||
BASELINE_FILE = 'ml/data_v2/baselines/realtime_baseline.pkl'
|
||||
MODEL_DIR = 'ml/checkpoints_v2'
|
||||
|
||||
# 检测配置
|
||||
CONFIG = {
|
||||
'seq_len': 10, # LSTM 序列长度
|
||||
'confirm_window': 5, # 持续确认窗口
|
||||
'confirm_ratio': 0.6, # 确认比例
|
||||
'rule_weight': 0.5,
|
||||
'ml_weight': 0.5,
|
||||
'rule_trigger': 60,
|
||||
'ml_trigger': 70,
|
||||
'fusion_trigger': 50,
|
||||
'cooldown_minutes': 10,
|
||||
'max_alerts_per_minute': 15,
|
||||
'zscore_clip': 5.0,
|
||||
'limit_up_threshold': 9.8,
|
||||
}
|
||||
|
||||
FEATURES = ['alpha_zscore', 'amt_zscore', 'rank_zscore', 'momentum_3m', 'momentum_5m', 'limit_up_ratio']
|
||||
|
||||
|
||||
# ==================== 数据库连接 ====================
|
||||
|
||||
_mysql_engine = None
|
||||
_es_client = None
|
||||
_ch_client = None
|
||||
|
||||
|
||||
def get_mysql_engine():
|
||||
global _mysql_engine
|
||||
if _mysql_engine is None:
|
||||
_mysql_engine = create_engine(MYSQL_URL, echo=False, pool_pre_ping=True)
|
||||
return _mysql_engine
|
||||
|
||||
|
||||
def get_es_client():
|
||||
global _es_client
|
||||
if _es_client is None:
|
||||
_es_client = Elasticsearch([ES_HOST])
|
||||
return _es_client
|
||||
|
||||
|
||||
def get_ch_client():
|
||||
global _ch_client
|
||||
if _ch_client is None:
|
||||
_ch_client = Client(**CLICKHOUSE_CONFIG)
|
||||
return _ch_client
|
||||
|
||||
|
||||
def code_to_ch_format(code: str) -> str:
|
||||
if not code or len(code) != 6 or not code.isdigit():
|
||||
return None
|
||||
if code.startswith('6'):
|
||||
return f"{code}.SH"
|
||||
elif code.startswith('0') or code.startswith('3'):
|
||||
return f"{code}.SZ"
|
||||
return f"{code}.BJ"
|
||||
|
||||
|
||||
def time_to_slot(ts) -> str:
|
||||
if isinstance(ts, str):
|
||||
return ts
|
||||
return ts.strftime('%H:%M')
|
||||
|
||||
|
||||
# ==================== 规则评分 ====================
|
||||
|
||||
def score_rules_zscore(features: Dict) -> Tuple[float, List[str]]:
|
||||
"""基于 Z-Score 的规则评分"""
|
||||
score = 0.0
|
||||
triggered = []
|
||||
|
||||
alpha_z = abs(features.get('alpha_zscore', 0))
|
||||
amt_z = features.get('amt_zscore', 0)
|
||||
rank_z = abs(features.get('rank_zscore', 0))
|
||||
mom_3m = features.get('momentum_3m', 0)
|
||||
mom_5m = features.get('momentum_5m', 0)
|
||||
limit_up = features.get('limit_up_ratio', 0)
|
||||
|
||||
# Alpha Z-Score
|
||||
if alpha_z >= 4.0:
|
||||
score += 25
|
||||
triggered.append('alpha_extreme')
|
||||
elif alpha_z >= 3.0:
|
||||
score += 18
|
||||
triggered.append('alpha_strong')
|
||||
elif alpha_z >= 2.0:
|
||||
score += 10
|
||||
triggered.append('alpha_moderate')
|
||||
|
||||
# 成交额 Z-Score
|
||||
if amt_z >= 4.0:
|
||||
score += 20
|
||||
triggered.append('amt_extreme')
|
||||
elif amt_z >= 3.0:
|
||||
score += 12
|
||||
triggered.append('amt_strong')
|
||||
elif amt_z >= 2.0:
|
||||
score += 6
|
||||
triggered.append('amt_moderate')
|
||||
|
||||
# 排名 Z-Score
|
||||
if rank_z >= 3.0:
|
||||
score += 15
|
||||
triggered.append('rank_extreme')
|
||||
elif rank_z >= 2.0:
|
||||
score += 8
|
||||
triggered.append('rank_strong')
|
||||
|
||||
# 动量(基于 Z-Score 的)
|
||||
if mom_3m >= 1.0:
|
||||
score += 12
|
||||
triggered.append('momentum_3m_strong')
|
||||
elif mom_3m >= 0.5:
|
||||
score += 6
|
||||
triggered.append('momentum_3m_moderate')
|
||||
|
||||
if mom_5m >= 1.5:
|
||||
score += 10
|
||||
triggered.append('momentum_5m_strong')
|
||||
|
||||
# 涨停比例
|
||||
if limit_up >= 0.3:
|
||||
score += 20
|
||||
triggered.append('limit_up_extreme')
|
||||
elif limit_up >= 0.15:
|
||||
score += 12
|
||||
triggered.append('limit_up_strong')
|
||||
elif limit_up >= 0.08:
|
||||
score += 5
|
||||
triggered.append('limit_up_moderate')
|
||||
|
||||
# 组合规则
|
||||
if alpha_z >= 2.0 and amt_z >= 2.0:
|
||||
score += 15
|
||||
triggered.append('combo_alpha_amt')
|
||||
|
||||
if alpha_z >= 2.0 and limit_up >= 0.1:
|
||||
score += 12
|
||||
triggered.append('combo_alpha_limitup')
|
||||
|
||||
return min(score, 100), triggered
|
||||
|
||||
|
||||
# ==================== 实时检测器 ====================
|
||||
|
||||
class RealtimeDetectorV2:
|
||||
"""V2 实时异动检测器"""
|
||||
|
||||
def __init__(self, model_dir: str = MODEL_DIR, baseline_file: str = BASELINE_FILE):
|
||||
print("初始化 V2 实时检测器...")
|
||||
|
||||
# 加载概念
|
||||
self.concepts = self._load_concepts()
|
||||
self.concept_stocks = {c['concept_id']: set(c['stocks']) for c in self.concepts}
|
||||
self.all_stocks = list(set(s for c in self.concepts for s in c['stocks']))
|
||||
|
||||
# 加载基线
|
||||
self.baselines = self._load_baselines(baseline_file)
|
||||
|
||||
# 加载模型
|
||||
self.model, self.thresholds, self.device = self._load_model(model_dir)
|
||||
|
||||
# 状态管理
|
||||
self.zscore_history = defaultdict(lambda: deque(maxlen=CONFIG['seq_len']))
|
||||
self.anomaly_candidates = defaultdict(lambda: deque(maxlen=CONFIG['confirm_window']))
|
||||
self.cooldown = {}
|
||||
|
||||
print(f"初始化完成: {len(self.concepts)} 概念, {len(self.baselines)} 基线")
|
||||
|
||||
def _load_concepts(self) -> List[dict]:
|
||||
"""从 ES 加载概念"""
|
||||
es = get_es_client()
|
||||
concepts = []
|
||||
|
||||
query = {"query": {"match_all": {}}, "size": 100, "_source": ["concept_id", "concept", "stocks"]}
|
||||
resp = es.search(index=ES_INDEX, body=query, scroll='2m')
|
||||
scroll_id = resp['_scroll_id']
|
||||
hits = resp['hits']['hits']
|
||||
|
||||
while hits:
|
||||
for hit in hits:
|
||||
src = hit['_source']
|
||||
stocks = [s['code'] for s in src.get('stocks', []) if isinstance(s, dict) and s.get('code')]
|
||||
if stocks:
|
||||
concepts.append({
|
||||
'concept_id': src.get('concept_id'),
|
||||
'concept_name': src.get('concept'),
|
||||
'stocks': stocks
|
||||
})
|
||||
resp = es.scroll(scroll_id=scroll_id, scroll='2m')
|
||||
scroll_id = resp['_scroll_id']
|
||||
hits = resp['hits']['hits']
|
||||
|
||||
es.clear_scroll(scroll_id=scroll_id)
|
||||
return concepts
|
||||
|
||||
def _load_baselines(self, baseline_file: str) -> Dict:
|
||||
"""加载基线"""
|
||||
if not os.path.exists(baseline_file):
|
||||
print(f"警告: 基线文件不存在: {baseline_file}")
|
||||
print("请先运行: python ml/update_baseline.py")
|
||||
return {}
|
||||
|
||||
with open(baseline_file, 'rb') as f:
|
||||
data = pickle.load(f)
|
||||
|
||||
print(f"基线日期范围: {data.get('date_range', 'unknown')}")
|
||||
print(f"更新时间: {data.get('update_time', 'unknown')}")
|
||||
|
||||
return data.get('baselines', {})
|
||||
|
||||
def _load_model(self, model_dir: str):
|
||||
"""加载模型"""
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
config_path = os.path.join(model_dir, 'config.json')
|
||||
model_path = os.path.join(model_dir, 'best_model.pt')
|
||||
threshold_path = os.path.join(model_dir, 'thresholds.json')
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
print(f"警告: 模型不存在: {model_path}")
|
||||
return None, {}, device
|
||||
|
||||
with open(config_path) as f:
|
||||
config = json.load(f)
|
||||
|
||||
model = TransformerAutoencoder(**config['model'])
|
||||
checkpoint = torch.load(model_path, map_location=device)
|
||||
model.load_state_dict(checkpoint['model_state_dict'])
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
thresholds = {}
|
||||
if os.path.exists(threshold_path):
|
||||
with open(threshold_path) as f:
|
||||
thresholds = json.load(f)
|
||||
|
||||
print(f"模型已加载: {model_path}")
|
||||
return model, thresholds, device
|
||||
|
||||
def _get_realtime_data(self, trade_date: str) -> pd.DataFrame:
|
||||
"""获取实时数据并计算原始特征"""
|
||||
ch = get_ch_client()
|
||||
|
||||
# 获取股票数据
|
||||
ch_codes = [code_to_ch_format(c) for c in self.all_stocks if code_to_ch_format(c)]
|
||||
ch_codes_str = "','".join(ch_codes)
|
||||
|
||||
stock_query = f"""
|
||||
SELECT code, timestamp, close, amt
|
||||
FROM stock_minute
|
||||
WHERE toDate(timestamp) = '{trade_date}'
|
||||
AND code IN ('{ch_codes_str}')
|
||||
ORDER BY timestamp
|
||||
"""
|
||||
stock_result = ch.execute(stock_query)
|
||||
if not stock_result:
|
||||
return pd.DataFrame()
|
||||
|
||||
stock_df = pd.DataFrame(stock_result, columns=['ch_code', 'timestamp', 'close', 'amt'])
|
||||
|
||||
# 映射回原始代码
|
||||
ch_to_code = {code_to_ch_format(c): c for c in self.all_stocks if code_to_ch_format(c)}
|
||||
stock_df['code'] = stock_df['ch_code'].map(ch_to_code)
|
||||
stock_df = stock_df.dropna(subset=['code'])
|
||||
|
||||
# 获取指数数据
|
||||
index_query = f"""
|
||||
SELECT timestamp, close
|
||||
FROM index_minute
|
||||
WHERE toDate(timestamp) = '{trade_date}'
|
||||
AND code = '{REFERENCE_INDEX}'
|
||||
ORDER BY timestamp
|
||||
"""
|
||||
index_result = ch.execute(index_query)
|
||||
if not index_result:
|
||||
return pd.DataFrame()
|
||||
|
||||
index_df = pd.DataFrame(index_result, columns=['timestamp', 'close'])
|
||||
|
||||
# 获取昨收价
|
||||
engine = get_mysql_engine()
|
||||
codes_str = "','".join([c for c in self.all_stocks if c and len(c) == 6])
|
||||
|
||||
with engine.connect() as conn:
|
||||
prev_result = conn.execute(text(f"""
|
||||
SELECT SECCODE, F007N FROM ea_trade
|
||||
WHERE SECCODE IN ('{codes_str}')
|
||||
AND TRADEDATE = (SELECT MAX(TRADEDATE) FROM ea_trade WHERE TRADEDATE < '{trade_date}')
|
||||
AND F007N > 0
|
||||
"""))
|
||||
prev_close = {row[0]: float(row[1]) for row in prev_result if row[1]}
|
||||
|
||||
idx_result = conn.execute(text("""
|
||||
SELECT F006N FROM ea_exchangetrade
|
||||
WHERE INDEXCODE = '000001' AND TRADEDATE < :today
|
||||
ORDER BY TRADEDATE DESC LIMIT 1
|
||||
"""), {'today': trade_date}).fetchone()
|
||||
index_prev_close = float(idx_result[0]) if idx_result else None
|
||||
|
||||
if not prev_close or not index_prev_close:
|
||||
return pd.DataFrame()
|
||||
|
||||
# 计算涨跌幅
|
||||
stock_df['prev_close'] = stock_df['code'].map(prev_close)
|
||||
stock_df = stock_df.dropna(subset=['prev_close'])
|
||||
stock_df['change_pct'] = (stock_df['close'] - stock_df['prev_close']) / stock_df['prev_close'] * 100
|
||||
|
||||
index_df['change_pct'] = (index_df['close'] - index_prev_close) / index_prev_close * 100
|
||||
index_map = dict(zip(index_df['timestamp'], index_df['change_pct']))
|
||||
|
||||
# 按时间聚合概念特征
|
||||
results = []
|
||||
for ts in sorted(stock_df['timestamp'].unique()):
|
||||
ts_data = stock_df[stock_df['timestamp'] == ts]
|
||||
idx_chg = index_map.get(ts, 0)
|
||||
|
||||
stock_chg = dict(zip(ts_data['code'], ts_data['change_pct']))
|
||||
stock_amt = dict(zip(ts_data['code'], ts_data['amt']))
|
||||
|
||||
for cid, stocks in self.concept_stocks.items():
|
||||
changes = [stock_chg[s] for s in stocks if s in stock_chg]
|
||||
amts = [stock_amt.get(s, 0) for s in stocks if s in stock_chg]
|
||||
|
||||
if not changes:
|
||||
continue
|
||||
|
||||
alpha = np.mean(changes) - idx_chg
|
||||
total_amt = sum(amts)
|
||||
limit_up_ratio = sum(1 for c in changes if c >= CONFIG['limit_up_threshold']) / len(changes)
|
||||
|
||||
results.append({
|
||||
'concept_id': cid,
|
||||
'timestamp': ts,
|
||||
'time_slot': time_to_slot(ts),
|
||||
'alpha': alpha,
|
||||
'total_amt': total_amt,
|
||||
'limit_up_ratio': limit_up_ratio,
|
||||
'stock_count': len(changes),
|
||||
})
|
||||
|
||||
if not results:
|
||||
return pd.DataFrame()
|
||||
|
||||
df = pd.DataFrame(results)
|
||||
|
||||
# 计算排名
|
||||
for ts in df['timestamp'].unique():
|
||||
mask = df['timestamp'] == ts
|
||||
df.loc[mask, 'rank_pct'] = df.loc[mask, 'alpha'].rank(pct=True)
|
||||
|
||||
return df
|
||||
|
||||
def _compute_zscore(self, concept_id: str, time_slot: str, alpha: float, total_amt: float, rank_pct: float) -> Optional[Dict]:
|
||||
"""计算 Z-Score"""
|
||||
if concept_id not in self.baselines:
|
||||
return None
|
||||
|
||||
baseline = self.baselines[concept_id]
|
||||
if time_slot not in baseline:
|
||||
return None
|
||||
|
||||
bl = baseline[time_slot]
|
||||
|
||||
alpha_z = np.clip((alpha - bl['alpha_mean']) / bl['alpha_std'], -5, 5)
|
||||
amt_z = np.clip((total_amt - bl['amt_mean']) / bl['amt_std'], -5, 5)
|
||||
rank_z = np.clip((rank_pct - bl['rank_mean']) / bl['rank_std'], -5, 5)
|
||||
|
||||
# 动量(基于 Z-Score 历史)
|
||||
history = list(self.zscore_history[concept_id])
|
||||
mom_3m = 0.0
|
||||
mom_5m = 0.0
|
||||
|
||||
if len(history) >= 3:
|
||||
recent = [h['alpha_zscore'] for h in history[-3:]]
|
||||
older = [h['alpha_zscore'] for h in history[-6:-3]] if len(history) >= 6 else [history[0]['alpha_zscore']]
|
||||
mom_3m = np.mean(recent) - np.mean(older)
|
||||
|
||||
if len(history) >= 5:
|
||||
recent = [h['alpha_zscore'] for h in history[-5:]]
|
||||
older = [h['alpha_zscore'] for h in history[-10:-5]] if len(history) >= 10 else [history[0]['alpha_zscore']]
|
||||
mom_5m = np.mean(recent) - np.mean(older)
|
||||
|
||||
return {
|
||||
'alpha_zscore': float(alpha_z),
|
||||
'amt_zscore': float(amt_z),
|
||||
'rank_zscore': float(rank_z),
|
||||
'momentum_3m': float(mom_3m),
|
||||
'momentum_5m': float(mom_5m),
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def _ml_score(self, sequences: np.ndarray) -> np.ndarray:
|
||||
"""批量 ML 评分"""
|
||||
if self.model is None or len(sequences) == 0:
|
||||
return np.zeros(len(sequences))
|
||||
|
||||
x = torch.FloatTensor(sequences).to(self.device)
|
||||
errors = self.model.compute_reconstruction_error(x, reduction='none')
|
||||
last_errors = errors[:, -1].cpu().numpy()
|
||||
|
||||
# 转换为 0-100 分数
|
||||
if self.thresholds:
|
||||
p50 = self.thresholds.get('median', 0.001)
|
||||
p99 = self.thresholds.get('p99', 0.05)
|
||||
scores = 50 + (last_errors - p50) / (p99 - p50 + 1e-6) * 49
|
||||
else:
|
||||
scores = last_errors * 1000
|
||||
|
||||
return np.clip(scores, 0, 100)
|
||||
|
||||
def detect(self, trade_date: str = None) -> List[Dict]:
|
||||
"""检测指定日期的异动"""
|
||||
trade_date = trade_date or datetime.now().strftime('%Y-%m-%d')
|
||||
print(f"\n检测 {trade_date} 的异动...")
|
||||
|
||||
# 重置状态
|
||||
self.zscore_history.clear()
|
||||
self.anomaly_candidates.clear()
|
||||
self.cooldown.clear()
|
||||
|
||||
# 获取数据
|
||||
raw_df = self._get_realtime_data(trade_date)
|
||||
if raw_df.empty:
|
||||
print("无数据")
|
||||
return []
|
||||
|
||||
timestamps = sorted(raw_df['timestamp'].unique())
|
||||
print(f"时间点数: {len(timestamps)}")
|
||||
|
||||
all_alerts = []
|
||||
|
||||
for ts in timestamps:
|
||||
ts_data = raw_df[raw_df['timestamp'] == ts]
|
||||
time_slot = time_to_slot(ts)
|
||||
|
||||
candidates = []
|
||||
|
||||
# 计算每个概念的 Z-Score
|
||||
for _, row in ts_data.iterrows():
|
||||
cid = row['concept_id']
|
||||
|
||||
zscore = self._compute_zscore(
|
||||
cid, time_slot,
|
||||
row['alpha'], row['total_amt'], row['rank_pct']
|
||||
)
|
||||
|
||||
if zscore is None:
|
||||
continue
|
||||
|
||||
# 完整特征
|
||||
features = {
|
||||
**zscore,
|
||||
'alpha': row['alpha'],
|
||||
'limit_up_ratio': row['limit_up_ratio'],
|
||||
'total_amt': row['total_amt'],
|
||||
}
|
||||
|
||||
# 更新历史
|
||||
self.zscore_history[cid].append(zscore)
|
||||
|
||||
# 规则评分
|
||||
rule_score, triggered = score_rules_zscore(features)
|
||||
|
||||
candidates.append((cid, features, rule_score, triggered))
|
||||
|
||||
if not candidates:
|
||||
continue
|
||||
|
||||
# 批量 ML 评分
|
||||
sequences = []
|
||||
valid_candidates = []
|
||||
|
||||
for cid, features, rule_score, triggered in candidates:
|
||||
history = list(self.zscore_history[cid])
|
||||
if len(history) >= CONFIG['seq_len']:
|
||||
seq = np.array([[h['alpha_zscore'], h['amt_zscore'], h['rank_zscore'],
|
||||
h['momentum_3m'], h['momentum_5m'], features['limit_up_ratio']]
|
||||
for h in history])
|
||||
sequences.append(seq)
|
||||
valid_candidates.append((cid, features, rule_score, triggered))
|
||||
|
||||
if not sequences:
|
||||
continue
|
||||
|
||||
ml_scores = self._ml_score(np.array(sequences))
|
||||
|
||||
# 融合 + 确认
|
||||
for i, (cid, features, rule_score, triggered) in enumerate(valid_candidates):
|
||||
ml_score = ml_scores[i]
|
||||
final_score = CONFIG['rule_weight'] * rule_score + CONFIG['ml_weight'] * ml_score
|
||||
|
||||
# 判断触发
|
||||
is_triggered = (
|
||||
rule_score >= CONFIG['rule_trigger'] or
|
||||
ml_score >= CONFIG['ml_trigger'] or
|
||||
final_score >= CONFIG['fusion_trigger']
|
||||
)
|
||||
|
||||
self.anomaly_candidates[cid].append((ts, final_score))
|
||||
|
||||
if not is_triggered:
|
||||
continue
|
||||
|
||||
# 冷却期
|
||||
if cid in self.cooldown:
|
||||
if (ts - self.cooldown[cid]).total_seconds() < CONFIG['cooldown_minutes'] * 60:
|
||||
continue
|
||||
|
||||
# 持续确认
|
||||
recent = list(self.anomaly_candidates[cid])
|
||||
if len(recent) < CONFIG['confirm_window']:
|
||||
continue
|
||||
|
||||
exceed = sum(1 for _, s in recent if s >= CONFIG['fusion_trigger'])
|
||||
ratio = exceed / len(recent)
|
||||
|
||||
if ratio < CONFIG['confirm_ratio']:
|
||||
continue
|
||||
|
||||
# 确认异动!
|
||||
self.cooldown[cid] = ts
|
||||
|
||||
alpha = features['alpha']
|
||||
alert_type = 'surge_up' if alpha >= 1.5 else 'surge_down' if alpha <= -1.5 else 'surge'
|
||||
|
||||
concept_name = next((c['concept_name'] for c in self.concepts if c['concept_id'] == cid), cid)
|
||||
|
||||
all_alerts.append({
|
||||
'concept_id': cid,
|
||||
'concept_name': concept_name,
|
||||
'alert_time': ts,
|
||||
'trade_date': trade_date,
|
||||
'alert_type': alert_type,
|
||||
'final_score': float(final_score),
|
||||
'rule_score': float(rule_score),
|
||||
'ml_score': float(ml_score),
|
||||
'confirm_ratio': float(ratio),
|
||||
'alpha': float(alpha),
|
||||
'alpha_zscore': float(features['alpha_zscore']),
|
||||
'amt_zscore': float(features['amt_zscore']),
|
||||
'rank_zscore': float(features['rank_zscore']),
|
||||
'momentum_3m': float(features['momentum_3m']),
|
||||
'momentum_5m': float(features['momentum_5m']),
|
||||
'limit_up_ratio': float(features['limit_up_ratio']),
|
||||
'triggered_rules': triggered,
|
||||
'trigger_reason': f"融合({final_score:.0f})+确认({ratio:.0%})",
|
||||
})
|
||||
|
||||
print(f"检测到 {len(all_alerts)} 个异动")
|
||||
return all_alerts
|
||||
|
||||
|
||||
# ==================== 数据库存储 ====================
|
||||
|
||||
def create_v2_table():
|
||||
"""创建 V2 异动表(如果不存在)"""
|
||||
engine = get_mysql_engine()
|
||||
with engine.begin() as conn:
|
||||
conn.execute(text("""
|
||||
CREATE TABLE IF NOT EXISTS concept_anomaly_v2 (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
concept_id VARCHAR(50) NOT NULL,
|
||||
alert_time DATETIME NOT NULL,
|
||||
trade_date DATE NOT NULL,
|
||||
alert_type VARCHAR(20) NOT NULL,
|
||||
final_score FLOAT,
|
||||
rule_score FLOAT,
|
||||
ml_score FLOAT,
|
||||
trigger_reason VARCHAR(200),
|
||||
confirm_ratio FLOAT,
|
||||
alpha FLOAT,
|
||||
alpha_zscore FLOAT,
|
||||
amt_zscore FLOAT,
|
||||
rank_zscore FLOAT,
|
||||
momentum_3m FLOAT,
|
||||
momentum_5m FLOAT,
|
||||
limit_up_ratio FLOAT,
|
||||
triggered_rules TEXT,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
UNIQUE KEY uk_concept_time (concept_id, alert_time),
|
||||
INDEX idx_trade_date (trade_date),
|
||||
INDEX idx_alert_type (alert_type)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
|
||||
"""))
|
||||
print("concept_anomaly_v2 表已就绪")
|
||||
|
||||
|
||||
def save_alerts_to_db(alerts: List[Dict]) -> int:
|
||||
"""保存异动到数据库"""
|
||||
if not alerts:
|
||||
return 0
|
||||
|
||||
engine = get_mysql_engine()
|
||||
saved = 0
|
||||
|
||||
with engine.begin() as conn:
|
||||
for alert in alerts:
|
||||
try:
|
||||
insert_sql = text("""
|
||||
INSERT IGNORE INTO concept_anomaly_v2
|
||||
(concept_id, alert_time, trade_date, alert_type,
|
||||
final_score, rule_score, ml_score, trigger_reason, confirm_ratio,
|
||||
alpha, alpha_zscore, amt_zscore, rank_zscore,
|
||||
momentum_3m, momentum_5m, limit_up_ratio, triggered_rules)
|
||||
VALUES
|
||||
(:concept_id, :alert_time, :trade_date, :alert_type,
|
||||
:final_score, :rule_score, :ml_score, :trigger_reason, :confirm_ratio,
|
||||
:alpha, :alpha_zscore, :amt_zscore, :rank_zscore,
|
||||
:momentum_3m, :momentum_5m, :limit_up_ratio, :triggered_rules)
|
||||
""")
|
||||
|
||||
result = conn.execute(insert_sql, {
|
||||
'concept_id': alert['concept_id'],
|
||||
'alert_time': alert['alert_time'],
|
||||
'trade_date': alert['trade_date'],
|
||||
'alert_type': alert['alert_type'],
|
||||
'final_score': alert['final_score'],
|
||||
'rule_score': alert['rule_score'],
|
||||
'ml_score': alert['ml_score'],
|
||||
'trigger_reason': alert['trigger_reason'],
|
||||
'confirm_ratio': alert['confirm_ratio'],
|
||||
'alpha': alert['alpha'],
|
||||
'alpha_zscore': alert['alpha_zscore'],
|
||||
'amt_zscore': alert['amt_zscore'],
|
||||
'rank_zscore': alert['rank_zscore'],
|
||||
'momentum_3m': alert['momentum_3m'],
|
||||
'momentum_5m': alert['momentum_5m'],
|
||||
'limit_up_ratio': alert['limit_up_ratio'],
|
||||
'triggered_rules': json.dumps(alert.get('triggered_rules', []), ensure_ascii=False),
|
||||
})
|
||||
|
||||
if result.rowcount > 0:
|
||||
saved += 1
|
||||
except Exception as e:
|
||||
print(f"保存失败: {alert['concept_id']} - {e}")
|
||||
|
||||
return saved
|
||||
|
||||
|
||||
def main():
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--date', type=str, default=None)
|
||||
parser.add_argument('--no-save', action='store_true', help='不保存到数据库,只打印')
|
||||
args = parser.parse_args()
|
||||
|
||||
# 确保表存在
|
||||
if not args.no_save:
|
||||
create_v2_table()
|
||||
|
||||
detector = RealtimeDetectorV2()
|
||||
alerts = detector.detect(args.date)
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f"检测结果 ({len(alerts)} 个异动)")
|
||||
print('='*60)
|
||||
|
||||
for a in alerts[:20]:
|
||||
print(f"[{a['alert_time'].strftime('%H:%M') if hasattr(a['alert_time'], 'strftime') else a['alert_time']}] "
|
||||
f"{a['concept_name']} | {a['alert_type']} | "
|
||||
f"分数={a['final_score']:.0f} 确认={a['confirm_ratio']:.0%} "
|
||||
f"α={a['alpha']:.2f}% αZ={a['alpha_zscore']:.1f}")
|
||||
|
||||
if len(alerts) > 20:
|
||||
print(f"... 共 {len(alerts)} 个")
|
||||
|
||||
# 保存到数据库
|
||||
if not args.no_save and alerts:
|
||||
saved = save_alerts_to_db(alerts)
|
||||
print(f"\n✅ 已保存 {saved}/{len(alerts)} 条到 concept_anomaly_v2 表")
|
||||
elif args.no_save:
|
||||
print(f"\n⚠️ --no-save 模式,未保存到数据库")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user