1521 lines
52 KiB
Python
1521 lines
52 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
实时概念异动检测服务(实盘可用)
|
||
|
||
盘中每分钟运行一次,检测概念异动并写入数据库
|
||
|
||
数据流程:
|
||
1. 启动时从 ES 获取概念列表,从 MySQL 获取昨收价
|
||
2. 自动预热:从 ClickHouse 加载当天已有的历史分钟数据
|
||
3. 每分钟增量获取最新分钟数据
|
||
4. 在内存中实时计算概念特征(无前瞻偏差)
|
||
5. 使用规则+ML融合评分检测异动
|
||
6. 异动写入 MySQL
|
||
|
||
特征计算说明(无 Looking Forward):
|
||
- alpha: 当前时间点的概念超额收益
|
||
- alpha_delta: 使用过去 5 分钟的 alpha 变化
|
||
- amt_ratio: 使用过去 20 分钟的成交额均值
|
||
- rank_pct: 当前时间点所有概念的 alpha 排名
|
||
- limit_up_ratio: 当前时间点的涨停股占比
|
||
|
||
使用方法:
|
||
# 实盘模式(推荐)- 自动预热,不依赖 prepare_data.py
|
||
python realtime_detector.py
|
||
|
||
# 单次检测
|
||
python realtime_detector.py --once
|
||
|
||
# 回补历史异动到数据库(需要 prepare_data.py 生成 parquet)
|
||
python realtime_detector.py --backfill-only
|
||
|
||
# 实盘模式 + 启动时回补历史
|
||
python realtime_detector.py --backfill
|
||
|
||
最小数据量要求:
|
||
- ML 评分需要 seq_len=15 分钟的序列
|
||
- amt_ratio 需要 amt_ma_window=20 分钟的历史
|
||
- 即:开盘后约 35 分钟才能正常工作
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import time
|
||
import json
|
||
import argparse
|
||
import schedule
|
||
from datetime import datetime, timedelta
|
||
from pathlib import Path
|
||
from typing import Dict, List, Optional, Tuple, Set
|
||
from collections import defaultdict
|
||
|
||
import numpy as np
|
||
import pandas as pd
|
||
import torch
|
||
from sqlalchemy import create_engine, text
|
||
from elasticsearch import Elasticsearch
|
||
from clickhouse_driver import Client as CHClient
|
||
|
||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
|
||
|
||
# ==================== 配置 ====================
|
||
|
||
MYSQL_ENGINE = create_engine(
|
||
"mysql+pymysql://root:Zzl5588161!@192.168.1.5:3306/stock",
|
||
echo=False,
|
||
pool_pre_ping=True,
|
||
pool_recycle=3600,
|
||
)
|
||
|
||
ES_CLIENT = Elasticsearch(['http://127.0.0.1:9200'])
|
||
ES_INDEX = 'concept_library_v3'
|
||
|
||
CLICKHOUSE_CONFIG = {
|
||
'host': '127.0.0.1',
|
||
'port': 9000,
|
||
'user': 'default',
|
||
'password': 'Zzl33818!',
|
||
'database': 'stock'
|
||
}
|
||
|
||
REFERENCE_INDEX = '000001.SH'
|
||
|
||
FEATURES = ['alpha', 'alpha_delta', 'amt_ratio', 'amt_delta', 'rank_pct', 'limit_up_ratio']
|
||
|
||
# 特征计算参数
|
||
FEATURE_CONFIG = {
|
||
'alpha_delta_window': 5,
|
||
'amt_ma_window': 20,
|
||
'limit_up_threshold': 9.8,
|
||
'limit_down_threshold': -9.8,
|
||
}
|
||
|
||
# 检测配置(与 backtest_fast.py 保持一致)
|
||
CONFIG = {
|
||
'seq_len': 15,
|
||
'min_alpha_abs': 0.3,
|
||
'cooldown_minutes': 8,
|
||
'max_alerts_per_minute': 20,
|
||
'clip_value': 10.0,
|
||
# === 融合权重:与 backtest_fast.py 一致 ===
|
||
'rule_weight': 0.5,
|
||
'ml_weight': 0.5,
|
||
# === 触发阈值:与 backtest_fast.py 一致 ===
|
||
'rule_trigger': 65,
|
||
'ml_trigger': 70,
|
||
'fusion_trigger': 45,
|
||
}
|
||
|
||
TRADING_PERIODS = [
|
||
('09:30', '11:30'),
|
||
('13:00', '15:00'),
|
||
]
|
||
|
||
|
||
# ==================== 工具函数 ====================
|
||
|
||
def get_ch_client():
|
||
return CHClient(**CLICKHOUSE_CONFIG)
|
||
|
||
|
||
def code_to_ch_format(code: str) -> str:
|
||
if not code or len(code) != 6 or not code.isdigit():
|
||
return None
|
||
if code.startswith('6'):
|
||
return f"{code}.SH"
|
||
elif code.startswith('0') or code.startswith('3'):
|
||
return f"{code}.SZ"
|
||
else:
|
||
return f"{code}.BJ"
|
||
|
||
|
||
def is_trading_time() -> bool:
|
||
now = datetime.now()
|
||
if now.weekday() >= 5:
|
||
return False
|
||
current_time = now.strftime('%H:%M')
|
||
for start, end in TRADING_PERIODS:
|
||
if start <= current_time <= end:
|
||
return True
|
||
return False
|
||
|
||
|
||
def get_current_trade_date() -> str:
|
||
now = datetime.now()
|
||
if now.hour < 9:
|
||
now = now - timedelta(days=1)
|
||
return now.strftime('%Y-%m-%d')
|
||
|
||
|
||
# ==================== 数据获取 ====================
|
||
|
||
def get_all_concepts() -> List[dict]:
|
||
"""从 ES 获取所有概念"""
|
||
concepts = []
|
||
query = {
|
||
"query": {"match_all": {}},
|
||
"size": 100,
|
||
"_source": ["concept_id", "concept", "stocks"]
|
||
}
|
||
|
||
resp = ES_CLIENT.search(index=ES_INDEX, body=query, scroll='2m')
|
||
scroll_id = resp['_scroll_id']
|
||
hits = resp['hits']['hits']
|
||
|
||
while len(hits) > 0:
|
||
for hit in hits:
|
||
source = hit['_source']
|
||
stocks = []
|
||
if 'stocks' in source and isinstance(source['stocks'], list):
|
||
for stock in source['stocks']:
|
||
if isinstance(stock, dict) and 'code' in stock and stock['code']:
|
||
stocks.append(stock['code'])
|
||
|
||
if stocks:
|
||
concepts.append({
|
||
'concept_id': source.get('concept_id'),
|
||
'concept_name': source.get('concept'),
|
||
'stocks': stocks
|
||
})
|
||
|
||
resp = ES_CLIENT.scroll(scroll_id=scroll_id, scroll='2m')
|
||
scroll_id = resp['_scroll_id']
|
||
hits = resp['hits']['hits']
|
||
|
||
ES_CLIENT.clear_scroll(scroll_id=scroll_id)
|
||
print(f"获取到 {len(concepts)} 个概念")
|
||
return concepts
|
||
|
||
|
||
def get_prev_close(stock_codes: List[str], trade_date: str) -> Dict[str, float]:
|
||
"""获取昨收价(上一交易日的收盘价 F007N)"""
|
||
valid_codes = [c for c in stock_codes if c and len(c) == 6 and c.isdigit()]
|
||
if not valid_codes:
|
||
return {}
|
||
|
||
codes_str = "','".join(valid_codes)
|
||
# 注意:F007N 是"最近成交价"即当日收盘价,F002N 是"昨日收盘价"
|
||
# 我们需要查上一交易日的 F007N(那天的收盘价)作为今天的昨收
|
||
query = f"""
|
||
SELECT SECCODE, F007N
|
||
FROM ea_trade
|
||
WHERE SECCODE IN ('{codes_str}')
|
||
AND TRADEDATE = (
|
||
SELECT MAX(TRADEDATE) FROM ea_trade WHERE TRADEDATE < '{trade_date}'
|
||
)
|
||
AND F007N IS NOT NULL AND F007N > 0
|
||
"""
|
||
|
||
try:
|
||
with MYSQL_ENGINE.connect() as conn:
|
||
result = conn.execute(text(query))
|
||
return {row[0]: float(row[1]) for row in result if row[1]}
|
||
except Exception as e:
|
||
print(f"获取昨收价失败: {e}")
|
||
return {}
|
||
|
||
|
||
def get_index_prev_close(trade_date: str) -> float:
|
||
"""获取指数昨收价"""
|
||
code_no_suffix = REFERENCE_INDEX.split('.')[0]
|
||
try:
|
||
with MYSQL_ENGINE.connect() as conn:
|
||
result = conn.execute(text("""
|
||
SELECT F006N FROM ea_exchangetrade
|
||
WHERE INDEXCODE = :code AND TRADEDATE < :today
|
||
ORDER BY TRADEDATE DESC LIMIT 1
|
||
"""), {'code': code_no_suffix, 'today': trade_date}).fetchone()
|
||
if result and result[0]:
|
||
return float(result[0])
|
||
except Exception as e:
|
||
print(f"获取指数昨收失败: {e}")
|
||
return None
|
||
|
||
|
||
def get_stock_minute_data(trade_date: str, stock_codes: List[str], since_time: datetime = None) -> pd.DataFrame:
|
||
"""
|
||
从 ClickHouse 获取股票分钟数据
|
||
|
||
Args:
|
||
trade_date: 交易日期
|
||
stock_codes: 股票代码列表
|
||
since_time: 只获取该时间之后的数据(增量获取)
|
||
"""
|
||
client = get_ch_client()
|
||
|
||
ch_codes = []
|
||
code_map = {}
|
||
for code in stock_codes:
|
||
ch_code = code_to_ch_format(code)
|
||
if ch_code:
|
||
ch_codes.append(ch_code)
|
||
code_map[ch_code] = code
|
||
|
||
if not ch_codes:
|
||
return pd.DataFrame()
|
||
|
||
ch_codes_str = "','".join(ch_codes)
|
||
|
||
time_filter = ""
|
||
if since_time:
|
||
time_filter = f"AND timestamp > '{since_time.strftime('%Y-%m-%d %H:%M:%S')}'"
|
||
|
||
query = f"""
|
||
SELECT code, timestamp, close, volume, amt
|
||
FROM stock_minute
|
||
WHERE toDate(timestamp) = '{trade_date}'
|
||
AND code IN ('{ch_codes_str}')
|
||
{time_filter}
|
||
ORDER BY code, timestamp
|
||
"""
|
||
|
||
result = client.execute(query)
|
||
if not result:
|
||
return pd.DataFrame()
|
||
|
||
df = pd.DataFrame(result, columns=['ch_code', 'timestamp', 'close', 'volume', 'amt'])
|
||
df['code'] = df['ch_code'].map(code_map)
|
||
df = df.dropna(subset=['code'])
|
||
return df[['code', 'timestamp', 'close', 'volume', 'amt']]
|
||
|
||
|
||
def get_index_minute_data(trade_date: str, since_time: datetime = None) -> pd.DataFrame:
|
||
"""从 ClickHouse 获取指数分钟数据"""
|
||
client = get_ch_client()
|
||
|
||
time_filter = ""
|
||
if since_time:
|
||
time_filter = f"AND timestamp > '{since_time.strftime('%Y-%m-%d %H:%M:%S')}'"
|
||
|
||
query = f"""
|
||
SELECT timestamp, close, volume, amt
|
||
FROM index_minute
|
||
WHERE toDate(timestamp) = '{trade_date}'
|
||
AND code = '{REFERENCE_INDEX}'
|
||
{time_filter}
|
||
ORDER BY timestamp
|
||
"""
|
||
|
||
result = client.execute(query)
|
||
if not result:
|
||
return pd.DataFrame()
|
||
|
||
return pd.DataFrame(result, columns=['timestamp', 'close', 'volume', 'amt'])
|
||
|
||
|
||
# ==================== 规则评分 ====================
|
||
|
||
def get_size_adjusted_thresholds(stock_count: np.ndarray) -> np.ndarray:
|
||
"""根据概念股票数量计算动态阈值"""
|
||
n = len(stock_count)
|
||
size_factor = np.ones(n)
|
||
|
||
size_factor[stock_count < 5] = 1.8
|
||
size_factor[(stock_count >= 5) & (stock_count < 10)] = 1.4
|
||
size_factor[(stock_count >= 10) & (stock_count < 20)] = 1.2
|
||
size_factor[(stock_count >= 20) & (stock_count < 50)] = 1.0
|
||
size_factor[(stock_count >= 50) & (stock_count < 100)] = 0.85
|
||
size_factor[stock_count >= 100] = 0.7
|
||
|
||
return size_factor
|
||
|
||
|
||
def score_rules_batch(df: pd.DataFrame) -> Tuple[np.ndarray, List[List[str]]]:
|
||
"""批量计算规则得分"""
|
||
n = len(df)
|
||
scores = np.zeros(n)
|
||
triggered = [[] for _ in range(n)]
|
||
|
||
alpha = df['alpha'].values
|
||
alpha_delta = df['alpha_delta'].values
|
||
amt_ratio = df['amt_ratio'].values
|
||
rank_pct = df['rank_pct'].values
|
||
limit_up_ratio = df['limit_up_ratio'].values
|
||
stock_count = df['stock_count'].values if 'stock_count' in df.columns else np.full(n, 20)
|
||
|
||
alpha_abs = np.abs(alpha)
|
||
alpha_delta_abs = np.abs(alpha_delta)
|
||
size_factor = get_size_adjusted_thresholds(stock_count)
|
||
|
||
# Alpha 规则
|
||
alpha_extreme_thresh = 5.0 * size_factor
|
||
mask = alpha_abs >= alpha_extreme_thresh
|
||
scores[mask] += 20
|
||
for i in np.where(mask)[0]: triggered[i].append('alpha_extreme')
|
||
|
||
alpha_strong_thresh = 4.0 * size_factor
|
||
mask = (alpha_abs >= alpha_strong_thresh) & (alpha_abs < alpha_extreme_thresh)
|
||
scores[mask] += 15
|
||
for i in np.where(mask)[0]: triggered[i].append('alpha_strong')
|
||
|
||
alpha_medium_thresh = 3.0 * size_factor
|
||
mask = (alpha_abs >= alpha_medium_thresh) & (alpha_abs < alpha_strong_thresh)
|
||
scores[mask] += 10
|
||
for i in np.where(mask)[0]: triggered[i].append('alpha_medium')
|
||
|
||
# Alpha 加速度
|
||
delta_strong_thresh = 2.0 * size_factor
|
||
mask = alpha_delta_abs >= delta_strong_thresh
|
||
scores[mask] += 15
|
||
for i in np.where(mask)[0]: triggered[i].append('alpha_delta_strong')
|
||
|
||
delta_medium_thresh = 1.5 * size_factor
|
||
mask = (alpha_delta_abs >= delta_medium_thresh) & (alpha_delta_abs < delta_strong_thresh)
|
||
scores[mask] += 10
|
||
for i in np.where(mask)[0]: triggered[i].append('alpha_delta_medium')
|
||
|
||
# 成交额
|
||
mask = amt_ratio >= 10.0
|
||
scores[mask] += 20
|
||
for i in np.where(mask)[0]: triggered[i].append('volume_extreme')
|
||
|
||
mask = (amt_ratio >= 6.0) & (amt_ratio < 10.0)
|
||
scores[mask] += 12
|
||
for i in np.where(mask)[0]: triggered[i].append('volume_strong')
|
||
|
||
# 排名
|
||
mask = rank_pct >= 0.98
|
||
scores[mask] += 15
|
||
for i in np.where(mask)[0]: triggered[i].append('rank_top')
|
||
|
||
mask = rank_pct <= 0.02
|
||
scores[mask] += 15
|
||
for i in np.where(mask)[0]: triggered[i].append('rank_bottom')
|
||
|
||
# 涨停
|
||
limit_high_thresh = 0.30 * size_factor
|
||
mask = limit_up_ratio >= limit_high_thresh
|
||
scores[mask] += 20
|
||
for i in np.where(mask)[0]: triggered[i].append('limit_up_high')
|
||
|
||
limit_medium_thresh = 0.20 * size_factor
|
||
mask = (limit_up_ratio >= limit_medium_thresh) & (limit_up_ratio < limit_high_thresh)
|
||
scores[mask] += 12
|
||
for i in np.where(mask)[0]: triggered[i].append('limit_up_medium')
|
||
|
||
# 概念规模加分
|
||
large_concept = stock_count >= 50
|
||
has_signal = scores > 0
|
||
mask = large_concept & has_signal
|
||
scores[mask] += 10
|
||
for i in np.where(mask)[0]: triggered[i].append('large_concept_bonus')
|
||
|
||
xlarge_concept = stock_count >= 100
|
||
mask = xlarge_concept & has_signal
|
||
scores[mask] += 10
|
||
for i in np.where(mask)[0]: triggered[i].append('xlarge_concept_bonus')
|
||
|
||
# 组合规则
|
||
combo_alpha_thresh = 3.0 * size_factor
|
||
mask = (alpha_abs >= combo_alpha_thresh) & (amt_ratio >= 5.0) & ((rank_pct >= 0.95) | (rank_pct <= 0.05))
|
||
scores[mask] += 20
|
||
for i in np.where(mask)[0]: triggered[i].append('triple_signal')
|
||
|
||
mask = (alpha_abs >= combo_alpha_thresh) & (limit_up_ratio >= 0.15 * size_factor)
|
||
scores[mask] += 15
|
||
for i in np.where(mask)[0]: triggered[i].append('alpha_with_limit')
|
||
|
||
# 小概念惩罚
|
||
tiny_concept = stock_count < 5
|
||
single_rule = np.array([len(t) <= 1 for t in triggered])
|
||
mask = tiny_concept & single_rule & (scores > 0)
|
||
scores[mask] *= 0.5
|
||
for i in np.where(mask)[0]: triggered[i].append('tiny_concept_penalty')
|
||
|
||
scores = np.clip(scores, 0, 100)
|
||
return scores, triggered
|
||
|
||
|
||
def rule_score_with_details(features: Dict, stock_count: int = 50) -> Tuple[float, Dict[str, float]]:
|
||
"""
|
||
单条记录的规则评分(带详情)
|
||
|
||
Args:
|
||
features: 特征字典,包含 alpha, alpha_delta, amt_ratio, rank_pct, limit_up_ratio
|
||
stock_count: 概念股票数量
|
||
|
||
Returns:
|
||
(score, details): 总分和各规则触发详情
|
||
"""
|
||
score = 0.0
|
||
details = {}
|
||
|
||
alpha = features.get('alpha', 0)
|
||
alpha_delta = features.get('alpha_delta', 0)
|
||
amt_ratio = features.get('amt_ratio', 1)
|
||
rank_pct = features.get('rank_pct', 0.5)
|
||
limit_up_ratio = features.get('limit_up_ratio', 0)
|
||
|
||
alpha_abs = abs(alpha)
|
||
alpha_delta_abs = abs(alpha_delta)
|
||
size_factor = get_size_adjusted_thresholds(np.array([stock_count]))[0]
|
||
|
||
# Alpha 规则
|
||
alpha_extreme_thresh = 5.0 * size_factor
|
||
alpha_strong_thresh = 4.0 * size_factor
|
||
alpha_medium_thresh = 3.0 * size_factor
|
||
|
||
if alpha_abs >= alpha_extreme_thresh:
|
||
score += 20
|
||
details['alpha_extreme'] = 20
|
||
elif alpha_abs >= alpha_strong_thresh:
|
||
score += 15
|
||
details['alpha_strong'] = 15
|
||
elif alpha_abs >= alpha_medium_thresh:
|
||
score += 10
|
||
details['alpha_medium'] = 10
|
||
|
||
# Alpha 加速度
|
||
delta_strong_thresh = 2.0 * size_factor
|
||
delta_medium_thresh = 1.5 * size_factor
|
||
|
||
if alpha_delta_abs >= delta_strong_thresh:
|
||
score += 15
|
||
details['alpha_delta_strong'] = 15
|
||
elif alpha_delta_abs >= delta_medium_thresh:
|
||
score += 10
|
||
details['alpha_delta_medium'] = 10
|
||
|
||
# 成交额
|
||
if amt_ratio >= 10.0:
|
||
score += 20
|
||
details['volume_extreme'] = 20
|
||
elif amt_ratio >= 6.0:
|
||
score += 12
|
||
details['volume_strong'] = 12
|
||
|
||
# 排名
|
||
if rank_pct >= 0.98:
|
||
score += 15
|
||
details['rank_top'] = 15
|
||
elif rank_pct <= 0.02:
|
||
score += 15
|
||
details['rank_bottom'] = 15
|
||
|
||
# 涨停
|
||
limit_high_thresh = 0.30 * size_factor
|
||
limit_medium_thresh = 0.20 * size_factor
|
||
|
||
if limit_up_ratio >= limit_high_thresh:
|
||
score += 20
|
||
details['limit_up_high'] = 20
|
||
elif limit_up_ratio >= limit_medium_thresh:
|
||
score += 12
|
||
details['limit_up_medium'] = 12
|
||
|
||
# 概念规模加分
|
||
if score > 0:
|
||
if stock_count >= 50:
|
||
score += 10
|
||
details['large_concept_bonus'] = 10
|
||
if stock_count >= 100:
|
||
score += 10
|
||
details['xlarge_concept_bonus'] = 10
|
||
|
||
# 组合规则
|
||
combo_alpha_thresh = 3.0 * size_factor
|
||
|
||
if alpha_abs >= combo_alpha_thresh and amt_ratio >= 5.0 and (rank_pct >= 0.95 or rank_pct <= 0.05):
|
||
score += 20
|
||
details['triple_signal'] = 20
|
||
|
||
if alpha_abs >= combo_alpha_thresh and limit_up_ratio >= 0.15 * size_factor:
|
||
score += 15
|
||
details['alpha_with_limit'] = 15
|
||
|
||
# 小概念惩罚
|
||
if stock_count < 5 and len(details) <= 1 and score > 0:
|
||
penalty = score * 0.5
|
||
score *= 0.5
|
||
details['tiny_concept_penalty'] = -penalty
|
||
|
||
score = min(max(score, 0), 100)
|
||
return score, details
|
||
|
||
|
||
# ==================== ML 评分器 ====================
|
||
|
||
class MLScorer:
|
||
def __init__(self, checkpoint_dir: str = 'ml/checkpoints', device: str = 'auto'):
|
||
self.checkpoint_dir = Path(checkpoint_dir)
|
||
if device == 'auto':
|
||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||
else:
|
||
self.device = torch.device(device)
|
||
|
||
self.model = None
|
||
self.thresholds = None
|
||
self._load_model()
|
||
|
||
def _load_model(self):
|
||
model_path = self.checkpoint_dir / 'best_model.pt'
|
||
thresholds_path = self.checkpoint_dir / 'thresholds.json'
|
||
config_path = self.checkpoint_dir / 'config.json'
|
||
|
||
if not model_path.exists():
|
||
print(f"警告: 模型不存在 {model_path}")
|
||
return
|
||
|
||
try:
|
||
from model import LSTMAutoencoder
|
||
|
||
config = {}
|
||
if config_path.exists():
|
||
with open(config_path) as f:
|
||
config = json.load(f).get('model', {})
|
||
|
||
if 'd_model' in config:
|
||
config['hidden_dim'] = config.pop('d_model') // 2
|
||
for key in ['num_encoder_layers', 'num_decoder_layers', 'nhead', 'dim_feedforward', 'max_seq_len', 'use_instance_norm']:
|
||
config.pop(key, None)
|
||
if 'num_layers' not in config:
|
||
config['num_layers'] = 1
|
||
|
||
checkpoint = torch.load(model_path, map_location='cpu')
|
||
self.model = LSTMAutoencoder(**config)
|
||
self.model.load_state_dict(checkpoint['model_state_dict'])
|
||
self.model.to(self.device)
|
||
self.model.eval()
|
||
|
||
if thresholds_path.exists():
|
||
with open(thresholds_path) as f:
|
||
self.thresholds = json.load(f)
|
||
|
||
print(f"ML模型加载成功 (设备: {self.device})")
|
||
except Exception as e:
|
||
print(f"ML模型加载失败: {e}")
|
||
|
||
def is_ready(self):
|
||
return self.model is not None
|
||
|
||
@torch.no_grad()
|
||
def score_batch(self, sequences: np.ndarray, debug: bool = False) -> np.ndarray:
|
||
if not self.is_ready() or len(sequences) == 0:
|
||
return np.zeros(len(sequences))
|
||
|
||
x = torch.FloatTensor(sequences).to(self.device)
|
||
output, _ = self.model(x)
|
||
mse = ((output - x) ** 2).mean(dim=-1)
|
||
errors = mse[:, -1].cpu().numpy()
|
||
|
||
p95 = self.thresholds.get('p95', 0.1) if self.thresholds else 0.1
|
||
scores = np.clip(errors / p95 * 50, 0, 100)
|
||
|
||
if debug and len(errors) > 0:
|
||
print(f"[ML调试] p95={p95:.4f}, errors: min={errors.min():.4f}, max={errors.max():.4f}, mean={errors.mean():.4f}")
|
||
print(f"[ML调试] scores: min={scores.min():.0f}, max={scores.max():.0f}, mean={scores.mean():.0f}, =100占比={100*(scores>=100).mean():.1f}%")
|
||
|
||
return scores
|
||
|
||
|
||
# ==================== 内存数据管理器 ====================
|
||
|
||
class RealtimeDataManager:
|
||
"""
|
||
内存数据管理器
|
||
|
||
- 缓存股票分钟数据和指数数据
|
||
- 增量获取新数据
|
||
- 实时计算概念特征
|
||
"""
|
||
|
||
def __init__(self, concepts: List[dict], prev_close: Dict[str, float], index_prev_close: float):
|
||
self.concepts = concepts
|
||
self.prev_close = prev_close
|
||
self.index_prev_close = index_prev_close
|
||
|
||
# 概念到股票的映射
|
||
self.concept_stocks = {c['concept_id']: set(c['stocks']) for c in concepts}
|
||
self.all_stocks = list(set(s for c in concepts for s in c['stocks']))
|
||
|
||
# 内存缓存:股票分钟数据
|
||
self.stock_data = pd.DataFrame() # code, timestamp, close, volume, amt, change_pct
|
||
self.index_data = pd.DataFrame() # timestamp, close, change_pct
|
||
|
||
# 最后更新时间
|
||
self.last_update_time = None
|
||
|
||
# 概念历史(用于计算变化率)
|
||
self.concept_history = defaultdict(lambda: {'alpha': [], 'amt': []})
|
||
|
||
# 概念特征时间序列(用于 ML)
|
||
self.concept_features_history = defaultdict(list) # concept_id -> list of feature dicts
|
||
|
||
def update(self, trade_date: str) -> int:
|
||
"""
|
||
增量更新数据
|
||
|
||
Returns:
|
||
新增的时间点数量
|
||
"""
|
||
# 获取增量数据
|
||
new_stock_df = get_stock_minute_data(trade_date, self.all_stocks, self.last_update_time)
|
||
new_index_df = get_index_minute_data(trade_date, self.last_update_time)
|
||
|
||
if new_stock_df.empty and new_index_df.empty:
|
||
return 0
|
||
|
||
# 计算涨跌幅
|
||
if not new_stock_df.empty:
|
||
new_stock_df['prev_close'] = new_stock_df['code'].map(self.prev_close)
|
||
new_stock_df = new_stock_df.dropna(subset=['prev_close'])
|
||
new_stock_df['change_pct'] = (new_stock_df['close'] - new_stock_df['prev_close']) / new_stock_df['prev_close'] * 100
|
||
|
||
# 合并到缓存
|
||
self.stock_data = pd.concat([self.stock_data, new_stock_df], ignore_index=True)
|
||
self.stock_data = self.stock_data.drop_duplicates(subset=['code', 'timestamp'], keep='last')
|
||
|
||
if not new_index_df.empty:
|
||
new_index_df['change_pct'] = (new_index_df['close'] - self.index_prev_close) / self.index_prev_close * 100
|
||
|
||
# 调试:打印指数涨跌幅范围
|
||
if len(self.index_data) == 0: # 第一次
|
||
print(f"[调试] 指数 close 范围: {new_index_df['close'].min():.2f} ~ {new_index_df['close'].max():.2f}")
|
||
print(f"[调试] 指数 change_pct 范围: {new_index_df['change_pct'].min():.2f}% ~ {new_index_df['change_pct'].max():.2f}%")
|
||
|
||
self.index_data = pd.concat([self.index_data, new_index_df], ignore_index=True)
|
||
self.index_data = self.index_data.drop_duplicates(subset=['timestamp'], keep='last')
|
||
|
||
# 更新最后时间
|
||
if not new_stock_df.empty:
|
||
self.last_update_time = new_stock_df['timestamp'].max()
|
||
elif not new_index_df.empty:
|
||
self.last_update_time = new_index_df['timestamp'].max()
|
||
|
||
# 获取新时间点
|
||
new_timestamps = sorted(new_stock_df['timestamp'].unique()) if not new_stock_df.empty else []
|
||
|
||
# 计算新时间点的概念特征
|
||
for ts in new_timestamps:
|
||
self._compute_features_for_timestamp(ts)
|
||
|
||
return len(new_timestamps)
|
||
|
||
def _compute_features_for_timestamp(self, ts):
|
||
"""计算单个时间点的概念特征"""
|
||
ts_stock_data = self.stock_data[self.stock_data['timestamp'] == ts]
|
||
index_row = self.index_data[self.index_data['timestamp'] == ts]
|
||
|
||
if ts_stock_data.empty or index_row.empty:
|
||
return
|
||
|
||
index_change = index_row['change_pct'].values[0]
|
||
stock_change = dict(zip(ts_stock_data['code'], ts_stock_data['change_pct']))
|
||
stock_amt = dict(zip(ts_stock_data['code'], ts_stock_data['amt']))
|
||
|
||
for concept_id, stocks in self.concept_stocks.items():
|
||
concept_changes = [stock_change[s] for s in stocks if s in stock_change]
|
||
concept_amts = [stock_amt.get(s, 0) for s in stocks if s in stock_change]
|
||
|
||
if not concept_changes:
|
||
continue
|
||
|
||
avg_change = np.mean(concept_changes)
|
||
total_amt = sum(concept_amts)
|
||
alpha = avg_change - index_change
|
||
|
||
# 涨停比例
|
||
limit_up_count = sum(1 for c in concept_changes if c >= FEATURE_CONFIG['limit_up_threshold'])
|
||
limit_up_ratio = limit_up_count / len(concept_changes)
|
||
|
||
# 更新历史
|
||
history = self.concept_history[concept_id]
|
||
history['alpha'].append(alpha)
|
||
history['amt'].append(total_amt)
|
||
|
||
# 计算变化率
|
||
alpha_delta = 0
|
||
if len(history['alpha']) > FEATURE_CONFIG['alpha_delta_window']:
|
||
alpha_delta = alpha - history['alpha'][-FEATURE_CONFIG['alpha_delta_window'] - 1]
|
||
|
||
amt_ratio = 1.0
|
||
amt_delta = 0
|
||
if len(history['amt']) > FEATURE_CONFIG['amt_ma_window']:
|
||
amt_ma = np.mean(history['amt'][-FEATURE_CONFIG['amt_ma_window'] - 1:-1])
|
||
if amt_ma > 0:
|
||
amt_ratio = total_amt / amt_ma
|
||
amt_delta = total_amt - history['amt'][-2] if len(history['amt']) > 1 else 0
|
||
|
||
features = {
|
||
'timestamp': ts,
|
||
'concept_id': concept_id,
|
||
'alpha': alpha,
|
||
'alpha_delta': alpha_delta,
|
||
'amt_ratio': amt_ratio,
|
||
'amt_delta': amt_delta,
|
||
'limit_up_ratio': limit_up_ratio,
|
||
'stock_count': len(concept_changes),
|
||
'total_amt': total_amt,
|
||
}
|
||
|
||
self.concept_features_history[concept_id].append(features)
|
||
|
||
def get_latest_features(self) -> pd.DataFrame:
|
||
"""获取最新时间点的所有概念特征"""
|
||
if not self.concept_features_history:
|
||
return pd.DataFrame()
|
||
|
||
latest_features = []
|
||
for concept_id, history in self.concept_features_history.items():
|
||
if history:
|
||
latest_features.append(history[-1])
|
||
|
||
if not latest_features:
|
||
return pd.DataFrame()
|
||
|
||
df = pd.DataFrame(latest_features)
|
||
|
||
# 计算排名百分位
|
||
if len(df) > 1:
|
||
df['rank_pct'] = df['alpha'].rank(pct=True)
|
||
else:
|
||
df['rank_pct'] = 0.5
|
||
|
||
return df
|
||
|
||
def get_sequences_for_concepts(self, seq_len: int) -> Tuple[np.ndarray, pd.DataFrame]:
|
||
"""获取所有概念的特征序列(用于 ML 评分)"""
|
||
sequences = []
|
||
infos = []
|
||
|
||
for concept_id, history in self.concept_features_history.items():
|
||
if len(history) < seq_len:
|
||
continue
|
||
|
||
# 取最近 seq_len 个时间点
|
||
recent = history[-seq_len:]
|
||
|
||
# 构建序列
|
||
seq = np.array([[
|
||
f['alpha'],
|
||
f['alpha_delta'],
|
||
f['amt_ratio'],
|
||
f['amt_delta'],
|
||
f.get('rank_pct', 0.5),
|
||
f['limit_up_ratio']
|
||
] for f in recent])
|
||
|
||
seq = np.nan_to_num(seq, nan=0.0, posinf=0.0, neginf=0.0)
|
||
seq = np.clip(seq, -CONFIG['clip_value'], CONFIG['clip_value'])
|
||
|
||
sequences.append(seq)
|
||
infos.append(recent[-1]) # 最新特征
|
||
|
||
if not sequences:
|
||
return np.array([]), pd.DataFrame()
|
||
|
||
# 补充 rank_pct
|
||
info_df = pd.DataFrame(infos)
|
||
if 'rank_pct' not in info_df.columns and len(info_df) > 1:
|
||
info_df['rank_pct'] = info_df['alpha'].rank(pct=True)
|
||
|
||
return np.array(sequences), info_df
|
||
|
||
def get_all_timestamps(self) -> List:
|
||
"""获取所有时间点"""
|
||
if self.stock_data.empty:
|
||
return []
|
||
return sorted(self.stock_data['timestamp'].unique())
|
||
|
||
def get_concept_features_df(self) -> pd.DataFrame:
|
||
"""获取概念特征的 DataFrame 形式(用于批量回测)"""
|
||
if not self.concept_features_history:
|
||
return pd.DataFrame()
|
||
|
||
rows = []
|
||
for concept_id, history in self.concept_features_history.items():
|
||
for f in history:
|
||
row = {
|
||
'concept_id': concept_id,
|
||
'timestamp': f['timestamp'],
|
||
'alpha': f['alpha'],
|
||
'alpha_delta': f['alpha_delta'],
|
||
'amt_ratio': f['amt_ratio'],
|
||
'amt_delta': f.get('amt_delta', 0),
|
||
'limit_up_ratio': f['limit_up_ratio'],
|
||
'stock_count': f.get('stock_count', 0),
|
||
'total_amt': f.get('total_amt', 0),
|
||
}
|
||
rows.append(row)
|
||
|
||
if not rows:
|
||
return pd.DataFrame()
|
||
|
||
df = pd.DataFrame(rows)
|
||
|
||
# 按时间点计算 rank_pct(每个时间点内部排名)
|
||
df['rank_pct'] = df.groupby('timestamp')['alpha'].rank(pct=True)
|
||
|
||
return df
|
||
|
||
|
||
# ==================== 冷却期管理 ====================
|
||
|
||
class CooldownManager:
|
||
def __init__(self, cooldown_minutes: int = 8):
|
||
self.cooldown_minutes = cooldown_minutes
|
||
self.last_alert_time = {}
|
||
|
||
def is_in_cooldown(self, concept_id: str, current_time: datetime) -> bool:
|
||
if concept_id not in self.last_alert_time:
|
||
return False
|
||
last_time = self.last_alert_time[concept_id]
|
||
diff = (current_time - last_time).total_seconds() / 60
|
||
return diff < self.cooldown_minutes
|
||
|
||
def record_alert(self, concept_id: str, alert_time: datetime):
|
||
self.last_alert_time[concept_id] = alert_time
|
||
|
||
def cleanup_old(self, current_time: datetime):
|
||
cutoff = current_time - timedelta(minutes=self.cooldown_minutes * 2)
|
||
self.last_alert_time = {cid: t for cid, t in self.last_alert_time.items() if t > cutoff}
|
||
|
||
|
||
# ==================== 异动检测 ====================
|
||
|
||
def detect_anomalies(
|
||
ml_scorer: MLScorer,
|
||
data_mgr: RealtimeDataManager,
|
||
cooldown_mgr: CooldownManager,
|
||
trade_date: str,
|
||
config: Dict
|
||
) -> List[Dict]:
|
||
"""检测当前时刻的异动"""
|
||
|
||
# 获取最新特征
|
||
latest_df = data_mgr.get_latest_features()
|
||
if latest_df.empty:
|
||
return []
|
||
|
||
# 获取 ML 序列
|
||
sequences, info_df = data_mgr.get_sequences_for_concepts(config['seq_len'])
|
||
|
||
if len(sequences) == 0:
|
||
return []
|
||
|
||
# 获取当前时间
|
||
current_time = pd.to_datetime(info_df['timestamp'].iloc[0])
|
||
|
||
# 清理过期冷却
|
||
cooldown_mgr.cleanup_old(current_time)
|
||
|
||
# 过滤冷却中的概念
|
||
valid_mask = []
|
||
for _, row in info_df.iterrows():
|
||
in_cooldown = cooldown_mgr.is_in_cooldown(row['concept_id'], current_time)
|
||
valid_mask.append(not in_cooldown)
|
||
|
||
valid_mask = np.array(valid_mask)
|
||
sequences = sequences[valid_mask]
|
||
info_df = info_df[valid_mask].reset_index(drop=True)
|
||
|
||
if len(sequences) == 0:
|
||
return []
|
||
|
||
# 过滤小波动
|
||
alpha_mask = np.abs(info_df['alpha'].values) >= config['min_alpha_abs']
|
||
sequences = sequences[alpha_mask]
|
||
info_df = info_df[alpha_mask].reset_index(drop=True)
|
||
|
||
if len(sequences) == 0:
|
||
return []
|
||
|
||
# 规则评分
|
||
rule_scores, triggered_rules = score_rules_batch(info_df)
|
||
|
||
# ML 评分
|
||
ml_scores = ml_scorer.score_batch(sequences)
|
||
|
||
# 融合得分
|
||
w1, w2 = config['rule_weight'], config['ml_weight']
|
||
final_scores = w1 * rule_scores + w2 * ml_scores
|
||
|
||
# 判断异动
|
||
alerts = []
|
||
for i, row in info_df.iterrows():
|
||
rule_score = rule_scores[i]
|
||
ml_score = ml_scores[i]
|
||
final_score = final_scores[i]
|
||
|
||
is_anomaly = (
|
||
rule_score >= config['rule_trigger'] or
|
||
ml_score >= config['ml_trigger'] or
|
||
final_score >= config['fusion_trigger']
|
||
)
|
||
|
||
if not is_anomaly:
|
||
continue
|
||
|
||
# 触发原因
|
||
if rule_score >= config['rule_trigger']:
|
||
trigger = f'规则强信号({rule_score:.0f}分)'
|
||
elif ml_score >= config['ml_trigger']:
|
||
trigger = f'ML强信号({ml_score:.0f}分)'
|
||
else:
|
||
trigger = f'融合触发({final_score:.0f}分)'
|
||
|
||
# 异动类型
|
||
alpha = row['alpha']
|
||
if alpha >= 1.5:
|
||
alert_type = 'surge_up'
|
||
elif alpha <= -1.5:
|
||
alert_type = 'surge_down'
|
||
elif row['amt_ratio'] >= 3.0:
|
||
alert_type = 'volume_spike'
|
||
else:
|
||
alert_type = 'unknown'
|
||
|
||
alert = {
|
||
'concept_id': row['concept_id'],
|
||
'alert_time': row['timestamp'],
|
||
'trade_date': trade_date,
|
||
'alert_type': alert_type,
|
||
'final_score': final_score,
|
||
'rule_score': rule_score,
|
||
'ml_score': ml_score,
|
||
'trigger_reason': trigger,
|
||
'triggered_rules': triggered_rules[i],
|
||
'alpha': row['alpha'],
|
||
'alpha_delta': row['alpha_delta'],
|
||
'amt_ratio': row['amt_ratio'],
|
||
'amt_delta': row.get('amt_delta', 0),
|
||
'rank_pct': row.get('rank_pct', 0.5),
|
||
'limit_up_ratio': row['limit_up_ratio'],
|
||
'stock_count': row['stock_count'],
|
||
'total_amt': row['total_amt'],
|
||
}
|
||
|
||
alerts.append(alert)
|
||
cooldown_mgr.record_alert(row['concept_id'], current_time)
|
||
|
||
# 按得分排序
|
||
alerts.sort(key=lambda x: x['final_score'], reverse=True)
|
||
return alerts[:config['max_alerts_per_minute']]
|
||
|
||
|
||
# ==================== 数据库写入 ====================
|
||
|
||
def save_alerts_to_mysql(alerts: List[Dict]) -> int:
|
||
if not alerts:
|
||
return 0
|
||
|
||
saved = 0
|
||
with MYSQL_ENGINE.begin() as conn:
|
||
for alert in alerts:
|
||
try:
|
||
insert_sql = text("""
|
||
INSERT IGNORE INTO concept_anomaly_hybrid
|
||
(concept_id, alert_time, trade_date, alert_type,
|
||
final_score, rule_score, ml_score, trigger_reason,
|
||
alpha, alpha_delta, amt_ratio, amt_delta,
|
||
rank_pct, limit_up_ratio, stock_count, total_amt,
|
||
triggered_rules)
|
||
VALUES
|
||
(:concept_id, :alert_time, :trade_date, :alert_type,
|
||
:final_score, :rule_score, :ml_score, :trigger_reason,
|
||
:alpha, :alpha_delta, :amt_ratio, :amt_delta,
|
||
:rank_pct, :limit_up_ratio, :stock_count, :total_amt,
|
||
:triggered_rules)
|
||
""")
|
||
|
||
result = conn.execute(insert_sql, {
|
||
'concept_id': alert['concept_id'],
|
||
'alert_time': alert['alert_time'],
|
||
'trade_date': alert['trade_date'],
|
||
'alert_type': alert['alert_type'],
|
||
'final_score': alert['final_score'],
|
||
'rule_score': alert['rule_score'],
|
||
'ml_score': alert['ml_score'],
|
||
'trigger_reason': alert['trigger_reason'],
|
||
'alpha': alert.get('alpha', 0),
|
||
'alpha_delta': alert.get('alpha_delta', 0),
|
||
'amt_ratio': alert.get('amt_ratio', 1),
|
||
'amt_delta': alert.get('amt_delta', 0),
|
||
'rank_pct': alert.get('rank_pct', 0.5),
|
||
'limit_up_ratio': alert.get('limit_up_ratio', 0),
|
||
'stock_count': alert.get('stock_count', 0),
|
||
'total_amt': alert.get('total_amt', 0),
|
||
'triggered_rules': json.dumps(alert.get('triggered_rules', []), ensure_ascii=False),
|
||
})
|
||
|
||
if result.rowcount > 0:
|
||
saved += 1
|
||
except Exception as e:
|
||
print(f"保存失败: {alert['concept_id']} - {e}")
|
||
|
||
return saved
|
||
|
||
|
||
# ==================== 主服务 ====================
|
||
|
||
class RealtimeDetectorService:
|
||
def __init__(self, checkpoint_dir: str = 'ml/checkpoints', device: str = 'auto'):
|
||
self.checkpoint_dir = checkpoint_dir
|
||
self.device = device
|
||
|
||
# 初始化 ML 评分器
|
||
self.ml_scorer = MLScorer(checkpoint_dir, device)
|
||
|
||
# 这些在 init_for_trade_date 中初始化
|
||
self.data_mgr = None
|
||
self.cooldown_mgr = None
|
||
self.trade_date = None
|
||
|
||
def init_for_trade_date(self, trade_date: str, preload_history: bool = True):
|
||
"""
|
||
为指定交易日初始化
|
||
|
||
Args:
|
||
trade_date: 交易日期
|
||
preload_history: 是否预加载当天已有的历史数据(实盘必须为 True)
|
||
"""
|
||
if self.trade_date == trade_date and self.data_mgr is not None:
|
||
return
|
||
|
||
print(f"[初始化] 交易日: {trade_date}")
|
||
|
||
# 获取概念列表
|
||
print(f"[初始化] 获取概念列表...")
|
||
concepts = get_all_concepts()
|
||
|
||
# 获取所有股票
|
||
all_stocks = list(set(s for c in concepts for s in c['stocks']))
|
||
print(f"[初始化] 共 {len(all_stocks)} 只股票")
|
||
|
||
# 获取昨收价
|
||
print(f"[初始化] 获取昨收价...")
|
||
prev_close = get_prev_close(all_stocks, trade_date)
|
||
index_prev_close = get_index_prev_close(trade_date)
|
||
print(f"[初始化] 获取到 {len(prev_close)} 只股票的昨收价")
|
||
print(f"[初始化] 指数昨收价: {index_prev_close}")
|
||
|
||
# 创建数据管理器
|
||
self.data_mgr = RealtimeDataManager(concepts, prev_close, index_prev_close)
|
||
self.cooldown_mgr = CooldownManager(CONFIG['cooldown_minutes'])
|
||
self.trade_date = trade_date
|
||
|
||
# 预加载当天已有的历史数据(实盘关键)
|
||
if preload_history:
|
||
self._preload_today_history(trade_date)
|
||
|
||
def _preload_today_history(self, trade_date: str):
|
||
"""
|
||
预加载当天已有的历史数据到内存
|
||
|
||
这是实盘运行的关键:
|
||
- 在盘中任意时刻启动服务时,需要先加载当天已有的数据
|
||
- 这样才能正确计算 alpha_delta(需要过去 5 分钟)和 amt_ratio(需要过去 20 分钟)
|
||
- 以及构建 ML 所需的序列(需要 seq_len=15 分钟)
|
||
|
||
整个过程不依赖 prepare_data.py,直接从 ClickHouse 读取原始数据计算
|
||
"""
|
||
print(f"[预热] 加载当天历史数据...")
|
||
|
||
# 直接调用 update,但不设置 last_update_time,会获取当天所有数据
|
||
# data_mgr.last_update_time 初始为 None,会获取全部数据
|
||
n_updates = self.data_mgr.update(trade_date)
|
||
|
||
if n_updates > 0:
|
||
print(f"[预热] 加载完成,共 {n_updates} 个时间点")
|
||
|
||
# 检查是否满足 ML 所需的最小数据量
|
||
min_required = CONFIG['seq_len'] + FEATURE_CONFIG['amt_ma_window']
|
||
if n_updates < min_required:
|
||
print(f"[预热] 警告:数据量 {n_updates} < 最小需求 {min_required},部分特征可能不准确")
|
||
else:
|
||
print(f"[预热] 数据充足,可以正常检测")
|
||
else:
|
||
print(f"[预热] 当天暂无历史数据(可能是开盘前)")
|
||
|
||
def backfill_today(self):
|
||
"""
|
||
补齐当天历史数据并检测异动(回补模式)
|
||
|
||
使用与 backtest_fast.py 完全相同的逻辑:
|
||
1. 先用 prepare_data.py 生成当天的 parquet 文件
|
||
2. 读取 parquet 文件进行回测
|
||
|
||
注意:这个方法用于回补历史异动记录,不是实盘必须的
|
||
实盘模式下,init_for_trade_date 会自动预热历史数据
|
||
"""
|
||
trade_date = get_current_trade_date()
|
||
print(f"[补齐] 交易日: {trade_date}")
|
||
|
||
# 1. 生成当天的 parquet 文件
|
||
parquet_path = Path('ml/data') / f'features_{trade_date}.parquet'
|
||
|
||
if not parquet_path.exists():
|
||
print(f"[补齐] 生成当天特征数据...")
|
||
self._generate_today_parquet(trade_date)
|
||
|
||
if not parquet_path.exists():
|
||
print(f"[补齐] 无法生成特征数据,跳过")
|
||
return
|
||
|
||
# 2. 读取 parquet 文件
|
||
df = pd.read_parquet(parquet_path)
|
||
print(f"[补齐] 读取到 {len(df)} 条特征数据")
|
||
|
||
if df.empty:
|
||
print("[补齐] 无数据")
|
||
return
|
||
|
||
# 打印特征分布(调试)
|
||
print(f"[调试] alpha 分布: min={df['alpha'].min():.2f}, max={df['alpha'].max():.2f}, mean={df['alpha'].mean():.2f}")
|
||
print(f"[调试] |alpha| >= 0.3 的数量: {(df['alpha'].abs() >= 0.3).sum()}")
|
||
|
||
# 3. 使用 backtest_fast.py 相同的回测逻辑
|
||
alerts = self._backtest_from_parquet(df, trade_date)
|
||
|
||
# 4. 保存结果
|
||
if alerts:
|
||
saved = save_alerts_to_mysql(alerts)
|
||
print(f"[补齐] 完成!共 {len(alerts)} 个异动, 保存 {saved} 条")
|
||
|
||
# 统计触发来源
|
||
trigger_stats = {'规则': 0, 'ML': 0, '融合': 0}
|
||
for a in alerts:
|
||
reason = a['trigger_reason']
|
||
if '规则' in reason:
|
||
trigger_stats['规则'] += 1
|
||
elif 'ML' in reason:
|
||
trigger_stats['ML'] += 1
|
||
else:
|
||
trigger_stats['融合'] += 1
|
||
print(f"[补齐] 触发来源: {trigger_stats}")
|
||
else:
|
||
print("[补齐] 无异动")
|
||
|
||
def _generate_today_parquet(self, trade_date: str):
|
||
"""
|
||
生成当天的 parquet 文件(调用 prepare_data.py 的逻辑)
|
||
"""
|
||
import subprocess
|
||
cmd = ['python', 'ml/prepare_data.py', '--start', trade_date, '--end', trade_date]
|
||
print(f"[补齐] 执行: {' '.join(cmd)}")
|
||
try:
|
||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=300)
|
||
if result.returncode != 0:
|
||
print(f"[补齐] prepare_data.py 执行失败: {result.stderr}")
|
||
except Exception as e:
|
||
print(f"[补齐] prepare_data.py 执行异常: {e}")
|
||
|
||
def _backtest_from_parquet(self, df: pd.DataFrame, trade_date: str) -> List[Dict]:
|
||
"""
|
||
从 parquet 数据回测(与 backtest_fast.py 完全一致的逻辑)
|
||
"""
|
||
seq_len = CONFIG['seq_len']
|
||
now = datetime.now()
|
||
|
||
# 确保按概念和时间排序
|
||
df = df.sort_values(['concept_id', 'timestamp']).reset_index(drop=True)
|
||
|
||
# 获取所有时间点
|
||
all_timestamps = sorted(df['timestamp'].unique())
|
||
|
||
# 只处理当前时间之前的
|
||
past_timestamps = []
|
||
for ts in all_timestamps:
|
||
try:
|
||
ts_dt = pd.to_datetime(ts)
|
||
if ts_dt.tzinfo is not None:
|
||
ts_dt = ts_dt.tz_localize(None)
|
||
if ts_dt < now:
|
||
past_timestamps.append(ts)
|
||
except Exception:
|
||
continue
|
||
|
||
print(f"[补齐] 处理 {len(past_timestamps)} 个历史时间点...")
|
||
|
||
if len(past_timestamps) < seq_len:
|
||
print(f"[补齐] 时间点不足 {seq_len},跳过")
|
||
return []
|
||
|
||
# 构建序列(与 backtest_fast.py 的 build_sequences_fast 一致)
|
||
sequences = []
|
||
infos = []
|
||
|
||
groups = df.groupby('concept_id')
|
||
|
||
for concept_id, gdf in groups:
|
||
gdf = gdf.reset_index(drop=True)
|
||
feat_matrix = gdf[FEATURES].values
|
||
feat_matrix = np.nan_to_num(feat_matrix, nan=0.0, posinf=0.0, neginf=0.0)
|
||
feat_matrix = np.clip(feat_matrix, -CONFIG['clip_value'], CONFIG['clip_value'])
|
||
|
||
n_total = len(feat_matrix)
|
||
if n_total < seq_len:
|
||
continue
|
||
|
||
for i in range(n_total - seq_len + 1):
|
||
seq = feat_matrix[i:i + seq_len]
|
||
row = gdf.iloc[i + seq_len - 1]
|
||
|
||
# 只保留当前时间之前的
|
||
ts = row['timestamp']
|
||
try:
|
||
ts_dt = pd.to_datetime(ts)
|
||
if ts_dt.tzinfo is not None:
|
||
ts_dt = ts_dt.tz_localize(None)
|
||
if ts_dt >= now:
|
||
continue
|
||
except Exception:
|
||
continue
|
||
|
||
sequences.append(seq)
|
||
infos.append({
|
||
'concept_id': concept_id,
|
||
'timestamp': row['timestamp'],
|
||
'alpha': row['alpha'],
|
||
'alpha_delta': row.get('alpha_delta', 0),
|
||
'amt_ratio': row.get('amt_ratio', 1),
|
||
'amt_delta': row.get('amt_delta', 0),
|
||
'rank_pct': row.get('rank_pct', 0.5),
|
||
'limit_up_ratio': row.get('limit_up_ratio', 0),
|
||
'stock_count': row.get('stock_count', 0),
|
||
'total_amt': row.get('total_amt', 0),
|
||
})
|
||
|
||
if not sequences:
|
||
return []
|
||
|
||
sequences = np.array(sequences)
|
||
info_df = pd.DataFrame(infos)
|
||
|
||
print(f"[补齐] 构建了 {len(sequences)} 个序列")
|
||
|
||
# 过滤小波动
|
||
alpha_abs = np.abs(info_df['alpha'].values)
|
||
valid_mask = alpha_abs >= CONFIG['min_alpha_abs']
|
||
sequences = sequences[valid_mask]
|
||
info_df = info_df[valid_mask].reset_index(drop=True)
|
||
|
||
if len(sequences) == 0:
|
||
return []
|
||
|
||
print(f"[补齐] 过滤后 {len(sequences)} 个序列")
|
||
|
||
# 批量规则评分
|
||
rule_scores, triggered_rules = score_rules_batch(info_df)
|
||
|
||
# 批量 ML 评分
|
||
batch_size = 2048
|
||
ml_scores = []
|
||
for i in range(0, len(sequences), batch_size):
|
||
batch_seq = sequences[i:i+batch_size]
|
||
batch_scores = self.ml_scorer.score_batch(batch_seq)
|
||
ml_scores.append(batch_scores)
|
||
ml_scores = np.concatenate(ml_scores) if ml_scores else np.zeros(len(sequences))
|
||
|
||
# 融合得分
|
||
w1, w2 = CONFIG['rule_weight'], CONFIG['ml_weight']
|
||
final_scores = w1 * rule_scores + w2 * ml_scores
|
||
|
||
# 判断异动
|
||
is_anomaly = (
|
||
(rule_scores >= CONFIG['rule_trigger']) |
|
||
(ml_scores >= CONFIG['ml_trigger']) |
|
||
(final_scores >= CONFIG['fusion_trigger'])
|
||
)
|
||
|
||
# 添加分数到 info_df
|
||
info_df['rule_score'] = rule_scores
|
||
info_df['ml_score'] = ml_scores
|
||
info_df['final_score'] = final_scores
|
||
info_df['is_anomaly'] = is_anomaly
|
||
info_df['triggered_rules'] = triggered_rules
|
||
|
||
# 只保留异动
|
||
anomaly_df = info_df[info_df['is_anomaly']].copy()
|
||
|
||
if len(anomaly_df) == 0:
|
||
return []
|
||
|
||
print(f"[补齐] 发现 {len(anomaly_df)} 个候选异动")
|
||
|
||
# 应用冷却期
|
||
anomaly_df = anomaly_df.sort_values(['concept_id', 'timestamp'])
|
||
cooldown = {}
|
||
keep_mask = []
|
||
|
||
for _, row in anomaly_df.iterrows():
|
||
cid = row['concept_id']
|
||
ts = row['timestamp']
|
||
|
||
if cid in cooldown:
|
||
try:
|
||
diff = (pd.to_datetime(ts) - pd.to_datetime(cooldown[cid])).total_seconds() / 60
|
||
except:
|
||
diff = CONFIG['cooldown_minutes'] + 1
|
||
|
||
if diff < CONFIG['cooldown_minutes']:
|
||
keep_mask.append(False)
|
||
continue
|
||
|
||
cooldown[cid] = ts
|
||
keep_mask.append(True)
|
||
|
||
anomaly_df = anomaly_df[keep_mask]
|
||
|
||
print(f"[补齐] 冷却后 {len(anomaly_df)} 个异动")
|
||
|
||
# 按时间分组,每分钟最多 max_alerts_per_minute 个
|
||
alerts = []
|
||
for ts, group in anomaly_df.groupby('timestamp'):
|
||
group = group.nlargest(CONFIG['max_alerts_per_minute'], 'final_score')
|
||
|
||
for _, row in group.iterrows():
|
||
alpha = row['alpha']
|
||
if alpha >= 1.5:
|
||
atype = 'surge_up'
|
||
elif alpha <= -1.5:
|
||
atype = 'surge_down'
|
||
elif row['amt_ratio'] >= 3.0:
|
||
atype = 'volume_spike'
|
||
else:
|
||
atype = 'unknown'
|
||
|
||
rule_score = row['rule_score']
|
||
ml_score = row['ml_score']
|
||
final_score = row['final_score']
|
||
|
||
if rule_score >= CONFIG['rule_trigger']:
|
||
trigger = f'规则强信号({rule_score:.0f}分)'
|
||
elif ml_score >= CONFIG['ml_trigger']:
|
||
trigger = f'ML强信号({ml_score:.0f}分)'
|
||
else:
|
||
trigger = f'融合触发({final_score:.0f}分)'
|
||
|
||
alerts.append({
|
||
'concept_id': row['concept_id'],
|
||
'alert_time': row['timestamp'],
|
||
'trade_date': trade_date,
|
||
'alert_type': atype,
|
||
'final_score': final_score,
|
||
'rule_score': rule_score,
|
||
'ml_score': ml_score,
|
||
'trigger_reason': trigger,
|
||
'triggered_rules': row['triggered_rules'],
|
||
'alpha': alpha,
|
||
'alpha_delta': row['alpha_delta'],
|
||
'amt_ratio': row['amt_ratio'],
|
||
'amt_delta': row['amt_delta'],
|
||
'rank_pct': row['rank_pct'],
|
||
'limit_up_ratio': row['limit_up_ratio'],
|
||
'stock_count': row['stock_count'],
|
||
'total_amt': row['total_amt'],
|
||
})
|
||
|
||
return alerts
|
||
|
||
def run_once(self):
|
||
"""执行一次检测"""
|
||
now = datetime.now()
|
||
trade_date = get_current_trade_date()
|
||
|
||
if not is_trading_time():
|
||
print(f"[{now.strftime('%H:%M:%S')}] 非交易时间,跳过")
|
||
return
|
||
|
||
# 初始化
|
||
self.init_for_trade_date(trade_date)
|
||
|
||
print(f"[{now.strftime('%H:%M:%S')}] 获取新数据...")
|
||
|
||
# 增量更新
|
||
n_updates = self.data_mgr.update(trade_date)
|
||
print(f" 新增 {n_updates} 个时间点")
|
||
|
||
if n_updates == 0:
|
||
print(f" 无新数据")
|
||
return
|
||
|
||
# 检测
|
||
alerts = detect_anomalies(
|
||
self.ml_scorer,
|
||
self.data_mgr,
|
||
self.cooldown_mgr,
|
||
trade_date,
|
||
CONFIG
|
||
)
|
||
|
||
if alerts:
|
||
saved = save_alerts_to_mysql(alerts)
|
||
print(f" 检测到 {len(alerts)} 个异动, 保存 {saved} 条")
|
||
|
||
for alert in alerts[:5]:
|
||
print(f" - {alert['concept_id']}: {alert['alert_type']} "
|
||
f"(final={alert['final_score']:.0f}, rule={alert['rule_score']:.0f}, ml={alert['ml_score']:.0f})")
|
||
else:
|
||
print(f" 无异动")
|
||
|
||
def run_loop(self, backfill: bool = False):
|
||
"""
|
||
持续运行(实盘模式)
|
||
|
||
Args:
|
||
backfill: 是否回补历史异动到数据库(使用 prepare_data.py 方式)
|
||
默认 False,因为实盘模式下 init_for_trade_date 会自动预热数据
|
||
"""
|
||
print("=" * 60)
|
||
print("实时概念异动检测服务(实盘模式)")
|
||
print("=" * 60)
|
||
print(f"模型目录: {self.checkpoint_dir}")
|
||
print(f"交易时段: {TRADING_PERIODS}")
|
||
print(f"ML 序列长度: {CONFIG['seq_len']} 分钟")
|
||
print(f"成交额均值窗口: {FEATURE_CONFIG['amt_ma_window']} 分钟")
|
||
print("=" * 60)
|
||
|
||
# 立即初始化并预热(即使不在交易时间也预热,方便测试)
|
||
trade_date = get_current_trade_date()
|
||
print(f"\n[启动] 初始化交易日 {trade_date}...")
|
||
self.init_for_trade_date(trade_date, preload_history=True)
|
||
|
||
# 可选:回补历史异动记录到数据库
|
||
if backfill and is_trading_time():
|
||
print("\n[启动] 回补历史异动...")
|
||
self.backfill_today()
|
||
|
||
# 每分钟第 10 秒执行
|
||
schedule.every().minute.at(":10").do(self.run_once)
|
||
|
||
print("\n服务已启动,等待下一分钟...")
|
||
|
||
while True:
|
||
schedule.run_pending()
|
||
time.sleep(1)
|
||
|
||
|
||
# ==================== 主函数 ====================
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(description='实时概念异动检测')
|
||
parser.add_argument('--checkpoint_dir', default='ml/checkpoints', help='模型目录')
|
||
parser.add_argument('--device', default='auto', help='设备 (auto/cpu/cuda)')
|
||
parser.add_argument('--once', action='store_true', help='只运行一次检测')
|
||
parser.add_argument('--backfill', action='store_true', help='启动时回补历史异动到数据库')
|
||
parser.add_argument('--backfill-only', action='store_true', help='只回补历史(不持续运行)')
|
||
|
||
args = parser.parse_args()
|
||
|
||
service = RealtimeDetectorService(
|
||
checkpoint_dir=args.checkpoint_dir,
|
||
device=args.device
|
||
)
|
||
|
||
if args.once:
|
||
# 单次检测模式
|
||
service.run_once()
|
||
elif args.backfill_only:
|
||
# 仅回补历史模式(需要 prepare_data.py)
|
||
service.backfill_today()
|
||
else:
|
||
# 实盘持续运行模式(自动预热,不依赖 prepare_data.py)
|
||
service.run_loop(backfill=args.backfill)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|