Files
vf_react/concept_alert_ml.py
2025-12-09 08:31:18 +08:00

1626 lines
52 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
概念异动智能检测服务 - 基于 Z-Score + SVM
- Z-Score: 动态阈值,根据历史波动率判断异常
- SVM: 多特征分类,综合判断是否为有意义的异动
- 支持暴涨和暴跌检测
- 异动重要性评分,避免图表过于密集
"""
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
from sqlalchemy import create_engine, text
from elasticsearch import Elasticsearch
from clickhouse_driver import Client
from collections import deque
import time
import logging
import json
import os
import hashlib
import argparse
import pickle
from typing import Dict, List, Optional, Tuple
# 尝试导入sklearn如果不存在则提示安装
try:
from sklearn.svm import OneClassSVM
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import IsolationForest
SKLEARN_AVAILABLE = True
except ImportError:
SKLEARN_AVAILABLE = False
print("警告: sklearn 未安装,将使用纯 Z-Score 方法")
print("安装命令: pip install scikit-learn")
# ==================== 配置 ====================
# MySQL配置
MYSQL_ENGINE = create_engine(
"mysql+pymysql://root:Zzl5588161!@222.128.1.157:33060/stock",
echo=False
)
# Elasticsearch配置
ES_CLIENT = Elasticsearch(['http://222.128.1.157:19200'])
INDEX_NAME = 'concept_library_v3'
# ClickHouse配置
CLICKHOUSE_CONFIG = {
'host': '222.128.1.157',
'port': 18000,
'user': 'default',
'password': 'Zzl33818!',
'database': 'stock'
}
# 层级结构文件
HIERARCHY_FILE = 'concept_hierarchy_v3.json'
# 模型保存路径
MODEL_DIR = 'models'
os.makedirs(MODEL_DIR, exist_ok=True)
# ==================== 智能异动检测配置 ====================
SMART_ALERT_CONFIG = {
# Z-Score 配置
'zscore': {
'enabled': True,
'lookback_days': 20, # 历史数据回看天数
'threshold_up': 2.5, # 上涨异动阈值(标准差倍数)
'threshold_down': -2.5, # 下跌异动阈值(标准差倍数)
'min_data_points': 10, # 最少数据点数
},
# SVM 异常检测配置
'svm': {
'enabled': SKLEARN_AVAILABLE,
'nu': 0.05, # 异常比例预期5%为异常)
'kernel': 'rbf', # 核函数
'gamma': 'auto',
'retrain_days': 7, # 每N天重新训练
},
# 特征配置用于SVM
'features': [
'change_pct', # 当前涨跌幅
'change_delta_5min', # 5分钟涨跌幅变化
'change_delta_10min', # 10分钟涨跌幅变化
'rank_delta_5min', # 5分钟排名变化
'limit_up_ratio', # 涨停股占比
'volume_ratio', # 成交量比率(预留)
'index_correlation', # 与指数相关性
],
# 重要性评分权重
'importance_weights': {
'zscore_abs': 0.3, # Z-Score 绝对值
'rank_position': 0.2, # 排名位置(越靠前越重要)
'limit_up_count': 0.2, # 涨停数
'stock_count': 0.1, # 概念股票数
'change_magnitude': 0.2, # 涨跌幅度
},
# 显示控制
'display': {
'max_alerts_per_hour': 20, # 每小时最多显示异动数
'min_importance_score': 0.3, # 最低重要性分数
'cooldown_minutes': 15, # 同一概念冷却时间
},
}
# 参考指数
REFERENCE_INDEX = '000001.SH'
# ==================== 日志配置 ====================
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(f'concept_alert_ml_{datetime.now().strftime("%Y%m%d")}.log', encoding='utf-8'),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
# ==================== 全局变量 ====================
ch_client = None
# 历史统计数据缓存
# 结构: {concept_id: {'mean': float, 'std': float, 'history': deque}}
stats_cache: Dict[str, dict] = {}
# 分钟级历史缓存(用于计算变化率)
minute_cache: Dict[str, deque] = {}
MINUTE_WINDOW = 15 # 保留15分钟数据
# 冷却记录
cooldown_cache: Dict[Tuple[str, str], datetime] = {}
# SVM 模型
svm_model = None
svm_scaler = None
svm_last_train = None
def get_ch_client():
"""获取ClickHouse客户端"""
global ch_client
if ch_client is None:
ch_client = Client(**CLICKHOUSE_CONFIG)
return ch_client
def generate_id(name: str) -> str:
"""生成概念ID"""
return hashlib.md5(name.encode('utf-8')).hexdigest()[:16]
def code_to_ch_format(code: str) -> str:
"""将6位股票代码转换为ClickHouse格式"""
if not code or len(code) != 6 or not code.isdigit():
return None
if code.startswith('6'):
return f"{code}.SH"
elif code.startswith('0') or code.startswith('3'):
return f"{code}.SZ"
else:
return f"{code}.BJ"
# ==================== 概念数据获取(复用原有代码)====================
def get_all_concepts():
"""从ES获取所有叶子概念及其股票列表"""
concepts = []
query = {
"query": {"match_all": {}},
"size": 100,
"_source": ["concept_id", "concept", "stocks"]
}
resp = ES_CLIENT.search(index=INDEX_NAME, body=query, scroll='2m')
scroll_id = resp['_scroll_id']
hits = resp['hits']['hits']
while len(hits) > 0:
for hit in hits:
source = hit['_source']
concept_info = {
'concept_id': source.get('concept_id'),
'concept_name': source.get('concept'),
'stocks': [],
'concept_type': 'leaf'
}
if 'stocks' in source and isinstance(source['stocks'], list):
for stock in source['stocks']:
if isinstance(stock, dict) and 'code' in stock and stock['code']:
concept_info['stocks'].append(stock['code'])
if concept_info['stocks']:
concepts.append(concept_info)
resp = ES_CLIENT.scroll(scroll_id=scroll_id, scroll='2m')
scroll_id = resp['_scroll_id']
hits = resp['hits']['hits']
ES_CLIENT.clear_scroll(scroll_id=scroll_id)
return concepts
def load_hierarchy_concepts(leaf_concepts: list) -> list:
"""加载层级结构,生成母概念"""
hierarchy_path = os.path.join(os.path.dirname(__file__), HIERARCHY_FILE)
if not os.path.exists(hierarchy_path):
logger.warning(f"层级文件不存在: {hierarchy_path}")
return []
with open(hierarchy_path, 'r', encoding='utf-8') as f:
hierarchy_data = json.load(f)
concept_to_stocks = {}
for c in leaf_concepts:
concept_to_stocks[c['concept_name']] = set(c['stocks'])
parent_concepts = []
for lv1 in hierarchy_data.get('hierarchy', []):
lv1_name = lv1.get('lv1', '')
lv1_stocks = set()
for child in lv1.get('children', []):
lv2_name = child.get('lv2', '')
lv2_stocks = set()
if 'children' in child:
for lv3_child in child.get('children', []):
lv3_name = lv3_child.get('lv3', '')
lv3_stocks = set()
for concept_name in lv3_child.get('concepts', []):
if concept_name in concept_to_stocks:
lv3_stocks.update(concept_to_stocks[concept_name])
if lv3_stocks:
parent_concepts.append({
'concept_id': generate_id(f"lv3_{lv3_name}"),
'concept_name': f"[三级] {lv3_name}",
'stocks': list(lv3_stocks),
'concept_type': 'lv3'
})
lv2_stocks.update(lv3_stocks)
else:
for concept_name in child.get('concepts', []):
if concept_name in concept_to_stocks:
lv2_stocks.update(concept_to_stocks[concept_name])
if lv2_stocks:
parent_concepts.append({
'concept_id': generate_id(f"lv2_{lv2_name}"),
'concept_name': f"[二级] {lv2_name}",
'stocks': list(lv2_stocks),
'concept_type': 'lv2'
})
lv1_stocks.update(lv2_stocks)
if lv1_stocks:
parent_concepts.append({
'concept_id': generate_id(f"lv1_{lv1_name}"),
'concept_name': f"[一级] {lv1_name}",
'stocks': list(lv1_stocks),
'concept_type': 'lv1'
})
return parent_concepts
# ==================== 价格数据获取 ====================
def get_base_prices(stock_codes: list, current_date: str) -> dict:
"""获取昨收价作为基准"""
if not stock_codes:
return {}
valid_codes = [code for code in stock_codes if code and len(code) == 6 and code.isdigit()]
if not valid_codes:
return {}
stock_codes_str = "','".join(valid_codes)
query = f"""
SELECT SECCODE, F002N
FROM ea_trade
WHERE SECCODE IN ('{stock_codes_str}')
AND TRADEDATE = (
SELECT MAX(TRADEDATE)
FROM ea_trade
WHERE TRADEDATE < '{current_date}'
)
AND F002N IS NOT NULL AND F002N > 0
"""
try:
with MYSQL_ENGINE.connect() as conn:
result = conn.execute(text(query))
base_prices = {row[0]: float(row[1]) for row in result if row[1] and float(row[1]) > 0}
return base_prices
except Exception as e:
logger.error(f"获取基准价格失败: {e}")
return {}
def get_latest_prices(stock_codes: list) -> dict:
"""从ClickHouse获取最新价格"""
if not stock_codes:
return {}
client = get_ch_client()
ch_codes = []
code_mapping = {}
for code in stock_codes:
ch_code = code_to_ch_format(code)
if ch_code:
ch_codes.append(ch_code)
code_mapping[ch_code] = code
if not ch_codes:
return {}
ch_codes_str = "','".join(ch_codes)
query = f"""
SELECT code, close, timestamp
FROM (
SELECT code, close, timestamp,
ROW_NUMBER() OVER (PARTITION BY code ORDER BY timestamp DESC) as rn
FROM stock_minute
WHERE code IN ('{ch_codes_str}')
AND toDate(timestamp) = today()
)
WHERE rn = 1
"""
try:
result = client.execute(query)
if not result:
return {}
latest_prices = {}
for row in result:
ch_code, close, ts = row
if close and close > 0:
pure_code = code_mapping.get(ch_code)
if pure_code:
latest_prices[pure_code] = {
'close': float(close),
'timestamp': ts
}
return latest_prices
except Exception as e:
logger.error(f"获取最新价格失败: {e}")
return {}
def get_prices_at_time(stock_codes: list, timestamp: datetime) -> dict:
"""获取指定时间点的股票价格"""
if not stock_codes:
return {}
client = get_ch_client()
ch_codes = []
code_mapping = {}
for code in stock_codes:
ch_code = code_to_ch_format(code)
if ch_code:
ch_codes.append(ch_code)
code_mapping[ch_code] = code
if not ch_codes:
return {}
ch_codes_str = "','".join(ch_codes)
query = f"""
SELECT code, close, timestamp
FROM stock_minute
WHERE code IN ('{ch_codes_str}')
AND timestamp = '{timestamp.strftime('%Y-%m-%d %H:%M:%S')}'
"""
try:
result = client.execute(query)
prices = {}
for row in result:
ch_code, close, ts = row
if close and close > 0:
pure_code = code_mapping.get(ch_code)
if pure_code:
prices[pure_code] = {
'close': float(close),
'timestamp': ts
}
return prices
except Exception as e:
logger.error(f"获取历史价格失败: {e}")
return {}
def get_index_realtime(index_code: str = REFERENCE_INDEX) -> dict:
"""获取指数实时数据"""
client = get_ch_client()
try:
query = f"""
SELECT close, timestamp
FROM index_minute
WHERE code = '{index_code}'
AND toDate(timestamp) = today()
ORDER BY timestamp DESC
LIMIT 1
"""
result = client.execute(query)
if not result:
return None
close, ts = result[0]
# 获取昨收价
prev_close = get_index_prev_close(index_code, datetime.now().strftime('%Y-%m-%d'))
change_pct = None
if close and prev_close and prev_close > 0:
change_pct = (float(close) - prev_close) / prev_close * 100
return {
'code': index_code,
'price': float(close),
'prev_close': prev_close,
'change_pct': round(change_pct, 4) if change_pct else None,
'timestamp': ts
}
except Exception as e:
logger.error(f"获取指数数据失败: {e}")
return None
def get_index_at_time(index_code: str, timestamp: datetime, prev_close: float) -> dict:
"""获取指定时间点的指数数据"""
client = get_ch_client()
query = f"""
SELECT close, timestamp
FROM index_minute
WHERE code = '{index_code}'
AND timestamp = '{timestamp.strftime('%Y-%m-%d %H:%M:%S')}'
LIMIT 1
"""
try:
result = client.execute(query)
if not result:
return None
close, ts = result[0]
change_pct = None
if close and prev_close and prev_close > 0:
change_pct = (float(close) - prev_close) / prev_close * 100
return {
'code': index_code,
'price': float(close),
'prev_close': prev_close,
'change_pct': round(change_pct, 4) if change_pct else None,
'timestamp': ts
}
except Exception as e:
logger.error(f"获取指数数据失败: {e}")
return None
def get_index_prev_close(index_code: str, trade_date: str) -> float:
"""获取指数昨收价"""
code_no_suffix = index_code.split('.')[0]
with MYSQL_ENGINE.connect() as conn:
result = conn.execute(text("""
SELECT F006N FROM ea_exchangetrade
WHERE INDEXCODE = :code
AND TRADEDATE < :today
ORDER BY TRADEDATE DESC LIMIT 1
"""), {
'code': code_no_suffix,
'today': trade_date
}).fetchone()
if result and result[0]:
return float(result[0])
return None
# ==================== 涨跌幅计算 ====================
def calculate_change_pct(base_prices: dict, latest_prices: dict) -> dict:
"""计算涨跌幅"""
changes = {}
for code, latest in latest_prices.items():
if code in base_prices and base_prices[code] > 0:
base = base_prices[code]
close = latest['close']
change_pct = (close - base) / base * 100
changes[code] = {
'change_pct': round(change_pct, 4),
'close': close,
'base': base
}
return changes
def calculate_concept_stats(concepts: list, stock_changes: dict) -> list:
"""计算概念统计"""
stats = []
for concept in concepts:
concept_id = concept['concept_id']
concept_name = concept['concept_name']
stock_codes = concept['stocks']
concept_type = concept.get('concept_type', 'leaf')
changes = []
limit_up_count = 0
limit_down_count = 0
limit_up_stocks = []
limit_down_stocks = []
for code in stock_codes:
if code in stock_changes:
change_info = stock_changes[code]
change_pct = change_info['change_pct']
changes.append(change_pct)
# 涨停判断(涨幅 >= 9.8%
if change_pct >= 9.8:
limit_up_count += 1
limit_up_stocks.append(code)
# 跌停判断(跌幅 <= -9.8%
elif change_pct <= -9.8:
limit_down_count += 1
limit_down_stocks.append(code)
if not changes:
continue
avg_change_pct = round(np.mean(changes), 4)
std_change_pct = round(np.std(changes), 4) if len(changes) > 1 else 0
stats.append({
'concept_id': concept_id,
'concept_name': concept_name,
'avg_change_pct': avg_change_pct,
'std_change_pct': std_change_pct,
'stock_count': len(changes),
'concept_type': concept_type,
'limit_up_count': limit_up_count,
'limit_down_count': limit_down_count,
'limit_up_stocks': limit_up_stocks,
'limit_down_stocks': limit_down_stocks,
'limit_up_ratio': limit_up_count / len(changes) if changes else 0,
'limit_down_ratio': limit_down_count / len(changes) if changes else 0,
})
# 按涨幅排序并添加排名
stats.sort(key=lambda x: x['avg_change_pct'], reverse=True)
for i, item in enumerate(stats):
item['rank'] = i + 1
return stats
# ==================== Z-Score 计算 ====================
def load_historical_stats(concept_id: str, lookback_days: int = 20) -> dict:
"""
加载概念的历史统计数据用于计算Z-Score
从历史分钟数据中计算每日的平均涨跌幅变化
"""
if concept_id in stats_cache and stats_cache[concept_id].get('loaded'):
return stats_cache[concept_id]
# 查询历史异动记录计算统计
try:
with MYSQL_ENGINE.connect() as conn:
# 获取历史每日变化统计
result = conn.execute(text("""
SELECT
trade_date,
AVG(change_delta) as avg_delta,
MAX(change_delta) as max_delta,
MIN(change_delta) as min_delta
FROM concept_minute_alert
WHERE concept_id = :concept_id
AND trade_date >= DATE_SUB(CURDATE(), INTERVAL :days DAY)
AND alert_type = 'surge'
GROUP BY trade_date
"""), {'concept_id': concept_id, 'days': lookback_days})
rows = list(result)
if len(rows) >= SMART_ALERT_CONFIG['zscore']['min_data_points']:
deltas = [float(r[1]) for r in rows if r[1]]
stats_cache[concept_id] = {
'mean': np.mean(deltas),
'std': np.std(deltas) if len(deltas) > 1 else 1.0,
'count': len(deltas),
'loaded': True
}
else:
# 数据不足,使用默认值
stats_cache[concept_id] = {
'mean': 0.5, # 默认平均变化0.5%
'std': 0.8, # 默认标准差0.8%
'count': 0,
'loaded': True
}
except Exception as e:
logger.error(f"加载历史统计失败: {e}")
stats_cache[concept_id] = {
'mean': 0.5,
'std': 0.8,
'count': 0,
'loaded': True
}
return stats_cache[concept_id]
def calculate_zscore(concept_id: str, change_delta: float) -> float:
"""
计算涨跌幅变化的Z-Score
Z = (X - μ) / σ
"""
stats = load_historical_stats(concept_id)
mean = stats['mean']
std = stats['std']
# 避免除零
if std < 0.1:
std = 0.1
zscore = (change_delta - mean) / std
return round(zscore, 4)
def update_minute_cache(concept_id: str, timestamp: datetime, data: dict):
"""更新分钟级缓存"""
if concept_id not in minute_cache:
minute_cache[concept_id] = deque(maxlen=MINUTE_WINDOW)
minute_cache[concept_id].append({
'timestamp': timestamp,
**data
})
def get_minute_history(concept_id: str, minutes_ago: int) -> Optional[dict]:
"""获取N分钟前的数据"""
if concept_id not in minute_cache:
return None
history = minute_cache[concept_id]
if not history:
return None
# 获取当前最新时间
current_time = history[-1]['timestamp'] if history else datetime.now()
target_time = current_time - timedelta(minutes=minutes_ago)
# 找到最接近目标时间的记录
for record in reversed(list(history)):
if record['timestamp'] <= target_time:
return record
return None
# ==================== SVM 异常检测 ====================
def extract_features(stat: dict, index_data: dict) -> np.ndarray:
"""
提取用于SVM的特征向量
"""
concept_id = stat['concept_id']
# 获取历史数据
prev_5min = get_minute_history(concept_id, 5)
prev_10min = get_minute_history(concept_id, 10)
features = []
# 1. 当前涨跌幅
features.append(stat['avg_change_pct'])
# 2. 5分钟涨跌幅变化
if prev_5min:
features.append(stat['avg_change_pct'] - prev_5min.get('change_pct', 0))
else:
features.append(0)
# 3. 10分钟涨跌幅变化
if prev_10min:
features.append(stat['avg_change_pct'] - prev_10min.get('change_pct', 0))
else:
features.append(0)
# 4. 5分钟排名变化
if prev_5min:
features.append(prev_5min.get('rank', stat['rank']) - stat['rank'])
else:
features.append(0)
# 5. 涨停股占比
features.append(stat.get('limit_up_ratio', 0) * 100)
# 6. 成交量比率预留暂用0
features.append(0)
# 7. 与指数相关性(简化:涨跌方向一致性)
if index_data and index_data.get('change_pct'):
index_change = index_data['change_pct']
concept_change = stat['avg_change_pct']
# 同向为正,反向为负
correlation = 1 if (index_change * concept_change > 0) else -1
features.append(correlation * abs(concept_change - index_change))
else:
features.append(0)
return np.array(features)
def train_svm_model(training_data: List[np.ndarray]):
"""
训练 OneClass SVM 模型
用于检测异常模式
"""
global svm_model, svm_scaler, svm_last_train
if not SKLEARN_AVAILABLE:
return False
if len(training_data) < 100:
logger.warning(f"训练数据不足: {len(training_data)}")
return False
try:
X = np.array(training_data)
# 标准化
svm_scaler = StandardScaler()
X_scaled = svm_scaler.fit_transform(X)
# 训练 OneClass SVM
svm_model = OneClassSVM(
nu=SMART_ALERT_CONFIG['svm']['nu'],
kernel=SMART_ALERT_CONFIG['svm']['kernel'],
gamma=SMART_ALERT_CONFIG['svm']['gamma']
)
svm_model.fit(X_scaled)
svm_last_train = datetime.now()
# 保存模型
model_path = os.path.join(MODEL_DIR, 'svm_model.pkl')
scaler_path = os.path.join(MODEL_DIR, 'svm_scaler.pkl')
with open(model_path, 'wb') as f:
pickle.dump(svm_model, f)
with open(scaler_path, 'wb') as f:
pickle.dump(svm_scaler, f)
logger.info(f"SVM模型训练完成使用 {len(training_data)} 条数据")
return True
except Exception as e:
logger.error(f"SVM模型训练失败: {e}")
return False
def load_svm_model():
"""加载已保存的SVM模型"""
global svm_model, svm_scaler
if not SKLEARN_AVAILABLE:
return False
model_path = os.path.join(MODEL_DIR, 'svm_model.pkl')
scaler_path = os.path.join(MODEL_DIR, 'svm_scaler.pkl')
if os.path.exists(model_path) and os.path.exists(scaler_path):
try:
with open(model_path, 'rb') as f:
svm_model = pickle.load(f)
with open(scaler_path, 'rb') as f:
svm_scaler = pickle.load(f)
logger.info("SVM模型加载成功")
return True
except Exception as e:
logger.error(f"SVM模型加载失败: {e}")
return False
def predict_anomaly(features: np.ndarray) -> Tuple[bool, float]:
"""
使用SVM预测是否为异常
返回: (是否异常, 异常分数)
"""
global svm_model, svm_scaler
if svm_model is None or svm_scaler is None:
return False, 0.0
try:
X_scaled = svm_scaler.transform(features.reshape(1, -1))
prediction = svm_model.predict(X_scaled)[0]
score = svm_model.decision_function(X_scaled)[0]
# prediction: 1 = 正常, -1 = 异常
is_anomaly = prediction == -1
return is_anomaly, float(score)
except Exception as e:
logger.error(f"SVM预测失败: {e}")
return False, 0.0
# ==================== 重要性评分 ====================
def calculate_importance_score(
zscore: float,
rank: int,
limit_up_count: int,
stock_count: int,
change_pct: float,
total_concepts: int
) -> float:
"""
计算异动的重要性分数0-1
综合多个因素判断这条异动是否值得显示
"""
weights = SMART_ALERT_CONFIG['importance_weights']
scores = {}
# 1. Z-Score 绝对值(越大越重要)
zscore_score = min(abs(zscore) / 5.0, 1.0) # 5倍标准差为满分
scores['zscore_abs'] = zscore_score
# 2. 排名位置(越靠前越重要)
rank_score = max(0, 1 - (rank - 1) / min(100, total_concepts))
scores['rank_position'] = rank_score
# 3. 涨停数(越多越重要)
limit_score = min(limit_up_count / 5.0, 1.0) # 5个涨停为满分
scores['limit_up_count'] = limit_score
# 4. 概念股票数(适中最好)
if stock_count < 10:
stock_score = stock_count / 10.0
elif stock_count > 100:
stock_score = max(0.5, 1 - (stock_count - 100) / 200)
else:
stock_score = 1.0
scores['stock_count'] = stock_score
# 5. 涨跌幅度
change_score = min(abs(change_pct) / 5.0, 1.0) # 5%为满分
scores['change_magnitude'] = change_score
# 加权求和
total_score = sum(scores[k] * weights[k] for k in weights)
return round(total_score, 4)
# ==================== 智能异动检测 ====================
def check_cooldown(concept_id: str, alert_type: str) -> bool:
"""检查是否在冷却期"""
key = (concept_id, alert_type)
cooldown_minutes = SMART_ALERT_CONFIG['display']['cooldown_minutes']
if key in cooldown_cache:
last_alert = cooldown_cache[key]
if datetime.now() - last_alert < timedelta(minutes=cooldown_minutes):
return True
return False
def set_cooldown(concept_id: str, alert_type: str, alert_time: datetime = None):
"""设置冷却"""
cooldown_cache[(concept_id, alert_type)] = alert_time or datetime.now()
def detect_smart_alerts(
current_stats: list,
index_data: dict,
trade_date: str,
current_time: datetime = None
) -> list:
"""
智能异动检测
结合 Z-Score + SVM 进行检测
"""
alerts = []
current_time = current_time or datetime.now()
total_concepts = len(current_stats)
for stat in current_stats:
concept_id = stat['concept_id']
concept_name = stat['concept_name']
change_pct = stat['avg_change_pct']
rank = stat['rank']
limit_up_count = stat['limit_up_count']
limit_down_count = stat['limit_down_count']
stock_count = stat['stock_count']
concept_type = stat['concept_type']
# 更新分钟缓存
update_minute_cache(concept_id, current_time, {
'change_pct': change_pct,
'rank': rank,
'limit_up_count': limit_up_count,
'limit_down_count': limit_down_count
})
# 获取历史数据计算变化
prev_5min = get_minute_history(concept_id, 5)
if not prev_5min:
continue
change_delta = change_pct - prev_5min.get('change_pct', 0)
# ========== Z-Score 检测 ==========
zscore = calculate_zscore(concept_id, abs(change_delta))
# 判断是涨还是跌
is_surge_up = change_delta > 0 and zscore >= SMART_ALERT_CONFIG['zscore']['threshold_up']
is_surge_down = change_delta < 0 and zscore >= abs(SMART_ALERT_CONFIG['zscore']['threshold_down'])
if not (is_surge_up or is_surge_down):
continue
alert_type = 'surge_up' if is_surge_up else 'surge_down'
# 检查冷却
if check_cooldown(concept_id, alert_type):
continue
# ========== SVM 验证(可选)==========
svm_is_anomaly = False
svm_score = 0.0
if SMART_ALERT_CONFIG['svm']['enabled'] and svm_model is not None:
features = extract_features(stat, index_data)
svm_is_anomaly, svm_score = predict_anomaly(features)
# 如果SVM认为不是异常降低Z-Score要求
if not svm_is_anomaly and abs(zscore) < 3.5:
continue
# ========== 计算重要性分数 ==========
importance = calculate_importance_score(
zscore=zscore,
rank=rank,
limit_up_count=limit_up_count if is_surge_up else limit_down_count,
stock_count=stock_count,
change_pct=change_pct,
total_concepts=total_concepts
)
# 过滤低重要性
if importance < SMART_ALERT_CONFIG['display']['min_importance_score']:
continue
# ========== 创建异动记录 ==========
alert = {
'concept_id': concept_id,
'concept_name': concept_name,
'alert_type': alert_type,
'alert_time': current_time,
'trade_date': trade_date,
'change_pct': change_pct,
'prev_change_pct': prev_5min.get('change_pct'),
'change_delta': round(change_delta, 4),
'zscore': zscore,
'svm_score': svm_score,
'importance_score': importance,
'limit_up_count': limit_up_count,
'limit_down_count': limit_down_count,
'prev_limit_up_count': prev_5min.get('limit_up_count', 0),
'rank_position': rank,
'prev_rank_position': prev_5min.get('rank'),
'rank_delta': (prev_5min.get('rank', rank) - rank) if prev_5min else 0,
'stock_count': stock_count,
'concept_type': concept_type,
'index_code': REFERENCE_INDEX,
'index_price': index_data['price'] if index_data else None,
'index_change_pct': index_data['change_pct'] if index_data else None,
'extra_info': {
'limit_up_stocks': stat.get('limit_up_stocks', []),
'limit_down_stocks': stat.get('limit_down_stocks', []),
}
}
alerts.append(alert)
set_cooldown(concept_id, alert_type, current_time)
# 日志
direction = "🔥 暴涨" if is_surge_up else "💧 暴跌"
logger.info(
f"{direction}: {concept_name} "
f"涨幅 {prev_5min.get('change_pct', 0):.2f}% -> {change_pct:.2f}% "
f"{change_delta:+.2f}%, Z={zscore:.2f}, 重要性={importance:.2f})"
)
# 按重要性排序,限制数量
alerts.sort(key=lambda x: x['importance_score'], reverse=True)
max_alerts = SMART_ALERT_CONFIG['display']['max_alerts_per_hour']
return alerts[:max_alerts]
# ==================== 数据持久化 ====================
def save_alerts_to_mysql(alerts: list):
"""保存异动数据到MySQL"""
if not alerts:
return 0
saved = 0
with MYSQL_ENGINE.begin() as conn:
for alert in alerts:
try:
insert_sql = text("""
INSERT INTO concept_minute_alert
(concept_id, concept_name, alert_time, alert_type, trade_date,
change_pct, prev_change_pct, change_delta,
limit_up_count, prev_limit_up_count, limit_up_delta,
rank_position, prev_rank_position, rank_delta,
index_code, index_price, index_change_pct,
stock_count, concept_type, extra_info)
VALUES
(:concept_id, :concept_name, :alert_time, :alert_type, :trade_date,
:change_pct, :prev_change_pct, :change_delta,
:limit_up_count, :prev_limit_up_count, :limit_up_delta,
:rank_position, :prev_rank_position, :rank_delta,
:index_code, :index_price, :index_change_pct,
:stock_count, :concept_type, :extra_info)
""")
# 计算 limit_up_delta
limit_up_delta = alert.get('limit_up_count', 0) - alert.get('prev_limit_up_count', 0)
params = {
'concept_id': alert['concept_id'],
'concept_name': alert['concept_name'],
'alert_time': alert['alert_time'],
'alert_type': alert['alert_type'],
'trade_date': alert['trade_date'],
'change_pct': alert.get('change_pct'),
'prev_change_pct': alert.get('prev_change_pct'),
'change_delta': alert.get('change_delta'),
'limit_up_count': alert.get('limit_up_count', 0),
'prev_limit_up_count': alert.get('prev_limit_up_count', 0),
'limit_up_delta': limit_up_delta,
'rank_position': alert.get('rank_position'),
'prev_rank_position': alert.get('prev_rank_position'),
'rank_delta': alert.get('rank_delta'),
'index_code': alert.get('index_code', REFERENCE_INDEX),
'index_price': alert.get('index_price'),
'index_change_pct': alert.get('index_change_pct'),
'stock_count': alert.get('stock_count'),
'concept_type': alert.get('concept_type', 'leaf'),
'extra_info': json.dumps({
**alert.get('extra_info', {}),
'zscore': alert.get('zscore'),
'svm_score': alert.get('svm_score'),
'importance_score': alert.get('importance_score'),
}, ensure_ascii=False) if alert.get('extra_info') else None
}
conn.execute(insert_sql, params)
saved += 1
except Exception as e:
logger.error(f"保存异动失败: {alert['concept_name']} - {e}")
return saved
def save_index_snapshot(index_data: dict, trade_date: str):
"""保存指数快照"""
if not index_data:
return
try:
with MYSQL_ENGINE.begin() as conn:
upsert_sql = text("""
REPLACE INTO index_minute_snapshot
(index_code, trade_date, snapshot_time, price, prev_close, change_pct)
VALUES (:index_code, :trade_date, :snapshot_time, :price, :prev_close, :change_pct)
""")
conn.execute(upsert_sql, {
'index_code': index_data['code'],
'trade_date': trade_date,
'snapshot_time': index_data['timestamp'],
'price': index_data['price'],
'prev_close': index_data.get('prev_close'),
'change_pct': index_data.get('change_pct')
})
except Exception as e:
logger.error(f"保存指数快照失败: {e}")
# ==================== 交易时间判断 ====================
def is_trading_time() -> bool:
"""判断当前是否为交易时间"""
now = datetime.now()
weekday = now.weekday()
if weekday >= 5:
return False
hour, minute = now.hour, now.minute
current_time = hour * 60 + minute
morning_start = 9 * 60 + 30
morning_end = 11 * 60 + 30
afternoon_start = 13 * 60
afternoon_end = 15 * 60
return (morning_start <= current_time <= morning_end) or \
(afternoon_start <= current_time <= afternoon_end)
def get_next_update_time() -> int:
"""获取距离下次更新的秒数"""
now = datetime.now()
if is_trading_time():
return 60 - now.second
else:
hour, minute = now.hour, now.minute
if hour < 9 or (hour == 9 and minute < 30):
target = now.replace(hour=9, minute=30, second=0, microsecond=0)
elif (hour == 11 and minute >= 30) or hour == 12:
target = now.replace(hour=13, minute=0, second=0, microsecond=0)
elif hour >= 15:
target = (now + timedelta(days=1)).replace(hour=9, minute=30, second=0, microsecond=0)
else:
target = now + timedelta(minutes=1)
wait_seconds = (target - now).total_seconds()
return max(60, int(wait_seconds))
# ==================== 主运行逻辑 ====================
def run_once(concepts: list, all_stocks: list) -> tuple:
"""执行一次检测"""
now = datetime.now()
trade_date = now.strftime('%Y-%m-%d')
# 获取基准价格
base_prices = get_base_prices(all_stocks, trade_date)
if not base_prices:
logger.warning("无法获取基准价格")
return 0, 0
# 获取最新价格
latest_prices = get_latest_prices(all_stocks)
if not latest_prices:
logger.warning("无法获取最新价格")
return 0, 0
# 获取指数数据
index_data = get_index_realtime(REFERENCE_INDEX)
if index_data:
save_index_snapshot(index_data, trade_date)
# 计算涨跌幅
stock_changes = calculate_change_pct(base_prices, latest_prices)
if not stock_changes:
logger.warning("无涨跌幅数据")
return 0, 0
logger.info(f"获取到 {len(stock_changes)} 只股票的涨跌幅")
# 计算概念统计
stats = calculate_concept_stats(concepts, stock_changes)
logger.info(f"计算了 {len(stats)} 个概念的涨跌幅")
# 智能异动检测
alerts = detect_smart_alerts(stats, index_data, trade_date, now)
# 保存异动
if alerts:
saved = save_alerts_to_mysql(alerts)
logger.info(f"💾 保存了 {saved} 条异动记录")
return len(stats), len(alerts)
def run_realtime():
"""实时检测主循环"""
logger.info("=" * 60)
logger.info("🚀 启动智能概念异动检测服务 (Z-Score + SVM)")
logger.info("=" * 60)
logger.info(f"配置: {json.dumps(SMART_ALERT_CONFIG, indent=2, ensure_ascii=False, default=str)}")
# 尝试加载SVM模型
if SKLEARN_AVAILABLE:
load_svm_model()
# 加载概念数据
logger.info("加载概念数据...")
leaf_concepts = get_all_concepts()
logger.info(f"获取到 {len(leaf_concepts)} 个叶子概念")
parent_concepts = load_hierarchy_concepts(leaf_concepts)
logger.info(f"生成了 {len(parent_concepts)} 个母概念")
all_concepts = leaf_concepts + parent_concepts
logger.info(f"总计 {len(all_concepts)} 个概念")
# 收集所有股票代码
all_stocks = set()
for c in all_concepts:
all_stocks.update(c['stocks'])
all_stocks = list(all_stocks)
logger.info(f"监控 {len(all_stocks)} 只股票")
last_concept_update = datetime.now()
total_alerts = 0
while True:
try:
now = datetime.now()
# 每小时重新加载概念数据
if (now - last_concept_update).total_seconds() > 3600:
logger.info("重新加载概念数据...")
leaf_concepts = get_all_concepts()
parent_concepts = load_hierarchy_concepts(leaf_concepts)
all_concepts = leaf_concepts + parent_concepts
all_stocks = set()
for c in all_concepts:
all_stocks.update(c['stocks'])
all_stocks = list(all_stocks)
last_concept_update = now
logger.info(f"更新完成: {len(all_concepts)} 个概念, {len(all_stocks)} 只股票")
# 检查是否交易时间
if not is_trading_time():
wait_sec = get_next_update_time()
wait_min = wait_sec // 60
logger.info(f"⏰ 非交易时间,等待 {wait_min} 分钟后重试...")
time.sleep(min(wait_sec, 300))
continue
# 执行检测
logger.info(f"\n{'=' * 40}")
logger.info(f"🔍 检测时间: {now.strftime('%Y-%m-%d %H:%M:%S')}")
updated, alert_count = run_once(all_concepts, all_stocks)
total_alerts += alert_count
if alert_count > 0:
logger.info(f"📊 本次检测到 {alert_count} 条异动,累计 {total_alerts}")
# 等待下一分钟
sleep_sec = 60 - datetime.now().second
logger.info(f"⏳ 等待 {sleep_sec} 秒后继续...")
time.sleep(sleep_sec)
except KeyboardInterrupt:
logger.info("\n收到退出信号,停止服务...")
break
except Exception as e:
logger.error(f"发生错误: {e}")
import traceback
traceback.print_exc()
time.sleep(60)
def run_single():
"""单次运行"""
logger.info("单次检测模式")
if SKLEARN_AVAILABLE:
load_svm_model()
leaf_concepts = get_all_concepts()
parent_concepts = load_hierarchy_concepts(leaf_concepts)
all_concepts = leaf_concepts + parent_concepts
all_stocks = set()
for c in all_concepts:
all_stocks.update(c['stocks'])
all_stocks = list(all_stocks)
logger.info(f"概念数: {len(all_concepts)}, 股票数: {len(all_stocks)}")
updated, alerts = run_once(all_concepts, all_stocks)
logger.info(f"检测完成: {updated} 个概念, {alerts} 条异动")
# ==================== 回测功能 ====================
def get_minute_timestamps(trade_date: str) -> list:
"""获取指定交易日的所有分钟时间戳"""
client = get_ch_client()
query = f"""
SELECT DISTINCT timestamp
FROM stock_minute
WHERE toDate(timestamp) = '{trade_date}'
ORDER BY timestamp
"""
result = client.execute(query)
return [row[0] for row in result]
def run_backtest(trade_date: str, clear_existing: bool = True):
"""
回测指定日期的异动检测
"""
global minute_cache, cooldown_cache, stats_cache
logger.info("=" * 60)
logger.info(f"🔄 开始智能回测: {trade_date}")
logger.info("=" * 60)
# 清空缓存
minute_cache = {}
cooldown_cache = {}
stats_cache = {}
# 清除已有数据
if clear_existing:
with MYSQL_ENGINE.begin() as conn:
conn.execute(text("DELETE FROM concept_minute_alert WHERE trade_date = :date"), {'date': trade_date})
conn.execute(text("DELETE FROM index_minute_snapshot WHERE trade_date = :date"), {'date': trade_date})
logger.info(f"已清除 {trade_date} 的已有数据")
# 加载概念数据
logger.info("加载概念数据...")
leaf_concepts = get_all_concepts()
logger.info(f"获取到 {len(leaf_concepts)} 个叶子概念")
parent_concepts = load_hierarchy_concepts(leaf_concepts)
logger.info(f"生成了 {len(parent_concepts)} 个母概念")
all_concepts = leaf_concepts + parent_concepts
logger.info(f"总计 {len(all_concepts)} 个概念")
# 收集所有股票代码
all_stocks = set()
for c in all_concepts:
all_stocks.update(c['stocks'])
all_stocks = list(all_stocks)
logger.info(f"监控 {len(all_stocks)} 只股票")
# 获取基准价格(昨收价)
base_prices = get_base_prices(all_stocks, trade_date)
if not base_prices:
logger.error("无法获取基准价格,退出回测")
return
logger.info(f"获取到 {len(base_prices)} 个基准价格")
# 获取指数昨收价
index_prev_close = get_index_prev_close(REFERENCE_INDEX, trade_date)
logger.info(f"指数昨收价: {index_prev_close}")
# 获取所有分钟时间戳
timestamps = get_minute_timestamps(trade_date)
if not timestamps:
logger.error(f"未找到 {trade_date} 的分钟数据")
return
logger.info(f"找到 {len(timestamps)} 个分钟时间点")
total_alerts = 0
processed = 0
# 逐分钟处理
for ts in timestamps:
processed += 1
# 获取该时间点的价格
latest_prices = get_prices_at_time(all_stocks, ts)
if not latest_prices:
continue
# 获取指数数据
index_data = get_index_at_time(REFERENCE_INDEX, ts, index_prev_close)
if index_data:
save_index_snapshot(index_data, trade_date)
# 计算涨跌幅
stock_changes = calculate_change_pct(base_prices, latest_prices)
if not stock_changes:
continue
# 计算概念统计
stats = calculate_concept_stats(all_concepts, stock_changes)
# 智能异动检测
alerts = detect_smart_alerts(stats, index_data, trade_date, ts)
# 保存异动
if alerts:
saved = save_alerts_to_mysql(alerts)
total_alerts += saved
# 进度显示
if processed % 30 == 0:
logger.info(f"进度: {processed}/{len(timestamps)} ({processed*100//len(timestamps)}%), 已检测到 {total_alerts} 条异动")
logger.info("=" * 60)
logger.info(f"✅ 回测完成!")
logger.info(f" 处理分钟数: {processed}")
logger.info(f" 检测到异动: {total_alerts}")
logger.info("=" * 60)
def show_status():
"""显示状态"""
print("\n" + "=" * 60)
print("智能概念异动检测服务 (Z-Score + SVM) - 状态")
print("=" * 60)
now = datetime.now()
print(f"\n当前时间: {now.strftime('%Y-%m-%d %H:%M:%S')}")
print(f"是否交易时间: {'' if is_trading_time() else ''}")
print(f"sklearn可用: {'' if SKLEARN_AVAILABLE else ''}")
# 模型状态
model_path = os.path.join(MODEL_DIR, 'svm_model.pkl')
print(f"SVM模型: {'已加载' if os.path.exists(model_path) else '未训练'}")
# 今日异动统计
print("\n今日异动统计:")
try:
with MYSQL_ENGINE.connect() as conn:
result = conn.execute(text("""
SELECT alert_type, COUNT(*) as cnt, AVG(change_delta) as avg_delta
FROM concept_minute_alert
WHERE trade_date = CURDATE()
GROUP BY alert_type
"""))
rows = list(result)
if rows:
for row in rows:
alert_type_name = {
'surge_up': '暴涨',
'surge_down': '暴跌',
'surge': '急涨(旧)',
'limit_up': '涨停增加',
'rank_jump': '排名跃升'
}.get(row[0], row[0])
avg_delta = f"{row[2]:.2f}%" if row[2] else "-"
print(f" {alert_type_name}: {row[1]} 条 (平均变化: {avg_delta})")
else:
print(" 今日暂无异动")
# 最新异动
print("\n最新异动 (前10条):")
result = conn.execute(text("""
SELECT concept_name, alert_type, alert_time, change_pct, change_delta, extra_info
FROM concept_minute_alert
WHERE trade_date = CURDATE()
ORDER BY alert_time DESC
LIMIT 10
"""))
rows = list(result)
if rows:
print(f" {'概念':<20} | {'类型':<6} | {'时间':<8} | {'涨幅':>6} | {'变化':>6} | {'Z分':>5}")
print(" " + "-" * 70)
for row in rows:
name = row[0][:18] if len(row[0]) > 18 else row[0]
alert_type = {'surge_up': '暴涨', 'surge_down': '暴跌'}.get(row[1], row[1][:4])
time_str = row[2].strftime('%H:%M') if row[2] else '-'
change = f"{row[3]:.2f}%" if row[3] else '-'
delta = f"{row[4]:+.2f}%" if row[4] else '-'
# 解析extra_info获取zscore
zscore = '-'
if row[5]:
try:
extra = json.loads(row[5]) if isinstance(row[5], str) else row[5]
zscore = f"{extra.get('zscore', 0):.1f}"
except:
pass
print(f" {name:<20} | {alert_type:<6} | {time_str:<8} | {change:>6} | {delta:>6} | {zscore:>5}")
else:
print(" 暂无异动记录")
except Exception as e:
print(f" 查询失败: {e}")
def train_model():
"""训练SVM模型"""
if not SKLEARN_AVAILABLE:
print("错误: sklearn未安装无法训练模型")
print("安装命令: pip install scikit-learn")
return
logger.info("=" * 60)
logger.info("🎓 开始训练SVM模型")
logger.info("=" * 60)
# 加载概念数据
leaf_concepts = get_all_concepts()
parent_concepts = load_hierarchy_concepts(leaf_concepts)
all_concepts = leaf_concepts + parent_concepts
all_stocks = set()
for c in all_concepts:
all_stocks.update(c['stocks'])
all_stocks = list(all_stocks)
logger.info(f"概念数: {len(all_concepts)}, 股票数: {len(all_stocks)}")
# 收集训练数据(使用历史异动数据)
training_features = []
# 从最近N天的数据中提取特征
lookback_days = 30
try:
with MYSQL_ENGINE.connect() as conn:
result = conn.execute(text("""
SELECT change_pct, change_delta, rank_position, limit_up_count,
stock_count, index_change_pct, extra_info
FROM concept_minute_alert
WHERE trade_date >= DATE_SUB(CURDATE(), INTERVAL :days DAY)
"""), {'days': lookback_days})
for row in result:
features = [
float(row[0]) if row[0] else 0, # change_pct
float(row[1]) if row[1] else 0, # change_delta
0, # change_delta_10min (not available)
-float(row[2]) if row[2] else 0, # rank_delta (approximation)
float(row[3]) / max(1, float(row[4])) * 100 if row[3] and row[4] else 0, # limit_up_ratio
0, # volume_ratio
float(row[0]) - float(row[5]) if row[0] and row[5] else 0, # index_correlation
]
training_features.append(features)
logger.info(f"收集到 {len(training_features)} 条训练数据")
if len(training_features) >= 100:
success = train_svm_model(training_features)
if success:
logger.info("✅ 模型训练成功!")
else:
logger.error("❌ 模型训练失败")
else:
logger.warning("训练数据不足100条跳过训练")
except Exception as e:
logger.error(f"训练失败: {e}")
import traceback
traceback.print_exc()
# ==================== 主函数 ====================
def main():
parser = argparse.ArgumentParser(description='智能概念异动检测服务 (Z-Score + SVM)')
parser.add_argument('command', nargs='?', default='realtime',
choices=['realtime', 'once', 'status', 'backtest', 'train'],
help='命令: realtime(实时运行), once(单次运行), status(状态), backtest(回测), train(训练模型)')
parser.add_argument('--date', '-d', type=str, default=None,
help='回测日期,格式: YYYY-MM-DD默认为今天')
parser.add_argument('--keep', '-k', action='store_true',
help='回测时保留已有数据(默认会清除)')
args = parser.parse_args()
if args.command == 'realtime':
run_realtime()
elif args.command == 'once':
run_single()
elif args.command == 'status':
show_status()
elif args.command == 'backtest':
trade_date = args.date or datetime.now().strftime('%Y-%m-%d')
clear_existing = not args.keep
run_backtest(trade_date, clear_existing)
elif args.command == 'train':
train_model()
if __name__ == "__main__":
main()