730 lines
25 KiB
Python
730 lines
25 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
V2 实时异动检测器
|
||
|
||
使用方法:
|
||
# 作为模块导入
|
||
from ml.realtime_detector_v2 import RealtimeDetectorV2
|
||
|
||
detector = RealtimeDetectorV2()
|
||
alerts = detector.detect_realtime() # 检测当前时刻
|
||
|
||
# 或命令行测试
|
||
python ml/realtime_detector_v2.py --date 2025-12-09
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import json
|
||
import pickle
|
||
from datetime import datetime, timedelta
|
||
from typing import Dict, List, Optional, Tuple
|
||
from collections import defaultdict, deque
|
||
|
||
import numpy as np
|
||
import pandas as pd
|
||
import torch
|
||
from sqlalchemy import create_engine, text
|
||
from elasticsearch import Elasticsearch
|
||
from clickhouse_driver import Client
|
||
|
||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
|
||
from ml.model import TransformerAutoencoder
|
||
|
||
# ==================== 配置 ====================
|
||
|
||
MYSQL_URL = "mysql+pymysql://root:Zzl5588161!@192.168.1.5:3306/stock"
|
||
ES_HOST = 'http://127.0.0.1:9200'
|
||
ES_INDEX = 'concept_library_v3'
|
||
|
||
CLICKHOUSE_CONFIG = {
|
||
'host': '127.0.0.1',
|
||
'port': 9000,
|
||
'user': 'default',
|
||
'password': 'Zzl33818!',
|
||
'database': 'stock'
|
||
}
|
||
|
||
REFERENCE_INDEX = '000001.SH'
|
||
BASELINE_FILE = 'ml/data_v2/baselines/realtime_baseline.pkl'
|
||
MODEL_DIR = 'ml/checkpoints_v2'
|
||
|
||
# 检测配置
|
||
CONFIG = {
|
||
'seq_len': 10, # LSTM 序列长度
|
||
'confirm_window': 5, # 持续确认窗口
|
||
'confirm_ratio': 0.6, # 确认比例
|
||
'rule_weight': 0.5,
|
||
'ml_weight': 0.5,
|
||
'rule_trigger': 60,
|
||
'ml_trigger': 70,
|
||
'fusion_trigger': 50,
|
||
'cooldown_minutes': 10,
|
||
'max_alerts_per_minute': 15,
|
||
'zscore_clip': 5.0,
|
||
'limit_up_threshold': 9.8,
|
||
}
|
||
|
||
FEATURES = ['alpha_zscore', 'amt_zscore', 'rank_zscore', 'momentum_3m', 'momentum_5m', 'limit_up_ratio']
|
||
|
||
|
||
# ==================== 数据库连接 ====================
|
||
|
||
_mysql_engine = None
|
||
_es_client = None
|
||
_ch_client = None
|
||
|
||
|
||
def get_mysql_engine():
|
||
global _mysql_engine
|
||
if _mysql_engine is None:
|
||
_mysql_engine = create_engine(MYSQL_URL, echo=False, pool_pre_ping=True)
|
||
return _mysql_engine
|
||
|
||
|
||
def get_es_client():
|
||
global _es_client
|
||
if _es_client is None:
|
||
_es_client = Elasticsearch([ES_HOST])
|
||
return _es_client
|
||
|
||
|
||
def get_ch_client():
|
||
global _ch_client
|
||
if _ch_client is None:
|
||
_ch_client = Client(**CLICKHOUSE_CONFIG)
|
||
return _ch_client
|
||
|
||
|
||
def code_to_ch_format(code: str) -> str:
|
||
if not code or len(code) != 6 or not code.isdigit():
|
||
return None
|
||
if code.startswith('6'):
|
||
return f"{code}.SH"
|
||
elif code.startswith('0') or code.startswith('3'):
|
||
return f"{code}.SZ"
|
||
return f"{code}.BJ"
|
||
|
||
|
||
def time_to_slot(ts) -> str:
|
||
if isinstance(ts, str):
|
||
return ts
|
||
return ts.strftime('%H:%M')
|
||
|
||
|
||
# ==================== 规则评分 ====================
|
||
|
||
def score_rules_zscore(features: Dict) -> Tuple[float, List[str]]:
|
||
"""基于 Z-Score 的规则评分"""
|
||
score = 0.0
|
||
triggered = []
|
||
|
||
alpha_z = abs(features.get('alpha_zscore', 0))
|
||
amt_z = features.get('amt_zscore', 0)
|
||
rank_z = abs(features.get('rank_zscore', 0))
|
||
mom_3m = features.get('momentum_3m', 0)
|
||
mom_5m = features.get('momentum_5m', 0)
|
||
limit_up = features.get('limit_up_ratio', 0)
|
||
|
||
# Alpha Z-Score
|
||
if alpha_z >= 4.0:
|
||
score += 25
|
||
triggered.append('alpha_extreme')
|
||
elif alpha_z >= 3.0:
|
||
score += 18
|
||
triggered.append('alpha_strong')
|
||
elif alpha_z >= 2.0:
|
||
score += 10
|
||
triggered.append('alpha_moderate')
|
||
|
||
# 成交额 Z-Score
|
||
if amt_z >= 4.0:
|
||
score += 20
|
||
triggered.append('amt_extreme')
|
||
elif amt_z >= 3.0:
|
||
score += 12
|
||
triggered.append('amt_strong')
|
||
elif amt_z >= 2.0:
|
||
score += 6
|
||
triggered.append('amt_moderate')
|
||
|
||
# 排名 Z-Score
|
||
if rank_z >= 3.0:
|
||
score += 15
|
||
triggered.append('rank_extreme')
|
||
elif rank_z >= 2.0:
|
||
score += 8
|
||
triggered.append('rank_strong')
|
||
|
||
# 动量(基于 Z-Score 的)
|
||
if mom_3m >= 1.0:
|
||
score += 12
|
||
triggered.append('momentum_3m_strong')
|
||
elif mom_3m >= 0.5:
|
||
score += 6
|
||
triggered.append('momentum_3m_moderate')
|
||
|
||
if mom_5m >= 1.5:
|
||
score += 10
|
||
triggered.append('momentum_5m_strong')
|
||
|
||
# 涨停比例
|
||
if limit_up >= 0.3:
|
||
score += 20
|
||
triggered.append('limit_up_extreme')
|
||
elif limit_up >= 0.15:
|
||
score += 12
|
||
triggered.append('limit_up_strong')
|
||
elif limit_up >= 0.08:
|
||
score += 5
|
||
triggered.append('limit_up_moderate')
|
||
|
||
# 组合规则
|
||
if alpha_z >= 2.0 and amt_z >= 2.0:
|
||
score += 15
|
||
triggered.append('combo_alpha_amt')
|
||
|
||
if alpha_z >= 2.0 and limit_up >= 0.1:
|
||
score += 12
|
||
triggered.append('combo_alpha_limitup')
|
||
|
||
return min(score, 100), triggered
|
||
|
||
|
||
# ==================== 实时检测器 ====================
|
||
|
||
class RealtimeDetectorV2:
|
||
"""V2 实时异动检测器"""
|
||
|
||
def __init__(self, model_dir: str = MODEL_DIR, baseline_file: str = BASELINE_FILE):
|
||
print("初始化 V2 实时检测器...")
|
||
|
||
# 加载概念
|
||
self.concepts = self._load_concepts()
|
||
self.concept_stocks = {c['concept_id']: set(c['stocks']) for c in self.concepts}
|
||
self.all_stocks = list(set(s for c in self.concepts for s in c['stocks']))
|
||
|
||
# 加载基线
|
||
self.baselines = self._load_baselines(baseline_file)
|
||
|
||
# 加载模型
|
||
self.model, self.thresholds, self.device = self._load_model(model_dir)
|
||
|
||
# 状态管理
|
||
self.zscore_history = defaultdict(lambda: deque(maxlen=CONFIG['seq_len']))
|
||
self.anomaly_candidates = defaultdict(lambda: deque(maxlen=CONFIG['confirm_window']))
|
||
self.cooldown = {}
|
||
|
||
print(f"初始化完成: {len(self.concepts)} 概念, {len(self.baselines)} 基线")
|
||
|
||
def _load_concepts(self) -> List[dict]:
|
||
"""从 ES 加载概念"""
|
||
es = get_es_client()
|
||
concepts = []
|
||
|
||
query = {"query": {"match_all": {}}, "size": 100, "_source": ["concept_id", "concept", "stocks"]}
|
||
resp = es.search(index=ES_INDEX, body=query, scroll='2m')
|
||
scroll_id = resp['_scroll_id']
|
||
hits = resp['hits']['hits']
|
||
|
||
while hits:
|
||
for hit in hits:
|
||
src = hit['_source']
|
||
stocks = [s['code'] for s in src.get('stocks', []) if isinstance(s, dict) and s.get('code')]
|
||
if stocks:
|
||
concepts.append({
|
||
'concept_id': src.get('concept_id'),
|
||
'concept_name': src.get('concept'),
|
||
'stocks': stocks
|
||
})
|
||
resp = es.scroll(scroll_id=scroll_id, scroll='2m')
|
||
scroll_id = resp['_scroll_id']
|
||
hits = resp['hits']['hits']
|
||
|
||
es.clear_scroll(scroll_id=scroll_id)
|
||
return concepts
|
||
|
||
def _load_baselines(self, baseline_file: str) -> Dict:
|
||
"""加载基线"""
|
||
if not os.path.exists(baseline_file):
|
||
print(f"警告: 基线文件不存在: {baseline_file}")
|
||
print("请先运行: python ml/update_baseline.py")
|
||
return {}
|
||
|
||
with open(baseline_file, 'rb') as f:
|
||
data = pickle.load(f)
|
||
|
||
print(f"基线日期范围: {data.get('date_range', 'unknown')}")
|
||
print(f"更新时间: {data.get('update_time', 'unknown')}")
|
||
|
||
return data.get('baselines', {})
|
||
|
||
def _load_model(self, model_dir: str):
|
||
"""加载模型"""
|
||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||
|
||
config_path = os.path.join(model_dir, 'config.json')
|
||
model_path = os.path.join(model_dir, 'best_model.pt')
|
||
threshold_path = os.path.join(model_dir, 'thresholds.json')
|
||
|
||
if not os.path.exists(model_path):
|
||
print(f"警告: 模型不存在: {model_path}")
|
||
return None, {}, device
|
||
|
||
with open(config_path) as f:
|
||
config = json.load(f)
|
||
|
||
model = TransformerAutoencoder(**config['model'])
|
||
checkpoint = torch.load(model_path, map_location=device)
|
||
model.load_state_dict(checkpoint['model_state_dict'])
|
||
model.to(device)
|
||
model.eval()
|
||
|
||
thresholds = {}
|
||
if os.path.exists(threshold_path):
|
||
with open(threshold_path) as f:
|
||
thresholds = json.load(f)
|
||
|
||
print(f"模型已加载: {model_path}")
|
||
return model, thresholds, device
|
||
|
||
def _get_realtime_data(self, trade_date: str) -> pd.DataFrame:
|
||
"""获取实时数据并计算原始特征"""
|
||
ch = get_ch_client()
|
||
|
||
# 获取股票数据
|
||
ch_codes = [code_to_ch_format(c) for c in self.all_stocks if code_to_ch_format(c)]
|
||
ch_codes_str = "','".join(ch_codes)
|
||
|
||
stock_query = f"""
|
||
SELECT code, timestamp, close, amt
|
||
FROM stock_minute
|
||
WHERE toDate(timestamp) = '{trade_date}'
|
||
AND code IN ('{ch_codes_str}')
|
||
ORDER BY timestamp
|
||
"""
|
||
stock_result = ch.execute(stock_query)
|
||
if not stock_result:
|
||
return pd.DataFrame()
|
||
|
||
stock_df = pd.DataFrame(stock_result, columns=['ch_code', 'timestamp', 'close', 'amt'])
|
||
|
||
# 映射回原始代码
|
||
ch_to_code = {code_to_ch_format(c): c for c in self.all_stocks if code_to_ch_format(c)}
|
||
stock_df['code'] = stock_df['ch_code'].map(ch_to_code)
|
||
stock_df = stock_df.dropna(subset=['code'])
|
||
|
||
# 获取指数数据
|
||
index_query = f"""
|
||
SELECT timestamp, close
|
||
FROM index_minute
|
||
WHERE toDate(timestamp) = '{trade_date}'
|
||
AND code = '{REFERENCE_INDEX}'
|
||
ORDER BY timestamp
|
||
"""
|
||
index_result = ch.execute(index_query)
|
||
if not index_result:
|
||
return pd.DataFrame()
|
||
|
||
index_df = pd.DataFrame(index_result, columns=['timestamp', 'close'])
|
||
|
||
# 获取昨收价
|
||
engine = get_mysql_engine()
|
||
codes_str = "','".join([c for c in self.all_stocks if c and len(c) == 6])
|
||
|
||
with engine.connect() as conn:
|
||
prev_result = conn.execute(text(f"""
|
||
SELECT SECCODE, F007N FROM ea_trade
|
||
WHERE SECCODE IN ('{codes_str}')
|
||
AND TRADEDATE = (SELECT MAX(TRADEDATE) FROM ea_trade WHERE TRADEDATE < '{trade_date}')
|
||
AND F007N > 0
|
||
"""))
|
||
prev_close = {row[0]: float(row[1]) for row in prev_result if row[1]}
|
||
|
||
idx_result = conn.execute(text("""
|
||
SELECT F006N FROM ea_exchangetrade
|
||
WHERE INDEXCODE = '000001' AND TRADEDATE < :today
|
||
ORDER BY TRADEDATE DESC LIMIT 1
|
||
"""), {'today': trade_date}).fetchone()
|
||
index_prev_close = float(idx_result[0]) if idx_result else None
|
||
|
||
if not prev_close or not index_prev_close:
|
||
return pd.DataFrame()
|
||
|
||
# 计算涨跌幅
|
||
stock_df['prev_close'] = stock_df['code'].map(prev_close)
|
||
stock_df = stock_df.dropna(subset=['prev_close'])
|
||
stock_df['change_pct'] = (stock_df['close'] - stock_df['prev_close']) / stock_df['prev_close'] * 100
|
||
|
||
index_df['change_pct'] = (index_df['close'] - index_prev_close) / index_prev_close * 100
|
||
index_map = dict(zip(index_df['timestamp'], index_df['change_pct']))
|
||
|
||
# 按时间聚合概念特征
|
||
results = []
|
||
for ts in sorted(stock_df['timestamp'].unique()):
|
||
ts_data = stock_df[stock_df['timestamp'] == ts]
|
||
idx_chg = index_map.get(ts, 0)
|
||
|
||
stock_chg = dict(zip(ts_data['code'], ts_data['change_pct']))
|
||
stock_amt = dict(zip(ts_data['code'], ts_data['amt']))
|
||
|
||
for cid, stocks in self.concept_stocks.items():
|
||
changes = [stock_chg[s] for s in stocks if s in stock_chg]
|
||
amts = [stock_amt.get(s, 0) for s in stocks if s in stock_chg]
|
||
|
||
if not changes:
|
||
continue
|
||
|
||
alpha = np.mean(changes) - idx_chg
|
||
total_amt = sum(amts)
|
||
limit_up_ratio = sum(1 for c in changes if c >= CONFIG['limit_up_threshold']) / len(changes)
|
||
|
||
results.append({
|
||
'concept_id': cid,
|
||
'timestamp': ts,
|
||
'time_slot': time_to_slot(ts),
|
||
'alpha': alpha,
|
||
'total_amt': total_amt,
|
||
'limit_up_ratio': limit_up_ratio,
|
||
'stock_count': len(changes),
|
||
})
|
||
|
||
if not results:
|
||
return pd.DataFrame()
|
||
|
||
df = pd.DataFrame(results)
|
||
|
||
# 计算排名
|
||
for ts in df['timestamp'].unique():
|
||
mask = df['timestamp'] == ts
|
||
df.loc[mask, 'rank_pct'] = df.loc[mask, 'alpha'].rank(pct=True)
|
||
|
||
return df
|
||
|
||
def _compute_zscore(self, concept_id: str, time_slot: str, alpha: float, total_amt: float, rank_pct: float) -> Optional[Dict]:
|
||
"""计算 Z-Score"""
|
||
if concept_id not in self.baselines:
|
||
return None
|
||
|
||
baseline = self.baselines[concept_id]
|
||
if time_slot not in baseline:
|
||
return None
|
||
|
||
bl = baseline[time_slot]
|
||
|
||
alpha_z = np.clip((alpha - bl['alpha_mean']) / bl['alpha_std'], -5, 5)
|
||
amt_z = np.clip((total_amt - bl['amt_mean']) / bl['amt_std'], -5, 5)
|
||
rank_z = np.clip((rank_pct - bl['rank_mean']) / bl['rank_std'], -5, 5)
|
||
|
||
# 动量(基于 Z-Score 历史)
|
||
history = list(self.zscore_history[concept_id])
|
||
mom_3m = 0.0
|
||
mom_5m = 0.0
|
||
|
||
if len(history) >= 3:
|
||
recent = [h['alpha_zscore'] for h in history[-3:]]
|
||
older = [h['alpha_zscore'] for h in history[-6:-3]] if len(history) >= 6 else [history[0]['alpha_zscore']]
|
||
mom_3m = np.mean(recent) - np.mean(older)
|
||
|
||
if len(history) >= 5:
|
||
recent = [h['alpha_zscore'] for h in history[-5:]]
|
||
older = [h['alpha_zscore'] for h in history[-10:-5]] if len(history) >= 10 else [history[0]['alpha_zscore']]
|
||
mom_5m = np.mean(recent) - np.mean(older)
|
||
|
||
return {
|
||
'alpha_zscore': float(alpha_z),
|
||
'amt_zscore': float(amt_z),
|
||
'rank_zscore': float(rank_z),
|
||
'momentum_3m': float(mom_3m),
|
||
'momentum_5m': float(mom_5m),
|
||
}
|
||
|
||
@torch.no_grad()
|
||
def _ml_score(self, sequences: np.ndarray) -> np.ndarray:
|
||
"""批量 ML 评分"""
|
||
if self.model is None or len(sequences) == 0:
|
||
return np.zeros(len(sequences))
|
||
|
||
x = torch.FloatTensor(sequences).to(self.device)
|
||
errors = self.model.compute_reconstruction_error(x, reduction='none')
|
||
last_errors = errors[:, -1].cpu().numpy()
|
||
|
||
# 转换为 0-100 分数
|
||
if self.thresholds:
|
||
p50 = self.thresholds.get('median', 0.001)
|
||
p99 = self.thresholds.get('p99', 0.05)
|
||
scores = 50 + (last_errors - p50) / (p99 - p50 + 1e-6) * 49
|
||
else:
|
||
scores = last_errors * 1000
|
||
|
||
return np.clip(scores, 0, 100)
|
||
|
||
def detect(self, trade_date: str = None) -> List[Dict]:
|
||
"""检测指定日期的异动"""
|
||
trade_date = trade_date or datetime.now().strftime('%Y-%m-%d')
|
||
print(f"\n检测 {trade_date} 的异动...")
|
||
|
||
# 重置状态
|
||
self.zscore_history.clear()
|
||
self.anomaly_candidates.clear()
|
||
self.cooldown.clear()
|
||
|
||
# 获取数据
|
||
raw_df = self._get_realtime_data(trade_date)
|
||
if raw_df.empty:
|
||
print("无数据")
|
||
return []
|
||
|
||
timestamps = sorted(raw_df['timestamp'].unique())
|
||
print(f"时间点数: {len(timestamps)}")
|
||
|
||
all_alerts = []
|
||
|
||
for ts in timestamps:
|
||
ts_data = raw_df[raw_df['timestamp'] == ts]
|
||
time_slot = time_to_slot(ts)
|
||
|
||
candidates = []
|
||
|
||
# 计算每个概念的 Z-Score
|
||
for _, row in ts_data.iterrows():
|
||
cid = row['concept_id']
|
||
|
||
zscore = self._compute_zscore(
|
||
cid, time_slot,
|
||
row['alpha'], row['total_amt'], row['rank_pct']
|
||
)
|
||
|
||
if zscore is None:
|
||
continue
|
||
|
||
# 完整特征
|
||
features = {
|
||
**zscore,
|
||
'alpha': row['alpha'],
|
||
'limit_up_ratio': row['limit_up_ratio'],
|
||
'total_amt': row['total_amt'],
|
||
}
|
||
|
||
# 更新历史
|
||
self.zscore_history[cid].append(zscore)
|
||
|
||
# 规则评分
|
||
rule_score, triggered = score_rules_zscore(features)
|
||
|
||
candidates.append((cid, features, rule_score, triggered))
|
||
|
||
if not candidates:
|
||
continue
|
||
|
||
# 批量 ML 评分
|
||
sequences = []
|
||
valid_candidates = []
|
||
|
||
for cid, features, rule_score, triggered in candidates:
|
||
history = list(self.zscore_history[cid])
|
||
if len(history) >= CONFIG['seq_len']:
|
||
seq = np.array([[h['alpha_zscore'], h['amt_zscore'], h['rank_zscore'],
|
||
h['momentum_3m'], h['momentum_5m'], features['limit_up_ratio']]
|
||
for h in history])
|
||
sequences.append(seq)
|
||
valid_candidates.append((cid, features, rule_score, triggered))
|
||
|
||
if not sequences:
|
||
continue
|
||
|
||
ml_scores = self._ml_score(np.array(sequences))
|
||
|
||
# 融合 + 确认
|
||
for i, (cid, features, rule_score, triggered) in enumerate(valid_candidates):
|
||
ml_score = ml_scores[i]
|
||
final_score = CONFIG['rule_weight'] * rule_score + CONFIG['ml_weight'] * ml_score
|
||
|
||
# 判断触发
|
||
is_triggered = (
|
||
rule_score >= CONFIG['rule_trigger'] or
|
||
ml_score >= CONFIG['ml_trigger'] or
|
||
final_score >= CONFIG['fusion_trigger']
|
||
)
|
||
|
||
self.anomaly_candidates[cid].append((ts, final_score))
|
||
|
||
if not is_triggered:
|
||
continue
|
||
|
||
# 冷却期
|
||
if cid in self.cooldown:
|
||
if (ts - self.cooldown[cid]).total_seconds() < CONFIG['cooldown_minutes'] * 60:
|
||
continue
|
||
|
||
# 持续确认
|
||
recent = list(self.anomaly_candidates[cid])
|
||
if len(recent) < CONFIG['confirm_window']:
|
||
continue
|
||
|
||
exceed = sum(1 for _, s in recent if s >= CONFIG['fusion_trigger'])
|
||
ratio = exceed / len(recent)
|
||
|
||
if ratio < CONFIG['confirm_ratio']:
|
||
continue
|
||
|
||
# 确认异动!
|
||
self.cooldown[cid] = ts
|
||
|
||
alpha = features['alpha']
|
||
alert_type = 'surge_up' if alpha >= 1.5 else 'surge_down' if alpha <= -1.5 else 'surge'
|
||
|
||
concept_name = next((c['concept_name'] for c in self.concepts if c['concept_id'] == cid), cid)
|
||
|
||
all_alerts.append({
|
||
'concept_id': cid,
|
||
'concept_name': concept_name,
|
||
'alert_time': ts,
|
||
'trade_date': trade_date,
|
||
'alert_type': alert_type,
|
||
'final_score': float(final_score),
|
||
'rule_score': float(rule_score),
|
||
'ml_score': float(ml_score),
|
||
'confirm_ratio': float(ratio),
|
||
'alpha': float(alpha),
|
||
'alpha_zscore': float(features['alpha_zscore']),
|
||
'amt_zscore': float(features['amt_zscore']),
|
||
'rank_zscore': float(features['rank_zscore']),
|
||
'momentum_3m': float(features['momentum_3m']),
|
||
'momentum_5m': float(features['momentum_5m']),
|
||
'limit_up_ratio': float(features['limit_up_ratio']),
|
||
'triggered_rules': triggered,
|
||
'trigger_reason': f"融合({final_score:.0f})+确认({ratio:.0%})",
|
||
})
|
||
|
||
print(f"检测到 {len(all_alerts)} 个异动")
|
||
return all_alerts
|
||
|
||
|
||
# ==================== 数据库存储 ====================
|
||
|
||
def create_v2_table():
|
||
"""创建 V2 异动表(如果不存在)"""
|
||
engine = get_mysql_engine()
|
||
with engine.begin() as conn:
|
||
conn.execute(text("""
|
||
CREATE TABLE IF NOT EXISTS concept_anomaly_v2 (
|
||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||
concept_id VARCHAR(50) NOT NULL,
|
||
alert_time DATETIME NOT NULL,
|
||
trade_date DATE NOT NULL,
|
||
alert_type VARCHAR(20) NOT NULL,
|
||
final_score FLOAT,
|
||
rule_score FLOAT,
|
||
ml_score FLOAT,
|
||
trigger_reason VARCHAR(200),
|
||
confirm_ratio FLOAT,
|
||
alpha FLOAT,
|
||
alpha_zscore FLOAT,
|
||
amt_zscore FLOAT,
|
||
rank_zscore FLOAT,
|
||
momentum_3m FLOAT,
|
||
momentum_5m FLOAT,
|
||
limit_up_ratio FLOAT,
|
||
triggered_rules TEXT,
|
||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||
UNIQUE KEY uk_concept_time (concept_id, alert_time),
|
||
INDEX idx_trade_date (trade_date),
|
||
INDEX idx_alert_type (alert_type)
|
||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
|
||
"""))
|
||
print("concept_anomaly_v2 表已就绪")
|
||
|
||
|
||
def save_alerts_to_db(alerts: List[Dict]) -> int:
|
||
"""保存异动到数据库"""
|
||
if not alerts:
|
||
return 0
|
||
|
||
engine = get_mysql_engine()
|
||
saved = 0
|
||
|
||
with engine.begin() as conn:
|
||
for alert in alerts:
|
||
try:
|
||
insert_sql = text("""
|
||
INSERT IGNORE INTO concept_anomaly_v2
|
||
(concept_id, alert_time, trade_date, alert_type,
|
||
final_score, rule_score, ml_score, trigger_reason, confirm_ratio,
|
||
alpha, alpha_zscore, amt_zscore, rank_zscore,
|
||
momentum_3m, momentum_5m, limit_up_ratio, triggered_rules)
|
||
VALUES
|
||
(:concept_id, :alert_time, :trade_date, :alert_type,
|
||
:final_score, :rule_score, :ml_score, :trigger_reason, :confirm_ratio,
|
||
:alpha, :alpha_zscore, :amt_zscore, :rank_zscore,
|
||
:momentum_3m, :momentum_5m, :limit_up_ratio, :triggered_rules)
|
||
""")
|
||
|
||
result = conn.execute(insert_sql, {
|
||
'concept_id': alert['concept_id'],
|
||
'alert_time': alert['alert_time'],
|
||
'trade_date': alert['trade_date'],
|
||
'alert_type': alert['alert_type'],
|
||
'final_score': alert['final_score'],
|
||
'rule_score': alert['rule_score'],
|
||
'ml_score': alert['ml_score'],
|
||
'trigger_reason': alert['trigger_reason'],
|
||
'confirm_ratio': alert['confirm_ratio'],
|
||
'alpha': alert['alpha'],
|
||
'alpha_zscore': alert['alpha_zscore'],
|
||
'amt_zscore': alert['amt_zscore'],
|
||
'rank_zscore': alert['rank_zscore'],
|
||
'momentum_3m': alert['momentum_3m'],
|
||
'momentum_5m': alert['momentum_5m'],
|
||
'limit_up_ratio': alert['limit_up_ratio'],
|
||
'triggered_rules': json.dumps(alert.get('triggered_rules', []), ensure_ascii=False),
|
||
})
|
||
|
||
if result.rowcount > 0:
|
||
saved += 1
|
||
except Exception as e:
|
||
print(f"保存失败: {alert['concept_id']} - {e}")
|
||
|
||
return saved
|
||
|
||
|
||
def main():
|
||
import argparse
|
||
parser = argparse.ArgumentParser()
|
||
parser.add_argument('--date', type=str, default=None)
|
||
parser.add_argument('--no-save', action='store_true', help='不保存到数据库,只打印')
|
||
args = parser.parse_args()
|
||
|
||
# 确保表存在
|
||
if not args.no_save:
|
||
create_v2_table()
|
||
|
||
detector = RealtimeDetectorV2()
|
||
alerts = detector.detect(args.date)
|
||
|
||
print(f"\n{'='*60}")
|
||
print(f"检测结果 ({len(alerts)} 个异动)")
|
||
print('='*60)
|
||
|
||
for a in alerts[:20]:
|
||
print(f"[{a['alert_time'].strftime('%H:%M') if hasattr(a['alert_time'], 'strftime') else a['alert_time']}] "
|
||
f"{a['concept_name']} | {a['alert_type']} | "
|
||
f"分数={a['final_score']:.0f} 确认={a['confirm_ratio']:.0%} "
|
||
f"α={a['alpha']:.2f}% αZ={a['alpha_zscore']:.1f}")
|
||
|
||
if len(alerts) > 20:
|
||
print(f"... 共 {len(alerts)} 个")
|
||
|
||
# 保存到数据库
|
||
if not args.no_save and alerts:
|
||
saved = save_alerts_to_db(alerts)
|
||
print(f"\n✅ 已保存 {saved}/{len(alerts)} 条到 concept_anomaly_v2 表")
|
||
elif args.no_save:
|
||
print(f"\n⚠️ --no-save 模式,未保存到数据库")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|