#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 概念异动智能检测服务 - 基于 Z-Score + SVM - Z-Score: 动态阈值,根据历史波动率判断异常 - SVM: 多特征分类,综合判断是否为有意义的异动 - 支持暴涨和暴跌检测 - 异动重要性评分,避免图表过于密集 """ import pandas as pd import numpy as np from datetime import datetime, timedelta from sqlalchemy import create_engine, text from elasticsearch import Elasticsearch from clickhouse_driver import Client from collections import deque import time import logging import json import os import hashlib import argparse import pickle from typing import Dict, List, Optional, Tuple # 尝试导入sklearn,如果不存在则提示安装 try: from sklearn.svm import OneClassSVM from sklearn.preprocessing import StandardScaler from sklearn.ensemble import IsolationForest SKLEARN_AVAILABLE = True except ImportError: SKLEARN_AVAILABLE = False print("警告: sklearn 未安装,将使用纯 Z-Score 方法") print("安装命令: pip install scikit-learn") # ==================== 配置 ==================== # MySQL配置 MYSQL_ENGINE = create_engine( "mysql+pymysql://root:Zzl5588161!@222.128.1.157:33060/stock", echo=False ) # Elasticsearch配置 ES_CLIENT = Elasticsearch(['http://222.128.1.157:19200']) INDEX_NAME = 'concept_library_v3' # ClickHouse配置 CLICKHOUSE_CONFIG = { 'host': '222.128.1.157', 'port': 18000, 'user': 'default', 'password': 'Zzl33818!', 'database': 'stock' } # 层级结构文件 HIERARCHY_FILE = 'concept_hierarchy_v3.json' # 模型保存路径 MODEL_DIR = 'models' os.makedirs(MODEL_DIR, exist_ok=True) # ==================== 智能异动检测配置 ==================== SMART_ALERT_CONFIG = { # Z-Score 配置 'zscore': { 'enabled': True, 'lookback_days': 20, # 历史数据回看天数 'threshold_up': 2.5, # 上涨异动阈值(标准差倍数) 'threshold_down': -2.5, # 下跌异动阈值(标准差倍数) 'min_data_points': 10, # 最少数据点数 }, # SVM 异常检测配置 'svm': { 'enabled': SKLEARN_AVAILABLE, 'nu': 0.05, # 异常比例(预期5%为异常) 'kernel': 'rbf', # 核函数 'gamma': 'auto', 'retrain_days': 7, # 每N天重新训练 }, # 特征配置(用于SVM) 'features': [ 'change_pct', # 当前涨跌幅 'change_delta_5min', # 5分钟涨跌幅变化 'change_delta_10min', # 10分钟涨跌幅变化 'rank_delta_5min', # 5分钟排名变化 'limit_up_ratio', # 涨停股占比 'volume_ratio', # 成交量比率(预留) 'index_correlation', # 与指数相关性 ], # 重要性评分权重 'importance_weights': { 'zscore_abs': 0.3, # Z-Score 绝对值 'rank_position': 0.2, # 排名位置(越靠前越重要) 'limit_up_count': 0.2, # 涨停数 'stock_count': 0.1, # 概念股票数 'change_magnitude': 0.2, # 涨跌幅度 }, # 显示控制 'display': { 'max_alerts_per_hour': 20, # 每小时最多显示异动数 'min_importance_score': 0.3, # 最低重要性分数 'cooldown_minutes': 15, # 同一概念冷却时间 }, } # 参考指数 REFERENCE_INDEX = '000001.SH' # ==================== 日志配置 ==================== logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler(f'concept_alert_ml_{datetime.now().strftime("%Y%m%d")}.log', encoding='utf-8'), logging.StreamHandler() ] ) logger = logging.getLogger(__name__) # ==================== 全局变量 ==================== ch_client = None # 历史统计数据缓存 # 结构: {concept_id: {'mean': float, 'std': float, 'history': deque}} stats_cache: Dict[str, dict] = {} # 分钟级历史缓存(用于计算变化率) minute_cache: Dict[str, deque] = {} MINUTE_WINDOW = 15 # 保留15分钟数据 # 冷却记录 cooldown_cache: Dict[Tuple[str, str], datetime] = {} # SVM 模型 svm_model = None svm_scaler = None svm_last_train = None def get_ch_client(): """获取ClickHouse客户端""" global ch_client if ch_client is None: ch_client = Client(**CLICKHOUSE_CONFIG) return ch_client def generate_id(name: str) -> str: """生成概念ID""" return hashlib.md5(name.encode('utf-8')).hexdigest()[:16] def code_to_ch_format(code: str) -> str: """将6位股票代码转换为ClickHouse格式""" if not code or len(code) != 6 or not code.isdigit(): return None if code.startswith('6'): return f"{code}.SH" elif code.startswith('0') or code.startswith('3'): return f"{code}.SZ" else: return f"{code}.BJ" # ==================== 概念数据获取(复用原有代码)==================== def get_all_concepts(): """从ES获取所有叶子概念及其股票列表""" concepts = [] query = { "query": {"match_all": {}}, "size": 100, "_source": ["concept_id", "concept", "stocks"] } resp = ES_CLIENT.search(index=INDEX_NAME, body=query, scroll='2m') scroll_id = resp['_scroll_id'] hits = resp['hits']['hits'] while len(hits) > 0: for hit in hits: source = hit['_source'] concept_info = { 'concept_id': source.get('concept_id'), 'concept_name': source.get('concept'), 'stocks': [], 'concept_type': 'leaf' } if 'stocks' in source and isinstance(source['stocks'], list): for stock in source['stocks']: if isinstance(stock, dict) and 'code' in stock and stock['code']: concept_info['stocks'].append(stock['code']) if concept_info['stocks']: concepts.append(concept_info) resp = ES_CLIENT.scroll(scroll_id=scroll_id, scroll='2m') scroll_id = resp['_scroll_id'] hits = resp['hits']['hits'] ES_CLIENT.clear_scroll(scroll_id=scroll_id) return concepts def load_hierarchy_concepts(leaf_concepts: list) -> list: """加载层级结构,生成母概念""" hierarchy_path = os.path.join(os.path.dirname(__file__), HIERARCHY_FILE) if not os.path.exists(hierarchy_path): logger.warning(f"层级文件不存在: {hierarchy_path}") return [] with open(hierarchy_path, 'r', encoding='utf-8') as f: hierarchy_data = json.load(f) concept_to_stocks = {} for c in leaf_concepts: concept_to_stocks[c['concept_name']] = set(c['stocks']) parent_concepts = [] for lv1 in hierarchy_data.get('hierarchy', []): lv1_name = lv1.get('lv1', '') lv1_stocks = set() for child in lv1.get('children', []): lv2_name = child.get('lv2', '') lv2_stocks = set() if 'children' in child: for lv3_child in child.get('children', []): lv3_name = lv3_child.get('lv3', '') lv3_stocks = set() for concept_name in lv3_child.get('concepts', []): if concept_name in concept_to_stocks: lv3_stocks.update(concept_to_stocks[concept_name]) if lv3_stocks: parent_concepts.append({ 'concept_id': generate_id(f"lv3_{lv3_name}"), 'concept_name': f"[三级] {lv3_name}", 'stocks': list(lv3_stocks), 'concept_type': 'lv3' }) lv2_stocks.update(lv3_stocks) else: for concept_name in child.get('concepts', []): if concept_name in concept_to_stocks: lv2_stocks.update(concept_to_stocks[concept_name]) if lv2_stocks: parent_concepts.append({ 'concept_id': generate_id(f"lv2_{lv2_name}"), 'concept_name': f"[二级] {lv2_name}", 'stocks': list(lv2_stocks), 'concept_type': 'lv2' }) lv1_stocks.update(lv2_stocks) if lv1_stocks: parent_concepts.append({ 'concept_id': generate_id(f"lv1_{lv1_name}"), 'concept_name': f"[一级] {lv1_name}", 'stocks': list(lv1_stocks), 'concept_type': 'lv1' }) return parent_concepts # ==================== 价格数据获取 ==================== def get_base_prices(stock_codes: list, current_date: str) -> dict: """获取昨收价作为基准""" if not stock_codes: return {} valid_codes = [code for code in stock_codes if code and len(code) == 6 and code.isdigit()] if not valid_codes: return {} stock_codes_str = "','".join(valid_codes) query = f""" SELECT SECCODE, F002N FROM ea_trade WHERE SECCODE IN ('{stock_codes_str}') AND TRADEDATE = ( SELECT MAX(TRADEDATE) FROM ea_trade WHERE TRADEDATE < '{current_date}' ) AND F002N IS NOT NULL AND F002N > 0 """ try: with MYSQL_ENGINE.connect() as conn: result = conn.execute(text(query)) base_prices = {row[0]: float(row[1]) for row in result if row[1] and float(row[1]) > 0} return base_prices except Exception as e: logger.error(f"获取基准价格失败: {e}") return {} def get_latest_prices(stock_codes: list) -> dict: """从ClickHouse获取最新价格""" if not stock_codes: return {} client = get_ch_client() ch_codes = [] code_mapping = {} for code in stock_codes: ch_code = code_to_ch_format(code) if ch_code: ch_codes.append(ch_code) code_mapping[ch_code] = code if not ch_codes: return {} ch_codes_str = "','".join(ch_codes) query = f""" SELECT code, close, timestamp FROM ( SELECT code, close, timestamp, ROW_NUMBER() OVER (PARTITION BY code ORDER BY timestamp DESC) as rn FROM stock_minute WHERE code IN ('{ch_codes_str}') AND toDate(timestamp) = today() ) WHERE rn = 1 """ try: result = client.execute(query) if not result: return {} latest_prices = {} for row in result: ch_code, close, ts = row if close and close > 0: pure_code = code_mapping.get(ch_code) if pure_code: latest_prices[pure_code] = { 'close': float(close), 'timestamp': ts } return latest_prices except Exception as e: logger.error(f"获取最新价格失败: {e}") return {} def get_prices_at_time(stock_codes: list, timestamp: datetime) -> dict: """获取指定时间点的股票价格""" if not stock_codes: return {} client = get_ch_client() ch_codes = [] code_mapping = {} for code in stock_codes: ch_code = code_to_ch_format(code) if ch_code: ch_codes.append(ch_code) code_mapping[ch_code] = code if not ch_codes: return {} ch_codes_str = "','".join(ch_codes) query = f""" SELECT code, close, timestamp FROM stock_minute WHERE code IN ('{ch_codes_str}') AND timestamp = '{timestamp.strftime('%Y-%m-%d %H:%M:%S')}' """ try: result = client.execute(query) prices = {} for row in result: ch_code, close, ts = row if close and close > 0: pure_code = code_mapping.get(ch_code) if pure_code: prices[pure_code] = { 'close': float(close), 'timestamp': ts } return prices except Exception as e: logger.error(f"获取历史价格失败: {e}") return {} def get_index_realtime(index_code: str = REFERENCE_INDEX) -> dict: """获取指数实时数据""" client = get_ch_client() try: query = f""" SELECT close, timestamp FROM index_minute WHERE code = '{index_code}' AND toDate(timestamp) = today() ORDER BY timestamp DESC LIMIT 1 """ result = client.execute(query) if not result: return None close, ts = result[0] # 获取昨收价 prev_close = get_index_prev_close(index_code, datetime.now().strftime('%Y-%m-%d')) change_pct = None if close and prev_close and prev_close > 0: change_pct = (float(close) - prev_close) / prev_close * 100 return { 'code': index_code, 'price': float(close), 'prev_close': prev_close, 'change_pct': round(change_pct, 4) if change_pct else None, 'timestamp': ts } except Exception as e: logger.error(f"获取指数数据失败: {e}") return None def get_index_at_time(index_code: str, timestamp: datetime, prev_close: float) -> dict: """获取指定时间点的指数数据""" client = get_ch_client() query = f""" SELECT close, timestamp FROM index_minute WHERE code = '{index_code}' AND timestamp = '{timestamp.strftime('%Y-%m-%d %H:%M:%S')}' LIMIT 1 """ try: result = client.execute(query) if not result: return None close, ts = result[0] change_pct = None if close and prev_close and prev_close > 0: change_pct = (float(close) - prev_close) / prev_close * 100 return { 'code': index_code, 'price': float(close), 'prev_close': prev_close, 'change_pct': round(change_pct, 4) if change_pct else None, 'timestamp': ts } except Exception as e: logger.error(f"获取指数数据失败: {e}") return None def get_index_prev_close(index_code: str, trade_date: str) -> float: """获取指数昨收价""" code_no_suffix = index_code.split('.')[0] with MYSQL_ENGINE.connect() as conn: result = conn.execute(text(""" SELECT F006N FROM ea_exchangetrade WHERE INDEXCODE = :code AND TRADEDATE < :today ORDER BY TRADEDATE DESC LIMIT 1 """), { 'code': code_no_suffix, 'today': trade_date }).fetchone() if result and result[0]: return float(result[0]) return None # ==================== 涨跌幅计算 ==================== def calculate_change_pct(base_prices: dict, latest_prices: dict) -> dict: """计算涨跌幅""" changes = {} for code, latest in latest_prices.items(): if code in base_prices and base_prices[code] > 0: base = base_prices[code] close = latest['close'] change_pct = (close - base) / base * 100 changes[code] = { 'change_pct': round(change_pct, 4), 'close': close, 'base': base } return changes def calculate_concept_stats(concepts: list, stock_changes: dict) -> list: """计算概念统计""" stats = [] for concept in concepts: concept_id = concept['concept_id'] concept_name = concept['concept_name'] stock_codes = concept['stocks'] concept_type = concept.get('concept_type', 'leaf') changes = [] limit_up_count = 0 limit_down_count = 0 limit_up_stocks = [] limit_down_stocks = [] for code in stock_codes: if code in stock_changes: change_info = stock_changes[code] change_pct = change_info['change_pct'] changes.append(change_pct) # 涨停判断(涨幅 >= 9.8%) if change_pct >= 9.8: limit_up_count += 1 limit_up_stocks.append(code) # 跌停判断(跌幅 <= -9.8%) elif change_pct <= -9.8: limit_down_count += 1 limit_down_stocks.append(code) if not changes: continue avg_change_pct = round(np.mean(changes), 4) std_change_pct = round(np.std(changes), 4) if len(changes) > 1 else 0 stats.append({ 'concept_id': concept_id, 'concept_name': concept_name, 'avg_change_pct': avg_change_pct, 'std_change_pct': std_change_pct, 'stock_count': len(changes), 'concept_type': concept_type, 'limit_up_count': limit_up_count, 'limit_down_count': limit_down_count, 'limit_up_stocks': limit_up_stocks, 'limit_down_stocks': limit_down_stocks, 'limit_up_ratio': limit_up_count / len(changes) if changes else 0, 'limit_down_ratio': limit_down_count / len(changes) if changes else 0, }) # 按涨幅排序并添加排名 stats.sort(key=lambda x: x['avg_change_pct'], reverse=True) for i, item in enumerate(stats): item['rank'] = i + 1 return stats # ==================== Z-Score 计算 ==================== def load_historical_stats(concept_id: str, lookback_days: int = 20) -> dict: """ 加载概念的历史统计数据(用于计算Z-Score) 从历史分钟数据中计算每日的平均涨跌幅变化 """ if concept_id in stats_cache and stats_cache[concept_id].get('loaded'): return stats_cache[concept_id] # 查询历史异动记录计算统计 try: with MYSQL_ENGINE.connect() as conn: # 获取历史每日变化统计 result = conn.execute(text(""" SELECT trade_date, AVG(change_delta) as avg_delta, MAX(change_delta) as max_delta, MIN(change_delta) as min_delta FROM concept_minute_alert WHERE concept_id = :concept_id AND trade_date >= DATE_SUB(CURDATE(), INTERVAL :days DAY) AND alert_type = 'surge' GROUP BY trade_date """), {'concept_id': concept_id, 'days': lookback_days}) rows = list(result) if len(rows) >= SMART_ALERT_CONFIG['zscore']['min_data_points']: deltas = [float(r[1]) for r in rows if r[1]] stats_cache[concept_id] = { 'mean': np.mean(deltas), 'std': np.std(deltas) if len(deltas) > 1 else 1.0, 'count': len(deltas), 'loaded': True } else: # 数据不足,使用默认值 stats_cache[concept_id] = { 'mean': 0.5, # 默认平均变化0.5% 'std': 0.8, # 默认标准差0.8% 'count': 0, 'loaded': True } except Exception as e: logger.error(f"加载历史统计失败: {e}") stats_cache[concept_id] = { 'mean': 0.5, 'std': 0.8, 'count': 0, 'loaded': True } return stats_cache[concept_id] def calculate_zscore(concept_id: str, change_delta: float) -> float: """ 计算涨跌幅变化的Z-Score Z = (X - μ) / σ """ stats = load_historical_stats(concept_id) mean = stats['mean'] std = stats['std'] # 避免除零 if std < 0.1: std = 0.1 zscore = (change_delta - mean) / std return round(zscore, 4) def update_minute_cache(concept_id: str, timestamp: datetime, data: dict): """更新分钟级缓存""" if concept_id not in minute_cache: minute_cache[concept_id] = deque(maxlen=MINUTE_WINDOW) minute_cache[concept_id].append({ 'timestamp': timestamp, **data }) def get_minute_history(concept_id: str, minutes_ago: int) -> Optional[dict]: """获取N分钟前的数据""" if concept_id not in minute_cache: return None history = minute_cache[concept_id] if not history: return None # 获取当前最新时间 current_time = history[-1]['timestamp'] if history else datetime.now() target_time = current_time - timedelta(minutes=minutes_ago) # 找到最接近目标时间的记录 for record in reversed(list(history)): if record['timestamp'] <= target_time: return record return None # ==================== SVM 异常检测 ==================== def extract_features(stat: dict, index_data: dict) -> np.ndarray: """ 提取用于SVM的特征向量 """ concept_id = stat['concept_id'] # 获取历史数据 prev_5min = get_minute_history(concept_id, 5) prev_10min = get_minute_history(concept_id, 10) features = [] # 1. 当前涨跌幅 features.append(stat['avg_change_pct']) # 2. 5分钟涨跌幅变化 if prev_5min: features.append(stat['avg_change_pct'] - prev_5min.get('change_pct', 0)) else: features.append(0) # 3. 10分钟涨跌幅变化 if prev_10min: features.append(stat['avg_change_pct'] - prev_10min.get('change_pct', 0)) else: features.append(0) # 4. 5分钟排名变化 if prev_5min: features.append(prev_5min.get('rank', stat['rank']) - stat['rank']) else: features.append(0) # 5. 涨停股占比 features.append(stat.get('limit_up_ratio', 0) * 100) # 6. 成交量比率(预留,暂用0) features.append(0) # 7. 与指数相关性(简化:涨跌方向一致性) if index_data and index_data.get('change_pct'): index_change = index_data['change_pct'] concept_change = stat['avg_change_pct'] # 同向为正,反向为负 correlation = 1 if (index_change * concept_change > 0) else -1 features.append(correlation * abs(concept_change - index_change)) else: features.append(0) return np.array(features) def train_svm_model(training_data: List[np.ndarray]): """ 训练 OneClass SVM 模型 用于检测异常模式 """ global svm_model, svm_scaler, svm_last_train if not SKLEARN_AVAILABLE: return False if len(training_data) < 100: logger.warning(f"训练数据不足: {len(training_data)} 条") return False try: X = np.array(training_data) # 标准化 svm_scaler = StandardScaler() X_scaled = svm_scaler.fit_transform(X) # 训练 OneClass SVM svm_model = OneClassSVM( nu=SMART_ALERT_CONFIG['svm']['nu'], kernel=SMART_ALERT_CONFIG['svm']['kernel'], gamma=SMART_ALERT_CONFIG['svm']['gamma'] ) svm_model.fit(X_scaled) svm_last_train = datetime.now() # 保存模型 model_path = os.path.join(MODEL_DIR, 'svm_model.pkl') scaler_path = os.path.join(MODEL_DIR, 'svm_scaler.pkl') with open(model_path, 'wb') as f: pickle.dump(svm_model, f) with open(scaler_path, 'wb') as f: pickle.dump(svm_scaler, f) logger.info(f"SVM模型训练完成,使用 {len(training_data)} 条数据") return True except Exception as e: logger.error(f"SVM模型训练失败: {e}") return False def load_svm_model(): """加载已保存的SVM模型""" global svm_model, svm_scaler if not SKLEARN_AVAILABLE: return False model_path = os.path.join(MODEL_DIR, 'svm_model.pkl') scaler_path = os.path.join(MODEL_DIR, 'svm_scaler.pkl') if os.path.exists(model_path) and os.path.exists(scaler_path): try: with open(model_path, 'rb') as f: svm_model = pickle.load(f) with open(scaler_path, 'rb') as f: svm_scaler = pickle.load(f) logger.info("SVM模型加载成功") return True except Exception as e: logger.error(f"SVM模型加载失败: {e}") return False def predict_anomaly(features: np.ndarray) -> Tuple[bool, float]: """ 使用SVM预测是否为异常 返回: (是否异常, 异常分数) """ global svm_model, svm_scaler if svm_model is None or svm_scaler is None: return False, 0.0 try: X_scaled = svm_scaler.transform(features.reshape(1, -1)) prediction = svm_model.predict(X_scaled)[0] score = svm_model.decision_function(X_scaled)[0] # prediction: 1 = 正常, -1 = 异常 is_anomaly = prediction == -1 return is_anomaly, float(score) except Exception as e: logger.error(f"SVM预测失败: {e}") return False, 0.0 # ==================== 重要性评分 ==================== def calculate_importance_score( zscore: float, rank: int, limit_up_count: int, stock_count: int, change_pct: float, total_concepts: int ) -> float: """ 计算异动的重要性分数(0-1) 综合多个因素判断这条异动是否值得显示 """ weights = SMART_ALERT_CONFIG['importance_weights'] scores = {} # 1. Z-Score 绝对值(越大越重要) zscore_score = min(abs(zscore) / 5.0, 1.0) # 5倍标准差为满分 scores['zscore_abs'] = zscore_score # 2. 排名位置(越靠前越重要) rank_score = max(0, 1 - (rank - 1) / min(100, total_concepts)) scores['rank_position'] = rank_score # 3. 涨停数(越多越重要) limit_score = min(limit_up_count / 5.0, 1.0) # 5个涨停为满分 scores['limit_up_count'] = limit_score # 4. 概念股票数(适中最好) if stock_count < 10: stock_score = stock_count / 10.0 elif stock_count > 100: stock_score = max(0.5, 1 - (stock_count - 100) / 200) else: stock_score = 1.0 scores['stock_count'] = stock_score # 5. 涨跌幅度 change_score = min(abs(change_pct) / 5.0, 1.0) # 5%为满分 scores['change_magnitude'] = change_score # 加权求和 total_score = sum(scores[k] * weights[k] for k in weights) return round(total_score, 4) # ==================== 智能异动检测 ==================== def check_cooldown(concept_id: str, alert_type: str) -> bool: """检查是否在冷却期""" key = (concept_id, alert_type) cooldown_minutes = SMART_ALERT_CONFIG['display']['cooldown_minutes'] if key in cooldown_cache: last_alert = cooldown_cache[key] if datetime.now() - last_alert < timedelta(minutes=cooldown_minutes): return True return False def set_cooldown(concept_id: str, alert_type: str, alert_time: datetime = None): """设置冷却""" cooldown_cache[(concept_id, alert_type)] = alert_time or datetime.now() def detect_smart_alerts( current_stats: list, index_data: dict, trade_date: str, current_time: datetime = None ) -> list: """ 智能异动检测 结合 Z-Score + SVM 进行检测 """ alerts = [] current_time = current_time or datetime.now() total_concepts = len(current_stats) for stat in current_stats: concept_id = stat['concept_id'] concept_name = stat['concept_name'] change_pct = stat['avg_change_pct'] rank = stat['rank'] limit_up_count = stat['limit_up_count'] limit_down_count = stat['limit_down_count'] stock_count = stat['stock_count'] concept_type = stat['concept_type'] # 更新分钟缓存 update_minute_cache(concept_id, current_time, { 'change_pct': change_pct, 'rank': rank, 'limit_up_count': limit_up_count, 'limit_down_count': limit_down_count }) # 获取历史数据计算变化 prev_5min = get_minute_history(concept_id, 5) if not prev_5min: continue change_delta = change_pct - prev_5min.get('change_pct', 0) # ========== Z-Score 检测 ========== zscore = calculate_zscore(concept_id, abs(change_delta)) # 判断是涨还是跌 is_surge_up = change_delta > 0 and zscore >= SMART_ALERT_CONFIG['zscore']['threshold_up'] is_surge_down = change_delta < 0 and zscore >= abs(SMART_ALERT_CONFIG['zscore']['threshold_down']) if not (is_surge_up or is_surge_down): continue alert_type = 'surge_up' if is_surge_up else 'surge_down' # 检查冷却 if check_cooldown(concept_id, alert_type): continue # ========== SVM 验证(可选)========== svm_is_anomaly = False svm_score = 0.0 if SMART_ALERT_CONFIG['svm']['enabled'] and svm_model is not None: features = extract_features(stat, index_data) svm_is_anomaly, svm_score = predict_anomaly(features) # 如果SVM认为不是异常,降低Z-Score要求 if not svm_is_anomaly and abs(zscore) < 3.5: continue # ========== 计算重要性分数 ========== importance = calculate_importance_score( zscore=zscore, rank=rank, limit_up_count=limit_up_count if is_surge_up else limit_down_count, stock_count=stock_count, change_pct=change_pct, total_concepts=total_concepts ) # 过滤低重要性 if importance < SMART_ALERT_CONFIG['display']['min_importance_score']: continue # ========== 创建异动记录 ========== alert = { 'concept_id': concept_id, 'concept_name': concept_name, 'alert_type': alert_type, 'alert_time': current_time, 'trade_date': trade_date, 'change_pct': change_pct, 'prev_change_pct': prev_5min.get('change_pct'), 'change_delta': round(change_delta, 4), 'zscore': zscore, 'svm_score': svm_score, 'importance_score': importance, 'limit_up_count': limit_up_count, 'limit_down_count': limit_down_count, 'prev_limit_up_count': prev_5min.get('limit_up_count', 0), 'rank_position': rank, 'prev_rank_position': prev_5min.get('rank'), 'rank_delta': (prev_5min.get('rank', rank) - rank) if prev_5min else 0, 'stock_count': stock_count, 'concept_type': concept_type, 'index_code': REFERENCE_INDEX, 'index_price': index_data['price'] if index_data else None, 'index_change_pct': index_data['change_pct'] if index_data else None, 'extra_info': { 'limit_up_stocks': stat.get('limit_up_stocks', []), 'limit_down_stocks': stat.get('limit_down_stocks', []), } } alerts.append(alert) set_cooldown(concept_id, alert_type, current_time) # 日志 direction = "🔥 暴涨" if is_surge_up else "💧 暴跌" logger.info( f"{direction}: {concept_name} " f"涨幅 {prev_5min.get('change_pct', 0):.2f}% -> {change_pct:.2f}% " f"(Δ{change_delta:+.2f}%, Z={zscore:.2f}, 重要性={importance:.2f})" ) # 按重要性排序,限制数量 alerts.sort(key=lambda x: x['importance_score'], reverse=True) max_alerts = SMART_ALERT_CONFIG['display']['max_alerts_per_hour'] return alerts[:max_alerts] # ==================== 数据持久化 ==================== def save_alerts_to_mysql(alerts: list): """保存异动数据到MySQL""" if not alerts: return 0 saved = 0 with MYSQL_ENGINE.begin() as conn: for alert in alerts: try: insert_sql = text(""" INSERT INTO concept_minute_alert (concept_id, concept_name, alert_time, alert_type, trade_date, change_pct, prev_change_pct, change_delta, limit_up_count, prev_limit_up_count, limit_up_delta, rank_position, prev_rank_position, rank_delta, index_code, index_price, index_change_pct, stock_count, concept_type, extra_info) VALUES (:concept_id, :concept_name, :alert_time, :alert_type, :trade_date, :change_pct, :prev_change_pct, :change_delta, :limit_up_count, :prev_limit_up_count, :limit_up_delta, :rank_position, :prev_rank_position, :rank_delta, :index_code, :index_price, :index_change_pct, :stock_count, :concept_type, :extra_info) """) # 计算 limit_up_delta limit_up_delta = alert.get('limit_up_count', 0) - alert.get('prev_limit_up_count', 0) params = { 'concept_id': alert['concept_id'], 'concept_name': alert['concept_name'], 'alert_time': alert['alert_time'], 'alert_type': alert['alert_type'], 'trade_date': alert['trade_date'], 'change_pct': alert.get('change_pct'), 'prev_change_pct': alert.get('prev_change_pct'), 'change_delta': alert.get('change_delta'), 'limit_up_count': alert.get('limit_up_count', 0), 'prev_limit_up_count': alert.get('prev_limit_up_count', 0), 'limit_up_delta': limit_up_delta, 'rank_position': alert.get('rank_position'), 'prev_rank_position': alert.get('prev_rank_position'), 'rank_delta': alert.get('rank_delta'), 'index_code': alert.get('index_code', REFERENCE_INDEX), 'index_price': alert.get('index_price'), 'index_change_pct': alert.get('index_change_pct'), 'stock_count': alert.get('stock_count'), 'concept_type': alert.get('concept_type', 'leaf'), 'extra_info': json.dumps({ **alert.get('extra_info', {}), 'zscore': alert.get('zscore'), 'svm_score': alert.get('svm_score'), 'importance_score': alert.get('importance_score'), }, ensure_ascii=False) if alert.get('extra_info') else None } conn.execute(insert_sql, params) saved += 1 except Exception as e: logger.error(f"保存异动失败: {alert['concept_name']} - {e}") return saved def save_index_snapshot(index_data: dict, trade_date: str): """保存指数快照""" if not index_data: return try: with MYSQL_ENGINE.begin() as conn: upsert_sql = text(""" REPLACE INTO index_minute_snapshot (index_code, trade_date, snapshot_time, price, prev_close, change_pct) VALUES (:index_code, :trade_date, :snapshot_time, :price, :prev_close, :change_pct) """) conn.execute(upsert_sql, { 'index_code': index_data['code'], 'trade_date': trade_date, 'snapshot_time': index_data['timestamp'], 'price': index_data['price'], 'prev_close': index_data.get('prev_close'), 'change_pct': index_data.get('change_pct') }) except Exception as e: logger.error(f"保存指数快照失败: {e}") # ==================== 交易时间判断 ==================== def is_trading_time() -> bool: """判断当前是否为交易时间""" now = datetime.now() weekday = now.weekday() if weekday >= 5: return False hour, minute = now.hour, now.minute current_time = hour * 60 + minute morning_start = 9 * 60 + 30 morning_end = 11 * 60 + 30 afternoon_start = 13 * 60 afternoon_end = 15 * 60 return (morning_start <= current_time <= morning_end) or \ (afternoon_start <= current_time <= afternoon_end) def get_next_update_time() -> int: """获取距离下次更新的秒数""" now = datetime.now() if is_trading_time(): return 60 - now.second else: hour, minute = now.hour, now.minute if hour < 9 or (hour == 9 and minute < 30): target = now.replace(hour=9, minute=30, second=0, microsecond=0) elif (hour == 11 and minute >= 30) or hour == 12: target = now.replace(hour=13, minute=0, second=0, microsecond=0) elif hour >= 15: target = (now + timedelta(days=1)).replace(hour=9, minute=30, second=0, microsecond=0) else: target = now + timedelta(minutes=1) wait_seconds = (target - now).total_seconds() return max(60, int(wait_seconds)) # ==================== 主运行逻辑 ==================== def run_once(concepts: list, all_stocks: list) -> tuple: """执行一次检测""" now = datetime.now() trade_date = now.strftime('%Y-%m-%d') # 获取基准价格 base_prices = get_base_prices(all_stocks, trade_date) if not base_prices: logger.warning("无法获取基准价格") return 0, 0 # 获取最新价格 latest_prices = get_latest_prices(all_stocks) if not latest_prices: logger.warning("无法获取最新价格") return 0, 0 # 获取指数数据 index_data = get_index_realtime(REFERENCE_INDEX) if index_data: save_index_snapshot(index_data, trade_date) # 计算涨跌幅 stock_changes = calculate_change_pct(base_prices, latest_prices) if not stock_changes: logger.warning("无涨跌幅数据") return 0, 0 logger.info(f"获取到 {len(stock_changes)} 只股票的涨跌幅") # 计算概念统计 stats = calculate_concept_stats(concepts, stock_changes) logger.info(f"计算了 {len(stats)} 个概念的涨跌幅") # 智能异动检测 alerts = detect_smart_alerts(stats, index_data, trade_date, now) # 保存异动 if alerts: saved = save_alerts_to_mysql(alerts) logger.info(f"💾 保存了 {saved} 条异动记录") return len(stats), len(alerts) def run_realtime(): """实时检测主循环""" logger.info("=" * 60) logger.info("🚀 启动智能概念异动检测服务 (Z-Score + SVM)") logger.info("=" * 60) logger.info(f"配置: {json.dumps(SMART_ALERT_CONFIG, indent=2, ensure_ascii=False, default=str)}") # 尝试加载SVM模型 if SKLEARN_AVAILABLE: load_svm_model() # 加载概念数据 logger.info("加载概念数据...") leaf_concepts = get_all_concepts() logger.info(f"获取到 {len(leaf_concepts)} 个叶子概念") parent_concepts = load_hierarchy_concepts(leaf_concepts) logger.info(f"生成了 {len(parent_concepts)} 个母概念") all_concepts = leaf_concepts + parent_concepts logger.info(f"总计 {len(all_concepts)} 个概念") # 收集所有股票代码 all_stocks = set() for c in all_concepts: all_stocks.update(c['stocks']) all_stocks = list(all_stocks) logger.info(f"监控 {len(all_stocks)} 只股票") last_concept_update = datetime.now() total_alerts = 0 while True: try: now = datetime.now() # 每小时重新加载概念数据 if (now - last_concept_update).total_seconds() > 3600: logger.info("重新加载概念数据...") leaf_concepts = get_all_concepts() parent_concepts = load_hierarchy_concepts(leaf_concepts) all_concepts = leaf_concepts + parent_concepts all_stocks = set() for c in all_concepts: all_stocks.update(c['stocks']) all_stocks = list(all_stocks) last_concept_update = now logger.info(f"更新完成: {len(all_concepts)} 个概念, {len(all_stocks)} 只股票") # 检查是否交易时间 if not is_trading_time(): wait_sec = get_next_update_time() wait_min = wait_sec // 60 logger.info(f"⏰ 非交易时间,等待 {wait_min} 分钟后重试...") time.sleep(min(wait_sec, 300)) continue # 执行检测 logger.info(f"\n{'=' * 40}") logger.info(f"🔍 检测时间: {now.strftime('%Y-%m-%d %H:%M:%S')}") updated, alert_count = run_once(all_concepts, all_stocks) total_alerts += alert_count if alert_count > 0: logger.info(f"📊 本次检测到 {alert_count} 条异动,累计 {total_alerts} 条") # 等待下一分钟 sleep_sec = 60 - datetime.now().second logger.info(f"⏳ 等待 {sleep_sec} 秒后继续...") time.sleep(sleep_sec) except KeyboardInterrupt: logger.info("\n收到退出信号,停止服务...") break except Exception as e: logger.error(f"发生错误: {e}") import traceback traceback.print_exc() time.sleep(60) def run_single(): """单次运行""" logger.info("单次检测模式") if SKLEARN_AVAILABLE: load_svm_model() leaf_concepts = get_all_concepts() parent_concepts = load_hierarchy_concepts(leaf_concepts) all_concepts = leaf_concepts + parent_concepts all_stocks = set() for c in all_concepts: all_stocks.update(c['stocks']) all_stocks = list(all_stocks) logger.info(f"概念数: {len(all_concepts)}, 股票数: {len(all_stocks)}") updated, alerts = run_once(all_concepts, all_stocks) logger.info(f"检测完成: {updated} 个概念, {alerts} 条异动") # ==================== 回测功能 ==================== def get_minute_timestamps(trade_date: str) -> list: """获取指定交易日的所有分钟时间戳""" client = get_ch_client() query = f""" SELECT DISTINCT timestamp FROM stock_minute WHERE toDate(timestamp) = '{trade_date}' ORDER BY timestamp """ result = client.execute(query) return [row[0] for row in result] def run_backtest(trade_date: str, clear_existing: bool = True): """ 回测指定日期的异动检测 """ global minute_cache, cooldown_cache, stats_cache logger.info("=" * 60) logger.info(f"🔄 开始智能回测: {trade_date}") logger.info("=" * 60) # 清空缓存 minute_cache = {} cooldown_cache = {} stats_cache = {} # 清除已有数据 if clear_existing: with MYSQL_ENGINE.begin() as conn: conn.execute(text("DELETE FROM concept_minute_alert WHERE trade_date = :date"), {'date': trade_date}) conn.execute(text("DELETE FROM index_minute_snapshot WHERE trade_date = :date"), {'date': trade_date}) logger.info(f"已清除 {trade_date} 的已有数据") # 加载概念数据 logger.info("加载概念数据...") leaf_concepts = get_all_concepts() logger.info(f"获取到 {len(leaf_concepts)} 个叶子概念") parent_concepts = load_hierarchy_concepts(leaf_concepts) logger.info(f"生成了 {len(parent_concepts)} 个母概念") all_concepts = leaf_concepts + parent_concepts logger.info(f"总计 {len(all_concepts)} 个概念") # 收集所有股票代码 all_stocks = set() for c in all_concepts: all_stocks.update(c['stocks']) all_stocks = list(all_stocks) logger.info(f"监控 {len(all_stocks)} 只股票") # 获取基准价格(昨收价) base_prices = get_base_prices(all_stocks, trade_date) if not base_prices: logger.error("无法获取基准价格,退出回测") return logger.info(f"获取到 {len(base_prices)} 个基准价格") # 获取指数昨收价 index_prev_close = get_index_prev_close(REFERENCE_INDEX, trade_date) logger.info(f"指数昨收价: {index_prev_close}") # 获取所有分钟时间戳 timestamps = get_minute_timestamps(trade_date) if not timestamps: logger.error(f"未找到 {trade_date} 的分钟数据") return logger.info(f"找到 {len(timestamps)} 个分钟时间点") total_alerts = 0 processed = 0 # 逐分钟处理 for ts in timestamps: processed += 1 # 获取该时间点的价格 latest_prices = get_prices_at_time(all_stocks, ts) if not latest_prices: continue # 获取指数数据 index_data = get_index_at_time(REFERENCE_INDEX, ts, index_prev_close) if index_data: save_index_snapshot(index_data, trade_date) # 计算涨跌幅 stock_changes = calculate_change_pct(base_prices, latest_prices) if not stock_changes: continue # 计算概念统计 stats = calculate_concept_stats(all_concepts, stock_changes) # 智能异动检测 alerts = detect_smart_alerts(stats, index_data, trade_date, ts) # 保存异动 if alerts: saved = save_alerts_to_mysql(alerts) total_alerts += saved # 进度显示 if processed % 30 == 0: logger.info(f"进度: {processed}/{len(timestamps)} ({processed*100//len(timestamps)}%), 已检测到 {total_alerts} 条异动") logger.info("=" * 60) logger.info(f"✅ 回测完成!") logger.info(f" 处理分钟数: {processed}") logger.info(f" 检测到异动: {total_alerts} 条") logger.info("=" * 60) def show_status(): """显示状态""" print("\n" + "=" * 60) print("智能概念异动检测服务 (Z-Score + SVM) - 状态") print("=" * 60) now = datetime.now() print(f"\n当前时间: {now.strftime('%Y-%m-%d %H:%M:%S')}") print(f"是否交易时间: {'是' if is_trading_time() else '否'}") print(f"sklearn可用: {'是' if SKLEARN_AVAILABLE else '否'}") # 模型状态 model_path = os.path.join(MODEL_DIR, 'svm_model.pkl') print(f"SVM模型: {'已加载' if os.path.exists(model_path) else '未训练'}") # 今日异动统计 print("\n今日异动统计:") try: with MYSQL_ENGINE.connect() as conn: result = conn.execute(text(""" SELECT alert_type, COUNT(*) as cnt, AVG(change_delta) as avg_delta FROM concept_minute_alert WHERE trade_date = CURDATE() GROUP BY alert_type """)) rows = list(result) if rows: for row in rows: alert_type_name = { 'surge_up': '暴涨', 'surge_down': '暴跌', 'surge': '急涨(旧)', 'limit_up': '涨停增加', 'rank_jump': '排名跃升' }.get(row[0], row[0]) avg_delta = f"{row[2]:.2f}%" if row[2] else "-" print(f" {alert_type_name}: {row[1]} 条 (平均变化: {avg_delta})") else: print(" 今日暂无异动") # 最新异动 print("\n最新异动 (前10条):") result = conn.execute(text(""" SELECT concept_name, alert_type, alert_time, change_pct, change_delta, extra_info FROM concept_minute_alert WHERE trade_date = CURDATE() ORDER BY alert_time DESC LIMIT 10 """)) rows = list(result) if rows: print(f" {'概念':<20} | {'类型':<6} | {'时间':<8} | {'涨幅':>6} | {'变化':>6} | {'Z分':>5}") print(" " + "-" * 70) for row in rows: name = row[0][:18] if len(row[0]) > 18 else row[0] alert_type = {'surge_up': '暴涨', 'surge_down': '暴跌'}.get(row[1], row[1][:4]) time_str = row[2].strftime('%H:%M') if row[2] else '-' change = f"{row[3]:.2f}%" if row[3] else '-' delta = f"{row[4]:+.2f}%" if row[4] else '-' # 解析extra_info获取zscore zscore = '-' if row[5]: try: extra = json.loads(row[5]) if isinstance(row[5], str) else row[5] zscore = f"{extra.get('zscore', 0):.1f}" except: pass print(f" {name:<20} | {alert_type:<6} | {time_str:<8} | {change:>6} | {delta:>6} | {zscore:>5}") else: print(" 暂无异动记录") except Exception as e: print(f" 查询失败: {e}") def train_model(): """训练SVM模型""" if not SKLEARN_AVAILABLE: print("错误: sklearn未安装,无法训练模型") print("安装命令: pip install scikit-learn") return logger.info("=" * 60) logger.info("🎓 开始训练SVM模型") logger.info("=" * 60) # 加载概念数据 leaf_concepts = get_all_concepts() parent_concepts = load_hierarchy_concepts(leaf_concepts) all_concepts = leaf_concepts + parent_concepts all_stocks = set() for c in all_concepts: all_stocks.update(c['stocks']) all_stocks = list(all_stocks) logger.info(f"概念数: {len(all_concepts)}, 股票数: {len(all_stocks)}") # 收集训练数据(使用历史异动数据) training_features = [] # 从最近N天的数据中提取特征 lookback_days = 30 try: with MYSQL_ENGINE.connect() as conn: result = conn.execute(text(""" SELECT change_pct, change_delta, rank_position, limit_up_count, stock_count, index_change_pct, extra_info FROM concept_minute_alert WHERE trade_date >= DATE_SUB(CURDATE(), INTERVAL :days DAY) """), {'days': lookback_days}) for row in result: features = [ float(row[0]) if row[0] else 0, # change_pct float(row[1]) if row[1] else 0, # change_delta 0, # change_delta_10min (not available) -float(row[2]) if row[2] else 0, # rank_delta (approximation) float(row[3]) / max(1, float(row[4])) * 100 if row[3] and row[4] else 0, # limit_up_ratio 0, # volume_ratio float(row[0]) - float(row[5]) if row[0] and row[5] else 0, # index_correlation ] training_features.append(features) logger.info(f"收集到 {len(training_features)} 条训练数据") if len(training_features) >= 100: success = train_svm_model(training_features) if success: logger.info("✅ 模型训练成功!") else: logger.error("❌ 模型训练失败") else: logger.warning("训练数据不足100条,跳过训练") except Exception as e: logger.error(f"训练失败: {e}") import traceback traceback.print_exc() # ==================== 主函数 ==================== def main(): parser = argparse.ArgumentParser(description='智能概念异动检测服务 (Z-Score + SVM)') parser.add_argument('command', nargs='?', default='realtime', choices=['realtime', 'once', 'status', 'backtest', 'train'], help='命令: realtime(实时运行), once(单次运行), status(状态), backtest(回测), train(训练模型)') parser.add_argument('--date', '-d', type=str, default=None, help='回测日期,格式: YYYY-MM-DD,默认为今天') parser.add_argument('--keep', '-k', action='store_true', help='回测时保留已有数据(默认会清除)') args = parser.parse_args() if args.command == 'realtime': run_realtime() elif args.command == 'once': run_single() elif args.command == 'status': show_status() elif args.command == 'backtest': trade_date = args.date or datetime.now().strftime('%Y-%m-%d') clear_existing = not args.keep run_backtest(trade_date, clear_existing) elif args.command == 'train': train_model() if __name__ == "__main__": main()