update pay ui
This commit is contained in:
273
app.py
273
app.py
@@ -12458,6 +12458,279 @@ def get_daily_top_concepts():
|
||||
}), 500
|
||||
|
||||
|
||||
# ==================== 热点概览 API ====================
|
||||
|
||||
@app.route('/api/market/hotspot-overview', methods=['GET'])
|
||||
def get_hotspot_overview():
|
||||
"""
|
||||
获取热点概览数据(用于个股中心的热点概览图表)
|
||||
返回:指数分时数据 + 概念异动标注
|
||||
"""
|
||||
try:
|
||||
trade_date = request.args.get('date')
|
||||
index_code = request.args.get('index', '000001.SH')
|
||||
|
||||
# 如果没有指定日期,使用最新交易日
|
||||
if not trade_date:
|
||||
today = date.today()
|
||||
if today in trading_days_set:
|
||||
trade_date = today.strftime('%Y-%m-%d')
|
||||
else:
|
||||
target_date = get_trading_day_near_date(today)
|
||||
trade_date = target_date.strftime('%Y-%m-%d') if target_date else today.strftime('%Y-%m-%d')
|
||||
|
||||
# 1. 获取指数分时数据
|
||||
client = get_clickhouse_client()
|
||||
target_date_obj = datetime.strptime(trade_date, '%Y-%m-%d').date()
|
||||
|
||||
index_data = client.execute(
|
||||
"""
|
||||
SELECT timestamp, open, high, low, close, volume
|
||||
FROM index_minute
|
||||
WHERE code = %(code)s
|
||||
AND toDate(timestamp) = %(date)s
|
||||
ORDER BY timestamp
|
||||
""",
|
||||
{
|
||||
'code': index_code,
|
||||
'date': target_date_obj
|
||||
}
|
||||
)
|
||||
|
||||
# 获取昨收价
|
||||
code_no_suffix = index_code.split('.')[0]
|
||||
prev_close = None
|
||||
with engine.connect() as conn:
|
||||
prev_result = conn.execute(text("""
|
||||
SELECT F006N FROM ea_exchangetrade
|
||||
WHERE INDEXCODE = :code
|
||||
AND TRADEDATE < :today
|
||||
ORDER BY TRADEDATE DESC LIMIT 1
|
||||
"""), {
|
||||
'code': code_no_suffix,
|
||||
'today': target_date_obj
|
||||
}).fetchone()
|
||||
if prev_result and prev_result[0]:
|
||||
prev_close = float(prev_result[0])
|
||||
|
||||
# 格式化指数数据
|
||||
index_timeline = []
|
||||
for row in index_data:
|
||||
ts, open_p, high_p, low_p, close_p, vol = row
|
||||
change_pct = None
|
||||
if prev_close and close_p:
|
||||
change_pct = round((float(close_p) - prev_close) / prev_close * 100, 4)
|
||||
|
||||
index_timeline.append({
|
||||
'time': ts.strftime('%H:%M'),
|
||||
'timestamp': ts.isoformat(),
|
||||
'price': float(close_p) if close_p else None,
|
||||
'open': float(open_p) if open_p else None,
|
||||
'high': float(high_p) if high_p else None,
|
||||
'low': float(low_p) if low_p else None,
|
||||
'volume': int(vol) if vol else 0,
|
||||
'change_pct': change_pct
|
||||
})
|
||||
|
||||
# 2. 获取概念异动数据
|
||||
alerts = []
|
||||
with engine.connect() as conn:
|
||||
alert_result = conn.execute(text("""
|
||||
SELECT
|
||||
concept_id, concept_name, alert_time, alert_type,
|
||||
change_pct, change_delta, limit_up_count, limit_up_delta,
|
||||
rank_position, index_price, index_change_pct,
|
||||
stock_count, concept_type, extra_info,
|
||||
prev_change_pct, zscore, importance_score
|
||||
FROM concept_minute_alert
|
||||
WHERE trade_date = :trade_date
|
||||
ORDER BY alert_time
|
||||
"""), {'trade_date': trade_date})
|
||||
|
||||
for row in alert_result:
|
||||
alert_time = row[2]
|
||||
extra_info = None
|
||||
if row[13]:
|
||||
try:
|
||||
extra_info = json.loads(row[13]) if isinstance(row[13], str) else row[13]
|
||||
except:
|
||||
pass
|
||||
|
||||
# 从 extra_info 提取 zscore 和 importance_score(兼容旧数据)
|
||||
zscore = None
|
||||
importance_score = None
|
||||
if len(row) > 15:
|
||||
zscore = float(row[15]) if row[15] else None
|
||||
importance_score = float(row[16]) if row[16] else None
|
||||
if extra_info:
|
||||
zscore = zscore or extra_info.get('zscore')
|
||||
importance_score = importance_score or extra_info.get('importance_score')
|
||||
|
||||
alerts.append({
|
||||
'concept_id': row[0],
|
||||
'concept_name': row[1],
|
||||
'time': alert_time.strftime('%H:%M') if alert_time else None,
|
||||
'timestamp': alert_time.isoformat() if alert_time else None,
|
||||
'alert_type': row[3],
|
||||
'change_pct': float(row[4]) if row[4] else None,
|
||||
'change_delta': float(row[5]) if row[5] else None,
|
||||
'limit_up_count': row[6],
|
||||
'limit_up_delta': row[7],
|
||||
'rank_position': row[8],
|
||||
'index_price': float(row[9]) if row[9] else None,
|
||||
'index_change_pct': float(row[10]) if row[10] else None,
|
||||
'stock_count': row[11],
|
||||
'concept_type': row[12],
|
||||
'extra_info': extra_info,
|
||||
'prev_change_pct': float(row[14]) if len(row) > 14 and row[14] else None,
|
||||
'zscore': zscore,
|
||||
'importance_score': importance_score
|
||||
})
|
||||
|
||||
# 计算统计信息
|
||||
day_high = max([d['price'] for d in index_timeline if d['price']], default=None)
|
||||
day_low = min([d['price'] for d in index_timeline if d['price']], default=None)
|
||||
latest_price = index_timeline[-1]['price'] if index_timeline else None
|
||||
latest_change_pct = index_timeline[-1]['change_pct'] if index_timeline else None
|
||||
|
||||
return jsonify({
|
||||
'success': True,
|
||||
'data': {
|
||||
'trade_date': trade_date,
|
||||
'index': {
|
||||
'code': index_code,
|
||||
'name': '上证指数' if index_code == '000001.SH' else index_code,
|
||||
'prev_close': prev_close,
|
||||
'latest_price': latest_price,
|
||||
'change_pct': latest_change_pct,
|
||||
'high': day_high,
|
||||
'low': day_low,
|
||||
'timeline': index_timeline
|
||||
},
|
||||
'alerts': alerts,
|
||||
'alert_count': len(alerts),
|
||||
'alert_summary': {
|
||||
'surge': len([a for a in alerts if a['alert_type'] == 'surge']),
|
||||
'surge_up': len([a for a in alerts if a['alert_type'] == 'surge_up']),
|
||||
'surge_down': len([a for a in alerts if a['alert_type'] == 'surge_down']),
|
||||
'limit_up': len([a for a in alerts if a['alert_type'] == 'limit_up']),
|
||||
'rank_jump': len([a for a in alerts if a['alert_type'] == 'rank_jump'])
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
logger.error(f"获取热点概览数据失败: {traceback.format_exc()}")
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}), 500
|
||||
|
||||
|
||||
@app.route('/api/market/concept-alerts', methods=['GET'])
|
||||
def get_concept_alerts():
|
||||
"""
|
||||
获取概念异动列表(支持分页和筛选)
|
||||
"""
|
||||
try:
|
||||
trade_date = request.args.get('date')
|
||||
alert_type = request.args.get('type') # surge/limit_up/rank_jump
|
||||
concept_type = request.args.get('concept_type') # leaf/lv1/lv2/lv3
|
||||
limit = request.args.get('limit', 50, type=int)
|
||||
offset = request.args.get('offset', 0, type=int)
|
||||
|
||||
# 构建查询条件
|
||||
conditions = []
|
||||
params = {'limit': limit, 'offset': offset}
|
||||
|
||||
if trade_date:
|
||||
conditions.append("trade_date = :trade_date")
|
||||
params['trade_date'] = trade_date
|
||||
else:
|
||||
conditions.append("trade_date = CURDATE()")
|
||||
|
||||
if alert_type:
|
||||
conditions.append("alert_type = :alert_type")
|
||||
params['alert_type'] = alert_type
|
||||
|
||||
if concept_type:
|
||||
conditions.append("concept_type = :concept_type")
|
||||
params['concept_type'] = concept_type
|
||||
|
||||
where_clause = " AND ".join(conditions) if conditions else "1=1"
|
||||
|
||||
with engine.connect() as conn:
|
||||
# 获取总数
|
||||
count_sql = text(f"SELECT COUNT(*) FROM concept_minute_alert WHERE {where_clause}")
|
||||
total = conn.execute(count_sql, params).scalar()
|
||||
|
||||
# 获取数据
|
||||
query_sql = text(f"""
|
||||
SELECT
|
||||
id, concept_id, concept_name, alert_time, alert_type, trade_date,
|
||||
change_pct, prev_change_pct, change_delta,
|
||||
limit_up_count, prev_limit_up_count, limit_up_delta,
|
||||
rank_position, prev_rank_position, rank_delta,
|
||||
index_price, index_change_pct,
|
||||
stock_count, concept_type, extra_info
|
||||
FROM concept_minute_alert
|
||||
WHERE {where_clause}
|
||||
ORDER BY alert_time DESC
|
||||
LIMIT :limit OFFSET :offset
|
||||
""")
|
||||
|
||||
result = conn.execute(query_sql, params)
|
||||
|
||||
alerts = []
|
||||
for row in result:
|
||||
extra_info = None
|
||||
if row[19]:
|
||||
try:
|
||||
extra_info = json.loads(row[19]) if isinstance(row[19], str) else row[19]
|
||||
except:
|
||||
pass
|
||||
|
||||
alerts.append({
|
||||
'id': row[0],
|
||||
'concept_id': row[1],
|
||||
'concept_name': row[2],
|
||||
'alert_time': row[3].isoformat() if row[3] else None,
|
||||
'alert_type': row[4],
|
||||
'trade_date': row[5].isoformat() if row[5] else None,
|
||||
'change_pct': float(row[6]) if row[6] else None,
|
||||
'prev_change_pct': float(row[7]) if row[7] else None,
|
||||
'change_delta': float(row[8]) if row[8] else None,
|
||||
'limit_up_count': row[9],
|
||||
'prev_limit_up_count': row[10],
|
||||
'limit_up_delta': row[11],
|
||||
'rank_position': row[12],
|
||||
'prev_rank_position': row[13],
|
||||
'rank_delta': row[14],
|
||||
'index_price': float(row[15]) if row[15] else None,
|
||||
'index_change_pct': float(row[16]) if row[16] else None,
|
||||
'stock_count': row[17],
|
||||
'concept_type': row[18],
|
||||
'extra_info': extra_info
|
||||
})
|
||||
|
||||
return jsonify({
|
||||
'success': True,
|
||||
'data': alerts,
|
||||
'total': total,
|
||||
'limit': limit,
|
||||
'offset': offset
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
logger.error(f"获取概念异动列表失败: {traceback.format_exc()}")
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}), 500
|
||||
|
||||
|
||||
@app.route('/api/market/rise-analysis/<seccode>', methods=['GET'])
|
||||
def get_rise_analysis(seccode):
|
||||
"""获取股票涨幅分析数据(从 Elasticsearch 获取)"""
|
||||
|
||||
2823
concept_alert_20251208.log
Normal file
2823
concept_alert_20251208.log
Normal file
File diff suppressed because it is too large
Load Diff
1078
concept_alert_alpha.py
Normal file
1078
concept_alert_alpha.py
Normal file
File diff suppressed because it is too large
Load Diff
28
concept_alert_alpha_20251208.log
Normal file
28
concept_alert_alpha_20251208.log
Normal file
@@ -0,0 +1,28 @@
|
||||
2025-12-08 16:40:41,567 - INFO - ============================================================
|
||||
2025-12-08 16:40:41,567 - INFO - 🔄 回测: 2025-12-08 (Alpha Z-Score 方法)
|
||||
2025-12-08 16:40:41,569 - INFO - ============================================================
|
||||
2025-12-08 16:40:41,679 - INFO - 已清除 2025-12-08 的数据
|
||||
2025-12-08 16:40:41,903 - INFO - POST http://222.128.1.157:19200/concept_library_v3/_search?scroll=2m [status:200 duration:0.224s]
|
||||
2025-12-08 16:40:42,105 - INFO - POST http://222.128.1.157:19200/_search/scroll [status:200 duration:0.197s]
|
||||
2025-12-08 16:40:42,330 - INFO - POST http://222.128.1.157:19200/_search/scroll [status:200 duration:0.178s]
|
||||
2025-12-08 16:40:42,518 - INFO - POST http://222.128.1.157:19200/_search/scroll [status:200 duration:0.183s]
|
||||
2025-12-08 16:40:42,704 - INFO - POST http://222.128.1.157:19200/_search/scroll [status:200 duration:0.182s]
|
||||
2025-12-08 16:40:42,894 - INFO - POST http://222.128.1.157:19200/_search/scroll [status:200 duration:0.186s]
|
||||
2025-12-08 16:40:43,060 - INFO - POST http://222.128.1.157:19200/_search/scroll [status:200 duration:0.162s]
|
||||
2025-12-08 16:40:43,234 - INFO - POST http://222.128.1.157:19200/_search/scroll [status:200 duration:0.171s]
|
||||
2025-12-08 16:40:43,383 - INFO - POST http://222.128.1.157:19200/_search/scroll [status:200 duration:0.145s]
|
||||
2025-12-08 16:40:43,394 - INFO - POST http://222.128.1.157:19200/_search/scroll [status:200 duration:0.008s]
|
||||
2025-12-08 16:40:43,399 - INFO - DELETE http://222.128.1.157:19200/_search/scroll [status:200 duration:0.005s]
|
||||
2025-12-08 16:40:43,409 - INFO - 概念: 968, 股票: 5938
|
||||
2025-12-08 16:40:43,505 - INFO - 时间点: 241
|
||||
2025-12-08 16:41:02,028 - INFO - 进度: 30/241 (12%), 异动: 0
|
||||
2025-12-08 16:41:20,851 - INFO - 进度: 60/241 (24%), 异动: 0
|
||||
2025-12-08 16:41:39,396 - INFO - 进度: 90/241 (37%), 异动: 0
|
||||
2025-12-08 16:41:58,687 - INFO - 进度: 120/241 (49%), 异动: 0
|
||||
2025-12-08 16:43:08,124 - INFO - 进度: 150/241 (62%), 异动: 0
|
||||
2025-12-08 16:43:26,973 - INFO - 进度: 180/241 (74%), 异动: 0
|
||||
2025-12-08 16:43:45,746 - INFO - 进度: 210/241 (87%), 异动: 0
|
||||
2025-12-08 16:44:04,479 - INFO - 进度: 240/241 (99%), 异动: 0
|
||||
2025-12-08 16:44:05,123 - INFO - ============================================================
|
||||
2025-12-08 16:44:05,123 - INFO - ✅ 回测完成! 检测到 0 条异动
|
||||
2025-12-08 16:44:05,125 - INFO - ============================================================
|
||||
1625
concept_alert_ml.py
Normal file
1625
concept_alert_ml.py
Normal file
File diff suppressed because it is too large
Load Diff
1366
concept_alert_realtime.py
Normal file
1366
concept_alert_realtime.py
Normal file
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
681
concept_quota_realtime.py
Normal file
681
concept_quota_realtime.py
Normal file
@@ -0,0 +1,681 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
概念涨跌幅实时更新服务
|
||||
- 在交易时间段每分钟从ClickHouse获取最新分钟数据
|
||||
- 计算涨跌幅后更新MySQL的concept_daily_stats表
|
||||
- 支持叶子概念和母概念(lv1/lv2/lv3)
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from datetime import datetime, timedelta
|
||||
from sqlalchemy import create_engine, text
|
||||
from elasticsearch import Elasticsearch
|
||||
from clickhouse_driver import Client
|
||||
import time
|
||||
import logging
|
||||
import json
|
||||
import os
|
||||
import hashlib
|
||||
import argparse
|
||||
|
||||
# ==================== 配置 ====================
|
||||
|
||||
# MySQL配置
|
||||
MYSQL_ENGINE = create_engine(
|
||||
"mysql+pymysql://root:Zzl5588161!@222.128.1.157:33060/stock",
|
||||
echo=False
|
||||
)
|
||||
|
||||
# Elasticsearch配置
|
||||
ES_CLIENT = Elasticsearch(['http://222.128.1.157:19200'])
|
||||
INDEX_NAME = 'concept_library_v3'
|
||||
|
||||
# ClickHouse配置
|
||||
CLICKHOUSE_CONFIG = {
|
||||
'host': '222.128.1.157',
|
||||
'port': 18000,
|
||||
'user': 'default',
|
||||
'password': 'Zzl33818!',
|
||||
'database': 'stock'
|
||||
}
|
||||
|
||||
# 层级结构文件
|
||||
HIERARCHY_FILE = 'concept_hierarchy_v3.json'
|
||||
|
||||
# 交易时间配置
|
||||
TRADING_HOURS = {
|
||||
'morning_start': (9, 30),
|
||||
'morning_end': (11, 30),
|
||||
'afternoon_start': (13, 0),
|
||||
'afternoon_end': (15, 0),
|
||||
}
|
||||
|
||||
# ==================== 日志配置 ====================
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.FileHandler(f'concept_realtime_{datetime.now().strftime("%Y%m%d")}.log', encoding='utf-8'),
|
||||
logging.StreamHandler()
|
||||
]
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ClickHouse客户端
|
||||
ch_client = None
|
||||
|
||||
|
||||
def get_ch_client():
|
||||
"""获取ClickHouse客户端"""
|
||||
global ch_client
|
||||
if ch_client is None:
|
||||
ch_client = Client(**CLICKHOUSE_CONFIG)
|
||||
return ch_client
|
||||
|
||||
|
||||
def generate_id(name: str) -> str:
|
||||
"""生成概念ID"""
|
||||
return hashlib.md5(name.encode('utf-8')).hexdigest()[:16]
|
||||
|
||||
|
||||
def code_to_ch_format(code: str) -> str:
|
||||
"""将6位股票代码转换为ClickHouse格式(带后缀)
|
||||
|
||||
规则:
|
||||
- 6开头 -> .SH(上海)
|
||||
- 0或3开头 -> .SZ(深圳)
|
||||
- 其他 -> .BJ(北京)
|
||||
- 非6位数字的忽略(可能是港股)
|
||||
"""
|
||||
if not code or len(code) != 6 or not code.isdigit():
|
||||
return None
|
||||
|
||||
if code.startswith('6'):
|
||||
return f"{code}.SH"
|
||||
elif code.startswith('0') or code.startswith('3'):
|
||||
return f"{code}.SZ"
|
||||
else:
|
||||
return f"{code}.BJ"
|
||||
|
||||
|
||||
def ch_code_to_pure(ch_code: str) -> str:
|
||||
"""将ClickHouse格式的股票代码转回纯6位代码"""
|
||||
if not ch_code:
|
||||
return None
|
||||
return ch_code.split('.')[0]
|
||||
|
||||
|
||||
# ==================== 概念数据获取 ====================
|
||||
|
||||
def get_all_concepts():
|
||||
"""从ES获取所有叶子概念及其股票列表"""
|
||||
concepts = []
|
||||
|
||||
query = {
|
||||
"query": {"match_all": {}},
|
||||
"size": 100,
|
||||
"_source": ["concept_id", "concept", "stocks"]
|
||||
}
|
||||
|
||||
resp = ES_CLIENT.search(index=INDEX_NAME, body=query, scroll='2m')
|
||||
scroll_id = resp['_scroll_id']
|
||||
hits = resp['hits']['hits']
|
||||
|
||||
while len(hits) > 0:
|
||||
for hit in hits:
|
||||
source = hit['_source']
|
||||
concept_info = {
|
||||
'concept_id': source.get('concept_id'),
|
||||
'concept_name': source.get('concept'),
|
||||
'stocks': [],
|
||||
'concept_type': 'leaf'
|
||||
}
|
||||
|
||||
# v3索引的stocks字段是 [{name, code}, ...]
|
||||
if 'stocks' in source and isinstance(source['stocks'], list):
|
||||
for stock in source['stocks']:
|
||||
if isinstance(stock, dict) and 'code' in stock and stock['code']:
|
||||
concept_info['stocks'].append(stock['code'])
|
||||
|
||||
if concept_info['stocks']:
|
||||
concepts.append(concept_info)
|
||||
|
||||
resp = ES_CLIENT.scroll(scroll_id=scroll_id, scroll='2m')
|
||||
scroll_id = resp['_scroll_id']
|
||||
hits = resp['hits']['hits']
|
||||
|
||||
ES_CLIENT.clear_scroll(scroll_id=scroll_id)
|
||||
return concepts
|
||||
|
||||
|
||||
def load_hierarchy_concepts(leaf_concepts: list) -> list:
|
||||
"""加载层级结构,生成母概念(lv1/lv2/lv3)"""
|
||||
hierarchy_path = os.path.join(os.path.dirname(__file__), HIERARCHY_FILE)
|
||||
if not os.path.exists(hierarchy_path):
|
||||
logger.warning(f"层级文件不存在: {hierarchy_path}")
|
||||
return []
|
||||
|
||||
with open(hierarchy_path, 'r', encoding='utf-8') as f:
|
||||
hierarchy_data = json.load(f)
|
||||
|
||||
# 建立概念名称到股票的映射
|
||||
concept_to_stocks = {}
|
||||
for c in leaf_concepts:
|
||||
concept_to_stocks[c['concept_name']] = set(c['stocks'])
|
||||
|
||||
parent_concepts = []
|
||||
|
||||
for lv1 in hierarchy_data.get('hierarchy', []):
|
||||
lv1_name = lv1.get('lv1', '')
|
||||
lv1_stocks = set()
|
||||
|
||||
for child in lv1.get('children', []):
|
||||
lv2_name = child.get('lv2', '')
|
||||
lv2_stocks = set()
|
||||
|
||||
if 'children' in child:
|
||||
for lv3_child in child.get('children', []):
|
||||
lv3_name = lv3_child.get('lv3', '')
|
||||
lv3_stocks = set()
|
||||
|
||||
for concept_name in lv3_child.get('concepts', []):
|
||||
if concept_name in concept_to_stocks:
|
||||
lv3_stocks.update(concept_to_stocks[concept_name])
|
||||
|
||||
if lv3_stocks:
|
||||
parent_concepts.append({
|
||||
'concept_id': generate_id(f"lv3_{lv3_name}"),
|
||||
'concept_name': f"[三级] {lv3_name}",
|
||||
'stocks': list(lv3_stocks),
|
||||
'concept_type': 'lv3'
|
||||
})
|
||||
|
||||
lv2_stocks.update(lv3_stocks)
|
||||
else:
|
||||
for concept_name in child.get('concepts', []):
|
||||
if concept_name in concept_to_stocks:
|
||||
lv2_stocks.update(concept_to_stocks[concept_name])
|
||||
|
||||
if lv2_stocks:
|
||||
parent_concepts.append({
|
||||
'concept_id': generate_id(f"lv2_{lv2_name}"),
|
||||
'concept_name': f"[二级] {lv2_name}",
|
||||
'stocks': list(lv2_stocks),
|
||||
'concept_type': 'lv2'
|
||||
})
|
||||
|
||||
lv1_stocks.update(lv2_stocks)
|
||||
|
||||
if lv1_stocks:
|
||||
parent_concepts.append({
|
||||
'concept_id': generate_id(f"lv1_{lv1_name}"),
|
||||
'concept_name': f"[一级] {lv1_name}",
|
||||
'stocks': list(lv1_stocks),
|
||||
'concept_type': 'lv1'
|
||||
})
|
||||
|
||||
return parent_concepts
|
||||
|
||||
|
||||
# ==================== 基准价格获取 ====================
|
||||
|
||||
def get_base_prices(stock_codes: list, current_date: str) -> dict:
|
||||
"""获取当日的昨收价作为基准(从ea_trade的F002N字段)
|
||||
|
||||
ea_trade表字段说明:
|
||||
- F002N: 昨日收盘价
|
||||
- F007N: 最近成交价(收盘价)
|
||||
- F010N: 涨跌幅
|
||||
"""
|
||||
if not stock_codes:
|
||||
return {}
|
||||
|
||||
# 过滤出有效的6位股票代码
|
||||
valid_codes = [code for code in stock_codes if code and len(code) == 6 and code.isdigit()]
|
||||
if not valid_codes:
|
||||
return {}
|
||||
|
||||
stock_codes_str = "','".join(valid_codes)
|
||||
|
||||
# 获取当日数据中的昨收价(F002N)
|
||||
query = f"""
|
||||
SELECT SECCODE, F002N
|
||||
FROM ea_trade
|
||||
WHERE SECCODE IN ('{stock_codes_str}')
|
||||
AND TRADEDATE = (
|
||||
SELECT MAX(TRADEDATE)
|
||||
FROM ea_trade
|
||||
WHERE TRADEDATE <= '{current_date}'
|
||||
)
|
||||
AND F002N IS NOT NULL AND F002N > 0
|
||||
"""
|
||||
|
||||
try:
|
||||
with MYSQL_ENGINE.connect() as conn:
|
||||
result = conn.execute(text(query))
|
||||
base_prices = {row[0]: float(row[1]) for row in result if row[1] and float(row[1]) > 0}
|
||||
logger.info(f"获取到 {len(base_prices)} 个基准价格")
|
||||
return base_prices
|
||||
except Exception as e:
|
||||
logger.error(f"获取基准价格失败: {e}")
|
||||
return {}
|
||||
|
||||
|
||||
# ==================== 实时价格获取 ====================
|
||||
|
||||
def get_latest_prices(stock_codes: list) -> dict:
|
||||
"""从ClickHouse获取最新分钟数据的收盘价
|
||||
|
||||
Args:
|
||||
stock_codes: 纯6位股票代码列表(如 ['000001', '600000'])
|
||||
|
||||
Returns:
|
||||
dict: {纯6位代码: {'close': 价格, 'timestamp': 时间}}
|
||||
"""
|
||||
if not stock_codes:
|
||||
return {}
|
||||
|
||||
client = get_ch_client()
|
||||
|
||||
# 转换为ClickHouse格式的代码(带后缀)
|
||||
ch_codes = []
|
||||
code_mapping = {} # ch_code -> pure_code
|
||||
for code in stock_codes:
|
||||
ch_code = code_to_ch_format(code)
|
||||
if ch_code:
|
||||
ch_codes.append(ch_code)
|
||||
code_mapping[ch_code] = code
|
||||
|
||||
if not ch_codes:
|
||||
logger.warning("没有有效的股票代码可查询")
|
||||
return {}
|
||||
|
||||
ch_codes_str = "','".join(ch_codes)
|
||||
|
||||
# 获取今日最新的分钟数据
|
||||
query = f"""
|
||||
SELECT code, close, timestamp
|
||||
FROM (
|
||||
SELECT code, close, timestamp,
|
||||
ROW_NUMBER() OVER (PARTITION BY code ORDER BY timestamp DESC) as rn
|
||||
FROM stock_minute
|
||||
WHERE code IN ('{ch_codes_str}')
|
||||
AND toDate(timestamp) = today()
|
||||
)
|
||||
WHERE rn = 1
|
||||
"""
|
||||
|
||||
try:
|
||||
result = client.execute(query)
|
||||
if not result:
|
||||
return {}
|
||||
|
||||
latest_prices = {}
|
||||
for row in result:
|
||||
ch_code, close, ts = row
|
||||
if close and close > 0:
|
||||
# 转回纯6位代码
|
||||
pure_code = code_mapping.get(ch_code)
|
||||
if pure_code:
|
||||
latest_prices[pure_code] = {
|
||||
'close': float(close),
|
||||
'timestamp': ts
|
||||
}
|
||||
|
||||
return latest_prices
|
||||
except Exception as e:
|
||||
logger.error(f"获取最新价格失败: {e}")
|
||||
return {}
|
||||
|
||||
|
||||
# ==================== 涨跌幅计算 ====================
|
||||
|
||||
def calculate_change_pct(base_prices: dict, latest_prices: dict) -> dict:
|
||||
"""计算涨跌幅"""
|
||||
changes = {}
|
||||
for code, latest in latest_prices.items():
|
||||
if code in base_prices and base_prices[code] > 0:
|
||||
base = base_prices[code]
|
||||
close = latest['close']
|
||||
change_pct = (close - base) / base * 100
|
||||
changes[code] = round(change_pct, 4)
|
||||
return changes
|
||||
|
||||
|
||||
def calculate_concept_stats(concepts: list, stock_changes: dict, trade_date: str) -> list:
|
||||
"""计算所有概念的涨跌幅统计"""
|
||||
stats = []
|
||||
|
||||
for concept in concepts:
|
||||
concept_id = concept['concept_id']
|
||||
concept_name = concept['concept_name']
|
||||
stock_codes = concept['stocks']
|
||||
concept_type = concept.get('concept_type', 'leaf')
|
||||
|
||||
# 获取该概念股票的涨跌幅
|
||||
changes = [stock_changes[code] for code in stock_codes if code in stock_changes]
|
||||
|
||||
if not changes:
|
||||
continue
|
||||
|
||||
avg_change_pct = round(np.mean(changes), 4)
|
||||
stock_count = len(changes)
|
||||
|
||||
stats.append({
|
||||
'concept_id': concept_id,
|
||||
'concept_name': concept_name,
|
||||
'trade_date': trade_date,
|
||||
'avg_change_pct': avg_change_pct,
|
||||
'stock_count': stock_count,
|
||||
'concept_type': concept_type
|
||||
})
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
# ==================== MySQL更新 ====================
|
||||
|
||||
def update_mysql_stats(stats: list):
|
||||
"""更新MySQL的concept_daily_stats表"""
|
||||
if not stats:
|
||||
return 0
|
||||
|
||||
with MYSQL_ENGINE.begin() as conn:
|
||||
updated = 0
|
||||
for item in stats:
|
||||
upsert_sql = text("""
|
||||
REPLACE INTO concept_daily_stats
|
||||
(concept_id, concept_name, trade_date, avg_change_pct, stock_count, concept_type)
|
||||
VALUES (:concept_id, :concept_name, :trade_date, :avg_change_pct, :stock_count, :concept_type)
|
||||
""")
|
||||
conn.execute(upsert_sql, item)
|
||||
updated += 1
|
||||
|
||||
return updated
|
||||
|
||||
|
||||
# ==================== 交易时间判断 ====================
|
||||
|
||||
def is_trading_time() -> bool:
|
||||
"""判断当前是否为交易时间"""
|
||||
now = datetime.now()
|
||||
weekday = now.weekday()
|
||||
|
||||
# 周末不交易
|
||||
if weekday >= 5:
|
||||
return False
|
||||
|
||||
hour, minute = now.hour, now.minute
|
||||
current_time = hour * 60 + minute
|
||||
|
||||
# 上午 9:30 - 11:30
|
||||
morning_start = 9 * 60 + 30
|
||||
morning_end = 11 * 60 + 30
|
||||
|
||||
# 下午 13:00 - 15:00
|
||||
afternoon_start = 13 * 60
|
||||
afternoon_end = 15 * 60
|
||||
|
||||
return (morning_start <= current_time <= morning_end) or \
|
||||
(afternoon_start <= current_time <= afternoon_end)
|
||||
|
||||
|
||||
def get_next_update_time() -> int:
|
||||
"""获取距离下次更新的秒数"""
|
||||
now = datetime.now()
|
||||
|
||||
if is_trading_time():
|
||||
# 交易时间内,等到下一分钟
|
||||
return 60 - now.second
|
||||
else:
|
||||
# 非交易时间
|
||||
hour, minute = now.hour, now.minute
|
||||
|
||||
# 计算距离下次交易开始的时间
|
||||
if hour < 9 or (hour == 9 and minute < 30):
|
||||
# 等到9:30
|
||||
target = now.replace(hour=9, minute=30, second=0, microsecond=0)
|
||||
elif (hour == 11 and minute >= 30) or hour == 12:
|
||||
# 等到13:00
|
||||
target = now.replace(hour=13, minute=0, second=0, microsecond=0)
|
||||
elif hour >= 15:
|
||||
# 等到明天9:30
|
||||
target = (now + timedelta(days=1)).replace(hour=9, minute=30, second=0, microsecond=0)
|
||||
else:
|
||||
target = now + timedelta(minutes=1)
|
||||
|
||||
wait_seconds = (target - now).total_seconds()
|
||||
return max(60, int(wait_seconds))
|
||||
|
||||
|
||||
# ==================== 主运行逻辑 ====================
|
||||
|
||||
def run_once(concepts: list, all_stocks: list) -> int:
|
||||
"""执行一次更新"""
|
||||
now = datetime.now()
|
||||
trade_date = now.strftime('%Y-%m-%d')
|
||||
|
||||
# 获取基准价格(昨日收盘价)
|
||||
base_prices = get_base_prices(all_stocks, trade_date)
|
||||
if not base_prices:
|
||||
logger.warning("无法获取基准价格")
|
||||
return 0
|
||||
|
||||
# 获取最新价格
|
||||
latest_prices = get_latest_prices(all_stocks)
|
||||
if not latest_prices:
|
||||
logger.warning("无法获取最新价格")
|
||||
return 0
|
||||
|
||||
# 计算涨跌幅
|
||||
stock_changes = calculate_change_pct(base_prices, latest_prices)
|
||||
if not stock_changes:
|
||||
logger.warning("无涨跌幅数据")
|
||||
return 0
|
||||
|
||||
logger.info(f"获取到 {len(stock_changes)} 只股票的涨跌幅")
|
||||
|
||||
# 计算概念统计
|
||||
stats = calculate_concept_stats(concepts, stock_changes, trade_date)
|
||||
logger.info(f"计算了 {len(stats)} 个概念的涨跌幅")
|
||||
|
||||
# 更新MySQL
|
||||
updated = update_mysql_stats(stats)
|
||||
logger.info(f"更新了 {updated} 条记录到MySQL")
|
||||
|
||||
return updated
|
||||
|
||||
|
||||
def run_realtime():
|
||||
"""实时更新主循环"""
|
||||
logger.info("=" * 60)
|
||||
logger.info("启动概念涨跌幅实时更新服务")
|
||||
logger.info("=" * 60)
|
||||
|
||||
# 加载概念数据
|
||||
logger.info("加载概念数据...")
|
||||
leaf_concepts = get_all_concepts()
|
||||
logger.info(f"获取到 {len(leaf_concepts)} 个叶子概念")
|
||||
|
||||
parent_concepts = load_hierarchy_concepts(leaf_concepts)
|
||||
logger.info(f"生成了 {len(parent_concepts)} 个母概念")
|
||||
|
||||
all_concepts = leaf_concepts + parent_concepts
|
||||
logger.info(f"总计 {len(all_concepts)} 个概念")
|
||||
|
||||
# 收集所有股票代码
|
||||
all_stocks = set()
|
||||
for c in all_concepts:
|
||||
all_stocks.update(c['stocks'])
|
||||
all_stocks = list(all_stocks)
|
||||
logger.info(f"监控 {len(all_stocks)} 只股票")
|
||||
|
||||
last_concept_update = datetime.now()
|
||||
|
||||
while True:
|
||||
try:
|
||||
now = datetime.now()
|
||||
|
||||
# 每小时重新加载概念数据
|
||||
if (now - last_concept_update).total_seconds() > 3600:
|
||||
logger.info("重新加载概念数据...")
|
||||
leaf_concepts = get_all_concepts()
|
||||
parent_concepts = load_hierarchy_concepts(leaf_concepts)
|
||||
all_concepts = leaf_concepts + parent_concepts
|
||||
all_stocks = set()
|
||||
for c in all_concepts:
|
||||
all_stocks.update(c['stocks'])
|
||||
all_stocks = list(all_stocks)
|
||||
last_concept_update = now
|
||||
logger.info(f"更新完成: {len(all_concepts)} 个概念, {len(all_stocks)} 只股票")
|
||||
|
||||
# 检查是否交易时间
|
||||
if not is_trading_time():
|
||||
wait_sec = get_next_update_time()
|
||||
wait_min = wait_sec // 60
|
||||
logger.info(f"非交易时间,等待 {wait_min} 分钟后重试...")
|
||||
time.sleep(min(wait_sec, 300)) # 最多等5分钟再检查
|
||||
continue
|
||||
|
||||
# 执行更新
|
||||
logger.info(f"\n{'=' * 40}")
|
||||
logger.info(f"更新时间: {now.strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
updated = run_once(all_concepts, all_stocks)
|
||||
|
||||
# 等待下一分钟
|
||||
sleep_sec = 60 - datetime.now().second
|
||||
logger.info(f"完成,等待 {sleep_sec} 秒后继续...")
|
||||
time.sleep(sleep_sec)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("\n收到退出信号,停止服务...")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"发生错误: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
time.sleep(60)
|
||||
|
||||
|
||||
def run_single():
|
||||
"""单次运行(不循环)"""
|
||||
logger.info("单次更新模式")
|
||||
|
||||
leaf_concepts = get_all_concepts()
|
||||
parent_concepts = load_hierarchy_concepts(leaf_concepts)
|
||||
all_concepts = leaf_concepts + parent_concepts
|
||||
|
||||
all_stocks = set()
|
||||
for c in all_concepts:
|
||||
all_stocks.update(c['stocks'])
|
||||
all_stocks = list(all_stocks)
|
||||
|
||||
logger.info(f"概念数: {len(all_concepts)}, 股票数: {len(all_stocks)}")
|
||||
|
||||
updated = run_once(all_concepts, all_stocks)
|
||||
logger.info(f"更新完成: {updated} 条记录")
|
||||
|
||||
|
||||
def show_status():
|
||||
"""显示当前状态"""
|
||||
print("\n" + "=" * 60)
|
||||
print("概念涨跌幅实时更新服务 - 状态")
|
||||
print("=" * 60)
|
||||
|
||||
# 当前时间
|
||||
now = datetime.now()
|
||||
print(f"\n当前时间: {now.strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
print(f"是否交易时间: {'是' if is_trading_time() else '否'}")
|
||||
|
||||
# MySQL数据状态
|
||||
print("\nMySQL数据状态:")
|
||||
try:
|
||||
with MYSQL_ENGINE.connect() as conn:
|
||||
# 今日数据量
|
||||
result = conn.execute(text("""
|
||||
SELECT concept_type, COUNT(*) as cnt
|
||||
FROM concept_daily_stats
|
||||
WHERE trade_date = CURDATE()
|
||||
GROUP BY concept_type
|
||||
"""))
|
||||
rows = list(result)
|
||||
if rows:
|
||||
print(" 今日数据:")
|
||||
for row in rows:
|
||||
print(f" {row[0]}: {row[1]} 条")
|
||||
else:
|
||||
print(" 今日暂无数据")
|
||||
|
||||
# 最新更新时间
|
||||
result = conn.execute(text("""
|
||||
SELECT MAX(updated_at) FROM concept_daily_stats WHERE trade_date = CURDATE()
|
||||
"""))
|
||||
row = result.fetchone()
|
||||
if row and row[0]:
|
||||
print(f" 最后更新: {row[0]}")
|
||||
except Exception as e:
|
||||
print(f" 查询失败: {e}")
|
||||
|
||||
# ClickHouse数据状态
|
||||
print("\nClickHouse数据状态:")
|
||||
try:
|
||||
client = get_ch_client()
|
||||
result = client.execute("""
|
||||
SELECT COUNT(*), MAX(timestamp)
|
||||
FROM stock_minute
|
||||
WHERE toDate(timestamp) = today()
|
||||
""")
|
||||
if result:
|
||||
count, max_ts = result[0]
|
||||
print(f" 今日分钟数据: {count:,} 条")
|
||||
print(f" 最新时间戳: {max_ts}")
|
||||
except Exception as e:
|
||||
print(f" 查询失败: {e}")
|
||||
|
||||
# 今日涨跌幅TOP10
|
||||
print("\n今日涨跌幅 TOP10:")
|
||||
try:
|
||||
with MYSQL_ENGINE.connect() as conn:
|
||||
result = conn.execute(text("""
|
||||
SELECT concept_name, avg_change_pct, stock_count, concept_type
|
||||
FROM concept_daily_stats
|
||||
WHERE trade_date = CURDATE() AND concept_type = 'leaf'
|
||||
ORDER BY avg_change_pct DESC
|
||||
LIMIT 10
|
||||
"""))
|
||||
rows = list(result)
|
||||
if rows:
|
||||
print(f" {'概念':<25} | {'涨跌幅':>8} | {'股票数':>6}")
|
||||
print(" " + "-" * 50)
|
||||
for row in rows:
|
||||
name = row[0][:25] if len(row[0]) > 25 else row[0]
|
||||
print(f" {name:<25} | {row[1]:>7.2f}% | {row[2]:>6}")
|
||||
else:
|
||||
print(" 暂无数据")
|
||||
except Exception as e:
|
||||
print(f" 查询失败: {e}")
|
||||
|
||||
|
||||
# ==================== 主函数 ====================
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='概念涨跌幅实时更新服务')
|
||||
parser.add_argument('command', nargs='?', default='realtime',
|
||||
choices=['realtime', 'once', 'status'],
|
||||
help='命令: realtime(实时运行), once(单次运行), status(状态查看)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.command == 'realtime':
|
||||
run_realtime()
|
||||
elif args.command == 'once':
|
||||
run_single()
|
||||
elif args.command == 'status':
|
||||
show_status()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
89
create_tables.py
Normal file
89
create_tables.py
Normal file
@@ -0,0 +1,89 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""创建异动检测所需的数据库表"""
|
||||
|
||||
import sys
|
||||
from sqlalchemy import create_engine, text
|
||||
|
||||
engine = create_engine('mysql+pymysql://root:Zzl5588161!@222.128.1.157:33060/stock', echo=False)
|
||||
|
||||
# 删除旧表
|
||||
drop_sql1 = 'DROP TABLE IF EXISTS concept_minute_alert'
|
||||
drop_sql2 = 'DROP TABLE IF EXISTS index_minute_snapshot'
|
||||
|
||||
# 创建 concept_minute_alert 表
|
||||
# 支持 Z-Score + SVM 智能检测
|
||||
sql1 = '''
|
||||
CREATE TABLE concept_minute_alert (
|
||||
id BIGINT AUTO_INCREMENT PRIMARY KEY,
|
||||
concept_id VARCHAR(32) NOT NULL,
|
||||
concept_name VARCHAR(100) NOT NULL,
|
||||
alert_time DATETIME NOT NULL,
|
||||
alert_type VARCHAR(20) NOT NULL COMMENT 'surge_up=暴涨, surge_down=暴跌, limit_up=涨停增加, rank_jump=排名跃升',
|
||||
trade_date DATE NOT NULL,
|
||||
change_pct DECIMAL(10,4) COMMENT '当前涨跌幅',
|
||||
prev_change_pct DECIMAL(10,4) COMMENT '之前涨跌幅',
|
||||
change_delta DECIMAL(10,4) COMMENT '涨跌幅变化',
|
||||
limit_up_count INT DEFAULT 0 COMMENT '涨停数',
|
||||
prev_limit_up_count INT DEFAULT 0,
|
||||
limit_up_delta INT DEFAULT 0,
|
||||
limit_down_count INT DEFAULT 0 COMMENT '跌停数',
|
||||
rank_position INT COMMENT '当前排名',
|
||||
prev_rank_position INT COMMENT '之前排名',
|
||||
rank_delta INT COMMENT '排名变化(负数表示上升)',
|
||||
index_code VARCHAR(20) DEFAULT '000001.SH',
|
||||
index_price DECIMAL(12,4),
|
||||
index_change_pct DECIMAL(10,4),
|
||||
stock_count INT,
|
||||
concept_type VARCHAR(20) DEFAULT 'leaf',
|
||||
zscore DECIMAL(8,4) COMMENT 'Z-Score值',
|
||||
importance_score DECIMAL(6,4) COMMENT '重要性评分(0-1)',
|
||||
extra_info JSON COMMENT '扩展信息(包含zscore,svm_score等)',
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
INDEX idx_trade_date (trade_date),
|
||||
INDEX idx_alert_time (alert_time),
|
||||
INDEX idx_concept_id (concept_id),
|
||||
INDEX idx_alert_type (alert_type),
|
||||
INDEX idx_trade_date_time (trade_date, alert_time),
|
||||
INDEX idx_importance (importance_score)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='概念异动记录表(智能版)'
|
||||
'''
|
||||
|
||||
# 创建 index_minute_snapshot 表
|
||||
sql2 = '''
|
||||
CREATE TABLE index_minute_snapshot (
|
||||
id BIGINT AUTO_INCREMENT PRIMARY KEY,
|
||||
index_code VARCHAR(20) NOT NULL,
|
||||
trade_date DATE NOT NULL,
|
||||
snapshot_time DATETIME NOT NULL,
|
||||
price DECIMAL(12,4),
|
||||
open_price DECIMAL(12,4),
|
||||
high_price DECIMAL(12,4),
|
||||
low_price DECIMAL(12,4),
|
||||
prev_close DECIMAL(12,4),
|
||||
change_pct DECIMAL(10,4),
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
UNIQUE KEY uk_index_time (index_code, snapshot_time),
|
||||
INDEX idx_trade_date (trade_date)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
|
||||
'''
|
||||
|
||||
if __name__ == '__main__':
|
||||
print('正在重建数据库表...\n')
|
||||
|
||||
with engine.begin() as conn:
|
||||
# 先删除旧表
|
||||
print('删除旧表...')
|
||||
conn.execute(text(drop_sql1))
|
||||
print(' - concept_minute_alert 已删除')
|
||||
conn.execute(text(drop_sql2))
|
||||
print(' - index_minute_snapshot 已删除')
|
||||
|
||||
# 创建新表
|
||||
print('\n创建新表...')
|
||||
conn.execute(text(sql1))
|
||||
print(' ✅ concept_minute_alert 表创建成功')
|
||||
conn.execute(text(sql2))
|
||||
print(' ✅ index_minute_snapshot 表创建成功')
|
||||
|
||||
print('\n✅ 所有表创建完成!')
|
||||
112
ml/README.md
Normal file
112
ml/README.md
Normal file
@@ -0,0 +1,112 @@
|
||||
# 概念异动检测 ML 模块
|
||||
|
||||
基于 Transformer Autoencoder 的概念异动检测系统。
|
||||
|
||||
## 环境要求
|
||||
|
||||
- Python 3.8+
|
||||
- PyTorch 2.0+ (CUDA 12.x for 5090 GPU)
|
||||
- ClickHouse, MySQL, Elasticsearch
|
||||
|
||||
## 数据库配置
|
||||
|
||||
当前配置(`prepare_data.py`):
|
||||
- MySQL: `192.168.1.5:3306`
|
||||
- Elasticsearch: `127.0.0.1:9200`
|
||||
- ClickHouse: `127.0.0.1:9000`
|
||||
|
||||
## 快速开始
|
||||
|
||||
```bash
|
||||
# 1. 安装依赖
|
||||
pip install -r ml/requirements.txt
|
||||
|
||||
# 2. 安装 PyTorch (5090 需要 CUDA 12.4)
|
||||
pip install torch --index-url https://download.pytorch.org/whl/cu124
|
||||
|
||||
# 3. 运行训练
|
||||
chmod +x ml/run_training.sh
|
||||
./ml/run_training.sh
|
||||
```
|
||||
|
||||
## 文件说明
|
||||
|
||||
| 文件 | 说明 |
|
||||
|------|------|
|
||||
| `model.py` | Transformer Autoencoder 模型定义 |
|
||||
| `prepare_data.py` | 数据提取和特征计算 |
|
||||
| `train.py` | 模型训练脚本 |
|
||||
| `inference.py` | 推理服务 |
|
||||
| `enhanced_detector.py` | 增强版检测器(融合 Alpha + ML) |
|
||||
|
||||
## 训练参数
|
||||
|
||||
```bash
|
||||
# 完整参数
|
||||
./ml/run_training.sh --start 2022-01-01 --end 2024-12-01 --epochs 100 --batch_size 256
|
||||
|
||||
# 只准备数据
|
||||
python ml/prepare_data.py --start 2022-01-01
|
||||
|
||||
# 只训练(数据已准备好)
|
||||
python ml/train.py --epochs 100 --batch_size 256 --lr 1e-4
|
||||
```
|
||||
|
||||
## 模型架构
|
||||
|
||||
```
|
||||
输入: (batch, 30, 6) # 30分钟序列,6个特征
|
||||
↓
|
||||
Positional Encoding
|
||||
↓
|
||||
Transformer Encoder (4层, 8头, d=128)
|
||||
↓
|
||||
Bottleneck (压缩到 32 维)
|
||||
↓
|
||||
Transformer Decoder (4层)
|
||||
↓
|
||||
输出: (batch, 30, 6) # 重构序列
|
||||
|
||||
异动判断: reconstruction_error > threshold
|
||||
```
|
||||
|
||||
## 6维特征
|
||||
|
||||
1. `alpha` - 超额收益(概念涨幅 - 大盘涨幅)
|
||||
2. `alpha_delta` - Alpha 5分钟变化
|
||||
3. `amt_ratio` - 成交额 / 20分钟均值
|
||||
4. `amt_delta` - 成交额变化率
|
||||
5. `rank_pct` - Alpha 排名百分位
|
||||
6. `limit_up_ratio` - 涨停股占比
|
||||
|
||||
## 训练产出
|
||||
|
||||
训练完成后,`ml/checkpoints/` 包含:
|
||||
- `best_model.pt` - 最佳模型权重
|
||||
- `thresholds.json` - 异动阈值 (P90/P95/P99)
|
||||
- `normalization_stats.json` - 数据标准化参数
|
||||
- `config.json` - 训练配置
|
||||
|
||||
## 使用示例
|
||||
|
||||
```python
|
||||
from ml.inference import ConceptAnomalyDetector
|
||||
|
||||
detector = ConceptAnomalyDetector('ml/checkpoints')
|
||||
|
||||
# 实时检测
|
||||
is_anomaly, score = detector.detect(
|
||||
concept_name="人工智能",
|
||||
features={
|
||||
'alpha': 2.5,
|
||||
'alpha_delta': 0.8,
|
||||
'amt_ratio': 1.5,
|
||||
'amt_delta': 0.3,
|
||||
'rank_pct': 0.95,
|
||||
'limit_up_ratio': 0.15,
|
||||
}
|
||||
)
|
||||
|
||||
if is_anomaly:
|
||||
print(f"检测到异动!分数: {score}")
|
||||
```
|
||||
10
ml/__init__.py
Normal file
10
ml/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
概念异动检测 ML 模块
|
||||
|
||||
提供基于 Transformer Autoencoder 的异动检测功能
|
||||
"""
|
||||
|
||||
from .inference import ConceptAnomalyDetector, MLAnomalyService
|
||||
|
||||
__all__ = ['ConceptAnomalyDetector', 'MLAnomalyService']
|
||||
481
ml/backtest.py
Normal file
481
ml/backtest.py
Normal file
@@ -0,0 +1,481 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
历史异动回测脚本
|
||||
|
||||
使用训练好的模型,对历史数据进行异动检测,生成异动记录
|
||||
|
||||
使用方法:
|
||||
# 回测指定日期范围
|
||||
python backtest.py --start 2024-01-01 --end 2024-12-01
|
||||
|
||||
# 回测单天
|
||||
python backtest.py --start 2024-11-01 --end 2024-11-01
|
||||
|
||||
# 只生成结果,不写入数据库
|
||||
python backtest.py --start 2024-01-01 --dry-run
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import json
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple, Optional
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from sqlalchemy import create_engine, text
|
||||
|
||||
# 添加父目录到路径
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from model import TransformerAutoencoder
|
||||
|
||||
|
||||
# ==================== 配置 ====================
|
||||
|
||||
MYSQL_ENGINE = create_engine(
|
||||
"mysql+pymysql://root:Zzl5588161!@192.168.1.5:3306/stock",
|
||||
echo=False
|
||||
)
|
||||
|
||||
# 特征列表(与训练一致)
|
||||
FEATURES = [
|
||||
'alpha',
|
||||
'alpha_delta',
|
||||
'amt_ratio',
|
||||
'amt_delta',
|
||||
'rank_pct',
|
||||
'limit_up_ratio',
|
||||
]
|
||||
|
||||
# 回测配置
|
||||
BACKTEST_CONFIG = {
|
||||
'seq_len': 30, # 序列长度
|
||||
'threshold_key': 'p95', # 使用的阈值
|
||||
'min_alpha_abs': 0.5, # 最小 Alpha 绝对值(过滤微小波动)
|
||||
'cooldown_minutes': 8, # 同一概念冷却时间
|
||||
'max_alerts_per_minute': 15, # 每分钟最多异动数
|
||||
'clip_value': 10.0, # 极端值截断
|
||||
}
|
||||
|
||||
|
||||
# ==================== 模型加载 ====================
|
||||
|
||||
class AnomalyDetector:
|
||||
"""异动检测器"""
|
||||
|
||||
def __init__(self, checkpoint_dir: str = 'ml/checkpoints', device: str = 'auto'):
|
||||
self.checkpoint_dir = Path(checkpoint_dir)
|
||||
|
||||
# 设备
|
||||
if device == 'auto':
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
else:
|
||||
self.device = torch.device(device)
|
||||
|
||||
# 加载配置
|
||||
self._load_config()
|
||||
|
||||
# 加载模型
|
||||
self._load_model()
|
||||
|
||||
# 加载阈值
|
||||
self._load_thresholds()
|
||||
|
||||
print(f"AnomalyDetector 初始化完成")
|
||||
print(f" 设备: {self.device}")
|
||||
print(f" 阈值 ({BACKTEST_CONFIG['threshold_key']}): {self.threshold:.6f}")
|
||||
|
||||
def _load_config(self):
|
||||
config_path = self.checkpoint_dir / 'config.json'
|
||||
with open(config_path, 'r') as f:
|
||||
self.config = json.load(f)
|
||||
|
||||
def _load_model(self):
|
||||
model_path = self.checkpoint_dir / 'best_model.pt'
|
||||
checkpoint = torch.load(model_path, map_location=self.device)
|
||||
|
||||
model_config = self.config['model'].copy()
|
||||
model_config['use_instance_norm'] = self.config.get('use_instance_norm', True)
|
||||
|
||||
self.model = TransformerAutoencoder(**model_config)
|
||||
self.model.load_state_dict(checkpoint['model_state_dict'])
|
||||
self.model.to(self.device)
|
||||
self.model.eval()
|
||||
|
||||
def _load_thresholds(self):
|
||||
thresholds_path = self.checkpoint_dir / 'thresholds.json'
|
||||
with open(thresholds_path, 'r') as f:
|
||||
thresholds = json.load(f)
|
||||
|
||||
self.threshold = thresholds[BACKTEST_CONFIG['threshold_key']]
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_anomaly_scores(self, sequences: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
计算异动分数
|
||||
|
||||
Args:
|
||||
sequences: (n_sequences, seq_len, n_features)
|
||||
Returns:
|
||||
scores: (n_sequences,) 每个序列最后时刻的异动分数
|
||||
"""
|
||||
# 截断极端值
|
||||
sequences = np.clip(sequences, -BACKTEST_CONFIG['clip_value'], BACKTEST_CONFIG['clip_value'])
|
||||
|
||||
# 转为 tensor
|
||||
x = torch.FloatTensor(sequences).to(self.device)
|
||||
|
||||
# 计算重构误差
|
||||
errors = self.model.compute_reconstruction_error(x, reduction='none')
|
||||
|
||||
# 取最后一个时刻的误差
|
||||
scores = errors[:, -1].cpu().numpy()
|
||||
|
||||
return scores
|
||||
|
||||
def is_anomaly(self, score: float) -> bool:
|
||||
"""判断是否异动"""
|
||||
return score > self.threshold
|
||||
|
||||
|
||||
# ==================== 数据加载 ====================
|
||||
|
||||
def load_daily_features(data_dir: str, date: str) -> Optional[pd.DataFrame]:
|
||||
"""加载单天的特征数据"""
|
||||
file_path = Path(data_dir) / f"features_{date}.parquet"
|
||||
|
||||
if not file_path.exists():
|
||||
return None
|
||||
|
||||
df = pd.read_parquet(file_path)
|
||||
return df
|
||||
|
||||
|
||||
def get_available_dates(data_dir: str, start_date: str, end_date: str) -> List[str]:
|
||||
"""获取可用的日期列表"""
|
||||
data_path = Path(data_dir)
|
||||
all_files = sorted(data_path.glob("features_*.parquet"))
|
||||
|
||||
dates = []
|
||||
for f in all_files:
|
||||
date = f.stem.replace('features_', '')
|
||||
if start_date <= date <= end_date:
|
||||
dates.append(date)
|
||||
|
||||
return dates
|
||||
|
||||
|
||||
# ==================== 回测逻辑 ====================
|
||||
|
||||
def backtest_single_day(
|
||||
detector: AnomalyDetector,
|
||||
df: pd.DataFrame,
|
||||
date: str,
|
||||
seq_len: int = 30
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
回测单天数据
|
||||
|
||||
Args:
|
||||
detector: 异动检测器
|
||||
df: 当天的特征数据
|
||||
date: 日期
|
||||
seq_len: 序列长度
|
||||
|
||||
Returns:
|
||||
alerts: 异动列表
|
||||
"""
|
||||
alerts = []
|
||||
|
||||
# 按概念分组
|
||||
grouped = df.groupby('concept_id', sort=False)
|
||||
|
||||
# 冷却记录 {concept_id: last_alert_timestamp}
|
||||
cooldown = {}
|
||||
|
||||
# 获取所有时间点
|
||||
all_timestamps = sorted(df['timestamp'].unique())
|
||||
|
||||
if len(all_timestamps) < seq_len:
|
||||
return alerts
|
||||
|
||||
# 对每个时间点进行检测(从第 seq_len 个开始)
|
||||
for t_idx in range(seq_len - 1, len(all_timestamps)):
|
||||
current_time = all_timestamps[t_idx]
|
||||
window_start_time = all_timestamps[t_idx - seq_len + 1]
|
||||
|
||||
minute_alerts = []
|
||||
|
||||
# 收集该时刻所有概念的序列
|
||||
concept_sequences = []
|
||||
concept_infos = []
|
||||
|
||||
for concept_id, concept_df in grouped:
|
||||
# 获取该概念在时间窗口内的数据
|
||||
mask = (concept_df['timestamp'] >= window_start_time) & (concept_df['timestamp'] <= current_time)
|
||||
window_df = concept_df[mask].sort_values('timestamp')
|
||||
|
||||
if len(window_df) < seq_len:
|
||||
continue
|
||||
|
||||
# 取最后 seq_len 个点
|
||||
window_df = window_df.tail(seq_len)
|
||||
|
||||
# 提取特征
|
||||
features = window_df[FEATURES].values
|
||||
|
||||
# 处理缺失值
|
||||
features = np.nan_to_num(features, nan=0.0, posinf=0.0, neginf=0.0)
|
||||
|
||||
# 获取当前时刻的信息
|
||||
current_row = window_df.iloc[-1]
|
||||
|
||||
concept_sequences.append(features)
|
||||
concept_infos.append({
|
||||
'concept_id': concept_id,
|
||||
'timestamp': current_time,
|
||||
'alpha': current_row.get('alpha', 0),
|
||||
'alpha_delta': current_row.get('alpha_delta', 0),
|
||||
'amt_ratio': current_row.get('amt_ratio', 1),
|
||||
'limit_up_ratio': current_row.get('limit_up_ratio', 0),
|
||||
'limit_down_ratio': current_row.get('limit_down_ratio', 0),
|
||||
'rank_pct': current_row.get('rank_pct', 0.5),
|
||||
'stock_count': current_row.get('stock_count', 0),
|
||||
'total_amt': current_row.get('total_amt', 0),
|
||||
})
|
||||
|
||||
if not concept_sequences:
|
||||
continue
|
||||
|
||||
# 批量计算异动分数
|
||||
sequences_array = np.array(concept_sequences)
|
||||
scores = detector.compute_anomaly_scores(sequences_array)
|
||||
|
||||
# 检测异动
|
||||
for i, (info, score) in enumerate(zip(concept_infos, scores)):
|
||||
concept_id = info['concept_id']
|
||||
alpha = info['alpha']
|
||||
|
||||
# 过滤小波动
|
||||
if abs(alpha) < BACKTEST_CONFIG['min_alpha_abs']:
|
||||
continue
|
||||
|
||||
# 检查冷却
|
||||
if concept_id in cooldown:
|
||||
last_alert = cooldown[concept_id]
|
||||
if isinstance(current_time, datetime):
|
||||
time_diff = (current_time - last_alert).total_seconds() / 60
|
||||
else:
|
||||
# timestamp 是字符串或其他格式
|
||||
time_diff = BACKTEST_CONFIG['cooldown_minutes'] + 1 # 跳过冷却检查
|
||||
|
||||
if time_diff < BACKTEST_CONFIG['cooldown_minutes']:
|
||||
continue
|
||||
|
||||
# 判断是否异动
|
||||
if not detector.is_anomaly(score):
|
||||
continue
|
||||
|
||||
# 记录异动
|
||||
alert_type = 'surge_up' if alpha > 0 else 'surge_down'
|
||||
|
||||
alert = {
|
||||
'concept_id': concept_id,
|
||||
'alert_time': current_time,
|
||||
'trade_date': date,
|
||||
'alert_type': alert_type,
|
||||
'anomaly_score': float(score),
|
||||
'threshold': detector.threshold,
|
||||
**info
|
||||
}
|
||||
|
||||
minute_alerts.append(alert)
|
||||
cooldown[concept_id] = current_time
|
||||
|
||||
# 按分数排序,限制数量
|
||||
minute_alerts.sort(key=lambda x: x['anomaly_score'], reverse=True)
|
||||
alerts.extend(minute_alerts[:BACKTEST_CONFIG['max_alerts_per_minute']])
|
||||
|
||||
return alerts
|
||||
|
||||
|
||||
# ==================== 数据库写入 ====================
|
||||
|
||||
def save_alerts_to_mysql(alerts: List[Dict], dry_run: bool = False) -> int:
|
||||
"""保存异动到 MySQL"""
|
||||
if not alerts:
|
||||
return 0
|
||||
|
||||
if dry_run:
|
||||
print(f" [Dry Run] 将写入 {len(alerts)} 条异动")
|
||||
return len(alerts)
|
||||
|
||||
saved = 0
|
||||
with MYSQL_ENGINE.begin() as conn:
|
||||
for alert in alerts:
|
||||
try:
|
||||
# 检查是否已存在
|
||||
check_sql = text("""
|
||||
SELECT id FROM concept_minute_alert
|
||||
WHERE concept_id = :concept_id
|
||||
AND alert_time = :alert_time
|
||||
AND trade_date = :trade_date
|
||||
""")
|
||||
exists = conn.execute(check_sql, {
|
||||
'concept_id': alert['concept_id'],
|
||||
'alert_time': alert['alert_time'],
|
||||
'trade_date': alert['trade_date'],
|
||||
}).fetchone()
|
||||
|
||||
if exists:
|
||||
continue
|
||||
|
||||
# 插入新记录
|
||||
insert_sql = text("""
|
||||
INSERT INTO concept_minute_alert
|
||||
(concept_id, concept_name, alert_time, alert_type, trade_date,
|
||||
change_pct, zscore, importance_score, stock_count, extra_info)
|
||||
VALUES
|
||||
(:concept_id, :concept_name, :alert_time, :alert_type, :trade_date,
|
||||
:change_pct, :zscore, :importance_score, :stock_count, :extra_info)
|
||||
""")
|
||||
|
||||
conn.execute(insert_sql, {
|
||||
'concept_id': alert['concept_id'],
|
||||
'concept_name': alert.get('concept_name', ''),
|
||||
'alert_time': alert['alert_time'],
|
||||
'alert_type': alert['alert_type'],
|
||||
'trade_date': alert['trade_date'],
|
||||
'change_pct': alert.get('alpha', 0),
|
||||
'zscore': alert['anomaly_score'],
|
||||
'importance_score': alert['anomaly_score'],
|
||||
'stock_count': alert.get('stock_count', 0),
|
||||
'extra_info': json.dumps({
|
||||
'detection_method': 'ml_autoencoder',
|
||||
'threshold': alert['threshold'],
|
||||
'alpha': alert.get('alpha', 0),
|
||||
'amt_ratio': alert.get('amt_ratio', 1),
|
||||
}, ensure_ascii=False)
|
||||
})
|
||||
|
||||
saved += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f" 保存失败: {alert['concept_id']} - {e}")
|
||||
|
||||
return saved
|
||||
|
||||
|
||||
def export_alerts_to_csv(alerts: List[Dict], output_path: str):
|
||||
"""导出异动到 CSV"""
|
||||
if not alerts:
|
||||
return
|
||||
|
||||
df = pd.DataFrame(alerts)
|
||||
df.to_csv(output_path, index=False, encoding='utf-8-sig')
|
||||
print(f"已导出到: {output_path}")
|
||||
|
||||
|
||||
# ==================== 主函数 ====================
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='历史异动回测')
|
||||
parser.add_argument('--data_dir', type=str, default='ml/data',
|
||||
help='特征数据目录')
|
||||
parser.add_argument('--checkpoint_dir', type=str, default='ml/checkpoints',
|
||||
help='模型检查点目录')
|
||||
parser.add_argument('--start', type=str, required=True,
|
||||
help='开始日期 (YYYY-MM-DD)')
|
||||
parser.add_argument('--end', type=str, required=True,
|
||||
help='结束日期 (YYYY-MM-DD)')
|
||||
parser.add_argument('--dry-run', action='store_true',
|
||||
help='只计算,不写入数据库')
|
||||
parser.add_argument('--export-csv', type=str, default=None,
|
||||
help='导出 CSV 文件路径')
|
||||
parser.add_argument('--device', type=str, default='auto',
|
||||
help='设备 (auto/cuda/cpu)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print("=" * 60)
|
||||
print("历史异动回测")
|
||||
print("=" * 60)
|
||||
print(f"日期范围: {args.start} ~ {args.end}")
|
||||
print(f"数据目录: {args.data_dir}")
|
||||
print(f"模型目录: {args.checkpoint_dir}")
|
||||
print(f"Dry Run: {args.dry_run}")
|
||||
print("=" * 60)
|
||||
|
||||
# 初始化检测器
|
||||
detector = AnomalyDetector(args.checkpoint_dir, args.device)
|
||||
|
||||
# 获取可用日期
|
||||
dates = get_available_dates(args.data_dir, args.start, args.end)
|
||||
|
||||
if not dates:
|
||||
print(f"未找到 {args.start} ~ {args.end} 范围内的数据")
|
||||
return
|
||||
|
||||
print(f"\n找到 {len(dates)} 天的数据")
|
||||
|
||||
# 回测
|
||||
all_alerts = []
|
||||
total_saved = 0
|
||||
|
||||
for date in tqdm(dates, desc="回测进度"):
|
||||
# 加载数据
|
||||
df = load_daily_features(args.data_dir, date)
|
||||
|
||||
if df is None or df.empty:
|
||||
continue
|
||||
|
||||
# 回测单天
|
||||
alerts = backtest_single_day(
|
||||
detector, df, date,
|
||||
seq_len=BACKTEST_CONFIG['seq_len']
|
||||
)
|
||||
|
||||
if alerts:
|
||||
all_alerts.extend(alerts)
|
||||
|
||||
# 写入数据库
|
||||
saved = save_alerts_to_mysql(alerts, dry_run=args.dry_run)
|
||||
total_saved += saved
|
||||
|
||||
if not args.dry_run:
|
||||
tqdm.write(f" {date}: 检测到 {len(alerts)} 个异动,保存 {saved} 条")
|
||||
|
||||
# 导出 CSV
|
||||
if args.export_csv and all_alerts:
|
||||
export_alerts_to_csv(all_alerts, args.export_csv)
|
||||
|
||||
# 汇总
|
||||
print("\n" + "=" * 60)
|
||||
print("回测完成!")
|
||||
print("=" * 60)
|
||||
print(f"总计检测到: {len(all_alerts)} 个异动")
|
||||
print(f"保存到数据库: {total_saved} 条")
|
||||
|
||||
# 统计
|
||||
if all_alerts:
|
||||
df_alerts = pd.DataFrame(all_alerts)
|
||||
print(f"\n异动类型分布:")
|
||||
print(df_alerts['alert_type'].value_counts())
|
||||
|
||||
print(f"\n异动分数统计:")
|
||||
print(f" Mean: {df_alerts['anomaly_score'].mean():.4f}")
|
||||
print(f" Max: {df_alerts['anomaly_score'].max():.4f}")
|
||||
print(f" Min: {df_alerts['anomaly_score'].min():.4f}")
|
||||
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
418
ml/backtest_hybrid.py
Normal file
418
ml/backtest_hybrid.py
Normal file
@@ -0,0 +1,418 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
融合异动回测脚本
|
||||
|
||||
使用 HybridAnomalyDetector 进行回测:
|
||||
- 规则评分 + LSTM Autoencoder 融合判断
|
||||
- 输出更丰富的异动信息
|
||||
|
||||
使用方法:
|
||||
python backtest_hybrid.py --start 2024-01-01 --end 2024-12-01
|
||||
python backtest_hybrid.py --start 2024-11-01 --dry-run
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
from sqlalchemy import create_engine, text
|
||||
|
||||
# 添加父目录到路径
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from detector import HybridAnomalyDetector, create_detector
|
||||
|
||||
|
||||
# ==================== 配置 ====================
|
||||
|
||||
MYSQL_ENGINE = create_engine(
|
||||
"mysql+pymysql://root:Zzl5588161!@192.168.1.5:3306/stock",
|
||||
echo=False
|
||||
)
|
||||
|
||||
FEATURES = [
|
||||
'alpha',
|
||||
'alpha_delta',
|
||||
'amt_ratio',
|
||||
'amt_delta',
|
||||
'rank_pct',
|
||||
'limit_up_ratio',
|
||||
]
|
||||
|
||||
BACKTEST_CONFIG = {
|
||||
'seq_len': 30,
|
||||
'min_alpha_abs': 0.3, # 降低阈值,让规则也能发挥作用
|
||||
'cooldown_minutes': 8,
|
||||
'max_alerts_per_minute': 20,
|
||||
'clip_value': 10.0,
|
||||
}
|
||||
|
||||
|
||||
# ==================== 数据加载 ====================
|
||||
|
||||
def load_daily_features(data_dir: str, date: str) -> Optional[pd.DataFrame]:
|
||||
"""加载单天的特征数据"""
|
||||
file_path = Path(data_dir) / f"features_{date}.parquet"
|
||||
|
||||
if not file_path.exists():
|
||||
return None
|
||||
|
||||
df = pd.read_parquet(file_path)
|
||||
return df
|
||||
|
||||
|
||||
def get_available_dates(data_dir: str, start_date: str, end_date: str) -> List[str]:
|
||||
"""获取可用的日期列表"""
|
||||
data_path = Path(data_dir)
|
||||
all_files = sorted(data_path.glob("features_*.parquet"))
|
||||
|
||||
dates = []
|
||||
for f in all_files:
|
||||
date = f.stem.replace('features_', '')
|
||||
if start_date <= date <= end_date:
|
||||
dates.append(date)
|
||||
|
||||
return dates
|
||||
|
||||
|
||||
# ==================== 融合回测 ====================
|
||||
|
||||
def backtest_single_day_hybrid(
|
||||
detector: HybridAnomalyDetector,
|
||||
df: pd.DataFrame,
|
||||
date: str,
|
||||
seq_len: int = 30
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
使用融合检测器回测单天数据
|
||||
"""
|
||||
alerts = []
|
||||
|
||||
# 按概念分组
|
||||
grouped = df.groupby('concept_id', sort=False)
|
||||
|
||||
# 冷却记录
|
||||
cooldown = {}
|
||||
|
||||
# 获取所有时间点
|
||||
all_timestamps = sorted(df['timestamp'].unique())
|
||||
|
||||
if len(all_timestamps) < seq_len:
|
||||
return alerts
|
||||
|
||||
# 对每个时间点进行检测
|
||||
for t_idx in range(seq_len - 1, len(all_timestamps)):
|
||||
current_time = all_timestamps[t_idx]
|
||||
window_start_time = all_timestamps[t_idx - seq_len + 1]
|
||||
|
||||
minute_alerts = []
|
||||
|
||||
for concept_id, concept_df in grouped:
|
||||
# 获取时间窗口内的数据
|
||||
mask = (concept_df['timestamp'] >= window_start_time) & (concept_df['timestamp'] <= current_time)
|
||||
window_df = concept_df[mask].sort_values('timestamp')
|
||||
|
||||
if len(window_df) < seq_len:
|
||||
continue
|
||||
|
||||
window_df = window_df.tail(seq_len)
|
||||
|
||||
# 提取特征序列(给 ML 模型)
|
||||
sequence = window_df[FEATURES].values
|
||||
sequence = np.nan_to_num(sequence, nan=0.0, posinf=0.0, neginf=0.0)
|
||||
sequence = np.clip(sequence, -BACKTEST_CONFIG['clip_value'], BACKTEST_CONFIG['clip_value'])
|
||||
|
||||
# 当前时刻特征(给规则系统)
|
||||
current_row = window_df.iloc[-1]
|
||||
current_features = {
|
||||
'alpha': current_row.get('alpha', 0),
|
||||
'alpha_delta': current_row.get('alpha_delta', 0),
|
||||
'amt_ratio': current_row.get('amt_ratio', 1),
|
||||
'amt_delta': current_row.get('amt_delta', 0),
|
||||
'rank_pct': current_row.get('rank_pct', 0.5),
|
||||
'limit_up_ratio': current_row.get('limit_up_ratio', 0),
|
||||
}
|
||||
|
||||
# 过滤微小波动
|
||||
if abs(current_features['alpha']) < BACKTEST_CONFIG['min_alpha_abs']:
|
||||
continue
|
||||
|
||||
# 检查冷却
|
||||
if concept_id in cooldown:
|
||||
last_alert = cooldown[concept_id]
|
||||
if isinstance(current_time, datetime):
|
||||
time_diff = (current_time - last_alert).total_seconds() / 60
|
||||
else:
|
||||
time_diff = BACKTEST_CONFIG['cooldown_minutes'] + 1
|
||||
|
||||
if time_diff < BACKTEST_CONFIG['cooldown_minutes']:
|
||||
continue
|
||||
|
||||
# 融合检测
|
||||
result = detector.detect(current_features, sequence)
|
||||
|
||||
if not result.is_anomaly:
|
||||
continue
|
||||
|
||||
# 记录异动
|
||||
alert = {
|
||||
'concept_id': concept_id,
|
||||
'alert_time': current_time,
|
||||
'trade_date': date,
|
||||
'alert_type': result.anomaly_type,
|
||||
'final_score': result.final_score,
|
||||
'rule_score': result.rule_score,
|
||||
'ml_score': result.ml_score,
|
||||
'trigger_reason': result.trigger_reason,
|
||||
'triggered_rules': list(result.rule_details.keys()),
|
||||
**current_features,
|
||||
'stock_count': current_row.get('stock_count', 0),
|
||||
'total_amt': current_row.get('total_amt', 0),
|
||||
}
|
||||
|
||||
minute_alerts.append(alert)
|
||||
cooldown[concept_id] = current_time
|
||||
|
||||
# 按最终得分排序
|
||||
minute_alerts.sort(key=lambda x: x['final_score'], reverse=True)
|
||||
alerts.extend(minute_alerts[:BACKTEST_CONFIG['max_alerts_per_minute']])
|
||||
|
||||
return alerts
|
||||
|
||||
|
||||
# ==================== 数据库写入 ====================
|
||||
|
||||
def save_alerts_to_mysql(alerts: List[Dict], dry_run: bool = False) -> int:
|
||||
"""保存异动到 MySQL(增强版字段)"""
|
||||
if not alerts:
|
||||
return 0
|
||||
|
||||
if dry_run:
|
||||
print(f" [Dry Run] 将写入 {len(alerts)} 条异动")
|
||||
return len(alerts)
|
||||
|
||||
saved = 0
|
||||
with MYSQL_ENGINE.begin() as conn:
|
||||
for alert in alerts:
|
||||
try:
|
||||
# 检查是否已存在
|
||||
check_sql = text("""
|
||||
SELECT id FROM concept_minute_alert
|
||||
WHERE concept_id = :concept_id
|
||||
AND alert_time = :alert_time
|
||||
AND trade_date = :trade_date
|
||||
""")
|
||||
exists = conn.execute(check_sql, {
|
||||
'concept_id': alert['concept_id'],
|
||||
'alert_time': alert['alert_time'],
|
||||
'trade_date': alert['trade_date'],
|
||||
}).fetchone()
|
||||
|
||||
if exists:
|
||||
continue
|
||||
|
||||
# 插入新记录
|
||||
insert_sql = text("""
|
||||
INSERT INTO concept_minute_alert
|
||||
(concept_id, concept_name, alert_time, alert_type, trade_date,
|
||||
change_pct, zscore, importance_score, stock_count, extra_info)
|
||||
VALUES
|
||||
(:concept_id, :concept_name, :alert_time, :alert_type, :trade_date,
|
||||
:change_pct, :zscore, :importance_score, :stock_count, :extra_info)
|
||||
""")
|
||||
|
||||
extra_info = {
|
||||
'detection_method': 'hybrid',
|
||||
'final_score': alert['final_score'],
|
||||
'rule_score': alert['rule_score'],
|
||||
'ml_score': alert['ml_score'],
|
||||
'trigger_reason': alert['trigger_reason'],
|
||||
'triggered_rules': alert['triggered_rules'],
|
||||
'alpha': alert.get('alpha', 0),
|
||||
'alpha_delta': alert.get('alpha_delta', 0),
|
||||
'amt_ratio': alert.get('amt_ratio', 1),
|
||||
}
|
||||
|
||||
conn.execute(insert_sql, {
|
||||
'concept_id': alert['concept_id'],
|
||||
'concept_name': alert.get('concept_name', ''),
|
||||
'alert_time': alert['alert_time'],
|
||||
'alert_type': alert['alert_type'],
|
||||
'trade_date': alert['trade_date'],
|
||||
'change_pct': alert.get('alpha', 0),
|
||||
'zscore': alert['final_score'], # 用最终得分作为 zscore
|
||||
'importance_score': alert['final_score'],
|
||||
'stock_count': alert.get('stock_count', 0),
|
||||
'extra_info': json.dumps(extra_info, ensure_ascii=False)
|
||||
})
|
||||
|
||||
saved += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f" 保存失败: {alert['concept_id']} - {e}")
|
||||
|
||||
return saved
|
||||
|
||||
|
||||
def export_alerts_to_csv(alerts: List[Dict], output_path: str):
|
||||
"""导出异动到 CSV"""
|
||||
if not alerts:
|
||||
return
|
||||
|
||||
df = pd.DataFrame(alerts)
|
||||
df.to_csv(output_path, index=False, encoding='utf-8-sig')
|
||||
print(f"已导出到: {output_path}")
|
||||
|
||||
|
||||
# ==================== 统计分析 ====================
|
||||
|
||||
def analyze_alerts(alerts: List[Dict]):
|
||||
"""分析异动结果"""
|
||||
if not alerts:
|
||||
print("无异动数据")
|
||||
return
|
||||
|
||||
df = pd.DataFrame(alerts)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("异动统计分析")
|
||||
print("=" * 60)
|
||||
|
||||
# 1. 基本统计
|
||||
print(f"\n总异动数: {len(alerts)}")
|
||||
|
||||
# 2. 按类型统计
|
||||
print(f"\n异动类型分布:")
|
||||
print(df['alert_type'].value_counts())
|
||||
|
||||
# 3. 得分统计
|
||||
print(f"\n得分统计:")
|
||||
print(f" 最终得分 - Mean: {df['final_score'].mean():.1f}, Max: {df['final_score'].max():.1f}")
|
||||
print(f" 规则得分 - Mean: {df['rule_score'].mean():.1f}, Max: {df['rule_score'].max():.1f}")
|
||||
print(f" ML得分 - Mean: {df['ml_score'].mean():.1f}, Max: {df['ml_score'].max():.1f}")
|
||||
|
||||
# 4. 触发来源分析
|
||||
print(f"\n触发来源分析:")
|
||||
trigger_counts = df['trigger_reason'].apply(
|
||||
lambda x: '规则' if '规则' in x else ('ML' if 'ML' in x else '融合')
|
||||
).value_counts()
|
||||
print(trigger_counts)
|
||||
|
||||
# 5. 规则触发频率
|
||||
all_rules = []
|
||||
for rules in df['triggered_rules']:
|
||||
if isinstance(rules, list):
|
||||
all_rules.extend(rules)
|
||||
|
||||
if all_rules:
|
||||
print(f"\n最常触发的规则 (Top 10):")
|
||||
from collections import Counter
|
||||
rule_counts = Counter(all_rules)
|
||||
for rule, count in rule_counts.most_common(10):
|
||||
print(f" {rule}: {count}")
|
||||
|
||||
|
||||
# ==================== 主函数 ====================
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='融合异动回测')
|
||||
parser.add_argument('--data_dir', type=str, default='ml/data',
|
||||
help='特征数据目录')
|
||||
parser.add_argument('--checkpoint_dir', type=str, default='ml/checkpoints',
|
||||
help='模型检查点目录')
|
||||
parser.add_argument('--start', type=str, required=True,
|
||||
help='开始日期 (YYYY-MM-DD)')
|
||||
parser.add_argument('--end', type=str, default=None,
|
||||
help='结束日期 (YYYY-MM-DD),默认=start')
|
||||
parser.add_argument('--dry-run', action='store_true',
|
||||
help='只计算,不写入数据库')
|
||||
parser.add_argument('--export-csv', type=str, default=None,
|
||||
help='导出 CSV 文件路径')
|
||||
parser.add_argument('--rule-weight', type=float, default=0.6,
|
||||
help='规则权重 (0-1)')
|
||||
parser.add_argument('--ml-weight', type=float, default=0.4,
|
||||
help='ML权重 (0-1)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.end is None:
|
||||
args.end = args.start
|
||||
|
||||
print("=" * 60)
|
||||
print("融合异动回测 (规则 + LSTM)")
|
||||
print("=" * 60)
|
||||
print(f"日期范围: {args.start} ~ {args.end}")
|
||||
print(f"数据目录: {args.data_dir}")
|
||||
print(f"模型目录: {args.checkpoint_dir}")
|
||||
print(f"规则权重: {args.rule_weight}")
|
||||
print(f"ML权重: {args.ml_weight}")
|
||||
print(f"Dry Run: {args.dry_run}")
|
||||
print("=" * 60)
|
||||
|
||||
# 初始化融合检测器
|
||||
config = {
|
||||
'rule_weight': args.rule_weight,
|
||||
'ml_weight': args.ml_weight,
|
||||
}
|
||||
detector = create_detector(args.checkpoint_dir, config)
|
||||
|
||||
# 获取可用日期
|
||||
dates = get_available_dates(args.data_dir, args.start, args.end)
|
||||
|
||||
if not dates:
|
||||
print(f"未找到 {args.start} ~ {args.end} 范围内的数据")
|
||||
return
|
||||
|
||||
print(f"\n找到 {len(dates)} 天的数据")
|
||||
|
||||
# 回测
|
||||
all_alerts = []
|
||||
total_saved = 0
|
||||
|
||||
for date in tqdm(dates, desc="回测进度"):
|
||||
df = load_daily_features(args.data_dir, date)
|
||||
|
||||
if df is None or df.empty:
|
||||
continue
|
||||
|
||||
alerts = backtest_single_day_hybrid(
|
||||
detector, df, date,
|
||||
seq_len=BACKTEST_CONFIG['seq_len']
|
||||
)
|
||||
|
||||
if alerts:
|
||||
all_alerts.extend(alerts)
|
||||
|
||||
saved = save_alerts_to_mysql(alerts, dry_run=args.dry_run)
|
||||
total_saved += saved
|
||||
|
||||
if not args.dry_run:
|
||||
tqdm.write(f" {date}: 检测到 {len(alerts)} 个异动,保存 {saved} 条")
|
||||
|
||||
# 导出 CSV
|
||||
if args.export_csv and all_alerts:
|
||||
export_alerts_to_csv(all_alerts, args.export_csv)
|
||||
|
||||
# 统计分析
|
||||
analyze_alerts(all_alerts)
|
||||
|
||||
# 汇总
|
||||
print("\n" + "=" * 60)
|
||||
print("回测完成!")
|
||||
print("=" * 60)
|
||||
print(f"总计检测到: {len(all_alerts)} 个异动")
|
||||
print(f"保存到数据库: {total_saved} 条")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
571
ml/detector.py
Normal file
571
ml/detector.py
Normal file
@@ -0,0 +1,571 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
概念异动检测器 - 融合版
|
||||
|
||||
结合两种方法的优势:
|
||||
1. 规则评分系统:可解释、稳定、覆盖已知模式
|
||||
2. LSTM Autoencoder:发现未知的异常模式
|
||||
|
||||
融合策略:
|
||||
┌─────────────────────────────────────────────────────────┐
|
||||
│ 输入特征 │
|
||||
│ (alpha, alpha_delta, amt_ratio, amt_delta, rank_pct, │
|
||||
│ limit_up_ratio) │
|
||||
├─────────────────────────────────────────────────────────┤
|
||||
│ │
|
||||
│ ┌──────────────┐ ┌──────────────┐ │
|
||||
│ │ 规则评分系统 │ │ LSTM Autoencoder │ │
|
||||
│ │ (0-100分) │ │ (重构误差) │ │
|
||||
│ └──────┬───────┘ └──────┬───────┘ │
|
||||
│ │ │ │
|
||||
│ ▼ ▼ │
|
||||
│ rule_score (0-100) ml_score (标准化后 0-100) │
|
||||
│ │
|
||||
├─────────────────────────────────────────────────────────┤
|
||||
│ 融合策略 │
|
||||
│ │
|
||||
│ final_score = w1 * rule_score + w2 * ml_score │
|
||||
│ │
|
||||
│ 异动判定: │
|
||||
│ - rule_score >= 60 → 直接触发(规则强信号) │
|
||||
│ - ml_score >= 80 → 直接触发(ML强信号) │
|
||||
│ - final_score >= 50 → 融合触发 │
|
||||
│ │
|
||||
└─────────────────────────────────────────────────────────┘
|
||||
|
||||
优势:
|
||||
- 规则系统保证已知模式的检出率
|
||||
- ML模型捕捉规则未覆盖的异常
|
||||
- 两者互相验证,减少误报
|
||||
"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
# 尝试导入模型(可能不存在)
|
||||
try:
|
||||
from model import LSTMAutoencoder, create_model
|
||||
HAS_MODEL = True
|
||||
except ImportError:
|
||||
HAS_MODEL = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class AnomalyResult:
|
||||
"""异动检测结果"""
|
||||
is_anomaly: bool
|
||||
final_score: float # 最终得分 (0-100)
|
||||
rule_score: float # 规则得分 (0-100)
|
||||
ml_score: float # ML得分 (0-100)
|
||||
trigger_reason: str # 触发原因
|
||||
rule_details: Dict # 规则明细
|
||||
anomaly_type: str # 异动类型: surge_up / surge_down / volume_spike / unknown
|
||||
|
||||
|
||||
class RuleBasedScorer:
|
||||
"""
|
||||
基于规则的评分系统
|
||||
|
||||
设计原则:
|
||||
- 每个规则独立打分
|
||||
- 分数可叠加
|
||||
- 阈值可配置
|
||||
"""
|
||||
|
||||
# 默认规则配置
|
||||
DEFAULT_RULES = {
|
||||
# Alpha 相关(超额收益)
|
||||
'alpha_strong': {
|
||||
'condition': lambda r: abs(r.get('alpha', 0)) >= 3.0,
|
||||
'score': 35,
|
||||
'description': 'Alpha强信号(|α|≥3%)'
|
||||
},
|
||||
'alpha_medium': {
|
||||
'condition': lambda r: 2.0 <= abs(r.get('alpha', 0)) < 3.0,
|
||||
'score': 25,
|
||||
'description': 'Alpha中等(2%≤|α|<3%)'
|
||||
},
|
||||
'alpha_weak': {
|
||||
'condition': lambda r: 1.5 <= abs(r.get('alpha', 0)) < 2.0,
|
||||
'score': 15,
|
||||
'description': 'Alpha轻微(1.5%≤|α|<2%)'
|
||||
},
|
||||
|
||||
# Alpha 变化率(加速度)
|
||||
'alpha_delta_strong': {
|
||||
'condition': lambda r: abs(r.get('alpha_delta', 0)) >= 1.0,
|
||||
'score': 30,
|
||||
'description': 'Alpha加速强(|Δα|≥1%)'
|
||||
},
|
||||
'alpha_delta_medium': {
|
||||
'condition': lambda r: 0.5 <= abs(r.get('alpha_delta', 0)) < 1.0,
|
||||
'score': 20,
|
||||
'description': 'Alpha加速中(0.5%≤|Δα|<1%)'
|
||||
},
|
||||
|
||||
# 成交额比率(放量)
|
||||
'volume_spike_strong': {
|
||||
'condition': lambda r: r.get('amt_ratio', 1) >= 5.0,
|
||||
'score': 30,
|
||||
'description': '极度放量(≥5倍)'
|
||||
},
|
||||
'volume_spike_medium': {
|
||||
'condition': lambda r: 3.0 <= r.get('amt_ratio', 1) < 5.0,
|
||||
'score': 20,
|
||||
'description': '显著放量(3-5倍)'
|
||||
},
|
||||
'volume_spike_weak': {
|
||||
'condition': lambda r: 2.0 <= r.get('amt_ratio', 1) < 3.0,
|
||||
'score': 10,
|
||||
'description': '轻微放量(2-3倍)'
|
||||
},
|
||||
|
||||
# 成交额变化率
|
||||
'amt_delta_strong': {
|
||||
'condition': lambda r: abs(r.get('amt_delta', 0)) >= 1.0,
|
||||
'score': 15,
|
||||
'description': '成交额急变(|Δamt|≥100%)'
|
||||
},
|
||||
|
||||
# 排名跳变
|
||||
'rank_top': {
|
||||
'condition': lambda r: r.get('rank_pct', 0.5) >= 0.95,
|
||||
'score': 25,
|
||||
'description': '排名前5%'
|
||||
},
|
||||
'rank_bottom': {
|
||||
'condition': lambda r: r.get('rank_pct', 0.5) <= 0.05,
|
||||
'score': 25,
|
||||
'description': '排名后5%'
|
||||
},
|
||||
'rank_high': {
|
||||
'condition': lambda r: 0.9 <= r.get('rank_pct', 0.5) < 0.95,
|
||||
'score': 15,
|
||||
'description': '排名前10%'
|
||||
},
|
||||
|
||||
# 涨停比例
|
||||
'limit_up_high': {
|
||||
'condition': lambda r: r.get('limit_up_ratio', 0) >= 0.2,
|
||||
'score': 25,
|
||||
'description': '涨停比例≥20%'
|
||||
},
|
||||
'limit_up_medium': {
|
||||
'condition': lambda r: 0.1 <= r.get('limit_up_ratio', 0) < 0.2,
|
||||
'score': 15,
|
||||
'description': '涨停比例10-20%'
|
||||
},
|
||||
|
||||
# 组合条件(更可靠的信号)
|
||||
'alpha_with_volume': {
|
||||
'condition': lambda r: abs(r.get('alpha', 0)) >= 1.5 and r.get('amt_ratio', 1) >= 2.0,
|
||||
'score': 20, # 额外加分
|
||||
'description': 'Alpha+放量组合'
|
||||
},
|
||||
'acceleration_with_rank': {
|
||||
'condition': lambda r: abs(r.get('alpha_delta', 0)) >= 0.5 and (r.get('rank_pct', 0.5) >= 0.9 or r.get('rank_pct', 0.5) <= 0.1),
|
||||
'score': 15, # 额外加分
|
||||
'description': '加速+排名异常组合'
|
||||
},
|
||||
}
|
||||
|
||||
def __init__(self, rules: Dict = None):
|
||||
"""
|
||||
初始化规则评分器
|
||||
|
||||
Args:
|
||||
rules: 自定义规则,格式同 DEFAULT_RULES
|
||||
"""
|
||||
self.rules = rules or self.DEFAULT_RULES
|
||||
|
||||
def score(self, features: Dict) -> Tuple[float, Dict]:
|
||||
"""
|
||||
计算规则得分
|
||||
|
||||
Args:
|
||||
features: 特征字典,包含 alpha, alpha_delta, amt_ratio 等
|
||||
Returns:
|
||||
score: 总分 (0-100)
|
||||
details: 触发的规则明细
|
||||
"""
|
||||
total_score = 0
|
||||
triggered_rules = {}
|
||||
|
||||
for rule_name, rule_config in self.rules.items():
|
||||
try:
|
||||
if rule_config['condition'](features):
|
||||
total_score += rule_config['score']
|
||||
triggered_rules[rule_name] = {
|
||||
'score': rule_config['score'],
|
||||
'description': rule_config['description']
|
||||
}
|
||||
except Exception:
|
||||
# 忽略规则计算错误
|
||||
pass
|
||||
|
||||
# 限制在 0-100
|
||||
total_score = min(100, max(0, total_score))
|
||||
|
||||
return total_score, triggered_rules
|
||||
|
||||
def get_anomaly_type(self, features: Dict) -> str:
|
||||
"""判断异动类型"""
|
||||
alpha = features.get('alpha', 0)
|
||||
amt_ratio = features.get('amt_ratio', 1)
|
||||
|
||||
if alpha >= 1.5:
|
||||
return 'surge_up'
|
||||
elif alpha <= -1.5:
|
||||
return 'surge_down'
|
||||
elif amt_ratio >= 3.0:
|
||||
return 'volume_spike'
|
||||
else:
|
||||
return 'unknown'
|
||||
|
||||
|
||||
class MLScorer:
|
||||
"""
|
||||
基于 LSTM Autoencoder 的评分器
|
||||
|
||||
将重构误差转换为 0-100 的分数
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
checkpoint_dir: str = 'ml/checkpoints',
|
||||
device: str = 'auto'
|
||||
):
|
||||
self.checkpoint_dir = Path(checkpoint_dir)
|
||||
|
||||
# 设备
|
||||
if device == 'auto':
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
else:
|
||||
self.device = torch.device(device)
|
||||
|
||||
self.model = None
|
||||
self.thresholds = None
|
||||
self.config = None
|
||||
|
||||
# 尝试加载模型
|
||||
self._load_model()
|
||||
|
||||
def _load_model(self):
|
||||
"""加载模型和阈值"""
|
||||
if not HAS_MODEL:
|
||||
print("警告: 无法导入模型模块")
|
||||
return
|
||||
|
||||
model_path = self.checkpoint_dir / 'best_model.pt'
|
||||
thresholds_path = self.checkpoint_dir / 'thresholds.json'
|
||||
config_path = self.checkpoint_dir / 'config.json'
|
||||
|
||||
if not model_path.exists():
|
||||
print(f"警告: 模型文件不存在 {model_path}")
|
||||
return
|
||||
|
||||
try:
|
||||
# 加载配置
|
||||
if config_path.exists():
|
||||
with open(config_path, 'r') as f:
|
||||
self.config = json.load(f)
|
||||
|
||||
# 加载模型
|
||||
checkpoint = torch.load(model_path, map_location=self.device)
|
||||
|
||||
model_config = self.config.get('model', {}) if self.config else {}
|
||||
self.model = create_model(model_config)
|
||||
self.model.load_state_dict(checkpoint['model_state_dict'])
|
||||
self.model.to(self.device)
|
||||
self.model.eval()
|
||||
|
||||
# 加载阈值
|
||||
if thresholds_path.exists():
|
||||
with open(thresholds_path, 'r') as f:
|
||||
self.thresholds = json.load(f)
|
||||
|
||||
print(f"MLScorer 加载成功 (设备: {self.device})")
|
||||
|
||||
except Exception as e:
|
||||
print(f"警告: 模型加载失败 - {e}")
|
||||
self.model = None
|
||||
|
||||
def is_ready(self) -> bool:
|
||||
"""检查模型是否就绪"""
|
||||
return self.model is not None
|
||||
|
||||
@torch.no_grad()
|
||||
def score(self, sequence: np.ndarray) -> float:
|
||||
"""
|
||||
计算 ML 得分
|
||||
|
||||
Args:
|
||||
sequence: (seq_len, n_features) 或 (batch, seq_len, n_features)
|
||||
Returns:
|
||||
score: 0-100 的分数,越高越异常
|
||||
"""
|
||||
if not self.is_ready():
|
||||
return 0.0
|
||||
|
||||
# 确保是 3D
|
||||
if sequence.ndim == 2:
|
||||
sequence = sequence[np.newaxis, ...]
|
||||
|
||||
# 转为 tensor
|
||||
x = torch.FloatTensor(sequence).to(self.device)
|
||||
|
||||
# 计算重构误差
|
||||
output, _ = self.model(x)
|
||||
mse = ((output - x) ** 2).mean(dim=-1) # (batch, seq_len)
|
||||
|
||||
# 取最后时刻的误差
|
||||
error = mse[:, -1].cpu().numpy()
|
||||
|
||||
# 转换为 0-100 分数
|
||||
# 使用 p95 阈值作为参考
|
||||
if self.thresholds:
|
||||
p95 = self.thresholds.get('p95', 0.1)
|
||||
p99 = self.thresholds.get('p99', 0.2)
|
||||
else:
|
||||
p95, p99 = 0.1, 0.2
|
||||
|
||||
# 线性映射:p95 -> 50分, p99 -> 80分
|
||||
# error=0 -> 0分, error>=p99*1.5 -> 100分
|
||||
score = np.clip(error / p95 * 50, 0, 100)
|
||||
|
||||
return float(score[0]) if len(score) == 1 else score.tolist()
|
||||
|
||||
|
||||
class HybridAnomalyDetector:
|
||||
"""
|
||||
融合异动检测器
|
||||
|
||||
结合规则系统和 ML 模型
|
||||
"""
|
||||
|
||||
# 默认配置
|
||||
DEFAULT_CONFIG = {
|
||||
# 权重配置
|
||||
'rule_weight': 0.6, # 规则权重
|
||||
'ml_weight': 0.4, # ML权重
|
||||
|
||||
# 触发阈值
|
||||
'rule_trigger': 60, # 规则直接触发阈值
|
||||
'ml_trigger': 80, # ML直接触发阈值
|
||||
'fusion_trigger': 50, # 融合触发阈值
|
||||
|
||||
# 特征列表
|
||||
'features': [
|
||||
'alpha', 'alpha_delta', 'amt_ratio',
|
||||
'amt_delta', 'rank_pct', 'limit_up_ratio'
|
||||
],
|
||||
|
||||
# 序列长度(ML模型需要)
|
||||
'seq_len': 30,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Dict = None,
|
||||
checkpoint_dir: str = 'ml/checkpoints',
|
||||
device: str = 'auto'
|
||||
):
|
||||
self.config = {**self.DEFAULT_CONFIG, **(config or {})}
|
||||
|
||||
# 初始化评分器
|
||||
self.rule_scorer = RuleBasedScorer()
|
||||
self.ml_scorer = MLScorer(checkpoint_dir, device)
|
||||
|
||||
print(f"HybridAnomalyDetector 初始化完成")
|
||||
print(f" 规则权重: {self.config['rule_weight']}")
|
||||
print(f" ML权重: {self.config['ml_weight']}")
|
||||
print(f" ML模型: {'就绪' if self.ml_scorer.is_ready() else '未加载'}")
|
||||
|
||||
def detect(
|
||||
self,
|
||||
features: Dict,
|
||||
sequence: np.ndarray = None
|
||||
) -> AnomalyResult:
|
||||
"""
|
||||
检测异动
|
||||
|
||||
Args:
|
||||
features: 当前时刻的特征字典
|
||||
sequence: 历史序列 (seq_len, n_features),ML模型需要
|
||||
Returns:
|
||||
AnomalyResult: 检测结果
|
||||
"""
|
||||
# 1. 规则评分
|
||||
rule_score, rule_details = self.rule_scorer.score(features)
|
||||
|
||||
# 2. ML评分
|
||||
ml_score = 0.0
|
||||
if sequence is not None and self.ml_scorer.is_ready():
|
||||
ml_score = self.ml_scorer.score(sequence)
|
||||
|
||||
# 3. 融合得分
|
||||
w1 = self.config['rule_weight']
|
||||
w2 = self.config['ml_weight']
|
||||
|
||||
# 如果ML不可用,全部权重给规则
|
||||
if not self.ml_scorer.is_ready():
|
||||
w1, w2 = 1.0, 0.0
|
||||
|
||||
final_score = w1 * rule_score + w2 * ml_score
|
||||
|
||||
# 4. 判断是否异动
|
||||
is_anomaly = False
|
||||
trigger_reason = ''
|
||||
|
||||
if rule_score >= self.config['rule_trigger']:
|
||||
is_anomaly = True
|
||||
trigger_reason = f'规则强信号({rule_score:.0f}分)'
|
||||
elif ml_score >= self.config['ml_trigger']:
|
||||
is_anomaly = True
|
||||
trigger_reason = f'ML强信号({ml_score:.0f}分)'
|
||||
elif final_score >= self.config['fusion_trigger']:
|
||||
is_anomaly = True
|
||||
trigger_reason = f'融合触发({final_score:.0f}分)'
|
||||
|
||||
# 5. 判断异动类型
|
||||
anomaly_type = self.rule_scorer.get_anomaly_type(features) if is_anomaly else ''
|
||||
|
||||
return AnomalyResult(
|
||||
is_anomaly=is_anomaly,
|
||||
final_score=final_score,
|
||||
rule_score=rule_score,
|
||||
ml_score=ml_score,
|
||||
trigger_reason=trigger_reason,
|
||||
rule_details=rule_details,
|
||||
anomaly_type=anomaly_type
|
||||
)
|
||||
|
||||
def detect_batch(
|
||||
self,
|
||||
features_list: List[Dict],
|
||||
sequences: np.ndarray = None
|
||||
) -> List[AnomalyResult]:
|
||||
"""
|
||||
批量检测
|
||||
|
||||
Args:
|
||||
features_list: 特征字典列表
|
||||
sequences: (batch, seq_len, n_features)
|
||||
Returns:
|
||||
List[AnomalyResult]
|
||||
"""
|
||||
results = []
|
||||
|
||||
for i, features in enumerate(features_list):
|
||||
seq = sequences[i] if sequences is not None else None
|
||||
result = self.detect(features, seq)
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# ==================== 便捷函数 ====================
|
||||
|
||||
def create_detector(
|
||||
checkpoint_dir: str = 'ml/checkpoints',
|
||||
config: Dict = None
|
||||
) -> HybridAnomalyDetector:
|
||||
"""创建融合检测器"""
|
||||
return HybridAnomalyDetector(config, checkpoint_dir)
|
||||
|
||||
|
||||
def quick_detect(features: Dict) -> bool:
|
||||
"""
|
||||
快速检测(只用规则,不需要ML模型)
|
||||
|
||||
适用于:
|
||||
- 实时检测
|
||||
- ML模型未训练完成时
|
||||
"""
|
||||
scorer = RuleBasedScorer()
|
||||
score, _ = scorer.score(features)
|
||||
return score >= 50
|
||||
|
||||
|
||||
# ==================== 测试 ====================
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("=" * 60)
|
||||
print("融合异动检测器测试")
|
||||
print("=" * 60)
|
||||
|
||||
# 创建检测器
|
||||
detector = create_detector()
|
||||
|
||||
# 测试用例
|
||||
test_cases = [
|
||||
{
|
||||
'name': '正常情况',
|
||||
'features': {
|
||||
'alpha': 0.5,
|
||||
'alpha_delta': 0.1,
|
||||
'amt_ratio': 1.2,
|
||||
'amt_delta': 0.1,
|
||||
'rank_pct': 0.5,
|
||||
'limit_up_ratio': 0.02
|
||||
}
|
||||
},
|
||||
{
|
||||
'name': 'Alpha异动',
|
||||
'features': {
|
||||
'alpha': 3.5,
|
||||
'alpha_delta': 0.8,
|
||||
'amt_ratio': 2.5,
|
||||
'amt_delta': 0.5,
|
||||
'rank_pct': 0.92,
|
||||
'limit_up_ratio': 0.05
|
||||
}
|
||||
},
|
||||
{
|
||||
'name': '放量异动',
|
||||
'features': {
|
||||
'alpha': 1.2,
|
||||
'alpha_delta': 0.3,
|
||||
'amt_ratio': 6.0,
|
||||
'amt_delta': 1.5,
|
||||
'rank_pct': 0.85,
|
||||
'limit_up_ratio': 0.08
|
||||
}
|
||||
},
|
||||
{
|
||||
'name': '涨停潮',
|
||||
'features': {
|
||||
'alpha': 2.5,
|
||||
'alpha_delta': 0.6,
|
||||
'amt_ratio': 3.5,
|
||||
'amt_delta': 0.8,
|
||||
'rank_pct': 0.98,
|
||||
'limit_up_ratio': 0.25
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
print("\n测试结果:")
|
||||
print("-" * 60)
|
||||
|
||||
for case in test_cases:
|
||||
result = detector.detect(case['features'])
|
||||
|
||||
print(f"\n{case['name']}:")
|
||||
print(f" 异动: {'是' if result.is_anomaly else '否'}")
|
||||
print(f" 最终得分: {result.final_score:.1f}")
|
||||
print(f" 规则得分: {result.rule_score:.1f}")
|
||||
print(f" ML得分: {result.ml_score:.1f}")
|
||||
if result.is_anomaly:
|
||||
print(f" 触发原因: {result.trigger_reason}")
|
||||
print(f" 异动类型: {result.anomaly_type}")
|
||||
print(f" 触发规则: {list(result.rule_details.keys())}")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("测试完成!")
|
||||
526
ml/enhanced_detector.py
Normal file
526
ml/enhanced_detector.py
Normal file
@@ -0,0 +1,526 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
增强版概念异动检测器
|
||||
|
||||
融合两种检测方法:
|
||||
1. Alpha-based Z-Score(规则方法,实时性好)
|
||||
2. Transformer Autoencoder(ML方法,更准确)
|
||||
|
||||
使用策略:
|
||||
- 当 ML 模型可用且历史数据足够时,优先使用 ML 方法
|
||||
- 否则回退到 Alpha-based 方法
|
||||
- 可以配置两种方法的融合权重
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple, Optional
|
||||
from dataclasses import dataclass, field
|
||||
from collections import deque
|
||||
import numpy as np
|
||||
|
||||
# 添加父目录到路径
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ==================== 配置 ====================
|
||||
|
||||
ENHANCED_CONFIG = {
|
||||
# 融合策略
|
||||
'fusion_mode': 'adaptive', # 'ml_only', 'alpha_only', 'adaptive', 'ensemble'
|
||||
|
||||
# ML 权重(在 ensemble 模式下)
|
||||
'ml_weight': 0.6,
|
||||
'alpha_weight': 0.4,
|
||||
|
||||
# ML 模型配置
|
||||
'ml_checkpoint_dir': 'ml/checkpoints',
|
||||
'ml_threshold_key': 'p95', # p90, p95, p99
|
||||
|
||||
# Alpha 配置(与 concept_alert_alpha.py 一致)
|
||||
'alpha_zscore_threshold': 2.0,
|
||||
'alpha_absolute_threshold': 1.5,
|
||||
'alpha_history_window': 60,
|
||||
'alpha_min_history': 5,
|
||||
|
||||
# 共享配置
|
||||
'cooldown_minutes': 8,
|
||||
'max_alerts_per_minute': 15,
|
||||
'min_alpha_abs': 0.5,
|
||||
}
|
||||
|
||||
# 特征配置(与训练一致)
|
||||
FEATURE_NAMES = [
|
||||
'alpha',
|
||||
'alpha_delta',
|
||||
'amt_ratio',
|
||||
'amt_delta',
|
||||
'rank_pct',
|
||||
'limit_up_ratio',
|
||||
]
|
||||
|
||||
|
||||
# ==================== 数据结构 ====================
|
||||
|
||||
@dataclass
|
||||
class AlphaStats:
|
||||
"""概念的Alpha统计信息"""
|
||||
history: deque = field(default_factory=lambda: deque(maxlen=ENHANCED_CONFIG['alpha_history_window']))
|
||||
mean: float = 0.0
|
||||
std: float = 1.0
|
||||
|
||||
def update(self, alpha: float):
|
||||
self.history.append(alpha)
|
||||
if len(self.history) >= 2:
|
||||
self.mean = np.mean(self.history)
|
||||
self.std = max(np.std(self.history), 0.1)
|
||||
|
||||
def get_zscore(self, alpha: float) -> float:
|
||||
if len(self.history) < ENHANCED_CONFIG['alpha_min_history']:
|
||||
return 0.0
|
||||
return (alpha - self.mean) / self.std
|
||||
|
||||
def is_ready(self) -> bool:
|
||||
return len(self.history) >= ENHANCED_CONFIG['alpha_min_history']
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConceptFeatures:
|
||||
"""概念的实时特征"""
|
||||
alpha: float = 0.0
|
||||
alpha_delta: float = 0.0
|
||||
amt_ratio: float = 1.0
|
||||
amt_delta: float = 0.0
|
||||
rank_pct: float = 0.5
|
||||
limit_up_ratio: float = 0.0
|
||||
|
||||
def to_dict(self) -> Dict[str, float]:
|
||||
return {
|
||||
'alpha': self.alpha,
|
||||
'alpha_delta': self.alpha_delta,
|
||||
'amt_ratio': self.amt_ratio,
|
||||
'amt_delta': self.amt_delta,
|
||||
'rank_pct': self.rank_pct,
|
||||
'limit_up_ratio': self.limit_up_ratio,
|
||||
}
|
||||
|
||||
|
||||
# ==================== 增强检测器 ====================
|
||||
|
||||
class EnhancedAnomalyDetector:
|
||||
"""
|
||||
增强版异动检测器
|
||||
|
||||
融合 Alpha-based 和 ML 两种方法
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Dict = None,
|
||||
ml_enabled: bool = True
|
||||
):
|
||||
self.config = config or ENHANCED_CONFIG
|
||||
self.ml_enabled = ml_enabled
|
||||
self.ml_detector = None
|
||||
|
||||
# Alpha 统计
|
||||
self.alpha_stats: Dict[str, AlphaStats] = {}
|
||||
|
||||
# 特征历史(用于计算 delta)
|
||||
self.feature_history: Dict[str, deque] = {}
|
||||
|
||||
# 冷却记录
|
||||
self.cooldown_cache: Dict[str, datetime] = {}
|
||||
|
||||
# 尝试加载 ML 模型
|
||||
if ml_enabled:
|
||||
self._load_ml_model()
|
||||
|
||||
logger.info(f"EnhancedAnomalyDetector 初始化完成")
|
||||
logger.info(f" 融合模式: {self.config['fusion_mode']}")
|
||||
logger.info(f" ML 可用: {self.ml_detector is not None}")
|
||||
|
||||
def _load_ml_model(self):
|
||||
"""加载 ML 模型"""
|
||||
try:
|
||||
from inference import ConceptAnomalyDetector
|
||||
checkpoint_dir = Path(__file__).parent / 'checkpoints'
|
||||
|
||||
if (checkpoint_dir / 'best_model.pt').exists():
|
||||
self.ml_detector = ConceptAnomalyDetector(
|
||||
checkpoint_dir=str(checkpoint_dir),
|
||||
threshold_key=self.config['ml_threshold_key']
|
||||
)
|
||||
logger.info("ML 模型加载成功")
|
||||
else:
|
||||
logger.warning(f"ML 模型不存在: {checkpoint_dir / 'best_model.pt'}")
|
||||
except Exception as e:
|
||||
logger.warning(f"ML 模型加载失败: {e}")
|
||||
self.ml_detector = None
|
||||
|
||||
def _get_alpha_stats(self, concept_id: str) -> AlphaStats:
|
||||
"""获取或创建 Alpha 统计"""
|
||||
if concept_id not in self.alpha_stats:
|
||||
self.alpha_stats[concept_id] = AlphaStats()
|
||||
return self.alpha_stats[concept_id]
|
||||
|
||||
def _get_feature_history(self, concept_id: str) -> deque:
|
||||
"""获取特征历史"""
|
||||
if concept_id not in self.feature_history:
|
||||
self.feature_history[concept_id] = deque(maxlen=10)
|
||||
return self.feature_history[concept_id]
|
||||
|
||||
def _check_cooldown(self, concept_id: str, current_time: datetime) -> bool:
|
||||
"""检查冷却"""
|
||||
if concept_id not in self.cooldown_cache:
|
||||
return False
|
||||
|
||||
last_alert = self.cooldown_cache[concept_id]
|
||||
cooldown_td = (current_time - last_alert).total_seconds() / 60
|
||||
|
||||
return cooldown_td < self.config['cooldown_minutes']
|
||||
|
||||
def _set_cooldown(self, concept_id: str, current_time: datetime):
|
||||
"""设置冷却"""
|
||||
self.cooldown_cache[concept_id] = current_time
|
||||
|
||||
def compute_features(
|
||||
self,
|
||||
concept_id: str,
|
||||
alpha: float,
|
||||
amt_ratio: float,
|
||||
rank_pct: float,
|
||||
limit_up_ratio: float
|
||||
) -> ConceptFeatures:
|
||||
"""
|
||||
计算概念的完整特征
|
||||
|
||||
Args:
|
||||
concept_id: 概念ID
|
||||
alpha: 当前超额收益
|
||||
amt_ratio: 成交额比率
|
||||
rank_pct: 排名百分位
|
||||
limit_up_ratio: 涨停股占比
|
||||
|
||||
Returns:
|
||||
完整特征
|
||||
"""
|
||||
history = self._get_feature_history(concept_id)
|
||||
|
||||
# 计算变化率
|
||||
alpha_delta = 0.0
|
||||
amt_delta = 0.0
|
||||
|
||||
if len(history) > 0:
|
||||
last_features = history[-1]
|
||||
alpha_delta = alpha - last_features.alpha
|
||||
if last_features.amt_ratio > 0:
|
||||
amt_delta = (amt_ratio - last_features.amt_ratio) / last_features.amt_ratio
|
||||
|
||||
features = ConceptFeatures(
|
||||
alpha=alpha,
|
||||
alpha_delta=alpha_delta,
|
||||
amt_ratio=amt_ratio,
|
||||
amt_delta=amt_delta,
|
||||
rank_pct=rank_pct,
|
||||
limit_up_ratio=limit_up_ratio,
|
||||
)
|
||||
|
||||
# 更新历史
|
||||
history.append(features)
|
||||
|
||||
return features
|
||||
|
||||
def detect_alpha_anomaly(
|
||||
self,
|
||||
concept_id: str,
|
||||
alpha: float
|
||||
) -> Tuple[bool, float, str]:
|
||||
"""
|
||||
Alpha-based 异动检测
|
||||
|
||||
Returns:
|
||||
is_anomaly: 是否异动
|
||||
score: 异动分数(Z-Score 绝对值)
|
||||
reason: 触发原因
|
||||
"""
|
||||
stats = self._get_alpha_stats(concept_id)
|
||||
|
||||
# 计算 Z-Score(在更新前)
|
||||
zscore = stats.get_zscore(alpha)
|
||||
|
||||
# 更新统计
|
||||
stats.update(alpha)
|
||||
|
||||
# 判断
|
||||
if stats.is_ready():
|
||||
if abs(zscore) >= self.config['alpha_zscore_threshold']:
|
||||
return True, abs(zscore), f"Z={zscore:.2f}"
|
||||
else:
|
||||
if abs(alpha) >= self.config['alpha_absolute_threshold']:
|
||||
fake_zscore = alpha / 0.5
|
||||
return True, abs(fake_zscore), f"Alpha={alpha:+.2f}%"
|
||||
|
||||
return False, abs(zscore) if zscore else 0.0, ""
|
||||
|
||||
def detect_ml_anomaly(
|
||||
self,
|
||||
concept_id: str,
|
||||
features: ConceptFeatures
|
||||
) -> Tuple[bool, float]:
|
||||
"""
|
||||
ML-based 异动检测
|
||||
|
||||
Returns:
|
||||
is_anomaly: 是否异动
|
||||
score: 异动分数(重构误差)
|
||||
"""
|
||||
if self.ml_detector is None:
|
||||
return False, 0.0
|
||||
|
||||
try:
|
||||
is_anomaly, score = self.ml_detector.detect(
|
||||
concept_id,
|
||||
features.to_dict()
|
||||
)
|
||||
return is_anomaly, score or 0.0
|
||||
except Exception as e:
|
||||
logger.warning(f"ML 检测失败: {e}")
|
||||
return False, 0.0
|
||||
|
||||
def detect(
|
||||
self,
|
||||
concept_id: str,
|
||||
concept_name: str,
|
||||
alpha: float,
|
||||
amt_ratio: float,
|
||||
rank_pct: float,
|
||||
limit_up_ratio: float,
|
||||
change_pct: float,
|
||||
index_change: float,
|
||||
current_time: datetime,
|
||||
**extra_data
|
||||
) -> Optional[Dict]:
|
||||
"""
|
||||
融合检测
|
||||
|
||||
Args:
|
||||
concept_id: 概念ID
|
||||
concept_name: 概念名称
|
||||
alpha: 超额收益
|
||||
amt_ratio: 成交额比率
|
||||
rank_pct: 排名百分位
|
||||
limit_up_ratio: 涨停股占比
|
||||
change_pct: 概念涨跌幅
|
||||
index_change: 大盘涨跌幅
|
||||
current_time: 当前时间
|
||||
**extra_data: 其他数据(limit_up_count, stock_count 等)
|
||||
|
||||
Returns:
|
||||
异动信息(如果触发),否则 None
|
||||
"""
|
||||
# Alpha 太小,不关注
|
||||
if abs(alpha) < self.config['min_alpha_abs']:
|
||||
return None
|
||||
|
||||
# 检查冷却
|
||||
if self._check_cooldown(concept_id, current_time):
|
||||
return None
|
||||
|
||||
# 计算特征
|
||||
features = self.compute_features(
|
||||
concept_id, alpha, amt_ratio, rank_pct, limit_up_ratio
|
||||
)
|
||||
|
||||
# 执行检测
|
||||
fusion_mode = self.config['fusion_mode']
|
||||
|
||||
alpha_anomaly, alpha_score, alpha_reason = self.detect_alpha_anomaly(concept_id, alpha)
|
||||
ml_anomaly, ml_score = False, 0.0
|
||||
|
||||
if fusion_mode in ('ml_only', 'adaptive', 'ensemble'):
|
||||
ml_anomaly, ml_score = self.detect_ml_anomaly(concept_id, features)
|
||||
|
||||
# 根据融合模式判断
|
||||
is_anomaly = False
|
||||
final_score = 0.0
|
||||
detection_method = ''
|
||||
|
||||
if fusion_mode == 'alpha_only':
|
||||
is_anomaly = alpha_anomaly
|
||||
final_score = alpha_score
|
||||
detection_method = 'alpha'
|
||||
|
||||
elif fusion_mode == 'ml_only':
|
||||
is_anomaly = ml_anomaly
|
||||
final_score = ml_score
|
||||
detection_method = 'ml'
|
||||
|
||||
elif fusion_mode == 'adaptive':
|
||||
# 优先 ML,回退 Alpha
|
||||
if self.ml_detector and ml_score > 0:
|
||||
is_anomaly = ml_anomaly
|
||||
final_score = ml_score
|
||||
detection_method = 'ml'
|
||||
else:
|
||||
is_anomaly = alpha_anomaly
|
||||
final_score = alpha_score
|
||||
detection_method = 'alpha'
|
||||
|
||||
elif fusion_mode == 'ensemble':
|
||||
# 加权融合
|
||||
# 归一化分数
|
||||
norm_alpha = min(alpha_score / 5.0, 1.0) # Z > 5 视为 1.0
|
||||
norm_ml = min(ml_score / (self.ml_detector.threshold if self.ml_detector else 1.0), 1.0)
|
||||
|
||||
final_score = (
|
||||
self.config['alpha_weight'] * norm_alpha +
|
||||
self.config['ml_weight'] * norm_ml
|
||||
)
|
||||
is_anomaly = final_score > 0.5 or alpha_anomaly or ml_anomaly
|
||||
detection_method = 'ensemble'
|
||||
|
||||
if not is_anomaly:
|
||||
return None
|
||||
|
||||
# 构建异动记录
|
||||
self._set_cooldown(concept_id, current_time)
|
||||
|
||||
alert_type = 'surge_up' if alpha > 0 else 'surge_down'
|
||||
|
||||
alert = {
|
||||
'concept_id': concept_id,
|
||||
'concept_name': concept_name,
|
||||
'alert_type': alert_type,
|
||||
'alert_time': current_time,
|
||||
'change_pct': change_pct,
|
||||
'alpha': alpha,
|
||||
'alpha_zscore': alpha_score,
|
||||
'index_change_pct': index_change,
|
||||
'detection_method': detection_method,
|
||||
'alpha_score': alpha_score,
|
||||
'ml_score': ml_score,
|
||||
'final_score': final_score,
|
||||
**extra_data
|
||||
}
|
||||
|
||||
return alert
|
||||
|
||||
def batch_detect(
|
||||
self,
|
||||
concepts_data: List[Dict],
|
||||
current_time: datetime
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
批量检测
|
||||
|
||||
Args:
|
||||
concepts_data: 概念数据列表
|
||||
current_time: 当前时间
|
||||
|
||||
Returns:
|
||||
异动列表(按分数排序,限制数量)
|
||||
"""
|
||||
alerts = []
|
||||
|
||||
for data in concepts_data:
|
||||
alert = self.detect(
|
||||
concept_id=data['concept_id'],
|
||||
concept_name=data['concept_name'],
|
||||
alpha=data.get('alpha', 0),
|
||||
amt_ratio=data.get('amt_ratio', 1.0),
|
||||
rank_pct=data.get('rank_pct', 0.5),
|
||||
limit_up_ratio=data.get('limit_up_ratio', 0),
|
||||
change_pct=data.get('change_pct', 0),
|
||||
index_change=data.get('index_change', 0),
|
||||
current_time=current_time,
|
||||
limit_up_count=data.get('limit_up_count', 0),
|
||||
limit_down_count=data.get('limit_down_count', 0),
|
||||
stock_count=data.get('stock_count', 0),
|
||||
concept_type=data.get('concept_type', 'leaf'),
|
||||
)
|
||||
|
||||
if alert:
|
||||
alerts.append(alert)
|
||||
|
||||
# 排序并限制数量
|
||||
alerts.sort(key=lambda x: x['final_score'], reverse=True)
|
||||
return alerts[:self.config['max_alerts_per_minute']]
|
||||
|
||||
def reset(self):
|
||||
"""重置所有状态(新交易日)"""
|
||||
self.alpha_stats.clear()
|
||||
self.feature_history.clear()
|
||||
self.cooldown_cache.clear()
|
||||
|
||||
if self.ml_detector:
|
||||
self.ml_detector.clear_history()
|
||||
|
||||
logger.info("检测器状态已重置")
|
||||
|
||||
|
||||
# ==================== 测试 ====================
|
||||
|
||||
if __name__ == "__main__":
|
||||
import random
|
||||
|
||||
print("测试 EnhancedAnomalyDetector...")
|
||||
|
||||
# 初始化
|
||||
detector = EnhancedAnomalyDetector(ml_enabled=False) # 不加载 ML(可能不存在)
|
||||
|
||||
# 模拟数据
|
||||
concepts = [
|
||||
{'concept_id': 'ai_001', 'concept_name': '人工智能'},
|
||||
{'concept_id': 'chip_002', 'concept_name': '芯片半导体'},
|
||||
{'concept_id': 'car_003', 'concept_name': '新能源汽车'},
|
||||
]
|
||||
|
||||
print("\n模拟实时检测...")
|
||||
current_time = datetime.now()
|
||||
|
||||
for minute in range(50):
|
||||
concepts_data = []
|
||||
|
||||
for c in concepts:
|
||||
# 生成随机数据
|
||||
alpha = random.gauss(0, 0.8)
|
||||
amt_ratio = max(0.3, random.gauss(1, 0.3))
|
||||
rank_pct = random.random()
|
||||
limit_up_ratio = random.random() * 0.1
|
||||
|
||||
# 模拟异动(第30分钟人工智能暴涨)
|
||||
if minute == 30 and c['concept_id'] == 'ai_001':
|
||||
alpha = 4.5
|
||||
amt_ratio = 2.5
|
||||
limit_up_ratio = 0.3
|
||||
|
||||
concepts_data.append({
|
||||
**c,
|
||||
'alpha': alpha,
|
||||
'amt_ratio': amt_ratio,
|
||||
'rank_pct': rank_pct,
|
||||
'limit_up_ratio': limit_up_ratio,
|
||||
'change_pct': alpha + 0.5,
|
||||
'index_change': 0.5,
|
||||
})
|
||||
|
||||
# 检测
|
||||
alerts = detector.batch_detect(concepts_data, current_time)
|
||||
|
||||
if alerts:
|
||||
for alert in alerts:
|
||||
print(f" t={minute:02d} 🔥 {alert['concept_name']} "
|
||||
f"Alpha={alert['alpha']:+.2f}% "
|
||||
f"Score={alert['final_score']:.2f} "
|
||||
f"Method={alert['detection_method']}")
|
||||
|
||||
current_time = current_time.replace(minute=current_time.minute + 1 if current_time.minute < 59 else 0)
|
||||
|
||||
print("\n测试完成!")
|
||||
455
ml/inference.py
Normal file
455
ml/inference.py
Normal file
@@ -0,0 +1,455 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
概念异动检测推理服务
|
||||
|
||||
在实时场景中使用训练好的 Transformer Autoencoder 进行异动检测
|
||||
|
||||
使用方法:
|
||||
from ml.inference import ConceptAnomalyDetector
|
||||
|
||||
detector = ConceptAnomalyDetector('ml/checkpoints')
|
||||
|
||||
# 检测异动
|
||||
features = {...} # 实时特征数据
|
||||
is_anomaly, score = detector.detect(features)
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple, Optional
|
||||
from collections import deque
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from model import TransformerAutoencoder
|
||||
|
||||
|
||||
class ConceptAnomalyDetector:
|
||||
"""
|
||||
概念异动检测器
|
||||
|
||||
使用训练好的 Transformer Autoencoder 进行实时异动检测
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
checkpoint_dir: str = 'ml/checkpoints',
|
||||
device: str = 'auto',
|
||||
threshold_key: str = 'p95'
|
||||
):
|
||||
"""
|
||||
初始化检测器
|
||||
|
||||
Args:
|
||||
checkpoint_dir: 模型检查点目录
|
||||
device: 设备 (auto/cuda/cpu)
|
||||
threshold_key: 使用的阈值键 (p90/p95/p99)
|
||||
"""
|
||||
self.checkpoint_dir = Path(checkpoint_dir)
|
||||
self.threshold_key = threshold_key
|
||||
|
||||
# 设备选择
|
||||
if device == 'auto':
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
else:
|
||||
self.device = torch.device(device)
|
||||
|
||||
# 加载配置
|
||||
self._load_config()
|
||||
|
||||
# 加载模型
|
||||
self._load_model()
|
||||
|
||||
# 加载阈值
|
||||
self._load_thresholds()
|
||||
|
||||
# 加载标准化统计量
|
||||
self._load_normalization_stats()
|
||||
|
||||
# 概念历史数据缓存
|
||||
# {concept_name: deque(maxlen=seq_len)}
|
||||
self.history_cache: Dict[str, deque] = {}
|
||||
|
||||
print(f"ConceptAnomalyDetector 初始化完成")
|
||||
print(f" 设备: {self.device}")
|
||||
print(f" 阈值: {self.threshold_key} = {self.threshold:.6f}")
|
||||
print(f" 序列长度: {self.seq_len}")
|
||||
|
||||
def _load_config(self):
|
||||
"""加载配置"""
|
||||
config_path = self.checkpoint_dir / 'config.json'
|
||||
if not config_path.exists():
|
||||
raise FileNotFoundError(f"配置文件不存在: {config_path}")
|
||||
|
||||
with open(config_path, 'r') as f:
|
||||
self.config = json.load(f)
|
||||
|
||||
self.features = self.config['features']
|
||||
self.seq_len = self.config['seq_len']
|
||||
self.model_config = self.config['model']
|
||||
|
||||
def _load_model(self):
|
||||
"""加载模型"""
|
||||
model_path = self.checkpoint_dir / 'best_model.pt'
|
||||
if not model_path.exists():
|
||||
raise FileNotFoundError(f"模型文件不存在: {model_path}")
|
||||
|
||||
# 创建模型
|
||||
self.model = TransformerAutoencoder(**self.model_config)
|
||||
|
||||
# 加载权重
|
||||
checkpoint = torch.load(model_path, map_location=self.device)
|
||||
self.model.load_state_dict(checkpoint['model_state_dict'])
|
||||
self.model.to(self.device)
|
||||
self.model.eval()
|
||||
|
||||
print(f"模型已加载: {model_path}")
|
||||
|
||||
def _load_thresholds(self):
|
||||
"""加载阈值"""
|
||||
thresholds_path = self.checkpoint_dir / 'thresholds.json'
|
||||
if not thresholds_path.exists():
|
||||
raise FileNotFoundError(f"阈值文件不存在: {thresholds_path}")
|
||||
|
||||
with open(thresholds_path, 'r') as f:
|
||||
self.thresholds = json.load(f)
|
||||
|
||||
if self.threshold_key not in self.thresholds:
|
||||
available_keys = list(self.thresholds.keys())
|
||||
raise KeyError(f"阈值键 '{self.threshold_key}' 不存在,可用: {available_keys}")
|
||||
|
||||
self.threshold = self.thresholds[self.threshold_key]
|
||||
|
||||
def _load_normalization_stats(self):
|
||||
"""加载标准化统计量"""
|
||||
stats_path = self.checkpoint_dir / 'normalization_stats.json'
|
||||
if not stats_path.exists():
|
||||
raise FileNotFoundError(f"标准化统计量文件不存在: {stats_path}")
|
||||
|
||||
with open(stats_path, 'r') as f:
|
||||
stats = json.load(f)
|
||||
|
||||
self.norm_mean = np.array(stats['mean'])
|
||||
self.norm_std = np.array(stats['std'])
|
||||
|
||||
def normalize(self, features: np.ndarray) -> np.ndarray:
|
||||
"""标准化特征"""
|
||||
return (features - self.norm_mean) / self.norm_std
|
||||
|
||||
def update_history(
|
||||
self,
|
||||
concept_name: str,
|
||||
features: Dict[str, float]
|
||||
):
|
||||
"""
|
||||
更新概念历史数据
|
||||
|
||||
Args:
|
||||
concept_name: 概念名称
|
||||
features: 当前时刻的特征字典
|
||||
"""
|
||||
# 初始化历史缓存
|
||||
if concept_name not in self.history_cache:
|
||||
self.history_cache[concept_name] = deque(maxlen=self.seq_len)
|
||||
|
||||
# 提取特征向量
|
||||
feature_vector = np.array([
|
||||
features.get(f, 0.0) for f in self.features
|
||||
])
|
||||
|
||||
# 处理异常值
|
||||
feature_vector = np.nan_to_num(feature_vector, nan=0.0, posinf=0.0, neginf=0.0)
|
||||
|
||||
# 添加到历史
|
||||
self.history_cache[concept_name].append(feature_vector)
|
||||
|
||||
def get_history_length(self, concept_name: str) -> int:
|
||||
"""获取概念的历史数据长度"""
|
||||
if concept_name not in self.history_cache:
|
||||
return 0
|
||||
return len(self.history_cache[concept_name])
|
||||
|
||||
@torch.no_grad()
|
||||
def detect(
|
||||
self,
|
||||
concept_name: str,
|
||||
features: Dict[str, float] = None,
|
||||
return_score: bool = True
|
||||
) -> Tuple[bool, Optional[float]]:
|
||||
"""
|
||||
检测概念是否异动
|
||||
|
||||
Args:
|
||||
concept_name: 概念名称
|
||||
features: 当前时刻的特征(如果提供,会先更新历史)
|
||||
return_score: 是否返回异动分数
|
||||
|
||||
Returns:
|
||||
is_anomaly: 是否异动
|
||||
score: 异动分数(如果 return_score=True)
|
||||
"""
|
||||
# 更新历史
|
||||
if features is not None:
|
||||
self.update_history(concept_name, features)
|
||||
|
||||
# 检查历史数据是否足够
|
||||
if concept_name not in self.history_cache:
|
||||
return False, None
|
||||
|
||||
history = self.history_cache[concept_name]
|
||||
if len(history) < self.seq_len:
|
||||
return False, None
|
||||
|
||||
# 构建输入序列
|
||||
sequence = np.array(list(history)) # (seq_len, n_features)
|
||||
|
||||
# 标准化
|
||||
sequence = self.normalize(sequence)
|
||||
|
||||
# 转为 tensor
|
||||
x = torch.FloatTensor(sequence).unsqueeze(0) # (1, seq_len, n_features)
|
||||
x = x.to(self.device)
|
||||
|
||||
# 计算重构误差
|
||||
error = self.model.compute_reconstruction_error(x, reduction='none')
|
||||
|
||||
# 取最后一个时刻的误差作为当前分数
|
||||
score = error[0, -1].item()
|
||||
|
||||
# 判断是否异动
|
||||
is_anomaly = score > self.threshold
|
||||
|
||||
if return_score:
|
||||
return is_anomaly, score
|
||||
else:
|
||||
return is_anomaly, None
|
||||
|
||||
@torch.no_grad()
|
||||
def batch_detect(
|
||||
self,
|
||||
concept_features: Dict[str, Dict[str, float]]
|
||||
) -> Dict[str, Tuple[bool, float]]:
|
||||
"""
|
||||
批量检测多个概念
|
||||
|
||||
Args:
|
||||
concept_features: {concept_name: {feature_name: value}}
|
||||
|
||||
Returns:
|
||||
results: {concept_name: (is_anomaly, score)}
|
||||
"""
|
||||
results = {}
|
||||
|
||||
for concept_name, features in concept_features.items():
|
||||
is_anomaly, score = self.detect(concept_name, features)
|
||||
results[concept_name] = (is_anomaly, score)
|
||||
|
||||
return results
|
||||
|
||||
def get_anomaly_type(
|
||||
self,
|
||||
concept_name: str,
|
||||
features: Dict[str, float]
|
||||
) -> str:
|
||||
"""
|
||||
判断异动类型
|
||||
|
||||
Args:
|
||||
concept_name: 概念名称
|
||||
features: 当前特征
|
||||
|
||||
Returns:
|
||||
anomaly_type: 'surge_up' / 'surge_down' / 'normal'
|
||||
"""
|
||||
is_anomaly, score = self.detect(concept_name, features)
|
||||
|
||||
if not is_anomaly:
|
||||
return 'normal'
|
||||
|
||||
# 根据 alpha 判断涨跌
|
||||
alpha = features.get('alpha', 0.0)
|
||||
|
||||
if alpha > 0:
|
||||
return 'surge_up'
|
||||
else:
|
||||
return 'surge_down'
|
||||
|
||||
def get_top_anomalies(
|
||||
self,
|
||||
concept_features: Dict[str, Dict[str, float]],
|
||||
top_k: int = 10
|
||||
) -> List[Tuple[str, float, str]]:
|
||||
"""
|
||||
获取异动分数最高的 top_k 个概念
|
||||
|
||||
Args:
|
||||
concept_features: {concept_name: {feature_name: value}}
|
||||
top_k: 返回数量
|
||||
|
||||
Returns:
|
||||
anomalies: [(concept_name, score, anomaly_type), ...]
|
||||
"""
|
||||
results = self.batch_detect(concept_features)
|
||||
|
||||
# 按分数排序
|
||||
sorted_results = sorted(
|
||||
[(name, is_anomaly, score) for name, (is_anomaly, score) in results.items() if score is not None],
|
||||
key=lambda x: x[2],
|
||||
reverse=True
|
||||
)
|
||||
|
||||
# 取 top_k
|
||||
top_anomalies = []
|
||||
for name, is_anomaly, score in sorted_results[:top_k]:
|
||||
if is_anomaly:
|
||||
alpha = concept_features[name].get('alpha', 0.0)
|
||||
anomaly_type = 'surge_up' if alpha > 0 else 'surge_down'
|
||||
top_anomalies.append((name, score, anomaly_type))
|
||||
|
||||
return top_anomalies
|
||||
|
||||
def clear_history(self, concept_name: str = None):
|
||||
"""
|
||||
清除历史缓存
|
||||
|
||||
Args:
|
||||
concept_name: 概念名称(如果为 None,清除所有)
|
||||
"""
|
||||
if concept_name is None:
|
||||
self.history_cache.clear()
|
||||
elif concept_name in self.history_cache:
|
||||
del self.history_cache[concept_name]
|
||||
|
||||
|
||||
# ==================== 集成到现有系统 ====================
|
||||
|
||||
class MLAnomalyService:
|
||||
"""
|
||||
ML 异动检测服务
|
||||
|
||||
用于替换或增强现有的 Alpha-based 检测
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
checkpoint_dir: str = 'ml/checkpoints',
|
||||
fallback_to_alpha: bool = True
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
checkpoint_dir: 模型检查点目录
|
||||
fallback_to_alpha: 当 ML 模型不可用时是否回退到 Alpha 方法
|
||||
"""
|
||||
self.fallback_to_alpha = fallback_to_alpha
|
||||
self.ml_detector = None
|
||||
|
||||
try:
|
||||
self.ml_detector = ConceptAnomalyDetector(checkpoint_dir)
|
||||
print("ML 异动检测服务初始化成功")
|
||||
except Exception as e:
|
||||
print(f"ML 模型加载失败: {e}")
|
||||
if not fallback_to_alpha:
|
||||
raise
|
||||
print("将回退到 Alpha-based 检测")
|
||||
|
||||
def is_ml_available(self) -> bool:
|
||||
"""检查 ML 模型是否可用"""
|
||||
return self.ml_detector is not None
|
||||
|
||||
def detect_anomaly(
|
||||
self,
|
||||
concept_name: str,
|
||||
features: Dict[str, float],
|
||||
alpha_threshold: float = 2.0
|
||||
) -> Tuple[bool, float, str]:
|
||||
"""
|
||||
检测异动
|
||||
|
||||
Args:
|
||||
concept_name: 概念名称
|
||||
features: 特征字典(需包含 alpha, amt_ratio 等)
|
||||
alpha_threshold: Alpha Z-Score 阈值(用于回退)
|
||||
|
||||
Returns:
|
||||
is_anomaly: 是否异动
|
||||
score: 异动分数
|
||||
method: 检测方法 ('ml' / 'alpha')
|
||||
"""
|
||||
# 优先使用 ML 检测
|
||||
if self.ml_detector is not None:
|
||||
history_len = self.ml_detector.get_history_length(concept_name)
|
||||
|
||||
# 历史数据足够时使用 ML
|
||||
if history_len >= self.ml_detector.seq_len - 1:
|
||||
is_anomaly, score = self.ml_detector.detect(concept_name, features)
|
||||
if score is not None:
|
||||
return is_anomaly, score, 'ml'
|
||||
else:
|
||||
# 更新历史但使用 Alpha 方法
|
||||
self.ml_detector.update_history(concept_name, features)
|
||||
|
||||
# 回退到 Alpha 方法
|
||||
if self.fallback_to_alpha:
|
||||
alpha = features.get('alpha', 0.0)
|
||||
alpha_zscore = features.get('alpha_zscore', 0.0)
|
||||
|
||||
is_anomaly = abs(alpha_zscore) > alpha_threshold
|
||||
score = abs(alpha_zscore)
|
||||
|
||||
return is_anomaly, score, 'alpha'
|
||||
|
||||
return False, 0.0, 'none'
|
||||
|
||||
|
||||
# ==================== 测试 ====================
|
||||
|
||||
if __name__ == "__main__":
|
||||
import random
|
||||
|
||||
print("测试 ConceptAnomalyDetector...")
|
||||
|
||||
# 检查模型是否存在
|
||||
checkpoint_dir = Path('ml/checkpoints')
|
||||
if not (checkpoint_dir / 'best_model.pt').exists():
|
||||
print("模型文件不存在,跳过测试")
|
||||
print("请先运行 train.py 训练模型")
|
||||
exit(0)
|
||||
|
||||
# 初始化检测器
|
||||
detector = ConceptAnomalyDetector('ml/checkpoints')
|
||||
|
||||
# 模拟数据
|
||||
print("\n模拟实时检测...")
|
||||
concept_name = "人工智能"
|
||||
|
||||
for i in range(40):
|
||||
# 生成随机特征
|
||||
features = {
|
||||
'alpha': random.gauss(0, 1),
|
||||
'alpha_delta': random.gauss(0, 0.5),
|
||||
'amt_ratio': random.gauss(1, 0.3),
|
||||
'amt_delta': random.gauss(0, 0.2),
|
||||
'rank_pct': random.random(),
|
||||
'limit_up_ratio': random.random() * 0.1,
|
||||
}
|
||||
|
||||
# 在第 35 分钟模拟异动
|
||||
if i == 35:
|
||||
features['alpha'] = 5.0
|
||||
features['alpha_delta'] = 2.0
|
||||
features['amt_ratio'] = 3.0
|
||||
|
||||
is_anomaly, score = detector.detect(concept_name, features)
|
||||
|
||||
history_len = detector.get_history_length(concept_name)
|
||||
|
||||
if score is not None:
|
||||
status = "🔥 异动!" if is_anomaly else "正常"
|
||||
print(f" t={i:02d} | 历史={history_len} | 分数={score:.4f} | {status}")
|
||||
else:
|
||||
print(f" t={i:02d} | 历史={history_len} | 数据不足")
|
||||
|
||||
print("\n测试完成!")
|
||||
390
ml/model.py
Normal file
390
ml/model.py
Normal file
@@ -0,0 +1,390 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
LSTM Autoencoder 模型定义
|
||||
|
||||
用于概念异动检测:
|
||||
- 学习"正常"市场模式
|
||||
- 重构误差大的时刻 = 异动
|
||||
|
||||
模型结构(简洁有效):
|
||||
┌─────────────────────────────────────┐
|
||||
│ 输入: (batch, seq_len, n_features) │
|
||||
│ 过去30分钟的特征序列 │
|
||||
├─────────────────────────────────────┤
|
||||
│ LSTM Encoder │
|
||||
│ - 双向 LSTM │
|
||||
│ - 输出最后隐藏状态 │
|
||||
├─────────────────────────────────────┤
|
||||
│ Bottleneck (压缩层) │
|
||||
│ 降维到 latent_dim(关键!) │
|
||||
├─────────────────────────────────────┤
|
||||
│ LSTM Decoder │
|
||||
│ - 单向 LSTM │
|
||||
│ - 重构序列 │
|
||||
├─────────────────────────────────────┤
|
||||
│ 输出: (batch, seq_len, n_features) │
|
||||
│ 重构的特征序列 │
|
||||
└─────────────────────────────────────┘
|
||||
|
||||
为什么用 LSTM 而不是 Transformer:
|
||||
1. 参数更少,不容易过拟合
|
||||
2. 对于 6 维特征足够用
|
||||
3. 训练更稳定
|
||||
4. 瓶颈约束更容易控制
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from typing import Optional, Tuple
|
||||
|
||||
|
||||
class LSTMAutoencoder(nn.Module):
|
||||
"""
|
||||
LSTM Autoencoder for Anomaly Detection
|
||||
|
||||
设计原则:
|
||||
- 足够简单,避免过拟合
|
||||
- 瓶颈层严格限制,迫使模型只学习主要模式
|
||||
- 异常难以通过狭窄瓶颈,重构误差大
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_features: int = 6,
|
||||
hidden_dim: int = 32, # LSTM 隐藏维度(小!)
|
||||
latent_dim: int = 4, # 瓶颈维度(非常小!关键参数)
|
||||
num_layers: int = 1, # LSTM 层数
|
||||
dropout: float = 0.2,
|
||||
bidirectional: bool = True, # 双向编码器
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.n_features = n_features
|
||||
self.hidden_dim = hidden_dim
|
||||
self.latent_dim = latent_dim
|
||||
self.num_layers = num_layers
|
||||
self.bidirectional = bidirectional
|
||||
self.num_directions = 2 if bidirectional else 1
|
||||
|
||||
# Encoder: 双向 LSTM
|
||||
self.encoder = nn.LSTM(
|
||||
input_size=n_features,
|
||||
hidden_size=hidden_dim,
|
||||
num_layers=num_layers,
|
||||
batch_first=True,
|
||||
dropout=dropout if num_layers > 1 else 0,
|
||||
bidirectional=bidirectional
|
||||
)
|
||||
|
||||
# Bottleneck: 压缩到极小的 latent space
|
||||
encoder_output_dim = hidden_dim * self.num_directions
|
||||
self.bottleneck_down = nn.Sequential(
|
||||
nn.Linear(encoder_output_dim, latent_dim),
|
||||
nn.Tanh(), # 限制范围,增加约束
|
||||
)
|
||||
|
||||
self.bottleneck_up = nn.Sequential(
|
||||
nn.Linear(latent_dim, hidden_dim),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
# Decoder: 单向 LSTM
|
||||
self.decoder = nn.LSTM(
|
||||
input_size=hidden_dim,
|
||||
hidden_size=hidden_dim,
|
||||
num_layers=num_layers,
|
||||
batch_first=True,
|
||||
dropout=dropout if num_layers > 1 else 0,
|
||||
bidirectional=False # 解码器用单向
|
||||
)
|
||||
|
||||
# 输出层
|
||||
self.output_layer = nn.Linear(hidden_dim, n_features)
|
||||
|
||||
# Dropout
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
# 初始化
|
||||
self._init_weights()
|
||||
|
||||
def _init_weights(self):
|
||||
"""初始化权重"""
|
||||
for name, param in self.named_parameters():
|
||||
if 'weight_ih' in name:
|
||||
nn.init.xavier_uniform_(param)
|
||||
elif 'weight_hh' in name:
|
||||
nn.init.orthogonal_(param)
|
||||
elif 'bias' in name:
|
||||
nn.init.zeros_(param)
|
||||
|
||||
def encode(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
编码器
|
||||
|
||||
Args:
|
||||
x: (batch, seq_len, n_features)
|
||||
Returns:
|
||||
latent: (batch, seq_len, latent_dim) 每个时间步的压缩表示
|
||||
encoder_outputs: (batch, seq_len, hidden_dim * num_directions)
|
||||
"""
|
||||
# LSTM 编码
|
||||
encoder_outputs, (h_n, c_n) = self.encoder(x)
|
||||
# encoder_outputs: (batch, seq_len, hidden_dim * num_directions)
|
||||
|
||||
encoder_outputs = self.dropout(encoder_outputs)
|
||||
|
||||
# 压缩到 latent space(对每个时间步)
|
||||
latent = self.bottleneck_down(encoder_outputs)
|
||||
# latent: (batch, seq_len, latent_dim)
|
||||
|
||||
return latent, encoder_outputs
|
||||
|
||||
def decode(self, latent: torch.Tensor, seq_len: int) -> torch.Tensor:
|
||||
"""
|
||||
解码器
|
||||
|
||||
Args:
|
||||
latent: (batch, seq_len, latent_dim)
|
||||
seq_len: 序列长度
|
||||
Returns:
|
||||
output: (batch, seq_len, n_features)
|
||||
"""
|
||||
# 从 latent space 恢复
|
||||
decoder_input = self.bottleneck_up(latent)
|
||||
# decoder_input: (batch, seq_len, hidden_dim)
|
||||
|
||||
# LSTM 解码
|
||||
decoder_outputs, _ = self.decoder(decoder_input)
|
||||
# decoder_outputs: (batch, seq_len, hidden_dim)
|
||||
|
||||
decoder_outputs = self.dropout(decoder_outputs)
|
||||
|
||||
# 投影到原始特征空间
|
||||
output = self.output_layer(decoder_outputs)
|
||||
# output: (batch, seq_len, n_features)
|
||||
|
||||
return output
|
||||
|
||||
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
前向传播
|
||||
|
||||
Args:
|
||||
x: (batch, seq_len, n_features)
|
||||
Returns:
|
||||
output: (batch, seq_len, n_features) 重构结果
|
||||
latent: (batch, seq_len, latent_dim) 隐向量
|
||||
"""
|
||||
batch_size, seq_len, _ = x.shape
|
||||
|
||||
# 编码
|
||||
latent, _ = self.encode(x)
|
||||
|
||||
# 解码
|
||||
output = self.decode(latent, seq_len)
|
||||
|
||||
return output, latent
|
||||
|
||||
def compute_reconstruction_error(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
reduction: str = 'none'
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
计算重构误差
|
||||
|
||||
Args:
|
||||
x: (batch, seq_len, n_features)
|
||||
reduction: 'none' | 'mean' | 'sum'
|
||||
Returns:
|
||||
error: 重构误差
|
||||
"""
|
||||
output, _ = self.forward(x)
|
||||
|
||||
# MSE per feature per timestep
|
||||
error = F.mse_loss(output, x, reduction='none')
|
||||
|
||||
if reduction == 'none':
|
||||
# (batch, seq_len, n_features) -> (batch, seq_len)
|
||||
return error.mean(dim=-1)
|
||||
elif reduction == 'mean':
|
||||
return error.mean()
|
||||
elif reduction == 'sum':
|
||||
return error.sum()
|
||||
else:
|
||||
raise ValueError(f"Unknown reduction: {reduction}")
|
||||
|
||||
def detect_anomaly(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
threshold: float = None,
|
||||
return_scores: bool = True
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""
|
||||
检测异动
|
||||
|
||||
Args:
|
||||
x: (batch, seq_len, n_features)
|
||||
threshold: 异动阈值(如果为 None,只返回分数)
|
||||
return_scores: 是否返回异动分数
|
||||
Returns:
|
||||
is_anomaly: (batch, seq_len) bool tensor (if threshold is not None)
|
||||
scores: (batch, seq_len) 异动分数 (if return_scores)
|
||||
"""
|
||||
scores = self.compute_reconstruction_error(x, reduction='none')
|
||||
|
||||
is_anomaly = None
|
||||
if threshold is not None:
|
||||
is_anomaly = scores > threshold
|
||||
|
||||
if return_scores:
|
||||
return is_anomaly, scores
|
||||
else:
|
||||
return is_anomaly, None
|
||||
|
||||
|
||||
# 为了兼容性,创建别名
|
||||
TransformerAutoencoder = LSTMAutoencoder
|
||||
|
||||
|
||||
# ==================== 损失函数 ====================
|
||||
|
||||
class AnomalyDetectionLoss(nn.Module):
|
||||
"""
|
||||
异动检测损失函数
|
||||
|
||||
简单的 MSE 重构损失
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
feature_weights: torch.Tensor = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.feature_weights = feature_weights
|
||||
|
||||
def forward(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
latent: torch.Tensor = None
|
||||
) -> Tuple[torch.Tensor, dict]:
|
||||
"""
|
||||
Args:
|
||||
output: (batch, seq_len, n_features) 重构结果
|
||||
target: (batch, seq_len, n_features) 原始输入
|
||||
latent: (batch, seq_len, latent_dim) 隐向量(未使用)
|
||||
Returns:
|
||||
loss: 总损失
|
||||
loss_dict: 各项损失详情
|
||||
"""
|
||||
# 重构损失 (MSE)
|
||||
mse = F.mse_loss(output, target, reduction='none')
|
||||
|
||||
# 特征加权(可选)
|
||||
if self.feature_weights is not None:
|
||||
weights = self.feature_weights.to(mse.device)
|
||||
mse = mse * weights
|
||||
|
||||
reconstruction_loss = mse.mean()
|
||||
|
||||
loss_dict = {
|
||||
'total': reconstruction_loss.item(),
|
||||
'reconstruction': reconstruction_loss.item(),
|
||||
}
|
||||
|
||||
return reconstruction_loss, loss_dict
|
||||
|
||||
|
||||
# ==================== 工具函数 ====================
|
||||
|
||||
def count_parameters(model: nn.Module) -> int:
|
||||
"""统计模型参数量"""
|
||||
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
|
||||
|
||||
def create_model(config: dict = None) -> LSTMAutoencoder:
|
||||
"""
|
||||
创建模型
|
||||
|
||||
默认使用小型 LSTM 配置,适合异动检测
|
||||
"""
|
||||
default_config = {
|
||||
'n_features': 6,
|
||||
'hidden_dim': 32, # 小!
|
||||
'latent_dim': 4, # 非常小!关键
|
||||
'num_layers': 1,
|
||||
'dropout': 0.2,
|
||||
'bidirectional': True,
|
||||
}
|
||||
|
||||
if config:
|
||||
# 兼容旧的 Transformer 配置键名
|
||||
if 'd_model' in config:
|
||||
config['hidden_dim'] = config.pop('d_model') // 2
|
||||
if 'num_encoder_layers' in config:
|
||||
config['num_layers'] = config.pop('num_encoder_layers')
|
||||
if 'num_decoder_layers' in config:
|
||||
config.pop('num_decoder_layers')
|
||||
if 'nhead' in config:
|
||||
config.pop('nhead')
|
||||
if 'dim_feedforward' in config:
|
||||
config.pop('dim_feedforward')
|
||||
if 'max_seq_len' in config:
|
||||
config.pop('max_seq_len')
|
||||
if 'use_instance_norm' in config:
|
||||
config.pop('use_instance_norm')
|
||||
|
||||
default_config.update(config)
|
||||
|
||||
model = LSTMAutoencoder(**default_config)
|
||||
param_count = count_parameters(model)
|
||||
print(f"模型参数量: {param_count:,}")
|
||||
|
||||
if param_count > 100000:
|
||||
print(f"⚠️ 警告: 参数量较大({param_count:,}),可能过拟合")
|
||||
else:
|
||||
print(f"✓ 参数量适中(LSTM Autoencoder)")
|
||||
|
||||
return model
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 测试模型
|
||||
print("测试 LSTM Autoencoder...")
|
||||
|
||||
# 创建模型
|
||||
model = create_model()
|
||||
|
||||
# 测试输入
|
||||
batch_size = 32
|
||||
seq_len = 30
|
||||
n_features = 6
|
||||
|
||||
x = torch.randn(batch_size, seq_len, n_features)
|
||||
|
||||
# 前向传播
|
||||
output, latent = model(x)
|
||||
|
||||
print(f"输入形状: {x.shape}")
|
||||
print(f"输出形状: {output.shape}")
|
||||
print(f"隐向量形状: {latent.shape}")
|
||||
|
||||
# 计算重构误差
|
||||
error = model.compute_reconstruction_error(x)
|
||||
print(f"重构误差形状: {error.shape}")
|
||||
print(f"平均重构误差: {error.mean().item():.4f}")
|
||||
|
||||
# 测试异动检测
|
||||
is_anomaly, scores = model.detect_anomaly(x, threshold=0.5)
|
||||
print(f"异动检测结果形状: {is_anomaly.shape if is_anomaly is not None else 'None'}")
|
||||
print(f"异动分数形状: {scores.shape}")
|
||||
|
||||
# 测试损失函数
|
||||
criterion = AnomalyDetectionLoss()
|
||||
loss, loss_dict = criterion(output, x, latent)
|
||||
print(f"损失: {loss.item():.4f}")
|
||||
|
||||
print("\n测试通过!")
|
||||
501
ml/prepare_data.py
Normal file
501
ml/prepare_data.py
Normal file
@@ -0,0 +1,501 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
数据准备脚本 - 为 Transformer Autoencoder 准备训练数据
|
||||
|
||||
从 ClickHouse 提取历史分钟数据,计算以下特征:
|
||||
1. alpha - 超额收益(概念涨幅 - 大盘涨幅)
|
||||
2. alpha_delta - Alpha 变化率(5分钟)
|
||||
3. amt_ratio - 成交额相对均值(当前/过去20分钟均值)
|
||||
4. amt_delta - 成交额变化率
|
||||
5. rank_pct - Alpha 排名百分位
|
||||
6. limit_up_ratio - 涨停股占比
|
||||
|
||||
输出:按交易日存储的特征文件(parquet格式)
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from datetime import datetime, timedelta, date
|
||||
from sqlalchemy import create_engine, text
|
||||
from elasticsearch import Elasticsearch
|
||||
from clickhouse_driver import Client
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, List, Set, Tuple
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
import warnings
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
# ==================== 配置 ====================
|
||||
|
||||
MYSQL_ENGINE = create_engine(
|
||||
"mysql+pymysql://root:Zzl5588161!@192.168.1.5:3306/stock",
|
||||
echo=False
|
||||
)
|
||||
|
||||
ES_CLIENT = Elasticsearch(['http://127.0.0.1:9200'])
|
||||
ES_INDEX = 'concept_library_v3'
|
||||
|
||||
CLICKHOUSE_CONFIG = {
|
||||
'host': '127.0.0.1',
|
||||
'port': 9000,
|
||||
'user': 'default',
|
||||
'password': 'Zzl33818!',
|
||||
'database': 'stock'
|
||||
}
|
||||
|
||||
# 输出目录
|
||||
OUTPUT_DIR = os.path.join(os.path.dirname(__file__), 'data')
|
||||
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
||||
|
||||
# 特征计算参数
|
||||
FEATURE_CONFIG = {
|
||||
'alpha_delta_window': 5, # Alpha变化窗口(分钟)
|
||||
'amt_ma_window': 20, # 成交额均值窗口(分钟)
|
||||
'limit_up_threshold': 9.8, # 涨停阈值(%)
|
||||
'limit_down_threshold': -9.8, # 跌停阈值(%)
|
||||
}
|
||||
|
||||
REFERENCE_INDEX = '000001.SH'
|
||||
|
||||
# ==================== 日志 ====================
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ==================== 工具函数 ====================
|
||||
|
||||
def get_ch_client():
|
||||
return Client(**CLICKHOUSE_CONFIG)
|
||||
|
||||
|
||||
def generate_id(name: str) -> str:
|
||||
return hashlib.md5(name.encode('utf-8')).hexdigest()[:16]
|
||||
|
||||
|
||||
def code_to_ch_format(code: str) -> str:
|
||||
if not code or len(code) != 6 or not code.isdigit():
|
||||
return None
|
||||
if code.startswith('6'):
|
||||
return f"{code}.SH"
|
||||
elif code.startswith('0') or code.startswith('3'):
|
||||
return f"{code}.SZ"
|
||||
else:
|
||||
return f"{code}.BJ"
|
||||
|
||||
|
||||
# ==================== 获取概念列表 ====================
|
||||
|
||||
def get_all_concepts() -> List[dict]:
|
||||
"""从ES获取所有叶子概念"""
|
||||
concepts = []
|
||||
|
||||
query = {
|
||||
"query": {"match_all": {}},
|
||||
"size": 100,
|
||||
"_source": ["concept_id", "concept", "stocks"]
|
||||
}
|
||||
|
||||
resp = ES_CLIENT.search(index=ES_INDEX, body=query, scroll='2m')
|
||||
scroll_id = resp['_scroll_id']
|
||||
hits = resp['hits']['hits']
|
||||
|
||||
while len(hits) > 0:
|
||||
for hit in hits:
|
||||
source = hit['_source']
|
||||
stocks = []
|
||||
if 'stocks' in source and isinstance(source['stocks'], list):
|
||||
for stock in source['stocks']:
|
||||
if isinstance(stock, dict) and 'code' in stock and stock['code']:
|
||||
stocks.append(stock['code'])
|
||||
|
||||
if stocks:
|
||||
concepts.append({
|
||||
'concept_id': source.get('concept_id'),
|
||||
'concept_name': source.get('concept'),
|
||||
'stocks': stocks
|
||||
})
|
||||
|
||||
resp = ES_CLIENT.scroll(scroll_id=scroll_id, scroll='2m')
|
||||
scroll_id = resp['_scroll_id']
|
||||
hits = resp['hits']['hits']
|
||||
|
||||
ES_CLIENT.clear_scroll(scroll_id=scroll_id)
|
||||
logger.info(f"获取到 {len(concepts)} 个概念")
|
||||
return concepts
|
||||
|
||||
|
||||
# ==================== 获取交易日列表 ====================
|
||||
|
||||
def get_trading_days(start_date: str, end_date: str) -> List[str]:
|
||||
"""获取交易日列表"""
|
||||
client = get_ch_client()
|
||||
|
||||
query = f"""
|
||||
SELECT DISTINCT toDate(timestamp) as trade_date
|
||||
FROM stock_minute
|
||||
WHERE toDate(timestamp) >= '{start_date}'
|
||||
AND toDate(timestamp) <= '{end_date}'
|
||||
ORDER BY trade_date
|
||||
"""
|
||||
|
||||
result = client.execute(query)
|
||||
days = [row[0].strftime('%Y-%m-%d') for row in result]
|
||||
logger.info(f"找到 {len(days)} 个交易日: {days[0]} ~ {days[-1]}")
|
||||
return days
|
||||
|
||||
|
||||
# ==================== 获取单日数据 ====================
|
||||
|
||||
def get_daily_stock_data(trade_date: str, stock_codes: List[str]) -> pd.DataFrame:
|
||||
"""获取单日所有股票的分钟数据"""
|
||||
client = get_ch_client()
|
||||
|
||||
# 转换代码格式
|
||||
ch_codes = []
|
||||
code_map = {}
|
||||
for code in stock_codes:
|
||||
ch_code = code_to_ch_format(code)
|
||||
if ch_code:
|
||||
ch_codes.append(ch_code)
|
||||
code_map[ch_code] = code
|
||||
|
||||
if not ch_codes:
|
||||
return pd.DataFrame()
|
||||
|
||||
ch_codes_str = "','".join(ch_codes)
|
||||
|
||||
query = f"""
|
||||
SELECT
|
||||
code,
|
||||
timestamp,
|
||||
close,
|
||||
volume,
|
||||
amt
|
||||
FROM stock_minute
|
||||
WHERE toDate(timestamp) = '{trade_date}'
|
||||
AND code IN ('{ch_codes_str}')
|
||||
ORDER BY code, timestamp
|
||||
"""
|
||||
|
||||
result = client.execute(query)
|
||||
|
||||
if not result:
|
||||
return pd.DataFrame()
|
||||
|
||||
df = pd.DataFrame(result, columns=['ch_code', 'timestamp', 'close', 'volume', 'amt'])
|
||||
df['code'] = df['ch_code'].map(code_map)
|
||||
df = df.dropna(subset=['code'])
|
||||
|
||||
return df[['code', 'timestamp', 'close', 'volume', 'amt']]
|
||||
|
||||
|
||||
def get_daily_index_data(trade_date: str, index_code: str = REFERENCE_INDEX) -> pd.DataFrame:
|
||||
"""获取单日指数分钟数据"""
|
||||
client = get_ch_client()
|
||||
|
||||
query = f"""
|
||||
SELECT
|
||||
timestamp,
|
||||
close,
|
||||
volume,
|
||||
amt
|
||||
FROM index_minute
|
||||
WHERE toDate(timestamp) = '{trade_date}'
|
||||
AND code = '{index_code}'
|
||||
ORDER BY timestamp
|
||||
"""
|
||||
|
||||
result = client.execute(query)
|
||||
|
||||
if not result:
|
||||
return pd.DataFrame()
|
||||
|
||||
df = pd.DataFrame(result, columns=['timestamp', 'close', 'volume', 'amt'])
|
||||
return df
|
||||
|
||||
|
||||
def get_prev_close(stock_codes: List[str], trade_date: str) -> Dict[str, float]:
|
||||
"""获取昨收价"""
|
||||
valid_codes = [c for c in stock_codes if c and len(c) == 6 and c.isdigit()]
|
||||
if not valid_codes:
|
||||
return {}
|
||||
|
||||
codes_str = "','".join(valid_codes)
|
||||
|
||||
query = f"""
|
||||
SELECT SECCODE, F002N
|
||||
FROM ea_trade
|
||||
WHERE SECCODE IN ('{codes_str}')
|
||||
AND TRADEDATE = (
|
||||
SELECT MAX(TRADEDATE) FROM ea_trade WHERE TRADEDATE < '{trade_date}'
|
||||
)
|
||||
AND F002N IS NOT NULL AND F002N > 0
|
||||
"""
|
||||
|
||||
try:
|
||||
with MYSQL_ENGINE.connect() as conn:
|
||||
result = conn.execute(text(query))
|
||||
return {row[0]: float(row[1]) for row in result if row[1]}
|
||||
except Exception as e:
|
||||
logger.error(f"获取昨收价失败: {e}")
|
||||
return {}
|
||||
|
||||
|
||||
def get_index_prev_close(trade_date: str, index_code: str = REFERENCE_INDEX) -> float:
|
||||
"""获取指数昨收价"""
|
||||
code_no_suffix = index_code.split('.')[0]
|
||||
|
||||
try:
|
||||
with MYSQL_ENGINE.connect() as conn:
|
||||
result = conn.execute(text("""
|
||||
SELECT F006N FROM ea_exchangetrade
|
||||
WHERE INDEXCODE = :code AND TRADEDATE < :today
|
||||
ORDER BY TRADEDATE DESC LIMIT 1
|
||||
"""), {'code': code_no_suffix, 'today': trade_date}).fetchone()
|
||||
|
||||
if result and result[0]:
|
||||
return float(result[0])
|
||||
except Exception as e:
|
||||
logger.error(f"获取指数昨收失败: {e}")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# ==================== 计算特征 ====================
|
||||
|
||||
def compute_daily_features(
|
||||
trade_date: str,
|
||||
concepts: List[dict],
|
||||
all_stocks: List[str]
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
计算单日所有概念的特征
|
||||
|
||||
返回 DataFrame:
|
||||
- index: (timestamp, concept_id)
|
||||
- columns: alpha, alpha_delta, amt_ratio, amt_delta, rank_pct, limit_up_ratio
|
||||
"""
|
||||
|
||||
# 1. 获取数据
|
||||
logger.info(f" 获取股票数据...")
|
||||
stock_df = get_daily_stock_data(trade_date, all_stocks)
|
||||
if stock_df.empty:
|
||||
logger.warning(f" 无股票数据")
|
||||
return pd.DataFrame()
|
||||
|
||||
logger.info(f" 获取指数数据...")
|
||||
index_df = get_daily_index_data(trade_date)
|
||||
if index_df.empty:
|
||||
logger.warning(f" 无指数数据")
|
||||
return pd.DataFrame()
|
||||
|
||||
# 2. 获取昨收价
|
||||
logger.info(f" 获取昨收价...")
|
||||
prev_close = get_prev_close(all_stocks, trade_date)
|
||||
index_prev_close = get_index_prev_close(trade_date)
|
||||
|
||||
if not prev_close or not index_prev_close:
|
||||
logger.warning(f" 无昨收价数据")
|
||||
return pd.DataFrame()
|
||||
|
||||
# 3. 计算股票涨跌幅和成交额
|
||||
stock_df['prev_close'] = stock_df['code'].map(prev_close)
|
||||
stock_df = stock_df.dropna(subset=['prev_close'])
|
||||
stock_df['change_pct'] = (stock_df['close'] - stock_df['prev_close']) / stock_df['prev_close'] * 100
|
||||
|
||||
# 4. 计算指数涨跌幅
|
||||
index_df['change_pct'] = (index_df['close'] - index_prev_close) / index_prev_close * 100
|
||||
index_change_map = dict(zip(index_df['timestamp'], index_df['change_pct']))
|
||||
|
||||
# 5. 获取所有时间点
|
||||
timestamps = sorted(stock_df['timestamp'].unique())
|
||||
logger.info(f" 时间点数: {len(timestamps)}")
|
||||
|
||||
# 6. 按时间点计算概念特征
|
||||
results = []
|
||||
|
||||
# 概念到股票的映射
|
||||
concept_stocks = {c['concept_id']: set(c['stocks']) for c in concepts}
|
||||
concept_names = {c['concept_id']: c['concept_name'] for c in concepts}
|
||||
|
||||
# 历史数据缓存(用于计算变化率)
|
||||
concept_history = {cid: {'alpha': [], 'amt': []} for cid in concept_stocks}
|
||||
|
||||
for ts in timestamps:
|
||||
ts_stock_data = stock_df[stock_df['timestamp'] == ts]
|
||||
index_change = index_change_map.get(ts, 0)
|
||||
|
||||
# 股票涨跌幅和成交额字典
|
||||
stock_change = dict(zip(ts_stock_data['code'], ts_stock_data['change_pct']))
|
||||
stock_amt = dict(zip(ts_stock_data['code'], ts_stock_data['amt']))
|
||||
|
||||
concept_features = []
|
||||
|
||||
for concept_id, stocks in concept_stocks.items():
|
||||
# 该概念的股票数据
|
||||
concept_changes = [stock_change[s] for s in stocks if s in stock_change]
|
||||
concept_amts = [stock_amt.get(s, 0) for s in stocks if s in stock_change]
|
||||
|
||||
if not concept_changes:
|
||||
continue
|
||||
|
||||
# 基础统计
|
||||
avg_change = np.mean(concept_changes)
|
||||
total_amt = sum(concept_amts)
|
||||
|
||||
# Alpha = 概念涨幅 - 指数涨幅
|
||||
alpha = avg_change - index_change
|
||||
|
||||
# 涨停/跌停股占比
|
||||
limit_up_count = sum(1 for c in concept_changes if c >= FEATURE_CONFIG['limit_up_threshold'])
|
||||
limit_down_count = sum(1 for c in concept_changes if c <= FEATURE_CONFIG['limit_down_threshold'])
|
||||
limit_up_ratio = limit_up_count / len(concept_changes)
|
||||
limit_down_ratio = limit_down_count / len(concept_changes)
|
||||
|
||||
# 更新历史
|
||||
history = concept_history[concept_id]
|
||||
history['alpha'].append(alpha)
|
||||
history['amt'].append(total_amt)
|
||||
|
||||
# 计算变化率
|
||||
alpha_delta = 0
|
||||
if len(history['alpha']) > FEATURE_CONFIG['alpha_delta_window']:
|
||||
alpha_delta = alpha - history['alpha'][-FEATURE_CONFIG['alpha_delta_window']-1]
|
||||
|
||||
# 成交额相对均值
|
||||
amt_ratio = 1.0
|
||||
amt_delta = 0
|
||||
if len(history['amt']) > FEATURE_CONFIG['amt_ma_window']:
|
||||
amt_ma = np.mean(history['amt'][-FEATURE_CONFIG['amt_ma_window']-1:-1])
|
||||
if amt_ma > 0:
|
||||
amt_ratio = total_amt / amt_ma
|
||||
amt_delta = total_amt - history['amt'][-2] if len(history['amt']) > 1 else 0
|
||||
|
||||
concept_features.append({
|
||||
'concept_id': concept_id,
|
||||
'alpha': alpha,
|
||||
'alpha_delta': alpha_delta,
|
||||
'amt_ratio': amt_ratio,
|
||||
'amt_delta': amt_delta,
|
||||
'limit_up_ratio': limit_up_ratio,
|
||||
'limit_down_ratio': limit_down_ratio,
|
||||
'total_amt': total_amt,
|
||||
'stock_count': len(concept_changes),
|
||||
})
|
||||
|
||||
if not concept_features:
|
||||
continue
|
||||
|
||||
# 计算排名百分位
|
||||
concept_df = pd.DataFrame(concept_features)
|
||||
concept_df['rank_pct'] = concept_df['alpha'].rank(pct=True)
|
||||
|
||||
# 添加时间戳
|
||||
concept_df['timestamp'] = ts
|
||||
results.append(concept_df)
|
||||
|
||||
if not results:
|
||||
return pd.DataFrame()
|
||||
|
||||
# 合并所有时间点
|
||||
final_df = pd.concat(results, ignore_index=True)
|
||||
|
||||
# 标准化成交额变化率
|
||||
if 'amt_delta' in final_df.columns:
|
||||
amt_delta_std = final_df['amt_delta'].std()
|
||||
if amt_delta_std > 0:
|
||||
final_df['amt_delta'] = final_df['amt_delta'] / amt_delta_std
|
||||
|
||||
logger.info(f" 计算完成: {len(final_df)} 条记录")
|
||||
return final_df
|
||||
|
||||
|
||||
# ==================== 主流程 ====================
|
||||
|
||||
def process_single_day(trade_date: str, concepts: List[dict], all_stocks: List[str]) -> str:
|
||||
"""处理单个交易日"""
|
||||
output_file = os.path.join(OUTPUT_DIR, f'features_{trade_date}.parquet')
|
||||
|
||||
# 检查是否已处理
|
||||
if os.path.exists(output_file):
|
||||
logger.info(f"[{trade_date}] 已存在,跳过")
|
||||
return output_file
|
||||
|
||||
logger.info(f"[{trade_date}] 开始处理...")
|
||||
|
||||
try:
|
||||
df = compute_daily_features(trade_date, concepts, all_stocks)
|
||||
|
||||
if df.empty:
|
||||
logger.warning(f"[{trade_date}] 无数据")
|
||||
return None
|
||||
|
||||
# 保存
|
||||
df.to_parquet(output_file, index=False)
|
||||
logger.info(f"[{trade_date}] 保存完成: {output_file}")
|
||||
return output_file
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[{trade_date}] 处理失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
|
||||
def main():
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description='准备训练数据')
|
||||
parser.add_argument('--start', type=str, default='2022-01-01', help='开始日期')
|
||||
parser.add_argument('--end', type=str, default=None, help='结束日期(默认今天)')
|
||||
parser.add_argument('--workers', type=int, default=1, help='并行数(建议1,避免数据库压力)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
end_date = args.end or datetime.now().strftime('%Y-%m-%d')
|
||||
|
||||
logger.info("=" * 60)
|
||||
logger.info("数据准备 - Transformer Autoencoder 训练数据")
|
||||
logger.info("=" * 60)
|
||||
logger.info(f"日期范围: {args.start} ~ {end_date}")
|
||||
|
||||
# 1. 获取概念列表
|
||||
concepts = get_all_concepts()
|
||||
|
||||
# 收集所有股票
|
||||
all_stocks = list(set(s for c in concepts for s in c['stocks']))
|
||||
logger.info(f"股票总数: {len(all_stocks)}")
|
||||
|
||||
# 2. 获取交易日列表
|
||||
trading_days = get_trading_days(args.start, end_date)
|
||||
|
||||
if not trading_days:
|
||||
logger.error("无交易日数据")
|
||||
return
|
||||
|
||||
# 3. 处理每个交易日
|
||||
logger.info(f"\n开始处理 {len(trading_days)} 个交易日...")
|
||||
|
||||
success_count = 0
|
||||
for i, trade_date in enumerate(trading_days):
|
||||
logger.info(f"\n[{i+1}/{len(trading_days)}] {trade_date}")
|
||||
result = process_single_day(trade_date, concepts, all_stocks)
|
||||
if result:
|
||||
success_count += 1
|
||||
|
||||
logger.info("\n" + "=" * 60)
|
||||
logger.info(f"处理完成: {success_count}/{len(trading_days)} 个交易日")
|
||||
logger.info(f"数据保存在: {OUTPUT_DIR}")
|
||||
logger.info("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
25
ml/requirements.txt
Normal file
25
ml/requirements.txt
Normal file
@@ -0,0 +1,25 @@
|
||||
# 概念异动检测 ML 模块依赖
|
||||
# 安装: pip install -r ml/requirements.txt
|
||||
|
||||
# PyTorch (根据 CUDA 版本选择)
|
||||
# 5090 显卡需要 CUDA 12.x
|
||||
# pip install torch --index-url https://download.pytorch.org/whl/cu124
|
||||
torch>=2.0.0
|
||||
|
||||
# 数据处理
|
||||
numpy>=1.24.0
|
||||
pandas>=2.0.0
|
||||
pyarrow>=14.0.0
|
||||
|
||||
# 数据库
|
||||
clickhouse-driver>=0.2.6
|
||||
elasticsearch>=7.0.0,<8.0.0
|
||||
sqlalchemy>=2.0.0
|
||||
pymysql>=1.1.0
|
||||
|
||||
# 训练工具
|
||||
tqdm>=4.65.0
|
||||
|
||||
# 可选: 可视化
|
||||
# matplotlib>=3.7.0
|
||||
# tensorboard>=2.14.0
|
||||
99
ml/run_training.sh
Normal file
99
ml/run_training.sh
Normal file
@@ -0,0 +1,99 @@
|
||||
#!/bin/bash
|
||||
# 概念异动检测模型训练脚本 (Linux)
|
||||
#
|
||||
# 使用方法:
|
||||
# chmod +x run_training.sh
|
||||
# ./run_training.sh
|
||||
#
|
||||
# 或指定参数:
|
||||
# ./run_training.sh --start 2022-01-01 --epochs 100
|
||||
|
||||
set -e
|
||||
|
||||
echo "============================================================"
|
||||
echo "概念异动检测模型训练流程"
|
||||
echo "============================================================"
|
||||
echo ""
|
||||
|
||||
# 获取脚本所在目录
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
cd "$SCRIPT_DIR/.."
|
||||
|
||||
echo "[1/4] 检查环境..."
|
||||
python3 --version || { echo "Python3 未找到!"; exit 1; }
|
||||
|
||||
# 检查 GPU
|
||||
if python3 -c "import torch; print(f'CUDA: {torch.cuda.is_available()}')" 2>/dev/null; then
|
||||
echo "PyTorch GPU 检测完成"
|
||||
else
|
||||
echo "警告: PyTorch 未安装或无法检测 GPU"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "[2/4] 检查依赖..."
|
||||
pip3 install -q torch pandas numpy pyarrow tqdm clickhouse-driver elasticsearch sqlalchemy pymysql
|
||||
|
||||
echo ""
|
||||
echo "[3/4] 准备训练数据..."
|
||||
echo "从 ClickHouse 提取历史数据,这可能需要较长时间..."
|
||||
echo ""
|
||||
|
||||
# 解析参数
|
||||
START_DATE="2022-01-01"
|
||||
END_DATE=""
|
||||
EPOCHS=100
|
||||
BATCH_SIZE=256
|
||||
TRAIN_END="2025-06-30"
|
||||
VAL_END="2025-09-30"
|
||||
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case $1 in
|
||||
--start)
|
||||
START_DATE="$2"
|
||||
shift 2
|
||||
;;
|
||||
--end)
|
||||
END_DATE="$2"
|
||||
shift 2
|
||||
;;
|
||||
--epochs)
|
||||
EPOCHS="$2"
|
||||
shift 2
|
||||
;;
|
||||
--batch_size)
|
||||
BATCH_SIZE="$2"
|
||||
shift 2
|
||||
;;
|
||||
--train_end)
|
||||
TRAIN_END="$2"
|
||||
shift 2
|
||||
;;
|
||||
--val_end)
|
||||
VAL_END="$2"
|
||||
shift 2
|
||||
;;
|
||||
*)
|
||||
shift
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
# 数据准备
|
||||
if [ -n "$END_DATE" ]; then
|
||||
python3 ml/prepare_data.py --start "$START_DATE" --end "$END_DATE"
|
||||
else
|
||||
python3 ml/prepare_data.py --start "$START_DATE"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "[4/4] 训练模型..."
|
||||
echo "使用 GPU 加速训练..."
|
||||
echo ""
|
||||
|
||||
python3 ml/train.py --epochs "$EPOCHS" --batch_size "$BATCH_SIZE" --train_end "$TRAIN_END" --val_end "$VAL_END"
|
||||
|
||||
echo ""
|
||||
echo "============================================================"
|
||||
echo "训练完成!"
|
||||
echo "模型保存在: ml/checkpoints/"
|
||||
echo "============================================================"
|
||||
808
ml/train.py
Normal file
808
ml/train.py
Normal file
@@ -0,0 +1,808 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Transformer Autoencoder 训练脚本 (修复版)
|
||||
|
||||
修复问题:
|
||||
1. 按概念分组构建序列,避免跨概念切片
|
||||
2. 按时间(日期)切分数据集,避免数据泄露
|
||||
3. 使用 RobustScaler + Clipping 处理非平稳性
|
||||
4. 使用验证集计算阈值
|
||||
|
||||
训练流程:
|
||||
1. 加载预处理好的特征数据(parquet 文件)
|
||||
2. 按概念分组,在每个概念内部构建序列
|
||||
3. 按日期划分训练/验证/测试集
|
||||
4. 训练 Autoencoder(最小化重构误差)
|
||||
5. 保存模型和阈值
|
||||
|
||||
使用方法:
|
||||
python train.py --data_dir ml/data --epochs 100 --batch_size 256
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple, Dict
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from torch.optim import AdamW
|
||||
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
|
||||
from tqdm import tqdm
|
||||
|
||||
from model import TransformerAutoencoder, AnomalyDetectionLoss, count_parameters
|
||||
|
||||
# 性能优化:启用 cuDNN benchmark(对固定输入尺寸自动选择最快算法)
|
||||
torch.backends.cudnn.benchmark = True
|
||||
# 启用 TF32(RTX 30/40 系列特有,提速约 3 倍)
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
# 可视化(可选)
|
||||
try:
|
||||
import matplotlib
|
||||
matplotlib.use('Agg') # 无头模式,不需要显示器
|
||||
import matplotlib.pyplot as plt
|
||||
HAS_MATPLOTLIB = True
|
||||
except ImportError:
|
||||
HAS_MATPLOTLIB = False
|
||||
|
||||
|
||||
# ==================== 配置 ====================
|
||||
|
||||
TRAIN_CONFIG = {
|
||||
# 数据配置
|
||||
'seq_len': 30, # 输入序列长度(30分钟)
|
||||
'stride': 5, # 滑动窗口步长
|
||||
|
||||
# 时间切分(按日期)
|
||||
'train_end_date': '2024-06-30', # 训练集截止日期
|
||||
'val_end_date': '2024-09-30', # 验证集截止日期(之后为测试集)
|
||||
|
||||
# 特征配置
|
||||
'features': [
|
||||
'alpha', # 超额收益
|
||||
'alpha_delta', # Alpha 变化率
|
||||
'amt_ratio', # 成交额比率
|
||||
'amt_delta', # 成交额变化率
|
||||
'rank_pct', # Alpha 排名百分位
|
||||
'limit_up_ratio', # 涨停比例
|
||||
],
|
||||
|
||||
# 训练配置(针对 4x RTX 4090 优化)
|
||||
'batch_size': 4096, # 256 -> 4096(大幅增加,充分利用显存)
|
||||
'epochs': 100,
|
||||
'learning_rate': 3e-4, # 1e-4 -> 3e-4(大 batch 需要更大学习率)
|
||||
'weight_decay': 1e-5,
|
||||
'gradient_clip': 1.0,
|
||||
|
||||
# 早停配置
|
||||
'patience': 10,
|
||||
'min_delta': 1e-6,
|
||||
|
||||
# 模型配置(LSTM Autoencoder,简洁有效)
|
||||
'model': {
|
||||
'n_features': 6,
|
||||
'hidden_dim': 32, # LSTM 隐藏维度(小)
|
||||
'latent_dim': 4, # 瓶颈维度(非常小!关键)
|
||||
'num_layers': 1, # LSTM 层数
|
||||
'dropout': 0.2,
|
||||
'bidirectional': True, # 双向编码器
|
||||
},
|
||||
|
||||
# 标准化配置
|
||||
'use_instance_norm': True, # 模型内部使用 Instance Norm(推荐)
|
||||
'clip_value': 10.0, # 简单截断极端值
|
||||
|
||||
# 阈值配置
|
||||
'threshold_percentiles': [90, 95, 99],
|
||||
}
|
||||
|
||||
|
||||
# ==================== 数据加载(修复版)====================
|
||||
|
||||
def load_data_by_date(data_dir: str, features: List[str]) -> Dict[str, pd.DataFrame]:
|
||||
"""
|
||||
按日期加载数据,返回 {date: DataFrame} 字典
|
||||
|
||||
每个 DataFrame 包含该日所有概念的所有时间点数据
|
||||
"""
|
||||
data_path = Path(data_dir)
|
||||
parquet_files = sorted(data_path.glob("features_*.parquet"))
|
||||
|
||||
if not parquet_files:
|
||||
raise FileNotFoundError(f"未找到 parquet 文件: {data_dir}")
|
||||
|
||||
print(f"找到 {len(parquet_files)} 个数据文件")
|
||||
|
||||
date_data = {}
|
||||
|
||||
for pf in tqdm(parquet_files, desc="加载数据"):
|
||||
# 提取日期
|
||||
date = pf.stem.replace('features_', '')
|
||||
|
||||
df = pd.read_parquet(pf)
|
||||
|
||||
# 检查必要列
|
||||
required_cols = features + ['concept_id', 'timestamp']
|
||||
missing_cols = [c for c in required_cols if c not in df.columns]
|
||||
if missing_cols:
|
||||
print(f"警告: {date} 缺少列: {missing_cols}, 跳过")
|
||||
continue
|
||||
|
||||
date_data[date] = df
|
||||
|
||||
print(f"成功加载 {len(date_data)} 天的数据")
|
||||
return date_data
|
||||
|
||||
|
||||
def split_data_by_date(
|
||||
date_data: Dict[str, pd.DataFrame],
|
||||
train_end: str,
|
||||
val_end: str
|
||||
) -> Tuple[Dict[str, pd.DataFrame], Dict[str, pd.DataFrame], Dict[str, pd.DataFrame]]:
|
||||
"""
|
||||
按日期严格划分数据集
|
||||
|
||||
- 训练集: <= train_end
|
||||
- 验证集: train_end < date <= val_end
|
||||
- 测试集: > val_end
|
||||
"""
|
||||
train_data = {}
|
||||
val_data = {}
|
||||
test_data = {}
|
||||
|
||||
for date, df in date_data.items():
|
||||
if date <= train_end:
|
||||
train_data[date] = df
|
||||
elif date <= val_end:
|
||||
val_data[date] = df
|
||||
else:
|
||||
test_data[date] = df
|
||||
|
||||
print(f"数据集划分(按日期):")
|
||||
print(f" 训练集: {len(train_data)} 天 (<= {train_end})")
|
||||
print(f" 验证集: {len(val_data)} 天 ({train_end} ~ {val_end})")
|
||||
print(f" 测试集: {len(test_data)} 天 (> {val_end})")
|
||||
|
||||
return train_data, val_data, test_data
|
||||
|
||||
|
||||
def build_sequences_by_concept(
|
||||
date_data: Dict[str, pd.DataFrame],
|
||||
features: List[str],
|
||||
seq_len: int,
|
||||
stride: int
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
按概念分组构建序列(性能优化版)
|
||||
|
||||
使用 groupby 一次性分组,避免重复扫描大数组
|
||||
|
||||
1. 将所有日期的数据合并
|
||||
2. 使用 groupby 按 concept_id 分组
|
||||
3. 在每个概念内部,按时间排序并滑动窗口
|
||||
4. 合并所有序列
|
||||
"""
|
||||
# 合并所有日期的数据
|
||||
all_dfs = []
|
||||
for date, df in sorted(date_data.items()):
|
||||
df = df.copy()
|
||||
df['date'] = date
|
||||
all_dfs.append(df)
|
||||
|
||||
if not all_dfs:
|
||||
return np.array([])
|
||||
|
||||
combined = pd.concat(all_dfs, ignore_index=True)
|
||||
|
||||
# 预先排序(按概念、日期、时间),这样 groupby 会更快
|
||||
combined = combined.sort_values(['concept_id', 'date', 'timestamp'])
|
||||
|
||||
# 使用 groupby 一次性分组(性能关键!)
|
||||
all_sequences = []
|
||||
grouped = combined.groupby('concept_id', sort=False)
|
||||
n_concepts = len(grouped)
|
||||
|
||||
for concept_id, concept_df in tqdm(grouped, desc="构建序列", total=n_concepts, leave=False):
|
||||
# 已经排序过了,直接提取特征
|
||||
feature_data = concept_df[features].values
|
||||
|
||||
# 处理缺失值
|
||||
feature_data = np.nan_to_num(feature_data, nan=0.0, posinf=0.0, neginf=0.0)
|
||||
|
||||
# 在该概念内部滑动窗口
|
||||
n_points = len(feature_data)
|
||||
for start in range(0, n_points - seq_len + 1, stride):
|
||||
seq = feature_data[start:start + seq_len]
|
||||
all_sequences.append(seq)
|
||||
|
||||
if not all_sequences:
|
||||
return np.array([])
|
||||
|
||||
sequences = np.array(all_sequences)
|
||||
print(f" 构建序列: {len(sequences):,} 条 (来自 {n_concepts} 个概念)")
|
||||
|
||||
return sequences
|
||||
|
||||
|
||||
# ==================== 数据集 ====================
|
||||
|
||||
class SequenceDataset(Dataset):
|
||||
"""序列数据集(已经构建好的序列)"""
|
||||
|
||||
def __init__(self, sequences: np.ndarray):
|
||||
self.sequences = torch.FloatTensor(sequences)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.sequences)
|
||||
|
||||
def __getitem__(self, idx: int) -> torch.Tensor:
|
||||
return self.sequences[idx]
|
||||
|
||||
|
||||
# ==================== 训练器 ====================
|
||||
|
||||
class EarlyStopping:
|
||||
"""早停机制"""
|
||||
|
||||
def __init__(self, patience: int = 10, min_delta: float = 1e-6):
|
||||
self.patience = patience
|
||||
self.min_delta = min_delta
|
||||
self.counter = 0
|
||||
self.best_loss = float('inf')
|
||||
self.early_stop = False
|
||||
|
||||
def __call__(self, val_loss: float) -> bool:
|
||||
if val_loss < self.best_loss - self.min_delta:
|
||||
self.best_loss = val_loss
|
||||
self.counter = 0
|
||||
else:
|
||||
self.counter += 1
|
||||
if self.counter >= self.patience:
|
||||
self.early_stop = True
|
||||
|
||||
return self.early_stop
|
||||
|
||||
|
||||
class Trainer:
|
||||
"""模型训练器(支持 AMP 混合精度加速)"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
train_loader: DataLoader,
|
||||
val_loader: DataLoader,
|
||||
config: Dict,
|
||||
device: torch.device,
|
||||
save_dir: str = 'ml/checkpoints'
|
||||
):
|
||||
self.model = model.to(device)
|
||||
self.train_loader = train_loader
|
||||
self.val_loader = val_loader
|
||||
self.config = config
|
||||
self.device = device
|
||||
self.save_dir = Path(save_dir)
|
||||
self.save_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 优化器
|
||||
self.optimizer = AdamW(
|
||||
model.parameters(),
|
||||
lr=config['learning_rate'],
|
||||
weight_decay=config['weight_decay']
|
||||
)
|
||||
|
||||
# 学习率调度器
|
||||
self.scheduler = CosineAnnealingWarmRestarts(
|
||||
self.optimizer,
|
||||
T_0=10,
|
||||
T_mult=2,
|
||||
eta_min=1e-6
|
||||
)
|
||||
|
||||
# 损失函数(简化版,只用 MSE)
|
||||
self.criterion = AnomalyDetectionLoss()
|
||||
|
||||
# 早停
|
||||
self.early_stopping = EarlyStopping(
|
||||
patience=config['patience'],
|
||||
min_delta=config['min_delta']
|
||||
)
|
||||
|
||||
# AMP 混合精度训练(大幅提速 + 省显存)
|
||||
self.use_amp = torch.cuda.is_available()
|
||||
self.scaler = torch.cuda.amp.GradScaler() if self.use_amp else None
|
||||
if self.use_amp:
|
||||
print(" ✓ 启用 AMP 混合精度训练")
|
||||
|
||||
# 训练历史
|
||||
self.history = {
|
||||
'train_loss': [],
|
||||
'val_loss': [],
|
||||
'learning_rate': [],
|
||||
}
|
||||
|
||||
self.best_val_loss = float('inf')
|
||||
|
||||
def train_epoch(self) -> float:
|
||||
"""训练一个 epoch(使用 AMP 混合精度)"""
|
||||
self.model.train()
|
||||
total_loss = 0.0
|
||||
n_batches = 0
|
||||
|
||||
pbar = tqdm(self.train_loader, desc="Training", leave=False)
|
||||
for batch in pbar:
|
||||
batch = batch.to(self.device, non_blocking=True) # 异步传输
|
||||
|
||||
self.optimizer.zero_grad(set_to_none=True) # 更快的梯度清零
|
||||
|
||||
# AMP 混合精度前向传播
|
||||
if self.use_amp:
|
||||
with torch.cuda.amp.autocast():
|
||||
output, latent = self.model(batch)
|
||||
loss, loss_dict = self.criterion(output, batch, latent)
|
||||
|
||||
# AMP 反向传播
|
||||
self.scaler.scale(loss).backward()
|
||||
|
||||
# 梯度裁剪(需要 unscale)
|
||||
self.scaler.unscale_(self.optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(
|
||||
self.model.parameters(),
|
||||
self.config['gradient_clip']
|
||||
)
|
||||
|
||||
self.scaler.step(self.optimizer)
|
||||
self.scaler.update()
|
||||
else:
|
||||
# 非 AMP 模式
|
||||
output, latent = self.model(batch)
|
||||
loss, loss_dict = self.criterion(output, batch, latent)
|
||||
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(
|
||||
self.model.parameters(),
|
||||
self.config['gradient_clip']
|
||||
)
|
||||
self.optimizer.step()
|
||||
|
||||
total_loss += loss.item()
|
||||
n_batches += 1
|
||||
|
||||
pbar.set_postfix({'loss': f"{loss.item():.4f}"})
|
||||
|
||||
return total_loss / n_batches
|
||||
|
||||
@torch.no_grad()
|
||||
def validate(self) -> float:
|
||||
"""验证(使用 AMP)"""
|
||||
self.model.eval()
|
||||
total_loss = 0.0
|
||||
n_batches = 0
|
||||
|
||||
for batch in self.val_loader:
|
||||
batch = batch.to(self.device, non_blocking=True)
|
||||
|
||||
if self.use_amp:
|
||||
with torch.cuda.amp.autocast():
|
||||
output, latent = self.model(batch)
|
||||
loss, _ = self.criterion(output, batch, latent)
|
||||
else:
|
||||
output, latent = self.model(batch)
|
||||
loss, _ = self.criterion(output, batch, latent)
|
||||
|
||||
total_loss += loss.item()
|
||||
n_batches += 1
|
||||
|
||||
return total_loss / n_batches
|
||||
|
||||
def save_checkpoint(self, epoch: int, val_loss: float, is_best: bool = False):
|
||||
"""保存检查点"""
|
||||
# 处理 DataParallel 包装
|
||||
model_to_save = self.model.module if hasattr(self.model, 'module') else self.model
|
||||
|
||||
checkpoint = {
|
||||
'epoch': epoch,
|
||||
'model_state_dict': model_to_save.state_dict(),
|
||||
'optimizer_state_dict': self.optimizer.state_dict(),
|
||||
'scheduler_state_dict': self.scheduler.state_dict(),
|
||||
'val_loss': val_loss,
|
||||
'config': self.config,
|
||||
}
|
||||
|
||||
# 保存最新检查点
|
||||
torch.save(checkpoint, self.save_dir / 'last_checkpoint.pt')
|
||||
|
||||
# 保存最佳模型
|
||||
if is_best:
|
||||
torch.save(checkpoint, self.save_dir / 'best_model.pt')
|
||||
print(f" ✓ 保存最佳模型 (val_loss: {val_loss:.6f})")
|
||||
|
||||
def train(self, epochs: int):
|
||||
"""完整训练流程"""
|
||||
print(f"\n开始训练 ({epochs} epochs)...")
|
||||
print(f"设备: {self.device}")
|
||||
print(f"模型参数量: {count_parameters(self.model):,}")
|
||||
|
||||
for epoch in range(1, epochs + 1):
|
||||
print(f"\nEpoch {epoch}/{epochs}")
|
||||
|
||||
# 训练
|
||||
train_loss = self.train_epoch()
|
||||
|
||||
# 验证
|
||||
val_loss = self.validate()
|
||||
|
||||
# 更新学习率
|
||||
self.scheduler.step()
|
||||
current_lr = self.optimizer.param_groups[0]['lr']
|
||||
|
||||
# 记录历史
|
||||
self.history['train_loss'].append(train_loss)
|
||||
self.history['val_loss'].append(val_loss)
|
||||
self.history['learning_rate'].append(current_lr)
|
||||
|
||||
# 打印进度
|
||||
print(f" Train Loss: {train_loss:.6f}")
|
||||
print(f" Val Loss: {val_loss:.6f}")
|
||||
print(f" LR: {current_lr:.2e}")
|
||||
|
||||
# 保存检查点
|
||||
is_best = val_loss < self.best_val_loss
|
||||
if is_best:
|
||||
self.best_val_loss = val_loss
|
||||
self.save_checkpoint(epoch, val_loss, is_best)
|
||||
|
||||
# 早停检查
|
||||
if self.early_stopping(val_loss):
|
||||
print(f"\n早停触发!验证损失已 {self.early_stopping.patience} 个 epoch 未改善")
|
||||
break
|
||||
|
||||
print(f"\n训练完成!最佳验证损失: {self.best_val_loss:.6f}")
|
||||
|
||||
# 保存训练历史
|
||||
self.save_history()
|
||||
|
||||
return self.history
|
||||
|
||||
def save_history(self):
|
||||
"""保存训练历史"""
|
||||
history_path = self.save_dir / 'training_history.json'
|
||||
with open(history_path, 'w') as f:
|
||||
json.dump(self.history, f, indent=2)
|
||||
print(f"训练历史已保存: {history_path}")
|
||||
|
||||
# 绘制训练曲线
|
||||
self.plot_training_curves()
|
||||
|
||||
def plot_training_curves(self):
|
||||
"""绘制训练曲线"""
|
||||
if not HAS_MATPLOTLIB:
|
||||
print("matplotlib 未安装,跳过绘图")
|
||||
return
|
||||
|
||||
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
|
||||
|
||||
epochs = range(1, len(self.history['train_loss']) + 1)
|
||||
|
||||
# 1. Loss 曲线
|
||||
ax1 = axes[0]
|
||||
ax1.plot(epochs, self.history['train_loss'], 'b-', label='Train Loss', linewidth=2)
|
||||
ax1.plot(epochs, self.history['val_loss'], 'r-', label='Val Loss', linewidth=2)
|
||||
ax1.set_xlabel('Epoch', fontsize=12)
|
||||
ax1.set_ylabel('Loss', fontsize=12)
|
||||
ax1.set_title('Training & Validation Loss', fontsize=14)
|
||||
ax1.legend(fontsize=11)
|
||||
ax1.grid(True, alpha=0.3)
|
||||
|
||||
# 标记最佳点
|
||||
best_epoch = np.argmin(self.history['val_loss']) + 1
|
||||
best_val_loss = min(self.history['val_loss'])
|
||||
ax1.axvline(x=best_epoch, color='g', linestyle='--', alpha=0.7, label=f'Best Epoch: {best_epoch}')
|
||||
ax1.scatter([best_epoch], [best_val_loss], color='g', s=100, zorder=5)
|
||||
ax1.annotate(f'Best: {best_val_loss:.6f}', xy=(best_epoch, best_val_loss),
|
||||
xytext=(best_epoch + 2, best_val_loss + 0.0005),
|
||||
fontsize=10, color='green')
|
||||
|
||||
# 2. 学习率曲线
|
||||
ax2 = axes[1]
|
||||
ax2.plot(epochs, self.history['learning_rate'], 'g-', linewidth=2)
|
||||
ax2.set_xlabel('Epoch', fontsize=12)
|
||||
ax2.set_ylabel('Learning Rate', fontsize=12)
|
||||
ax2.set_title('Learning Rate Schedule', fontsize=14)
|
||||
ax2.set_yscale('log')
|
||||
ax2.grid(True, alpha=0.3)
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
# 保存图片
|
||||
plot_path = self.save_dir / 'training_curves.png'
|
||||
plt.savefig(plot_path, dpi=150, bbox_inches='tight')
|
||||
plt.close()
|
||||
print(f"训练曲线已保存: {plot_path}")
|
||||
|
||||
|
||||
# ==================== 阈值计算(使用验证集)====================
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_thresholds(
|
||||
model: nn.Module,
|
||||
data_loader: DataLoader,
|
||||
device: torch.device,
|
||||
percentiles: List[float] = [90, 95, 99]
|
||||
) -> Dict[str, float]:
|
||||
"""
|
||||
在验证集上计算重构误差的百分位数阈值
|
||||
|
||||
注:使用验证集而非测试集,避免数据泄露
|
||||
"""
|
||||
model.eval()
|
||||
all_errors = []
|
||||
|
||||
print("计算异动阈值(使用验证集)...")
|
||||
for batch in tqdm(data_loader, desc="Computing thresholds"):
|
||||
batch = batch.to(device)
|
||||
errors = model.compute_reconstruction_error(batch, reduction='none')
|
||||
|
||||
# 取每个序列的最后一个时刻误差(预测当前时刻)
|
||||
seq_errors = errors[:, -1] # (batch,)
|
||||
all_errors.append(seq_errors.cpu().numpy())
|
||||
|
||||
all_errors = np.concatenate(all_errors)
|
||||
|
||||
thresholds = {}
|
||||
for p in percentiles:
|
||||
threshold = np.percentile(all_errors, p)
|
||||
thresholds[f'p{p}'] = float(threshold)
|
||||
print(f" P{p}: {threshold:.6f}")
|
||||
|
||||
# 额外统计
|
||||
thresholds['mean'] = float(np.mean(all_errors))
|
||||
thresholds['std'] = float(np.std(all_errors))
|
||||
thresholds['median'] = float(np.median(all_errors))
|
||||
|
||||
print(f" Mean: {thresholds['mean']:.6f}")
|
||||
print(f" Median: {thresholds['median']:.6f}")
|
||||
print(f" Std: {thresholds['std']:.6f}")
|
||||
|
||||
return thresholds
|
||||
|
||||
|
||||
# ==================== 主函数 ====================
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='训练概念异动检测模型')
|
||||
parser.add_argument('--data_dir', type=str, default='ml/data',
|
||||
help='数据目录路径')
|
||||
parser.add_argument('--epochs', type=int, default=100,
|
||||
help='训练轮数')
|
||||
parser.add_argument('--batch_size', type=int, default=4096,
|
||||
help='批次大小(4x RTX 4090 推荐 4096~8192)')
|
||||
parser.add_argument('--lr', type=float, default=3e-4,
|
||||
help='学习率(大 batch 推荐 3e-4)')
|
||||
parser.add_argument('--device', type=str, default='auto',
|
||||
help='设备 (auto/cuda/cpu)')
|
||||
parser.add_argument('--save_dir', type=str, default='ml/checkpoints',
|
||||
help='模型保存目录')
|
||||
parser.add_argument('--train_end', type=str, default='2024-06-30',
|
||||
help='训练集截止日期')
|
||||
parser.add_argument('--val_end', type=str, default='2024-09-30',
|
||||
help='验证集截止日期')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 更新配置
|
||||
config = TRAIN_CONFIG.copy()
|
||||
config['batch_size'] = args.batch_size
|
||||
config['epochs'] = args.epochs
|
||||
config['learning_rate'] = args.lr
|
||||
config['train_end_date'] = args.train_end
|
||||
config['val_end_date'] = args.val_end
|
||||
|
||||
# 设备选择
|
||||
if args.device == 'auto':
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
else:
|
||||
device = torch.device(args.device)
|
||||
|
||||
print("=" * 60)
|
||||
print("概念异动检测模型训练(修复版)")
|
||||
print("=" * 60)
|
||||
print(f"配置:")
|
||||
print(f" 数据目录: {args.data_dir}")
|
||||
print(f" 设备: {device}")
|
||||
print(f" 批次大小: {config['batch_size']}")
|
||||
print(f" 学习率: {config['learning_rate']}")
|
||||
print(f" 训练轮数: {config['epochs']}")
|
||||
print(f" 训练集截止: {config['train_end_date']}")
|
||||
print(f" 验证集截止: {config['val_end_date']}")
|
||||
print("=" * 60)
|
||||
|
||||
# 1. 按日期加载数据
|
||||
print("\n[1/6] 加载数据...")
|
||||
date_data = load_data_by_date(args.data_dir, config['features'])
|
||||
|
||||
# 2. 按日期划分
|
||||
print("\n[2/6] 按日期划分数据集...")
|
||||
train_data, val_data, test_data = split_data_by_date(
|
||||
date_data,
|
||||
config['train_end_date'],
|
||||
config['val_end_date']
|
||||
)
|
||||
|
||||
# 3. 按概念构建序列
|
||||
print("\n[3/6] 按概念构建序列...")
|
||||
print("训练集:")
|
||||
train_sequences = build_sequences_by_concept(
|
||||
train_data, config['features'], config['seq_len'], config['stride']
|
||||
)
|
||||
print("验证集:")
|
||||
val_sequences = build_sequences_by_concept(
|
||||
val_data, config['features'], config['seq_len'], config['stride']
|
||||
)
|
||||
print("测试集:")
|
||||
test_sequences = build_sequences_by_concept(
|
||||
test_data, config['features'], config['seq_len'], config['stride']
|
||||
)
|
||||
|
||||
if len(train_sequences) == 0:
|
||||
print("错误: 训练集为空!请检查数据和日期范围")
|
||||
return
|
||||
|
||||
# 4. 数据预处理(简单截断极端值,标准化在模型内部通过 Instance Norm 完成)
|
||||
print("\n[4/6] 数据预处理...")
|
||||
print(" 注意: 使用 Instance Norm,每个序列在模型内部单独标准化")
|
||||
print(" 这样可以处理不同概念波动率差异(银行 vs 半导体)")
|
||||
|
||||
clip_value = config['clip_value']
|
||||
print(f" 截断极端值: ±{clip_value}")
|
||||
|
||||
# 简单截断极端值(防止异常数据影响训练)
|
||||
train_sequences = np.clip(train_sequences, -clip_value, clip_value)
|
||||
if len(val_sequences) > 0:
|
||||
val_sequences = np.clip(val_sequences, -clip_value, clip_value)
|
||||
if len(test_sequences) > 0:
|
||||
test_sequences = np.clip(test_sequences, -clip_value, clip_value)
|
||||
|
||||
# 保存配置
|
||||
save_dir = Path(args.save_dir)
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
preprocess_params = {
|
||||
'features': config['features'],
|
||||
'normalization': 'instance_norm', # 在模型内部完成
|
||||
'clip_value': clip_value,
|
||||
'note': '标准化在模型内部通过 InstanceNorm1d 完成,无需外部 Scaler'
|
||||
}
|
||||
|
||||
with open(save_dir / 'normalization_stats.json', 'w') as f:
|
||||
json.dump(preprocess_params, f, indent=2)
|
||||
print(f" 预处理参数已保存")
|
||||
|
||||
# 5. 创建数据集和加载器
|
||||
print("\n[5/6] 创建数据加载器...")
|
||||
train_dataset = SequenceDataset(train_sequences)
|
||||
val_dataset = SequenceDataset(val_sequences) if len(val_sequences) > 0 else None
|
||||
test_dataset = SequenceDataset(test_sequences) if len(test_sequences) > 0 else None
|
||||
|
||||
print(f" 训练序列: {len(train_dataset):,}")
|
||||
print(f" 验证序列: {len(val_dataset) if val_dataset else 0:,}")
|
||||
print(f" 测试序列: {len(test_dataset) if test_dataset else 0:,}")
|
||||
|
||||
# 多卡时增加 num_workers(Linux 上可以用更多)
|
||||
n_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 1
|
||||
num_workers = min(32, 8 * n_gpus) if sys.platform != 'win32' else 0
|
||||
print(f" DataLoader workers: {num_workers}")
|
||||
print(f" Batch size: {config['batch_size']}")
|
||||
|
||||
# 大 batch + 多 worker + prefetch 提速
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=config['batch_size'],
|
||||
shuffle=True,
|
||||
num_workers=num_workers,
|
||||
pin_memory=True,
|
||||
prefetch_factor=4 if num_workers > 0 else None, # 预取更多 batch
|
||||
persistent_workers=True if num_workers > 0 else False, # 保持 worker 存活
|
||||
drop_last=True # 丢弃不完整的最后一批,避免 batch 大小不一致
|
||||
)
|
||||
|
||||
val_loader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=config['batch_size'] * 2, # 验证时可以用更大 batch(无梯度)
|
||||
shuffle=False,
|
||||
num_workers=num_workers,
|
||||
pin_memory=True,
|
||||
prefetch_factor=4 if num_workers > 0 else None,
|
||||
persistent_workers=True if num_workers > 0 else False,
|
||||
) if val_dataset else None
|
||||
|
||||
test_loader = DataLoader(
|
||||
test_dataset,
|
||||
batch_size=config['batch_size'] * 2,
|
||||
shuffle=False,
|
||||
num_workers=num_workers,
|
||||
pin_memory=True,
|
||||
prefetch_factor=4 if num_workers > 0 else None,
|
||||
persistent_workers=True if num_workers > 0 else False,
|
||||
) if test_dataset else None
|
||||
|
||||
# 6. 训练
|
||||
print("\n[6/6] 训练模型...")
|
||||
model_config = config['model'].copy()
|
||||
model = TransformerAutoencoder(**model_config)
|
||||
|
||||
# 多卡并行
|
||||
if torch.cuda.device_count() > 1:
|
||||
print(f" 使用 {torch.cuda.device_count()} 张 GPU 并行训练")
|
||||
model = nn.DataParallel(model)
|
||||
|
||||
if val_loader is None:
|
||||
print("警告: 验证集为空,将使用训练集的一部分作为验证")
|
||||
# 简单处理:用训练集的后 10% 作为验证
|
||||
split_idx = int(len(train_dataset) * 0.9)
|
||||
train_subset = torch.utils.data.Subset(train_dataset, range(split_idx))
|
||||
val_subset = torch.utils.data.Subset(train_dataset, range(split_idx, len(train_dataset)))
|
||||
|
||||
train_loader = DataLoader(train_subset, batch_size=config['batch_size'], shuffle=True, num_workers=num_workers, pin_memory=True)
|
||||
val_loader = DataLoader(val_subset, batch_size=config['batch_size'], shuffle=False, num_workers=num_workers, pin_memory=True)
|
||||
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
train_loader=train_loader,
|
||||
val_loader=val_loader,
|
||||
config=config,
|
||||
device=device,
|
||||
save_dir=args.save_dir
|
||||
)
|
||||
|
||||
history = trainer.train(config['epochs'])
|
||||
|
||||
# 7. 计算阈值(使用验证集)
|
||||
print("\n[额外] 计算异动阈值...")
|
||||
|
||||
# 加载最佳模型
|
||||
best_checkpoint = torch.load(
|
||||
save_dir / 'best_model.pt',
|
||||
map_location=device
|
||||
)
|
||||
model.load_state_dict(best_checkpoint['model_state_dict'])
|
||||
model.to(device)
|
||||
|
||||
# 使用验证集计算阈值(避免数据泄露)
|
||||
thresholds = compute_thresholds(
|
||||
model,
|
||||
val_loader,
|
||||
device,
|
||||
config['threshold_percentiles']
|
||||
)
|
||||
|
||||
# 保存阈值
|
||||
with open(save_dir / 'thresholds.json', 'w') as f:
|
||||
json.dump(thresholds, f, indent=2)
|
||||
print(f"阈值已保存")
|
||||
|
||||
# 保存完整配置
|
||||
with open(save_dir / 'config.json', 'w') as f:
|
||||
json.dump(config, f, indent=2)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("训练完成!")
|
||||
print("=" * 60)
|
||||
print(f"模型保存位置: {args.save_dir}")
|
||||
print(f" - best_model.pt: 最佳模型权重")
|
||||
print(f" - thresholds.json: 异动阈值")
|
||||
print(f" - normalization_stats.json: 标准化参数")
|
||||
print(f" - config.json: 训练配置")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
1017
public/htmls/卫星能源太阳翼.html
Normal file
1017
public/htmls/卫星能源太阳翼.html
Normal file
File diff suppressed because it is too large
Load Diff
588
public/htmls/厦门“十五五规划”.html
Normal file
588
public/htmls/厦门“十五五规划”.html
Normal file
File diff suppressed because one or more lines are too long
68
sql/concept_minute_alert.sql
Normal file
68
sql/concept_minute_alert.sql
Normal file
@@ -0,0 +1,68 @@
|
||||
-- 概念分钟级异动数据表
|
||||
-- 用于存储概念板块的实时异动信息,支持热点概览图表展示
|
||||
|
||||
CREATE TABLE IF NOT EXISTS concept_minute_alert (
|
||||
id BIGINT AUTO_INCREMENT PRIMARY KEY,
|
||||
concept_id VARCHAR(32) NOT NULL COMMENT '概念ID',
|
||||
concept_name VARCHAR(100) NOT NULL COMMENT '概念名称',
|
||||
alert_time DATETIME NOT NULL COMMENT '异动时间(精确到分钟)',
|
||||
alert_type VARCHAR(20) NOT NULL COMMENT '异动类型:surge(急涨)/limit_up(涨停增加)/rank_jump(排名跃升)',
|
||||
trade_date DATE NOT NULL COMMENT '交易日期',
|
||||
|
||||
-- 涨跌幅相关
|
||||
change_pct DECIMAL(10,4) COMMENT '当时涨跌幅(%)',
|
||||
prev_change_pct DECIMAL(10,4) COMMENT '之前涨跌幅(%)',
|
||||
change_delta DECIMAL(10,4) COMMENT '涨幅变化量(%)',
|
||||
|
||||
-- 涨停相关
|
||||
limit_up_count INT DEFAULT 0 COMMENT '当前涨停数量',
|
||||
prev_limit_up_count INT DEFAULT 0 COMMENT '之前涨停数量',
|
||||
limit_up_delta INT DEFAULT 0 COMMENT '涨停变化数量',
|
||||
|
||||
-- 排名相关
|
||||
rank_position INT COMMENT '当前涨幅排名',
|
||||
prev_rank_position INT COMMENT '之前涨幅排名',
|
||||
rank_delta INT COMMENT '排名变化(负数表示上升)',
|
||||
|
||||
-- 指数位置(用于图表Y轴定位)
|
||||
index_code VARCHAR(20) DEFAULT '000001.SH' COMMENT '参考指数代码',
|
||||
index_price DECIMAL(12,4) COMMENT '异动时的指数点位',
|
||||
index_change_pct DECIMAL(10,4) COMMENT '异动时的指数涨跌幅(%)',
|
||||
|
||||
-- 概念详情
|
||||
stock_count INT COMMENT '概念包含股票数',
|
||||
concept_type VARCHAR(20) DEFAULT 'leaf' COMMENT '概念类型:leaf/lv1/lv2/lv3',
|
||||
|
||||
-- 额外信息(JSON格式,存储涨停股票列表等)
|
||||
extra_info JSON COMMENT '额外信息',
|
||||
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
|
||||
-- 索引
|
||||
INDEX idx_trade_date (trade_date),
|
||||
INDEX idx_alert_time (alert_time),
|
||||
INDEX idx_concept_id (concept_id),
|
||||
INDEX idx_alert_type (alert_type),
|
||||
INDEX idx_trade_date_time (trade_date, alert_time)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='概念分钟级异动数据表';
|
||||
|
||||
|
||||
-- 创建指数分时快照表(用于异动时获取指数位置)
|
||||
CREATE TABLE IF NOT EXISTS index_minute_snapshot (
|
||||
id BIGINT AUTO_INCREMENT PRIMARY KEY,
|
||||
index_code VARCHAR(20) NOT NULL COMMENT '指数代码',
|
||||
trade_date DATE NOT NULL COMMENT '交易日期',
|
||||
snapshot_time DATETIME NOT NULL COMMENT '快照时间',
|
||||
|
||||
price DECIMAL(12,4) COMMENT '指数点位',
|
||||
open_price DECIMAL(12,4) COMMENT '开盘价',
|
||||
high_price DECIMAL(12,4) COMMENT '最高价',
|
||||
low_price DECIMAL(12,4) COMMENT '最低价',
|
||||
prev_close DECIMAL(12,4) COMMENT '昨收价',
|
||||
change_pct DECIMAL(10,4) COMMENT '涨跌幅(%)',
|
||||
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
|
||||
UNIQUE KEY uk_index_time (index_code, snapshot_time),
|
||||
INDEX idx_trade_date (trade_date)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='指数分时快照表';
|
||||
539
src/views/StockOverview/components/HotspotOverview/index.js
Normal file
539
src/views/StockOverview/components/HotspotOverview/index.js
Normal file
@@ -0,0 +1,539 @@
|
||||
/**
|
||||
* 热点概览组件
|
||||
* 展示大盘分时走势 + 概念异动标注
|
||||
*/
|
||||
import React, { useState, useEffect, useRef, useCallback } from 'react';
|
||||
import {
|
||||
Box,
|
||||
Card,
|
||||
CardBody,
|
||||
Heading,
|
||||
Text,
|
||||
HStack,
|
||||
VStack,
|
||||
Badge,
|
||||
Spinner,
|
||||
Center,
|
||||
Icon,
|
||||
Flex,
|
||||
Spacer,
|
||||
Tooltip,
|
||||
useColorModeValue,
|
||||
Stat,
|
||||
StatLabel,
|
||||
StatNumber,
|
||||
StatHelpText,
|
||||
StatArrow,
|
||||
SimpleGrid,
|
||||
} from '@chakra-ui/react';
|
||||
import { FaFire, FaRocket, FaChartLine, FaBolt, FaArrowDown } from 'react-icons/fa';
|
||||
import { InfoIcon } from '@chakra-ui/icons';
|
||||
import * as echarts from 'echarts';
|
||||
import { logger } from '@utils/logger';
|
||||
|
||||
const HotspotOverview = ({ selectedDate }) => {
|
||||
const chartRef = useRef(null);
|
||||
const chartInstance = useRef(null);
|
||||
const [loading, setLoading] = useState(true);
|
||||
const [data, setData] = useState(null);
|
||||
const [error, setError] = useState(null);
|
||||
|
||||
// 颜色主题
|
||||
const cardBg = useColorModeValue('white', '#1a1a1a');
|
||||
const borderColor = useColorModeValue('gray.200', '#333333');
|
||||
const textColor = useColorModeValue('gray.800', 'white');
|
||||
const subTextColor = useColorModeValue('gray.600', 'gray.400');
|
||||
|
||||
// 获取数据
|
||||
const fetchData = useCallback(async () => {
|
||||
setLoading(true);
|
||||
setError(null);
|
||||
|
||||
try {
|
||||
const dateParam = selectedDate
|
||||
? `?date=${selectedDate.toISOString().split('T')[0]}`
|
||||
: '';
|
||||
const response = await fetch(`/api/market/hotspot-overview${dateParam}`);
|
||||
const result = await response.json();
|
||||
|
||||
if (result.success) {
|
||||
setData(result.data);
|
||||
} else {
|
||||
setError(result.error || '获取数据失败');
|
||||
}
|
||||
} catch (err) {
|
||||
logger.error('HotspotOverview', 'fetchData', err);
|
||||
setError('网络请求失败');
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
}, [selectedDate]);
|
||||
|
||||
useEffect(() => {
|
||||
fetchData();
|
||||
}, [fetchData]);
|
||||
|
||||
// 渲染图表
|
||||
const renderChart = useCallback(() => {
|
||||
if (!chartRef.current || !data) return;
|
||||
|
||||
if (!chartInstance.current) {
|
||||
chartInstance.current = echarts.init(chartRef.current);
|
||||
}
|
||||
|
||||
const { index, alerts } = data;
|
||||
const timeline = index.timeline || [];
|
||||
|
||||
// 准备数据
|
||||
const times = timeline.map((d) => d.time);
|
||||
const prices = timeline.map((d) => d.price);
|
||||
const volumes = timeline.map((d) => d.volume);
|
||||
const changePcts = timeline.map((d) => d.change_pct);
|
||||
|
||||
// 计算Y轴范围
|
||||
const priceMin = Math.min(...prices.filter(Boolean));
|
||||
const priceMax = Math.max(...prices.filter(Boolean));
|
||||
const priceRange = priceMax - priceMin;
|
||||
const yAxisMin = priceMin - priceRange * 0.1;
|
||||
const yAxisMax = priceMax + priceRange * 0.2; // 上方留更多空间给标注
|
||||
|
||||
// 准备异动标注 - 按重要性排序,限制显示数量
|
||||
const sortedAlerts = [...alerts]
|
||||
.sort((a, b) => (b.importance_score || 0) - (a.importance_score || 0))
|
||||
.slice(0, 15); // 最多显示15个标注,避免图表过于密集
|
||||
|
||||
const markPoints = sortedAlerts.map((alert) => {
|
||||
// 找到对应时间的价格
|
||||
const timeIndex = times.indexOf(alert.time);
|
||||
const price = timeIndex >= 0 ? prices[timeIndex] : (alert.index_price || priceMax);
|
||||
|
||||
// 根据异动类型设置颜色和符号
|
||||
let color = '#ff6b6b';
|
||||
let symbol = 'pin';
|
||||
let symbolSize = 35;
|
||||
|
||||
// 暴涨
|
||||
if (alert.alert_type === 'surge_up' || alert.alert_type === 'surge') {
|
||||
color = '#ff4757';
|
||||
symbol = 'triangle';
|
||||
symbolSize = 30 + Math.min((alert.importance_score || 0.5) * 20, 15); // 根据重要性调整大小
|
||||
}
|
||||
// 暴跌
|
||||
else if (alert.alert_type === 'surge_down') {
|
||||
color = '#2ed573';
|
||||
symbol = 'path://M0,0 L10,0 L5,10 Z'; // 向下三角形
|
||||
symbolSize = 30 + Math.min((alert.importance_score || 0.5) * 20, 15);
|
||||
}
|
||||
// 涨停增加
|
||||
else if (alert.alert_type === 'limit_up') {
|
||||
color = '#ff6348';
|
||||
symbol = 'diamond';
|
||||
symbolSize = 28;
|
||||
}
|
||||
// 排名跃升
|
||||
else if (alert.alert_type === 'rank_jump') {
|
||||
color = '#3742fa';
|
||||
symbol = 'circle';
|
||||
symbolSize = 25;
|
||||
}
|
||||
|
||||
// 格式化标签 - 简化显示
|
||||
let label = alert.concept_name;
|
||||
// 截断过长的名称
|
||||
if (label.length > 8) {
|
||||
label = label.substring(0, 7) + '...';
|
||||
}
|
||||
|
||||
// 添加变化信息
|
||||
const changeDelta = alert.change_delta;
|
||||
if (changeDelta) {
|
||||
const sign = changeDelta > 0 ? '+' : '';
|
||||
label += `\n${sign}${changeDelta.toFixed(1)}%`;
|
||||
}
|
||||
|
||||
return {
|
||||
name: alert.concept_name,
|
||||
coord: [alert.time, price],
|
||||
value: label,
|
||||
symbol: symbol,
|
||||
symbolSize: symbolSize,
|
||||
itemStyle: {
|
||||
color: color,
|
||||
borderColor: '#fff',
|
||||
borderWidth: 1,
|
||||
shadowBlur: 3,
|
||||
shadowColor: 'rgba(0,0,0,0.2)',
|
||||
},
|
||||
label: {
|
||||
show: true,
|
||||
position: alert.alert_type === 'surge_down' ? 'bottom' : 'top', // 暴跌标签在下方
|
||||
formatter: '{b}',
|
||||
fontSize: 9,
|
||||
color: textColor,
|
||||
backgroundColor: alert.alert_type === 'surge_down'
|
||||
? 'rgba(46, 213, 115, 0.9)'
|
||||
: 'rgba(255,255,255,0.9)',
|
||||
padding: [2, 4],
|
||||
borderRadius: 2,
|
||||
borderColor: color,
|
||||
borderWidth: 1,
|
||||
},
|
||||
// 存储额外信息用于 tooltip
|
||||
alertData: alert,
|
||||
};
|
||||
});
|
||||
|
||||
// 渐变色 - 根据涨跌
|
||||
const latestChangePct = changePcts[changePcts.length - 1] || 0;
|
||||
const areaColorStops = latestChangePct >= 0
|
||||
? [
|
||||
{ offset: 0, color: 'rgba(255, 77, 77, 0.4)' },
|
||||
{ offset: 1, color: 'rgba(255, 77, 77, 0.05)' },
|
||||
]
|
||||
: [
|
||||
{ offset: 0, color: 'rgba(34, 197, 94, 0.4)' },
|
||||
{ offset: 1, color: 'rgba(34, 197, 94, 0.05)' },
|
||||
];
|
||||
|
||||
const lineColor = latestChangePct >= 0 ? '#ff4d4d' : '#22c55e';
|
||||
|
||||
const option = {
|
||||
backgroundColor: 'transparent',
|
||||
tooltip: {
|
||||
trigger: 'axis',
|
||||
axisPointer: {
|
||||
type: 'cross',
|
||||
crossStyle: {
|
||||
color: '#999',
|
||||
},
|
||||
},
|
||||
formatter: function (params) {
|
||||
if (!params || params.length === 0) return '';
|
||||
|
||||
const dataIndex = params[0].dataIndex;
|
||||
const time = times[dataIndex];
|
||||
const price = prices[dataIndex];
|
||||
const changePct = changePcts[dataIndex];
|
||||
const volume = volumes[dataIndex];
|
||||
|
||||
let html = `
|
||||
<div style="padding: 8px;">
|
||||
<div style="font-weight: bold; margin-bottom: 4px;">${time}</div>
|
||||
<div>指数: <span style="color: ${changePct >= 0 ? '#ff4d4d' : '#22c55e'}; font-weight: bold;">${price?.toFixed(2)}</span></div>
|
||||
<div>涨跌: <span style="color: ${changePct >= 0 ? '#ff4d4d' : '#22c55e'};">${changePct >= 0 ? '+' : ''}${changePct?.toFixed(2)}%</span></div>
|
||||
<div>成交量: ${(volume / 10000).toFixed(0)}万手</div>
|
||||
</div>
|
||||
`;
|
||||
|
||||
// 检查是否有异动
|
||||
const alertsAtTime = alerts.filter((a) => a.time === time);
|
||||
if (alertsAtTime.length > 0) {
|
||||
html += '<div style="border-top: 1px solid #eee; margin-top: 4px; padding-top: 4px;">';
|
||||
html += '<div style="font-weight: bold; color: #ff6b6b;">概念异动:</div>';
|
||||
alertsAtTime.forEach((alert) => {
|
||||
const typeLabel = {
|
||||
surge: '急涨',
|
||||
surge_up: '暴涨',
|
||||
surge_down: '暴跌',
|
||||
limit_up: '涨停增加',
|
||||
rank_jump: '排名跃升',
|
||||
}[alert.alert_type] || alert.alert_type;
|
||||
const typeColor = alert.alert_type === 'surge_down' ? '#2ed573' : '#ff6b6b';
|
||||
const delta = alert.change_delta ? ` (${alert.change_delta > 0 ? '+' : ''}${alert.change_delta.toFixed(2)}%)` : '';
|
||||
const zscore = alert.zscore ? ` Z=${alert.zscore.toFixed(1)}` : '';
|
||||
html += `<div style="color: ${typeColor}">• ${alert.concept_name} (${typeLabel}${delta}${zscore})</div>`;
|
||||
});
|
||||
html += '</div>';
|
||||
}
|
||||
|
||||
return html;
|
||||
},
|
||||
},
|
||||
legend: {
|
||||
show: false,
|
||||
},
|
||||
grid: [
|
||||
{
|
||||
left: '8%',
|
||||
right: '3%',
|
||||
top: '8%',
|
||||
height: '55%',
|
||||
},
|
||||
{
|
||||
left: '8%',
|
||||
right: '3%',
|
||||
top: '70%',
|
||||
height: '20%',
|
||||
},
|
||||
],
|
||||
xAxis: [
|
||||
{
|
||||
type: 'category',
|
||||
data: times,
|
||||
axisLine: { lineStyle: { color: '#ddd' } },
|
||||
axisLabel: {
|
||||
color: subTextColor,
|
||||
fontSize: 10,
|
||||
interval: Math.floor(times.length / 6),
|
||||
},
|
||||
axisTick: { show: false },
|
||||
splitLine: { show: false },
|
||||
},
|
||||
{
|
||||
type: 'category',
|
||||
gridIndex: 1,
|
||||
data: times,
|
||||
axisLine: { lineStyle: { color: '#ddd' } },
|
||||
axisLabel: { show: false },
|
||||
axisTick: { show: false },
|
||||
splitLine: { show: false },
|
||||
},
|
||||
],
|
||||
yAxis: [
|
||||
{
|
||||
type: 'value',
|
||||
min: yAxisMin,
|
||||
max: yAxisMax,
|
||||
axisLine: { show: false },
|
||||
axisLabel: {
|
||||
color: subTextColor,
|
||||
fontSize: 10,
|
||||
formatter: (val) => val.toFixed(0),
|
||||
},
|
||||
splitLine: {
|
||||
lineStyle: { color: '#eee', type: 'dashed' },
|
||||
},
|
||||
// 右侧显示涨跌幅
|
||||
axisPointer: {
|
||||
label: {
|
||||
formatter: function (params) {
|
||||
const pct = ((params.value - index.prev_close) / index.prev_close) * 100;
|
||||
return `${params.value.toFixed(2)} (${pct >= 0 ? '+' : ''}${pct.toFixed(2)}%)`;
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
type: 'value',
|
||||
gridIndex: 1,
|
||||
axisLine: { show: false },
|
||||
axisLabel: { show: false },
|
||||
splitLine: { show: false },
|
||||
},
|
||||
],
|
||||
series: [
|
||||
// 分时线
|
||||
{
|
||||
name: '上证指数',
|
||||
type: 'line',
|
||||
data: prices,
|
||||
smooth: true,
|
||||
symbol: 'none',
|
||||
lineStyle: {
|
||||
color: lineColor,
|
||||
width: 1.5,
|
||||
},
|
||||
areaStyle: {
|
||||
color: new echarts.graphic.LinearGradient(0, 0, 0, 1, areaColorStops),
|
||||
},
|
||||
markPoint: {
|
||||
symbol: 'pin',
|
||||
symbolSize: 40,
|
||||
data: markPoints,
|
||||
animation: true,
|
||||
},
|
||||
},
|
||||
// 成交量
|
||||
{
|
||||
name: '成交量',
|
||||
type: 'bar',
|
||||
xAxisIndex: 1,
|
||||
yAxisIndex: 1,
|
||||
data: volumes.map((v, i) => ({
|
||||
value: v,
|
||||
itemStyle: {
|
||||
color: changePcts[i] >= 0 ? 'rgba(255, 77, 77, 0.6)' : 'rgba(34, 197, 94, 0.6)',
|
||||
},
|
||||
})),
|
||||
barWidth: '60%',
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
chartInstance.current.setOption(option, true);
|
||||
}, [data, textColor, subTextColor]);
|
||||
|
||||
// 数据变化时重新渲染
|
||||
useEffect(() => {
|
||||
if (data) {
|
||||
renderChart();
|
||||
}
|
||||
}, [data, renderChart]);
|
||||
|
||||
// 窗口大小变化时重新渲染
|
||||
useEffect(() => {
|
||||
const handleResize = () => {
|
||||
if (chartInstance.current) {
|
||||
chartInstance.current.resize();
|
||||
}
|
||||
};
|
||||
|
||||
window.addEventListener('resize', handleResize);
|
||||
return () => {
|
||||
window.removeEventListener('resize', handleResize);
|
||||
if (chartInstance.current) {
|
||||
chartInstance.current.dispose();
|
||||
chartInstance.current = null;
|
||||
}
|
||||
};
|
||||
}, []);
|
||||
|
||||
// 异动类型标签
|
||||
const AlertTypeBadge = ({ type, count }) => {
|
||||
const config = {
|
||||
surge: { label: '急涨', color: 'red', icon: FaBolt },
|
||||
surge_up: { label: '暴涨', color: 'red', icon: FaBolt },
|
||||
surge_down: { label: '暴跌', color: 'green', icon: FaArrowDown },
|
||||
limit_up: { label: '涨停', color: 'orange', icon: FaRocket },
|
||||
rank_jump: { label: '排名跃升', color: 'blue', icon: FaChartLine },
|
||||
};
|
||||
|
||||
const cfg = config[type] || { label: type, color: 'gray', icon: FaFire };
|
||||
|
||||
return (
|
||||
<Badge colorScheme={cfg.color} variant="subtle" px={2} py={1} borderRadius="md">
|
||||
<HStack spacing={1}>
|
||||
<Icon as={cfg.icon} boxSize={3} />
|
||||
<Text>{cfg.label}</Text>
|
||||
<Text fontWeight="bold">{count}</Text>
|
||||
</HStack>
|
||||
</Badge>
|
||||
);
|
||||
};
|
||||
|
||||
if (loading) {
|
||||
return (
|
||||
<Card bg={cardBg} borderWidth="1px" borderColor={borderColor}>
|
||||
<CardBody>
|
||||
<Center h="400px">
|
||||
<VStack spacing={4}>
|
||||
<Spinner size="xl" color="purple.500" thickness="4px" />
|
||||
<Text color={subTextColor}>加载热点概览数据...</Text>
|
||||
</VStack>
|
||||
</Center>
|
||||
</CardBody>
|
||||
</Card>
|
||||
);
|
||||
}
|
||||
|
||||
if (error) {
|
||||
return (
|
||||
<Card bg={cardBg} borderWidth="1px" borderColor={borderColor}>
|
||||
<CardBody>
|
||||
<Center h="400px">
|
||||
<VStack spacing={4}>
|
||||
<Icon as={InfoIcon} boxSize={10} color="red.400" />
|
||||
<Text color="red.500">{error}</Text>
|
||||
</VStack>
|
||||
</Center>
|
||||
</CardBody>
|
||||
</Card>
|
||||
);
|
||||
}
|
||||
|
||||
if (!data) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const { index, alerts, alert_summary } = data;
|
||||
|
||||
return (
|
||||
<Card bg={cardBg} borderWidth="1px" borderColor={borderColor}>
|
||||
<CardBody>
|
||||
{/* 头部信息 */}
|
||||
<Flex align="center" mb={4}>
|
||||
<HStack spacing={3}>
|
||||
<Icon as={FaFire} boxSize={6} color="orange.500" />
|
||||
<Heading size="md" color={textColor}>
|
||||
热点概览
|
||||
</Heading>
|
||||
</HStack>
|
||||
<Spacer />
|
||||
<Tooltip label="展示大盘走势与概念异动的关联">
|
||||
<Icon as={InfoIcon} color={subTextColor} />
|
||||
</Tooltip>
|
||||
</Flex>
|
||||
|
||||
{/* 指数统计 */}
|
||||
<SimpleGrid columns={{ base: 2, md: 4 }} spacing={4} mb={4}>
|
||||
<Stat size="sm">
|
||||
<StatLabel color={subTextColor}>{index.name}</StatLabel>
|
||||
<StatNumber
|
||||
fontSize="xl"
|
||||
color={index.change_pct >= 0 ? 'red.500' : 'green.500'}
|
||||
>
|
||||
{index.latest_price?.toFixed(2)}
|
||||
</StatNumber>
|
||||
<StatHelpText mb={0}>
|
||||
<StatArrow type={index.change_pct >= 0 ? 'increase' : 'decrease'} />
|
||||
{index.change_pct?.toFixed(2)}%
|
||||
</StatHelpText>
|
||||
</Stat>
|
||||
|
||||
<Stat size="sm">
|
||||
<StatLabel color={subTextColor}>最高</StatLabel>
|
||||
<StatNumber fontSize="xl" color="red.500">
|
||||
{index.high?.toFixed(2)}
|
||||
</StatNumber>
|
||||
</Stat>
|
||||
|
||||
<Stat size="sm">
|
||||
<StatLabel color={subTextColor}>最低</StatLabel>
|
||||
<StatNumber fontSize="xl" color="green.500">
|
||||
{index.low?.toFixed(2)}
|
||||
</StatNumber>
|
||||
</Stat>
|
||||
|
||||
<Stat size="sm">
|
||||
<StatLabel color={subTextColor}>异动次数</StatLabel>
|
||||
<StatNumber fontSize="xl" color="orange.500">
|
||||
{alerts.length}
|
||||
</StatNumber>
|
||||
</Stat>
|
||||
</SimpleGrid>
|
||||
|
||||
{/* 异动类型统计 */}
|
||||
{alerts.length > 0 && (
|
||||
<HStack spacing={2} mb={4} flexWrap="wrap">
|
||||
{(alert_summary.surge_up > 0 || alert_summary.surge > 0) && (
|
||||
<AlertTypeBadge type="surge_up" count={(alert_summary.surge_up || 0) + (alert_summary.surge || 0)} />
|
||||
)}
|
||||
{alert_summary.surge_down > 0 && (
|
||||
<AlertTypeBadge type="surge_down" count={alert_summary.surge_down} />
|
||||
)}
|
||||
{alert_summary.limit_up > 0 && (
|
||||
<AlertTypeBadge type="limit_up" count={alert_summary.limit_up} />
|
||||
)}
|
||||
{alert_summary.rank_jump > 0 && (
|
||||
<AlertTypeBadge type="rank_jump" count={alert_summary.rank_jump} />
|
||||
)}
|
||||
</HStack>
|
||||
)}
|
||||
|
||||
{/* 图表 */}
|
||||
<Box ref={chartRef} h="400px" w="100%" />
|
||||
|
||||
{/* 无异动提示 */}
|
||||
{alerts.length === 0 && (
|
||||
<Center py={4}>
|
||||
<Text color={subTextColor} fontSize="sm">
|
||||
当日暂无概念异动数据
|
||||
</Text>
|
||||
</Center>
|
||||
)}
|
||||
</CardBody>
|
||||
</Card>
|
||||
);
|
||||
};
|
||||
|
||||
export default HotspotOverview;
|
||||
@@ -53,6 +53,7 @@ import { SearchIcon, CloseIcon, ArrowForwardIcon, TrendingUpIcon, InfoIcon, Chev
|
||||
import { FaChartLine, FaFire, FaRocket, FaBrain, FaCalendarAlt, FaChevronRight, FaArrowUp, FaArrowDown, FaChartBar, FaTag, FaLayerGroup, FaBolt } from 'react-icons/fa';
|
||||
import ConceptStocksModal from '@components/ConceptStocksModal';
|
||||
import TradeDatePicker from '@components/TradeDatePicker';
|
||||
import HotspotOverview from './components/HotspotOverview';
|
||||
import { BsGraphUp, BsLightningFill } from 'react-icons/bs';
|
||||
import * as echarts from 'echarts';
|
||||
import { logger } from '../../utils/logger';
|
||||
@@ -840,6 +841,11 @@ const StockOverview = () => {
|
||||
)}
|
||||
</Box>
|
||||
|
||||
{/* 热点概览 - 大盘走势 + 概念异动 */}
|
||||
<Box mb={10}>
|
||||
<HotspotOverview selectedDate={selectedDate} />
|
||||
</Box>
|
||||
|
||||
{/* 今日热门概念 */}
|
||||
<Box mb={10}>
|
||||
<Flex align="center" mb={6}>
|
||||
|
||||
Reference in New Issue
Block a user