update pay ui

This commit is contained in:
2025-12-09 08:31:18 +08:00
parent e4937c2719
commit 25492caf15
26 changed files with 15577 additions and 1061 deletions

273
app.py
View File

@@ -12458,6 +12458,279 @@ def get_daily_top_concepts():
}), 500 }), 500
# ==================== 热点概览 API ====================
@app.route('/api/market/hotspot-overview', methods=['GET'])
def get_hotspot_overview():
"""
获取热点概览数据(用于个股中心的热点概览图表)
返回:指数分时数据 + 概念异动标注
"""
try:
trade_date = request.args.get('date')
index_code = request.args.get('index', '000001.SH')
# 如果没有指定日期,使用最新交易日
if not trade_date:
today = date.today()
if today in trading_days_set:
trade_date = today.strftime('%Y-%m-%d')
else:
target_date = get_trading_day_near_date(today)
trade_date = target_date.strftime('%Y-%m-%d') if target_date else today.strftime('%Y-%m-%d')
# 1. 获取指数分时数据
client = get_clickhouse_client()
target_date_obj = datetime.strptime(trade_date, '%Y-%m-%d').date()
index_data = client.execute(
"""
SELECT timestamp, open, high, low, close, volume
FROM index_minute
WHERE code = %(code)s
AND toDate(timestamp) = %(date)s
ORDER BY timestamp
""",
{
'code': index_code,
'date': target_date_obj
}
)
# 获取昨收价
code_no_suffix = index_code.split('.')[0]
prev_close = None
with engine.connect() as conn:
prev_result = conn.execute(text("""
SELECT F006N FROM ea_exchangetrade
WHERE INDEXCODE = :code
AND TRADEDATE < :today
ORDER BY TRADEDATE DESC LIMIT 1
"""), {
'code': code_no_suffix,
'today': target_date_obj
}).fetchone()
if prev_result and prev_result[0]:
prev_close = float(prev_result[0])
# 格式化指数数据
index_timeline = []
for row in index_data:
ts, open_p, high_p, low_p, close_p, vol = row
change_pct = None
if prev_close and close_p:
change_pct = round((float(close_p) - prev_close) / prev_close * 100, 4)
index_timeline.append({
'time': ts.strftime('%H:%M'),
'timestamp': ts.isoformat(),
'price': float(close_p) if close_p else None,
'open': float(open_p) if open_p else None,
'high': float(high_p) if high_p else None,
'low': float(low_p) if low_p else None,
'volume': int(vol) if vol else 0,
'change_pct': change_pct
})
# 2. 获取概念异动数据
alerts = []
with engine.connect() as conn:
alert_result = conn.execute(text("""
SELECT
concept_id, concept_name, alert_time, alert_type,
change_pct, change_delta, limit_up_count, limit_up_delta,
rank_position, index_price, index_change_pct,
stock_count, concept_type, extra_info,
prev_change_pct, zscore, importance_score
FROM concept_minute_alert
WHERE trade_date = :trade_date
ORDER BY alert_time
"""), {'trade_date': trade_date})
for row in alert_result:
alert_time = row[2]
extra_info = None
if row[13]:
try:
extra_info = json.loads(row[13]) if isinstance(row[13], str) else row[13]
except:
pass
# 从 extra_info 提取 zscore 和 importance_score兼容旧数据
zscore = None
importance_score = None
if len(row) > 15:
zscore = float(row[15]) if row[15] else None
importance_score = float(row[16]) if row[16] else None
if extra_info:
zscore = zscore or extra_info.get('zscore')
importance_score = importance_score or extra_info.get('importance_score')
alerts.append({
'concept_id': row[0],
'concept_name': row[1],
'time': alert_time.strftime('%H:%M') if alert_time else None,
'timestamp': alert_time.isoformat() if alert_time else None,
'alert_type': row[3],
'change_pct': float(row[4]) if row[4] else None,
'change_delta': float(row[5]) if row[5] else None,
'limit_up_count': row[6],
'limit_up_delta': row[7],
'rank_position': row[8],
'index_price': float(row[9]) if row[9] else None,
'index_change_pct': float(row[10]) if row[10] else None,
'stock_count': row[11],
'concept_type': row[12],
'extra_info': extra_info,
'prev_change_pct': float(row[14]) if len(row) > 14 and row[14] else None,
'zscore': zscore,
'importance_score': importance_score
})
# 计算统计信息
day_high = max([d['price'] for d in index_timeline if d['price']], default=None)
day_low = min([d['price'] for d in index_timeline if d['price']], default=None)
latest_price = index_timeline[-1]['price'] if index_timeline else None
latest_change_pct = index_timeline[-1]['change_pct'] if index_timeline else None
return jsonify({
'success': True,
'data': {
'trade_date': trade_date,
'index': {
'code': index_code,
'name': '上证指数' if index_code == '000001.SH' else index_code,
'prev_close': prev_close,
'latest_price': latest_price,
'change_pct': latest_change_pct,
'high': day_high,
'low': day_low,
'timeline': index_timeline
},
'alerts': alerts,
'alert_count': len(alerts),
'alert_summary': {
'surge': len([a for a in alerts if a['alert_type'] == 'surge']),
'surge_up': len([a for a in alerts if a['alert_type'] == 'surge_up']),
'surge_down': len([a for a in alerts if a['alert_type'] == 'surge_down']),
'limit_up': len([a for a in alerts if a['alert_type'] == 'limit_up']),
'rank_jump': len([a for a in alerts if a['alert_type'] == 'rank_jump'])
}
}
})
except Exception as e:
import traceback
logger.error(f"获取热点概览数据失败: {traceback.format_exc()}")
return jsonify({
'success': False,
'error': str(e)
}), 500
@app.route('/api/market/concept-alerts', methods=['GET'])
def get_concept_alerts():
"""
获取概念异动列表(支持分页和筛选)
"""
try:
trade_date = request.args.get('date')
alert_type = request.args.get('type') # surge/limit_up/rank_jump
concept_type = request.args.get('concept_type') # leaf/lv1/lv2/lv3
limit = request.args.get('limit', 50, type=int)
offset = request.args.get('offset', 0, type=int)
# 构建查询条件
conditions = []
params = {'limit': limit, 'offset': offset}
if trade_date:
conditions.append("trade_date = :trade_date")
params['trade_date'] = trade_date
else:
conditions.append("trade_date = CURDATE()")
if alert_type:
conditions.append("alert_type = :alert_type")
params['alert_type'] = alert_type
if concept_type:
conditions.append("concept_type = :concept_type")
params['concept_type'] = concept_type
where_clause = " AND ".join(conditions) if conditions else "1=1"
with engine.connect() as conn:
# 获取总数
count_sql = text(f"SELECT COUNT(*) FROM concept_minute_alert WHERE {where_clause}")
total = conn.execute(count_sql, params).scalar()
# 获取数据
query_sql = text(f"""
SELECT
id, concept_id, concept_name, alert_time, alert_type, trade_date,
change_pct, prev_change_pct, change_delta,
limit_up_count, prev_limit_up_count, limit_up_delta,
rank_position, prev_rank_position, rank_delta,
index_price, index_change_pct,
stock_count, concept_type, extra_info
FROM concept_minute_alert
WHERE {where_clause}
ORDER BY alert_time DESC
LIMIT :limit OFFSET :offset
""")
result = conn.execute(query_sql, params)
alerts = []
for row in result:
extra_info = None
if row[19]:
try:
extra_info = json.loads(row[19]) if isinstance(row[19], str) else row[19]
except:
pass
alerts.append({
'id': row[0],
'concept_id': row[1],
'concept_name': row[2],
'alert_time': row[3].isoformat() if row[3] else None,
'alert_type': row[4],
'trade_date': row[5].isoformat() if row[5] else None,
'change_pct': float(row[6]) if row[6] else None,
'prev_change_pct': float(row[7]) if row[7] else None,
'change_delta': float(row[8]) if row[8] else None,
'limit_up_count': row[9],
'prev_limit_up_count': row[10],
'limit_up_delta': row[11],
'rank_position': row[12],
'prev_rank_position': row[13],
'rank_delta': row[14],
'index_price': float(row[15]) if row[15] else None,
'index_change_pct': float(row[16]) if row[16] else None,
'stock_count': row[17],
'concept_type': row[18],
'extra_info': extra_info
})
return jsonify({
'success': True,
'data': alerts,
'total': total,
'limit': limit,
'offset': offset
})
except Exception as e:
import traceback
logger.error(f"获取概念异动列表失败: {traceback.format_exc()}")
return jsonify({
'success': False,
'error': str(e)
}), 500
@app.route('/api/market/rise-analysis/<seccode>', methods=['GET']) @app.route('/api/market/rise-analysis/<seccode>', methods=['GET'])
def get_rise_analysis(seccode): def get_rise_analysis(seccode):
"""获取股票涨幅分析数据(从 Elasticsearch 获取)""" """获取股票涨幅分析数据(从 Elasticsearch 获取)"""

2823
concept_alert_20251208.log Normal file

File diff suppressed because it is too large Load Diff

1078
concept_alert_alpha.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,28 @@
2025-12-08 16:40:41,567 - INFO - ============================================================
2025-12-08 16:40:41,567 - INFO - 🔄 回测: 2025-12-08 (Alpha Z-Score 方法)
2025-12-08 16:40:41,569 - INFO - ============================================================
2025-12-08 16:40:41,679 - INFO - 已清除 2025-12-08 的数据
2025-12-08 16:40:41,903 - INFO - POST http://222.128.1.157:19200/concept_library_v3/_search?scroll=2m [status:200 duration:0.224s]
2025-12-08 16:40:42,105 - INFO - POST http://222.128.1.157:19200/_search/scroll [status:200 duration:0.197s]
2025-12-08 16:40:42,330 - INFO - POST http://222.128.1.157:19200/_search/scroll [status:200 duration:0.178s]
2025-12-08 16:40:42,518 - INFO - POST http://222.128.1.157:19200/_search/scroll [status:200 duration:0.183s]
2025-12-08 16:40:42,704 - INFO - POST http://222.128.1.157:19200/_search/scroll [status:200 duration:0.182s]
2025-12-08 16:40:42,894 - INFO - POST http://222.128.1.157:19200/_search/scroll [status:200 duration:0.186s]
2025-12-08 16:40:43,060 - INFO - POST http://222.128.1.157:19200/_search/scroll [status:200 duration:0.162s]
2025-12-08 16:40:43,234 - INFO - POST http://222.128.1.157:19200/_search/scroll [status:200 duration:0.171s]
2025-12-08 16:40:43,383 - INFO - POST http://222.128.1.157:19200/_search/scroll [status:200 duration:0.145s]
2025-12-08 16:40:43,394 - INFO - POST http://222.128.1.157:19200/_search/scroll [status:200 duration:0.008s]
2025-12-08 16:40:43,399 - INFO - DELETE http://222.128.1.157:19200/_search/scroll [status:200 duration:0.005s]
2025-12-08 16:40:43,409 - INFO - 概念: 968, 股票: 5938
2025-12-08 16:40:43,505 - INFO - 时间点: 241
2025-12-08 16:41:02,028 - INFO - 进度: 30/241 (12%), 异动: 0
2025-12-08 16:41:20,851 - INFO - 进度: 60/241 (24%), 异动: 0
2025-12-08 16:41:39,396 - INFO - 进度: 90/241 (37%), 异动: 0
2025-12-08 16:41:58,687 - INFO - 进度: 120/241 (49%), 异动: 0
2025-12-08 16:43:08,124 - INFO - 进度: 150/241 (62%), 异动: 0
2025-12-08 16:43:26,973 - INFO - 进度: 180/241 (74%), 异动: 0
2025-12-08 16:43:45,746 - INFO - 进度: 210/241 (87%), 异动: 0
2025-12-08 16:44:04,479 - INFO - 进度: 240/241 (99%), 异动: 0
2025-12-08 16:44:05,123 - INFO - ============================================================
2025-12-08 16:44:05,123 - INFO - ✅ 回测完成! 检测到 0 条异动
2025-12-08 16:44:05,125 - INFO - ============================================================

1625
concept_alert_ml.py Normal file

File diff suppressed because it is too large Load Diff

1366
concept_alert_realtime.py Normal file

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

681
concept_quota_realtime.py Normal file
View File

@@ -0,0 +1,681 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
概念涨跌幅实时更新服务
- 在交易时间段每分钟从ClickHouse获取最新分钟数据
- 计算涨跌幅后更新MySQL的concept_daily_stats表
- 支持叶子概念和母概念lv1/lv2/lv3
"""
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
from sqlalchemy import create_engine, text
from elasticsearch import Elasticsearch
from clickhouse_driver import Client
import time
import logging
import json
import os
import hashlib
import argparse
# ==================== 配置 ====================
# MySQL配置
MYSQL_ENGINE = create_engine(
"mysql+pymysql://root:Zzl5588161!@222.128.1.157:33060/stock",
echo=False
)
# Elasticsearch配置
ES_CLIENT = Elasticsearch(['http://222.128.1.157:19200'])
INDEX_NAME = 'concept_library_v3'
# ClickHouse配置
CLICKHOUSE_CONFIG = {
'host': '222.128.1.157',
'port': 18000,
'user': 'default',
'password': 'Zzl33818!',
'database': 'stock'
}
# 层级结构文件
HIERARCHY_FILE = 'concept_hierarchy_v3.json'
# 交易时间配置
TRADING_HOURS = {
'morning_start': (9, 30),
'morning_end': (11, 30),
'afternoon_start': (13, 0),
'afternoon_end': (15, 0),
}
# ==================== 日志配置 ====================
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(f'concept_realtime_{datetime.now().strftime("%Y%m%d")}.log', encoding='utf-8'),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
# ClickHouse客户端
ch_client = None
def get_ch_client():
"""获取ClickHouse客户端"""
global ch_client
if ch_client is None:
ch_client = Client(**CLICKHOUSE_CONFIG)
return ch_client
def generate_id(name: str) -> str:
"""生成概念ID"""
return hashlib.md5(name.encode('utf-8')).hexdigest()[:16]
def code_to_ch_format(code: str) -> str:
"""将6位股票代码转换为ClickHouse格式带后缀
规则:
- 6开头 -> .SH上海
- 0或3开头 -> .SZ深圳
- 其他 -> .BJ北京
- 非6位数字的忽略可能是港股
"""
if not code or len(code) != 6 or not code.isdigit():
return None
if code.startswith('6'):
return f"{code}.SH"
elif code.startswith('0') or code.startswith('3'):
return f"{code}.SZ"
else:
return f"{code}.BJ"
def ch_code_to_pure(ch_code: str) -> str:
"""将ClickHouse格式的股票代码转回纯6位代码"""
if not ch_code:
return None
return ch_code.split('.')[0]
# ==================== 概念数据获取 ====================
def get_all_concepts():
"""从ES获取所有叶子概念及其股票列表"""
concepts = []
query = {
"query": {"match_all": {}},
"size": 100,
"_source": ["concept_id", "concept", "stocks"]
}
resp = ES_CLIENT.search(index=INDEX_NAME, body=query, scroll='2m')
scroll_id = resp['_scroll_id']
hits = resp['hits']['hits']
while len(hits) > 0:
for hit in hits:
source = hit['_source']
concept_info = {
'concept_id': source.get('concept_id'),
'concept_name': source.get('concept'),
'stocks': [],
'concept_type': 'leaf'
}
# v3索引的stocks字段是 [{name, code}, ...]
if 'stocks' in source and isinstance(source['stocks'], list):
for stock in source['stocks']:
if isinstance(stock, dict) and 'code' in stock and stock['code']:
concept_info['stocks'].append(stock['code'])
if concept_info['stocks']:
concepts.append(concept_info)
resp = ES_CLIENT.scroll(scroll_id=scroll_id, scroll='2m')
scroll_id = resp['_scroll_id']
hits = resp['hits']['hits']
ES_CLIENT.clear_scroll(scroll_id=scroll_id)
return concepts
def load_hierarchy_concepts(leaf_concepts: list) -> list:
"""加载层级结构生成母概念lv1/lv2/lv3"""
hierarchy_path = os.path.join(os.path.dirname(__file__), HIERARCHY_FILE)
if not os.path.exists(hierarchy_path):
logger.warning(f"层级文件不存在: {hierarchy_path}")
return []
with open(hierarchy_path, 'r', encoding='utf-8') as f:
hierarchy_data = json.load(f)
# 建立概念名称到股票的映射
concept_to_stocks = {}
for c in leaf_concepts:
concept_to_stocks[c['concept_name']] = set(c['stocks'])
parent_concepts = []
for lv1 in hierarchy_data.get('hierarchy', []):
lv1_name = lv1.get('lv1', '')
lv1_stocks = set()
for child in lv1.get('children', []):
lv2_name = child.get('lv2', '')
lv2_stocks = set()
if 'children' in child:
for lv3_child in child.get('children', []):
lv3_name = lv3_child.get('lv3', '')
lv3_stocks = set()
for concept_name in lv3_child.get('concepts', []):
if concept_name in concept_to_stocks:
lv3_stocks.update(concept_to_stocks[concept_name])
if lv3_stocks:
parent_concepts.append({
'concept_id': generate_id(f"lv3_{lv3_name}"),
'concept_name': f"[三级] {lv3_name}",
'stocks': list(lv3_stocks),
'concept_type': 'lv3'
})
lv2_stocks.update(lv3_stocks)
else:
for concept_name in child.get('concepts', []):
if concept_name in concept_to_stocks:
lv2_stocks.update(concept_to_stocks[concept_name])
if lv2_stocks:
parent_concepts.append({
'concept_id': generate_id(f"lv2_{lv2_name}"),
'concept_name': f"[二级] {lv2_name}",
'stocks': list(lv2_stocks),
'concept_type': 'lv2'
})
lv1_stocks.update(lv2_stocks)
if lv1_stocks:
parent_concepts.append({
'concept_id': generate_id(f"lv1_{lv1_name}"),
'concept_name': f"[一级] {lv1_name}",
'stocks': list(lv1_stocks),
'concept_type': 'lv1'
})
return parent_concepts
# ==================== 基准价格获取 ====================
def get_base_prices(stock_codes: list, current_date: str) -> dict:
"""获取当日的昨收价作为基准从ea_trade的F002N字段
ea_trade表字段说明
- F002N: 昨日收盘价
- F007N: 最近成交价(收盘价)
- F010N: 涨跌幅
"""
if not stock_codes:
return {}
# 过滤出有效的6位股票代码
valid_codes = [code for code in stock_codes if code and len(code) == 6 and code.isdigit()]
if not valid_codes:
return {}
stock_codes_str = "','".join(valid_codes)
# 获取当日数据中的昨收价F002N
query = f"""
SELECT SECCODE, F002N
FROM ea_trade
WHERE SECCODE IN ('{stock_codes_str}')
AND TRADEDATE = (
SELECT MAX(TRADEDATE)
FROM ea_trade
WHERE TRADEDATE <= '{current_date}'
)
AND F002N IS NOT NULL AND F002N > 0
"""
try:
with MYSQL_ENGINE.connect() as conn:
result = conn.execute(text(query))
base_prices = {row[0]: float(row[1]) for row in result if row[1] and float(row[1]) > 0}
logger.info(f"获取到 {len(base_prices)} 个基准价格")
return base_prices
except Exception as e:
logger.error(f"获取基准价格失败: {e}")
return {}
# ==================== 实时价格获取 ====================
def get_latest_prices(stock_codes: list) -> dict:
"""从ClickHouse获取最新分钟数据的收盘价
Args:
stock_codes: 纯6位股票代码列表如 ['000001', '600000']
Returns:
dict: {纯6位代码: {'close': 价格, 'timestamp': 时间}}
"""
if not stock_codes:
return {}
client = get_ch_client()
# 转换为ClickHouse格式的代码带后缀
ch_codes = []
code_mapping = {} # ch_code -> pure_code
for code in stock_codes:
ch_code = code_to_ch_format(code)
if ch_code:
ch_codes.append(ch_code)
code_mapping[ch_code] = code
if not ch_codes:
logger.warning("没有有效的股票代码可查询")
return {}
ch_codes_str = "','".join(ch_codes)
# 获取今日最新的分钟数据
query = f"""
SELECT code, close, timestamp
FROM (
SELECT code, close, timestamp,
ROW_NUMBER() OVER (PARTITION BY code ORDER BY timestamp DESC) as rn
FROM stock_minute
WHERE code IN ('{ch_codes_str}')
AND toDate(timestamp) = today()
)
WHERE rn = 1
"""
try:
result = client.execute(query)
if not result:
return {}
latest_prices = {}
for row in result:
ch_code, close, ts = row
if close and close > 0:
# 转回纯6位代码
pure_code = code_mapping.get(ch_code)
if pure_code:
latest_prices[pure_code] = {
'close': float(close),
'timestamp': ts
}
return latest_prices
except Exception as e:
logger.error(f"获取最新价格失败: {e}")
return {}
# ==================== 涨跌幅计算 ====================
def calculate_change_pct(base_prices: dict, latest_prices: dict) -> dict:
"""计算涨跌幅"""
changes = {}
for code, latest in latest_prices.items():
if code in base_prices and base_prices[code] > 0:
base = base_prices[code]
close = latest['close']
change_pct = (close - base) / base * 100
changes[code] = round(change_pct, 4)
return changes
def calculate_concept_stats(concepts: list, stock_changes: dict, trade_date: str) -> list:
"""计算所有概念的涨跌幅统计"""
stats = []
for concept in concepts:
concept_id = concept['concept_id']
concept_name = concept['concept_name']
stock_codes = concept['stocks']
concept_type = concept.get('concept_type', 'leaf')
# 获取该概念股票的涨跌幅
changes = [stock_changes[code] for code in stock_codes if code in stock_changes]
if not changes:
continue
avg_change_pct = round(np.mean(changes), 4)
stock_count = len(changes)
stats.append({
'concept_id': concept_id,
'concept_name': concept_name,
'trade_date': trade_date,
'avg_change_pct': avg_change_pct,
'stock_count': stock_count,
'concept_type': concept_type
})
return stats
# ==================== MySQL更新 ====================
def update_mysql_stats(stats: list):
"""更新MySQL的concept_daily_stats表"""
if not stats:
return 0
with MYSQL_ENGINE.begin() as conn:
updated = 0
for item in stats:
upsert_sql = text("""
REPLACE INTO concept_daily_stats
(concept_id, concept_name, trade_date, avg_change_pct, stock_count, concept_type)
VALUES (:concept_id, :concept_name, :trade_date, :avg_change_pct, :stock_count, :concept_type)
""")
conn.execute(upsert_sql, item)
updated += 1
return updated
# ==================== 交易时间判断 ====================
def is_trading_time() -> bool:
"""判断当前是否为交易时间"""
now = datetime.now()
weekday = now.weekday()
# 周末不交易
if weekday >= 5:
return False
hour, minute = now.hour, now.minute
current_time = hour * 60 + minute
# 上午 9:30 - 11:30
morning_start = 9 * 60 + 30
morning_end = 11 * 60 + 30
# 下午 13:00 - 15:00
afternoon_start = 13 * 60
afternoon_end = 15 * 60
return (morning_start <= current_time <= morning_end) or \
(afternoon_start <= current_time <= afternoon_end)
def get_next_update_time() -> int:
"""获取距离下次更新的秒数"""
now = datetime.now()
if is_trading_time():
# 交易时间内,等到下一分钟
return 60 - now.second
else:
# 非交易时间
hour, minute = now.hour, now.minute
# 计算距离下次交易开始的时间
if hour < 9 or (hour == 9 and minute < 30):
# 等到9:30
target = now.replace(hour=9, minute=30, second=0, microsecond=0)
elif (hour == 11 and minute >= 30) or hour == 12:
# 等到13:00
target = now.replace(hour=13, minute=0, second=0, microsecond=0)
elif hour >= 15:
# 等到明天9:30
target = (now + timedelta(days=1)).replace(hour=9, minute=30, second=0, microsecond=0)
else:
target = now + timedelta(minutes=1)
wait_seconds = (target - now).total_seconds()
return max(60, int(wait_seconds))
# ==================== 主运行逻辑 ====================
def run_once(concepts: list, all_stocks: list) -> int:
"""执行一次更新"""
now = datetime.now()
trade_date = now.strftime('%Y-%m-%d')
# 获取基准价格(昨日收盘价)
base_prices = get_base_prices(all_stocks, trade_date)
if not base_prices:
logger.warning("无法获取基准价格")
return 0
# 获取最新价格
latest_prices = get_latest_prices(all_stocks)
if not latest_prices:
logger.warning("无法获取最新价格")
return 0
# 计算涨跌幅
stock_changes = calculate_change_pct(base_prices, latest_prices)
if not stock_changes:
logger.warning("无涨跌幅数据")
return 0
logger.info(f"获取到 {len(stock_changes)} 只股票的涨跌幅")
# 计算概念统计
stats = calculate_concept_stats(concepts, stock_changes, trade_date)
logger.info(f"计算了 {len(stats)} 个概念的涨跌幅")
# 更新MySQL
updated = update_mysql_stats(stats)
logger.info(f"更新了 {updated} 条记录到MySQL")
return updated
def run_realtime():
"""实时更新主循环"""
logger.info("=" * 60)
logger.info("启动概念涨跌幅实时更新服务")
logger.info("=" * 60)
# 加载概念数据
logger.info("加载概念数据...")
leaf_concepts = get_all_concepts()
logger.info(f"获取到 {len(leaf_concepts)} 个叶子概念")
parent_concepts = load_hierarchy_concepts(leaf_concepts)
logger.info(f"生成了 {len(parent_concepts)} 个母概念")
all_concepts = leaf_concepts + parent_concepts
logger.info(f"总计 {len(all_concepts)} 个概念")
# 收集所有股票代码
all_stocks = set()
for c in all_concepts:
all_stocks.update(c['stocks'])
all_stocks = list(all_stocks)
logger.info(f"监控 {len(all_stocks)} 只股票")
last_concept_update = datetime.now()
while True:
try:
now = datetime.now()
# 每小时重新加载概念数据
if (now - last_concept_update).total_seconds() > 3600:
logger.info("重新加载概念数据...")
leaf_concepts = get_all_concepts()
parent_concepts = load_hierarchy_concepts(leaf_concepts)
all_concepts = leaf_concepts + parent_concepts
all_stocks = set()
for c in all_concepts:
all_stocks.update(c['stocks'])
all_stocks = list(all_stocks)
last_concept_update = now
logger.info(f"更新完成: {len(all_concepts)} 个概念, {len(all_stocks)} 只股票")
# 检查是否交易时间
if not is_trading_time():
wait_sec = get_next_update_time()
wait_min = wait_sec // 60
logger.info(f"非交易时间,等待 {wait_min} 分钟后重试...")
time.sleep(min(wait_sec, 300)) # 最多等5分钟再检查
continue
# 执行更新
logger.info(f"\n{'=' * 40}")
logger.info(f"更新时间: {now.strftime('%Y-%m-%d %H:%M:%S')}")
updated = run_once(all_concepts, all_stocks)
# 等待下一分钟
sleep_sec = 60 - datetime.now().second
logger.info(f"完成,等待 {sleep_sec} 秒后继续...")
time.sleep(sleep_sec)
except KeyboardInterrupt:
logger.info("\n收到退出信号,停止服务...")
break
except Exception as e:
logger.error(f"发生错误: {e}")
import traceback
traceback.print_exc()
time.sleep(60)
def run_single():
"""单次运行(不循环)"""
logger.info("单次更新模式")
leaf_concepts = get_all_concepts()
parent_concepts = load_hierarchy_concepts(leaf_concepts)
all_concepts = leaf_concepts + parent_concepts
all_stocks = set()
for c in all_concepts:
all_stocks.update(c['stocks'])
all_stocks = list(all_stocks)
logger.info(f"概念数: {len(all_concepts)}, 股票数: {len(all_stocks)}")
updated = run_once(all_concepts, all_stocks)
logger.info(f"更新完成: {updated} 条记录")
def show_status():
"""显示当前状态"""
print("\n" + "=" * 60)
print("概念涨跌幅实时更新服务 - 状态")
print("=" * 60)
# 当前时间
now = datetime.now()
print(f"\n当前时间: {now.strftime('%Y-%m-%d %H:%M:%S')}")
print(f"是否交易时间: {'' if is_trading_time() else ''}")
# MySQL数据状态
print("\nMySQL数据状态:")
try:
with MYSQL_ENGINE.connect() as conn:
# 今日数据量
result = conn.execute(text("""
SELECT concept_type, COUNT(*) as cnt
FROM concept_daily_stats
WHERE trade_date = CURDATE()
GROUP BY concept_type
"""))
rows = list(result)
if rows:
print(" 今日数据:")
for row in rows:
print(f" {row[0]}: {row[1]}")
else:
print(" 今日暂无数据")
# 最新更新时间
result = conn.execute(text("""
SELECT MAX(updated_at) FROM concept_daily_stats WHERE trade_date = CURDATE()
"""))
row = result.fetchone()
if row and row[0]:
print(f" 最后更新: {row[0]}")
except Exception as e:
print(f" 查询失败: {e}")
# ClickHouse数据状态
print("\nClickHouse数据状态:")
try:
client = get_ch_client()
result = client.execute("""
SELECT COUNT(*), MAX(timestamp)
FROM stock_minute
WHERE toDate(timestamp) = today()
""")
if result:
count, max_ts = result[0]
print(f" 今日分钟数据: {count:,}")
print(f" 最新时间戳: {max_ts}")
except Exception as e:
print(f" 查询失败: {e}")
# 今日涨跌幅TOP10
print("\n今日涨跌幅 TOP10:")
try:
with MYSQL_ENGINE.connect() as conn:
result = conn.execute(text("""
SELECT concept_name, avg_change_pct, stock_count, concept_type
FROM concept_daily_stats
WHERE trade_date = CURDATE() AND concept_type = 'leaf'
ORDER BY avg_change_pct DESC
LIMIT 10
"""))
rows = list(result)
if rows:
print(f" {'概念':<25} | {'涨跌幅':>8} | {'股票数':>6}")
print(" " + "-" * 50)
for row in rows:
name = row[0][:25] if len(row[0]) > 25 else row[0]
print(f" {name:<25} | {row[1]:>7.2f}% | {row[2]:>6}")
else:
print(" 暂无数据")
except Exception as e:
print(f" 查询失败: {e}")
# ==================== 主函数 ====================
def main():
parser = argparse.ArgumentParser(description='概念涨跌幅实时更新服务')
parser.add_argument('command', nargs='?', default='realtime',
choices=['realtime', 'once', 'status'],
help='命令: realtime(实时运行), once(单次运行), status(状态查看)')
args = parser.parse_args()
if args.command == 'realtime':
run_realtime()
elif args.command == 'once':
run_single()
elif args.command == 'status':
show_status()
if __name__ == "__main__":
main()

89
create_tables.py Normal file
View File

@@ -0,0 +1,89 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""创建异动检测所需的数据库表"""
import sys
from sqlalchemy import create_engine, text
engine = create_engine('mysql+pymysql://root:Zzl5588161!@222.128.1.157:33060/stock', echo=False)
# 删除旧表
drop_sql1 = 'DROP TABLE IF EXISTS concept_minute_alert'
drop_sql2 = 'DROP TABLE IF EXISTS index_minute_snapshot'
# 创建 concept_minute_alert 表
# 支持 Z-Score + SVM 智能检测
sql1 = '''
CREATE TABLE concept_minute_alert (
id BIGINT AUTO_INCREMENT PRIMARY KEY,
concept_id VARCHAR(32) NOT NULL,
concept_name VARCHAR(100) NOT NULL,
alert_time DATETIME NOT NULL,
alert_type VARCHAR(20) NOT NULL COMMENT 'surge_up=暴涨, surge_down=暴跌, limit_up=涨停增加, rank_jump=排名跃升',
trade_date DATE NOT NULL,
change_pct DECIMAL(10,4) COMMENT '当前涨跌幅',
prev_change_pct DECIMAL(10,4) COMMENT '之前涨跌幅',
change_delta DECIMAL(10,4) COMMENT '涨跌幅变化',
limit_up_count INT DEFAULT 0 COMMENT '涨停数',
prev_limit_up_count INT DEFAULT 0,
limit_up_delta INT DEFAULT 0,
limit_down_count INT DEFAULT 0 COMMENT '跌停数',
rank_position INT COMMENT '当前排名',
prev_rank_position INT COMMENT '之前排名',
rank_delta INT COMMENT '排名变化(负数表示上升)',
index_code VARCHAR(20) DEFAULT '000001.SH',
index_price DECIMAL(12,4),
index_change_pct DECIMAL(10,4),
stock_count INT,
concept_type VARCHAR(20) DEFAULT 'leaf',
zscore DECIMAL(8,4) COMMENT 'Z-Score值',
importance_score DECIMAL(6,4) COMMENT '重要性评分(0-1)',
extra_info JSON COMMENT '扩展信息(包含zscore,svm_score等)',
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
INDEX idx_trade_date (trade_date),
INDEX idx_alert_time (alert_time),
INDEX idx_concept_id (concept_id),
INDEX idx_alert_type (alert_type),
INDEX idx_trade_date_time (trade_date, alert_time),
INDEX idx_importance (importance_score)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='概念异动记录表(智能版)'
'''
# 创建 index_minute_snapshot 表
sql2 = '''
CREATE TABLE index_minute_snapshot (
id BIGINT AUTO_INCREMENT PRIMARY KEY,
index_code VARCHAR(20) NOT NULL,
trade_date DATE NOT NULL,
snapshot_time DATETIME NOT NULL,
price DECIMAL(12,4),
open_price DECIMAL(12,4),
high_price DECIMAL(12,4),
low_price DECIMAL(12,4),
prev_close DECIMAL(12,4),
change_pct DECIMAL(10,4),
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
UNIQUE KEY uk_index_time (index_code, snapshot_time),
INDEX idx_trade_date (trade_date)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
'''
if __name__ == '__main__':
print('正在重建数据库表...\n')
with engine.begin() as conn:
# 先删除旧表
print('删除旧表...')
conn.execute(text(drop_sql1))
print(' - concept_minute_alert 已删除')
conn.execute(text(drop_sql2))
print(' - index_minute_snapshot 已删除')
# 创建新表
print('\n创建新表...')
conn.execute(text(sql1))
print(' ✅ concept_minute_alert 表创建成功')
conn.execute(text(sql2))
print(' ✅ index_minute_snapshot 表创建成功')
print('\n✅ 所有表创建完成!')

112
ml/README.md Normal file
View File

@@ -0,0 +1,112 @@
# 概念异动检测 ML 模块
基于 Transformer Autoencoder 的概念异动检测系统。
## 环境要求
- Python 3.8+
- PyTorch 2.0+ (CUDA 12.x for 5090 GPU)
- ClickHouse, MySQL, Elasticsearch
## 数据库配置
当前配置(`prepare_data.py`:
- MySQL: `192.168.1.5:3306`
- Elasticsearch: `127.0.0.1:9200`
- ClickHouse: `127.0.0.1:9000`
## 快速开始
```bash
# 1. 安装依赖
pip install -r ml/requirements.txt
# 2. 安装 PyTorch (5090 需要 CUDA 12.4)
pip install torch --index-url https://download.pytorch.org/whl/cu124
# 3. 运行训练
chmod +x ml/run_training.sh
./ml/run_training.sh
```
## 文件说明
| 文件 | 说明 |
|------|------|
| `model.py` | Transformer Autoencoder 模型定义 |
| `prepare_data.py` | 数据提取和特征计算 |
| `train.py` | 模型训练脚本 |
| `inference.py` | 推理服务 |
| `enhanced_detector.py` | 增强版检测器(融合 Alpha + ML |
## 训练参数
```bash
# 完整参数
./ml/run_training.sh --start 2022-01-01 --end 2024-12-01 --epochs 100 --batch_size 256
# 只准备数据
python ml/prepare_data.py --start 2022-01-01
# 只训练(数据已准备好)
python ml/train.py --epochs 100 --batch_size 256 --lr 1e-4
```
## 模型架构
```
输入: (batch, 30, 6) # 30分钟序列6个特征
Positional Encoding
Transformer Encoder (4层, 8头, d=128)
Bottleneck (压缩到 32 维)
Transformer Decoder (4层)
输出: (batch, 30, 6) # 重构序列
异动判断: reconstruction_error > threshold
```
## 6维特征
1. `alpha` - 超额收益(概念涨幅 - 大盘涨幅)
2. `alpha_delta` - Alpha 5分钟变化
3. `amt_ratio` - 成交额 / 20分钟均值
4. `amt_delta` - 成交额变化率
5. `rank_pct` - Alpha 排名百分位
6. `limit_up_ratio` - 涨停股占比
## 训练产出
训练完成后,`ml/checkpoints/` 包含:
- `best_model.pt` - 最佳模型权重
- `thresholds.json` - 异动阈值 (P90/P95/P99)
- `normalization_stats.json` - 数据标准化参数
- `config.json` - 训练配置
## 使用示例
```python
from ml.inference import ConceptAnomalyDetector
detector = ConceptAnomalyDetector('ml/checkpoints')
# 实时检测
is_anomaly, score = detector.detect(
concept_name="人工智能",
features={
'alpha': 2.5,
'alpha_delta': 0.8,
'amt_ratio': 1.5,
'amt_delta': 0.3,
'rank_pct': 0.95,
'limit_up_ratio': 0.15,
}
)
if is_anomaly:
print(f"检测到异动!分数: {score}")
```

10
ml/__init__.py Normal file
View File

@@ -0,0 +1,10 @@
# -*- coding: utf-8 -*-
"""
概念异动检测 ML 模块
提供基于 Transformer Autoencoder 的异动检测功能
"""
from .inference import ConceptAnomalyDetector, MLAnomalyService
__all__ = ['ConceptAnomalyDetector', 'MLAnomalyService']

481
ml/backtest.py Normal file
View File

@@ -0,0 +1,481 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
历史异动回测脚本
使用训练好的模型,对历史数据进行异动检测,生成异动记录
使用方法:
# 回测指定日期范围
python backtest.py --start 2024-01-01 --end 2024-12-01
# 回测单天
python backtest.py --start 2024-11-01 --end 2024-11-01
# 只生成结果,不写入数据库
python backtest.py --start 2024-01-01 --dry-run
"""
import os
import sys
import argparse
import json
from datetime import datetime, timedelta
from pathlib import Path
from typing import Dict, List, Tuple, Optional
from collections import defaultdict
import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
from sqlalchemy import create_engine, text
# 添加父目录到路径
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from model import TransformerAutoencoder
# ==================== 配置 ====================
MYSQL_ENGINE = create_engine(
"mysql+pymysql://root:Zzl5588161!@192.168.1.5:3306/stock",
echo=False
)
# 特征列表(与训练一致)
FEATURES = [
'alpha',
'alpha_delta',
'amt_ratio',
'amt_delta',
'rank_pct',
'limit_up_ratio',
]
# 回测配置
BACKTEST_CONFIG = {
'seq_len': 30, # 序列长度
'threshold_key': 'p95', # 使用的阈值
'min_alpha_abs': 0.5, # 最小 Alpha 绝对值(过滤微小波动)
'cooldown_minutes': 8, # 同一概念冷却时间
'max_alerts_per_minute': 15, # 每分钟最多异动数
'clip_value': 10.0, # 极端值截断
}
# ==================== 模型加载 ====================
class AnomalyDetector:
"""异动检测器"""
def __init__(self, checkpoint_dir: str = 'ml/checkpoints', device: str = 'auto'):
self.checkpoint_dir = Path(checkpoint_dir)
# 设备
if device == 'auto':
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
else:
self.device = torch.device(device)
# 加载配置
self._load_config()
# 加载模型
self._load_model()
# 加载阈值
self._load_thresholds()
print(f"AnomalyDetector 初始化完成")
print(f" 设备: {self.device}")
print(f" 阈值 ({BACKTEST_CONFIG['threshold_key']}): {self.threshold:.6f}")
def _load_config(self):
config_path = self.checkpoint_dir / 'config.json'
with open(config_path, 'r') as f:
self.config = json.load(f)
def _load_model(self):
model_path = self.checkpoint_dir / 'best_model.pt'
checkpoint = torch.load(model_path, map_location=self.device)
model_config = self.config['model'].copy()
model_config['use_instance_norm'] = self.config.get('use_instance_norm', True)
self.model = TransformerAutoencoder(**model_config)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.model.to(self.device)
self.model.eval()
def _load_thresholds(self):
thresholds_path = self.checkpoint_dir / 'thresholds.json'
with open(thresholds_path, 'r') as f:
thresholds = json.load(f)
self.threshold = thresholds[BACKTEST_CONFIG['threshold_key']]
@torch.no_grad()
def compute_anomaly_scores(self, sequences: np.ndarray) -> np.ndarray:
"""
计算异动分数
Args:
sequences: (n_sequences, seq_len, n_features)
Returns:
scores: (n_sequences,) 每个序列最后时刻的异动分数
"""
# 截断极端值
sequences = np.clip(sequences, -BACKTEST_CONFIG['clip_value'], BACKTEST_CONFIG['clip_value'])
# 转为 tensor
x = torch.FloatTensor(sequences).to(self.device)
# 计算重构误差
errors = self.model.compute_reconstruction_error(x, reduction='none')
# 取最后一个时刻的误差
scores = errors[:, -1].cpu().numpy()
return scores
def is_anomaly(self, score: float) -> bool:
"""判断是否异动"""
return score > self.threshold
# ==================== 数据加载 ====================
def load_daily_features(data_dir: str, date: str) -> Optional[pd.DataFrame]:
"""加载单天的特征数据"""
file_path = Path(data_dir) / f"features_{date}.parquet"
if not file_path.exists():
return None
df = pd.read_parquet(file_path)
return df
def get_available_dates(data_dir: str, start_date: str, end_date: str) -> List[str]:
"""获取可用的日期列表"""
data_path = Path(data_dir)
all_files = sorted(data_path.glob("features_*.parquet"))
dates = []
for f in all_files:
date = f.stem.replace('features_', '')
if start_date <= date <= end_date:
dates.append(date)
return dates
# ==================== 回测逻辑 ====================
def backtest_single_day(
detector: AnomalyDetector,
df: pd.DataFrame,
date: str,
seq_len: int = 30
) -> List[Dict]:
"""
回测单天数据
Args:
detector: 异动检测器
df: 当天的特征数据
date: 日期
seq_len: 序列长度
Returns:
alerts: 异动列表
"""
alerts = []
# 按概念分组
grouped = df.groupby('concept_id', sort=False)
# 冷却记录 {concept_id: last_alert_timestamp}
cooldown = {}
# 获取所有时间点
all_timestamps = sorted(df['timestamp'].unique())
if len(all_timestamps) < seq_len:
return alerts
# 对每个时间点进行检测(从第 seq_len 个开始)
for t_idx in range(seq_len - 1, len(all_timestamps)):
current_time = all_timestamps[t_idx]
window_start_time = all_timestamps[t_idx - seq_len + 1]
minute_alerts = []
# 收集该时刻所有概念的序列
concept_sequences = []
concept_infos = []
for concept_id, concept_df in grouped:
# 获取该概念在时间窗口内的数据
mask = (concept_df['timestamp'] >= window_start_time) & (concept_df['timestamp'] <= current_time)
window_df = concept_df[mask].sort_values('timestamp')
if len(window_df) < seq_len:
continue
# 取最后 seq_len 个点
window_df = window_df.tail(seq_len)
# 提取特征
features = window_df[FEATURES].values
# 处理缺失值
features = np.nan_to_num(features, nan=0.0, posinf=0.0, neginf=0.0)
# 获取当前时刻的信息
current_row = window_df.iloc[-1]
concept_sequences.append(features)
concept_infos.append({
'concept_id': concept_id,
'timestamp': current_time,
'alpha': current_row.get('alpha', 0),
'alpha_delta': current_row.get('alpha_delta', 0),
'amt_ratio': current_row.get('amt_ratio', 1),
'limit_up_ratio': current_row.get('limit_up_ratio', 0),
'limit_down_ratio': current_row.get('limit_down_ratio', 0),
'rank_pct': current_row.get('rank_pct', 0.5),
'stock_count': current_row.get('stock_count', 0),
'total_amt': current_row.get('total_amt', 0),
})
if not concept_sequences:
continue
# 批量计算异动分数
sequences_array = np.array(concept_sequences)
scores = detector.compute_anomaly_scores(sequences_array)
# 检测异动
for i, (info, score) in enumerate(zip(concept_infos, scores)):
concept_id = info['concept_id']
alpha = info['alpha']
# 过滤小波动
if abs(alpha) < BACKTEST_CONFIG['min_alpha_abs']:
continue
# 检查冷却
if concept_id in cooldown:
last_alert = cooldown[concept_id]
if isinstance(current_time, datetime):
time_diff = (current_time - last_alert).total_seconds() / 60
else:
# timestamp 是字符串或其他格式
time_diff = BACKTEST_CONFIG['cooldown_minutes'] + 1 # 跳过冷却检查
if time_diff < BACKTEST_CONFIG['cooldown_minutes']:
continue
# 判断是否异动
if not detector.is_anomaly(score):
continue
# 记录异动
alert_type = 'surge_up' if alpha > 0 else 'surge_down'
alert = {
'concept_id': concept_id,
'alert_time': current_time,
'trade_date': date,
'alert_type': alert_type,
'anomaly_score': float(score),
'threshold': detector.threshold,
**info
}
minute_alerts.append(alert)
cooldown[concept_id] = current_time
# 按分数排序,限制数量
minute_alerts.sort(key=lambda x: x['anomaly_score'], reverse=True)
alerts.extend(minute_alerts[:BACKTEST_CONFIG['max_alerts_per_minute']])
return alerts
# ==================== 数据库写入 ====================
def save_alerts_to_mysql(alerts: List[Dict], dry_run: bool = False) -> int:
"""保存异动到 MySQL"""
if not alerts:
return 0
if dry_run:
print(f" [Dry Run] 将写入 {len(alerts)} 条异动")
return len(alerts)
saved = 0
with MYSQL_ENGINE.begin() as conn:
for alert in alerts:
try:
# 检查是否已存在
check_sql = text("""
SELECT id FROM concept_minute_alert
WHERE concept_id = :concept_id
AND alert_time = :alert_time
AND trade_date = :trade_date
""")
exists = conn.execute(check_sql, {
'concept_id': alert['concept_id'],
'alert_time': alert['alert_time'],
'trade_date': alert['trade_date'],
}).fetchone()
if exists:
continue
# 插入新记录
insert_sql = text("""
INSERT INTO concept_minute_alert
(concept_id, concept_name, alert_time, alert_type, trade_date,
change_pct, zscore, importance_score, stock_count, extra_info)
VALUES
(:concept_id, :concept_name, :alert_time, :alert_type, :trade_date,
:change_pct, :zscore, :importance_score, :stock_count, :extra_info)
""")
conn.execute(insert_sql, {
'concept_id': alert['concept_id'],
'concept_name': alert.get('concept_name', ''),
'alert_time': alert['alert_time'],
'alert_type': alert['alert_type'],
'trade_date': alert['trade_date'],
'change_pct': alert.get('alpha', 0),
'zscore': alert['anomaly_score'],
'importance_score': alert['anomaly_score'],
'stock_count': alert.get('stock_count', 0),
'extra_info': json.dumps({
'detection_method': 'ml_autoencoder',
'threshold': alert['threshold'],
'alpha': alert.get('alpha', 0),
'amt_ratio': alert.get('amt_ratio', 1),
}, ensure_ascii=False)
})
saved += 1
except Exception as e:
print(f" 保存失败: {alert['concept_id']} - {e}")
return saved
def export_alerts_to_csv(alerts: List[Dict], output_path: str):
"""导出异动到 CSV"""
if not alerts:
return
df = pd.DataFrame(alerts)
df.to_csv(output_path, index=False, encoding='utf-8-sig')
print(f"已导出到: {output_path}")
# ==================== 主函数 ====================
def main():
parser = argparse.ArgumentParser(description='历史异动回测')
parser.add_argument('--data_dir', type=str, default='ml/data',
help='特征数据目录')
parser.add_argument('--checkpoint_dir', type=str, default='ml/checkpoints',
help='模型检查点目录')
parser.add_argument('--start', type=str, required=True,
help='开始日期 (YYYY-MM-DD)')
parser.add_argument('--end', type=str, required=True,
help='结束日期 (YYYY-MM-DD)')
parser.add_argument('--dry-run', action='store_true',
help='只计算,不写入数据库')
parser.add_argument('--export-csv', type=str, default=None,
help='导出 CSV 文件路径')
parser.add_argument('--device', type=str, default='auto',
help='设备 (auto/cuda/cpu)')
args = parser.parse_args()
print("=" * 60)
print("历史异动回测")
print("=" * 60)
print(f"日期范围: {args.start} ~ {args.end}")
print(f"数据目录: {args.data_dir}")
print(f"模型目录: {args.checkpoint_dir}")
print(f"Dry Run: {args.dry_run}")
print("=" * 60)
# 初始化检测器
detector = AnomalyDetector(args.checkpoint_dir, args.device)
# 获取可用日期
dates = get_available_dates(args.data_dir, args.start, args.end)
if not dates:
print(f"未找到 {args.start} ~ {args.end} 范围内的数据")
return
print(f"\n找到 {len(dates)} 天的数据")
# 回测
all_alerts = []
total_saved = 0
for date in tqdm(dates, desc="回测进度"):
# 加载数据
df = load_daily_features(args.data_dir, date)
if df is None or df.empty:
continue
# 回测单天
alerts = backtest_single_day(
detector, df, date,
seq_len=BACKTEST_CONFIG['seq_len']
)
if alerts:
all_alerts.extend(alerts)
# 写入数据库
saved = save_alerts_to_mysql(alerts, dry_run=args.dry_run)
total_saved += saved
if not args.dry_run:
tqdm.write(f" {date}: 检测到 {len(alerts)} 个异动,保存 {saved}")
# 导出 CSV
if args.export_csv and all_alerts:
export_alerts_to_csv(all_alerts, args.export_csv)
# 汇总
print("\n" + "=" * 60)
print("回测完成!")
print("=" * 60)
print(f"总计检测到: {len(all_alerts)} 个异动")
print(f"保存到数据库: {total_saved}")
# 统计
if all_alerts:
df_alerts = pd.DataFrame(all_alerts)
print(f"\n异动类型分布:")
print(df_alerts['alert_type'].value_counts())
print(f"\n异动分数统计:")
print(f" Mean: {df_alerts['anomaly_score'].mean():.4f}")
print(f" Max: {df_alerts['anomaly_score'].max():.4f}")
print(f" Min: {df_alerts['anomaly_score'].min():.4f}")
print("=" * 60)
if __name__ == "__main__":
main()

418
ml/backtest_hybrid.py Normal file
View File

@@ -0,0 +1,418 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
融合异动回测脚本
使用 HybridAnomalyDetector 进行回测:
- 规则评分 + LSTM Autoencoder 融合判断
- 输出更丰富的异动信息
使用方法:
python backtest_hybrid.py --start 2024-01-01 --end 2024-12-01
python backtest_hybrid.py --start 2024-11-01 --dry-run
"""
import os
import sys
import argparse
import json
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional
from collections import defaultdict
import numpy as np
import pandas as pd
from tqdm import tqdm
from sqlalchemy import create_engine, text
# 添加父目录到路径
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from detector import HybridAnomalyDetector, create_detector
# ==================== 配置 ====================
MYSQL_ENGINE = create_engine(
"mysql+pymysql://root:Zzl5588161!@192.168.1.5:3306/stock",
echo=False
)
FEATURES = [
'alpha',
'alpha_delta',
'amt_ratio',
'amt_delta',
'rank_pct',
'limit_up_ratio',
]
BACKTEST_CONFIG = {
'seq_len': 30,
'min_alpha_abs': 0.3, # 降低阈值,让规则也能发挥作用
'cooldown_minutes': 8,
'max_alerts_per_minute': 20,
'clip_value': 10.0,
}
# ==================== 数据加载 ====================
def load_daily_features(data_dir: str, date: str) -> Optional[pd.DataFrame]:
"""加载单天的特征数据"""
file_path = Path(data_dir) / f"features_{date}.parquet"
if not file_path.exists():
return None
df = pd.read_parquet(file_path)
return df
def get_available_dates(data_dir: str, start_date: str, end_date: str) -> List[str]:
"""获取可用的日期列表"""
data_path = Path(data_dir)
all_files = sorted(data_path.glob("features_*.parquet"))
dates = []
for f in all_files:
date = f.stem.replace('features_', '')
if start_date <= date <= end_date:
dates.append(date)
return dates
# ==================== 融合回测 ====================
def backtest_single_day_hybrid(
detector: HybridAnomalyDetector,
df: pd.DataFrame,
date: str,
seq_len: int = 30
) -> List[Dict]:
"""
使用融合检测器回测单天数据
"""
alerts = []
# 按概念分组
grouped = df.groupby('concept_id', sort=False)
# 冷却记录
cooldown = {}
# 获取所有时间点
all_timestamps = sorted(df['timestamp'].unique())
if len(all_timestamps) < seq_len:
return alerts
# 对每个时间点进行检测
for t_idx in range(seq_len - 1, len(all_timestamps)):
current_time = all_timestamps[t_idx]
window_start_time = all_timestamps[t_idx - seq_len + 1]
minute_alerts = []
for concept_id, concept_df in grouped:
# 获取时间窗口内的数据
mask = (concept_df['timestamp'] >= window_start_time) & (concept_df['timestamp'] <= current_time)
window_df = concept_df[mask].sort_values('timestamp')
if len(window_df) < seq_len:
continue
window_df = window_df.tail(seq_len)
# 提取特征序列(给 ML 模型)
sequence = window_df[FEATURES].values
sequence = np.nan_to_num(sequence, nan=0.0, posinf=0.0, neginf=0.0)
sequence = np.clip(sequence, -BACKTEST_CONFIG['clip_value'], BACKTEST_CONFIG['clip_value'])
# 当前时刻特征(给规则系统)
current_row = window_df.iloc[-1]
current_features = {
'alpha': current_row.get('alpha', 0),
'alpha_delta': current_row.get('alpha_delta', 0),
'amt_ratio': current_row.get('amt_ratio', 1),
'amt_delta': current_row.get('amt_delta', 0),
'rank_pct': current_row.get('rank_pct', 0.5),
'limit_up_ratio': current_row.get('limit_up_ratio', 0),
}
# 过滤微小波动
if abs(current_features['alpha']) < BACKTEST_CONFIG['min_alpha_abs']:
continue
# 检查冷却
if concept_id in cooldown:
last_alert = cooldown[concept_id]
if isinstance(current_time, datetime):
time_diff = (current_time - last_alert).total_seconds() / 60
else:
time_diff = BACKTEST_CONFIG['cooldown_minutes'] + 1
if time_diff < BACKTEST_CONFIG['cooldown_minutes']:
continue
# 融合检测
result = detector.detect(current_features, sequence)
if not result.is_anomaly:
continue
# 记录异动
alert = {
'concept_id': concept_id,
'alert_time': current_time,
'trade_date': date,
'alert_type': result.anomaly_type,
'final_score': result.final_score,
'rule_score': result.rule_score,
'ml_score': result.ml_score,
'trigger_reason': result.trigger_reason,
'triggered_rules': list(result.rule_details.keys()),
**current_features,
'stock_count': current_row.get('stock_count', 0),
'total_amt': current_row.get('total_amt', 0),
}
minute_alerts.append(alert)
cooldown[concept_id] = current_time
# 按最终得分排序
minute_alerts.sort(key=lambda x: x['final_score'], reverse=True)
alerts.extend(minute_alerts[:BACKTEST_CONFIG['max_alerts_per_minute']])
return alerts
# ==================== 数据库写入 ====================
def save_alerts_to_mysql(alerts: List[Dict], dry_run: bool = False) -> int:
"""保存异动到 MySQL增强版字段"""
if not alerts:
return 0
if dry_run:
print(f" [Dry Run] 将写入 {len(alerts)} 条异动")
return len(alerts)
saved = 0
with MYSQL_ENGINE.begin() as conn:
for alert in alerts:
try:
# 检查是否已存在
check_sql = text("""
SELECT id FROM concept_minute_alert
WHERE concept_id = :concept_id
AND alert_time = :alert_time
AND trade_date = :trade_date
""")
exists = conn.execute(check_sql, {
'concept_id': alert['concept_id'],
'alert_time': alert['alert_time'],
'trade_date': alert['trade_date'],
}).fetchone()
if exists:
continue
# 插入新记录
insert_sql = text("""
INSERT INTO concept_minute_alert
(concept_id, concept_name, alert_time, alert_type, trade_date,
change_pct, zscore, importance_score, stock_count, extra_info)
VALUES
(:concept_id, :concept_name, :alert_time, :alert_type, :trade_date,
:change_pct, :zscore, :importance_score, :stock_count, :extra_info)
""")
extra_info = {
'detection_method': 'hybrid',
'final_score': alert['final_score'],
'rule_score': alert['rule_score'],
'ml_score': alert['ml_score'],
'trigger_reason': alert['trigger_reason'],
'triggered_rules': alert['triggered_rules'],
'alpha': alert.get('alpha', 0),
'alpha_delta': alert.get('alpha_delta', 0),
'amt_ratio': alert.get('amt_ratio', 1),
}
conn.execute(insert_sql, {
'concept_id': alert['concept_id'],
'concept_name': alert.get('concept_name', ''),
'alert_time': alert['alert_time'],
'alert_type': alert['alert_type'],
'trade_date': alert['trade_date'],
'change_pct': alert.get('alpha', 0),
'zscore': alert['final_score'], # 用最终得分作为 zscore
'importance_score': alert['final_score'],
'stock_count': alert.get('stock_count', 0),
'extra_info': json.dumps(extra_info, ensure_ascii=False)
})
saved += 1
except Exception as e:
print(f" 保存失败: {alert['concept_id']} - {e}")
return saved
def export_alerts_to_csv(alerts: List[Dict], output_path: str):
"""导出异动到 CSV"""
if not alerts:
return
df = pd.DataFrame(alerts)
df.to_csv(output_path, index=False, encoding='utf-8-sig')
print(f"已导出到: {output_path}")
# ==================== 统计分析 ====================
def analyze_alerts(alerts: List[Dict]):
"""分析异动结果"""
if not alerts:
print("无异动数据")
return
df = pd.DataFrame(alerts)
print("\n" + "=" * 60)
print("异动统计分析")
print("=" * 60)
# 1. 基本统计
print(f"\n总异动数: {len(alerts)}")
# 2. 按类型统计
print(f"\n异动类型分布:")
print(df['alert_type'].value_counts())
# 3. 得分统计
print(f"\n得分统计:")
print(f" 最终得分 - Mean: {df['final_score'].mean():.1f}, Max: {df['final_score'].max():.1f}")
print(f" 规则得分 - Mean: {df['rule_score'].mean():.1f}, Max: {df['rule_score'].max():.1f}")
print(f" ML得分 - Mean: {df['ml_score'].mean():.1f}, Max: {df['ml_score'].max():.1f}")
# 4. 触发来源分析
print(f"\n触发来源分析:")
trigger_counts = df['trigger_reason'].apply(
lambda x: '规则' if '规则' in x else ('ML' if 'ML' in x else '融合')
).value_counts()
print(trigger_counts)
# 5. 规则触发频率
all_rules = []
for rules in df['triggered_rules']:
if isinstance(rules, list):
all_rules.extend(rules)
if all_rules:
print(f"\n最常触发的规则 (Top 10):")
from collections import Counter
rule_counts = Counter(all_rules)
for rule, count in rule_counts.most_common(10):
print(f" {rule}: {count}")
# ==================== 主函数 ====================
def main():
parser = argparse.ArgumentParser(description='融合异动回测')
parser.add_argument('--data_dir', type=str, default='ml/data',
help='特征数据目录')
parser.add_argument('--checkpoint_dir', type=str, default='ml/checkpoints',
help='模型检查点目录')
parser.add_argument('--start', type=str, required=True,
help='开始日期 (YYYY-MM-DD)')
parser.add_argument('--end', type=str, default=None,
help='结束日期 (YYYY-MM-DD),默认=start')
parser.add_argument('--dry-run', action='store_true',
help='只计算,不写入数据库')
parser.add_argument('--export-csv', type=str, default=None,
help='导出 CSV 文件路径')
parser.add_argument('--rule-weight', type=float, default=0.6,
help='规则权重 (0-1)')
parser.add_argument('--ml-weight', type=float, default=0.4,
help='ML权重 (0-1)')
args = parser.parse_args()
if args.end is None:
args.end = args.start
print("=" * 60)
print("融合异动回测 (规则 + LSTM)")
print("=" * 60)
print(f"日期范围: {args.start} ~ {args.end}")
print(f"数据目录: {args.data_dir}")
print(f"模型目录: {args.checkpoint_dir}")
print(f"规则权重: {args.rule_weight}")
print(f"ML权重: {args.ml_weight}")
print(f"Dry Run: {args.dry_run}")
print("=" * 60)
# 初始化融合检测器
config = {
'rule_weight': args.rule_weight,
'ml_weight': args.ml_weight,
}
detector = create_detector(args.checkpoint_dir, config)
# 获取可用日期
dates = get_available_dates(args.data_dir, args.start, args.end)
if not dates:
print(f"未找到 {args.start} ~ {args.end} 范围内的数据")
return
print(f"\n找到 {len(dates)} 天的数据")
# 回测
all_alerts = []
total_saved = 0
for date in tqdm(dates, desc="回测进度"):
df = load_daily_features(args.data_dir, date)
if df is None or df.empty:
continue
alerts = backtest_single_day_hybrid(
detector, df, date,
seq_len=BACKTEST_CONFIG['seq_len']
)
if alerts:
all_alerts.extend(alerts)
saved = save_alerts_to_mysql(alerts, dry_run=args.dry_run)
total_saved += saved
if not args.dry_run:
tqdm.write(f" {date}: 检测到 {len(alerts)} 个异动,保存 {saved}")
# 导出 CSV
if args.export_csv and all_alerts:
export_alerts_to_csv(all_alerts, args.export_csv)
# 统计分析
analyze_alerts(all_alerts)
# 汇总
print("\n" + "=" * 60)
print("回测完成!")
print("=" * 60)
print(f"总计检测到: {len(all_alerts)} 个异动")
print(f"保存到数据库: {total_saved}")
print("=" * 60)
if __name__ == "__main__":
main()

571
ml/detector.py Normal file
View File

@@ -0,0 +1,571 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
概念异动检测器 - 融合版
结合两种方法的优势:
1. 规则评分系统:可解释、稳定、覆盖已知模式
2. LSTM Autoencoder发现未知的异常模式
融合策略:
┌─────────────────────────────────────────────────────────┐
│ 输入特征 │
│ (alpha, alpha_delta, amt_ratio, amt_delta, rank_pct, │
│ limit_up_ratio) │
├─────────────────────────────────────────────────────────┤
│ │
│ ┌──────────────┐ ┌──────────────┐ │
│ │ 规则评分系统 │ │ LSTM Autoencoder │ │
│ │ (0-100分) │ │ (重构误差) │ │
│ └──────┬───────┘ └──────┬───────┘ │
│ │ │ │
│ ▼ ▼ │
│ rule_score (0-100) ml_score (标准化后 0-100) │
│ │
├─────────────────────────────────────────────────────────┤
│ 融合策略 │
│ │
│ final_score = w1 * rule_score + w2 * ml_score │
│ │
│ 异动判定: │
│ - rule_score >= 60 → 直接触发(规则强信号) │
│ - ml_score >= 80 → 直接触发ML强信号
│ - final_score >= 50 → 融合触发 │
│ │
└─────────────────────────────────────────────────────────┘
优势:
- 规则系统保证已知模式的检出率
- ML模型捕捉规则未覆盖的异常
- 两者互相验证,减少误报
"""
import json
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass
import numpy as np
import torch
# 尝试导入模型(可能不存在)
try:
from model import LSTMAutoencoder, create_model
HAS_MODEL = True
except ImportError:
HAS_MODEL = False
@dataclass
class AnomalyResult:
"""异动检测结果"""
is_anomaly: bool
final_score: float # 最终得分 (0-100)
rule_score: float # 规则得分 (0-100)
ml_score: float # ML得分 (0-100)
trigger_reason: str # 触发原因
rule_details: Dict # 规则明细
anomaly_type: str # 异动类型: surge_up / surge_down / volume_spike / unknown
class RuleBasedScorer:
"""
基于规则的评分系统
设计原则:
- 每个规则独立打分
- 分数可叠加
- 阈值可配置
"""
# 默认规则配置
DEFAULT_RULES = {
# Alpha 相关(超额收益)
'alpha_strong': {
'condition': lambda r: abs(r.get('alpha', 0)) >= 3.0,
'score': 35,
'description': 'Alpha强信号(|α|≥3%)'
},
'alpha_medium': {
'condition': lambda r: 2.0 <= abs(r.get('alpha', 0)) < 3.0,
'score': 25,
'description': 'Alpha中等(2%≤|α|<3%)'
},
'alpha_weak': {
'condition': lambda r: 1.5 <= abs(r.get('alpha', 0)) < 2.0,
'score': 15,
'description': 'Alpha轻微(1.5%≤|α|<2%)'
},
# Alpha 变化率(加速度)
'alpha_delta_strong': {
'condition': lambda r: abs(r.get('alpha_delta', 0)) >= 1.0,
'score': 30,
'description': 'Alpha加速强(|Δα|≥1%)'
},
'alpha_delta_medium': {
'condition': lambda r: 0.5 <= abs(r.get('alpha_delta', 0)) < 1.0,
'score': 20,
'description': 'Alpha加速中(0.5%≤|Δα|<1%)'
},
# 成交额比率(放量)
'volume_spike_strong': {
'condition': lambda r: r.get('amt_ratio', 1) >= 5.0,
'score': 30,
'description': '极度放量(≥5倍)'
},
'volume_spike_medium': {
'condition': lambda r: 3.0 <= r.get('amt_ratio', 1) < 5.0,
'score': 20,
'description': '显著放量(3-5倍)'
},
'volume_spike_weak': {
'condition': lambda r: 2.0 <= r.get('amt_ratio', 1) < 3.0,
'score': 10,
'description': '轻微放量(2-3倍)'
},
# 成交额变化率
'amt_delta_strong': {
'condition': lambda r: abs(r.get('amt_delta', 0)) >= 1.0,
'score': 15,
'description': '成交额急变(|Δamt|≥100%)'
},
# 排名跳变
'rank_top': {
'condition': lambda r: r.get('rank_pct', 0.5) >= 0.95,
'score': 25,
'description': '排名前5%'
},
'rank_bottom': {
'condition': lambda r: r.get('rank_pct', 0.5) <= 0.05,
'score': 25,
'description': '排名后5%'
},
'rank_high': {
'condition': lambda r: 0.9 <= r.get('rank_pct', 0.5) < 0.95,
'score': 15,
'description': '排名前10%'
},
# 涨停比例
'limit_up_high': {
'condition': lambda r: r.get('limit_up_ratio', 0) >= 0.2,
'score': 25,
'description': '涨停比例≥20%'
},
'limit_up_medium': {
'condition': lambda r: 0.1 <= r.get('limit_up_ratio', 0) < 0.2,
'score': 15,
'description': '涨停比例10-20%'
},
# 组合条件(更可靠的信号)
'alpha_with_volume': {
'condition': lambda r: abs(r.get('alpha', 0)) >= 1.5 and r.get('amt_ratio', 1) >= 2.0,
'score': 20, # 额外加分
'description': 'Alpha+放量组合'
},
'acceleration_with_rank': {
'condition': lambda r: abs(r.get('alpha_delta', 0)) >= 0.5 and (r.get('rank_pct', 0.5) >= 0.9 or r.get('rank_pct', 0.5) <= 0.1),
'score': 15, # 额外加分
'description': '加速+排名异常组合'
},
}
def __init__(self, rules: Dict = None):
"""
初始化规则评分器
Args:
rules: 自定义规则,格式同 DEFAULT_RULES
"""
self.rules = rules or self.DEFAULT_RULES
def score(self, features: Dict) -> Tuple[float, Dict]:
"""
计算规则得分
Args:
features: 特征字典,包含 alpha, alpha_delta, amt_ratio 等
Returns:
score: 总分 (0-100)
details: 触发的规则明细
"""
total_score = 0
triggered_rules = {}
for rule_name, rule_config in self.rules.items():
try:
if rule_config['condition'](features):
total_score += rule_config['score']
triggered_rules[rule_name] = {
'score': rule_config['score'],
'description': rule_config['description']
}
except Exception:
# 忽略规则计算错误
pass
# 限制在 0-100
total_score = min(100, max(0, total_score))
return total_score, triggered_rules
def get_anomaly_type(self, features: Dict) -> str:
"""判断异动类型"""
alpha = features.get('alpha', 0)
amt_ratio = features.get('amt_ratio', 1)
if alpha >= 1.5:
return 'surge_up'
elif alpha <= -1.5:
return 'surge_down'
elif amt_ratio >= 3.0:
return 'volume_spike'
else:
return 'unknown'
class MLScorer:
"""
基于 LSTM Autoencoder 的评分器
将重构误差转换为 0-100 的分数
"""
def __init__(
self,
checkpoint_dir: str = 'ml/checkpoints',
device: str = 'auto'
):
self.checkpoint_dir = Path(checkpoint_dir)
# 设备
if device == 'auto':
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
else:
self.device = torch.device(device)
self.model = None
self.thresholds = None
self.config = None
# 尝试加载模型
self._load_model()
def _load_model(self):
"""加载模型和阈值"""
if not HAS_MODEL:
print("警告: 无法导入模型模块")
return
model_path = self.checkpoint_dir / 'best_model.pt'
thresholds_path = self.checkpoint_dir / 'thresholds.json'
config_path = self.checkpoint_dir / 'config.json'
if not model_path.exists():
print(f"警告: 模型文件不存在 {model_path}")
return
try:
# 加载配置
if config_path.exists():
with open(config_path, 'r') as f:
self.config = json.load(f)
# 加载模型
checkpoint = torch.load(model_path, map_location=self.device)
model_config = self.config.get('model', {}) if self.config else {}
self.model = create_model(model_config)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.model.to(self.device)
self.model.eval()
# 加载阈值
if thresholds_path.exists():
with open(thresholds_path, 'r') as f:
self.thresholds = json.load(f)
print(f"MLScorer 加载成功 (设备: {self.device})")
except Exception as e:
print(f"警告: 模型加载失败 - {e}")
self.model = None
def is_ready(self) -> bool:
"""检查模型是否就绪"""
return self.model is not None
@torch.no_grad()
def score(self, sequence: np.ndarray) -> float:
"""
计算 ML 得分
Args:
sequence: (seq_len, n_features) 或 (batch, seq_len, n_features)
Returns:
score: 0-100 的分数,越高越异常
"""
if not self.is_ready():
return 0.0
# 确保是 3D
if sequence.ndim == 2:
sequence = sequence[np.newaxis, ...]
# 转为 tensor
x = torch.FloatTensor(sequence).to(self.device)
# 计算重构误差
output, _ = self.model(x)
mse = ((output - x) ** 2).mean(dim=-1) # (batch, seq_len)
# 取最后时刻的误差
error = mse[:, -1].cpu().numpy()
# 转换为 0-100 分数
# 使用 p95 阈值作为参考
if self.thresholds:
p95 = self.thresholds.get('p95', 0.1)
p99 = self.thresholds.get('p99', 0.2)
else:
p95, p99 = 0.1, 0.2
# 线性映射p95 -> 50分, p99 -> 80分
# error=0 -> 0分, error>=p99*1.5 -> 100分
score = np.clip(error / p95 * 50, 0, 100)
return float(score[0]) if len(score) == 1 else score.tolist()
class HybridAnomalyDetector:
"""
融合异动检测器
结合规则系统和 ML 模型
"""
# 默认配置
DEFAULT_CONFIG = {
# 权重配置
'rule_weight': 0.6, # 规则权重
'ml_weight': 0.4, # ML权重
# 触发阈值
'rule_trigger': 60, # 规则直接触发阈值
'ml_trigger': 80, # ML直接触发阈值
'fusion_trigger': 50, # 融合触发阈值
# 特征列表
'features': [
'alpha', 'alpha_delta', 'amt_ratio',
'amt_delta', 'rank_pct', 'limit_up_ratio'
],
# 序列长度ML模型需要
'seq_len': 30,
}
def __init__(
self,
config: Dict = None,
checkpoint_dir: str = 'ml/checkpoints',
device: str = 'auto'
):
self.config = {**self.DEFAULT_CONFIG, **(config or {})}
# 初始化评分器
self.rule_scorer = RuleBasedScorer()
self.ml_scorer = MLScorer(checkpoint_dir, device)
print(f"HybridAnomalyDetector 初始化完成")
print(f" 规则权重: {self.config['rule_weight']}")
print(f" ML权重: {self.config['ml_weight']}")
print(f" ML模型: {'就绪' if self.ml_scorer.is_ready() else '未加载'}")
def detect(
self,
features: Dict,
sequence: np.ndarray = None
) -> AnomalyResult:
"""
检测异动
Args:
features: 当前时刻的特征字典
sequence: 历史序列 (seq_len, n_features)ML模型需要
Returns:
AnomalyResult: 检测结果
"""
# 1. 规则评分
rule_score, rule_details = self.rule_scorer.score(features)
# 2. ML评分
ml_score = 0.0
if sequence is not None and self.ml_scorer.is_ready():
ml_score = self.ml_scorer.score(sequence)
# 3. 融合得分
w1 = self.config['rule_weight']
w2 = self.config['ml_weight']
# 如果ML不可用全部权重给规则
if not self.ml_scorer.is_ready():
w1, w2 = 1.0, 0.0
final_score = w1 * rule_score + w2 * ml_score
# 4. 判断是否异动
is_anomaly = False
trigger_reason = ''
if rule_score >= self.config['rule_trigger']:
is_anomaly = True
trigger_reason = f'规则强信号({rule_score:.0f}分)'
elif ml_score >= self.config['ml_trigger']:
is_anomaly = True
trigger_reason = f'ML强信号({ml_score:.0f}分)'
elif final_score >= self.config['fusion_trigger']:
is_anomaly = True
trigger_reason = f'融合触发({final_score:.0f}分)'
# 5. 判断异动类型
anomaly_type = self.rule_scorer.get_anomaly_type(features) if is_anomaly else ''
return AnomalyResult(
is_anomaly=is_anomaly,
final_score=final_score,
rule_score=rule_score,
ml_score=ml_score,
trigger_reason=trigger_reason,
rule_details=rule_details,
anomaly_type=anomaly_type
)
def detect_batch(
self,
features_list: List[Dict],
sequences: np.ndarray = None
) -> List[AnomalyResult]:
"""
批量检测
Args:
features_list: 特征字典列表
sequences: (batch, seq_len, n_features)
Returns:
List[AnomalyResult]
"""
results = []
for i, features in enumerate(features_list):
seq = sequences[i] if sequences is not None else None
result = self.detect(features, seq)
results.append(result)
return results
# ==================== 便捷函数 ====================
def create_detector(
checkpoint_dir: str = 'ml/checkpoints',
config: Dict = None
) -> HybridAnomalyDetector:
"""创建融合检测器"""
return HybridAnomalyDetector(config, checkpoint_dir)
def quick_detect(features: Dict) -> bool:
"""
快速检测只用规则不需要ML模型
适用于:
- 实时检测
- ML模型未训练完成时
"""
scorer = RuleBasedScorer()
score, _ = scorer.score(features)
return score >= 50
# ==================== 测试 ====================
if __name__ == "__main__":
print("=" * 60)
print("融合异动检测器测试")
print("=" * 60)
# 创建检测器
detector = create_detector()
# 测试用例
test_cases = [
{
'name': '正常情况',
'features': {
'alpha': 0.5,
'alpha_delta': 0.1,
'amt_ratio': 1.2,
'amt_delta': 0.1,
'rank_pct': 0.5,
'limit_up_ratio': 0.02
}
},
{
'name': 'Alpha异动',
'features': {
'alpha': 3.5,
'alpha_delta': 0.8,
'amt_ratio': 2.5,
'amt_delta': 0.5,
'rank_pct': 0.92,
'limit_up_ratio': 0.05
}
},
{
'name': '放量异动',
'features': {
'alpha': 1.2,
'alpha_delta': 0.3,
'amt_ratio': 6.0,
'amt_delta': 1.5,
'rank_pct': 0.85,
'limit_up_ratio': 0.08
}
},
{
'name': '涨停潮',
'features': {
'alpha': 2.5,
'alpha_delta': 0.6,
'amt_ratio': 3.5,
'amt_delta': 0.8,
'rank_pct': 0.98,
'limit_up_ratio': 0.25
}
},
]
print("\n测试结果:")
print("-" * 60)
for case in test_cases:
result = detector.detect(case['features'])
print(f"\n{case['name']}:")
print(f" 异动: {'' if result.is_anomaly else ''}")
print(f" 最终得分: {result.final_score:.1f}")
print(f" 规则得分: {result.rule_score:.1f}")
print(f" ML得分: {result.ml_score:.1f}")
if result.is_anomaly:
print(f" 触发原因: {result.trigger_reason}")
print(f" 异动类型: {result.anomaly_type}")
print(f" 触发规则: {list(result.rule_details.keys())}")
print("\n" + "=" * 60)
print("测试完成!")

526
ml/enhanced_detector.py Normal file
View File

@@ -0,0 +1,526 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
增强版概念异动检测器
融合两种检测方法:
1. Alpha-based Z-Score规则方法实时性好
2. Transformer AutoencoderML方法更准确
使用策略:
- 当 ML 模型可用且历史数据足够时,优先使用 ML 方法
- 否则回退到 Alpha-based 方法
- 可以配置两种方法的融合权重
"""
import os
import sys
import logging
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Tuple, Optional
from dataclasses import dataclass, field
from collections import deque
import numpy as np
# 添加父目录到路径
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
logger = logging.getLogger(__name__)
# ==================== 配置 ====================
ENHANCED_CONFIG = {
# 融合策略
'fusion_mode': 'adaptive', # 'ml_only', 'alpha_only', 'adaptive', 'ensemble'
# ML 权重(在 ensemble 模式下)
'ml_weight': 0.6,
'alpha_weight': 0.4,
# ML 模型配置
'ml_checkpoint_dir': 'ml/checkpoints',
'ml_threshold_key': 'p95', # p90, p95, p99
# Alpha 配置(与 concept_alert_alpha.py 一致)
'alpha_zscore_threshold': 2.0,
'alpha_absolute_threshold': 1.5,
'alpha_history_window': 60,
'alpha_min_history': 5,
# 共享配置
'cooldown_minutes': 8,
'max_alerts_per_minute': 15,
'min_alpha_abs': 0.5,
}
# 特征配置(与训练一致)
FEATURE_NAMES = [
'alpha',
'alpha_delta',
'amt_ratio',
'amt_delta',
'rank_pct',
'limit_up_ratio',
]
# ==================== 数据结构 ====================
@dataclass
class AlphaStats:
"""概念的Alpha统计信息"""
history: deque = field(default_factory=lambda: deque(maxlen=ENHANCED_CONFIG['alpha_history_window']))
mean: float = 0.0
std: float = 1.0
def update(self, alpha: float):
self.history.append(alpha)
if len(self.history) >= 2:
self.mean = np.mean(self.history)
self.std = max(np.std(self.history), 0.1)
def get_zscore(self, alpha: float) -> float:
if len(self.history) < ENHANCED_CONFIG['alpha_min_history']:
return 0.0
return (alpha - self.mean) / self.std
def is_ready(self) -> bool:
return len(self.history) >= ENHANCED_CONFIG['alpha_min_history']
@dataclass
class ConceptFeatures:
"""概念的实时特征"""
alpha: float = 0.0
alpha_delta: float = 0.0
amt_ratio: float = 1.0
amt_delta: float = 0.0
rank_pct: float = 0.5
limit_up_ratio: float = 0.0
def to_dict(self) -> Dict[str, float]:
return {
'alpha': self.alpha,
'alpha_delta': self.alpha_delta,
'amt_ratio': self.amt_ratio,
'amt_delta': self.amt_delta,
'rank_pct': self.rank_pct,
'limit_up_ratio': self.limit_up_ratio,
}
# ==================== 增强检测器 ====================
class EnhancedAnomalyDetector:
"""
增强版异动检测器
融合 Alpha-based 和 ML 两种方法
"""
def __init__(
self,
config: Dict = None,
ml_enabled: bool = True
):
self.config = config or ENHANCED_CONFIG
self.ml_enabled = ml_enabled
self.ml_detector = None
# Alpha 统计
self.alpha_stats: Dict[str, AlphaStats] = {}
# 特征历史(用于计算 delta
self.feature_history: Dict[str, deque] = {}
# 冷却记录
self.cooldown_cache: Dict[str, datetime] = {}
# 尝试加载 ML 模型
if ml_enabled:
self._load_ml_model()
logger.info(f"EnhancedAnomalyDetector 初始化完成")
logger.info(f" 融合模式: {self.config['fusion_mode']}")
logger.info(f" ML 可用: {self.ml_detector is not None}")
def _load_ml_model(self):
"""加载 ML 模型"""
try:
from inference import ConceptAnomalyDetector
checkpoint_dir = Path(__file__).parent / 'checkpoints'
if (checkpoint_dir / 'best_model.pt').exists():
self.ml_detector = ConceptAnomalyDetector(
checkpoint_dir=str(checkpoint_dir),
threshold_key=self.config['ml_threshold_key']
)
logger.info("ML 模型加载成功")
else:
logger.warning(f"ML 模型不存在: {checkpoint_dir / 'best_model.pt'}")
except Exception as e:
logger.warning(f"ML 模型加载失败: {e}")
self.ml_detector = None
def _get_alpha_stats(self, concept_id: str) -> AlphaStats:
"""获取或创建 Alpha 统计"""
if concept_id not in self.alpha_stats:
self.alpha_stats[concept_id] = AlphaStats()
return self.alpha_stats[concept_id]
def _get_feature_history(self, concept_id: str) -> deque:
"""获取特征历史"""
if concept_id not in self.feature_history:
self.feature_history[concept_id] = deque(maxlen=10)
return self.feature_history[concept_id]
def _check_cooldown(self, concept_id: str, current_time: datetime) -> bool:
"""检查冷却"""
if concept_id not in self.cooldown_cache:
return False
last_alert = self.cooldown_cache[concept_id]
cooldown_td = (current_time - last_alert).total_seconds() / 60
return cooldown_td < self.config['cooldown_minutes']
def _set_cooldown(self, concept_id: str, current_time: datetime):
"""设置冷却"""
self.cooldown_cache[concept_id] = current_time
def compute_features(
self,
concept_id: str,
alpha: float,
amt_ratio: float,
rank_pct: float,
limit_up_ratio: float
) -> ConceptFeatures:
"""
计算概念的完整特征
Args:
concept_id: 概念ID
alpha: 当前超额收益
amt_ratio: 成交额比率
rank_pct: 排名百分位
limit_up_ratio: 涨停股占比
Returns:
完整特征
"""
history = self._get_feature_history(concept_id)
# 计算变化率
alpha_delta = 0.0
amt_delta = 0.0
if len(history) > 0:
last_features = history[-1]
alpha_delta = alpha - last_features.alpha
if last_features.amt_ratio > 0:
amt_delta = (amt_ratio - last_features.amt_ratio) / last_features.amt_ratio
features = ConceptFeatures(
alpha=alpha,
alpha_delta=alpha_delta,
amt_ratio=amt_ratio,
amt_delta=amt_delta,
rank_pct=rank_pct,
limit_up_ratio=limit_up_ratio,
)
# 更新历史
history.append(features)
return features
def detect_alpha_anomaly(
self,
concept_id: str,
alpha: float
) -> Tuple[bool, float, str]:
"""
Alpha-based 异动检测
Returns:
is_anomaly: 是否异动
score: 异动分数Z-Score 绝对值)
reason: 触发原因
"""
stats = self._get_alpha_stats(concept_id)
# 计算 Z-Score在更新前
zscore = stats.get_zscore(alpha)
# 更新统计
stats.update(alpha)
# 判断
if stats.is_ready():
if abs(zscore) >= self.config['alpha_zscore_threshold']:
return True, abs(zscore), f"Z={zscore:.2f}"
else:
if abs(alpha) >= self.config['alpha_absolute_threshold']:
fake_zscore = alpha / 0.5
return True, abs(fake_zscore), f"Alpha={alpha:+.2f}%"
return False, abs(zscore) if zscore else 0.0, ""
def detect_ml_anomaly(
self,
concept_id: str,
features: ConceptFeatures
) -> Tuple[bool, float]:
"""
ML-based 异动检测
Returns:
is_anomaly: 是否异动
score: 异动分数(重构误差)
"""
if self.ml_detector is None:
return False, 0.0
try:
is_anomaly, score = self.ml_detector.detect(
concept_id,
features.to_dict()
)
return is_anomaly, score or 0.0
except Exception as e:
logger.warning(f"ML 检测失败: {e}")
return False, 0.0
def detect(
self,
concept_id: str,
concept_name: str,
alpha: float,
amt_ratio: float,
rank_pct: float,
limit_up_ratio: float,
change_pct: float,
index_change: float,
current_time: datetime,
**extra_data
) -> Optional[Dict]:
"""
融合检测
Args:
concept_id: 概念ID
concept_name: 概念名称
alpha: 超额收益
amt_ratio: 成交额比率
rank_pct: 排名百分位
limit_up_ratio: 涨停股占比
change_pct: 概念涨跌幅
index_change: 大盘涨跌幅
current_time: 当前时间
**extra_data: 其他数据limit_up_count, stock_count 等)
Returns:
异动信息(如果触发),否则 None
"""
# Alpha 太小,不关注
if abs(alpha) < self.config['min_alpha_abs']:
return None
# 检查冷却
if self._check_cooldown(concept_id, current_time):
return None
# 计算特征
features = self.compute_features(
concept_id, alpha, amt_ratio, rank_pct, limit_up_ratio
)
# 执行检测
fusion_mode = self.config['fusion_mode']
alpha_anomaly, alpha_score, alpha_reason = self.detect_alpha_anomaly(concept_id, alpha)
ml_anomaly, ml_score = False, 0.0
if fusion_mode in ('ml_only', 'adaptive', 'ensemble'):
ml_anomaly, ml_score = self.detect_ml_anomaly(concept_id, features)
# 根据融合模式判断
is_anomaly = False
final_score = 0.0
detection_method = ''
if fusion_mode == 'alpha_only':
is_anomaly = alpha_anomaly
final_score = alpha_score
detection_method = 'alpha'
elif fusion_mode == 'ml_only':
is_anomaly = ml_anomaly
final_score = ml_score
detection_method = 'ml'
elif fusion_mode == 'adaptive':
# 优先 ML回退 Alpha
if self.ml_detector and ml_score > 0:
is_anomaly = ml_anomaly
final_score = ml_score
detection_method = 'ml'
else:
is_anomaly = alpha_anomaly
final_score = alpha_score
detection_method = 'alpha'
elif fusion_mode == 'ensemble':
# 加权融合
# 归一化分数
norm_alpha = min(alpha_score / 5.0, 1.0) # Z > 5 视为 1.0
norm_ml = min(ml_score / (self.ml_detector.threshold if self.ml_detector else 1.0), 1.0)
final_score = (
self.config['alpha_weight'] * norm_alpha +
self.config['ml_weight'] * norm_ml
)
is_anomaly = final_score > 0.5 or alpha_anomaly or ml_anomaly
detection_method = 'ensemble'
if not is_anomaly:
return None
# 构建异动记录
self._set_cooldown(concept_id, current_time)
alert_type = 'surge_up' if alpha > 0 else 'surge_down'
alert = {
'concept_id': concept_id,
'concept_name': concept_name,
'alert_type': alert_type,
'alert_time': current_time,
'change_pct': change_pct,
'alpha': alpha,
'alpha_zscore': alpha_score,
'index_change_pct': index_change,
'detection_method': detection_method,
'alpha_score': alpha_score,
'ml_score': ml_score,
'final_score': final_score,
**extra_data
}
return alert
def batch_detect(
self,
concepts_data: List[Dict],
current_time: datetime
) -> List[Dict]:
"""
批量检测
Args:
concepts_data: 概念数据列表
current_time: 当前时间
Returns:
异动列表(按分数排序,限制数量)
"""
alerts = []
for data in concepts_data:
alert = self.detect(
concept_id=data['concept_id'],
concept_name=data['concept_name'],
alpha=data.get('alpha', 0),
amt_ratio=data.get('amt_ratio', 1.0),
rank_pct=data.get('rank_pct', 0.5),
limit_up_ratio=data.get('limit_up_ratio', 0),
change_pct=data.get('change_pct', 0),
index_change=data.get('index_change', 0),
current_time=current_time,
limit_up_count=data.get('limit_up_count', 0),
limit_down_count=data.get('limit_down_count', 0),
stock_count=data.get('stock_count', 0),
concept_type=data.get('concept_type', 'leaf'),
)
if alert:
alerts.append(alert)
# 排序并限制数量
alerts.sort(key=lambda x: x['final_score'], reverse=True)
return alerts[:self.config['max_alerts_per_minute']]
def reset(self):
"""重置所有状态(新交易日)"""
self.alpha_stats.clear()
self.feature_history.clear()
self.cooldown_cache.clear()
if self.ml_detector:
self.ml_detector.clear_history()
logger.info("检测器状态已重置")
# ==================== 测试 ====================
if __name__ == "__main__":
import random
print("测试 EnhancedAnomalyDetector...")
# 初始化
detector = EnhancedAnomalyDetector(ml_enabled=False) # 不加载 ML可能不存在
# 模拟数据
concepts = [
{'concept_id': 'ai_001', 'concept_name': '人工智能'},
{'concept_id': 'chip_002', 'concept_name': '芯片半导体'},
{'concept_id': 'car_003', 'concept_name': '新能源汽车'},
]
print("\n模拟实时检测...")
current_time = datetime.now()
for minute in range(50):
concepts_data = []
for c in concepts:
# 生成随机数据
alpha = random.gauss(0, 0.8)
amt_ratio = max(0.3, random.gauss(1, 0.3))
rank_pct = random.random()
limit_up_ratio = random.random() * 0.1
# 模拟异动第30分钟人工智能暴涨
if minute == 30 and c['concept_id'] == 'ai_001':
alpha = 4.5
amt_ratio = 2.5
limit_up_ratio = 0.3
concepts_data.append({
**c,
'alpha': alpha,
'amt_ratio': amt_ratio,
'rank_pct': rank_pct,
'limit_up_ratio': limit_up_ratio,
'change_pct': alpha + 0.5,
'index_change': 0.5,
})
# 检测
alerts = detector.batch_detect(concepts_data, current_time)
if alerts:
for alert in alerts:
print(f" t={minute:02d} 🔥 {alert['concept_name']} "
f"Alpha={alert['alpha']:+.2f}% "
f"Score={alert['final_score']:.2f} "
f"Method={alert['detection_method']}")
current_time = current_time.replace(minute=current_time.minute + 1 if current_time.minute < 59 else 0)
print("\n测试完成!")

455
ml/inference.py Normal file
View File

@@ -0,0 +1,455 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
概念异动检测推理服务
在实时场景中使用训练好的 Transformer Autoencoder 进行异动检测
使用方法:
from ml.inference import ConceptAnomalyDetector
detector = ConceptAnomalyDetector('ml/checkpoints')
# 检测异动
features = {...} # 实时特征数据
is_anomaly, score = detector.detect(features)
"""
import os
import json
from pathlib import Path
from typing import Dict, List, Tuple, Optional
from collections import deque
import numpy as np
import torch
from model import TransformerAutoencoder
class ConceptAnomalyDetector:
"""
概念异动检测器
使用训练好的 Transformer Autoencoder 进行实时异动检测
"""
def __init__(
self,
checkpoint_dir: str = 'ml/checkpoints',
device: str = 'auto',
threshold_key: str = 'p95'
):
"""
初始化检测器
Args:
checkpoint_dir: 模型检查点目录
device: 设备 (auto/cuda/cpu)
threshold_key: 使用的阈值键 (p90/p95/p99)
"""
self.checkpoint_dir = Path(checkpoint_dir)
self.threshold_key = threshold_key
# 设备选择
if device == 'auto':
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
else:
self.device = torch.device(device)
# 加载配置
self._load_config()
# 加载模型
self._load_model()
# 加载阈值
self._load_thresholds()
# 加载标准化统计量
self._load_normalization_stats()
# 概念历史数据缓存
# {concept_name: deque(maxlen=seq_len)}
self.history_cache: Dict[str, deque] = {}
print(f"ConceptAnomalyDetector 初始化完成")
print(f" 设备: {self.device}")
print(f" 阈值: {self.threshold_key} = {self.threshold:.6f}")
print(f" 序列长度: {self.seq_len}")
def _load_config(self):
"""加载配置"""
config_path = self.checkpoint_dir / 'config.json'
if not config_path.exists():
raise FileNotFoundError(f"配置文件不存在: {config_path}")
with open(config_path, 'r') as f:
self.config = json.load(f)
self.features = self.config['features']
self.seq_len = self.config['seq_len']
self.model_config = self.config['model']
def _load_model(self):
"""加载模型"""
model_path = self.checkpoint_dir / 'best_model.pt'
if not model_path.exists():
raise FileNotFoundError(f"模型文件不存在: {model_path}")
# 创建模型
self.model = TransformerAutoencoder(**self.model_config)
# 加载权重
checkpoint = torch.load(model_path, map_location=self.device)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.model.to(self.device)
self.model.eval()
print(f"模型已加载: {model_path}")
def _load_thresholds(self):
"""加载阈值"""
thresholds_path = self.checkpoint_dir / 'thresholds.json'
if not thresholds_path.exists():
raise FileNotFoundError(f"阈值文件不存在: {thresholds_path}")
with open(thresholds_path, 'r') as f:
self.thresholds = json.load(f)
if self.threshold_key not in self.thresholds:
available_keys = list(self.thresholds.keys())
raise KeyError(f"阈值键 '{self.threshold_key}' 不存在,可用: {available_keys}")
self.threshold = self.thresholds[self.threshold_key]
def _load_normalization_stats(self):
"""加载标准化统计量"""
stats_path = self.checkpoint_dir / 'normalization_stats.json'
if not stats_path.exists():
raise FileNotFoundError(f"标准化统计量文件不存在: {stats_path}")
with open(stats_path, 'r') as f:
stats = json.load(f)
self.norm_mean = np.array(stats['mean'])
self.norm_std = np.array(stats['std'])
def normalize(self, features: np.ndarray) -> np.ndarray:
"""标准化特征"""
return (features - self.norm_mean) / self.norm_std
def update_history(
self,
concept_name: str,
features: Dict[str, float]
):
"""
更新概念历史数据
Args:
concept_name: 概念名称
features: 当前时刻的特征字典
"""
# 初始化历史缓存
if concept_name not in self.history_cache:
self.history_cache[concept_name] = deque(maxlen=self.seq_len)
# 提取特征向量
feature_vector = np.array([
features.get(f, 0.0) for f in self.features
])
# 处理异常值
feature_vector = np.nan_to_num(feature_vector, nan=0.0, posinf=0.0, neginf=0.0)
# 添加到历史
self.history_cache[concept_name].append(feature_vector)
def get_history_length(self, concept_name: str) -> int:
"""获取概念的历史数据长度"""
if concept_name not in self.history_cache:
return 0
return len(self.history_cache[concept_name])
@torch.no_grad()
def detect(
self,
concept_name: str,
features: Dict[str, float] = None,
return_score: bool = True
) -> Tuple[bool, Optional[float]]:
"""
检测概念是否异动
Args:
concept_name: 概念名称
features: 当前时刻的特征(如果提供,会先更新历史)
return_score: 是否返回异动分数
Returns:
is_anomaly: 是否异动
score: 异动分数(如果 return_score=True
"""
# 更新历史
if features is not None:
self.update_history(concept_name, features)
# 检查历史数据是否足够
if concept_name not in self.history_cache:
return False, None
history = self.history_cache[concept_name]
if len(history) < self.seq_len:
return False, None
# 构建输入序列
sequence = np.array(list(history)) # (seq_len, n_features)
# 标准化
sequence = self.normalize(sequence)
# 转为 tensor
x = torch.FloatTensor(sequence).unsqueeze(0) # (1, seq_len, n_features)
x = x.to(self.device)
# 计算重构误差
error = self.model.compute_reconstruction_error(x, reduction='none')
# 取最后一个时刻的误差作为当前分数
score = error[0, -1].item()
# 判断是否异动
is_anomaly = score > self.threshold
if return_score:
return is_anomaly, score
else:
return is_anomaly, None
@torch.no_grad()
def batch_detect(
self,
concept_features: Dict[str, Dict[str, float]]
) -> Dict[str, Tuple[bool, float]]:
"""
批量检测多个概念
Args:
concept_features: {concept_name: {feature_name: value}}
Returns:
results: {concept_name: (is_anomaly, score)}
"""
results = {}
for concept_name, features in concept_features.items():
is_anomaly, score = self.detect(concept_name, features)
results[concept_name] = (is_anomaly, score)
return results
def get_anomaly_type(
self,
concept_name: str,
features: Dict[str, float]
) -> str:
"""
判断异动类型
Args:
concept_name: 概念名称
features: 当前特征
Returns:
anomaly_type: 'surge_up' / 'surge_down' / 'normal'
"""
is_anomaly, score = self.detect(concept_name, features)
if not is_anomaly:
return 'normal'
# 根据 alpha 判断涨跌
alpha = features.get('alpha', 0.0)
if alpha > 0:
return 'surge_up'
else:
return 'surge_down'
def get_top_anomalies(
self,
concept_features: Dict[str, Dict[str, float]],
top_k: int = 10
) -> List[Tuple[str, float, str]]:
"""
获取异动分数最高的 top_k 个概念
Args:
concept_features: {concept_name: {feature_name: value}}
top_k: 返回数量
Returns:
anomalies: [(concept_name, score, anomaly_type), ...]
"""
results = self.batch_detect(concept_features)
# 按分数排序
sorted_results = sorted(
[(name, is_anomaly, score) for name, (is_anomaly, score) in results.items() if score is not None],
key=lambda x: x[2],
reverse=True
)
# 取 top_k
top_anomalies = []
for name, is_anomaly, score in sorted_results[:top_k]:
if is_anomaly:
alpha = concept_features[name].get('alpha', 0.0)
anomaly_type = 'surge_up' if alpha > 0 else 'surge_down'
top_anomalies.append((name, score, anomaly_type))
return top_anomalies
def clear_history(self, concept_name: str = None):
"""
清除历史缓存
Args:
concept_name: 概念名称(如果为 None清除所有
"""
if concept_name is None:
self.history_cache.clear()
elif concept_name in self.history_cache:
del self.history_cache[concept_name]
# ==================== 集成到现有系统 ====================
class MLAnomalyService:
"""
ML 异动检测服务
用于替换或增强现有的 Alpha-based 检测
"""
def __init__(
self,
checkpoint_dir: str = 'ml/checkpoints',
fallback_to_alpha: bool = True
):
"""
Args:
checkpoint_dir: 模型检查点目录
fallback_to_alpha: 当 ML 模型不可用时是否回退到 Alpha 方法
"""
self.fallback_to_alpha = fallback_to_alpha
self.ml_detector = None
try:
self.ml_detector = ConceptAnomalyDetector(checkpoint_dir)
print("ML 异动检测服务初始化成功")
except Exception as e:
print(f"ML 模型加载失败: {e}")
if not fallback_to_alpha:
raise
print("将回退到 Alpha-based 检测")
def is_ml_available(self) -> bool:
"""检查 ML 模型是否可用"""
return self.ml_detector is not None
def detect_anomaly(
self,
concept_name: str,
features: Dict[str, float],
alpha_threshold: float = 2.0
) -> Tuple[bool, float, str]:
"""
检测异动
Args:
concept_name: 概念名称
features: 特征字典(需包含 alpha, amt_ratio 等)
alpha_threshold: Alpha Z-Score 阈值(用于回退)
Returns:
is_anomaly: 是否异动
score: 异动分数
method: 检测方法 ('ml' / 'alpha')
"""
# 优先使用 ML 检测
if self.ml_detector is not None:
history_len = self.ml_detector.get_history_length(concept_name)
# 历史数据足够时使用 ML
if history_len >= self.ml_detector.seq_len - 1:
is_anomaly, score = self.ml_detector.detect(concept_name, features)
if score is not None:
return is_anomaly, score, 'ml'
else:
# 更新历史但使用 Alpha 方法
self.ml_detector.update_history(concept_name, features)
# 回退到 Alpha 方法
if self.fallback_to_alpha:
alpha = features.get('alpha', 0.0)
alpha_zscore = features.get('alpha_zscore', 0.0)
is_anomaly = abs(alpha_zscore) > alpha_threshold
score = abs(alpha_zscore)
return is_anomaly, score, 'alpha'
return False, 0.0, 'none'
# ==================== 测试 ====================
if __name__ == "__main__":
import random
print("测试 ConceptAnomalyDetector...")
# 检查模型是否存在
checkpoint_dir = Path('ml/checkpoints')
if not (checkpoint_dir / 'best_model.pt').exists():
print("模型文件不存在,跳过测试")
print("请先运行 train.py 训练模型")
exit(0)
# 初始化检测器
detector = ConceptAnomalyDetector('ml/checkpoints')
# 模拟数据
print("\n模拟实时检测...")
concept_name = "人工智能"
for i in range(40):
# 生成随机特征
features = {
'alpha': random.gauss(0, 1),
'alpha_delta': random.gauss(0, 0.5),
'amt_ratio': random.gauss(1, 0.3),
'amt_delta': random.gauss(0, 0.2),
'rank_pct': random.random(),
'limit_up_ratio': random.random() * 0.1,
}
# 在第 35 分钟模拟异动
if i == 35:
features['alpha'] = 5.0
features['alpha_delta'] = 2.0
features['amt_ratio'] = 3.0
is_anomaly, score = detector.detect(concept_name, features)
history_len = detector.get_history_length(concept_name)
if score is not None:
status = "🔥 异动!" if is_anomaly else "正常"
print(f" t={i:02d} | 历史={history_len} | 分数={score:.4f} | {status}")
else:
print(f" t={i:02d} | 历史={history_len} | 数据不足")
print("\n测试完成!")

390
ml/model.py Normal file
View File

@@ -0,0 +1,390 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
LSTM Autoencoder 模型定义
用于概念异动检测:
- 学习"正常"市场模式
- 重构误差大的时刻 = 异动
模型结构(简洁有效):
┌─────────────────────────────────────┐
│ 输入: (batch, seq_len, n_features) │
│ 过去30分钟的特征序列 │
├─────────────────────────────────────┤
│ LSTM Encoder │
│ - 双向 LSTM │
│ - 输出最后隐藏状态 │
├─────────────────────────────────────┤
│ Bottleneck (压缩层) │
│ 降维到 latent_dim关键
├─────────────────────────────────────┤
│ LSTM Decoder │
│ - 单向 LSTM │
│ - 重构序列 │
├─────────────────────────────────────┤
│ 输出: (batch, seq_len, n_features) │
│ 重构的特征序列 │
└─────────────────────────────────────┘
为什么用 LSTM 而不是 Transformer:
1. 参数更少,不容易过拟合
2. 对于 6 维特征足够用
3. 训练更稳定
4. 瓶颈约束更容易控制
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple
class LSTMAutoencoder(nn.Module):
"""
LSTM Autoencoder for Anomaly Detection
设计原则:
- 足够简单,避免过拟合
- 瓶颈层严格限制,迫使模型只学习主要模式
- 异常难以通过狭窄瓶颈,重构误差大
"""
def __init__(
self,
n_features: int = 6,
hidden_dim: int = 32, # LSTM 隐藏维度(小!)
latent_dim: int = 4, # 瓶颈维度(非常小!关键参数)
num_layers: int = 1, # LSTM 层数
dropout: float = 0.2,
bidirectional: bool = True, # 双向编码器
):
super().__init__()
self.n_features = n_features
self.hidden_dim = hidden_dim
self.latent_dim = latent_dim
self.num_layers = num_layers
self.bidirectional = bidirectional
self.num_directions = 2 if bidirectional else 1
# Encoder: 双向 LSTM
self.encoder = nn.LSTM(
input_size=n_features,
hidden_size=hidden_dim,
num_layers=num_layers,
batch_first=True,
dropout=dropout if num_layers > 1 else 0,
bidirectional=bidirectional
)
# Bottleneck: 压缩到极小的 latent space
encoder_output_dim = hidden_dim * self.num_directions
self.bottleneck_down = nn.Sequential(
nn.Linear(encoder_output_dim, latent_dim),
nn.Tanh(), # 限制范围,增加约束
)
self.bottleneck_up = nn.Sequential(
nn.Linear(latent_dim, hidden_dim),
nn.ReLU(),
)
# Decoder: 单向 LSTM
self.decoder = nn.LSTM(
input_size=hidden_dim,
hidden_size=hidden_dim,
num_layers=num_layers,
batch_first=True,
dropout=dropout if num_layers > 1 else 0,
bidirectional=False # 解码器用单向
)
# 输出层
self.output_layer = nn.Linear(hidden_dim, n_features)
# Dropout
self.dropout = nn.Dropout(dropout)
# 初始化
self._init_weights()
def _init_weights(self):
"""初始化权重"""
for name, param in self.named_parameters():
if 'weight_ih' in name:
nn.init.xavier_uniform_(param)
elif 'weight_hh' in name:
nn.init.orthogonal_(param)
elif 'bias' in name:
nn.init.zeros_(param)
def encode(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
编码器
Args:
x: (batch, seq_len, n_features)
Returns:
latent: (batch, seq_len, latent_dim) 每个时间步的压缩表示
encoder_outputs: (batch, seq_len, hidden_dim * num_directions)
"""
# LSTM 编码
encoder_outputs, (h_n, c_n) = self.encoder(x)
# encoder_outputs: (batch, seq_len, hidden_dim * num_directions)
encoder_outputs = self.dropout(encoder_outputs)
# 压缩到 latent space对每个时间步
latent = self.bottleneck_down(encoder_outputs)
# latent: (batch, seq_len, latent_dim)
return latent, encoder_outputs
def decode(self, latent: torch.Tensor, seq_len: int) -> torch.Tensor:
"""
解码器
Args:
latent: (batch, seq_len, latent_dim)
seq_len: 序列长度
Returns:
output: (batch, seq_len, n_features)
"""
# 从 latent space 恢复
decoder_input = self.bottleneck_up(latent)
# decoder_input: (batch, seq_len, hidden_dim)
# LSTM 解码
decoder_outputs, _ = self.decoder(decoder_input)
# decoder_outputs: (batch, seq_len, hidden_dim)
decoder_outputs = self.dropout(decoder_outputs)
# 投影到原始特征空间
output = self.output_layer(decoder_outputs)
# output: (batch, seq_len, n_features)
return output
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
前向传播
Args:
x: (batch, seq_len, n_features)
Returns:
output: (batch, seq_len, n_features) 重构结果
latent: (batch, seq_len, latent_dim) 隐向量
"""
batch_size, seq_len, _ = x.shape
# 编码
latent, _ = self.encode(x)
# 解码
output = self.decode(latent, seq_len)
return output, latent
def compute_reconstruction_error(
self,
x: torch.Tensor,
reduction: str = 'none'
) -> torch.Tensor:
"""
计算重构误差
Args:
x: (batch, seq_len, n_features)
reduction: 'none' | 'mean' | 'sum'
Returns:
error: 重构误差
"""
output, _ = self.forward(x)
# MSE per feature per timestep
error = F.mse_loss(output, x, reduction='none')
if reduction == 'none':
# (batch, seq_len, n_features) -> (batch, seq_len)
return error.mean(dim=-1)
elif reduction == 'mean':
return error.mean()
elif reduction == 'sum':
return error.sum()
else:
raise ValueError(f"Unknown reduction: {reduction}")
def detect_anomaly(
self,
x: torch.Tensor,
threshold: float = None,
return_scores: bool = True
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
检测异动
Args:
x: (batch, seq_len, n_features)
threshold: 异动阈值(如果为 None只返回分数
return_scores: 是否返回异动分数
Returns:
is_anomaly: (batch, seq_len) bool tensor (if threshold is not None)
scores: (batch, seq_len) 异动分数 (if return_scores)
"""
scores = self.compute_reconstruction_error(x, reduction='none')
is_anomaly = None
if threshold is not None:
is_anomaly = scores > threshold
if return_scores:
return is_anomaly, scores
else:
return is_anomaly, None
# 为了兼容性,创建别名
TransformerAutoencoder = LSTMAutoencoder
# ==================== 损失函数 ====================
class AnomalyDetectionLoss(nn.Module):
"""
异动检测损失函数
简单的 MSE 重构损失
"""
def __init__(
self,
feature_weights: torch.Tensor = None,
):
super().__init__()
self.feature_weights = feature_weights
def forward(
self,
output: torch.Tensor,
target: torch.Tensor,
latent: torch.Tensor = None
) -> Tuple[torch.Tensor, dict]:
"""
Args:
output: (batch, seq_len, n_features) 重构结果
target: (batch, seq_len, n_features) 原始输入
latent: (batch, seq_len, latent_dim) 隐向量(未使用)
Returns:
loss: 总损失
loss_dict: 各项损失详情
"""
# 重构损失 (MSE)
mse = F.mse_loss(output, target, reduction='none')
# 特征加权(可选)
if self.feature_weights is not None:
weights = self.feature_weights.to(mse.device)
mse = mse * weights
reconstruction_loss = mse.mean()
loss_dict = {
'total': reconstruction_loss.item(),
'reconstruction': reconstruction_loss.item(),
}
return reconstruction_loss, loss_dict
# ==================== 工具函数 ====================
def count_parameters(model: nn.Module) -> int:
"""统计模型参数量"""
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def create_model(config: dict = None) -> LSTMAutoencoder:
"""
创建模型
默认使用小型 LSTM 配置,适合异动检测
"""
default_config = {
'n_features': 6,
'hidden_dim': 32, # 小!
'latent_dim': 4, # 非常小!关键
'num_layers': 1,
'dropout': 0.2,
'bidirectional': True,
}
if config:
# 兼容旧的 Transformer 配置键名
if 'd_model' in config:
config['hidden_dim'] = config.pop('d_model') // 2
if 'num_encoder_layers' in config:
config['num_layers'] = config.pop('num_encoder_layers')
if 'num_decoder_layers' in config:
config.pop('num_decoder_layers')
if 'nhead' in config:
config.pop('nhead')
if 'dim_feedforward' in config:
config.pop('dim_feedforward')
if 'max_seq_len' in config:
config.pop('max_seq_len')
if 'use_instance_norm' in config:
config.pop('use_instance_norm')
default_config.update(config)
model = LSTMAutoencoder(**default_config)
param_count = count_parameters(model)
print(f"模型参数量: {param_count:,}")
if param_count > 100000:
print(f"⚠️ 警告: 参数量较大({param_count:,}),可能过拟合")
else:
print(f"✓ 参数量适中LSTM Autoencoder")
return model
if __name__ == "__main__":
# 测试模型
print("测试 LSTM Autoencoder...")
# 创建模型
model = create_model()
# 测试输入
batch_size = 32
seq_len = 30
n_features = 6
x = torch.randn(batch_size, seq_len, n_features)
# 前向传播
output, latent = model(x)
print(f"输入形状: {x.shape}")
print(f"输出形状: {output.shape}")
print(f"隐向量形状: {latent.shape}")
# 计算重构误差
error = model.compute_reconstruction_error(x)
print(f"重构误差形状: {error.shape}")
print(f"平均重构误差: {error.mean().item():.4f}")
# 测试异动检测
is_anomaly, scores = model.detect_anomaly(x, threshold=0.5)
print(f"异动检测结果形状: {is_anomaly.shape if is_anomaly is not None else 'None'}")
print(f"异动分数形状: {scores.shape}")
# 测试损失函数
criterion = AnomalyDetectionLoss()
loss, loss_dict = criterion(output, x, latent)
print(f"损失: {loss.item():.4f}")
print("\n测试通过!")

501
ml/prepare_data.py Normal file
View File

@@ -0,0 +1,501 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
数据准备脚本 - 为 Transformer Autoencoder 准备训练数据
从 ClickHouse 提取历史分钟数据,计算以下特征:
1. alpha - 超额收益(概念涨幅 - 大盘涨幅)
2. alpha_delta - Alpha 变化率5分钟
3. amt_ratio - 成交额相对均值(当前/过去20分钟均值
4. amt_delta - 成交额变化率
5. rank_pct - Alpha 排名百分位
6. limit_up_ratio - 涨停股占比
输出按交易日存储的特征文件parquet格式
"""
import os
import sys
import numpy as np
import pandas as pd
from datetime import datetime, timedelta, date
from sqlalchemy import create_engine, text
from elasticsearch import Elasticsearch
from clickhouse_driver import Client
import hashlib
import json
import logging
from typing import Dict, List, Set, Tuple
from concurrent.futures import ThreadPoolExecutor, as_completed
import warnings
warnings.filterwarnings('ignore')
# ==================== 配置 ====================
MYSQL_ENGINE = create_engine(
"mysql+pymysql://root:Zzl5588161!@192.168.1.5:3306/stock",
echo=False
)
ES_CLIENT = Elasticsearch(['http://127.0.0.1:9200'])
ES_INDEX = 'concept_library_v3'
CLICKHOUSE_CONFIG = {
'host': '127.0.0.1',
'port': 9000,
'user': 'default',
'password': 'Zzl33818!',
'database': 'stock'
}
# 输出目录
OUTPUT_DIR = os.path.join(os.path.dirname(__file__), 'data')
os.makedirs(OUTPUT_DIR, exist_ok=True)
# 特征计算参数
FEATURE_CONFIG = {
'alpha_delta_window': 5, # Alpha变化窗口分钟
'amt_ma_window': 20, # 成交额均值窗口(分钟)
'limit_up_threshold': 9.8, # 涨停阈值(%
'limit_down_threshold': -9.8, # 跌停阈值(%
}
REFERENCE_INDEX = '000001.SH'
# ==================== 日志 ====================
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# ==================== 工具函数 ====================
def get_ch_client():
return Client(**CLICKHOUSE_CONFIG)
def generate_id(name: str) -> str:
return hashlib.md5(name.encode('utf-8')).hexdigest()[:16]
def code_to_ch_format(code: str) -> str:
if not code or len(code) != 6 or not code.isdigit():
return None
if code.startswith('6'):
return f"{code}.SH"
elif code.startswith('0') or code.startswith('3'):
return f"{code}.SZ"
else:
return f"{code}.BJ"
# ==================== 获取概念列表 ====================
def get_all_concepts() -> List[dict]:
"""从ES获取所有叶子概念"""
concepts = []
query = {
"query": {"match_all": {}},
"size": 100,
"_source": ["concept_id", "concept", "stocks"]
}
resp = ES_CLIENT.search(index=ES_INDEX, body=query, scroll='2m')
scroll_id = resp['_scroll_id']
hits = resp['hits']['hits']
while len(hits) > 0:
for hit in hits:
source = hit['_source']
stocks = []
if 'stocks' in source and isinstance(source['stocks'], list):
for stock in source['stocks']:
if isinstance(stock, dict) and 'code' in stock and stock['code']:
stocks.append(stock['code'])
if stocks:
concepts.append({
'concept_id': source.get('concept_id'),
'concept_name': source.get('concept'),
'stocks': stocks
})
resp = ES_CLIENT.scroll(scroll_id=scroll_id, scroll='2m')
scroll_id = resp['_scroll_id']
hits = resp['hits']['hits']
ES_CLIENT.clear_scroll(scroll_id=scroll_id)
logger.info(f"获取到 {len(concepts)} 个概念")
return concepts
# ==================== 获取交易日列表 ====================
def get_trading_days(start_date: str, end_date: str) -> List[str]:
"""获取交易日列表"""
client = get_ch_client()
query = f"""
SELECT DISTINCT toDate(timestamp) as trade_date
FROM stock_minute
WHERE toDate(timestamp) >= '{start_date}'
AND toDate(timestamp) <= '{end_date}'
ORDER BY trade_date
"""
result = client.execute(query)
days = [row[0].strftime('%Y-%m-%d') for row in result]
logger.info(f"找到 {len(days)} 个交易日: {days[0]} ~ {days[-1]}")
return days
# ==================== 获取单日数据 ====================
def get_daily_stock_data(trade_date: str, stock_codes: List[str]) -> pd.DataFrame:
"""获取单日所有股票的分钟数据"""
client = get_ch_client()
# 转换代码格式
ch_codes = []
code_map = {}
for code in stock_codes:
ch_code = code_to_ch_format(code)
if ch_code:
ch_codes.append(ch_code)
code_map[ch_code] = code
if not ch_codes:
return pd.DataFrame()
ch_codes_str = "','".join(ch_codes)
query = f"""
SELECT
code,
timestamp,
close,
volume,
amt
FROM stock_minute
WHERE toDate(timestamp) = '{trade_date}'
AND code IN ('{ch_codes_str}')
ORDER BY code, timestamp
"""
result = client.execute(query)
if not result:
return pd.DataFrame()
df = pd.DataFrame(result, columns=['ch_code', 'timestamp', 'close', 'volume', 'amt'])
df['code'] = df['ch_code'].map(code_map)
df = df.dropna(subset=['code'])
return df[['code', 'timestamp', 'close', 'volume', 'amt']]
def get_daily_index_data(trade_date: str, index_code: str = REFERENCE_INDEX) -> pd.DataFrame:
"""获取单日指数分钟数据"""
client = get_ch_client()
query = f"""
SELECT
timestamp,
close,
volume,
amt
FROM index_minute
WHERE toDate(timestamp) = '{trade_date}'
AND code = '{index_code}'
ORDER BY timestamp
"""
result = client.execute(query)
if not result:
return pd.DataFrame()
df = pd.DataFrame(result, columns=['timestamp', 'close', 'volume', 'amt'])
return df
def get_prev_close(stock_codes: List[str], trade_date: str) -> Dict[str, float]:
"""获取昨收价"""
valid_codes = [c for c in stock_codes if c and len(c) == 6 and c.isdigit()]
if not valid_codes:
return {}
codes_str = "','".join(valid_codes)
query = f"""
SELECT SECCODE, F002N
FROM ea_trade
WHERE SECCODE IN ('{codes_str}')
AND TRADEDATE = (
SELECT MAX(TRADEDATE) FROM ea_trade WHERE TRADEDATE < '{trade_date}'
)
AND F002N IS NOT NULL AND F002N > 0
"""
try:
with MYSQL_ENGINE.connect() as conn:
result = conn.execute(text(query))
return {row[0]: float(row[1]) for row in result if row[1]}
except Exception as e:
logger.error(f"获取昨收价失败: {e}")
return {}
def get_index_prev_close(trade_date: str, index_code: str = REFERENCE_INDEX) -> float:
"""获取指数昨收价"""
code_no_suffix = index_code.split('.')[0]
try:
with MYSQL_ENGINE.connect() as conn:
result = conn.execute(text("""
SELECT F006N FROM ea_exchangetrade
WHERE INDEXCODE = :code AND TRADEDATE < :today
ORDER BY TRADEDATE DESC LIMIT 1
"""), {'code': code_no_suffix, 'today': trade_date}).fetchone()
if result and result[0]:
return float(result[0])
except Exception as e:
logger.error(f"获取指数昨收失败: {e}")
return None
# ==================== 计算特征 ====================
def compute_daily_features(
trade_date: str,
concepts: List[dict],
all_stocks: List[str]
) -> pd.DataFrame:
"""
计算单日所有概念的特征
返回 DataFrame:
- index: (timestamp, concept_id)
- columns: alpha, alpha_delta, amt_ratio, amt_delta, rank_pct, limit_up_ratio
"""
# 1. 获取数据
logger.info(f" 获取股票数据...")
stock_df = get_daily_stock_data(trade_date, all_stocks)
if stock_df.empty:
logger.warning(f" 无股票数据")
return pd.DataFrame()
logger.info(f" 获取指数数据...")
index_df = get_daily_index_data(trade_date)
if index_df.empty:
logger.warning(f" 无指数数据")
return pd.DataFrame()
# 2. 获取昨收价
logger.info(f" 获取昨收价...")
prev_close = get_prev_close(all_stocks, trade_date)
index_prev_close = get_index_prev_close(trade_date)
if not prev_close or not index_prev_close:
logger.warning(f" 无昨收价数据")
return pd.DataFrame()
# 3. 计算股票涨跌幅和成交额
stock_df['prev_close'] = stock_df['code'].map(prev_close)
stock_df = stock_df.dropna(subset=['prev_close'])
stock_df['change_pct'] = (stock_df['close'] - stock_df['prev_close']) / stock_df['prev_close'] * 100
# 4. 计算指数涨跌幅
index_df['change_pct'] = (index_df['close'] - index_prev_close) / index_prev_close * 100
index_change_map = dict(zip(index_df['timestamp'], index_df['change_pct']))
# 5. 获取所有时间点
timestamps = sorted(stock_df['timestamp'].unique())
logger.info(f" 时间点数: {len(timestamps)}")
# 6. 按时间点计算概念特征
results = []
# 概念到股票的映射
concept_stocks = {c['concept_id']: set(c['stocks']) for c in concepts}
concept_names = {c['concept_id']: c['concept_name'] for c in concepts}
# 历史数据缓存(用于计算变化率)
concept_history = {cid: {'alpha': [], 'amt': []} for cid in concept_stocks}
for ts in timestamps:
ts_stock_data = stock_df[stock_df['timestamp'] == ts]
index_change = index_change_map.get(ts, 0)
# 股票涨跌幅和成交额字典
stock_change = dict(zip(ts_stock_data['code'], ts_stock_data['change_pct']))
stock_amt = dict(zip(ts_stock_data['code'], ts_stock_data['amt']))
concept_features = []
for concept_id, stocks in concept_stocks.items():
# 该概念的股票数据
concept_changes = [stock_change[s] for s in stocks if s in stock_change]
concept_amts = [stock_amt.get(s, 0) for s in stocks if s in stock_change]
if not concept_changes:
continue
# 基础统计
avg_change = np.mean(concept_changes)
total_amt = sum(concept_amts)
# Alpha = 概念涨幅 - 指数涨幅
alpha = avg_change - index_change
# 涨停/跌停股占比
limit_up_count = sum(1 for c in concept_changes if c >= FEATURE_CONFIG['limit_up_threshold'])
limit_down_count = sum(1 for c in concept_changes if c <= FEATURE_CONFIG['limit_down_threshold'])
limit_up_ratio = limit_up_count / len(concept_changes)
limit_down_ratio = limit_down_count / len(concept_changes)
# 更新历史
history = concept_history[concept_id]
history['alpha'].append(alpha)
history['amt'].append(total_amt)
# 计算变化率
alpha_delta = 0
if len(history['alpha']) > FEATURE_CONFIG['alpha_delta_window']:
alpha_delta = alpha - history['alpha'][-FEATURE_CONFIG['alpha_delta_window']-1]
# 成交额相对均值
amt_ratio = 1.0
amt_delta = 0
if len(history['amt']) > FEATURE_CONFIG['amt_ma_window']:
amt_ma = np.mean(history['amt'][-FEATURE_CONFIG['amt_ma_window']-1:-1])
if amt_ma > 0:
amt_ratio = total_amt / amt_ma
amt_delta = total_amt - history['amt'][-2] if len(history['amt']) > 1 else 0
concept_features.append({
'concept_id': concept_id,
'alpha': alpha,
'alpha_delta': alpha_delta,
'amt_ratio': amt_ratio,
'amt_delta': amt_delta,
'limit_up_ratio': limit_up_ratio,
'limit_down_ratio': limit_down_ratio,
'total_amt': total_amt,
'stock_count': len(concept_changes),
})
if not concept_features:
continue
# 计算排名百分位
concept_df = pd.DataFrame(concept_features)
concept_df['rank_pct'] = concept_df['alpha'].rank(pct=True)
# 添加时间戳
concept_df['timestamp'] = ts
results.append(concept_df)
if not results:
return pd.DataFrame()
# 合并所有时间点
final_df = pd.concat(results, ignore_index=True)
# 标准化成交额变化率
if 'amt_delta' in final_df.columns:
amt_delta_std = final_df['amt_delta'].std()
if amt_delta_std > 0:
final_df['amt_delta'] = final_df['amt_delta'] / amt_delta_std
logger.info(f" 计算完成: {len(final_df)} 条记录")
return final_df
# ==================== 主流程 ====================
def process_single_day(trade_date: str, concepts: List[dict], all_stocks: List[str]) -> str:
"""处理单个交易日"""
output_file = os.path.join(OUTPUT_DIR, f'features_{trade_date}.parquet')
# 检查是否已处理
if os.path.exists(output_file):
logger.info(f"[{trade_date}] 已存在,跳过")
return output_file
logger.info(f"[{trade_date}] 开始处理...")
try:
df = compute_daily_features(trade_date, concepts, all_stocks)
if df.empty:
logger.warning(f"[{trade_date}] 无数据")
return None
# 保存
df.to_parquet(output_file, index=False)
logger.info(f"[{trade_date}] 保存完成: {output_file}")
return output_file
except Exception as e:
logger.error(f"[{trade_date}] 处理失败: {e}")
import traceback
traceback.print_exc()
return None
def main():
import argparse
parser = argparse.ArgumentParser(description='准备训练数据')
parser.add_argument('--start', type=str, default='2022-01-01', help='开始日期')
parser.add_argument('--end', type=str, default=None, help='结束日期(默认今天)')
parser.add_argument('--workers', type=int, default=1, help='并行数建议1避免数据库压力')
args = parser.parse_args()
end_date = args.end or datetime.now().strftime('%Y-%m-%d')
logger.info("=" * 60)
logger.info("数据准备 - Transformer Autoencoder 训练数据")
logger.info("=" * 60)
logger.info(f"日期范围: {args.start} ~ {end_date}")
# 1. 获取概念列表
concepts = get_all_concepts()
# 收集所有股票
all_stocks = list(set(s for c in concepts for s in c['stocks']))
logger.info(f"股票总数: {len(all_stocks)}")
# 2. 获取交易日列表
trading_days = get_trading_days(args.start, end_date)
if not trading_days:
logger.error("无交易日数据")
return
# 3. 处理每个交易日
logger.info(f"\n开始处理 {len(trading_days)} 个交易日...")
success_count = 0
for i, trade_date in enumerate(trading_days):
logger.info(f"\n[{i+1}/{len(trading_days)}] {trade_date}")
result = process_single_day(trade_date, concepts, all_stocks)
if result:
success_count += 1
logger.info("\n" + "=" * 60)
logger.info(f"处理完成: {success_count}/{len(trading_days)} 个交易日")
logger.info(f"数据保存在: {OUTPUT_DIR}")
logger.info("=" * 60)
if __name__ == "__main__":
main()

25
ml/requirements.txt Normal file
View File

@@ -0,0 +1,25 @@
# 概念异动检测 ML 模块依赖
# 安装: pip install -r ml/requirements.txt
# PyTorch (根据 CUDA 版本选择)
# 5090 显卡需要 CUDA 12.x
# pip install torch --index-url https://download.pytorch.org/whl/cu124
torch>=2.0.0
# 数据处理
numpy>=1.24.0
pandas>=2.0.0
pyarrow>=14.0.0
# 数据库
clickhouse-driver>=0.2.6
elasticsearch>=7.0.0,<8.0.0
sqlalchemy>=2.0.0
pymysql>=1.1.0
# 训练工具
tqdm>=4.65.0
# 可选: 可视化
# matplotlib>=3.7.0
# tensorboard>=2.14.0

99
ml/run_training.sh Normal file
View File

@@ -0,0 +1,99 @@
#!/bin/bash
# 概念异动检测模型训练脚本 (Linux)
#
# 使用方法:
# chmod +x run_training.sh
# ./run_training.sh
#
# 或指定参数:
# ./run_training.sh --start 2022-01-01 --epochs 100
set -e
echo "============================================================"
echo "概念异动检测模型训练流程"
echo "============================================================"
echo ""
# 获取脚本所在目录
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
cd "$SCRIPT_DIR/.."
echo "[1/4] 检查环境..."
python3 --version || { echo "Python3 未找到!"; exit 1; }
# 检查 GPU
if python3 -c "import torch; print(f'CUDA: {torch.cuda.is_available()}')" 2>/dev/null; then
echo "PyTorch GPU 检测完成"
else
echo "警告: PyTorch 未安装或无法检测 GPU"
fi
echo ""
echo "[2/4] 检查依赖..."
pip3 install -q torch pandas numpy pyarrow tqdm clickhouse-driver elasticsearch sqlalchemy pymysql
echo ""
echo "[3/4] 准备训练数据..."
echo "从 ClickHouse 提取历史数据,这可能需要较长时间..."
echo ""
# 解析参数
START_DATE="2022-01-01"
END_DATE=""
EPOCHS=100
BATCH_SIZE=256
TRAIN_END="2025-06-30"
VAL_END="2025-09-30"
while [[ $# -gt 0 ]]; do
case $1 in
--start)
START_DATE="$2"
shift 2
;;
--end)
END_DATE="$2"
shift 2
;;
--epochs)
EPOCHS="$2"
shift 2
;;
--batch_size)
BATCH_SIZE="$2"
shift 2
;;
--train_end)
TRAIN_END="$2"
shift 2
;;
--val_end)
VAL_END="$2"
shift 2
;;
*)
shift
;;
esac
done
# 数据准备
if [ -n "$END_DATE" ]; then
python3 ml/prepare_data.py --start "$START_DATE" --end "$END_DATE"
else
python3 ml/prepare_data.py --start "$START_DATE"
fi
echo ""
echo "[4/4] 训练模型..."
echo "使用 GPU 加速训练..."
echo ""
python3 ml/train.py --epochs "$EPOCHS" --batch_size "$BATCH_SIZE" --train_end "$TRAIN_END" --val_end "$VAL_END"
echo ""
echo "============================================================"
echo "训练完成!"
echo "模型保存在: ml/checkpoints/"
echo "============================================================"

808
ml/train.py Normal file
View File

@@ -0,0 +1,808 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Transformer Autoencoder 训练脚本 (修复版)
修复问题:
1. 按概念分组构建序列,避免跨概念切片
2. 按时间(日期)切分数据集,避免数据泄露
3. 使用 RobustScaler + Clipping 处理非平稳性
4. 使用验证集计算阈值
训练流程:
1. 加载预处理好的特征数据parquet 文件)
2. 按概念分组,在每个概念内部构建序列
3. 按日期划分训练/验证/测试集
4. 训练 Autoencoder最小化重构误差
5. 保存模型和阈值
使用方法:
python train.py --data_dir ml/data --epochs 100 --batch_size 256
"""
import os
import sys
import argparse
import json
from datetime import datetime
from pathlib import Path
from typing import List, Tuple, Dict
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from tqdm import tqdm
from model import TransformerAutoencoder, AnomalyDetectionLoss, count_parameters
# 性能优化:启用 cuDNN benchmark对固定输入尺寸自动选择最快算法
torch.backends.cudnn.benchmark = True
# 启用 TF32RTX 30/40 系列特有,提速约 3 倍)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# 可视化(可选)
try:
import matplotlib
matplotlib.use('Agg') # 无头模式,不需要显示器
import matplotlib.pyplot as plt
HAS_MATPLOTLIB = True
except ImportError:
HAS_MATPLOTLIB = False
# ==================== 配置 ====================
TRAIN_CONFIG = {
# 数据配置
'seq_len': 30, # 输入序列长度30分钟
'stride': 5, # 滑动窗口步长
# 时间切分(按日期)
'train_end_date': '2024-06-30', # 训练集截止日期
'val_end_date': '2024-09-30', # 验证集截止日期(之后为测试集)
# 特征配置
'features': [
'alpha', # 超额收益
'alpha_delta', # Alpha 变化率
'amt_ratio', # 成交额比率
'amt_delta', # 成交额变化率
'rank_pct', # Alpha 排名百分位
'limit_up_ratio', # 涨停比例
],
# 训练配置(针对 4x RTX 4090 优化)
'batch_size': 4096, # 256 -> 4096大幅增加充分利用显存
'epochs': 100,
'learning_rate': 3e-4, # 1e-4 -> 3e-4大 batch 需要更大学习率)
'weight_decay': 1e-5,
'gradient_clip': 1.0,
# 早停配置
'patience': 10,
'min_delta': 1e-6,
# 模型配置LSTM Autoencoder简洁有效
'model': {
'n_features': 6,
'hidden_dim': 32, # LSTM 隐藏维度(小)
'latent_dim': 4, # 瓶颈维度(非常小!关键)
'num_layers': 1, # LSTM 层数
'dropout': 0.2,
'bidirectional': True, # 双向编码器
},
# 标准化配置
'use_instance_norm': True, # 模型内部使用 Instance Norm推荐
'clip_value': 10.0, # 简单截断极端值
# 阈值配置
'threshold_percentiles': [90, 95, 99],
}
# ==================== 数据加载(修复版)====================
def load_data_by_date(data_dir: str, features: List[str]) -> Dict[str, pd.DataFrame]:
"""
按日期加载数据,返回 {date: DataFrame} 字典
每个 DataFrame 包含该日所有概念的所有时间点数据
"""
data_path = Path(data_dir)
parquet_files = sorted(data_path.glob("features_*.parquet"))
if not parquet_files:
raise FileNotFoundError(f"未找到 parquet 文件: {data_dir}")
print(f"找到 {len(parquet_files)} 个数据文件")
date_data = {}
for pf in tqdm(parquet_files, desc="加载数据"):
# 提取日期
date = pf.stem.replace('features_', '')
df = pd.read_parquet(pf)
# 检查必要列
required_cols = features + ['concept_id', 'timestamp']
missing_cols = [c for c in required_cols if c not in df.columns]
if missing_cols:
print(f"警告: {date} 缺少列: {missing_cols}, 跳过")
continue
date_data[date] = df
print(f"成功加载 {len(date_data)} 天的数据")
return date_data
def split_data_by_date(
date_data: Dict[str, pd.DataFrame],
train_end: str,
val_end: str
) -> Tuple[Dict[str, pd.DataFrame], Dict[str, pd.DataFrame], Dict[str, pd.DataFrame]]:
"""
按日期严格划分数据集
- 训练集: <= train_end
- 验证集: train_end < date <= val_end
- 测试集: > val_end
"""
train_data = {}
val_data = {}
test_data = {}
for date, df in date_data.items():
if date <= train_end:
train_data[date] = df
elif date <= val_end:
val_data[date] = df
else:
test_data[date] = df
print(f"数据集划分(按日期):")
print(f" 训练集: {len(train_data)} 天 (<= {train_end})")
print(f" 验证集: {len(val_data)} 天 ({train_end} ~ {val_end})")
print(f" 测试集: {len(test_data)} 天 (> {val_end})")
return train_data, val_data, test_data
def build_sequences_by_concept(
date_data: Dict[str, pd.DataFrame],
features: List[str],
seq_len: int,
stride: int
) -> np.ndarray:
"""
按概念分组构建序列(性能优化版)
使用 groupby 一次性分组,避免重复扫描大数组
1. 将所有日期的数据合并
2. 使用 groupby 按 concept_id 分组
3. 在每个概念内部,按时间排序并滑动窗口
4. 合并所有序列
"""
# 合并所有日期的数据
all_dfs = []
for date, df in sorted(date_data.items()):
df = df.copy()
df['date'] = date
all_dfs.append(df)
if not all_dfs:
return np.array([])
combined = pd.concat(all_dfs, ignore_index=True)
# 预先排序(按概念、日期、时间),这样 groupby 会更快
combined = combined.sort_values(['concept_id', 'date', 'timestamp'])
# 使用 groupby 一次性分组(性能关键!)
all_sequences = []
grouped = combined.groupby('concept_id', sort=False)
n_concepts = len(grouped)
for concept_id, concept_df in tqdm(grouped, desc="构建序列", total=n_concepts, leave=False):
# 已经排序过了,直接提取特征
feature_data = concept_df[features].values
# 处理缺失值
feature_data = np.nan_to_num(feature_data, nan=0.0, posinf=0.0, neginf=0.0)
# 在该概念内部滑动窗口
n_points = len(feature_data)
for start in range(0, n_points - seq_len + 1, stride):
seq = feature_data[start:start + seq_len]
all_sequences.append(seq)
if not all_sequences:
return np.array([])
sequences = np.array(all_sequences)
print(f" 构建序列: {len(sequences):,} 条 (来自 {n_concepts} 个概念)")
return sequences
# ==================== 数据集 ====================
class SequenceDataset(Dataset):
"""序列数据集(已经构建好的序列)"""
def __init__(self, sequences: np.ndarray):
self.sequences = torch.FloatTensor(sequences)
def __len__(self) -> int:
return len(self.sequences)
def __getitem__(self, idx: int) -> torch.Tensor:
return self.sequences[idx]
# ==================== 训练器 ====================
class EarlyStopping:
"""早停机制"""
def __init__(self, patience: int = 10, min_delta: float = 1e-6):
self.patience = patience
self.min_delta = min_delta
self.counter = 0
self.best_loss = float('inf')
self.early_stop = False
def __call__(self, val_loss: float) -> bool:
if val_loss < self.best_loss - self.min_delta:
self.best_loss = val_loss
self.counter = 0
else:
self.counter += 1
if self.counter >= self.patience:
self.early_stop = True
return self.early_stop
class Trainer:
"""模型训练器(支持 AMP 混合精度加速)"""
def __init__(
self,
model: nn.Module,
train_loader: DataLoader,
val_loader: DataLoader,
config: Dict,
device: torch.device,
save_dir: str = 'ml/checkpoints'
):
self.model = model.to(device)
self.train_loader = train_loader
self.val_loader = val_loader
self.config = config
self.device = device
self.save_dir = Path(save_dir)
self.save_dir.mkdir(parents=True, exist_ok=True)
# 优化器
self.optimizer = AdamW(
model.parameters(),
lr=config['learning_rate'],
weight_decay=config['weight_decay']
)
# 学习率调度器
self.scheduler = CosineAnnealingWarmRestarts(
self.optimizer,
T_0=10,
T_mult=2,
eta_min=1e-6
)
# 损失函数(简化版,只用 MSE
self.criterion = AnomalyDetectionLoss()
# 早停
self.early_stopping = EarlyStopping(
patience=config['patience'],
min_delta=config['min_delta']
)
# AMP 混合精度训练(大幅提速 + 省显存)
self.use_amp = torch.cuda.is_available()
self.scaler = torch.cuda.amp.GradScaler() if self.use_amp else None
if self.use_amp:
print(" ✓ 启用 AMP 混合精度训练")
# 训练历史
self.history = {
'train_loss': [],
'val_loss': [],
'learning_rate': [],
}
self.best_val_loss = float('inf')
def train_epoch(self) -> float:
"""训练一个 epoch使用 AMP 混合精度)"""
self.model.train()
total_loss = 0.0
n_batches = 0
pbar = tqdm(self.train_loader, desc="Training", leave=False)
for batch in pbar:
batch = batch.to(self.device, non_blocking=True) # 异步传输
self.optimizer.zero_grad(set_to_none=True) # 更快的梯度清零
# AMP 混合精度前向传播
if self.use_amp:
with torch.cuda.amp.autocast():
output, latent = self.model(batch)
loss, loss_dict = self.criterion(output, batch, latent)
# AMP 反向传播
self.scaler.scale(loss).backward()
# 梯度裁剪(需要 unscale
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
self.config['gradient_clip']
)
self.scaler.step(self.optimizer)
self.scaler.update()
else:
# 非 AMP 模式
output, latent = self.model(batch)
loss, loss_dict = self.criterion(output, batch, latent)
loss.backward()
torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
self.config['gradient_clip']
)
self.optimizer.step()
total_loss += loss.item()
n_batches += 1
pbar.set_postfix({'loss': f"{loss.item():.4f}"})
return total_loss / n_batches
@torch.no_grad()
def validate(self) -> float:
"""验证(使用 AMP"""
self.model.eval()
total_loss = 0.0
n_batches = 0
for batch in self.val_loader:
batch = batch.to(self.device, non_blocking=True)
if self.use_amp:
with torch.cuda.amp.autocast():
output, latent = self.model(batch)
loss, _ = self.criterion(output, batch, latent)
else:
output, latent = self.model(batch)
loss, _ = self.criterion(output, batch, latent)
total_loss += loss.item()
n_batches += 1
return total_loss / n_batches
def save_checkpoint(self, epoch: int, val_loss: float, is_best: bool = False):
"""保存检查点"""
# 处理 DataParallel 包装
model_to_save = self.model.module if hasattr(self.model, 'module') else self.model
checkpoint = {
'epoch': epoch,
'model_state_dict': model_to_save.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'scheduler_state_dict': self.scheduler.state_dict(),
'val_loss': val_loss,
'config': self.config,
}
# 保存最新检查点
torch.save(checkpoint, self.save_dir / 'last_checkpoint.pt')
# 保存最佳模型
if is_best:
torch.save(checkpoint, self.save_dir / 'best_model.pt')
print(f" ✓ 保存最佳模型 (val_loss: {val_loss:.6f})")
def train(self, epochs: int):
"""完整训练流程"""
print(f"\n开始训练 ({epochs} epochs)...")
print(f"设备: {self.device}")
print(f"模型参数量: {count_parameters(self.model):,}")
for epoch in range(1, epochs + 1):
print(f"\nEpoch {epoch}/{epochs}")
# 训练
train_loss = self.train_epoch()
# 验证
val_loss = self.validate()
# 更新学习率
self.scheduler.step()
current_lr = self.optimizer.param_groups[0]['lr']
# 记录历史
self.history['train_loss'].append(train_loss)
self.history['val_loss'].append(val_loss)
self.history['learning_rate'].append(current_lr)
# 打印进度
print(f" Train Loss: {train_loss:.6f}")
print(f" Val Loss: {val_loss:.6f}")
print(f" LR: {current_lr:.2e}")
# 保存检查点
is_best = val_loss < self.best_val_loss
if is_best:
self.best_val_loss = val_loss
self.save_checkpoint(epoch, val_loss, is_best)
# 早停检查
if self.early_stopping(val_loss):
print(f"\n早停触发!验证损失已 {self.early_stopping.patience} 个 epoch 未改善")
break
print(f"\n训练完成!最佳验证损失: {self.best_val_loss:.6f}")
# 保存训练历史
self.save_history()
return self.history
def save_history(self):
"""保存训练历史"""
history_path = self.save_dir / 'training_history.json'
with open(history_path, 'w') as f:
json.dump(self.history, f, indent=2)
print(f"训练历史已保存: {history_path}")
# 绘制训练曲线
self.plot_training_curves()
def plot_training_curves(self):
"""绘制训练曲线"""
if not HAS_MATPLOTLIB:
print("matplotlib 未安装,跳过绘图")
return
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
epochs = range(1, len(self.history['train_loss']) + 1)
# 1. Loss 曲线
ax1 = axes[0]
ax1.plot(epochs, self.history['train_loss'], 'b-', label='Train Loss', linewidth=2)
ax1.plot(epochs, self.history['val_loss'], 'r-', label='Val Loss', linewidth=2)
ax1.set_xlabel('Epoch', fontsize=12)
ax1.set_ylabel('Loss', fontsize=12)
ax1.set_title('Training & Validation Loss', fontsize=14)
ax1.legend(fontsize=11)
ax1.grid(True, alpha=0.3)
# 标记最佳点
best_epoch = np.argmin(self.history['val_loss']) + 1
best_val_loss = min(self.history['val_loss'])
ax1.axvline(x=best_epoch, color='g', linestyle='--', alpha=0.7, label=f'Best Epoch: {best_epoch}')
ax1.scatter([best_epoch], [best_val_loss], color='g', s=100, zorder=5)
ax1.annotate(f'Best: {best_val_loss:.6f}', xy=(best_epoch, best_val_loss),
xytext=(best_epoch + 2, best_val_loss + 0.0005),
fontsize=10, color='green')
# 2. 学习率曲线
ax2 = axes[1]
ax2.plot(epochs, self.history['learning_rate'], 'g-', linewidth=2)
ax2.set_xlabel('Epoch', fontsize=12)
ax2.set_ylabel('Learning Rate', fontsize=12)
ax2.set_title('Learning Rate Schedule', fontsize=14)
ax2.set_yscale('log')
ax2.grid(True, alpha=0.3)
plt.tight_layout()
# 保存图片
plot_path = self.save_dir / 'training_curves.png'
plt.savefig(plot_path, dpi=150, bbox_inches='tight')
plt.close()
print(f"训练曲线已保存: {plot_path}")
# ==================== 阈值计算(使用验证集)====================
@torch.no_grad()
def compute_thresholds(
model: nn.Module,
data_loader: DataLoader,
device: torch.device,
percentiles: List[float] = [90, 95, 99]
) -> Dict[str, float]:
"""
在验证集上计算重构误差的百分位数阈值
注:使用验证集而非测试集,避免数据泄露
"""
model.eval()
all_errors = []
print("计算异动阈值(使用验证集)...")
for batch in tqdm(data_loader, desc="Computing thresholds"):
batch = batch.to(device)
errors = model.compute_reconstruction_error(batch, reduction='none')
# 取每个序列的最后一个时刻误差(预测当前时刻)
seq_errors = errors[:, -1] # (batch,)
all_errors.append(seq_errors.cpu().numpy())
all_errors = np.concatenate(all_errors)
thresholds = {}
for p in percentiles:
threshold = np.percentile(all_errors, p)
thresholds[f'p{p}'] = float(threshold)
print(f" P{p}: {threshold:.6f}")
# 额外统计
thresholds['mean'] = float(np.mean(all_errors))
thresholds['std'] = float(np.std(all_errors))
thresholds['median'] = float(np.median(all_errors))
print(f" Mean: {thresholds['mean']:.6f}")
print(f" Median: {thresholds['median']:.6f}")
print(f" Std: {thresholds['std']:.6f}")
return thresholds
# ==================== 主函数 ====================
def main():
parser = argparse.ArgumentParser(description='训练概念异动检测模型')
parser.add_argument('--data_dir', type=str, default='ml/data',
help='数据目录路径')
parser.add_argument('--epochs', type=int, default=100,
help='训练轮数')
parser.add_argument('--batch_size', type=int, default=4096,
help='批次大小4x RTX 4090 推荐 4096~8192')
parser.add_argument('--lr', type=float, default=3e-4,
help='学习率(大 batch 推荐 3e-4')
parser.add_argument('--device', type=str, default='auto',
help='设备 (auto/cuda/cpu)')
parser.add_argument('--save_dir', type=str, default='ml/checkpoints',
help='模型保存目录')
parser.add_argument('--train_end', type=str, default='2024-06-30',
help='训练集截止日期')
parser.add_argument('--val_end', type=str, default='2024-09-30',
help='验证集截止日期')
args = parser.parse_args()
# 更新配置
config = TRAIN_CONFIG.copy()
config['batch_size'] = args.batch_size
config['epochs'] = args.epochs
config['learning_rate'] = args.lr
config['train_end_date'] = args.train_end
config['val_end_date'] = args.val_end
# 设备选择
if args.device == 'auto':
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
else:
device = torch.device(args.device)
print("=" * 60)
print("概念异动检测模型训练(修复版)")
print("=" * 60)
print(f"配置:")
print(f" 数据目录: {args.data_dir}")
print(f" 设备: {device}")
print(f" 批次大小: {config['batch_size']}")
print(f" 学习率: {config['learning_rate']}")
print(f" 训练轮数: {config['epochs']}")
print(f" 训练集截止: {config['train_end_date']}")
print(f" 验证集截止: {config['val_end_date']}")
print("=" * 60)
# 1. 按日期加载数据
print("\n[1/6] 加载数据...")
date_data = load_data_by_date(args.data_dir, config['features'])
# 2. 按日期划分
print("\n[2/6] 按日期划分数据集...")
train_data, val_data, test_data = split_data_by_date(
date_data,
config['train_end_date'],
config['val_end_date']
)
# 3. 按概念构建序列
print("\n[3/6] 按概念构建序列...")
print("训练集:")
train_sequences = build_sequences_by_concept(
train_data, config['features'], config['seq_len'], config['stride']
)
print("验证集:")
val_sequences = build_sequences_by_concept(
val_data, config['features'], config['seq_len'], config['stride']
)
print("测试集:")
test_sequences = build_sequences_by_concept(
test_data, config['features'], config['seq_len'], config['stride']
)
if len(train_sequences) == 0:
print("错误: 训练集为空!请检查数据和日期范围")
return
# 4. 数据预处理(简单截断极端值,标准化在模型内部通过 Instance Norm 完成)
print("\n[4/6] 数据预处理...")
print(" 注意: 使用 Instance Norm每个序列在模型内部单独标准化")
print(" 这样可以处理不同概念波动率差异(银行 vs 半导体)")
clip_value = config['clip_value']
print(f" 截断极端值: ±{clip_value}")
# 简单截断极端值(防止异常数据影响训练)
train_sequences = np.clip(train_sequences, -clip_value, clip_value)
if len(val_sequences) > 0:
val_sequences = np.clip(val_sequences, -clip_value, clip_value)
if len(test_sequences) > 0:
test_sequences = np.clip(test_sequences, -clip_value, clip_value)
# 保存配置
save_dir = Path(args.save_dir)
save_dir.mkdir(parents=True, exist_ok=True)
preprocess_params = {
'features': config['features'],
'normalization': 'instance_norm', # 在模型内部完成
'clip_value': clip_value,
'note': '标准化在模型内部通过 InstanceNorm1d 完成,无需外部 Scaler'
}
with open(save_dir / 'normalization_stats.json', 'w') as f:
json.dump(preprocess_params, f, indent=2)
print(f" 预处理参数已保存")
# 5. 创建数据集和加载器
print("\n[5/6] 创建数据加载器...")
train_dataset = SequenceDataset(train_sequences)
val_dataset = SequenceDataset(val_sequences) if len(val_sequences) > 0 else None
test_dataset = SequenceDataset(test_sequences) if len(test_sequences) > 0 else None
print(f" 训练序列: {len(train_dataset):,}")
print(f" 验证序列: {len(val_dataset) if val_dataset else 0:,}")
print(f" 测试序列: {len(test_dataset) if test_dataset else 0:,}")
# 多卡时增加 num_workersLinux 上可以用更多)
n_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 1
num_workers = min(32, 8 * n_gpus) if sys.platform != 'win32' else 0
print(f" DataLoader workers: {num_workers}")
print(f" Batch size: {config['batch_size']}")
# 大 batch + 多 worker + prefetch 提速
train_loader = DataLoader(
train_dataset,
batch_size=config['batch_size'],
shuffle=True,
num_workers=num_workers,
pin_memory=True,
prefetch_factor=4 if num_workers > 0 else None, # 预取更多 batch
persistent_workers=True if num_workers > 0 else False, # 保持 worker 存活
drop_last=True # 丢弃不完整的最后一批,避免 batch 大小不一致
)
val_loader = DataLoader(
val_dataset,
batch_size=config['batch_size'] * 2, # 验证时可以用更大 batch无梯度
shuffle=False,
num_workers=num_workers,
pin_memory=True,
prefetch_factor=4 if num_workers > 0 else None,
persistent_workers=True if num_workers > 0 else False,
) if val_dataset else None
test_loader = DataLoader(
test_dataset,
batch_size=config['batch_size'] * 2,
shuffle=False,
num_workers=num_workers,
pin_memory=True,
prefetch_factor=4 if num_workers > 0 else None,
persistent_workers=True if num_workers > 0 else False,
) if test_dataset else None
# 6. 训练
print("\n[6/6] 训练模型...")
model_config = config['model'].copy()
model = TransformerAutoencoder(**model_config)
# 多卡并行
if torch.cuda.device_count() > 1:
print(f" 使用 {torch.cuda.device_count()} 张 GPU 并行训练")
model = nn.DataParallel(model)
if val_loader is None:
print("警告: 验证集为空,将使用训练集的一部分作为验证")
# 简单处理:用训练集的后 10% 作为验证
split_idx = int(len(train_dataset) * 0.9)
train_subset = torch.utils.data.Subset(train_dataset, range(split_idx))
val_subset = torch.utils.data.Subset(train_dataset, range(split_idx, len(train_dataset)))
train_loader = DataLoader(train_subset, batch_size=config['batch_size'], shuffle=True, num_workers=num_workers, pin_memory=True)
val_loader = DataLoader(val_subset, batch_size=config['batch_size'], shuffle=False, num_workers=num_workers, pin_memory=True)
trainer = Trainer(
model=model,
train_loader=train_loader,
val_loader=val_loader,
config=config,
device=device,
save_dir=args.save_dir
)
history = trainer.train(config['epochs'])
# 7. 计算阈值(使用验证集)
print("\n[额外] 计算异动阈值...")
# 加载最佳模型
best_checkpoint = torch.load(
save_dir / 'best_model.pt',
map_location=device
)
model.load_state_dict(best_checkpoint['model_state_dict'])
model.to(device)
# 使用验证集计算阈值(避免数据泄露)
thresholds = compute_thresholds(
model,
val_loader,
device,
config['threshold_percentiles']
)
# 保存阈值
with open(save_dir / 'thresholds.json', 'w') as f:
json.dump(thresholds, f, indent=2)
print(f"阈值已保存")
# 保存完整配置
with open(save_dir / 'config.json', 'w') as f:
json.dump(config, f, indent=2)
print("\n" + "=" * 60)
print("训练完成!")
print("=" * 60)
print(f"模型保存位置: {args.save_dir}")
print(f" - best_model.pt: 最佳模型权重")
print(f" - thresholds.json: 异动阈值")
print(f" - normalization_stats.json: 标准化参数")
print(f" - config.json: 训练配置")
print("=" * 60)
if __name__ == "__main__":
main()

File diff suppressed because it is too large Load Diff

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,68 @@
-- 概念分钟级异动数据表
-- 用于存储概念板块的实时异动信息,支持热点概览图表展示
CREATE TABLE IF NOT EXISTS concept_minute_alert (
id BIGINT AUTO_INCREMENT PRIMARY KEY,
concept_id VARCHAR(32) NOT NULL COMMENT '概念ID',
concept_name VARCHAR(100) NOT NULL COMMENT '概念名称',
alert_time DATETIME NOT NULL COMMENT '异动时间(精确到分钟)',
alert_type VARCHAR(20) NOT NULL COMMENT '异动类型surge(急涨)/limit_up(涨停增加)/rank_jump(排名跃升)',
trade_date DATE NOT NULL COMMENT '交易日期',
-- 涨跌幅相关
change_pct DECIMAL(10,4) COMMENT '当时涨跌幅(%)',
prev_change_pct DECIMAL(10,4) COMMENT '之前涨跌幅(%)',
change_delta DECIMAL(10,4) COMMENT '涨幅变化量(%)',
-- 涨停相关
limit_up_count INT DEFAULT 0 COMMENT '当前涨停数量',
prev_limit_up_count INT DEFAULT 0 COMMENT '之前涨停数量',
limit_up_delta INT DEFAULT 0 COMMENT '涨停变化数量',
-- 排名相关
rank_position INT COMMENT '当前涨幅排名',
prev_rank_position INT COMMENT '之前涨幅排名',
rank_delta INT COMMENT '排名变化(负数表示上升)',
-- 指数位置用于图表Y轴定位
index_code VARCHAR(20) DEFAULT '000001.SH' COMMENT '参考指数代码',
index_price DECIMAL(12,4) COMMENT '异动时的指数点位',
index_change_pct DECIMAL(10,4) COMMENT '异动时的指数涨跌幅(%)',
-- 概念详情
stock_count INT COMMENT '概念包含股票数',
concept_type VARCHAR(20) DEFAULT 'leaf' COMMENT '概念类型leaf/lv1/lv2/lv3',
-- 额外信息JSON格式存储涨停股票列表等
extra_info JSON COMMENT '额外信息',
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
-- 索引
INDEX idx_trade_date (trade_date),
INDEX idx_alert_time (alert_time),
INDEX idx_concept_id (concept_id),
INDEX idx_alert_type (alert_type),
INDEX idx_trade_date_time (trade_date, alert_time)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='概念分钟级异动数据表';
-- 创建指数分时快照表(用于异动时获取指数位置)
CREATE TABLE IF NOT EXISTS index_minute_snapshot (
id BIGINT AUTO_INCREMENT PRIMARY KEY,
index_code VARCHAR(20) NOT NULL COMMENT '指数代码',
trade_date DATE NOT NULL COMMENT '交易日期',
snapshot_time DATETIME NOT NULL COMMENT '快照时间',
price DECIMAL(12,4) COMMENT '指数点位',
open_price DECIMAL(12,4) COMMENT '开盘价',
high_price DECIMAL(12,4) COMMENT '最高价',
low_price DECIMAL(12,4) COMMENT '最低价',
prev_close DECIMAL(12,4) COMMENT '昨收价',
change_pct DECIMAL(10,4) COMMENT '涨跌幅(%)',
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
UNIQUE KEY uk_index_time (index_code, snapshot_time),
INDEX idx_trade_date (trade_date)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='指数分时快照表';

View File

@@ -0,0 +1,539 @@
/**
* 热点概览组件
* 展示大盘分时走势 + 概念异动标注
*/
import React, { useState, useEffect, useRef, useCallback } from 'react';
import {
Box,
Card,
CardBody,
Heading,
Text,
HStack,
VStack,
Badge,
Spinner,
Center,
Icon,
Flex,
Spacer,
Tooltip,
useColorModeValue,
Stat,
StatLabel,
StatNumber,
StatHelpText,
StatArrow,
SimpleGrid,
} from '@chakra-ui/react';
import { FaFire, FaRocket, FaChartLine, FaBolt, FaArrowDown } from 'react-icons/fa';
import { InfoIcon } from '@chakra-ui/icons';
import * as echarts from 'echarts';
import { logger } from '@utils/logger';
const HotspotOverview = ({ selectedDate }) => {
const chartRef = useRef(null);
const chartInstance = useRef(null);
const [loading, setLoading] = useState(true);
const [data, setData] = useState(null);
const [error, setError] = useState(null);
// 颜色主题
const cardBg = useColorModeValue('white', '#1a1a1a');
const borderColor = useColorModeValue('gray.200', '#333333');
const textColor = useColorModeValue('gray.800', 'white');
const subTextColor = useColorModeValue('gray.600', 'gray.400');
// 获取数据
const fetchData = useCallback(async () => {
setLoading(true);
setError(null);
try {
const dateParam = selectedDate
? `?date=${selectedDate.toISOString().split('T')[0]}`
: '';
const response = await fetch(`/api/market/hotspot-overview${dateParam}`);
const result = await response.json();
if (result.success) {
setData(result.data);
} else {
setError(result.error || '获取数据失败');
}
} catch (err) {
logger.error('HotspotOverview', 'fetchData', err);
setError('网络请求失败');
} finally {
setLoading(false);
}
}, [selectedDate]);
useEffect(() => {
fetchData();
}, [fetchData]);
// 渲染图表
const renderChart = useCallback(() => {
if (!chartRef.current || !data) return;
if (!chartInstance.current) {
chartInstance.current = echarts.init(chartRef.current);
}
const { index, alerts } = data;
const timeline = index.timeline || [];
// 准备数据
const times = timeline.map((d) => d.time);
const prices = timeline.map((d) => d.price);
const volumes = timeline.map((d) => d.volume);
const changePcts = timeline.map((d) => d.change_pct);
// 计算Y轴范围
const priceMin = Math.min(...prices.filter(Boolean));
const priceMax = Math.max(...prices.filter(Boolean));
const priceRange = priceMax - priceMin;
const yAxisMin = priceMin - priceRange * 0.1;
const yAxisMax = priceMax + priceRange * 0.2; // 上方留更多空间给标注
// 准备异动标注 - 按重要性排序,限制显示数量
const sortedAlerts = [...alerts]
.sort((a, b) => (b.importance_score || 0) - (a.importance_score || 0))
.slice(0, 15); // 最多显示15个标注避免图表过于密集
const markPoints = sortedAlerts.map((alert) => {
// 找到对应时间的价格
const timeIndex = times.indexOf(alert.time);
const price = timeIndex >= 0 ? prices[timeIndex] : (alert.index_price || priceMax);
// 根据异动类型设置颜色和符号
let color = '#ff6b6b';
let symbol = 'pin';
let symbolSize = 35;
// 暴涨
if (alert.alert_type === 'surge_up' || alert.alert_type === 'surge') {
color = '#ff4757';
symbol = 'triangle';
symbolSize = 30 + Math.min((alert.importance_score || 0.5) * 20, 15); // 根据重要性调整大小
}
// 暴跌
else if (alert.alert_type === 'surge_down') {
color = '#2ed573';
symbol = 'path://M0,0 L10,0 L5,10 Z'; // 向下三角形
symbolSize = 30 + Math.min((alert.importance_score || 0.5) * 20, 15);
}
// 涨停增加
else if (alert.alert_type === 'limit_up') {
color = '#ff6348';
symbol = 'diamond';
symbolSize = 28;
}
// 排名跃升
else if (alert.alert_type === 'rank_jump') {
color = '#3742fa';
symbol = 'circle';
symbolSize = 25;
}
// 格式化标签 - 简化显示
let label = alert.concept_name;
// 截断过长的名称
if (label.length > 8) {
label = label.substring(0, 7) + '...';
}
// 添加变化信息
const changeDelta = alert.change_delta;
if (changeDelta) {
const sign = changeDelta > 0 ? '+' : '';
label += `\n${sign}${changeDelta.toFixed(1)}%`;
}
return {
name: alert.concept_name,
coord: [alert.time, price],
value: label,
symbol: symbol,
symbolSize: symbolSize,
itemStyle: {
color: color,
borderColor: '#fff',
borderWidth: 1,
shadowBlur: 3,
shadowColor: 'rgba(0,0,0,0.2)',
},
label: {
show: true,
position: alert.alert_type === 'surge_down' ? 'bottom' : 'top', // 暴跌标签在下方
formatter: '{b}',
fontSize: 9,
color: textColor,
backgroundColor: alert.alert_type === 'surge_down'
? 'rgba(46, 213, 115, 0.9)'
: 'rgba(255,255,255,0.9)',
padding: [2, 4],
borderRadius: 2,
borderColor: color,
borderWidth: 1,
},
// 存储额外信息用于 tooltip
alertData: alert,
};
});
// 渐变色 - 根据涨跌
const latestChangePct = changePcts[changePcts.length - 1] || 0;
const areaColorStops = latestChangePct >= 0
? [
{ offset: 0, color: 'rgba(255, 77, 77, 0.4)' },
{ offset: 1, color: 'rgba(255, 77, 77, 0.05)' },
]
: [
{ offset: 0, color: 'rgba(34, 197, 94, 0.4)' },
{ offset: 1, color: 'rgba(34, 197, 94, 0.05)' },
];
const lineColor = latestChangePct >= 0 ? '#ff4d4d' : '#22c55e';
const option = {
backgroundColor: 'transparent',
tooltip: {
trigger: 'axis',
axisPointer: {
type: 'cross',
crossStyle: {
color: '#999',
},
},
formatter: function (params) {
if (!params || params.length === 0) return '';
const dataIndex = params[0].dataIndex;
const time = times[dataIndex];
const price = prices[dataIndex];
const changePct = changePcts[dataIndex];
const volume = volumes[dataIndex];
let html = `
<div style="padding: 8px;">
<div style="font-weight: bold; margin-bottom: 4px;">${time}</div>
<div>指数: <span style="color: ${changePct >= 0 ? '#ff4d4d' : '#22c55e'}; font-weight: bold;">${price?.toFixed(2)}</span></div>
<div>涨跌: <span style="color: ${changePct >= 0 ? '#ff4d4d' : '#22c55e'};">${changePct >= 0 ? '+' : ''}${changePct?.toFixed(2)}%</span></div>
<div>成交量: ${(volume / 10000).toFixed(0)}万手</div>
</div>
`;
// 检查是否有异动
const alertsAtTime = alerts.filter((a) => a.time === time);
if (alertsAtTime.length > 0) {
html += '<div style="border-top: 1px solid #eee; margin-top: 4px; padding-top: 4px;">';
html += '<div style="font-weight: bold; color: #ff6b6b;">概念异动:</div>';
alertsAtTime.forEach((alert) => {
const typeLabel = {
surge: '急涨',
surge_up: '暴涨',
surge_down: '暴跌',
limit_up: '涨停增加',
rank_jump: '排名跃升',
}[alert.alert_type] || alert.alert_type;
const typeColor = alert.alert_type === 'surge_down' ? '#2ed573' : '#ff6b6b';
const delta = alert.change_delta ? ` (${alert.change_delta > 0 ? '+' : ''}${alert.change_delta.toFixed(2)}%)` : '';
const zscore = alert.zscore ? ` Z=${alert.zscore.toFixed(1)}` : '';
html += `<div style="color: ${typeColor}">• ${alert.concept_name} (${typeLabel}${delta}${zscore})</div>`;
});
html += '</div>';
}
return html;
},
},
legend: {
show: false,
},
grid: [
{
left: '8%',
right: '3%',
top: '8%',
height: '55%',
},
{
left: '8%',
right: '3%',
top: '70%',
height: '20%',
},
],
xAxis: [
{
type: 'category',
data: times,
axisLine: { lineStyle: { color: '#ddd' } },
axisLabel: {
color: subTextColor,
fontSize: 10,
interval: Math.floor(times.length / 6),
},
axisTick: { show: false },
splitLine: { show: false },
},
{
type: 'category',
gridIndex: 1,
data: times,
axisLine: { lineStyle: { color: '#ddd' } },
axisLabel: { show: false },
axisTick: { show: false },
splitLine: { show: false },
},
],
yAxis: [
{
type: 'value',
min: yAxisMin,
max: yAxisMax,
axisLine: { show: false },
axisLabel: {
color: subTextColor,
fontSize: 10,
formatter: (val) => val.toFixed(0),
},
splitLine: {
lineStyle: { color: '#eee', type: 'dashed' },
},
// 右侧显示涨跌幅
axisPointer: {
label: {
formatter: function (params) {
const pct = ((params.value - index.prev_close) / index.prev_close) * 100;
return `${params.value.toFixed(2)} (${pct >= 0 ? '+' : ''}${pct.toFixed(2)}%)`;
},
},
},
},
{
type: 'value',
gridIndex: 1,
axisLine: { show: false },
axisLabel: { show: false },
splitLine: { show: false },
},
],
series: [
// 分时线
{
name: '上证指数',
type: 'line',
data: prices,
smooth: true,
symbol: 'none',
lineStyle: {
color: lineColor,
width: 1.5,
},
areaStyle: {
color: new echarts.graphic.LinearGradient(0, 0, 0, 1, areaColorStops),
},
markPoint: {
symbol: 'pin',
symbolSize: 40,
data: markPoints,
animation: true,
},
},
// 成交量
{
name: '成交量',
type: 'bar',
xAxisIndex: 1,
yAxisIndex: 1,
data: volumes.map((v, i) => ({
value: v,
itemStyle: {
color: changePcts[i] >= 0 ? 'rgba(255, 77, 77, 0.6)' : 'rgba(34, 197, 94, 0.6)',
},
})),
barWidth: '60%',
},
],
};
chartInstance.current.setOption(option, true);
}, [data, textColor, subTextColor]);
// 数据变化时重新渲染
useEffect(() => {
if (data) {
renderChart();
}
}, [data, renderChart]);
// 窗口大小变化时重新渲染
useEffect(() => {
const handleResize = () => {
if (chartInstance.current) {
chartInstance.current.resize();
}
};
window.addEventListener('resize', handleResize);
return () => {
window.removeEventListener('resize', handleResize);
if (chartInstance.current) {
chartInstance.current.dispose();
chartInstance.current = null;
}
};
}, []);
// 异动类型标签
const AlertTypeBadge = ({ type, count }) => {
const config = {
surge: { label: '急涨', color: 'red', icon: FaBolt },
surge_up: { label: '暴涨', color: 'red', icon: FaBolt },
surge_down: { label: '暴跌', color: 'green', icon: FaArrowDown },
limit_up: { label: '涨停', color: 'orange', icon: FaRocket },
rank_jump: { label: '排名跃升', color: 'blue', icon: FaChartLine },
};
const cfg = config[type] || { label: type, color: 'gray', icon: FaFire };
return (
<Badge colorScheme={cfg.color} variant="subtle" px={2} py={1} borderRadius="md">
<HStack spacing={1}>
<Icon as={cfg.icon} boxSize={3} />
<Text>{cfg.label}</Text>
<Text fontWeight="bold">{count}</Text>
</HStack>
</Badge>
);
};
if (loading) {
return (
<Card bg={cardBg} borderWidth="1px" borderColor={borderColor}>
<CardBody>
<Center h="400px">
<VStack spacing={4}>
<Spinner size="xl" color="purple.500" thickness="4px" />
<Text color={subTextColor}>加载热点概览数据...</Text>
</VStack>
</Center>
</CardBody>
</Card>
);
}
if (error) {
return (
<Card bg={cardBg} borderWidth="1px" borderColor={borderColor}>
<CardBody>
<Center h="400px">
<VStack spacing={4}>
<Icon as={InfoIcon} boxSize={10} color="red.400" />
<Text color="red.500">{error}</Text>
</VStack>
</Center>
</CardBody>
</Card>
);
}
if (!data) {
return null;
}
const { index, alerts, alert_summary } = data;
return (
<Card bg={cardBg} borderWidth="1px" borderColor={borderColor}>
<CardBody>
{/* 头部信息 */}
<Flex align="center" mb={4}>
<HStack spacing={3}>
<Icon as={FaFire} boxSize={6} color="orange.500" />
<Heading size="md" color={textColor}>
热点概览
</Heading>
</HStack>
<Spacer />
<Tooltip label="展示大盘走势与概念异动的关联">
<Icon as={InfoIcon} color={subTextColor} />
</Tooltip>
</Flex>
{/* 指数统计 */}
<SimpleGrid columns={{ base: 2, md: 4 }} spacing={4} mb={4}>
<Stat size="sm">
<StatLabel color={subTextColor}>{index.name}</StatLabel>
<StatNumber
fontSize="xl"
color={index.change_pct >= 0 ? 'red.500' : 'green.500'}
>
{index.latest_price?.toFixed(2)}
</StatNumber>
<StatHelpText mb={0}>
<StatArrow type={index.change_pct >= 0 ? 'increase' : 'decrease'} />
{index.change_pct?.toFixed(2)}%
</StatHelpText>
</Stat>
<Stat size="sm">
<StatLabel color={subTextColor}>最高</StatLabel>
<StatNumber fontSize="xl" color="red.500">
{index.high?.toFixed(2)}
</StatNumber>
</Stat>
<Stat size="sm">
<StatLabel color={subTextColor}>最低</StatLabel>
<StatNumber fontSize="xl" color="green.500">
{index.low?.toFixed(2)}
</StatNumber>
</Stat>
<Stat size="sm">
<StatLabel color={subTextColor}>异动次数</StatLabel>
<StatNumber fontSize="xl" color="orange.500">
{alerts.length}
</StatNumber>
</Stat>
</SimpleGrid>
{/* 异动类型统计 */}
{alerts.length > 0 && (
<HStack spacing={2} mb={4} flexWrap="wrap">
{(alert_summary.surge_up > 0 || alert_summary.surge > 0) && (
<AlertTypeBadge type="surge_up" count={(alert_summary.surge_up || 0) + (alert_summary.surge || 0)} />
)}
{alert_summary.surge_down > 0 && (
<AlertTypeBadge type="surge_down" count={alert_summary.surge_down} />
)}
{alert_summary.limit_up > 0 && (
<AlertTypeBadge type="limit_up" count={alert_summary.limit_up} />
)}
{alert_summary.rank_jump > 0 && (
<AlertTypeBadge type="rank_jump" count={alert_summary.rank_jump} />
)}
</HStack>
)}
{/* 图表 */}
<Box ref={chartRef} h="400px" w="100%" />
{/* 无异动提示 */}
{alerts.length === 0 && (
<Center py={4}>
<Text color={subTextColor} fontSize="sm">
当日暂无概念异动数据
</Text>
</Center>
)}
</CardBody>
</Card>
);
};
export default HotspotOverview;

View File

@@ -53,6 +53,7 @@ import { SearchIcon, CloseIcon, ArrowForwardIcon, TrendingUpIcon, InfoIcon, Chev
import { FaChartLine, FaFire, FaRocket, FaBrain, FaCalendarAlt, FaChevronRight, FaArrowUp, FaArrowDown, FaChartBar, FaTag, FaLayerGroup, FaBolt } from 'react-icons/fa'; import { FaChartLine, FaFire, FaRocket, FaBrain, FaCalendarAlt, FaChevronRight, FaArrowUp, FaArrowDown, FaChartBar, FaTag, FaLayerGroup, FaBolt } from 'react-icons/fa';
import ConceptStocksModal from '@components/ConceptStocksModal'; import ConceptStocksModal from '@components/ConceptStocksModal';
import TradeDatePicker from '@components/TradeDatePicker'; import TradeDatePicker from '@components/TradeDatePicker';
import HotspotOverview from './components/HotspotOverview';
import { BsGraphUp, BsLightningFill } from 'react-icons/bs'; import { BsGraphUp, BsLightningFill } from 'react-icons/bs';
import * as echarts from 'echarts'; import * as echarts from 'echarts';
import { logger } from '../../utils/logger'; import { logger } from '../../utils/logger';
@@ -840,6 +841,11 @@ const StockOverview = () => {
)} )}
</Box> </Box>
{/* 热点概览 - 大盘走势 + 概念异动 */}
<Box mb={10}>
<HotspotOverview selectedDate={selectedDate} />
</Box>
{/* 今日热门概念 */} {/* 今日热门概念 */}
<Box mb={10}> <Box mb={10}>
<Flex align="center" mb={6}> <Flex align="center" mb={6}>