1626 lines
52 KiB
Python
1626 lines
52 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
概念异动智能检测服务 - 基于 Z-Score + SVM
|
||
- Z-Score: 动态阈值,根据历史波动率判断异常
|
||
- SVM: 多特征分类,综合判断是否为有意义的异动
|
||
- 支持暴涨和暴跌检测
|
||
- 异动重要性评分,避免图表过于密集
|
||
"""
|
||
|
||
import pandas as pd
|
||
import numpy as np
|
||
from datetime import datetime, timedelta
|
||
from sqlalchemy import create_engine, text
|
||
from elasticsearch import Elasticsearch
|
||
from clickhouse_driver import Client
|
||
from collections import deque
|
||
import time
|
||
import logging
|
||
import json
|
||
import os
|
||
import hashlib
|
||
import argparse
|
||
import pickle
|
||
from typing import Dict, List, Optional, Tuple
|
||
|
||
# 尝试导入sklearn,如果不存在则提示安装
|
||
try:
|
||
from sklearn.svm import OneClassSVM
|
||
from sklearn.preprocessing import StandardScaler
|
||
from sklearn.ensemble import IsolationForest
|
||
SKLEARN_AVAILABLE = True
|
||
except ImportError:
|
||
SKLEARN_AVAILABLE = False
|
||
print("警告: sklearn 未安装,将使用纯 Z-Score 方法")
|
||
print("安装命令: pip install scikit-learn")
|
||
|
||
# ==================== 配置 ====================
|
||
|
||
# MySQL配置
|
||
MYSQL_ENGINE = create_engine(
|
||
"mysql+pymysql://root:Zzl5588161!@222.128.1.157:33060/stock",
|
||
echo=False
|
||
)
|
||
|
||
# Elasticsearch配置
|
||
ES_CLIENT = Elasticsearch(['http://222.128.1.157:19200'])
|
||
INDEX_NAME = 'concept_library_v3'
|
||
|
||
# ClickHouse配置
|
||
CLICKHOUSE_CONFIG = {
|
||
'host': '222.128.1.157',
|
||
'port': 18000,
|
||
'user': 'default',
|
||
'password': 'Zzl33818!',
|
||
'database': 'stock'
|
||
}
|
||
|
||
# 层级结构文件
|
||
HIERARCHY_FILE = 'concept_hierarchy_v3.json'
|
||
|
||
# 模型保存路径
|
||
MODEL_DIR = 'models'
|
||
os.makedirs(MODEL_DIR, exist_ok=True)
|
||
|
||
# ==================== 智能异动检测配置 ====================
|
||
|
||
SMART_ALERT_CONFIG = {
|
||
# Z-Score 配置
|
||
'zscore': {
|
||
'enabled': True,
|
||
'lookback_days': 20, # 历史数据回看天数
|
||
'threshold_up': 2.5, # 上涨异动阈值(标准差倍数)
|
||
'threshold_down': -2.5, # 下跌异动阈值(标准差倍数)
|
||
'min_data_points': 10, # 最少数据点数
|
||
},
|
||
|
||
# SVM 异常检测配置
|
||
'svm': {
|
||
'enabled': SKLEARN_AVAILABLE,
|
||
'nu': 0.05, # 异常比例(预期5%为异常)
|
||
'kernel': 'rbf', # 核函数
|
||
'gamma': 'auto',
|
||
'retrain_days': 7, # 每N天重新训练
|
||
},
|
||
|
||
# 特征配置(用于SVM)
|
||
'features': [
|
||
'change_pct', # 当前涨跌幅
|
||
'change_delta_5min', # 5分钟涨跌幅变化
|
||
'change_delta_10min', # 10分钟涨跌幅变化
|
||
'rank_delta_5min', # 5分钟排名变化
|
||
'limit_up_ratio', # 涨停股占比
|
||
'volume_ratio', # 成交量比率(预留)
|
||
'index_correlation', # 与指数相关性
|
||
],
|
||
|
||
# 重要性评分权重
|
||
'importance_weights': {
|
||
'zscore_abs': 0.3, # Z-Score 绝对值
|
||
'rank_position': 0.2, # 排名位置(越靠前越重要)
|
||
'limit_up_count': 0.2, # 涨停数
|
||
'stock_count': 0.1, # 概念股票数
|
||
'change_magnitude': 0.2, # 涨跌幅度
|
||
},
|
||
|
||
# 显示控制
|
||
'display': {
|
||
'max_alerts_per_hour': 20, # 每小时最多显示异动数
|
||
'min_importance_score': 0.3, # 最低重要性分数
|
||
'cooldown_minutes': 15, # 同一概念冷却时间
|
||
},
|
||
}
|
||
|
||
# 参考指数
|
||
REFERENCE_INDEX = '000001.SH'
|
||
|
||
# ==================== 日志配置 ====================
|
||
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||
handlers=[
|
||
logging.FileHandler(f'concept_alert_ml_{datetime.now().strftime("%Y%m%d")}.log', encoding='utf-8'),
|
||
logging.StreamHandler()
|
||
]
|
||
)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# ==================== 全局变量 ====================
|
||
|
||
ch_client = None
|
||
|
||
# 历史统计数据缓存
|
||
# 结构: {concept_id: {'mean': float, 'std': float, 'history': deque}}
|
||
stats_cache: Dict[str, dict] = {}
|
||
|
||
# 分钟级历史缓存(用于计算变化率)
|
||
minute_cache: Dict[str, deque] = {}
|
||
MINUTE_WINDOW = 15 # 保留15分钟数据
|
||
|
||
# 冷却记录
|
||
cooldown_cache: Dict[Tuple[str, str], datetime] = {}
|
||
|
||
# SVM 模型
|
||
svm_model = None
|
||
svm_scaler = None
|
||
svm_last_train = None
|
||
|
||
|
||
def get_ch_client():
|
||
"""获取ClickHouse客户端"""
|
||
global ch_client
|
||
if ch_client is None:
|
||
ch_client = Client(**CLICKHOUSE_CONFIG)
|
||
return ch_client
|
||
|
||
|
||
def generate_id(name: str) -> str:
|
||
"""生成概念ID"""
|
||
return hashlib.md5(name.encode('utf-8')).hexdigest()[:16]
|
||
|
||
|
||
def code_to_ch_format(code: str) -> str:
|
||
"""将6位股票代码转换为ClickHouse格式"""
|
||
if not code or len(code) != 6 or not code.isdigit():
|
||
return None
|
||
if code.startswith('6'):
|
||
return f"{code}.SH"
|
||
elif code.startswith('0') or code.startswith('3'):
|
||
return f"{code}.SZ"
|
||
else:
|
||
return f"{code}.BJ"
|
||
|
||
|
||
# ==================== 概念数据获取(复用原有代码)====================
|
||
|
||
def get_all_concepts():
|
||
"""从ES获取所有叶子概念及其股票列表"""
|
||
concepts = []
|
||
|
||
query = {
|
||
"query": {"match_all": {}},
|
||
"size": 100,
|
||
"_source": ["concept_id", "concept", "stocks"]
|
||
}
|
||
|
||
resp = ES_CLIENT.search(index=INDEX_NAME, body=query, scroll='2m')
|
||
scroll_id = resp['_scroll_id']
|
||
hits = resp['hits']['hits']
|
||
|
||
while len(hits) > 0:
|
||
for hit in hits:
|
||
source = hit['_source']
|
||
concept_info = {
|
||
'concept_id': source.get('concept_id'),
|
||
'concept_name': source.get('concept'),
|
||
'stocks': [],
|
||
'concept_type': 'leaf'
|
||
}
|
||
|
||
if 'stocks' in source and isinstance(source['stocks'], list):
|
||
for stock in source['stocks']:
|
||
if isinstance(stock, dict) and 'code' in stock and stock['code']:
|
||
concept_info['stocks'].append(stock['code'])
|
||
|
||
if concept_info['stocks']:
|
||
concepts.append(concept_info)
|
||
|
||
resp = ES_CLIENT.scroll(scroll_id=scroll_id, scroll='2m')
|
||
scroll_id = resp['_scroll_id']
|
||
hits = resp['hits']['hits']
|
||
|
||
ES_CLIENT.clear_scroll(scroll_id=scroll_id)
|
||
return concepts
|
||
|
||
|
||
def load_hierarchy_concepts(leaf_concepts: list) -> list:
|
||
"""加载层级结构,生成母概念"""
|
||
hierarchy_path = os.path.join(os.path.dirname(__file__), HIERARCHY_FILE)
|
||
if not os.path.exists(hierarchy_path):
|
||
logger.warning(f"层级文件不存在: {hierarchy_path}")
|
||
return []
|
||
|
||
with open(hierarchy_path, 'r', encoding='utf-8') as f:
|
||
hierarchy_data = json.load(f)
|
||
|
||
concept_to_stocks = {}
|
||
for c in leaf_concepts:
|
||
concept_to_stocks[c['concept_name']] = set(c['stocks'])
|
||
|
||
parent_concepts = []
|
||
|
||
for lv1 in hierarchy_data.get('hierarchy', []):
|
||
lv1_name = lv1.get('lv1', '')
|
||
lv1_stocks = set()
|
||
|
||
for child in lv1.get('children', []):
|
||
lv2_name = child.get('lv2', '')
|
||
lv2_stocks = set()
|
||
|
||
if 'children' in child:
|
||
for lv3_child in child.get('children', []):
|
||
lv3_name = lv3_child.get('lv3', '')
|
||
lv3_stocks = set()
|
||
|
||
for concept_name in lv3_child.get('concepts', []):
|
||
if concept_name in concept_to_stocks:
|
||
lv3_stocks.update(concept_to_stocks[concept_name])
|
||
|
||
if lv3_stocks:
|
||
parent_concepts.append({
|
||
'concept_id': generate_id(f"lv3_{lv3_name}"),
|
||
'concept_name': f"[三级] {lv3_name}",
|
||
'stocks': list(lv3_stocks),
|
||
'concept_type': 'lv3'
|
||
})
|
||
|
||
lv2_stocks.update(lv3_stocks)
|
||
else:
|
||
for concept_name in child.get('concepts', []):
|
||
if concept_name in concept_to_stocks:
|
||
lv2_stocks.update(concept_to_stocks[concept_name])
|
||
|
||
if lv2_stocks:
|
||
parent_concepts.append({
|
||
'concept_id': generate_id(f"lv2_{lv2_name}"),
|
||
'concept_name': f"[二级] {lv2_name}",
|
||
'stocks': list(lv2_stocks),
|
||
'concept_type': 'lv2'
|
||
})
|
||
|
||
lv1_stocks.update(lv2_stocks)
|
||
|
||
if lv1_stocks:
|
||
parent_concepts.append({
|
||
'concept_id': generate_id(f"lv1_{lv1_name}"),
|
||
'concept_name': f"[一级] {lv1_name}",
|
||
'stocks': list(lv1_stocks),
|
||
'concept_type': 'lv1'
|
||
})
|
||
|
||
return parent_concepts
|
||
|
||
|
||
# ==================== 价格数据获取 ====================
|
||
|
||
def get_base_prices(stock_codes: list, current_date: str) -> dict:
|
||
"""获取昨收价作为基准"""
|
||
if not stock_codes:
|
||
return {}
|
||
|
||
valid_codes = [code for code in stock_codes if code and len(code) == 6 and code.isdigit()]
|
||
if not valid_codes:
|
||
return {}
|
||
|
||
stock_codes_str = "','".join(valid_codes)
|
||
|
||
query = f"""
|
||
SELECT SECCODE, F002N
|
||
FROM ea_trade
|
||
WHERE SECCODE IN ('{stock_codes_str}')
|
||
AND TRADEDATE = (
|
||
SELECT MAX(TRADEDATE)
|
||
FROM ea_trade
|
||
WHERE TRADEDATE < '{current_date}'
|
||
)
|
||
AND F002N IS NOT NULL AND F002N > 0
|
||
"""
|
||
|
||
try:
|
||
with MYSQL_ENGINE.connect() as conn:
|
||
result = conn.execute(text(query))
|
||
base_prices = {row[0]: float(row[1]) for row in result if row[1] and float(row[1]) > 0}
|
||
return base_prices
|
||
except Exception as e:
|
||
logger.error(f"获取基准价格失败: {e}")
|
||
return {}
|
||
|
||
|
||
def get_latest_prices(stock_codes: list) -> dict:
|
||
"""从ClickHouse获取最新价格"""
|
||
if not stock_codes:
|
||
return {}
|
||
|
||
client = get_ch_client()
|
||
|
||
ch_codes = []
|
||
code_mapping = {}
|
||
for code in stock_codes:
|
||
ch_code = code_to_ch_format(code)
|
||
if ch_code:
|
||
ch_codes.append(ch_code)
|
||
code_mapping[ch_code] = code
|
||
|
||
if not ch_codes:
|
||
return {}
|
||
|
||
ch_codes_str = "','".join(ch_codes)
|
||
|
||
query = f"""
|
||
SELECT code, close, timestamp
|
||
FROM (
|
||
SELECT code, close, timestamp,
|
||
ROW_NUMBER() OVER (PARTITION BY code ORDER BY timestamp DESC) as rn
|
||
FROM stock_minute
|
||
WHERE code IN ('{ch_codes_str}')
|
||
AND toDate(timestamp) = today()
|
||
)
|
||
WHERE rn = 1
|
||
"""
|
||
|
||
try:
|
||
result = client.execute(query)
|
||
if not result:
|
||
return {}
|
||
|
||
latest_prices = {}
|
||
for row in result:
|
||
ch_code, close, ts = row
|
||
if close and close > 0:
|
||
pure_code = code_mapping.get(ch_code)
|
||
if pure_code:
|
||
latest_prices[pure_code] = {
|
||
'close': float(close),
|
||
'timestamp': ts
|
||
}
|
||
|
||
return latest_prices
|
||
except Exception as e:
|
||
logger.error(f"获取最新价格失败: {e}")
|
||
return {}
|
||
|
||
|
||
def get_prices_at_time(stock_codes: list, timestamp: datetime) -> dict:
|
||
"""获取指定时间点的股票价格"""
|
||
if not stock_codes:
|
||
return {}
|
||
|
||
client = get_ch_client()
|
||
|
||
ch_codes = []
|
||
code_mapping = {}
|
||
for code in stock_codes:
|
||
ch_code = code_to_ch_format(code)
|
||
if ch_code:
|
||
ch_codes.append(ch_code)
|
||
code_mapping[ch_code] = code
|
||
|
||
if not ch_codes:
|
||
return {}
|
||
|
||
ch_codes_str = "','".join(ch_codes)
|
||
|
||
query = f"""
|
||
SELECT code, close, timestamp
|
||
FROM stock_minute
|
||
WHERE code IN ('{ch_codes_str}')
|
||
AND timestamp = '{timestamp.strftime('%Y-%m-%d %H:%M:%S')}'
|
||
"""
|
||
|
||
try:
|
||
result = client.execute(query)
|
||
prices = {}
|
||
for row in result:
|
||
ch_code, close, ts = row
|
||
if close and close > 0:
|
||
pure_code = code_mapping.get(ch_code)
|
||
if pure_code:
|
||
prices[pure_code] = {
|
||
'close': float(close),
|
||
'timestamp': ts
|
||
}
|
||
return prices
|
||
except Exception as e:
|
||
logger.error(f"获取历史价格失败: {e}")
|
||
return {}
|
||
|
||
|
||
def get_index_realtime(index_code: str = REFERENCE_INDEX) -> dict:
|
||
"""获取指数实时数据"""
|
||
client = get_ch_client()
|
||
|
||
try:
|
||
query = f"""
|
||
SELECT close, timestamp
|
||
FROM index_minute
|
||
WHERE code = '{index_code}'
|
||
AND toDate(timestamp) = today()
|
||
ORDER BY timestamp DESC
|
||
LIMIT 1
|
||
"""
|
||
result = client.execute(query)
|
||
|
||
if not result:
|
||
return None
|
||
|
||
close, ts = result[0]
|
||
|
||
# 获取昨收价
|
||
prev_close = get_index_prev_close(index_code, datetime.now().strftime('%Y-%m-%d'))
|
||
|
||
change_pct = None
|
||
if close and prev_close and prev_close > 0:
|
||
change_pct = (float(close) - prev_close) / prev_close * 100
|
||
|
||
return {
|
||
'code': index_code,
|
||
'price': float(close),
|
||
'prev_close': prev_close,
|
||
'change_pct': round(change_pct, 4) if change_pct else None,
|
||
'timestamp': ts
|
||
}
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取指数数据失败: {e}")
|
||
return None
|
||
|
||
|
||
def get_index_at_time(index_code: str, timestamp: datetime, prev_close: float) -> dict:
|
||
"""获取指定时间点的指数数据"""
|
||
client = get_ch_client()
|
||
|
||
query = f"""
|
||
SELECT close, timestamp
|
||
FROM index_minute
|
||
WHERE code = '{index_code}'
|
||
AND timestamp = '{timestamp.strftime('%Y-%m-%d %H:%M:%S')}'
|
||
LIMIT 1
|
||
"""
|
||
|
||
try:
|
||
result = client.execute(query)
|
||
if not result:
|
||
return None
|
||
|
||
close, ts = result[0]
|
||
change_pct = None
|
||
if close and prev_close and prev_close > 0:
|
||
change_pct = (float(close) - prev_close) / prev_close * 100
|
||
|
||
return {
|
||
'code': index_code,
|
||
'price': float(close),
|
||
'prev_close': prev_close,
|
||
'change_pct': round(change_pct, 4) if change_pct else None,
|
||
'timestamp': ts
|
||
}
|
||
except Exception as e:
|
||
logger.error(f"获取指数数据失败: {e}")
|
||
return None
|
||
|
||
|
||
def get_index_prev_close(index_code: str, trade_date: str) -> float:
|
||
"""获取指数昨收价"""
|
||
code_no_suffix = index_code.split('.')[0]
|
||
|
||
with MYSQL_ENGINE.connect() as conn:
|
||
result = conn.execute(text("""
|
||
SELECT F006N FROM ea_exchangetrade
|
||
WHERE INDEXCODE = :code
|
||
AND TRADEDATE < :today
|
||
ORDER BY TRADEDATE DESC LIMIT 1
|
||
"""), {
|
||
'code': code_no_suffix,
|
||
'today': trade_date
|
||
}).fetchone()
|
||
|
||
if result and result[0]:
|
||
return float(result[0])
|
||
return None
|
||
|
||
|
||
# ==================== 涨跌幅计算 ====================
|
||
|
||
def calculate_change_pct(base_prices: dict, latest_prices: dict) -> dict:
|
||
"""计算涨跌幅"""
|
||
changes = {}
|
||
for code, latest in latest_prices.items():
|
||
if code in base_prices and base_prices[code] > 0:
|
||
base = base_prices[code]
|
||
close = latest['close']
|
||
change_pct = (close - base) / base * 100
|
||
changes[code] = {
|
||
'change_pct': round(change_pct, 4),
|
||
'close': close,
|
||
'base': base
|
||
}
|
||
return changes
|
||
|
||
|
||
def calculate_concept_stats(concepts: list, stock_changes: dict) -> list:
|
||
"""计算概念统计"""
|
||
stats = []
|
||
|
||
for concept in concepts:
|
||
concept_id = concept['concept_id']
|
||
concept_name = concept['concept_name']
|
||
stock_codes = concept['stocks']
|
||
concept_type = concept.get('concept_type', 'leaf')
|
||
|
||
changes = []
|
||
limit_up_count = 0
|
||
limit_down_count = 0
|
||
limit_up_stocks = []
|
||
limit_down_stocks = []
|
||
|
||
for code in stock_codes:
|
||
if code in stock_changes:
|
||
change_info = stock_changes[code]
|
||
change_pct = change_info['change_pct']
|
||
changes.append(change_pct)
|
||
|
||
# 涨停判断(涨幅 >= 9.8%)
|
||
if change_pct >= 9.8:
|
||
limit_up_count += 1
|
||
limit_up_stocks.append(code)
|
||
# 跌停判断(跌幅 <= -9.8%)
|
||
elif change_pct <= -9.8:
|
||
limit_down_count += 1
|
||
limit_down_stocks.append(code)
|
||
|
||
if not changes:
|
||
continue
|
||
|
||
avg_change_pct = round(np.mean(changes), 4)
|
||
std_change_pct = round(np.std(changes), 4) if len(changes) > 1 else 0
|
||
|
||
stats.append({
|
||
'concept_id': concept_id,
|
||
'concept_name': concept_name,
|
||
'avg_change_pct': avg_change_pct,
|
||
'std_change_pct': std_change_pct,
|
||
'stock_count': len(changes),
|
||
'concept_type': concept_type,
|
||
'limit_up_count': limit_up_count,
|
||
'limit_down_count': limit_down_count,
|
||
'limit_up_stocks': limit_up_stocks,
|
||
'limit_down_stocks': limit_down_stocks,
|
||
'limit_up_ratio': limit_up_count / len(changes) if changes else 0,
|
||
'limit_down_ratio': limit_down_count / len(changes) if changes else 0,
|
||
})
|
||
|
||
# 按涨幅排序并添加排名
|
||
stats.sort(key=lambda x: x['avg_change_pct'], reverse=True)
|
||
for i, item in enumerate(stats):
|
||
item['rank'] = i + 1
|
||
|
||
return stats
|
||
|
||
|
||
# ==================== Z-Score 计算 ====================
|
||
|
||
def load_historical_stats(concept_id: str, lookback_days: int = 20) -> dict:
|
||
"""
|
||
加载概念的历史统计数据(用于计算Z-Score)
|
||
从历史分钟数据中计算每日的平均涨跌幅变化
|
||
"""
|
||
if concept_id in stats_cache and stats_cache[concept_id].get('loaded'):
|
||
return stats_cache[concept_id]
|
||
|
||
# 查询历史异动记录计算统计
|
||
try:
|
||
with MYSQL_ENGINE.connect() as conn:
|
||
# 获取历史每日变化统计
|
||
result = conn.execute(text("""
|
||
SELECT
|
||
trade_date,
|
||
AVG(change_delta) as avg_delta,
|
||
MAX(change_delta) as max_delta,
|
||
MIN(change_delta) as min_delta
|
||
FROM concept_minute_alert
|
||
WHERE concept_id = :concept_id
|
||
AND trade_date >= DATE_SUB(CURDATE(), INTERVAL :days DAY)
|
||
AND alert_type = 'surge'
|
||
GROUP BY trade_date
|
||
"""), {'concept_id': concept_id, 'days': lookback_days})
|
||
|
||
rows = list(result)
|
||
|
||
if len(rows) >= SMART_ALERT_CONFIG['zscore']['min_data_points']:
|
||
deltas = [float(r[1]) for r in rows if r[1]]
|
||
stats_cache[concept_id] = {
|
||
'mean': np.mean(deltas),
|
||
'std': np.std(deltas) if len(deltas) > 1 else 1.0,
|
||
'count': len(deltas),
|
||
'loaded': True
|
||
}
|
||
else:
|
||
# 数据不足,使用默认值
|
||
stats_cache[concept_id] = {
|
||
'mean': 0.5, # 默认平均变化0.5%
|
||
'std': 0.8, # 默认标准差0.8%
|
||
'count': 0,
|
||
'loaded': True
|
||
}
|
||
except Exception as e:
|
||
logger.error(f"加载历史统计失败: {e}")
|
||
stats_cache[concept_id] = {
|
||
'mean': 0.5,
|
||
'std': 0.8,
|
||
'count': 0,
|
||
'loaded': True
|
||
}
|
||
|
||
return stats_cache[concept_id]
|
||
|
||
|
||
def calculate_zscore(concept_id: str, change_delta: float) -> float:
|
||
"""
|
||
计算涨跌幅变化的Z-Score
|
||
Z = (X - μ) / σ
|
||
"""
|
||
stats = load_historical_stats(concept_id)
|
||
|
||
mean = stats['mean']
|
||
std = stats['std']
|
||
|
||
# 避免除零
|
||
if std < 0.1:
|
||
std = 0.1
|
||
|
||
zscore = (change_delta - mean) / std
|
||
return round(zscore, 4)
|
||
|
||
|
||
def update_minute_cache(concept_id: str, timestamp: datetime, data: dict):
|
||
"""更新分钟级缓存"""
|
||
if concept_id not in minute_cache:
|
||
minute_cache[concept_id] = deque(maxlen=MINUTE_WINDOW)
|
||
|
||
minute_cache[concept_id].append({
|
||
'timestamp': timestamp,
|
||
**data
|
||
})
|
||
|
||
|
||
def get_minute_history(concept_id: str, minutes_ago: int) -> Optional[dict]:
|
||
"""获取N分钟前的数据"""
|
||
if concept_id not in minute_cache:
|
||
return None
|
||
|
||
history = minute_cache[concept_id]
|
||
if not history:
|
||
return None
|
||
|
||
# 获取当前最新时间
|
||
current_time = history[-1]['timestamp'] if history else datetime.now()
|
||
target_time = current_time - timedelta(minutes=minutes_ago)
|
||
|
||
# 找到最接近目标时间的记录
|
||
for record in reversed(list(history)):
|
||
if record['timestamp'] <= target_time:
|
||
return record
|
||
|
||
return None
|
||
|
||
|
||
# ==================== SVM 异常检测 ====================
|
||
|
||
def extract_features(stat: dict, index_data: dict) -> np.ndarray:
|
||
"""
|
||
提取用于SVM的特征向量
|
||
"""
|
||
concept_id = stat['concept_id']
|
||
|
||
# 获取历史数据
|
||
prev_5min = get_minute_history(concept_id, 5)
|
||
prev_10min = get_minute_history(concept_id, 10)
|
||
|
||
features = []
|
||
|
||
# 1. 当前涨跌幅
|
||
features.append(stat['avg_change_pct'])
|
||
|
||
# 2. 5分钟涨跌幅变化
|
||
if prev_5min:
|
||
features.append(stat['avg_change_pct'] - prev_5min.get('change_pct', 0))
|
||
else:
|
||
features.append(0)
|
||
|
||
# 3. 10分钟涨跌幅变化
|
||
if prev_10min:
|
||
features.append(stat['avg_change_pct'] - prev_10min.get('change_pct', 0))
|
||
else:
|
||
features.append(0)
|
||
|
||
# 4. 5分钟排名变化
|
||
if prev_5min:
|
||
features.append(prev_5min.get('rank', stat['rank']) - stat['rank'])
|
||
else:
|
||
features.append(0)
|
||
|
||
# 5. 涨停股占比
|
||
features.append(stat.get('limit_up_ratio', 0) * 100)
|
||
|
||
# 6. 成交量比率(预留,暂用0)
|
||
features.append(0)
|
||
|
||
# 7. 与指数相关性(简化:涨跌方向一致性)
|
||
if index_data and index_data.get('change_pct'):
|
||
index_change = index_data['change_pct']
|
||
concept_change = stat['avg_change_pct']
|
||
# 同向为正,反向为负
|
||
correlation = 1 if (index_change * concept_change > 0) else -1
|
||
features.append(correlation * abs(concept_change - index_change))
|
||
else:
|
||
features.append(0)
|
||
|
||
return np.array(features)
|
||
|
||
|
||
def train_svm_model(training_data: List[np.ndarray]):
|
||
"""
|
||
训练 OneClass SVM 模型
|
||
用于检测异常模式
|
||
"""
|
||
global svm_model, svm_scaler, svm_last_train
|
||
|
||
if not SKLEARN_AVAILABLE:
|
||
return False
|
||
|
||
if len(training_data) < 100:
|
||
logger.warning(f"训练数据不足: {len(training_data)} 条")
|
||
return False
|
||
|
||
try:
|
||
X = np.array(training_data)
|
||
|
||
# 标准化
|
||
svm_scaler = StandardScaler()
|
||
X_scaled = svm_scaler.fit_transform(X)
|
||
|
||
# 训练 OneClass SVM
|
||
svm_model = OneClassSVM(
|
||
nu=SMART_ALERT_CONFIG['svm']['nu'],
|
||
kernel=SMART_ALERT_CONFIG['svm']['kernel'],
|
||
gamma=SMART_ALERT_CONFIG['svm']['gamma']
|
||
)
|
||
svm_model.fit(X_scaled)
|
||
|
||
svm_last_train = datetime.now()
|
||
|
||
# 保存模型
|
||
model_path = os.path.join(MODEL_DIR, 'svm_model.pkl')
|
||
scaler_path = os.path.join(MODEL_DIR, 'svm_scaler.pkl')
|
||
|
||
with open(model_path, 'wb') as f:
|
||
pickle.dump(svm_model, f)
|
||
with open(scaler_path, 'wb') as f:
|
||
pickle.dump(svm_scaler, f)
|
||
|
||
logger.info(f"SVM模型训练完成,使用 {len(training_data)} 条数据")
|
||
return True
|
||
|
||
except Exception as e:
|
||
logger.error(f"SVM模型训练失败: {e}")
|
||
return False
|
||
|
||
|
||
def load_svm_model():
|
||
"""加载已保存的SVM模型"""
|
||
global svm_model, svm_scaler
|
||
|
||
if not SKLEARN_AVAILABLE:
|
||
return False
|
||
|
||
model_path = os.path.join(MODEL_DIR, 'svm_model.pkl')
|
||
scaler_path = os.path.join(MODEL_DIR, 'svm_scaler.pkl')
|
||
|
||
if os.path.exists(model_path) and os.path.exists(scaler_path):
|
||
try:
|
||
with open(model_path, 'rb') as f:
|
||
svm_model = pickle.load(f)
|
||
with open(scaler_path, 'rb') as f:
|
||
svm_scaler = pickle.load(f)
|
||
logger.info("SVM模型加载成功")
|
||
return True
|
||
except Exception as e:
|
||
logger.error(f"SVM模型加载失败: {e}")
|
||
|
||
return False
|
||
|
||
|
||
def predict_anomaly(features: np.ndarray) -> Tuple[bool, float]:
|
||
"""
|
||
使用SVM预测是否为异常
|
||
返回: (是否异常, 异常分数)
|
||
"""
|
||
global svm_model, svm_scaler
|
||
|
||
if svm_model is None or svm_scaler is None:
|
||
return False, 0.0
|
||
|
||
try:
|
||
X_scaled = svm_scaler.transform(features.reshape(1, -1))
|
||
prediction = svm_model.predict(X_scaled)[0]
|
||
score = svm_model.decision_function(X_scaled)[0]
|
||
|
||
# prediction: 1 = 正常, -1 = 异常
|
||
is_anomaly = prediction == -1
|
||
|
||
return is_anomaly, float(score)
|
||
except Exception as e:
|
||
logger.error(f"SVM预测失败: {e}")
|
||
return False, 0.0
|
||
|
||
|
||
# ==================== 重要性评分 ====================
|
||
|
||
def calculate_importance_score(
|
||
zscore: float,
|
||
rank: int,
|
||
limit_up_count: int,
|
||
stock_count: int,
|
||
change_pct: float,
|
||
total_concepts: int
|
||
) -> float:
|
||
"""
|
||
计算异动的重要性分数(0-1)
|
||
综合多个因素判断这条异动是否值得显示
|
||
"""
|
||
weights = SMART_ALERT_CONFIG['importance_weights']
|
||
|
||
scores = {}
|
||
|
||
# 1. Z-Score 绝对值(越大越重要)
|
||
zscore_score = min(abs(zscore) / 5.0, 1.0) # 5倍标准差为满分
|
||
scores['zscore_abs'] = zscore_score
|
||
|
||
# 2. 排名位置(越靠前越重要)
|
||
rank_score = max(0, 1 - (rank - 1) / min(100, total_concepts))
|
||
scores['rank_position'] = rank_score
|
||
|
||
# 3. 涨停数(越多越重要)
|
||
limit_score = min(limit_up_count / 5.0, 1.0) # 5个涨停为满分
|
||
scores['limit_up_count'] = limit_score
|
||
|
||
# 4. 概念股票数(适中最好)
|
||
if stock_count < 10:
|
||
stock_score = stock_count / 10.0
|
||
elif stock_count > 100:
|
||
stock_score = max(0.5, 1 - (stock_count - 100) / 200)
|
||
else:
|
||
stock_score = 1.0
|
||
scores['stock_count'] = stock_score
|
||
|
||
# 5. 涨跌幅度
|
||
change_score = min(abs(change_pct) / 5.0, 1.0) # 5%为满分
|
||
scores['change_magnitude'] = change_score
|
||
|
||
# 加权求和
|
||
total_score = sum(scores[k] * weights[k] for k in weights)
|
||
|
||
return round(total_score, 4)
|
||
|
||
|
||
# ==================== 智能异动检测 ====================
|
||
|
||
def check_cooldown(concept_id: str, alert_type: str) -> bool:
|
||
"""检查是否在冷却期"""
|
||
key = (concept_id, alert_type)
|
||
cooldown_minutes = SMART_ALERT_CONFIG['display']['cooldown_minutes']
|
||
|
||
if key in cooldown_cache:
|
||
last_alert = cooldown_cache[key]
|
||
if datetime.now() - last_alert < timedelta(minutes=cooldown_minutes):
|
||
return True
|
||
return False
|
||
|
||
|
||
def set_cooldown(concept_id: str, alert_type: str, alert_time: datetime = None):
|
||
"""设置冷却"""
|
||
cooldown_cache[(concept_id, alert_type)] = alert_time or datetime.now()
|
||
|
||
|
||
def detect_smart_alerts(
|
||
current_stats: list,
|
||
index_data: dict,
|
||
trade_date: str,
|
||
current_time: datetime = None
|
||
) -> list:
|
||
"""
|
||
智能异动检测
|
||
结合 Z-Score + SVM 进行检测
|
||
"""
|
||
alerts = []
|
||
current_time = current_time or datetime.now()
|
||
total_concepts = len(current_stats)
|
||
|
||
for stat in current_stats:
|
||
concept_id = stat['concept_id']
|
||
concept_name = stat['concept_name']
|
||
change_pct = stat['avg_change_pct']
|
||
rank = stat['rank']
|
||
limit_up_count = stat['limit_up_count']
|
||
limit_down_count = stat['limit_down_count']
|
||
stock_count = stat['stock_count']
|
||
concept_type = stat['concept_type']
|
||
|
||
# 更新分钟缓存
|
||
update_minute_cache(concept_id, current_time, {
|
||
'change_pct': change_pct,
|
||
'rank': rank,
|
||
'limit_up_count': limit_up_count,
|
||
'limit_down_count': limit_down_count
|
||
})
|
||
|
||
# 获取历史数据计算变化
|
||
prev_5min = get_minute_history(concept_id, 5)
|
||
if not prev_5min:
|
||
continue
|
||
|
||
change_delta = change_pct - prev_5min.get('change_pct', 0)
|
||
|
||
# ========== Z-Score 检测 ==========
|
||
zscore = calculate_zscore(concept_id, abs(change_delta))
|
||
|
||
# 判断是涨还是跌
|
||
is_surge_up = change_delta > 0 and zscore >= SMART_ALERT_CONFIG['zscore']['threshold_up']
|
||
is_surge_down = change_delta < 0 and zscore >= abs(SMART_ALERT_CONFIG['zscore']['threshold_down'])
|
||
|
||
if not (is_surge_up or is_surge_down):
|
||
continue
|
||
|
||
alert_type = 'surge_up' if is_surge_up else 'surge_down'
|
||
|
||
# 检查冷却
|
||
if check_cooldown(concept_id, alert_type):
|
||
continue
|
||
|
||
# ========== SVM 验证(可选)==========
|
||
svm_is_anomaly = False
|
||
svm_score = 0.0
|
||
|
||
if SMART_ALERT_CONFIG['svm']['enabled'] and svm_model is not None:
|
||
features = extract_features(stat, index_data)
|
||
svm_is_anomaly, svm_score = predict_anomaly(features)
|
||
|
||
# 如果SVM认为不是异常,降低Z-Score要求
|
||
if not svm_is_anomaly and abs(zscore) < 3.5:
|
||
continue
|
||
|
||
# ========== 计算重要性分数 ==========
|
||
importance = calculate_importance_score(
|
||
zscore=zscore,
|
||
rank=rank,
|
||
limit_up_count=limit_up_count if is_surge_up else limit_down_count,
|
||
stock_count=stock_count,
|
||
change_pct=change_pct,
|
||
total_concepts=total_concepts
|
||
)
|
||
|
||
# 过滤低重要性
|
||
if importance < SMART_ALERT_CONFIG['display']['min_importance_score']:
|
||
continue
|
||
|
||
# ========== 创建异动记录 ==========
|
||
alert = {
|
||
'concept_id': concept_id,
|
||
'concept_name': concept_name,
|
||
'alert_type': alert_type,
|
||
'alert_time': current_time,
|
||
'trade_date': trade_date,
|
||
'change_pct': change_pct,
|
||
'prev_change_pct': prev_5min.get('change_pct'),
|
||
'change_delta': round(change_delta, 4),
|
||
'zscore': zscore,
|
||
'svm_score': svm_score,
|
||
'importance_score': importance,
|
||
'limit_up_count': limit_up_count,
|
||
'limit_down_count': limit_down_count,
|
||
'prev_limit_up_count': prev_5min.get('limit_up_count', 0),
|
||
'rank_position': rank,
|
||
'prev_rank_position': prev_5min.get('rank'),
|
||
'rank_delta': (prev_5min.get('rank', rank) - rank) if prev_5min else 0,
|
||
'stock_count': stock_count,
|
||
'concept_type': concept_type,
|
||
'index_code': REFERENCE_INDEX,
|
||
'index_price': index_data['price'] if index_data else None,
|
||
'index_change_pct': index_data['change_pct'] if index_data else None,
|
||
'extra_info': {
|
||
'limit_up_stocks': stat.get('limit_up_stocks', []),
|
||
'limit_down_stocks': stat.get('limit_down_stocks', []),
|
||
}
|
||
}
|
||
|
||
alerts.append(alert)
|
||
set_cooldown(concept_id, alert_type, current_time)
|
||
|
||
# 日志
|
||
direction = "🔥 暴涨" if is_surge_up else "💧 暴跌"
|
||
logger.info(
|
||
f"{direction}: {concept_name} "
|
||
f"涨幅 {prev_5min.get('change_pct', 0):.2f}% -> {change_pct:.2f}% "
|
||
f"(Δ{change_delta:+.2f}%, Z={zscore:.2f}, 重要性={importance:.2f})"
|
||
)
|
||
|
||
# 按重要性排序,限制数量
|
||
alerts.sort(key=lambda x: x['importance_score'], reverse=True)
|
||
max_alerts = SMART_ALERT_CONFIG['display']['max_alerts_per_hour']
|
||
|
||
return alerts[:max_alerts]
|
||
|
||
|
||
# ==================== 数据持久化 ====================
|
||
|
||
def save_alerts_to_mysql(alerts: list):
|
||
"""保存异动数据到MySQL"""
|
||
if not alerts:
|
||
return 0
|
||
|
||
saved = 0
|
||
with MYSQL_ENGINE.begin() as conn:
|
||
for alert in alerts:
|
||
try:
|
||
insert_sql = text("""
|
||
INSERT INTO concept_minute_alert
|
||
(concept_id, concept_name, alert_time, alert_type, trade_date,
|
||
change_pct, prev_change_pct, change_delta,
|
||
limit_up_count, prev_limit_up_count, limit_up_delta,
|
||
rank_position, prev_rank_position, rank_delta,
|
||
index_code, index_price, index_change_pct,
|
||
stock_count, concept_type, extra_info)
|
||
VALUES
|
||
(:concept_id, :concept_name, :alert_time, :alert_type, :trade_date,
|
||
:change_pct, :prev_change_pct, :change_delta,
|
||
:limit_up_count, :prev_limit_up_count, :limit_up_delta,
|
||
:rank_position, :prev_rank_position, :rank_delta,
|
||
:index_code, :index_price, :index_change_pct,
|
||
:stock_count, :concept_type, :extra_info)
|
||
""")
|
||
|
||
# 计算 limit_up_delta
|
||
limit_up_delta = alert.get('limit_up_count', 0) - alert.get('prev_limit_up_count', 0)
|
||
|
||
params = {
|
||
'concept_id': alert['concept_id'],
|
||
'concept_name': alert['concept_name'],
|
||
'alert_time': alert['alert_time'],
|
||
'alert_type': alert['alert_type'],
|
||
'trade_date': alert['trade_date'],
|
||
'change_pct': alert.get('change_pct'),
|
||
'prev_change_pct': alert.get('prev_change_pct'),
|
||
'change_delta': alert.get('change_delta'),
|
||
'limit_up_count': alert.get('limit_up_count', 0),
|
||
'prev_limit_up_count': alert.get('prev_limit_up_count', 0),
|
||
'limit_up_delta': limit_up_delta,
|
||
'rank_position': alert.get('rank_position'),
|
||
'prev_rank_position': alert.get('prev_rank_position'),
|
||
'rank_delta': alert.get('rank_delta'),
|
||
'index_code': alert.get('index_code', REFERENCE_INDEX),
|
||
'index_price': alert.get('index_price'),
|
||
'index_change_pct': alert.get('index_change_pct'),
|
||
'stock_count': alert.get('stock_count'),
|
||
'concept_type': alert.get('concept_type', 'leaf'),
|
||
'extra_info': json.dumps({
|
||
**alert.get('extra_info', {}),
|
||
'zscore': alert.get('zscore'),
|
||
'svm_score': alert.get('svm_score'),
|
||
'importance_score': alert.get('importance_score'),
|
||
}, ensure_ascii=False) if alert.get('extra_info') else None
|
||
}
|
||
|
||
conn.execute(insert_sql, params)
|
||
saved += 1
|
||
|
||
except Exception as e:
|
||
logger.error(f"保存异动失败: {alert['concept_name']} - {e}")
|
||
|
||
return saved
|
||
|
||
|
||
def save_index_snapshot(index_data: dict, trade_date: str):
|
||
"""保存指数快照"""
|
||
if not index_data:
|
||
return
|
||
|
||
try:
|
||
with MYSQL_ENGINE.begin() as conn:
|
||
upsert_sql = text("""
|
||
REPLACE INTO index_minute_snapshot
|
||
(index_code, trade_date, snapshot_time, price, prev_close, change_pct)
|
||
VALUES (:index_code, :trade_date, :snapshot_time, :price, :prev_close, :change_pct)
|
||
""")
|
||
|
||
conn.execute(upsert_sql, {
|
||
'index_code': index_data['code'],
|
||
'trade_date': trade_date,
|
||
'snapshot_time': index_data['timestamp'],
|
||
'price': index_data['price'],
|
||
'prev_close': index_data.get('prev_close'),
|
||
'change_pct': index_data.get('change_pct')
|
||
})
|
||
except Exception as e:
|
||
logger.error(f"保存指数快照失败: {e}")
|
||
|
||
|
||
# ==================== 交易时间判断 ====================
|
||
|
||
def is_trading_time() -> bool:
|
||
"""判断当前是否为交易时间"""
|
||
now = datetime.now()
|
||
weekday = now.weekday()
|
||
|
||
if weekday >= 5:
|
||
return False
|
||
|
||
hour, minute = now.hour, now.minute
|
||
current_time = hour * 60 + minute
|
||
|
||
morning_start = 9 * 60 + 30
|
||
morning_end = 11 * 60 + 30
|
||
afternoon_start = 13 * 60
|
||
afternoon_end = 15 * 60
|
||
|
||
return (morning_start <= current_time <= morning_end) or \
|
||
(afternoon_start <= current_time <= afternoon_end)
|
||
|
||
|
||
def get_next_update_time() -> int:
|
||
"""获取距离下次更新的秒数"""
|
||
now = datetime.now()
|
||
|
||
if is_trading_time():
|
||
return 60 - now.second
|
||
else:
|
||
hour, minute = now.hour, now.minute
|
||
|
||
if hour < 9 or (hour == 9 and minute < 30):
|
||
target = now.replace(hour=9, minute=30, second=0, microsecond=0)
|
||
elif (hour == 11 and minute >= 30) or hour == 12:
|
||
target = now.replace(hour=13, minute=0, second=0, microsecond=0)
|
||
elif hour >= 15:
|
||
target = (now + timedelta(days=1)).replace(hour=9, minute=30, second=0, microsecond=0)
|
||
else:
|
||
target = now + timedelta(minutes=1)
|
||
|
||
wait_seconds = (target - now).total_seconds()
|
||
return max(60, int(wait_seconds))
|
||
|
||
|
||
# ==================== 主运行逻辑 ====================
|
||
|
||
def run_once(concepts: list, all_stocks: list) -> tuple:
|
||
"""执行一次检测"""
|
||
now = datetime.now()
|
||
trade_date = now.strftime('%Y-%m-%d')
|
||
|
||
# 获取基准价格
|
||
base_prices = get_base_prices(all_stocks, trade_date)
|
||
if not base_prices:
|
||
logger.warning("无法获取基准价格")
|
||
return 0, 0
|
||
|
||
# 获取最新价格
|
||
latest_prices = get_latest_prices(all_stocks)
|
||
if not latest_prices:
|
||
logger.warning("无法获取最新价格")
|
||
return 0, 0
|
||
|
||
# 获取指数数据
|
||
index_data = get_index_realtime(REFERENCE_INDEX)
|
||
if index_data:
|
||
save_index_snapshot(index_data, trade_date)
|
||
|
||
# 计算涨跌幅
|
||
stock_changes = calculate_change_pct(base_prices, latest_prices)
|
||
if not stock_changes:
|
||
logger.warning("无涨跌幅数据")
|
||
return 0, 0
|
||
|
||
logger.info(f"获取到 {len(stock_changes)} 只股票的涨跌幅")
|
||
|
||
# 计算概念统计
|
||
stats = calculate_concept_stats(concepts, stock_changes)
|
||
logger.info(f"计算了 {len(stats)} 个概念的涨跌幅")
|
||
|
||
# 智能异动检测
|
||
alerts = detect_smart_alerts(stats, index_data, trade_date, now)
|
||
|
||
# 保存异动
|
||
if alerts:
|
||
saved = save_alerts_to_mysql(alerts)
|
||
logger.info(f"💾 保存了 {saved} 条异动记录")
|
||
|
||
return len(stats), len(alerts)
|
||
|
||
|
||
def run_realtime():
|
||
"""实时检测主循环"""
|
||
logger.info("=" * 60)
|
||
logger.info("🚀 启动智能概念异动检测服务 (Z-Score + SVM)")
|
||
logger.info("=" * 60)
|
||
logger.info(f"配置: {json.dumps(SMART_ALERT_CONFIG, indent=2, ensure_ascii=False, default=str)}")
|
||
|
||
# 尝试加载SVM模型
|
||
if SKLEARN_AVAILABLE:
|
||
load_svm_model()
|
||
|
||
# 加载概念数据
|
||
logger.info("加载概念数据...")
|
||
leaf_concepts = get_all_concepts()
|
||
logger.info(f"获取到 {len(leaf_concepts)} 个叶子概念")
|
||
|
||
parent_concepts = load_hierarchy_concepts(leaf_concepts)
|
||
logger.info(f"生成了 {len(parent_concepts)} 个母概念")
|
||
|
||
all_concepts = leaf_concepts + parent_concepts
|
||
logger.info(f"总计 {len(all_concepts)} 个概念")
|
||
|
||
# 收集所有股票代码
|
||
all_stocks = set()
|
||
for c in all_concepts:
|
||
all_stocks.update(c['stocks'])
|
||
all_stocks = list(all_stocks)
|
||
logger.info(f"监控 {len(all_stocks)} 只股票")
|
||
|
||
last_concept_update = datetime.now()
|
||
total_alerts = 0
|
||
|
||
while True:
|
||
try:
|
||
now = datetime.now()
|
||
|
||
# 每小时重新加载概念数据
|
||
if (now - last_concept_update).total_seconds() > 3600:
|
||
logger.info("重新加载概念数据...")
|
||
leaf_concepts = get_all_concepts()
|
||
parent_concepts = load_hierarchy_concepts(leaf_concepts)
|
||
all_concepts = leaf_concepts + parent_concepts
|
||
all_stocks = set()
|
||
for c in all_concepts:
|
||
all_stocks.update(c['stocks'])
|
||
all_stocks = list(all_stocks)
|
||
last_concept_update = now
|
||
logger.info(f"更新完成: {len(all_concepts)} 个概念, {len(all_stocks)} 只股票")
|
||
|
||
# 检查是否交易时间
|
||
if not is_trading_time():
|
||
wait_sec = get_next_update_time()
|
||
wait_min = wait_sec // 60
|
||
logger.info(f"⏰ 非交易时间,等待 {wait_min} 分钟后重试...")
|
||
time.sleep(min(wait_sec, 300))
|
||
continue
|
||
|
||
# 执行检测
|
||
logger.info(f"\n{'=' * 40}")
|
||
logger.info(f"🔍 检测时间: {now.strftime('%Y-%m-%d %H:%M:%S')}")
|
||
updated, alert_count = run_once(all_concepts, all_stocks)
|
||
total_alerts += alert_count
|
||
|
||
if alert_count > 0:
|
||
logger.info(f"📊 本次检测到 {alert_count} 条异动,累计 {total_alerts} 条")
|
||
|
||
# 等待下一分钟
|
||
sleep_sec = 60 - datetime.now().second
|
||
logger.info(f"⏳ 等待 {sleep_sec} 秒后继续...")
|
||
time.sleep(sleep_sec)
|
||
|
||
except KeyboardInterrupt:
|
||
logger.info("\n收到退出信号,停止服务...")
|
||
break
|
||
except Exception as e:
|
||
logger.error(f"发生错误: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
time.sleep(60)
|
||
|
||
|
||
def run_single():
|
||
"""单次运行"""
|
||
logger.info("单次检测模式")
|
||
|
||
if SKLEARN_AVAILABLE:
|
||
load_svm_model()
|
||
|
||
leaf_concepts = get_all_concepts()
|
||
parent_concepts = load_hierarchy_concepts(leaf_concepts)
|
||
all_concepts = leaf_concepts + parent_concepts
|
||
|
||
all_stocks = set()
|
||
for c in all_concepts:
|
||
all_stocks.update(c['stocks'])
|
||
all_stocks = list(all_stocks)
|
||
|
||
logger.info(f"概念数: {len(all_concepts)}, 股票数: {len(all_stocks)}")
|
||
|
||
updated, alerts = run_once(all_concepts, all_stocks)
|
||
logger.info(f"检测完成: {updated} 个概念, {alerts} 条异动")
|
||
|
||
|
||
# ==================== 回测功能 ====================
|
||
|
||
def get_minute_timestamps(trade_date: str) -> list:
|
||
"""获取指定交易日的所有分钟时间戳"""
|
||
client = get_ch_client()
|
||
|
||
query = f"""
|
||
SELECT DISTINCT timestamp
|
||
FROM stock_minute
|
||
WHERE toDate(timestamp) = '{trade_date}'
|
||
ORDER BY timestamp
|
||
"""
|
||
|
||
result = client.execute(query)
|
||
return [row[0] for row in result]
|
||
|
||
|
||
def run_backtest(trade_date: str, clear_existing: bool = True):
|
||
"""
|
||
回测指定日期的异动检测
|
||
"""
|
||
global minute_cache, cooldown_cache, stats_cache
|
||
|
||
logger.info("=" * 60)
|
||
logger.info(f"🔄 开始智能回测: {trade_date}")
|
||
logger.info("=" * 60)
|
||
|
||
# 清空缓存
|
||
minute_cache = {}
|
||
cooldown_cache = {}
|
||
stats_cache = {}
|
||
|
||
# 清除已有数据
|
||
if clear_existing:
|
||
with MYSQL_ENGINE.begin() as conn:
|
||
conn.execute(text("DELETE FROM concept_minute_alert WHERE trade_date = :date"), {'date': trade_date})
|
||
conn.execute(text("DELETE FROM index_minute_snapshot WHERE trade_date = :date"), {'date': trade_date})
|
||
logger.info(f"已清除 {trade_date} 的已有数据")
|
||
|
||
# 加载概念数据
|
||
logger.info("加载概念数据...")
|
||
leaf_concepts = get_all_concepts()
|
||
logger.info(f"获取到 {len(leaf_concepts)} 个叶子概念")
|
||
|
||
parent_concepts = load_hierarchy_concepts(leaf_concepts)
|
||
logger.info(f"生成了 {len(parent_concepts)} 个母概念")
|
||
|
||
all_concepts = leaf_concepts + parent_concepts
|
||
logger.info(f"总计 {len(all_concepts)} 个概念")
|
||
|
||
# 收集所有股票代码
|
||
all_stocks = set()
|
||
for c in all_concepts:
|
||
all_stocks.update(c['stocks'])
|
||
all_stocks = list(all_stocks)
|
||
logger.info(f"监控 {len(all_stocks)} 只股票")
|
||
|
||
# 获取基准价格(昨收价)
|
||
base_prices = get_base_prices(all_stocks, trade_date)
|
||
if not base_prices:
|
||
logger.error("无法获取基准价格,退出回测")
|
||
return
|
||
logger.info(f"获取到 {len(base_prices)} 个基准价格")
|
||
|
||
# 获取指数昨收价
|
||
index_prev_close = get_index_prev_close(REFERENCE_INDEX, trade_date)
|
||
logger.info(f"指数昨收价: {index_prev_close}")
|
||
|
||
# 获取所有分钟时间戳
|
||
timestamps = get_minute_timestamps(trade_date)
|
||
if not timestamps:
|
||
logger.error(f"未找到 {trade_date} 的分钟数据")
|
||
return
|
||
logger.info(f"找到 {len(timestamps)} 个分钟时间点")
|
||
|
||
total_alerts = 0
|
||
processed = 0
|
||
|
||
# 逐分钟处理
|
||
for ts in timestamps:
|
||
processed += 1
|
||
|
||
# 获取该时间点的价格
|
||
latest_prices = get_prices_at_time(all_stocks, ts)
|
||
if not latest_prices:
|
||
continue
|
||
|
||
# 获取指数数据
|
||
index_data = get_index_at_time(REFERENCE_INDEX, ts, index_prev_close)
|
||
if index_data:
|
||
save_index_snapshot(index_data, trade_date)
|
||
|
||
# 计算涨跌幅
|
||
stock_changes = calculate_change_pct(base_prices, latest_prices)
|
||
if not stock_changes:
|
||
continue
|
||
|
||
# 计算概念统计
|
||
stats = calculate_concept_stats(all_concepts, stock_changes)
|
||
|
||
# 智能异动检测
|
||
alerts = detect_smart_alerts(stats, index_data, trade_date, ts)
|
||
|
||
# 保存异动
|
||
if alerts:
|
||
saved = save_alerts_to_mysql(alerts)
|
||
total_alerts += saved
|
||
|
||
# 进度显示
|
||
if processed % 30 == 0:
|
||
logger.info(f"进度: {processed}/{len(timestamps)} ({processed*100//len(timestamps)}%), 已检测到 {total_alerts} 条异动")
|
||
|
||
logger.info("=" * 60)
|
||
logger.info(f"✅ 回测完成!")
|
||
logger.info(f" 处理分钟数: {processed}")
|
||
logger.info(f" 检测到异动: {total_alerts} 条")
|
||
logger.info("=" * 60)
|
||
|
||
|
||
def show_status():
|
||
"""显示状态"""
|
||
print("\n" + "=" * 60)
|
||
print("智能概念异动检测服务 (Z-Score + SVM) - 状态")
|
||
print("=" * 60)
|
||
|
||
now = datetime.now()
|
||
print(f"\n当前时间: {now.strftime('%Y-%m-%d %H:%M:%S')}")
|
||
print(f"是否交易时间: {'是' if is_trading_time() else '否'}")
|
||
print(f"sklearn可用: {'是' if SKLEARN_AVAILABLE else '否'}")
|
||
|
||
# 模型状态
|
||
model_path = os.path.join(MODEL_DIR, 'svm_model.pkl')
|
||
print(f"SVM模型: {'已加载' if os.path.exists(model_path) else '未训练'}")
|
||
|
||
# 今日异动统计
|
||
print("\n今日异动统计:")
|
||
try:
|
||
with MYSQL_ENGINE.connect() as conn:
|
||
result = conn.execute(text("""
|
||
SELECT alert_type, COUNT(*) as cnt, AVG(change_delta) as avg_delta
|
||
FROM concept_minute_alert
|
||
WHERE trade_date = CURDATE()
|
||
GROUP BY alert_type
|
||
"""))
|
||
rows = list(result)
|
||
if rows:
|
||
for row in rows:
|
||
alert_type_name = {
|
||
'surge_up': '暴涨',
|
||
'surge_down': '暴跌',
|
||
'surge': '急涨(旧)',
|
||
'limit_up': '涨停增加',
|
||
'rank_jump': '排名跃升'
|
||
}.get(row[0], row[0])
|
||
avg_delta = f"{row[2]:.2f}%" if row[2] else "-"
|
||
print(f" {alert_type_name}: {row[1]} 条 (平均变化: {avg_delta})")
|
||
else:
|
||
print(" 今日暂无异动")
|
||
|
||
# 最新异动
|
||
print("\n最新异动 (前10条):")
|
||
result = conn.execute(text("""
|
||
SELECT concept_name, alert_type, alert_time, change_pct, change_delta, extra_info
|
||
FROM concept_minute_alert
|
||
WHERE trade_date = CURDATE()
|
||
ORDER BY alert_time DESC
|
||
LIMIT 10
|
||
"""))
|
||
rows = list(result)
|
||
if rows:
|
||
print(f" {'概念':<20} | {'类型':<6} | {'时间':<8} | {'涨幅':>6} | {'变化':>6} | {'Z分':>5}")
|
||
print(" " + "-" * 70)
|
||
for row in rows:
|
||
name = row[0][:18] if len(row[0]) > 18 else row[0]
|
||
alert_type = {'surge_up': '暴涨', 'surge_down': '暴跌'}.get(row[1], row[1][:4])
|
||
time_str = row[2].strftime('%H:%M') if row[2] else '-'
|
||
change = f"{row[3]:.2f}%" if row[3] else '-'
|
||
delta = f"{row[4]:+.2f}%" if row[4] else '-'
|
||
|
||
# 解析extra_info获取zscore
|
||
zscore = '-'
|
||
if row[5]:
|
||
try:
|
||
extra = json.loads(row[5]) if isinstance(row[5], str) else row[5]
|
||
zscore = f"{extra.get('zscore', 0):.1f}"
|
||
except:
|
||
pass
|
||
|
||
print(f" {name:<20} | {alert_type:<6} | {time_str:<8} | {change:>6} | {delta:>6} | {zscore:>5}")
|
||
else:
|
||
print(" 暂无异动记录")
|
||
|
||
except Exception as e:
|
||
print(f" 查询失败: {e}")
|
||
|
||
|
||
def train_model():
|
||
"""训练SVM模型"""
|
||
if not SKLEARN_AVAILABLE:
|
||
print("错误: sklearn未安装,无法训练模型")
|
||
print("安装命令: pip install scikit-learn")
|
||
return
|
||
|
||
logger.info("=" * 60)
|
||
logger.info("🎓 开始训练SVM模型")
|
||
logger.info("=" * 60)
|
||
|
||
# 加载概念数据
|
||
leaf_concepts = get_all_concepts()
|
||
parent_concepts = load_hierarchy_concepts(leaf_concepts)
|
||
all_concepts = leaf_concepts + parent_concepts
|
||
|
||
all_stocks = set()
|
||
for c in all_concepts:
|
||
all_stocks.update(c['stocks'])
|
||
all_stocks = list(all_stocks)
|
||
|
||
logger.info(f"概念数: {len(all_concepts)}, 股票数: {len(all_stocks)}")
|
||
|
||
# 收集训练数据(使用历史异动数据)
|
||
training_features = []
|
||
|
||
# 从最近N天的数据中提取特征
|
||
lookback_days = 30
|
||
|
||
try:
|
||
with MYSQL_ENGINE.connect() as conn:
|
||
result = conn.execute(text("""
|
||
SELECT change_pct, change_delta, rank_position, limit_up_count,
|
||
stock_count, index_change_pct, extra_info
|
||
FROM concept_minute_alert
|
||
WHERE trade_date >= DATE_SUB(CURDATE(), INTERVAL :days DAY)
|
||
"""), {'days': lookback_days})
|
||
|
||
for row in result:
|
||
features = [
|
||
float(row[0]) if row[0] else 0, # change_pct
|
||
float(row[1]) if row[1] else 0, # change_delta
|
||
0, # change_delta_10min (not available)
|
||
-float(row[2]) if row[2] else 0, # rank_delta (approximation)
|
||
float(row[3]) / max(1, float(row[4])) * 100 if row[3] and row[4] else 0, # limit_up_ratio
|
||
0, # volume_ratio
|
||
float(row[0]) - float(row[5]) if row[0] and row[5] else 0, # index_correlation
|
||
]
|
||
training_features.append(features)
|
||
|
||
logger.info(f"收集到 {len(training_features)} 条训练数据")
|
||
|
||
if len(training_features) >= 100:
|
||
success = train_svm_model(training_features)
|
||
if success:
|
||
logger.info("✅ 模型训练成功!")
|
||
else:
|
||
logger.error("❌ 模型训练失败")
|
||
else:
|
||
logger.warning("训练数据不足100条,跳过训练")
|
||
|
||
except Exception as e:
|
||
logger.error(f"训练失败: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
|
||
|
||
# ==================== 主函数 ====================
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(description='智能概念异动检测服务 (Z-Score + SVM)')
|
||
parser.add_argument('command', nargs='?', default='realtime',
|
||
choices=['realtime', 'once', 'status', 'backtest', 'train'],
|
||
help='命令: realtime(实时运行), once(单次运行), status(状态), backtest(回测), train(训练模型)')
|
||
parser.add_argument('--date', '-d', type=str, default=None,
|
||
help='回测日期,格式: YYYY-MM-DD,默认为今天')
|
||
parser.add_argument('--keep', '-k', action='store_true',
|
||
help='回测时保留已有数据(默认会清除)')
|
||
|
||
args = parser.parse_args()
|
||
|
||
if args.command == 'realtime':
|
||
run_realtime()
|
||
elif args.command == 'once':
|
||
run_single()
|
||
elif args.command == 'status':
|
||
show_status()
|
||
elif args.command == 'backtest':
|
||
trade_date = args.date or datetime.now().strftime('%Y-%m-%d')
|
||
clear_existing = not args.keep
|
||
run_backtest(trade_date, clear_existing)
|
||
elif args.command == 'train':
|
||
train_model()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|