Merge branch 'feature_bugfix/251201_py_h5_ui' into feature_2025/251209_stock_pref
* feature_bugfix/251201_py_h5_ui: feat: Company 页面搜索框添加股票模糊搜索功能 update pay ui update pay ui fix: 个股中心bug修复 update pay ui update pay ui update pay ui update pay ui update pay ui update pay ui update pay ui update pay ui update pay ui update pay ui update pay ui update pay ui update pay ui update pay ui feat: 替换公众号文件 update pay ui
This commit is contained in:
760
app.py
760
app.py
@@ -6412,6 +6412,10 @@ def get_stock_kline(stock_code):
|
||||
except ValueError:
|
||||
return jsonify({'error': 'Invalid event_time format'}), 400
|
||||
|
||||
# 确保股票代码包含后缀(ClickHouse 中数据带后缀)
|
||||
if '.' not in stock_code:
|
||||
stock_code = f"{stock_code}.SH" if stock_code.startswith('6') else f"{stock_code}.SZ"
|
||||
|
||||
# 获取股票名称
|
||||
with engine.connect() as conn:
|
||||
result = conn.execute(text(
|
||||
@@ -7819,7 +7823,7 @@ def get_index_realtime(index_code):
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取指数实时行情失败: {index_code}, 错误: {str(e)}")
|
||||
app.logger.error(f"获取指数实时行情失败: {index_code}, 错误: {str(e)}")
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'error': str(e),
|
||||
@@ -7837,8 +7841,13 @@ def get_index_kline(index_code):
|
||||
except ValueError:
|
||||
return jsonify({'error': 'Invalid event_time format'}), 400
|
||||
|
||||
# 确保指数代码包含后缀(ClickHouse 中数据带后缀)
|
||||
# 399xxx -> 深交所, 其他(000xxx等)-> 上交所
|
||||
if '.' not in index_code:
|
||||
index_code = f"{index_code}.SZ" if index_code.startswith('39') else f"{index_code}.SH"
|
||||
|
||||
# 指数名称(暂无索引表,先返回代码本身)
|
||||
index_name = index_code
|
||||
index_name = index_code.split('.')[0]
|
||||
|
||||
if chart_type == 'minute':
|
||||
return get_index_minute_kline(index_code, event_datetime, index_name)
|
||||
@@ -12044,10 +12053,11 @@ def get_market_summary(seccode):
|
||||
|
||||
@app.route('/api/stocks/search', methods=['GET'])
|
||||
def search_stocks():
|
||||
"""搜索股票(支持股票代码、股票简称、拼音首字母)"""
|
||||
"""搜索股票和指数(支持代码、名称搜索)"""
|
||||
try:
|
||||
query = request.args.get('q', '').strip()
|
||||
limit = request.args.get('limit', 20, type=int)
|
||||
search_type = request.args.get('type', 'all') # all, stock, index
|
||||
|
||||
if not query:
|
||||
return jsonify({
|
||||
@@ -12055,73 +12065,132 @@ def search_stocks():
|
||||
'error': '请输入搜索关键词'
|
||||
}), 400
|
||||
|
||||
results = []
|
||||
|
||||
with engine.connect() as conn:
|
||||
test_sql = text("""
|
||||
SELECT SECCODE, SECNAME, F001V, F003V, F010V, F011V
|
||||
FROM ea_stocklist
|
||||
WHERE SECCODE = '300750'
|
||||
OR F001V LIKE '%ndsd%' LIMIT 5
|
||||
""")
|
||||
test_result = conn.execute(test_sql).fetchall()
|
||||
# 搜索指数(优先显示指数,因为通常用户搜索代码时指数更常用)
|
||||
if search_type in ('all', 'index'):
|
||||
index_sql = text("""
|
||||
SELECT DISTINCT
|
||||
INDEXCODE as stock_code,
|
||||
SECNAME as stock_name,
|
||||
INDEXNAME as full_name,
|
||||
F018V as exchange
|
||||
FROM ea_exchangeindex
|
||||
WHERE (
|
||||
UPPER(INDEXCODE) LIKE UPPER(:query_pattern)
|
||||
OR UPPER(SECNAME) LIKE UPPER(:query_pattern)
|
||||
OR UPPER(INDEXNAME) LIKE UPPER(:query_pattern)
|
||||
)
|
||||
ORDER BY CASE
|
||||
WHEN UPPER(INDEXCODE) = UPPER(:exact_query) THEN 1
|
||||
WHEN UPPER(SECNAME) = UPPER(:exact_query) THEN 2
|
||||
WHEN UPPER(INDEXCODE) LIKE UPPER(:prefix_pattern) THEN 3
|
||||
WHEN UPPER(SECNAME) LIKE UPPER(:prefix_pattern) THEN 4
|
||||
ELSE 5
|
||||
END,
|
||||
INDEXCODE
|
||||
LIMIT :limit
|
||||
""")
|
||||
|
||||
# 构建搜索SQL - 支持股票代码、股票简称、拼音简称搜索
|
||||
search_sql = text("""
|
||||
SELECT DISTINCT SECCODE as stock_code,
|
||||
SECNAME as stock_name,
|
||||
F001V as pinyin_abbr,
|
||||
F003V as security_type,
|
||||
F005V as exchange,
|
||||
F011V as listing_status
|
||||
FROM ea_stocklist
|
||||
WHERE (
|
||||
UPPER(SECCODE) LIKE UPPER(:query_pattern)
|
||||
OR UPPER(SECNAME) LIKE UPPER(:query_pattern)
|
||||
OR UPPER(F001V) LIKE UPPER(:query_pattern)
|
||||
)
|
||||
-- 基本过滤条件:只搜索正常的A股和B股
|
||||
AND (F011V = '正常上市' OR F010V = '013001') -- 正常上市状态
|
||||
AND F003V IN ('A股', 'B股') -- 只搜索A股和B股
|
||||
ORDER BY CASE
|
||||
WHEN UPPER(SECCODE) = UPPER(:exact_query) THEN 1
|
||||
WHEN UPPER(SECNAME) = UPPER(:exact_query) THEN 2
|
||||
WHEN UPPER(F001V) = UPPER(:exact_query) THEN 3
|
||||
WHEN UPPER(SECCODE) LIKE UPPER(:prefix_pattern) THEN 4
|
||||
WHEN UPPER(SECNAME) LIKE UPPER(:prefix_pattern) THEN 5
|
||||
WHEN UPPER(F001V) LIKE UPPER(:prefix_pattern) THEN 6
|
||||
ELSE 7
|
||||
END,
|
||||
SECCODE LIMIT :limit
|
||||
""")
|
||||
index_result = conn.execute(index_sql, {
|
||||
'query_pattern': f'%{query}%',
|
||||
'exact_query': query,
|
||||
'prefix_pattern': f'{query}%',
|
||||
'limit': limit
|
||||
}).fetchall()
|
||||
|
||||
result = conn.execute(search_sql, {
|
||||
'query_pattern': f'%{query}%',
|
||||
'exact_query': query,
|
||||
'prefix_pattern': f'{query}%',
|
||||
'limit': limit
|
||||
}).fetchall()
|
||||
for row in index_result:
|
||||
results.append({
|
||||
'stock_code': row.stock_code,
|
||||
'stock_name': row.stock_name,
|
||||
'full_name': row.full_name,
|
||||
'exchange': row.exchange,
|
||||
'isIndex': True,
|
||||
'security_type': '指数'
|
||||
})
|
||||
|
||||
stocks = []
|
||||
for row in result:
|
||||
# 获取当前价格
|
||||
current_price, _ = get_latest_price_from_clickhouse(row.stock_code)
|
||||
# 搜索股票
|
||||
if search_type in ('all', 'stock'):
|
||||
stock_sql = text("""
|
||||
SELECT DISTINCT SECCODE as stock_code,
|
||||
SECNAME as stock_name,
|
||||
F001V as pinyin_abbr,
|
||||
F003V as security_type,
|
||||
F005V as exchange,
|
||||
F011V as listing_status
|
||||
FROM ea_stocklist
|
||||
WHERE (
|
||||
UPPER(SECCODE) LIKE UPPER(:query_pattern)
|
||||
OR UPPER(SECNAME) LIKE UPPER(:query_pattern)
|
||||
OR UPPER(F001V) LIKE UPPER(:query_pattern)
|
||||
)
|
||||
AND (F011V = '正常上市' OR F010V = '013001')
|
||||
AND F003V IN ('A股', 'B股')
|
||||
ORDER BY CASE
|
||||
WHEN UPPER(SECCODE) = UPPER(:exact_query) THEN 1
|
||||
WHEN UPPER(SECNAME) = UPPER(:exact_query) THEN 2
|
||||
WHEN UPPER(F001V) = UPPER(:exact_query) THEN 3
|
||||
WHEN UPPER(SECCODE) LIKE UPPER(:prefix_pattern) THEN 4
|
||||
WHEN UPPER(SECNAME) LIKE UPPER(:prefix_pattern) THEN 5
|
||||
WHEN UPPER(F001V) LIKE UPPER(:prefix_pattern) THEN 6
|
||||
ELSE 7
|
||||
END,
|
||||
SECCODE
|
||||
LIMIT :limit
|
||||
""")
|
||||
|
||||
stocks.append({
|
||||
'stock_code': row.stock_code,
|
||||
'stock_name': row.stock_name,
|
||||
'current_price': current_price or 0, # 添加当前价格
|
||||
'pinyin_abbr': row.pinyin_abbr,
|
||||
'security_type': row.security_type,
|
||||
'exchange': row.exchange,
|
||||
'listing_status': row.listing_status
|
||||
})
|
||||
stock_result = conn.execute(stock_sql, {
|
||||
'query_pattern': f'%{query}%',
|
||||
'exact_query': query,
|
||||
'prefix_pattern': f'{query}%',
|
||||
'limit': limit
|
||||
}).fetchall()
|
||||
|
||||
for row in stock_result:
|
||||
results.append({
|
||||
'stock_code': row.stock_code,
|
||||
'stock_name': row.stock_name,
|
||||
'pinyin_abbr': row.pinyin_abbr,
|
||||
'security_type': row.security_type,
|
||||
'exchange': row.exchange,
|
||||
'listing_status': row.listing_status,
|
||||
'isIndex': False
|
||||
})
|
||||
|
||||
# 如果搜索全部,按相关性重新排序(精确匹配优先)
|
||||
if search_type == 'all':
|
||||
def sort_key(item):
|
||||
code = item['stock_code'].upper()
|
||||
name = item['stock_name'].upper()
|
||||
q = query.upper()
|
||||
# 精确匹配代码优先
|
||||
if code == q:
|
||||
return (0, not item['isIndex'], code) # 指数优先
|
||||
# 精确匹配名称
|
||||
if name == q:
|
||||
return (1, not item['isIndex'], code)
|
||||
# 前缀匹配代码
|
||||
if code.startswith(q):
|
||||
return (2, not item['isIndex'], code)
|
||||
# 前缀匹配名称
|
||||
if name.startswith(q):
|
||||
return (3, not item['isIndex'], code)
|
||||
return (4, not item['isIndex'], code)
|
||||
|
||||
results.sort(key=sort_key)
|
||||
|
||||
# 限制总数
|
||||
results = results[:limit]
|
||||
|
||||
return jsonify({
|
||||
'success': True,
|
||||
'data': stocks,
|
||||
'count': len(stocks)
|
||||
'data': results,
|
||||
'count': len(results)
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
app.logger.error(f"搜索股票/指数错误: {e}")
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
@@ -12403,7 +12472,21 @@ def get_daily_top_concepts():
|
||||
top_concepts = []
|
||||
|
||||
for concept in data.get('results', []):
|
||||
# 保持与 /concept-api/search 相同的字段结构
|
||||
# 处理 stocks 字段:兼容 {name, code} 和 {stock_name, stock_code} 两种格式
|
||||
raw_stocks = concept.get('stocks', [])
|
||||
formatted_stocks = []
|
||||
for stock in raw_stocks:
|
||||
# 优先使用 stock_name,其次使用 name
|
||||
stock_name = stock.get('stock_name') or stock.get('name', '')
|
||||
stock_code = stock.get('stock_code') or stock.get('code', '')
|
||||
formatted_stocks.append({
|
||||
'stock_name': stock_name,
|
||||
'stock_code': stock_code,
|
||||
'name': stock_name, # 兼容旧格式
|
||||
'code': stock_code # 兼容旧格式
|
||||
})
|
||||
|
||||
# 保持与 /concept-api/search 相同的字段结构,并添加新字段
|
||||
top_concepts.append({
|
||||
'concept_id': concept.get('concept_id'),
|
||||
'concept': concept.get('concept'), # 原始字段名
|
||||
@@ -12414,8 +12497,10 @@ def get_daily_top_concepts():
|
||||
'match_type': concept.get('match_type'),
|
||||
'price_info': concept.get('price_info', {}), # 完整的价格信息
|
||||
'change_percent': concept.get('price_info', {}).get('avg_change_pct', 0), # 兼容旧字段
|
||||
'happened_times': concept.get('happened_times', []), # 历史触发时间
|
||||
'stocks': concept.get('stocks', []), # 返回完整股票列表
|
||||
'tags': concept.get('tags', []), # 标签列表
|
||||
'outbreak_dates': concept.get('outbreak_dates', []), # 爆发日期列表
|
||||
'hierarchy': concept.get('hierarchy'), # 层级信息 {lv1, lv2, lv3}
|
||||
'stocks': formatted_stocks, # 返回格式化后的股票列表
|
||||
'hot_score': concept.get('hot_score')
|
||||
})
|
||||
|
||||
@@ -12442,6 +12527,557 @@ def get_daily_top_concepts():
|
||||
}), 500
|
||||
|
||||
|
||||
# ==================== 热点概览 API ====================
|
||||
|
||||
@app.route('/api/market/hotspot-overview', methods=['GET'])
|
||||
def get_hotspot_overview():
|
||||
"""
|
||||
获取热点概览数据(用于个股中心的热点概览图表)
|
||||
返回:指数分时数据 + 概念异动标注
|
||||
|
||||
数据来源:
|
||||
- 指数分时:ClickHouse index_minute 表
|
||||
- 概念异动:MySQL concept_anomaly_hybrid 表(来自 realtime_detector.py)
|
||||
"""
|
||||
try:
|
||||
trade_date = request.args.get('date')
|
||||
index_code = request.args.get('index', '000001.SH')
|
||||
|
||||
# 如果没有指定日期,使用最新交易日
|
||||
if not trade_date:
|
||||
today = date.today()
|
||||
if today in trading_days_set:
|
||||
trade_date = today.strftime('%Y-%m-%d')
|
||||
else:
|
||||
target_date = get_trading_day_near_date(today)
|
||||
trade_date = target_date.strftime('%Y-%m-%d') if target_date else today.strftime('%Y-%m-%d')
|
||||
|
||||
# 1. 获取指数分时数据
|
||||
client = get_clickhouse_client()
|
||||
target_date_obj = datetime.strptime(trade_date, '%Y-%m-%d').date()
|
||||
|
||||
index_data = client.execute(
|
||||
"""
|
||||
SELECT timestamp, open, high, low, close, volume
|
||||
FROM index_minute
|
||||
WHERE code = %(code)s
|
||||
AND toDate(timestamp) = %(date)s
|
||||
ORDER BY timestamp
|
||||
""",
|
||||
{
|
||||
'code': index_code,
|
||||
'date': target_date_obj
|
||||
}
|
||||
)
|
||||
|
||||
# 获取昨收价
|
||||
code_no_suffix = index_code.split('.')[0]
|
||||
prev_close = None
|
||||
with engine.connect() as conn:
|
||||
prev_result = conn.execute(text("""
|
||||
SELECT F006N FROM ea_exchangetrade
|
||||
WHERE INDEXCODE = :code
|
||||
AND TRADEDATE < :today
|
||||
ORDER BY TRADEDATE DESC LIMIT 1
|
||||
"""), {
|
||||
'code': code_no_suffix,
|
||||
'today': target_date_obj
|
||||
}).fetchone()
|
||||
if prev_result and prev_result[0]:
|
||||
prev_close = float(prev_result[0])
|
||||
|
||||
# 格式化指数数据
|
||||
index_timeline = []
|
||||
for row in index_data:
|
||||
ts, open_p, high_p, low_p, close_p, vol = row
|
||||
change_pct = None
|
||||
if prev_close and close_p:
|
||||
change_pct = round((float(close_p) - prev_close) / prev_close * 100, 4)
|
||||
|
||||
index_timeline.append({
|
||||
'time': ts.strftime('%H:%M'),
|
||||
'timestamp': ts.isoformat(),
|
||||
'price': float(close_p) if close_p else None,
|
||||
'open': float(open_p) if open_p else None,
|
||||
'high': float(high_p) if high_p else None,
|
||||
'low': float(low_p) if low_p else None,
|
||||
'volume': int(vol) if vol else 0,
|
||||
'change_pct': change_pct
|
||||
})
|
||||
|
||||
# 2. 获取概念异动数据(优先从 V2 表,fallback 到旧表)
|
||||
alerts = []
|
||||
use_v2 = False
|
||||
|
||||
with engine.connect() as conn:
|
||||
# 尝试查询 V2 表(时间片对齐 + 持续确认版本)
|
||||
try:
|
||||
v2_result = conn.execute(text("""
|
||||
SELECT
|
||||
concept_id, alert_time, trade_date, alert_type,
|
||||
final_score, rule_score, ml_score, trigger_reason, confirm_ratio,
|
||||
alpha, alpha_zscore, amt_zscore, rank_zscore,
|
||||
momentum_3m, momentum_5m, limit_up_ratio, triggered_rules
|
||||
FROM concept_anomaly_v2
|
||||
WHERE trade_date = :trade_date
|
||||
ORDER BY alert_time
|
||||
"""), {'trade_date': trade_date})
|
||||
v2_rows = v2_result.fetchall()
|
||||
if v2_rows:
|
||||
use_v2 = True
|
||||
for row in v2_rows:
|
||||
triggered_rules = None
|
||||
if row[16]:
|
||||
try:
|
||||
triggered_rules = json.loads(row[16]) if isinstance(row[16], str) else row[16]
|
||||
except:
|
||||
pass
|
||||
|
||||
alerts.append({
|
||||
'concept_id': row[0],
|
||||
'concept_name': row[0], # 后面会填充
|
||||
'time': row[1].strftime('%H:%M') if row[1] else None,
|
||||
'timestamp': row[1].isoformat() if row[1] else None,
|
||||
'alert_type': row[3],
|
||||
'final_score': float(row[4]) if row[4] else None,
|
||||
'rule_score': float(row[5]) if row[5] else None,
|
||||
'ml_score': float(row[6]) if row[6] else None,
|
||||
'trigger_reason': row[7],
|
||||
# V2 新增字段
|
||||
'confirm_ratio': float(row[8]) if row[8] else None,
|
||||
'alpha': float(row[9]) if row[9] else None,
|
||||
'alpha_zscore': float(row[10]) if row[10] else None,
|
||||
'amt_zscore': float(row[11]) if row[11] else None,
|
||||
'rank_zscore': float(row[12]) if row[12] else None,
|
||||
'momentum_3m': float(row[13]) if row[13] else None,
|
||||
'momentum_5m': float(row[14]) if row[14] else None,
|
||||
'limit_up_ratio': float(row[15]) if row[15] else 0,
|
||||
'triggered_rules': triggered_rules,
|
||||
# 兼容字段
|
||||
'importance_score': float(row[4]) / 100 if row[4] else None,
|
||||
'is_v2': True,
|
||||
})
|
||||
except Exception as v2_err:
|
||||
app.logger.debug(f"V2 表查询失败,使用旧表: {v2_err}")
|
||||
|
||||
# Fallback: 查询旧表
|
||||
if not use_v2:
|
||||
try:
|
||||
alert_result = conn.execute(text("""
|
||||
SELECT
|
||||
a.concept_id, a.alert_time, a.trade_date, a.alert_type,
|
||||
a.final_score, a.rule_score, a.ml_score, a.trigger_reason,
|
||||
a.alpha, a.alpha_delta, a.amt_ratio, a.amt_delta,
|
||||
a.rank_pct, a.limit_up_ratio, a.stock_count, a.total_amt,
|
||||
a.triggered_rules
|
||||
FROM concept_anomaly_hybrid a
|
||||
WHERE a.trade_date = :trade_date
|
||||
ORDER BY a.alert_time
|
||||
"""), {'trade_date': trade_date})
|
||||
|
||||
for row in alert_result:
|
||||
triggered_rules = None
|
||||
if row[16]:
|
||||
try:
|
||||
triggered_rules = json.loads(row[16]) if isinstance(row[16], str) else row[16]
|
||||
except:
|
||||
pass
|
||||
|
||||
limit_up_ratio = float(row[13]) if row[13] else 0
|
||||
stock_count = int(row[14]) if row[14] else 0
|
||||
limit_up_count = int(limit_up_ratio * stock_count) if stock_count > 0 else 0
|
||||
|
||||
alerts.append({
|
||||
'concept_id': row[0],
|
||||
'concept_name': row[0],
|
||||
'time': row[1].strftime('%H:%M') if row[1] else None,
|
||||
'timestamp': row[1].isoformat() if row[1] else None,
|
||||
'alert_type': row[3],
|
||||
'final_score': float(row[4]) if row[4] else None,
|
||||
'rule_score': float(row[5]) if row[5] else None,
|
||||
'ml_score': float(row[6]) if row[6] else None,
|
||||
'trigger_reason': row[7],
|
||||
'alpha': float(row[8]) if row[8] else None,
|
||||
'alpha_delta': float(row[9]) if row[9] else None,
|
||||
'amt_ratio': float(row[10]) if row[10] else None,
|
||||
'amt_delta': float(row[11]) if row[11] else None,
|
||||
'rank_pct': float(row[12]) if row[12] else None,
|
||||
'limit_up_ratio': limit_up_ratio,
|
||||
'limit_up_count': limit_up_count,
|
||||
'stock_count': stock_count,
|
||||
'total_amt': float(row[15]) if row[15] else None,
|
||||
'triggered_rules': triggered_rules,
|
||||
'importance_score': float(row[4]) / 100 if row[4] else None,
|
||||
'is_v2': False,
|
||||
})
|
||||
except Exception as old_err:
|
||||
app.logger.debug(f"旧表查询也失败: {old_err}")
|
||||
|
||||
# 尝试批量获取概念名称
|
||||
if alerts:
|
||||
concept_ids = list(set(a['concept_id'] for a in alerts))
|
||||
concept_names = {} # 初始化 concept_names 字典
|
||||
try:
|
||||
from elasticsearch import Elasticsearch
|
||||
es_client = Elasticsearch(["http://222.128.1.157:19200"])
|
||||
es_result = es_client.mget(
|
||||
index='concept_library_v3',
|
||||
body={'ids': concept_ids},
|
||||
_source=['concept']
|
||||
)
|
||||
for doc in es_result.get('docs', []):
|
||||
if doc.get('found') and doc.get('_source'):
|
||||
concept_names[doc['_id']] = doc['_source'].get('concept', doc['_id'])
|
||||
# 更新 alerts 中的概念名称
|
||||
for alert in alerts:
|
||||
if alert['concept_id'] in concept_names:
|
||||
alert['concept_name'] = concept_names[alert['concept_id']]
|
||||
except Exception as e:
|
||||
app.logger.warning(f"获取概念名称失败: {e}")
|
||||
|
||||
# 计算统计信息
|
||||
day_high = max([d['price'] for d in index_timeline if d['price']], default=None)
|
||||
day_low = min([d['price'] for d in index_timeline if d['price']], default=None)
|
||||
latest_price = index_timeline[-1]['price'] if index_timeline else None
|
||||
latest_change_pct = index_timeline[-1]['change_pct'] if index_timeline else None
|
||||
|
||||
return jsonify({
|
||||
'success': True,
|
||||
'data': {
|
||||
'trade_date': trade_date,
|
||||
'index': {
|
||||
'code': index_code,
|
||||
'name': '上证指数' if index_code == '000001.SH' else index_code,
|
||||
'prev_close': prev_close,
|
||||
'latest_price': latest_price,
|
||||
'change_pct': latest_change_pct,
|
||||
'high': day_high,
|
||||
'low': day_low,
|
||||
'timeline': index_timeline
|
||||
},
|
||||
'alerts': alerts,
|
||||
'alert_count': len(alerts),
|
||||
'alert_summary': {
|
||||
'surge': len([a for a in alerts if a['alert_type'] == 'surge']),
|
||||
'surge_up': len([a for a in alerts if a['alert_type'] == 'surge_up']),
|
||||
'surge_down': len([a for a in alerts if a['alert_type'] == 'surge_down']),
|
||||
'limit_up': len([a for a in alerts if a['alert_type'] == 'limit_up']),
|
||||
'volume_spike': len([a for a in alerts if a['alert_type'] == 'volume_spike']),
|
||||
'rank_jump': len([a for a in alerts if a['alert_type'] == 'rank_jump'])
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
error_trace = traceback.format_exc()
|
||||
app.logger.error(f"获取热点概览数据失败: {error_trace}")
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'error': str(e),
|
||||
'traceback': error_trace # 临时返回完整错误信息用于调试
|
||||
}), 500
|
||||
|
||||
|
||||
@app.route('/api/concept/<concept_id>/stocks', methods=['GET'])
|
||||
def get_concept_stocks(concept_id):
|
||||
"""
|
||||
获取概念的相关股票列表(带实时涨跌幅)
|
||||
|
||||
Args:
|
||||
concept_id: 概念 ID(来自 ES concept_library_v3)
|
||||
|
||||
Returns:
|
||||
- stocks: 股票列表 [{code, name, reason, change_pct}, ...]
|
||||
"""
|
||||
try:
|
||||
from elasticsearch import Elasticsearch
|
||||
from clickhouse_driver import Client
|
||||
|
||||
# 1. 从 ES 获取概念的股票列表
|
||||
es_client = Elasticsearch(["http://222.128.1.157:19200"])
|
||||
es_result = es_client.get(index='concept_library_v3', id=concept_id)
|
||||
|
||||
if not es_result.get('found'):
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'error': f'概念 {concept_id} 不存在'
|
||||
}), 404
|
||||
|
||||
source = es_result.get('_source', {})
|
||||
concept_name = source.get('concept', concept_id)
|
||||
raw_stocks = source.get('stocks', [])
|
||||
|
||||
if not raw_stocks:
|
||||
return jsonify({
|
||||
'success': True,
|
||||
'data': {
|
||||
'concept_id': concept_id,
|
||||
'concept_name': concept_name,
|
||||
'stocks': []
|
||||
}
|
||||
})
|
||||
|
||||
# 提取股票代码和原因
|
||||
stocks_info = []
|
||||
stock_codes = []
|
||||
for s in raw_stocks:
|
||||
if isinstance(s, dict):
|
||||
code = s.get('code', '')
|
||||
if code and len(code) == 6:
|
||||
stocks_info.append({
|
||||
'code': code,
|
||||
'name': s.get('name', ''),
|
||||
'reason': s.get('reason', '')
|
||||
})
|
||||
stock_codes.append(code)
|
||||
|
||||
if not stock_codes:
|
||||
return jsonify({
|
||||
'success': True,
|
||||
'data': {
|
||||
'concept_id': concept_id,
|
||||
'concept_name': concept_name,
|
||||
'stocks': stocks_info
|
||||
}
|
||||
})
|
||||
|
||||
# 2. 获取最新交易日和前一交易日
|
||||
today = datetime.now().date()
|
||||
trading_day = None
|
||||
prev_trading_day = None
|
||||
|
||||
with engine.connect() as conn:
|
||||
# 获取最新交易日
|
||||
result = conn.execute(text("""
|
||||
SELECT EXCHANGE_DATE FROM trading_days
|
||||
WHERE EXCHANGE_DATE <= :today
|
||||
ORDER BY EXCHANGE_DATE DESC LIMIT 1
|
||||
"""), {"today": today}).fetchone()
|
||||
if result:
|
||||
trading_day = result[0].date() if hasattr(result[0], 'date') else result[0]
|
||||
|
||||
# 获取前一交易日
|
||||
if trading_day:
|
||||
result = conn.execute(text("""
|
||||
SELECT EXCHANGE_DATE FROM trading_days
|
||||
WHERE EXCHANGE_DATE < :date
|
||||
ORDER BY EXCHANGE_DATE DESC LIMIT 1
|
||||
"""), {"date": trading_day}).fetchone()
|
||||
if result:
|
||||
prev_trading_day = result[0].date() if hasattr(result[0], 'date') else result[0]
|
||||
|
||||
# 3. 从 MySQL ea_trade 获取前一交易日收盘价(F007N)
|
||||
prev_close_map = {}
|
||||
if prev_trading_day and stock_codes:
|
||||
with engine.connect() as conn:
|
||||
placeholders = ','.join([f':code{i}' for i in range(len(stock_codes))])
|
||||
params = {f'code{i}': code for i, code in enumerate(stock_codes)}
|
||||
params['trade_date'] = prev_trading_day
|
||||
|
||||
result = conn.execute(text(f"""
|
||||
SELECT SECCODE, F007N
|
||||
FROM ea_trade
|
||||
WHERE SECCODE IN ({placeholders})
|
||||
AND TRADEDATE = :trade_date
|
||||
AND F007N > 0
|
||||
"""), params).fetchall()
|
||||
|
||||
prev_close_map = {row[0]: float(row[1]) for row in result if row[1]}
|
||||
|
||||
# 4. 从 ClickHouse 获取最新价格
|
||||
current_price_map = {}
|
||||
if stock_codes:
|
||||
try:
|
||||
ch_client = Client(
|
||||
host='127.0.0.1',
|
||||
port=9000,
|
||||
user='default',
|
||||
password='Zzl33818!',
|
||||
database='stock'
|
||||
)
|
||||
|
||||
# 转换为 ClickHouse 格式
|
||||
ch_codes = []
|
||||
code_mapping = {}
|
||||
for code in stock_codes:
|
||||
if code.startswith('6'):
|
||||
ch_code = f"{code}.SH"
|
||||
elif code.startswith('0') or code.startswith('3'):
|
||||
ch_code = f"{code}.SZ"
|
||||
else:
|
||||
ch_code = f"{code}.BJ"
|
||||
ch_codes.append(ch_code)
|
||||
code_mapping[ch_code] = code
|
||||
|
||||
ch_codes_str = "','".join(ch_codes)
|
||||
|
||||
# 查询当天最新价格
|
||||
query = f"""
|
||||
SELECT code, close
|
||||
FROM stock_minute
|
||||
WHERE code IN ('{ch_codes_str}')
|
||||
AND toDate(timestamp) = today()
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT 1 BY code
|
||||
"""
|
||||
result = ch_client.execute(query)
|
||||
|
||||
for row in result:
|
||||
ch_code, close_price = row
|
||||
if ch_code in code_mapping and close_price:
|
||||
original_code = code_mapping[ch_code]
|
||||
current_price_map[original_code] = float(close_price)
|
||||
|
||||
except Exception as ch_err:
|
||||
app.logger.warning(f"ClickHouse 获取价格失败: {ch_err}")
|
||||
|
||||
# 5. 计算涨跌幅并合并数据
|
||||
result_stocks = []
|
||||
for stock in stocks_info:
|
||||
code = stock['code']
|
||||
prev_close = prev_close_map.get(code)
|
||||
current_price = current_price_map.get(code)
|
||||
|
||||
change_pct = None
|
||||
if prev_close and current_price and prev_close > 0:
|
||||
change_pct = round((current_price - prev_close) / prev_close * 100, 2)
|
||||
|
||||
result_stocks.append({
|
||||
'code': code,
|
||||
'name': stock['name'],
|
||||
'reason': stock['reason'],
|
||||
'change_pct': change_pct,
|
||||
'price': current_price,
|
||||
'prev_close': prev_close
|
||||
})
|
||||
|
||||
# 按涨跌幅排序(涨停优先)
|
||||
result_stocks.sort(key=lambda x: x.get('change_pct') if x.get('change_pct') is not None else -999, reverse=True)
|
||||
|
||||
return jsonify({
|
||||
'success': True,
|
||||
'data': {
|
||||
'concept_id': concept_id,
|
||||
'concept_name': concept_name,
|
||||
'stock_count': len(result_stocks),
|
||||
'trading_day': str(trading_day) if trading_day else None,
|
||||
'stocks': result_stocks
|
||||
}
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
app.logger.error(f"获取概念股票失败: {traceback.format_exc()}")
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}), 500
|
||||
|
||||
|
||||
@app.route('/api/market/concept-alerts', methods=['GET'])
|
||||
def get_concept_alerts():
|
||||
"""
|
||||
获取概念异动列表(支持分页和筛选)
|
||||
"""
|
||||
try:
|
||||
trade_date = request.args.get('date')
|
||||
alert_type = request.args.get('type') # surge/limit_up/rank_jump
|
||||
concept_type = request.args.get('concept_type') # leaf/lv1/lv2/lv3
|
||||
limit = request.args.get('limit', 50, type=int)
|
||||
offset = request.args.get('offset', 0, type=int)
|
||||
|
||||
# 构建查询条件
|
||||
conditions = []
|
||||
params = {'limit': limit, 'offset': offset}
|
||||
|
||||
if trade_date:
|
||||
conditions.append("trade_date = :trade_date")
|
||||
params['trade_date'] = trade_date
|
||||
else:
|
||||
conditions.append("trade_date = CURDATE()")
|
||||
|
||||
if alert_type:
|
||||
conditions.append("alert_type = :alert_type")
|
||||
params['alert_type'] = alert_type
|
||||
|
||||
if concept_type:
|
||||
conditions.append("concept_type = :concept_type")
|
||||
params['concept_type'] = concept_type
|
||||
|
||||
where_clause = " AND ".join(conditions) if conditions else "1=1"
|
||||
|
||||
with engine.connect() as conn:
|
||||
# 获取总数
|
||||
count_sql = text(f"SELECT COUNT(*) FROM concept_minute_alert WHERE {where_clause}")
|
||||
total = conn.execute(count_sql, params).scalar()
|
||||
|
||||
# 获取数据
|
||||
query_sql = text(f"""
|
||||
SELECT
|
||||
id, concept_id, concept_name, alert_time, alert_type, trade_date,
|
||||
change_pct, prev_change_pct, change_delta,
|
||||
limit_up_count, prev_limit_up_count, limit_up_delta,
|
||||
rank_position, prev_rank_position, rank_delta,
|
||||
index_price, index_change_pct,
|
||||
stock_count, concept_type, extra_info
|
||||
FROM concept_minute_alert
|
||||
WHERE {where_clause}
|
||||
ORDER BY alert_time DESC
|
||||
LIMIT :limit OFFSET :offset
|
||||
""")
|
||||
|
||||
result = conn.execute(query_sql, params)
|
||||
|
||||
alerts = []
|
||||
for row in result:
|
||||
extra_info = None
|
||||
if row[19]:
|
||||
try:
|
||||
extra_info = json.loads(row[19]) if isinstance(row[19], str) else row[19]
|
||||
except:
|
||||
pass
|
||||
|
||||
alerts.append({
|
||||
'id': row[0],
|
||||
'concept_id': row[1],
|
||||
'concept_name': row[2],
|
||||
'alert_time': row[3].isoformat() if row[3] else None,
|
||||
'alert_type': row[4],
|
||||
'trade_date': row[5].isoformat() if row[5] else None,
|
||||
'change_pct': float(row[6]) if row[6] else None,
|
||||
'prev_change_pct': float(row[7]) if row[7] else None,
|
||||
'change_delta': float(row[8]) if row[8] else None,
|
||||
'limit_up_count': row[9],
|
||||
'prev_limit_up_count': row[10],
|
||||
'limit_up_delta': row[11],
|
||||
'rank_position': row[12],
|
||||
'prev_rank_position': row[13],
|
||||
'rank_delta': row[14],
|
||||
'index_price': float(row[15]) if row[15] else None,
|
||||
'index_change_pct': float(row[16]) if row[16] else None,
|
||||
'stock_count': row[17],
|
||||
'concept_type': row[18],
|
||||
'extra_info': extra_info
|
||||
})
|
||||
|
||||
return jsonify({
|
||||
'success': True,
|
||||
'data': alerts,
|
||||
'total': total,
|
||||
'limit': limit,
|
||||
'offset': offset
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
app.logger.error(f"获取概念异动列表失败: {traceback.format_exc()}")
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}), 500
|
||||
|
||||
|
||||
@app.route('/api/market/rise-analysis/<seccode>', methods=['GET'])
|
||||
def get_rise_analysis(seccode):
|
||||
"""获取股票涨幅分析数据(从 Elasticsearch 获取)"""
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
112
ml/README.md
Normal file
112
ml/README.md
Normal file
@@ -0,0 +1,112 @@
|
||||
# 概念异动检测 ML 模块
|
||||
|
||||
基于 Transformer Autoencoder 的概念异动检测系统。
|
||||
|
||||
## 环境要求
|
||||
|
||||
- Python 3.8+
|
||||
- PyTorch 2.0+ (CUDA 12.x for 5090 GPU)
|
||||
- ClickHouse, MySQL, Elasticsearch
|
||||
|
||||
## 数据库配置
|
||||
|
||||
当前配置(`prepare_data.py`):
|
||||
- MySQL: `192.168.1.5:3306`
|
||||
- Elasticsearch: `127.0.0.1:9200`
|
||||
- ClickHouse: `127.0.0.1:9000`
|
||||
|
||||
## 快速开始
|
||||
|
||||
```bash
|
||||
# 1. 安装依赖
|
||||
pip install -r ml/requirements.txt
|
||||
|
||||
# 2. 安装 PyTorch (5090 需要 CUDA 12.4)
|
||||
pip install torch --index-url https://download.pytorch.org/whl/cu124
|
||||
|
||||
# 3. 运行训练
|
||||
chmod +x ml/run_training.sh
|
||||
./ml/run_training.sh
|
||||
```
|
||||
|
||||
## 文件说明
|
||||
|
||||
| 文件 | 说明 |
|
||||
|------|------|
|
||||
| `model.py` | Transformer Autoencoder 模型定义 |
|
||||
| `prepare_data.py` | 数据提取和特征计算 |
|
||||
| `train.py` | 模型训练脚本 |
|
||||
| `inference.py` | 推理服务 |
|
||||
| `enhanced_detector.py` | 增强版检测器(融合 Alpha + ML) |
|
||||
|
||||
## 训练参数
|
||||
|
||||
```bash
|
||||
# 完整参数
|
||||
./ml/run_training.sh --start 2022-01-01 --end 2024-12-01 --epochs 100 --batch_size 256
|
||||
|
||||
# 只准备数据
|
||||
python ml/prepare_data.py --start 2022-01-01
|
||||
|
||||
# 只训练(数据已准备好)
|
||||
python ml/train.py --epochs 100 --batch_size 256 --lr 1e-4
|
||||
```
|
||||
|
||||
## 模型架构
|
||||
|
||||
```
|
||||
输入: (batch, 30, 6) # 30分钟序列,6个特征
|
||||
↓
|
||||
Positional Encoding
|
||||
↓
|
||||
Transformer Encoder (4层, 8头, d=128)
|
||||
↓
|
||||
Bottleneck (压缩到 32 维)
|
||||
↓
|
||||
Transformer Decoder (4层)
|
||||
↓
|
||||
输出: (batch, 30, 6) # 重构序列
|
||||
|
||||
异动判断: reconstruction_error > threshold
|
||||
```
|
||||
|
||||
## 6维特征
|
||||
|
||||
1. `alpha` - 超额收益(概念涨幅 - 大盘涨幅)
|
||||
2. `alpha_delta` - Alpha 5分钟变化
|
||||
3. `amt_ratio` - 成交额 / 20分钟均值
|
||||
4. `amt_delta` - 成交额变化率
|
||||
5. `rank_pct` - Alpha 排名百分位
|
||||
6. `limit_up_ratio` - 涨停股占比
|
||||
|
||||
## 训练产出
|
||||
|
||||
训练完成后,`ml/checkpoints/` 包含:
|
||||
- `best_model.pt` - 最佳模型权重
|
||||
- `thresholds.json` - 异动阈值 (P90/P95/P99)
|
||||
- `normalization_stats.json` - 数据标准化参数
|
||||
- `config.json` - 训练配置
|
||||
|
||||
## 使用示例
|
||||
|
||||
```python
|
||||
from ml.inference import ConceptAnomalyDetector
|
||||
|
||||
detector = ConceptAnomalyDetector('ml/checkpoints')
|
||||
|
||||
# 实时检测
|
||||
is_anomaly, score = detector.detect(
|
||||
concept_name="人工智能",
|
||||
features={
|
||||
'alpha': 2.5,
|
||||
'alpha_delta': 0.8,
|
||||
'amt_ratio': 1.5,
|
||||
'amt_delta': 0.3,
|
||||
'rank_pct': 0.95,
|
||||
'limit_up_ratio': 0.15,
|
||||
}
|
||||
)
|
||||
|
||||
if is_anomaly:
|
||||
print(f"检测到异动!分数: {score}")
|
||||
```
|
||||
10
ml/__init__.py
Normal file
10
ml/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
概念异动检测 ML 模块
|
||||
|
||||
提供基于 Transformer Autoencoder 的异动检测功能
|
||||
"""
|
||||
|
||||
from .inference import ConceptAnomalyDetector, MLAnomalyService
|
||||
|
||||
__all__ = ['ConceptAnomalyDetector', 'MLAnomalyService']
|
||||
BIN
ml/__pycache__/realtime_detector.cpython-310.pyc
Normal file
BIN
ml/__pycache__/realtime_detector.cpython-310.pyc
Normal file
Binary file not shown.
481
ml/backtest.py
Normal file
481
ml/backtest.py
Normal file
@@ -0,0 +1,481 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
历史异动回测脚本
|
||||
|
||||
使用训练好的模型,对历史数据进行异动检测,生成异动记录
|
||||
|
||||
使用方法:
|
||||
# 回测指定日期范围
|
||||
python backtest.py --start 2024-01-01 --end 2024-12-01
|
||||
|
||||
# 回测单天
|
||||
python backtest.py --start 2024-11-01 --end 2024-11-01
|
||||
|
||||
# 只生成结果,不写入数据库
|
||||
python backtest.py --start 2024-01-01 --dry-run
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import json
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple, Optional
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from sqlalchemy import create_engine, text
|
||||
|
||||
# 添加父目录到路径
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from model import TransformerAutoencoder
|
||||
|
||||
|
||||
# ==================== 配置 ====================
|
||||
|
||||
MYSQL_ENGINE = create_engine(
|
||||
"mysql+pymysql://root:Zzl5588161!@192.168.1.5:3306/stock",
|
||||
echo=False
|
||||
)
|
||||
|
||||
# 特征列表(与训练一致)
|
||||
FEATURES = [
|
||||
'alpha',
|
||||
'alpha_delta',
|
||||
'amt_ratio',
|
||||
'amt_delta',
|
||||
'rank_pct',
|
||||
'limit_up_ratio',
|
||||
]
|
||||
|
||||
# 回测配置
|
||||
BACKTEST_CONFIG = {
|
||||
'seq_len': 30, # 序列长度
|
||||
'threshold_key': 'p95', # 使用的阈值
|
||||
'min_alpha_abs': 0.5, # 最小 Alpha 绝对值(过滤微小波动)
|
||||
'cooldown_minutes': 8, # 同一概念冷却时间
|
||||
'max_alerts_per_minute': 15, # 每分钟最多异动数
|
||||
'clip_value': 10.0, # 极端值截断
|
||||
}
|
||||
|
||||
|
||||
# ==================== 模型加载 ====================
|
||||
|
||||
class AnomalyDetector:
|
||||
"""异动检测器"""
|
||||
|
||||
def __init__(self, checkpoint_dir: str = 'ml/checkpoints', device: str = 'auto'):
|
||||
self.checkpoint_dir = Path(checkpoint_dir)
|
||||
|
||||
# 设备
|
||||
if device == 'auto':
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
else:
|
||||
self.device = torch.device(device)
|
||||
|
||||
# 加载配置
|
||||
self._load_config()
|
||||
|
||||
# 加载模型
|
||||
self._load_model()
|
||||
|
||||
# 加载阈值
|
||||
self._load_thresholds()
|
||||
|
||||
print(f"AnomalyDetector 初始化完成")
|
||||
print(f" 设备: {self.device}")
|
||||
print(f" 阈值 ({BACKTEST_CONFIG['threshold_key']}): {self.threshold:.6f}")
|
||||
|
||||
def _load_config(self):
|
||||
config_path = self.checkpoint_dir / 'config.json'
|
||||
with open(config_path, 'r') as f:
|
||||
self.config = json.load(f)
|
||||
|
||||
def _load_model(self):
|
||||
model_path = self.checkpoint_dir / 'best_model.pt'
|
||||
checkpoint = torch.load(model_path, map_location=self.device)
|
||||
|
||||
model_config = self.config['model'].copy()
|
||||
model_config['use_instance_norm'] = self.config.get('use_instance_norm', True)
|
||||
|
||||
self.model = TransformerAutoencoder(**model_config)
|
||||
self.model.load_state_dict(checkpoint['model_state_dict'])
|
||||
self.model.to(self.device)
|
||||
self.model.eval()
|
||||
|
||||
def _load_thresholds(self):
|
||||
thresholds_path = self.checkpoint_dir / 'thresholds.json'
|
||||
with open(thresholds_path, 'r') as f:
|
||||
thresholds = json.load(f)
|
||||
|
||||
self.threshold = thresholds[BACKTEST_CONFIG['threshold_key']]
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_anomaly_scores(self, sequences: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
计算异动分数
|
||||
|
||||
Args:
|
||||
sequences: (n_sequences, seq_len, n_features)
|
||||
Returns:
|
||||
scores: (n_sequences,) 每个序列最后时刻的异动分数
|
||||
"""
|
||||
# 截断极端值
|
||||
sequences = np.clip(sequences, -BACKTEST_CONFIG['clip_value'], BACKTEST_CONFIG['clip_value'])
|
||||
|
||||
# 转为 tensor
|
||||
x = torch.FloatTensor(sequences).to(self.device)
|
||||
|
||||
# 计算重构误差
|
||||
errors = self.model.compute_reconstruction_error(x, reduction='none')
|
||||
|
||||
# 取最后一个时刻的误差
|
||||
scores = errors[:, -1].cpu().numpy()
|
||||
|
||||
return scores
|
||||
|
||||
def is_anomaly(self, score: float) -> bool:
|
||||
"""判断是否异动"""
|
||||
return score > self.threshold
|
||||
|
||||
|
||||
# ==================== 数据加载 ====================
|
||||
|
||||
def load_daily_features(data_dir: str, date: str) -> Optional[pd.DataFrame]:
|
||||
"""加载单天的特征数据"""
|
||||
file_path = Path(data_dir) / f"features_{date}.parquet"
|
||||
|
||||
if not file_path.exists():
|
||||
return None
|
||||
|
||||
df = pd.read_parquet(file_path)
|
||||
return df
|
||||
|
||||
|
||||
def get_available_dates(data_dir: str, start_date: str, end_date: str) -> List[str]:
|
||||
"""获取可用的日期列表"""
|
||||
data_path = Path(data_dir)
|
||||
all_files = sorted(data_path.glob("features_*.parquet"))
|
||||
|
||||
dates = []
|
||||
for f in all_files:
|
||||
date = f.stem.replace('features_', '')
|
||||
if start_date <= date <= end_date:
|
||||
dates.append(date)
|
||||
|
||||
return dates
|
||||
|
||||
|
||||
# ==================== 回测逻辑 ====================
|
||||
|
||||
def backtest_single_day(
|
||||
detector: AnomalyDetector,
|
||||
df: pd.DataFrame,
|
||||
date: str,
|
||||
seq_len: int = 30
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
回测单天数据
|
||||
|
||||
Args:
|
||||
detector: 异动检测器
|
||||
df: 当天的特征数据
|
||||
date: 日期
|
||||
seq_len: 序列长度
|
||||
|
||||
Returns:
|
||||
alerts: 异动列表
|
||||
"""
|
||||
alerts = []
|
||||
|
||||
# 按概念分组
|
||||
grouped = df.groupby('concept_id', sort=False)
|
||||
|
||||
# 冷却记录 {concept_id: last_alert_timestamp}
|
||||
cooldown = {}
|
||||
|
||||
# 获取所有时间点
|
||||
all_timestamps = sorted(df['timestamp'].unique())
|
||||
|
||||
if len(all_timestamps) < seq_len:
|
||||
return alerts
|
||||
|
||||
# 对每个时间点进行检测(从第 seq_len 个开始)
|
||||
for t_idx in range(seq_len - 1, len(all_timestamps)):
|
||||
current_time = all_timestamps[t_idx]
|
||||
window_start_time = all_timestamps[t_idx - seq_len + 1]
|
||||
|
||||
minute_alerts = []
|
||||
|
||||
# 收集该时刻所有概念的序列
|
||||
concept_sequences = []
|
||||
concept_infos = []
|
||||
|
||||
for concept_id, concept_df in grouped:
|
||||
# 获取该概念在时间窗口内的数据
|
||||
mask = (concept_df['timestamp'] >= window_start_time) & (concept_df['timestamp'] <= current_time)
|
||||
window_df = concept_df[mask].sort_values('timestamp')
|
||||
|
||||
if len(window_df) < seq_len:
|
||||
continue
|
||||
|
||||
# 取最后 seq_len 个点
|
||||
window_df = window_df.tail(seq_len)
|
||||
|
||||
# 提取特征
|
||||
features = window_df[FEATURES].values
|
||||
|
||||
# 处理缺失值
|
||||
features = np.nan_to_num(features, nan=0.0, posinf=0.0, neginf=0.0)
|
||||
|
||||
# 获取当前时刻的信息
|
||||
current_row = window_df.iloc[-1]
|
||||
|
||||
concept_sequences.append(features)
|
||||
concept_infos.append({
|
||||
'concept_id': concept_id,
|
||||
'timestamp': current_time,
|
||||
'alpha': current_row.get('alpha', 0),
|
||||
'alpha_delta': current_row.get('alpha_delta', 0),
|
||||
'amt_ratio': current_row.get('amt_ratio', 1),
|
||||
'limit_up_ratio': current_row.get('limit_up_ratio', 0),
|
||||
'limit_down_ratio': current_row.get('limit_down_ratio', 0),
|
||||
'rank_pct': current_row.get('rank_pct', 0.5),
|
||||
'stock_count': current_row.get('stock_count', 0),
|
||||
'total_amt': current_row.get('total_amt', 0),
|
||||
})
|
||||
|
||||
if not concept_sequences:
|
||||
continue
|
||||
|
||||
# 批量计算异动分数
|
||||
sequences_array = np.array(concept_sequences)
|
||||
scores = detector.compute_anomaly_scores(sequences_array)
|
||||
|
||||
# 检测异动
|
||||
for i, (info, score) in enumerate(zip(concept_infos, scores)):
|
||||
concept_id = info['concept_id']
|
||||
alpha = info['alpha']
|
||||
|
||||
# 过滤小波动
|
||||
if abs(alpha) < BACKTEST_CONFIG['min_alpha_abs']:
|
||||
continue
|
||||
|
||||
# 检查冷却
|
||||
if concept_id in cooldown:
|
||||
last_alert = cooldown[concept_id]
|
||||
if isinstance(current_time, datetime):
|
||||
time_diff = (current_time - last_alert).total_seconds() / 60
|
||||
else:
|
||||
# timestamp 是字符串或其他格式
|
||||
time_diff = BACKTEST_CONFIG['cooldown_minutes'] + 1 # 跳过冷却检查
|
||||
|
||||
if time_diff < BACKTEST_CONFIG['cooldown_minutes']:
|
||||
continue
|
||||
|
||||
# 判断是否异动
|
||||
if not detector.is_anomaly(score):
|
||||
continue
|
||||
|
||||
# 记录异动
|
||||
alert_type = 'surge_up' if alpha > 0 else 'surge_down'
|
||||
|
||||
alert = {
|
||||
'concept_id': concept_id,
|
||||
'alert_time': current_time,
|
||||
'trade_date': date,
|
||||
'alert_type': alert_type,
|
||||
'anomaly_score': float(score),
|
||||
'threshold': detector.threshold,
|
||||
**info
|
||||
}
|
||||
|
||||
minute_alerts.append(alert)
|
||||
cooldown[concept_id] = current_time
|
||||
|
||||
# 按分数排序,限制数量
|
||||
minute_alerts.sort(key=lambda x: x['anomaly_score'], reverse=True)
|
||||
alerts.extend(minute_alerts[:BACKTEST_CONFIG['max_alerts_per_minute']])
|
||||
|
||||
return alerts
|
||||
|
||||
|
||||
# ==================== 数据库写入 ====================
|
||||
|
||||
def save_alerts_to_mysql(alerts: List[Dict], dry_run: bool = False) -> int:
|
||||
"""保存异动到 MySQL"""
|
||||
if not alerts:
|
||||
return 0
|
||||
|
||||
if dry_run:
|
||||
print(f" [Dry Run] 将写入 {len(alerts)} 条异动")
|
||||
return len(alerts)
|
||||
|
||||
saved = 0
|
||||
with MYSQL_ENGINE.begin() as conn:
|
||||
for alert in alerts:
|
||||
try:
|
||||
# 检查是否已存在
|
||||
check_sql = text("""
|
||||
SELECT id FROM concept_minute_alert
|
||||
WHERE concept_id = :concept_id
|
||||
AND alert_time = :alert_time
|
||||
AND trade_date = :trade_date
|
||||
""")
|
||||
exists = conn.execute(check_sql, {
|
||||
'concept_id': alert['concept_id'],
|
||||
'alert_time': alert['alert_time'],
|
||||
'trade_date': alert['trade_date'],
|
||||
}).fetchone()
|
||||
|
||||
if exists:
|
||||
continue
|
||||
|
||||
# 插入新记录
|
||||
insert_sql = text("""
|
||||
INSERT INTO concept_minute_alert
|
||||
(concept_id, concept_name, alert_time, alert_type, trade_date,
|
||||
change_pct, zscore, importance_score, stock_count, extra_info)
|
||||
VALUES
|
||||
(:concept_id, :concept_name, :alert_time, :alert_type, :trade_date,
|
||||
:change_pct, :zscore, :importance_score, :stock_count, :extra_info)
|
||||
""")
|
||||
|
||||
conn.execute(insert_sql, {
|
||||
'concept_id': alert['concept_id'],
|
||||
'concept_name': alert.get('concept_name', ''),
|
||||
'alert_time': alert['alert_time'],
|
||||
'alert_type': alert['alert_type'],
|
||||
'trade_date': alert['trade_date'],
|
||||
'change_pct': alert.get('alpha', 0),
|
||||
'zscore': alert['anomaly_score'],
|
||||
'importance_score': alert['anomaly_score'],
|
||||
'stock_count': alert.get('stock_count', 0),
|
||||
'extra_info': json.dumps({
|
||||
'detection_method': 'ml_autoencoder',
|
||||
'threshold': alert['threshold'],
|
||||
'alpha': alert.get('alpha', 0),
|
||||
'amt_ratio': alert.get('amt_ratio', 1),
|
||||
}, ensure_ascii=False)
|
||||
})
|
||||
|
||||
saved += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f" 保存失败: {alert['concept_id']} - {e}")
|
||||
|
||||
return saved
|
||||
|
||||
|
||||
def export_alerts_to_csv(alerts: List[Dict], output_path: str):
|
||||
"""导出异动到 CSV"""
|
||||
if not alerts:
|
||||
return
|
||||
|
||||
df = pd.DataFrame(alerts)
|
||||
df.to_csv(output_path, index=False, encoding='utf-8-sig')
|
||||
print(f"已导出到: {output_path}")
|
||||
|
||||
|
||||
# ==================== 主函数 ====================
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='历史异动回测')
|
||||
parser.add_argument('--data_dir', type=str, default='ml/data',
|
||||
help='特征数据目录')
|
||||
parser.add_argument('--checkpoint_dir', type=str, default='ml/checkpoints',
|
||||
help='模型检查点目录')
|
||||
parser.add_argument('--start', type=str, required=True,
|
||||
help='开始日期 (YYYY-MM-DD)')
|
||||
parser.add_argument('--end', type=str, required=True,
|
||||
help='结束日期 (YYYY-MM-DD)')
|
||||
parser.add_argument('--dry-run', action='store_true',
|
||||
help='只计算,不写入数据库')
|
||||
parser.add_argument('--export-csv', type=str, default=None,
|
||||
help='导出 CSV 文件路径')
|
||||
parser.add_argument('--device', type=str, default='auto',
|
||||
help='设备 (auto/cuda/cpu)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print("=" * 60)
|
||||
print("历史异动回测")
|
||||
print("=" * 60)
|
||||
print(f"日期范围: {args.start} ~ {args.end}")
|
||||
print(f"数据目录: {args.data_dir}")
|
||||
print(f"模型目录: {args.checkpoint_dir}")
|
||||
print(f"Dry Run: {args.dry_run}")
|
||||
print("=" * 60)
|
||||
|
||||
# 初始化检测器
|
||||
detector = AnomalyDetector(args.checkpoint_dir, args.device)
|
||||
|
||||
# 获取可用日期
|
||||
dates = get_available_dates(args.data_dir, args.start, args.end)
|
||||
|
||||
if not dates:
|
||||
print(f"未找到 {args.start} ~ {args.end} 范围内的数据")
|
||||
return
|
||||
|
||||
print(f"\n找到 {len(dates)} 天的数据")
|
||||
|
||||
# 回测
|
||||
all_alerts = []
|
||||
total_saved = 0
|
||||
|
||||
for date in tqdm(dates, desc="回测进度"):
|
||||
# 加载数据
|
||||
df = load_daily_features(args.data_dir, date)
|
||||
|
||||
if df is None or df.empty:
|
||||
continue
|
||||
|
||||
# 回测单天
|
||||
alerts = backtest_single_day(
|
||||
detector, df, date,
|
||||
seq_len=BACKTEST_CONFIG['seq_len']
|
||||
)
|
||||
|
||||
if alerts:
|
||||
all_alerts.extend(alerts)
|
||||
|
||||
# 写入数据库
|
||||
saved = save_alerts_to_mysql(alerts, dry_run=args.dry_run)
|
||||
total_saved += saved
|
||||
|
||||
if not args.dry_run:
|
||||
tqdm.write(f" {date}: 检测到 {len(alerts)} 个异动,保存 {saved} 条")
|
||||
|
||||
# 导出 CSV
|
||||
if args.export_csv and all_alerts:
|
||||
export_alerts_to_csv(all_alerts, args.export_csv)
|
||||
|
||||
# 汇总
|
||||
print("\n" + "=" * 60)
|
||||
print("回测完成!")
|
||||
print("=" * 60)
|
||||
print(f"总计检测到: {len(all_alerts)} 个异动")
|
||||
print(f"保存到数据库: {total_saved} 条")
|
||||
|
||||
# 统计
|
||||
if all_alerts:
|
||||
df_alerts = pd.DataFrame(all_alerts)
|
||||
print(f"\n异动类型分布:")
|
||||
print(df_alerts['alert_type'].value_counts())
|
||||
|
||||
print(f"\n异动分数统计:")
|
||||
print(f" Mean: {df_alerts['anomaly_score'].mean():.4f}")
|
||||
print(f" Max: {df_alerts['anomaly_score'].max():.4f}")
|
||||
print(f" Min: {df_alerts['anomaly_score'].min():.4f}")
|
||||
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
859
ml/backtest_fast.py
Normal file
859
ml/backtest_fast.py
Normal file
@@ -0,0 +1,859 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
快速融合异动回测脚本
|
||||
|
||||
优化策略:
|
||||
1. 预先构建所有序列(向量化),避免循环内重复切片
|
||||
2. 批量 ML 推理(一次推理所有候选)
|
||||
3. 使用 NumPy 向量化操作替代 Python 循环
|
||||
|
||||
性能对比:
|
||||
- 原版:5分钟/天
|
||||
- 优化版:预计 10-30秒/天
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from sqlalchemy import create_engine, text
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
|
||||
# ==================== 配置 ====================
|
||||
|
||||
MYSQL_ENGINE = create_engine(
|
||||
"mysql+pymysql://root:Zzl5588161!@192.168.1.5:3306/stock",
|
||||
echo=False
|
||||
)
|
||||
|
||||
FEATURES = ['alpha', 'alpha_delta', 'amt_ratio', 'amt_delta', 'rank_pct', 'limit_up_ratio']
|
||||
|
||||
CONFIG = {
|
||||
'seq_len': 15, # 序列长度(支持跨日后可从 9:30 检测)
|
||||
'min_alpha_abs': 0.3, # 最小 alpha 过滤
|
||||
'cooldown_minutes': 8,
|
||||
'max_alerts_per_minute': 20,
|
||||
'clip_value': 10.0,
|
||||
# === 融合权重:均衡 ===
|
||||
'rule_weight': 0.5,
|
||||
'ml_weight': 0.5,
|
||||
# === 触发阈值 ===
|
||||
'rule_trigger': 65, # 60 -> 65,略提高规则门槛
|
||||
'ml_trigger': 70, # 75 -> 70,略降低 ML 门槛
|
||||
'fusion_trigger': 45,
|
||||
}
|
||||
|
||||
|
||||
# ==================== 规则评分(向量化版)====================
|
||||
|
||||
def get_size_adjusted_thresholds(stock_count: np.ndarray) -> dict:
|
||||
"""
|
||||
根据概念股票数量计算动态阈值
|
||||
|
||||
设计思路:
|
||||
- 小概念(<10 只):波动大是正常的,需要更高阈值
|
||||
- 中概念(10-50 只):标准阈值
|
||||
- 大概念(>50 只):能有明显波动说明是真异动,降低阈值
|
||||
|
||||
返回各指标的调整系数(乘以基准阈值)
|
||||
"""
|
||||
n = len(stock_count)
|
||||
|
||||
# 基于股票数量的调整系数
|
||||
# 小概念:系数 > 1(提高阈值,更难触发)
|
||||
# 大概念:系数 < 1(降低阈值,更容易触发)
|
||||
size_factor = np.ones(n)
|
||||
|
||||
# 微型概念(<5 只):阈值 × 1.8
|
||||
tiny = stock_count < 5
|
||||
size_factor[tiny] = 1.8
|
||||
|
||||
# 小概念(5-10 只):阈值 × 1.4
|
||||
small = (stock_count >= 5) & (stock_count < 10)
|
||||
size_factor[small] = 1.4
|
||||
|
||||
# 中小概念(10-20 只):阈值 × 1.2
|
||||
medium_small = (stock_count >= 10) & (stock_count < 20)
|
||||
size_factor[medium_small] = 1.2
|
||||
|
||||
# 中概念(20-50 只):标准阈值 × 1.0
|
||||
medium = (stock_count >= 20) & (stock_count < 50)
|
||||
size_factor[medium] = 1.0
|
||||
|
||||
# 大概念(50-100 只):阈值 × 0.85
|
||||
large = (stock_count >= 50) & (stock_count < 100)
|
||||
size_factor[large] = 0.85
|
||||
|
||||
# 超大概念(>100 只):阈值 × 0.7
|
||||
xlarge = stock_count >= 100
|
||||
size_factor[xlarge] = 0.7
|
||||
|
||||
return size_factor
|
||||
|
||||
|
||||
def score_rules_batch(df: pd.DataFrame) -> Tuple[np.ndarray, List[List[str]]]:
|
||||
"""
|
||||
批量计算规则得分(向量化)- 考虑概念规模版
|
||||
|
||||
设计原则:
|
||||
- 规则作为辅助信号,不应单独主导决策
|
||||
- 根据概念股票数量动态调整阈值
|
||||
- 大概念异动更有价值,小概念需要更大波动才算异动
|
||||
|
||||
Args:
|
||||
df: DataFrame,包含所有特征列(必须包含 stock_count)
|
||||
Returns:
|
||||
scores: (n,) 规则得分数组
|
||||
triggered_rules: 每行触发的规则列表
|
||||
"""
|
||||
n = len(df)
|
||||
scores = np.zeros(n)
|
||||
triggered = [[] for _ in range(n)]
|
||||
|
||||
alpha = df['alpha'].values
|
||||
alpha_delta = df['alpha_delta'].values
|
||||
amt_ratio = df['amt_ratio'].values
|
||||
amt_delta = df['amt_delta'].values
|
||||
rank_pct = df['rank_pct'].values
|
||||
limit_up_ratio = df['limit_up_ratio'].values
|
||||
stock_count = df['stock_count'].values if 'stock_count' in df.columns else np.full(n, 20)
|
||||
|
||||
alpha_abs = np.abs(alpha)
|
||||
alpha_delta_abs = np.abs(alpha_delta)
|
||||
|
||||
# 获取基于规模的调整系数
|
||||
size_factor = get_size_adjusted_thresholds(stock_count)
|
||||
|
||||
# ========== Alpha 规则(动态阈值)==========
|
||||
# 基准阈值:极强 5%,强 4%,中等 3%
|
||||
# 实际阈值 = 基准 × size_factor
|
||||
|
||||
# 极强信号
|
||||
alpha_extreme_thresh = 5.0 * size_factor
|
||||
mask = alpha_abs >= alpha_extreme_thresh
|
||||
scores[mask] += 20
|
||||
for i in np.where(mask)[0]: triggered[i].append('alpha_extreme')
|
||||
|
||||
# 强信号
|
||||
alpha_strong_thresh = 4.0 * size_factor
|
||||
mask = (alpha_abs >= alpha_strong_thresh) & (alpha_abs < alpha_extreme_thresh)
|
||||
scores[mask] += 15
|
||||
for i in np.where(mask)[0]: triggered[i].append('alpha_strong')
|
||||
|
||||
# 中等信号
|
||||
alpha_medium_thresh = 3.0 * size_factor
|
||||
mask = (alpha_abs >= alpha_medium_thresh) & (alpha_abs < alpha_strong_thresh)
|
||||
scores[mask] += 10
|
||||
for i in np.where(mask)[0]: triggered[i].append('alpha_medium')
|
||||
|
||||
# ========== Alpha 加速度规则(动态阈值)==========
|
||||
delta_strong_thresh = 2.0 * size_factor
|
||||
mask = alpha_delta_abs >= delta_strong_thresh
|
||||
scores[mask] += 15
|
||||
for i in np.where(mask)[0]: triggered[i].append('alpha_delta_strong')
|
||||
|
||||
delta_medium_thresh = 1.5 * size_factor
|
||||
mask = (alpha_delta_abs >= delta_medium_thresh) & (alpha_delta_abs < delta_strong_thresh)
|
||||
scores[mask] += 10
|
||||
for i in np.where(mask)[0]: triggered[i].append('alpha_delta_medium')
|
||||
|
||||
# ========== 成交额规则(不受规模影响,放量就是放量)==========
|
||||
mask = amt_ratio >= 10.0
|
||||
scores[mask] += 20
|
||||
for i in np.where(mask)[0]: triggered[i].append('volume_extreme')
|
||||
|
||||
mask = (amt_ratio >= 6.0) & (amt_ratio < 10.0)
|
||||
scores[mask] += 12
|
||||
for i in np.where(mask)[0]: triggered[i].append('volume_strong')
|
||||
|
||||
# ========== 排名规则 ==========
|
||||
mask = rank_pct >= 0.98
|
||||
scores[mask] += 15
|
||||
for i in np.where(mask)[0]: triggered[i].append('rank_top')
|
||||
|
||||
mask = rank_pct <= 0.02
|
||||
scores[mask] += 15
|
||||
for i in np.where(mask)[0]: triggered[i].append('rank_bottom')
|
||||
|
||||
# ========== 涨停规则(动态阈值)==========
|
||||
# 大概念有涨停更有意义
|
||||
limit_high_thresh = 0.30 * size_factor
|
||||
mask = limit_up_ratio >= limit_high_thresh
|
||||
scores[mask] += 20
|
||||
for i in np.where(mask)[0]: triggered[i].append('limit_up_high')
|
||||
|
||||
limit_medium_thresh = 0.20 * size_factor
|
||||
mask = (limit_up_ratio >= limit_medium_thresh) & (limit_up_ratio < limit_high_thresh)
|
||||
scores[mask] += 12
|
||||
for i in np.where(mask)[0]: triggered[i].append('limit_up_medium')
|
||||
|
||||
# ========== 概念规模加分(大概念异动更有价值)==========
|
||||
# 大概念(50+)额外加分
|
||||
large_concept = stock_count >= 50
|
||||
has_signal = scores > 0 # 至少触发了某个规则
|
||||
mask = large_concept & has_signal
|
||||
scores[mask] += 10
|
||||
for i in np.where(mask)[0]: triggered[i].append('large_concept_bonus')
|
||||
|
||||
# 超大概念(100+)再加分
|
||||
xlarge_concept = stock_count >= 100
|
||||
mask = xlarge_concept & has_signal
|
||||
scores[mask] += 10
|
||||
for i in np.where(mask)[0]: triggered[i].append('xlarge_concept_bonus')
|
||||
|
||||
# ========== 组合规则(动态阈值)==========
|
||||
combo_alpha_thresh = 3.0 * size_factor
|
||||
|
||||
# Alpha + 放量 + 排名(三重验证)
|
||||
mask = (alpha_abs >= combo_alpha_thresh) & (amt_ratio >= 5.0) & ((rank_pct >= 0.95) | (rank_pct <= 0.05))
|
||||
scores[mask] += 20
|
||||
for i in np.where(mask)[0]: triggered[i].append('triple_signal')
|
||||
|
||||
# Alpha + 涨停(强组合)
|
||||
mask = (alpha_abs >= combo_alpha_thresh) & (limit_up_ratio >= 0.15 * size_factor)
|
||||
scores[mask] += 15
|
||||
for i in np.where(mask)[0]: triggered[i].append('alpha_with_limit')
|
||||
|
||||
# ========== 小概念惩罚(过滤噪音)==========
|
||||
# 微型概念(<5 只)如果只有单一信号,减分
|
||||
tiny_concept = stock_count < 5
|
||||
single_rule = np.array([len(t) <= 1 for t in triggered])
|
||||
mask = tiny_concept & single_rule & (scores > 0)
|
||||
scores[mask] *= 0.5 # 减半
|
||||
for i in np.where(mask)[0]: triggered[i].append('tiny_concept_penalty')
|
||||
|
||||
scores = np.clip(scores, 0, 100)
|
||||
return scores, triggered
|
||||
|
||||
|
||||
# ==================== ML 评分器 ====================
|
||||
|
||||
class FastMLScorer:
|
||||
"""快速 ML 评分器"""
|
||||
|
||||
def __init__(self, checkpoint_dir: str = 'ml/checkpoints', device: str = 'auto'):
|
||||
self.checkpoint_dir = Path(checkpoint_dir)
|
||||
|
||||
if device == 'auto':
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
elif device == 'cuda' and not torch.cuda.is_available():
|
||||
print("警告: CUDA 不可用,使用 CPU")
|
||||
self.device = torch.device('cpu')
|
||||
else:
|
||||
self.device = torch.device(device)
|
||||
|
||||
self.model = None
|
||||
self.thresholds = None
|
||||
self._load_model()
|
||||
|
||||
def _load_model(self):
|
||||
model_path = self.checkpoint_dir / 'best_model.pt'
|
||||
thresholds_path = self.checkpoint_dir / 'thresholds.json'
|
||||
config_path = self.checkpoint_dir / 'config.json'
|
||||
|
||||
if not model_path.exists():
|
||||
print(f"警告: 模型不存在 {model_path}")
|
||||
return
|
||||
|
||||
try:
|
||||
from model import LSTMAutoencoder
|
||||
|
||||
config = {}
|
||||
if config_path.exists():
|
||||
with open(config_path) as f:
|
||||
config = json.load(f).get('model', {})
|
||||
|
||||
# 处理旧配置键名
|
||||
if 'd_model' in config:
|
||||
config['hidden_dim'] = config.pop('d_model') // 2
|
||||
for key in ['num_encoder_layers', 'num_decoder_layers', 'nhead', 'dim_feedforward', 'max_seq_len', 'use_instance_norm']:
|
||||
config.pop(key, None)
|
||||
if 'num_layers' not in config:
|
||||
config['num_layers'] = 1
|
||||
|
||||
checkpoint = torch.load(model_path, map_location='cpu')
|
||||
self.model = LSTMAutoencoder(**config)
|
||||
self.model.load_state_dict(checkpoint['model_state_dict'])
|
||||
self.model.to(self.device)
|
||||
self.model.eval()
|
||||
|
||||
if thresholds_path.exists():
|
||||
with open(thresholds_path) as f:
|
||||
self.thresholds = json.load(f)
|
||||
|
||||
print(f"ML模型加载成功 (设备: {self.device})")
|
||||
except Exception as e:
|
||||
print(f"ML模型加载失败: {e}")
|
||||
self.model = None
|
||||
|
||||
def is_ready(self):
|
||||
return self.model is not None
|
||||
|
||||
@torch.no_grad()
|
||||
def score_batch(self, sequences: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
批量计算 ML 得分
|
||||
|
||||
Args:
|
||||
sequences: (batch, seq_len, n_features)
|
||||
Returns:
|
||||
scores: (batch,) 0-100 分数
|
||||
"""
|
||||
if not self.is_ready() or len(sequences) == 0:
|
||||
return np.zeros(len(sequences))
|
||||
|
||||
x = torch.FloatTensor(sequences).to(self.device)
|
||||
output, _ = self.model(x)
|
||||
mse = ((output - x) ** 2).mean(dim=-1)
|
||||
errors = mse[:, -1].cpu().numpy()
|
||||
|
||||
p95 = self.thresholds.get('p95', 0.1) if self.thresholds else 0.1
|
||||
scores = np.clip(errors / p95 * 50, 0, 100)
|
||||
return scores
|
||||
|
||||
|
||||
# ==================== 快速回测 ====================
|
||||
|
||||
def build_sequences_fast(
|
||||
df: pd.DataFrame,
|
||||
seq_len: int = 30,
|
||||
prev_df: pd.DataFrame = None
|
||||
) -> Tuple[np.ndarray, pd.DataFrame]:
|
||||
"""
|
||||
快速构建所有有效序列
|
||||
|
||||
支持跨日序列:用前一天收盘数据 + 当天开盘数据拼接,实现 9:30 就能检测
|
||||
|
||||
Args:
|
||||
df: 当天数据
|
||||
seq_len: 序列长度
|
||||
prev_df: 前一天数据(可选,用于构建开盘时的序列)
|
||||
|
||||
返回:
|
||||
sequences: (n_valid, seq_len, n_features) 所有有效序列
|
||||
info_df: 对应的元信息 DataFrame
|
||||
"""
|
||||
# 确保按概念和时间排序
|
||||
df = df.sort_values(['concept_id', 'timestamp']).reset_index(drop=True)
|
||||
|
||||
# 如果有前一天数据,按概念构建尾部缓存(取每个概念最后 seq_len-1 条)
|
||||
prev_cache = {}
|
||||
if prev_df is not None and len(prev_df) > 0:
|
||||
prev_df = prev_df.sort_values(['concept_id', 'timestamp'])
|
||||
for concept_id, gdf in prev_df.groupby('concept_id'):
|
||||
tail_data = gdf.tail(seq_len - 1)
|
||||
if len(tail_data) > 0:
|
||||
feat_matrix = tail_data[FEATURES].values
|
||||
feat_matrix = np.nan_to_num(feat_matrix, nan=0.0, posinf=0.0, neginf=0.0)
|
||||
feat_matrix = np.clip(feat_matrix, -CONFIG['clip_value'], CONFIG['clip_value'])
|
||||
prev_cache[concept_id] = feat_matrix
|
||||
|
||||
# 按概念分组
|
||||
groups = df.groupby('concept_id')
|
||||
|
||||
sequences = []
|
||||
infos = []
|
||||
|
||||
for concept_id, gdf in groups:
|
||||
gdf = gdf.reset_index(drop=True)
|
||||
|
||||
# 获取特征矩阵
|
||||
feat_matrix = gdf[FEATURES].values
|
||||
feat_matrix = np.nan_to_num(feat_matrix, nan=0.0, posinf=0.0, neginf=0.0)
|
||||
feat_matrix = np.clip(feat_matrix, -CONFIG['clip_value'], CONFIG['clip_value'])
|
||||
|
||||
# 如果有前一天缓存,拼接到当天数据前面
|
||||
if concept_id in prev_cache:
|
||||
prev_data = prev_cache[concept_id]
|
||||
combined_matrix = np.vstack([prev_data, feat_matrix])
|
||||
# 计算偏移量:前一天数据的长度
|
||||
offset = len(prev_data)
|
||||
else:
|
||||
combined_matrix = feat_matrix
|
||||
offset = 0
|
||||
|
||||
# 滑动窗口构建序列
|
||||
n_total = len(combined_matrix)
|
||||
if n_total < seq_len:
|
||||
continue
|
||||
|
||||
for i in range(n_total - seq_len + 1):
|
||||
seq = combined_matrix[i:i + seq_len]
|
||||
|
||||
# 计算对应当天数据的索引
|
||||
# 序列最后一个点的位置 = i + seq_len - 1
|
||||
# 对应当天数据的索引 = (i + seq_len - 1) - offset
|
||||
today_idx = i + seq_len - 1 - offset
|
||||
|
||||
# 只要序列的最后一个点是当天的数据,就记录
|
||||
if today_idx < 0 or today_idx >= len(gdf):
|
||||
continue
|
||||
|
||||
sequences.append(seq)
|
||||
|
||||
# 记录最后一个时间步的信息(当天的)
|
||||
row = gdf.iloc[today_idx]
|
||||
infos.append({
|
||||
'concept_id': concept_id,
|
||||
'timestamp': row['timestamp'],
|
||||
'alpha': row['alpha'],
|
||||
'alpha_delta': row.get('alpha_delta', 0),
|
||||
'amt_ratio': row.get('amt_ratio', 1),
|
||||
'amt_delta': row.get('amt_delta', 0),
|
||||
'rank_pct': row.get('rank_pct', 0.5),
|
||||
'limit_up_ratio': row.get('limit_up_ratio', 0),
|
||||
'stock_count': row.get('stock_count', 0),
|
||||
'total_amt': row.get('total_amt', 0),
|
||||
})
|
||||
|
||||
if not sequences:
|
||||
return np.array([]), pd.DataFrame()
|
||||
|
||||
return np.array(sequences), pd.DataFrame(infos)
|
||||
|
||||
|
||||
def backtest_single_day_fast(
|
||||
ml_scorer: FastMLScorer,
|
||||
df: pd.DataFrame,
|
||||
date: str,
|
||||
config: Dict,
|
||||
prev_df: pd.DataFrame = None
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
快速回测单天(向量化版本)
|
||||
|
||||
Args:
|
||||
ml_scorer: ML 评分器
|
||||
df: 当天数据
|
||||
date: 日期
|
||||
config: 配置
|
||||
prev_df: 前一天数据(用于 9:30 开始检测)
|
||||
"""
|
||||
seq_len = config.get('seq_len', 30)
|
||||
|
||||
# 1. 构建所有序列(支持跨日)
|
||||
sequences, info_df = build_sequences_fast(df, seq_len, prev_df)
|
||||
|
||||
if len(sequences) == 0:
|
||||
return []
|
||||
|
||||
# 2. 过滤小波动
|
||||
alpha_abs = np.abs(info_df['alpha'].values)
|
||||
valid_mask = alpha_abs >= config['min_alpha_abs']
|
||||
|
||||
sequences = sequences[valid_mask]
|
||||
info_df = info_df[valid_mask].reset_index(drop=True)
|
||||
|
||||
if len(sequences) == 0:
|
||||
return []
|
||||
|
||||
# 3. 批量规则评分
|
||||
rule_scores, triggered_rules = score_rules_batch(info_df)
|
||||
|
||||
# 4. 批量 ML 评分(分批处理避免显存溢出)
|
||||
batch_size = 2048
|
||||
ml_scores = []
|
||||
for i in range(0, len(sequences), batch_size):
|
||||
batch_seq = sequences[i:i+batch_size]
|
||||
batch_scores = ml_scorer.score_batch(batch_seq)
|
||||
ml_scores.append(batch_scores)
|
||||
ml_scores = np.concatenate(ml_scores) if ml_scores else np.zeros(len(sequences))
|
||||
|
||||
# 5. 融合得分
|
||||
w1, w2 = config['rule_weight'], config['ml_weight']
|
||||
final_scores = w1 * rule_scores + w2 * ml_scores
|
||||
|
||||
# 6. 判断异动
|
||||
is_anomaly = (
|
||||
(rule_scores >= config['rule_trigger']) |
|
||||
(ml_scores >= config['ml_trigger']) |
|
||||
(final_scores >= config['fusion_trigger'])
|
||||
)
|
||||
|
||||
# 7. 应用冷却期(按概念+时间排序后处理)
|
||||
info_df['rule_score'] = rule_scores
|
||||
info_df['ml_score'] = ml_scores
|
||||
info_df['final_score'] = final_scores
|
||||
info_df['is_anomaly'] = is_anomaly
|
||||
info_df['triggered_rules'] = triggered_rules
|
||||
|
||||
# 只保留异动
|
||||
anomaly_df = info_df[info_df['is_anomaly']].copy()
|
||||
|
||||
if len(anomaly_df) == 0:
|
||||
return []
|
||||
|
||||
# 应用冷却期
|
||||
anomaly_df = anomaly_df.sort_values(['concept_id', 'timestamp'])
|
||||
cooldown = {}
|
||||
keep_mask = []
|
||||
|
||||
for _, row in anomaly_df.iterrows():
|
||||
cid = row['concept_id']
|
||||
ts = row['timestamp']
|
||||
|
||||
if cid in cooldown:
|
||||
try:
|
||||
diff = (ts - cooldown[cid]).total_seconds() / 60
|
||||
except:
|
||||
diff = config['cooldown_minutes'] + 1
|
||||
|
||||
if diff < config['cooldown_minutes']:
|
||||
keep_mask.append(False)
|
||||
continue
|
||||
|
||||
cooldown[cid] = ts
|
||||
keep_mask.append(True)
|
||||
|
||||
anomaly_df = anomaly_df[keep_mask]
|
||||
|
||||
# 8. 按时间分组,每分钟最多 max_alerts_per_minute 个
|
||||
alerts = []
|
||||
for ts, group in anomaly_df.groupby('timestamp'):
|
||||
group = group.nlargest(config['max_alerts_per_minute'], 'final_score')
|
||||
|
||||
for _, row in group.iterrows():
|
||||
alpha = row['alpha']
|
||||
if alpha >= 1.5:
|
||||
atype = 'surge_up'
|
||||
elif alpha <= -1.5:
|
||||
atype = 'surge_down'
|
||||
elif row['amt_ratio'] >= 3.0:
|
||||
atype = 'volume_spike'
|
||||
else:
|
||||
atype = 'unknown'
|
||||
|
||||
rule_score = row['rule_score']
|
||||
ml_score = row['ml_score']
|
||||
final_score = row['final_score']
|
||||
|
||||
if rule_score >= config['rule_trigger']:
|
||||
trigger = f'规则强信号({rule_score:.0f}分)'
|
||||
elif ml_score >= config['ml_trigger']:
|
||||
trigger = f'ML强信号({ml_score:.0f}分)'
|
||||
else:
|
||||
trigger = f'融合触发({final_score:.0f}分)'
|
||||
|
||||
alerts.append({
|
||||
'concept_id': row['concept_id'],
|
||||
'alert_time': row['timestamp'],
|
||||
'trade_date': date,
|
||||
'alert_type': atype,
|
||||
'final_score': final_score,
|
||||
'rule_score': rule_score,
|
||||
'ml_score': ml_score,
|
||||
'trigger_reason': trigger,
|
||||
'triggered_rules': row['triggered_rules'],
|
||||
'alpha': alpha,
|
||||
'alpha_delta': row['alpha_delta'],
|
||||
'amt_ratio': row['amt_ratio'],
|
||||
'amt_delta': row['amt_delta'],
|
||||
'rank_pct': row['rank_pct'],
|
||||
'limit_up_ratio': row['limit_up_ratio'],
|
||||
'stock_count': row['stock_count'],
|
||||
'total_amt': row['total_amt'],
|
||||
})
|
||||
|
||||
return alerts
|
||||
|
||||
|
||||
# ==================== 数据加载 ====================
|
||||
|
||||
def load_daily_features(data_dir: str, date: str) -> Optional[pd.DataFrame]:
|
||||
file_path = Path(data_dir) / f"features_{date}.parquet"
|
||||
if not file_path.exists():
|
||||
return None
|
||||
return pd.read_parquet(file_path)
|
||||
|
||||
|
||||
def get_available_dates(data_dir: str, start: str, end: str) -> List[str]:
|
||||
data_path = Path(data_dir)
|
||||
dates = []
|
||||
for f in sorted(data_path.glob("features_*.parquet")):
|
||||
d = f.stem.replace('features_', '')
|
||||
if start <= d <= end:
|
||||
dates.append(d)
|
||||
return dates
|
||||
|
||||
|
||||
def get_prev_trading_day(data_dir: str, date: str) -> Optional[str]:
|
||||
"""获取给定日期之前最近的有数据的交易日"""
|
||||
data_path = Path(data_dir)
|
||||
all_dates = sorted([f.stem.replace('features_', '') for f in data_path.glob("features_*.parquet")])
|
||||
|
||||
for i, d in enumerate(all_dates):
|
||||
if d == date and i > 0:
|
||||
return all_dates[i - 1]
|
||||
return None
|
||||
|
||||
|
||||
def export_to_csv(alerts: List[Dict], path: str):
|
||||
if alerts:
|
||||
pd.DataFrame(alerts).to_csv(path, index=False, encoding='utf-8-sig')
|
||||
print(f"已导出: {path}")
|
||||
|
||||
|
||||
# ==================== 数据库写入 ====================
|
||||
|
||||
def init_db_table():
|
||||
"""
|
||||
初始化数据库表(如果不存在则创建)
|
||||
|
||||
表结构说明:
|
||||
- concept_id: 概念ID
|
||||
- alert_time: 异动时间(精确到分钟)
|
||||
- trade_date: 交易日期
|
||||
- alert_type: 异动类型(surge_up/surge_down/volume_spike/unknown)
|
||||
- final_score: 最终得分(0-100)
|
||||
- rule_score: 规则得分(0-100)
|
||||
- ml_score: ML得分(0-100)
|
||||
- trigger_reason: 触发原因
|
||||
- alpha: 超额收益率
|
||||
- alpha_delta: alpha变化速度
|
||||
- amt_ratio: 成交额放大倍数
|
||||
- rank_pct: 排名百分位
|
||||
- stock_count: 概念股票数量
|
||||
- triggered_rules: 触发的规则列表(JSON)
|
||||
"""
|
||||
create_sql = text("""
|
||||
CREATE TABLE IF NOT EXISTS concept_anomaly_hybrid (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
concept_id VARCHAR(64) NOT NULL,
|
||||
alert_time DATETIME NOT NULL,
|
||||
trade_date DATE NOT NULL,
|
||||
alert_type VARCHAR(32) NOT NULL,
|
||||
final_score FLOAT NOT NULL,
|
||||
rule_score FLOAT NOT NULL,
|
||||
ml_score FLOAT NOT NULL,
|
||||
trigger_reason VARCHAR(64),
|
||||
alpha FLOAT,
|
||||
alpha_delta FLOAT,
|
||||
amt_ratio FLOAT,
|
||||
amt_delta FLOAT,
|
||||
rank_pct FLOAT,
|
||||
limit_up_ratio FLOAT,
|
||||
stock_count INT,
|
||||
total_amt FLOAT,
|
||||
triggered_rules JSON,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
UNIQUE KEY uk_concept_time (concept_id, alert_time, trade_date),
|
||||
INDEX idx_trade_date (trade_date),
|
||||
INDEX idx_concept_id (concept_id),
|
||||
INDEX idx_final_score (final_score),
|
||||
INDEX idx_alert_type (alert_type)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='概念异动检测结果(融合版)'
|
||||
""")
|
||||
|
||||
with MYSQL_ENGINE.begin() as conn:
|
||||
conn.execute(create_sql)
|
||||
print("数据库表已就绪: concept_anomaly_hybrid")
|
||||
|
||||
|
||||
def save_alerts_to_mysql(alerts: List[Dict], dry_run: bool = False) -> int:
|
||||
"""
|
||||
保存异动到 MySQL
|
||||
|
||||
Args:
|
||||
alerts: 异动列表
|
||||
dry_run: 是否只模拟,不实际写入
|
||||
|
||||
Returns:
|
||||
实际保存的记录数
|
||||
"""
|
||||
if not alerts:
|
||||
return 0
|
||||
|
||||
if dry_run:
|
||||
print(f" [Dry Run] 将写入 {len(alerts)} 条异动")
|
||||
return len(alerts)
|
||||
|
||||
saved = 0
|
||||
skipped = 0
|
||||
|
||||
with MYSQL_ENGINE.begin() as conn:
|
||||
for alert in alerts:
|
||||
try:
|
||||
# 检查是否已存在(使用 INSERT IGNORE 更高效)
|
||||
insert_sql = text("""
|
||||
INSERT IGNORE INTO concept_anomaly_hybrid
|
||||
(concept_id, alert_time, trade_date, alert_type,
|
||||
final_score, rule_score, ml_score, trigger_reason,
|
||||
alpha, alpha_delta, amt_ratio, amt_delta,
|
||||
rank_pct, limit_up_ratio, stock_count, total_amt,
|
||||
triggered_rules)
|
||||
VALUES
|
||||
(:concept_id, :alert_time, :trade_date, :alert_type,
|
||||
:final_score, :rule_score, :ml_score, :trigger_reason,
|
||||
:alpha, :alpha_delta, :amt_ratio, :amt_delta,
|
||||
:rank_pct, :limit_up_ratio, :stock_count, :total_amt,
|
||||
:triggered_rules)
|
||||
""")
|
||||
|
||||
result = conn.execute(insert_sql, {
|
||||
'concept_id': alert['concept_id'],
|
||||
'alert_time': alert['alert_time'],
|
||||
'trade_date': alert['trade_date'],
|
||||
'alert_type': alert['alert_type'],
|
||||
'final_score': alert['final_score'],
|
||||
'rule_score': alert['rule_score'],
|
||||
'ml_score': alert['ml_score'],
|
||||
'trigger_reason': alert['trigger_reason'],
|
||||
'alpha': alert.get('alpha', 0),
|
||||
'alpha_delta': alert.get('alpha_delta', 0),
|
||||
'amt_ratio': alert.get('amt_ratio', 1),
|
||||
'amt_delta': alert.get('amt_delta', 0),
|
||||
'rank_pct': alert.get('rank_pct', 0.5),
|
||||
'limit_up_ratio': alert.get('limit_up_ratio', 0),
|
||||
'stock_count': alert.get('stock_count', 0),
|
||||
'total_amt': alert.get('total_amt', 0),
|
||||
'triggered_rules': json.dumps(alert.get('triggered_rules', []), ensure_ascii=False),
|
||||
})
|
||||
|
||||
if result.rowcount > 0:
|
||||
saved += 1
|
||||
else:
|
||||
skipped += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f" 保存失败: {alert['concept_id']} @ {alert['alert_time']} - {e}")
|
||||
|
||||
if skipped > 0:
|
||||
print(f" 跳过 {skipped} 条重复记录")
|
||||
|
||||
return saved
|
||||
|
||||
|
||||
def clear_alerts_by_date(trade_date: str) -> int:
|
||||
"""清除指定日期的异动记录(用于重新回测)"""
|
||||
with MYSQL_ENGINE.begin() as conn:
|
||||
result = conn.execute(
|
||||
text("DELETE FROM concept_anomaly_hybrid WHERE trade_date = :trade_date"),
|
||||
{'trade_date': trade_date}
|
||||
)
|
||||
return result.rowcount
|
||||
|
||||
|
||||
def analyze_alerts(alerts: List[Dict]):
|
||||
if not alerts:
|
||||
print("无异动")
|
||||
return
|
||||
|
||||
df = pd.DataFrame(alerts)
|
||||
print(f"\n总异动: {len(alerts)}")
|
||||
print(f"\n类型分布:\n{df['alert_type'].value_counts()}")
|
||||
print(f"\n得分统计:")
|
||||
print(f" 最终: {df['final_score'].mean():.1f} (max: {df['final_score'].max():.1f})")
|
||||
print(f" 规则: {df['rule_score'].mean():.1f} (max: {df['rule_score'].max():.1f})")
|
||||
print(f" ML: {df['ml_score'].mean():.1f} (max: {df['ml_score'].max():.1f})")
|
||||
|
||||
trigger_type = df['trigger_reason'].apply(
|
||||
lambda x: '规则' if '规则' in x else ('ML' if 'ML' in x else '融合')
|
||||
)
|
||||
print(f"\n触发来源:\n{trigger_type.value_counts()}")
|
||||
|
||||
|
||||
# ==================== 主函数 ====================
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='快速融合异动回测')
|
||||
parser.add_argument('--data_dir', default='ml/data')
|
||||
parser.add_argument('--checkpoint_dir', default='ml/checkpoints')
|
||||
parser.add_argument('--start', required=True)
|
||||
parser.add_argument('--end', default=None)
|
||||
parser.add_argument('--dry-run', action='store_true', help='模拟运行,不写入数据库')
|
||||
parser.add_argument('--export-csv', default=None, help='导出 CSV 文件路径')
|
||||
parser.add_argument('--save-db', action='store_true', help='保存结果到数据库')
|
||||
parser.add_argument('--clear-first', action='store_true', help='写入前先清除该日期的旧数据')
|
||||
parser.add_argument('--device', default='auto')
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.end is None:
|
||||
args.end = args.start
|
||||
|
||||
print("=" * 60)
|
||||
print("快速融合异动回测")
|
||||
print("=" * 60)
|
||||
print(f"日期: {args.start} ~ {args.end}")
|
||||
print(f"设备: {args.device}")
|
||||
print(f"保存数据库: {args.save_db}")
|
||||
print("=" * 60)
|
||||
|
||||
# 初始化数据库表(如果需要保存)
|
||||
if args.save_db and not args.dry_run:
|
||||
init_db_table()
|
||||
|
||||
# 初始化 ML 评分器
|
||||
ml_scorer = FastMLScorer(args.checkpoint_dir, args.device)
|
||||
|
||||
# 获取日期
|
||||
dates = get_available_dates(args.data_dir, args.start, args.end)
|
||||
if not dates:
|
||||
print("无数据")
|
||||
return
|
||||
|
||||
print(f"找到 {len(dates)} 天数据\n")
|
||||
|
||||
# 回测(支持跨日序列)
|
||||
all_alerts = []
|
||||
total_saved = 0
|
||||
prev_df = None # 缓存前一天数据
|
||||
|
||||
for i, date in enumerate(tqdm(dates, desc="回测")):
|
||||
df = load_daily_features(args.data_dir, date)
|
||||
if df is None or df.empty:
|
||||
prev_df = None # 当天无数据,清空缓存
|
||||
continue
|
||||
|
||||
# 第一天需要加载前一天数据(如果存在)
|
||||
if i == 0 and prev_df is None:
|
||||
prev_date = get_prev_trading_day(args.data_dir, date)
|
||||
if prev_date:
|
||||
prev_df = load_daily_features(args.data_dir, prev_date)
|
||||
if prev_df is not None:
|
||||
tqdm.write(f" 加载前一天数据: {prev_date}")
|
||||
|
||||
alerts = backtest_single_day_fast(ml_scorer, df, date, CONFIG, prev_df)
|
||||
all_alerts.extend(alerts)
|
||||
|
||||
# 保存到数据库
|
||||
if args.save_db and alerts:
|
||||
if args.clear_first and not args.dry_run:
|
||||
cleared = clear_alerts_by_date(date)
|
||||
if cleared > 0:
|
||||
tqdm.write(f" 清除 {date} 旧数据: {cleared} 条")
|
||||
|
||||
saved = save_alerts_to_mysql(alerts, dry_run=args.dry_run)
|
||||
total_saved += saved
|
||||
tqdm.write(f" {date}: {len(alerts)} 个异动, 保存 {saved} 条")
|
||||
elif alerts:
|
||||
tqdm.write(f" {date}: {len(alerts)} 个异动")
|
||||
|
||||
# 当天数据成为下一天的 prev_df
|
||||
prev_df = df
|
||||
|
||||
# 导出 CSV
|
||||
if args.export_csv:
|
||||
export_to_csv(all_alerts, args.export_csv)
|
||||
|
||||
# 分析
|
||||
analyze_alerts(all_alerts)
|
||||
|
||||
print(f"\n总计: {len(all_alerts)} 个异动")
|
||||
if args.save_db:
|
||||
print(f"已保存到数据库: {total_saved} 条")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
481
ml/backtest_hybrid.py
Normal file
481
ml/backtest_hybrid.py
Normal file
@@ -0,0 +1,481 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
融合异动回测脚本
|
||||
|
||||
使用 HybridAnomalyDetector 进行回测:
|
||||
- 规则评分 + LSTM Autoencoder 融合判断
|
||||
- 输出更丰富的异动信息
|
||||
|
||||
使用方法:
|
||||
python backtest_hybrid.py --start 2024-01-01 --end 2024-12-01
|
||||
python backtest_hybrid.py --start 2024-11-01 --dry-run
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
from sqlalchemy import create_engine, text
|
||||
|
||||
# 添加父目录到路径
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from detector import HybridAnomalyDetector, create_detector
|
||||
|
||||
|
||||
# ==================== 配置 ====================
|
||||
|
||||
MYSQL_ENGINE = create_engine(
|
||||
"mysql+pymysql://root:Zzl5588161!@192.168.1.5:3306/stock",
|
||||
echo=False
|
||||
)
|
||||
|
||||
FEATURES = [
|
||||
'alpha',
|
||||
'alpha_delta',
|
||||
'amt_ratio',
|
||||
'amt_delta',
|
||||
'rank_pct',
|
||||
'limit_up_ratio',
|
||||
]
|
||||
|
||||
BACKTEST_CONFIG = {
|
||||
'seq_len': 30,
|
||||
'min_alpha_abs': 0.3, # 降低阈值,让规则也能发挥作用
|
||||
'cooldown_minutes': 8,
|
||||
'max_alerts_per_minute': 20,
|
||||
'clip_value': 10.0,
|
||||
}
|
||||
|
||||
|
||||
# ==================== 数据加载 ====================
|
||||
|
||||
def load_daily_features(data_dir: str, date: str) -> Optional[pd.DataFrame]:
|
||||
"""加载单天的特征数据"""
|
||||
file_path = Path(data_dir) / f"features_{date}.parquet"
|
||||
|
||||
if not file_path.exists():
|
||||
return None
|
||||
|
||||
df = pd.read_parquet(file_path)
|
||||
return df
|
||||
|
||||
|
||||
def get_available_dates(data_dir: str, start_date: str, end_date: str) -> List[str]:
|
||||
"""获取可用的日期列表"""
|
||||
data_path = Path(data_dir)
|
||||
all_files = sorted(data_path.glob("features_*.parquet"))
|
||||
|
||||
dates = []
|
||||
for f in all_files:
|
||||
date = f.stem.replace('features_', '')
|
||||
if start_date <= date <= end_date:
|
||||
dates.append(date)
|
||||
|
||||
return dates
|
||||
|
||||
|
||||
# ==================== 融合回测 ====================
|
||||
|
||||
def backtest_single_day_hybrid(
|
||||
detector: HybridAnomalyDetector,
|
||||
df: pd.DataFrame,
|
||||
date: str,
|
||||
seq_len: int = 30
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
使用融合检测器回测单天数据(批量优化版)
|
||||
"""
|
||||
alerts = []
|
||||
|
||||
# 按概念分组,预先构建字典
|
||||
grouped_dict = {cid: cdf for cid, cdf in df.groupby('concept_id', sort=False)}
|
||||
|
||||
# 冷却记录
|
||||
cooldown = {}
|
||||
|
||||
# 获取所有时间点
|
||||
all_timestamps = sorted(df['timestamp'].unique())
|
||||
|
||||
if len(all_timestamps) < seq_len:
|
||||
return alerts
|
||||
|
||||
# 对每个时间点进行检测
|
||||
for t_idx in range(seq_len - 1, len(all_timestamps)):
|
||||
current_time = all_timestamps[t_idx]
|
||||
window_start_time = all_timestamps[t_idx - seq_len + 1]
|
||||
|
||||
# 批量收集该时刻所有候选概念
|
||||
batch_sequences = []
|
||||
batch_features = []
|
||||
batch_infos = []
|
||||
|
||||
for concept_id, concept_df in grouped_dict.items():
|
||||
# 检查冷却(提前过滤)
|
||||
if concept_id in cooldown:
|
||||
last_alert = cooldown[concept_id]
|
||||
if isinstance(current_time, datetime):
|
||||
time_diff = (current_time - last_alert).total_seconds() / 60
|
||||
else:
|
||||
time_diff = BACKTEST_CONFIG['cooldown_minutes'] + 1
|
||||
if time_diff < BACKTEST_CONFIG['cooldown_minutes']:
|
||||
continue
|
||||
|
||||
# 获取时间窗口内的数据
|
||||
mask = (concept_df['timestamp'] >= window_start_time) & (concept_df['timestamp'] <= current_time)
|
||||
window_df = concept_df.loc[mask]
|
||||
|
||||
if len(window_df) < seq_len:
|
||||
continue
|
||||
|
||||
window_df = window_df.sort_values('timestamp').tail(seq_len)
|
||||
|
||||
# 当前时刻特征
|
||||
current_row = window_df.iloc[-1]
|
||||
alpha = current_row.get('alpha', 0)
|
||||
|
||||
# 过滤微小波动(提前过滤)
|
||||
if abs(alpha) < BACKTEST_CONFIG['min_alpha_abs']:
|
||||
continue
|
||||
|
||||
# 提取特征序列
|
||||
sequence = window_df[FEATURES].values
|
||||
sequence = np.nan_to_num(sequence, nan=0.0, posinf=0.0, neginf=0.0)
|
||||
sequence = np.clip(sequence, -BACKTEST_CONFIG['clip_value'], BACKTEST_CONFIG['clip_value'])
|
||||
|
||||
current_features = {
|
||||
'alpha': alpha,
|
||||
'alpha_delta': current_row.get('alpha_delta', 0),
|
||||
'amt_ratio': current_row.get('amt_ratio', 1),
|
||||
'amt_delta': current_row.get('amt_delta', 0),
|
||||
'rank_pct': current_row.get('rank_pct', 0.5),
|
||||
'limit_up_ratio': current_row.get('limit_up_ratio', 0),
|
||||
}
|
||||
|
||||
batch_sequences.append(sequence)
|
||||
batch_features.append(current_features)
|
||||
batch_infos.append({
|
||||
'concept_id': concept_id,
|
||||
'stock_count': current_row.get('stock_count', 0),
|
||||
'total_amt': current_row.get('total_amt', 0),
|
||||
})
|
||||
|
||||
if not batch_sequences:
|
||||
continue
|
||||
|
||||
# 批量 ML 推理
|
||||
sequences_array = np.array(batch_sequences)
|
||||
ml_scores = detector.ml_scorer.score(sequences_array) if detector.ml_scorer.is_ready() else [0.0] * len(batch_sequences)
|
||||
if isinstance(ml_scores, float):
|
||||
ml_scores = [ml_scores]
|
||||
|
||||
# 批量规则评分 + 融合
|
||||
minute_alerts = []
|
||||
for i, (features, info) in enumerate(zip(batch_features, batch_infos)):
|
||||
concept_id = info['concept_id']
|
||||
|
||||
# 规则评分
|
||||
rule_score, rule_details = detector.rule_scorer.score(features)
|
||||
|
||||
# ML 评分
|
||||
ml_score = ml_scores[i] if i < len(ml_scores) else 0.0
|
||||
|
||||
# 融合
|
||||
w1 = detector.config['rule_weight']
|
||||
w2 = detector.config['ml_weight']
|
||||
final_score = w1 * rule_score + w2 * ml_score
|
||||
|
||||
# 判断是否异动
|
||||
is_anomaly = False
|
||||
trigger_reason = ''
|
||||
|
||||
if rule_score >= detector.config['rule_trigger']:
|
||||
is_anomaly = True
|
||||
trigger_reason = f'规则强信号({rule_score:.0f}分)'
|
||||
elif ml_score >= detector.config['ml_trigger']:
|
||||
is_anomaly = True
|
||||
trigger_reason = f'ML强信号({ml_score:.0f}分)'
|
||||
elif final_score >= detector.config['fusion_trigger']:
|
||||
is_anomaly = True
|
||||
trigger_reason = f'融合触发({final_score:.0f}分)'
|
||||
|
||||
if not is_anomaly:
|
||||
continue
|
||||
|
||||
# 异动类型
|
||||
alpha = features.get('alpha', 0)
|
||||
if alpha >= 1.5:
|
||||
anomaly_type = 'surge_up'
|
||||
elif alpha <= -1.5:
|
||||
anomaly_type = 'surge_down'
|
||||
elif features.get('amt_ratio', 1) >= 3.0:
|
||||
anomaly_type = 'volume_spike'
|
||||
else:
|
||||
anomaly_type = 'unknown'
|
||||
|
||||
alert = {
|
||||
'concept_id': concept_id,
|
||||
'alert_time': current_time,
|
||||
'trade_date': date,
|
||||
'alert_type': anomaly_type,
|
||||
'final_score': final_score,
|
||||
'rule_score': rule_score,
|
||||
'ml_score': ml_score,
|
||||
'trigger_reason': trigger_reason,
|
||||
'triggered_rules': list(rule_details.keys()),
|
||||
**features,
|
||||
**info,
|
||||
}
|
||||
|
||||
minute_alerts.append(alert)
|
||||
cooldown[concept_id] = current_time
|
||||
|
||||
# 按最终得分排序
|
||||
minute_alerts.sort(key=lambda x: x['final_score'], reverse=True)
|
||||
alerts.extend(minute_alerts[:BACKTEST_CONFIG['max_alerts_per_minute']])
|
||||
|
||||
return alerts
|
||||
|
||||
|
||||
# ==================== 数据库写入 ====================
|
||||
|
||||
def save_alerts_to_mysql(alerts: List[Dict], dry_run: bool = False) -> int:
|
||||
"""保存异动到 MySQL(增强版字段)"""
|
||||
if not alerts:
|
||||
return 0
|
||||
|
||||
if dry_run:
|
||||
print(f" [Dry Run] 将写入 {len(alerts)} 条异动")
|
||||
return len(alerts)
|
||||
|
||||
saved = 0
|
||||
with MYSQL_ENGINE.begin() as conn:
|
||||
for alert in alerts:
|
||||
try:
|
||||
# 检查是否已存在
|
||||
check_sql = text("""
|
||||
SELECT id FROM concept_minute_alert
|
||||
WHERE concept_id = :concept_id
|
||||
AND alert_time = :alert_time
|
||||
AND trade_date = :trade_date
|
||||
""")
|
||||
exists = conn.execute(check_sql, {
|
||||
'concept_id': alert['concept_id'],
|
||||
'alert_time': alert['alert_time'],
|
||||
'trade_date': alert['trade_date'],
|
||||
}).fetchone()
|
||||
|
||||
if exists:
|
||||
continue
|
||||
|
||||
# 插入新记录
|
||||
insert_sql = text("""
|
||||
INSERT INTO concept_minute_alert
|
||||
(concept_id, concept_name, alert_time, alert_type, trade_date,
|
||||
change_pct, zscore, importance_score, stock_count, extra_info)
|
||||
VALUES
|
||||
(:concept_id, :concept_name, :alert_time, :alert_type, :trade_date,
|
||||
:change_pct, :zscore, :importance_score, :stock_count, :extra_info)
|
||||
""")
|
||||
|
||||
extra_info = {
|
||||
'detection_method': 'hybrid',
|
||||
'final_score': alert['final_score'],
|
||||
'rule_score': alert['rule_score'],
|
||||
'ml_score': alert['ml_score'],
|
||||
'trigger_reason': alert['trigger_reason'],
|
||||
'triggered_rules': alert['triggered_rules'],
|
||||
'alpha': alert.get('alpha', 0),
|
||||
'alpha_delta': alert.get('alpha_delta', 0),
|
||||
'amt_ratio': alert.get('amt_ratio', 1),
|
||||
}
|
||||
|
||||
conn.execute(insert_sql, {
|
||||
'concept_id': alert['concept_id'],
|
||||
'concept_name': alert.get('concept_name', ''),
|
||||
'alert_time': alert['alert_time'],
|
||||
'alert_type': alert['alert_type'],
|
||||
'trade_date': alert['trade_date'],
|
||||
'change_pct': alert.get('alpha', 0),
|
||||
'zscore': alert['final_score'], # 用最终得分作为 zscore
|
||||
'importance_score': alert['final_score'],
|
||||
'stock_count': alert.get('stock_count', 0),
|
||||
'extra_info': json.dumps(extra_info, ensure_ascii=False)
|
||||
})
|
||||
|
||||
saved += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f" 保存失败: {alert['concept_id']} - {e}")
|
||||
|
||||
return saved
|
||||
|
||||
|
||||
def export_alerts_to_csv(alerts: List[Dict], output_path: str):
|
||||
"""导出异动到 CSV"""
|
||||
if not alerts:
|
||||
return
|
||||
|
||||
df = pd.DataFrame(alerts)
|
||||
df.to_csv(output_path, index=False, encoding='utf-8-sig')
|
||||
print(f"已导出到: {output_path}")
|
||||
|
||||
|
||||
# ==================== 统计分析 ====================
|
||||
|
||||
def analyze_alerts(alerts: List[Dict]):
|
||||
"""分析异动结果"""
|
||||
if not alerts:
|
||||
print("无异动数据")
|
||||
return
|
||||
|
||||
df = pd.DataFrame(alerts)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("异动统计分析")
|
||||
print("=" * 60)
|
||||
|
||||
# 1. 基本统计
|
||||
print(f"\n总异动数: {len(alerts)}")
|
||||
|
||||
# 2. 按类型统计
|
||||
print(f"\n异动类型分布:")
|
||||
print(df['alert_type'].value_counts())
|
||||
|
||||
# 3. 得分统计
|
||||
print(f"\n得分统计:")
|
||||
print(f" 最终得分 - Mean: {df['final_score'].mean():.1f}, Max: {df['final_score'].max():.1f}")
|
||||
print(f" 规则得分 - Mean: {df['rule_score'].mean():.1f}, Max: {df['rule_score'].max():.1f}")
|
||||
print(f" ML得分 - Mean: {df['ml_score'].mean():.1f}, Max: {df['ml_score'].max():.1f}")
|
||||
|
||||
# 4. 触发来源分析
|
||||
print(f"\n触发来源分析:")
|
||||
trigger_counts = df['trigger_reason'].apply(
|
||||
lambda x: '规则' if '规则' in x else ('ML' if 'ML' in x else '融合')
|
||||
).value_counts()
|
||||
print(trigger_counts)
|
||||
|
||||
# 5. 规则触发频率
|
||||
all_rules = []
|
||||
for rules in df['triggered_rules']:
|
||||
if isinstance(rules, list):
|
||||
all_rules.extend(rules)
|
||||
|
||||
if all_rules:
|
||||
print(f"\n最常触发的规则 (Top 10):")
|
||||
from collections import Counter
|
||||
rule_counts = Counter(all_rules)
|
||||
for rule, count in rule_counts.most_common(10):
|
||||
print(f" {rule}: {count}")
|
||||
|
||||
|
||||
# ==================== 主函数 ====================
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='融合异动回测')
|
||||
parser.add_argument('--data_dir', type=str, default='ml/data',
|
||||
help='特征数据目录')
|
||||
parser.add_argument('--checkpoint_dir', type=str, default='ml/checkpoints',
|
||||
help='模型检查点目录')
|
||||
parser.add_argument('--start', type=str, required=True,
|
||||
help='开始日期 (YYYY-MM-DD)')
|
||||
parser.add_argument('--end', type=str, default=None,
|
||||
help='结束日期 (YYYY-MM-DD),默认=start')
|
||||
parser.add_argument('--dry-run', action='store_true',
|
||||
help='只计算,不写入数据库')
|
||||
parser.add_argument('--export-csv', type=str, default=None,
|
||||
help='导出 CSV 文件路径')
|
||||
parser.add_argument('--rule-weight', type=float, default=0.6,
|
||||
help='规则权重 (0-1)')
|
||||
parser.add_argument('--ml-weight', type=float, default=0.4,
|
||||
help='ML权重 (0-1)')
|
||||
parser.add_argument('--device', type=str, default='cuda',
|
||||
help='设备 (cuda/cpu),默认 cuda')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.end is None:
|
||||
args.end = args.start
|
||||
|
||||
print("=" * 60)
|
||||
print("融合异动回测 (规则 + LSTM)")
|
||||
print("=" * 60)
|
||||
print(f"日期范围: {args.start} ~ {args.end}")
|
||||
print(f"数据目录: {args.data_dir}")
|
||||
print(f"模型目录: {args.checkpoint_dir}")
|
||||
print(f"规则权重: {args.rule_weight}")
|
||||
print(f"ML权重: {args.ml_weight}")
|
||||
print(f"设备: {args.device}")
|
||||
print(f"Dry Run: {args.dry_run}")
|
||||
print("=" * 60)
|
||||
|
||||
# 初始化融合检测器(使用 GPU)
|
||||
config = {
|
||||
'rule_weight': args.rule_weight,
|
||||
'ml_weight': args.ml_weight,
|
||||
}
|
||||
|
||||
# 修改 detector.py 中 MLScorer 的设备
|
||||
from detector import HybridAnomalyDetector
|
||||
detector = HybridAnomalyDetector(config, args.checkpoint_dir, device=args.device)
|
||||
|
||||
# 获取可用日期
|
||||
dates = get_available_dates(args.data_dir, args.start, args.end)
|
||||
|
||||
if not dates:
|
||||
print(f"未找到 {args.start} ~ {args.end} 范围内的数据")
|
||||
return
|
||||
|
||||
print(f"\n找到 {len(dates)} 天的数据")
|
||||
|
||||
# 回测
|
||||
all_alerts = []
|
||||
total_saved = 0
|
||||
|
||||
for date in tqdm(dates, desc="回测进度"):
|
||||
df = load_daily_features(args.data_dir, date)
|
||||
|
||||
if df is None or df.empty:
|
||||
continue
|
||||
|
||||
alerts = backtest_single_day_hybrid(
|
||||
detector, df, date,
|
||||
seq_len=BACKTEST_CONFIG['seq_len']
|
||||
)
|
||||
|
||||
if alerts:
|
||||
all_alerts.extend(alerts)
|
||||
|
||||
saved = save_alerts_to_mysql(alerts, dry_run=args.dry_run)
|
||||
total_saved += saved
|
||||
|
||||
if not args.dry_run:
|
||||
tqdm.write(f" {date}: 检测到 {len(alerts)} 个异动,保存 {saved} 条")
|
||||
|
||||
# 导出 CSV
|
||||
if args.export_csv and all_alerts:
|
||||
export_alerts_to_csv(all_alerts, args.export_csv)
|
||||
|
||||
# 统计分析
|
||||
analyze_alerts(all_alerts)
|
||||
|
||||
# 汇总
|
||||
print("\n" + "=" * 60)
|
||||
print("回测完成!")
|
||||
print("=" * 60)
|
||||
print(f"总计检测到: {len(all_alerts)} 个异动")
|
||||
print(f"保存到数据库: {total_saved} 条")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
294
ml/backtest_v2.py
Normal file
294
ml/backtest_v2.py
Normal file
@@ -0,0 +1,294 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
V2 回测脚本 - 验证时间片对齐 + 持续性确认的效果
|
||||
|
||||
回测指标:
|
||||
1. 准确率:异动后 N 分钟内 alpha 是否继续上涨/下跌
|
||||
2. 虚警率:多少异动是噪音
|
||||
3. 持续性:平均异动持续时长
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import argparse
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
from sqlalchemy import create_engine, text
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from ml.detector_v2 import AnomalyDetectorV2, CONFIG
|
||||
|
||||
|
||||
# ==================== 配置 ====================
|
||||
|
||||
MYSQL_ENGINE = create_engine(
|
||||
"mysql+pymysql://root:Zzl5588161!@192.168.1.5:3306/stock",
|
||||
echo=False
|
||||
)
|
||||
|
||||
|
||||
# ==================== 回测评估 ====================
|
||||
|
||||
def evaluate_alerts(
|
||||
alerts: List[Dict],
|
||||
raw_data: pd.DataFrame,
|
||||
lookahead_minutes: int = 10
|
||||
) -> Dict:
|
||||
"""
|
||||
评估异动质量
|
||||
|
||||
指标:
|
||||
1. 方向正确率:异动后 N 分钟 alpha 方向是否一致
|
||||
2. 持续率:异动后 N 分钟内有多少时刻 alpha 保持同向
|
||||
3. 峰值收益:异动后 N 分钟内的最大 alpha
|
||||
"""
|
||||
if not alerts:
|
||||
return {'accuracy': 0, 'sustained_rate': 0, 'avg_peak': 0, 'total_alerts': 0}
|
||||
|
||||
results = []
|
||||
|
||||
for alert in alerts:
|
||||
concept_id = alert['concept_id']
|
||||
alert_time = alert['alert_time']
|
||||
alert_alpha = alert['alpha']
|
||||
is_up = alert_alpha > 0
|
||||
|
||||
# 获取该概念在异动后的数据
|
||||
concept_data = raw_data[
|
||||
(raw_data['concept_id'] == concept_id) &
|
||||
(raw_data['timestamp'] > alert_time)
|
||||
].head(lookahead_minutes)
|
||||
|
||||
if len(concept_data) < 3:
|
||||
continue
|
||||
|
||||
future_alphas = concept_data['alpha'].values
|
||||
|
||||
# 方向正确:未来 alpha 平均值与当前同向
|
||||
avg_future_alpha = np.mean(future_alphas)
|
||||
direction_correct = (is_up and avg_future_alpha > 0) or (not is_up and avg_future_alpha < 0)
|
||||
|
||||
# 持续率:有多少时刻保持同向
|
||||
if is_up:
|
||||
sustained_count = sum(1 for a in future_alphas if a > 0)
|
||||
else:
|
||||
sustained_count = sum(1 for a in future_alphas if a < 0)
|
||||
sustained_rate = sustained_count / len(future_alphas)
|
||||
|
||||
# 峰值收益
|
||||
if is_up:
|
||||
peak = max(future_alphas)
|
||||
else:
|
||||
peak = min(future_alphas)
|
||||
|
||||
results.append({
|
||||
'direction_correct': direction_correct,
|
||||
'sustained_rate': sustained_rate,
|
||||
'peak': peak,
|
||||
'alert_alpha': alert_alpha,
|
||||
})
|
||||
|
||||
if not results:
|
||||
return {'accuracy': 0, 'sustained_rate': 0, 'avg_peak': 0, 'total_alerts': 0}
|
||||
|
||||
return {
|
||||
'accuracy': np.mean([r['direction_correct'] for r in results]),
|
||||
'sustained_rate': np.mean([r['sustained_rate'] for r in results]),
|
||||
'avg_peak': np.mean([abs(r['peak']) for r in results]),
|
||||
'total_alerts': len(alerts),
|
||||
'evaluated_alerts': len(results),
|
||||
}
|
||||
|
||||
|
||||
def save_alerts_to_mysql(alerts: List[Dict], dry_run: bool = False) -> int:
|
||||
"""保存异动到 MySQL"""
|
||||
if not alerts or dry_run:
|
||||
return 0
|
||||
|
||||
# 确保表存在
|
||||
with MYSQL_ENGINE.begin() as conn:
|
||||
conn.execute(text("""
|
||||
CREATE TABLE IF NOT EXISTS concept_anomaly_v2 (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
concept_id VARCHAR(64) NOT NULL,
|
||||
alert_time DATETIME NOT NULL,
|
||||
trade_date DATE NOT NULL,
|
||||
alert_type VARCHAR(32) NOT NULL,
|
||||
final_score FLOAT NOT NULL,
|
||||
rule_score FLOAT NOT NULL,
|
||||
ml_score FLOAT NOT NULL,
|
||||
trigger_reason VARCHAR(128),
|
||||
confirm_ratio FLOAT,
|
||||
alpha FLOAT,
|
||||
alpha_zscore FLOAT,
|
||||
amt_zscore FLOAT,
|
||||
rank_zscore FLOAT,
|
||||
momentum_3m FLOAT,
|
||||
momentum_5m FLOAT,
|
||||
limit_up_ratio FLOAT,
|
||||
triggered_rules JSON,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
UNIQUE KEY uk_concept_time (concept_id, alert_time, trade_date),
|
||||
INDEX idx_trade_date (trade_date),
|
||||
INDEX idx_final_score (final_score)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='概念异动 V2(时间片对齐+持续确认)'
|
||||
"""))
|
||||
|
||||
# 插入数据
|
||||
saved = 0
|
||||
with MYSQL_ENGINE.begin() as conn:
|
||||
for alert in alerts:
|
||||
try:
|
||||
conn.execute(text("""
|
||||
INSERT IGNORE INTO concept_anomaly_v2
|
||||
(concept_id, alert_time, trade_date, alert_type,
|
||||
final_score, rule_score, ml_score, trigger_reason, confirm_ratio,
|
||||
alpha, alpha_zscore, amt_zscore, rank_zscore,
|
||||
momentum_3m, momentum_5m, limit_up_ratio, triggered_rules)
|
||||
VALUES
|
||||
(:concept_id, :alert_time, :trade_date, :alert_type,
|
||||
:final_score, :rule_score, :ml_score, :trigger_reason, :confirm_ratio,
|
||||
:alpha, :alpha_zscore, :amt_zscore, :rank_zscore,
|
||||
:momentum_3m, :momentum_5m, :limit_up_ratio, :triggered_rules)
|
||||
"""), {
|
||||
'concept_id': alert['concept_id'],
|
||||
'alert_time': alert['alert_time'],
|
||||
'trade_date': alert['trade_date'],
|
||||
'alert_type': alert['alert_type'],
|
||||
'final_score': alert['final_score'],
|
||||
'rule_score': alert['rule_score'],
|
||||
'ml_score': alert['ml_score'],
|
||||
'trigger_reason': alert['trigger_reason'],
|
||||
'confirm_ratio': alert.get('confirm_ratio', 0),
|
||||
'alpha': alert['alpha'],
|
||||
'alpha_zscore': alert.get('alpha_zscore', 0),
|
||||
'amt_zscore': alert.get('amt_zscore', 0),
|
||||
'rank_zscore': alert.get('rank_zscore', 0),
|
||||
'momentum_3m': alert.get('momentum_3m', 0),
|
||||
'momentum_5m': alert.get('momentum_5m', 0),
|
||||
'limit_up_ratio': alert.get('limit_up_ratio', 0),
|
||||
'triggered_rules': json.dumps(alert.get('triggered_rules', [])),
|
||||
})
|
||||
saved += 1
|
||||
except Exception as e:
|
||||
print(f"保存失败: {e}")
|
||||
|
||||
return saved
|
||||
|
||||
|
||||
# ==================== 主函数 ====================
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='V2 回测')
|
||||
parser.add_argument('--start', type=str, required=True, help='开始日期')
|
||||
parser.add_argument('--end', type=str, default=None, help='结束日期')
|
||||
parser.add_argument('--model_dir', type=str, default='ml/checkpoints_v2')
|
||||
parser.add_argument('--baseline_dir', type=str, default='ml/data_v2/baselines')
|
||||
parser.add_argument('--save', action='store_true', help='保存到数据库')
|
||||
parser.add_argument('--lookahead', type=int, default=10, help='评估前瞻时间(分钟)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
end_date = args.end or args.start
|
||||
|
||||
print("=" * 60)
|
||||
print("V2 回测 - 时间片对齐 + 持续性确认")
|
||||
print("=" * 60)
|
||||
print(f"日期范围: {args.start} ~ {end_date}")
|
||||
print(f"模型目录: {args.model_dir}")
|
||||
print(f"评估前瞻: {args.lookahead} 分钟")
|
||||
|
||||
# 初始化检测器
|
||||
detector = AnomalyDetectorV2(
|
||||
model_dir=args.model_dir,
|
||||
baseline_dir=args.baseline_dir
|
||||
)
|
||||
|
||||
# 获取交易日
|
||||
from prepare_data_v2 import get_trading_days
|
||||
trading_days = get_trading_days(args.start, end_date)
|
||||
|
||||
if not trading_days:
|
||||
print("无交易日")
|
||||
return
|
||||
|
||||
print(f"交易日数: {len(trading_days)}")
|
||||
|
||||
# 回测统计
|
||||
total_stats = {
|
||||
'total_alerts': 0,
|
||||
'accuracy_sum': 0,
|
||||
'sustained_sum': 0,
|
||||
'peak_sum': 0,
|
||||
'day_count': 0,
|
||||
}
|
||||
|
||||
all_alerts = []
|
||||
|
||||
for trade_date in tqdm(trading_days, desc="回测进度"):
|
||||
# 检测异动
|
||||
alerts = detector.detect(trade_date)
|
||||
|
||||
if not alerts:
|
||||
continue
|
||||
|
||||
all_alerts.extend(alerts)
|
||||
|
||||
# 评估
|
||||
raw_data = detector._compute_raw_features(trade_date)
|
||||
if raw_data.empty:
|
||||
continue
|
||||
|
||||
stats = evaluate_alerts(alerts, raw_data, args.lookahead)
|
||||
|
||||
if stats['evaluated_alerts'] > 0:
|
||||
total_stats['total_alerts'] += stats['total_alerts']
|
||||
total_stats['accuracy_sum'] += stats['accuracy'] * stats['evaluated_alerts']
|
||||
total_stats['sustained_sum'] += stats['sustained_rate'] * stats['evaluated_alerts']
|
||||
total_stats['peak_sum'] += stats['avg_peak'] * stats['evaluated_alerts']
|
||||
total_stats['day_count'] += 1
|
||||
|
||||
print(f"\n[{trade_date}] 异动: {stats['total_alerts']}, "
|
||||
f"准确率: {stats['accuracy']:.1%}, "
|
||||
f"持续率: {stats['sustained_rate']:.1%}, "
|
||||
f"峰值: {stats['avg_peak']:.2f}%")
|
||||
|
||||
# 汇总
|
||||
print("\n" + "=" * 60)
|
||||
print("回测汇总")
|
||||
print("=" * 60)
|
||||
|
||||
if total_stats['total_alerts'] > 0:
|
||||
avg_accuracy = total_stats['accuracy_sum'] / total_stats['total_alerts']
|
||||
avg_sustained = total_stats['sustained_sum'] / total_stats['total_alerts']
|
||||
avg_peak = total_stats['peak_sum'] / total_stats['total_alerts']
|
||||
|
||||
print(f"总异动数: {total_stats['total_alerts']}")
|
||||
print(f"回测天数: {total_stats['day_count']}")
|
||||
print(f"平均每天: {total_stats['total_alerts'] / max(1, total_stats['day_count']):.1f} 个")
|
||||
print(f"方向准确率: {avg_accuracy:.1%}")
|
||||
print(f"持续率: {avg_sustained:.1%}")
|
||||
print(f"平均峰值: {avg_peak:.2f}%")
|
||||
else:
|
||||
print("无异动检测结果")
|
||||
|
||||
# 保存
|
||||
if args.save and all_alerts:
|
||||
print(f"\n保存 {len(all_alerts)} 条异动到数据库...")
|
||||
saved = save_alerts_to_mysql(all_alerts)
|
||||
print(f"保存完成: {saved} 条")
|
||||
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
31
ml/checkpoints_v2/config.json
Normal file
31
ml/checkpoints_v2/config.json
Normal file
@@ -0,0 +1,31 @@
|
||||
{
|
||||
"seq_len": 10,
|
||||
"stride": 2,
|
||||
"train_end_date": "2025-06-30",
|
||||
"val_end_date": "2025-09-30",
|
||||
"features": [
|
||||
"alpha_zscore",
|
||||
"amt_zscore",
|
||||
"rank_zscore",
|
||||
"momentum_3m",
|
||||
"momentum_5m",
|
||||
"limit_up_ratio"
|
||||
],
|
||||
"batch_size": 32768,
|
||||
"epochs": 150,
|
||||
"learning_rate": 0.0006,
|
||||
"weight_decay": 1e-05,
|
||||
"gradient_clip": 1.0,
|
||||
"patience": 15,
|
||||
"min_delta": 1e-06,
|
||||
"model": {
|
||||
"n_features": 6,
|
||||
"hidden_dim": 32,
|
||||
"latent_dim": 4,
|
||||
"num_layers": 1,
|
||||
"dropout": 0.2,
|
||||
"bidirectional": true
|
||||
},
|
||||
"clip_value": 5.0,
|
||||
"threshold_percentiles": [90, 95, 99]
|
||||
}
|
||||
8
ml/checkpoints_v2/thresholds.json
Normal file
8
ml/checkpoints_v2/thresholds.json
Normal file
@@ -0,0 +1,8 @@
|
||||
{
|
||||
"p90": 0.15,
|
||||
"p95": 0.25,
|
||||
"p99": 0.50,
|
||||
"mean": 0.08,
|
||||
"std": 0.12,
|
||||
"median": 0.06
|
||||
}
|
||||
635
ml/detector.py
Normal file
635
ml/detector.py
Normal file
@@ -0,0 +1,635 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
概念异动检测器 - 融合版
|
||||
|
||||
结合两种方法的优势:
|
||||
1. 规则评分系统:可解释、稳定、覆盖已知模式
|
||||
2. LSTM Autoencoder:发现未知的异常模式
|
||||
|
||||
融合策略:
|
||||
┌─────────────────────────────────────────────────────────┐
|
||||
│ 输入特征 │
|
||||
│ (alpha, alpha_delta, amt_ratio, amt_delta, rank_pct, │
|
||||
│ limit_up_ratio) │
|
||||
├─────────────────────────────────────────────────────────┤
|
||||
│ │
|
||||
│ ┌──────────────┐ ┌──────────────┐ │
|
||||
│ │ 规则评分系统 │ │ LSTM Autoencoder │ │
|
||||
│ │ (0-100分) │ │ (重构误差) │ │
|
||||
│ └──────┬───────┘ └──────┬───────┘ │
|
||||
│ │ │ │
|
||||
│ ▼ ▼ │
|
||||
│ rule_score (0-100) ml_score (标准化后 0-100) │
|
||||
│ │
|
||||
├─────────────────────────────────────────────────────────┤
|
||||
│ 融合策略 │
|
||||
│ │
|
||||
│ final_score = w1 * rule_score + w2 * ml_score │
|
||||
│ │
|
||||
│ 异动判定: │
|
||||
│ - rule_score >= 60 → 直接触发(规则强信号) │
|
||||
│ - ml_score >= 80 → 直接触发(ML强信号) │
|
||||
│ - final_score >= 50 → 融合触发 │
|
||||
│ │
|
||||
└─────────────────────────────────────────────────────────┘
|
||||
|
||||
优势:
|
||||
- 规则系统保证已知模式的检出率
|
||||
- ML模型捕捉规则未覆盖的异常
|
||||
- 两者互相验证,减少误报
|
||||
"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
# 尝试导入模型(可能不存在)
|
||||
try:
|
||||
from model import LSTMAutoencoder, create_model
|
||||
HAS_MODEL = True
|
||||
except ImportError:
|
||||
HAS_MODEL = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class AnomalyResult:
|
||||
"""异动检测结果"""
|
||||
is_anomaly: bool
|
||||
final_score: float # 最终得分 (0-100)
|
||||
rule_score: float # 规则得分 (0-100)
|
||||
ml_score: float # ML得分 (0-100)
|
||||
trigger_reason: str # 触发原因
|
||||
rule_details: Dict # 规则明细
|
||||
anomaly_type: str # 异动类型: surge_up / surge_down / volume_spike / unknown
|
||||
|
||||
|
||||
class RuleBasedScorer:
|
||||
"""
|
||||
基于规则的评分系统
|
||||
|
||||
设计原则:
|
||||
- 每个规则独立打分
|
||||
- 分数可叠加
|
||||
- 阈值可配置
|
||||
"""
|
||||
|
||||
# 默认规则配置
|
||||
DEFAULT_RULES = {
|
||||
# Alpha 相关(超额收益)
|
||||
'alpha_strong': {
|
||||
'condition': lambda r: abs(r.get('alpha', 0)) >= 3.0,
|
||||
'score': 35,
|
||||
'description': 'Alpha强信号(|α|≥3%)'
|
||||
},
|
||||
'alpha_medium': {
|
||||
'condition': lambda r: 2.0 <= abs(r.get('alpha', 0)) < 3.0,
|
||||
'score': 25,
|
||||
'description': 'Alpha中等(2%≤|α|<3%)'
|
||||
},
|
||||
'alpha_weak': {
|
||||
'condition': lambda r: 1.5 <= abs(r.get('alpha', 0)) < 2.0,
|
||||
'score': 15,
|
||||
'description': 'Alpha轻微(1.5%≤|α|<2%)'
|
||||
},
|
||||
|
||||
# Alpha 变化率(加速度)
|
||||
'alpha_delta_strong': {
|
||||
'condition': lambda r: abs(r.get('alpha_delta', 0)) >= 1.0,
|
||||
'score': 30,
|
||||
'description': 'Alpha加速强(|Δα|≥1%)'
|
||||
},
|
||||
'alpha_delta_medium': {
|
||||
'condition': lambda r: 0.5 <= abs(r.get('alpha_delta', 0)) < 1.0,
|
||||
'score': 20,
|
||||
'description': 'Alpha加速中(0.5%≤|Δα|<1%)'
|
||||
},
|
||||
|
||||
# 成交额比率(放量)
|
||||
'volume_spike_strong': {
|
||||
'condition': lambda r: r.get('amt_ratio', 1) >= 5.0,
|
||||
'score': 30,
|
||||
'description': '极度放量(≥5倍)'
|
||||
},
|
||||
'volume_spike_medium': {
|
||||
'condition': lambda r: 3.0 <= r.get('amt_ratio', 1) < 5.0,
|
||||
'score': 20,
|
||||
'description': '显著放量(3-5倍)'
|
||||
},
|
||||
'volume_spike_weak': {
|
||||
'condition': lambda r: 2.0 <= r.get('amt_ratio', 1) < 3.0,
|
||||
'score': 10,
|
||||
'description': '轻微放量(2-3倍)'
|
||||
},
|
||||
|
||||
# 成交额变化率
|
||||
'amt_delta_strong': {
|
||||
'condition': lambda r: abs(r.get('amt_delta', 0)) >= 1.0,
|
||||
'score': 15,
|
||||
'description': '成交额急变(|Δamt|≥100%)'
|
||||
},
|
||||
|
||||
# 排名跳变
|
||||
'rank_top': {
|
||||
'condition': lambda r: r.get('rank_pct', 0.5) >= 0.95,
|
||||
'score': 25,
|
||||
'description': '排名前5%'
|
||||
},
|
||||
'rank_bottom': {
|
||||
'condition': lambda r: r.get('rank_pct', 0.5) <= 0.05,
|
||||
'score': 25,
|
||||
'description': '排名后5%'
|
||||
},
|
||||
'rank_high': {
|
||||
'condition': lambda r: 0.9 <= r.get('rank_pct', 0.5) < 0.95,
|
||||
'score': 15,
|
||||
'description': '排名前10%'
|
||||
},
|
||||
|
||||
# 涨停比例
|
||||
'limit_up_high': {
|
||||
'condition': lambda r: r.get('limit_up_ratio', 0) >= 0.2,
|
||||
'score': 25,
|
||||
'description': '涨停比例≥20%'
|
||||
},
|
||||
'limit_up_medium': {
|
||||
'condition': lambda r: 0.1 <= r.get('limit_up_ratio', 0) < 0.2,
|
||||
'score': 15,
|
||||
'description': '涨停比例10-20%'
|
||||
},
|
||||
|
||||
# 组合条件(更可靠的信号)
|
||||
'alpha_with_volume': {
|
||||
'condition': lambda r: abs(r.get('alpha', 0)) >= 1.5 and r.get('amt_ratio', 1) >= 2.0,
|
||||
'score': 20, # 额外加分
|
||||
'description': 'Alpha+放量组合'
|
||||
},
|
||||
'acceleration_with_rank': {
|
||||
'condition': lambda r: abs(r.get('alpha_delta', 0)) >= 0.5 and (r.get('rank_pct', 0.5) >= 0.9 or r.get('rank_pct', 0.5) <= 0.1),
|
||||
'score': 15, # 额外加分
|
||||
'description': '加速+排名异常组合'
|
||||
},
|
||||
}
|
||||
|
||||
def __init__(self, rules: Dict = None):
|
||||
"""
|
||||
初始化规则评分器
|
||||
|
||||
Args:
|
||||
rules: 自定义规则,格式同 DEFAULT_RULES
|
||||
"""
|
||||
self.rules = rules or self.DEFAULT_RULES
|
||||
|
||||
def score(self, features: Dict) -> Tuple[float, Dict]:
|
||||
"""
|
||||
计算规则得分
|
||||
|
||||
Args:
|
||||
features: 特征字典,包含 alpha, alpha_delta, amt_ratio 等
|
||||
Returns:
|
||||
score: 总分 (0-100)
|
||||
details: 触发的规则明细
|
||||
"""
|
||||
total_score = 0
|
||||
triggered_rules = {}
|
||||
|
||||
for rule_name, rule_config in self.rules.items():
|
||||
try:
|
||||
if rule_config['condition'](features):
|
||||
total_score += rule_config['score']
|
||||
triggered_rules[rule_name] = {
|
||||
'score': rule_config['score'],
|
||||
'description': rule_config['description']
|
||||
}
|
||||
except Exception:
|
||||
# 忽略规则计算错误
|
||||
pass
|
||||
|
||||
# 限制在 0-100
|
||||
total_score = min(100, max(0, total_score))
|
||||
|
||||
return total_score, triggered_rules
|
||||
|
||||
def get_anomaly_type(self, features: Dict) -> str:
|
||||
"""判断异动类型"""
|
||||
alpha = features.get('alpha', 0)
|
||||
amt_ratio = features.get('amt_ratio', 1)
|
||||
|
||||
if alpha >= 1.5:
|
||||
return 'surge_up'
|
||||
elif alpha <= -1.5:
|
||||
return 'surge_down'
|
||||
elif amt_ratio >= 3.0:
|
||||
return 'volume_spike'
|
||||
else:
|
||||
return 'unknown'
|
||||
|
||||
|
||||
class MLScorer:
|
||||
"""
|
||||
基于 LSTM Autoencoder 的评分器
|
||||
|
||||
将重构误差转换为 0-100 的分数
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
checkpoint_dir: str = 'ml/checkpoints',
|
||||
device: str = 'auto'
|
||||
):
|
||||
self.checkpoint_dir = Path(checkpoint_dir)
|
||||
|
||||
# 设备检测
|
||||
if device == 'auto':
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
elif device == 'cuda' and not torch.cuda.is_available():
|
||||
print("警告: CUDA 不可用,使用 CPU")
|
||||
self.device = torch.device('cpu')
|
||||
else:
|
||||
self.device = torch.device(device)
|
||||
|
||||
self.model = None
|
||||
self.thresholds = None
|
||||
self.config = None
|
||||
|
||||
# 尝试加载模型
|
||||
self._load_model()
|
||||
|
||||
def _load_model(self):
|
||||
"""加载模型和阈值"""
|
||||
if not HAS_MODEL:
|
||||
print("警告: 无法导入模型模块")
|
||||
return
|
||||
|
||||
model_path = self.checkpoint_dir / 'best_model.pt'
|
||||
thresholds_path = self.checkpoint_dir / 'thresholds.json'
|
||||
config_path = self.checkpoint_dir / 'config.json'
|
||||
|
||||
if not model_path.exists():
|
||||
print(f"警告: 模型文件不存在 {model_path}")
|
||||
return
|
||||
|
||||
try:
|
||||
# 加载配置
|
||||
if config_path.exists():
|
||||
with open(config_path, 'r') as f:
|
||||
self.config = json.load(f)
|
||||
|
||||
# 先用 CPU 加载模型(避免 CUDA 不可用问题),再移动到目标设备
|
||||
checkpoint = torch.load(model_path, map_location='cpu')
|
||||
|
||||
model_config = self.config.get('model', {}) if self.config else {}
|
||||
self.model = create_model(model_config)
|
||||
self.model.load_state_dict(checkpoint['model_state_dict'])
|
||||
self.model.to(self.device)
|
||||
self.model.eval()
|
||||
|
||||
# 加载阈值
|
||||
if thresholds_path.exists():
|
||||
with open(thresholds_path, 'r') as f:
|
||||
self.thresholds = json.load(f)
|
||||
|
||||
print(f"MLScorer 加载成功 (设备: {self.device})")
|
||||
|
||||
except Exception as e:
|
||||
print(f"警告: 模型加载失败 - {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
self.model = None
|
||||
|
||||
def is_ready(self) -> bool:
|
||||
"""检查模型是否就绪"""
|
||||
return self.model is not None
|
||||
|
||||
@torch.no_grad()
|
||||
def score(self, sequence: np.ndarray) -> float:
|
||||
"""
|
||||
计算 ML 得分
|
||||
|
||||
Args:
|
||||
sequence: (seq_len, n_features) 或 (batch, seq_len, n_features)
|
||||
Returns:
|
||||
score: 0-100 的分数,越高越异常
|
||||
"""
|
||||
if not self.is_ready():
|
||||
return 0.0
|
||||
|
||||
# 确保是 3D
|
||||
if sequence.ndim == 2:
|
||||
sequence = sequence[np.newaxis, ...]
|
||||
|
||||
# 转为 tensor
|
||||
x = torch.FloatTensor(sequence).to(self.device)
|
||||
|
||||
# 计算重构误差
|
||||
output, _ = self.model(x)
|
||||
mse = ((output - x) ** 2).mean(dim=-1) # (batch, seq_len)
|
||||
|
||||
# 取最后时刻的误差
|
||||
error = mse[:, -1].cpu().numpy()
|
||||
|
||||
# 转换为 0-100 分数
|
||||
# 使用 p95 阈值作为参考
|
||||
if self.thresholds:
|
||||
p95 = self.thresholds.get('p95', 0.1)
|
||||
p99 = self.thresholds.get('p99', 0.2)
|
||||
else:
|
||||
p95, p99 = 0.1, 0.2
|
||||
|
||||
# 线性映射:p95 -> 50分, p99 -> 80分
|
||||
# error=0 -> 0分, error>=p99*1.5 -> 100分
|
||||
score = np.clip(error / p95 * 50, 0, 100)
|
||||
|
||||
return float(score[0]) if len(score) == 1 else score.tolist()
|
||||
|
||||
|
||||
class HybridAnomalyDetector:
|
||||
"""
|
||||
融合异动检测器
|
||||
|
||||
结合规则系统和 ML 模型
|
||||
"""
|
||||
|
||||
# 默认配置
|
||||
DEFAULT_CONFIG = {
|
||||
# 权重配置
|
||||
'rule_weight': 0.6, # 规则权重
|
||||
'ml_weight': 0.4, # ML权重
|
||||
|
||||
# 触发阈值
|
||||
'rule_trigger': 60, # 规则直接触发阈值
|
||||
'ml_trigger': 80, # ML直接触发阈值
|
||||
'fusion_trigger': 50, # 融合触发阈值
|
||||
|
||||
# 特征列表
|
||||
'features': [
|
||||
'alpha', 'alpha_delta', 'amt_ratio',
|
||||
'amt_delta', 'rank_pct', 'limit_up_ratio'
|
||||
],
|
||||
|
||||
# 序列长度(ML模型需要)
|
||||
'seq_len': 30,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Dict = None,
|
||||
checkpoint_dir: str = 'ml/checkpoints',
|
||||
device: str = 'auto'
|
||||
):
|
||||
self.config = {**self.DEFAULT_CONFIG, **(config or {})}
|
||||
|
||||
# 初始化评分器
|
||||
self.rule_scorer = RuleBasedScorer()
|
||||
self.ml_scorer = MLScorer(checkpoint_dir, device)
|
||||
|
||||
print(f"HybridAnomalyDetector 初始化完成")
|
||||
print(f" 规则权重: {self.config['rule_weight']}")
|
||||
print(f" ML权重: {self.config['ml_weight']}")
|
||||
print(f" ML模型: {'就绪' if self.ml_scorer.is_ready() else '未加载'}")
|
||||
|
||||
def detect(
|
||||
self,
|
||||
features: Dict,
|
||||
sequence: np.ndarray = None
|
||||
) -> AnomalyResult:
|
||||
"""
|
||||
检测异动
|
||||
|
||||
Args:
|
||||
features: 当前时刻的特征字典
|
||||
sequence: 历史序列 (seq_len, n_features),ML模型需要
|
||||
Returns:
|
||||
AnomalyResult: 检测结果
|
||||
"""
|
||||
# 1. 规则评分
|
||||
rule_score, rule_details = self.rule_scorer.score(features)
|
||||
|
||||
# 2. ML评分
|
||||
ml_score = 0.0
|
||||
if sequence is not None and self.ml_scorer.is_ready():
|
||||
ml_score = self.ml_scorer.score(sequence)
|
||||
|
||||
# 3. 融合得分
|
||||
w1 = self.config['rule_weight']
|
||||
w2 = self.config['ml_weight']
|
||||
|
||||
# 如果ML不可用,全部权重给规则
|
||||
if not self.ml_scorer.is_ready():
|
||||
w1, w2 = 1.0, 0.0
|
||||
|
||||
final_score = w1 * rule_score + w2 * ml_score
|
||||
|
||||
# 4. 判断是否异动
|
||||
is_anomaly = False
|
||||
trigger_reason = ''
|
||||
|
||||
if rule_score >= self.config['rule_trigger']:
|
||||
is_anomaly = True
|
||||
trigger_reason = f'规则强信号({rule_score:.0f}分)'
|
||||
elif ml_score >= self.config['ml_trigger']:
|
||||
is_anomaly = True
|
||||
trigger_reason = f'ML强信号({ml_score:.0f}分)'
|
||||
elif final_score >= self.config['fusion_trigger']:
|
||||
is_anomaly = True
|
||||
trigger_reason = f'融合触发({final_score:.0f}分)'
|
||||
|
||||
# 5. 判断异动类型
|
||||
anomaly_type = self.rule_scorer.get_anomaly_type(features) if is_anomaly else ''
|
||||
|
||||
return AnomalyResult(
|
||||
is_anomaly=is_anomaly,
|
||||
final_score=final_score,
|
||||
rule_score=rule_score,
|
||||
ml_score=ml_score,
|
||||
trigger_reason=trigger_reason,
|
||||
rule_details=rule_details,
|
||||
anomaly_type=anomaly_type
|
||||
)
|
||||
|
||||
def detect_batch(
|
||||
self,
|
||||
features_list: List[Dict],
|
||||
sequences: np.ndarray = None
|
||||
) -> List[AnomalyResult]:
|
||||
"""
|
||||
批量检测
|
||||
|
||||
Args:
|
||||
features_list: 特征字典列表
|
||||
sequences: (batch, seq_len, n_features)
|
||||
Returns:
|
||||
List[AnomalyResult]
|
||||
"""
|
||||
results = []
|
||||
|
||||
for i, features in enumerate(features_list):
|
||||
seq = sequences[i] if sequences is not None else None
|
||||
result = self.detect(features, seq)
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# ==================== 便捷函数 ====================
|
||||
|
||||
def create_detector(
|
||||
checkpoint_dir: str = 'ml/checkpoints',
|
||||
config: Dict = None
|
||||
) -> HybridAnomalyDetector:
|
||||
"""创建融合检测器"""
|
||||
return HybridAnomalyDetector(config, checkpoint_dir)
|
||||
|
||||
|
||||
def quick_detect(features: Dict) -> bool:
|
||||
"""
|
||||
快速检测(只用规则,不需要ML模型)
|
||||
|
||||
适用于:
|
||||
- 实时检测
|
||||
- ML模型未训练完成时
|
||||
"""
|
||||
scorer = RuleBasedScorer()
|
||||
score, _ = scorer.score(features)
|
||||
return score >= 50
|
||||
|
||||
|
||||
# ==================== 测试 ====================
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("=" * 60)
|
||||
print("融合异动检测器测试")
|
||||
print("=" * 60)
|
||||
|
||||
# 创建检测器
|
||||
detector = create_detector()
|
||||
|
||||
# 测试用例
|
||||
test_cases = [
|
||||
{
|
||||
'name': '正常情况',
|
||||
'features': {
|
||||
'alpha': 0.5,
|
||||
'alpha_delta': 0.1,
|
||||
'amt_ratio': 1.2,
|
||||
'amt_delta': 0.1,
|
||||
'rank_pct': 0.5,
|
||||
'limit_up_ratio': 0.02
|
||||
}
|
||||
},
|
||||
{
|
||||
'name': 'Alpha异动',
|
||||
'features': {
|
||||
'alpha': 3.5,
|
||||
'alpha_delta': 0.8,
|
||||
'amt_ratio': 2.5,
|
||||
'amt_delta': 0.5,
|
||||
'rank_pct': 0.92,
|
||||
'limit_up_ratio': 0.05
|
||||
}
|
||||
},
|
||||
{
|
||||
'name': '放量异动',
|
||||
'features': {
|
||||
'alpha': 1.2,
|
||||
'alpha_delta': 0.3,
|
||||
'amt_ratio': 6.0,
|
||||
'amt_delta': 1.5,
|
||||
'rank_pct': 0.85,
|
||||
'limit_up_ratio': 0.08
|
||||
}
|
||||
},
|
||||
{
|
||||
'name': '涨停潮',
|
||||
'features': {
|
||||
'alpha': 2.5,
|
||||
'alpha_delta': 0.6,
|
||||
'amt_ratio': 3.5,
|
||||
'amt_delta': 0.8,
|
||||
'rank_pct': 0.98,
|
||||
'limit_up_ratio': 0.25
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
print("\n" + "-" * 60)
|
||||
print("测试1: 只用规则(无序列数据)")
|
||||
print("-" * 60)
|
||||
|
||||
for case in test_cases:
|
||||
result = detector.detect(case['features'])
|
||||
|
||||
print(f"\n{case['name']}:")
|
||||
print(f" 异动: {'是' if result.is_anomaly else '否'}")
|
||||
print(f" 最终得分: {result.final_score:.1f}")
|
||||
print(f" 规则得分: {result.rule_score:.1f}")
|
||||
print(f" ML得分: {result.ml_score:.1f}")
|
||||
if result.is_anomaly:
|
||||
print(f" 触发原因: {result.trigger_reason}")
|
||||
print(f" 异动类型: {result.anomaly_type}")
|
||||
print(f" 触发规则: {list(result.rule_details.keys())}")
|
||||
|
||||
# 测试2: 带序列数据的融合检测
|
||||
print("\n" + "-" * 60)
|
||||
print("测试2: 融合检测(规则 + ML)")
|
||||
print("-" * 60)
|
||||
|
||||
# 生成模拟序列数据
|
||||
seq_len = 30
|
||||
n_features = 6
|
||||
|
||||
# 正常序列:小幅波动
|
||||
normal_sequence = np.random.randn(seq_len, n_features) * 0.3
|
||||
normal_sequence[:, 0] = np.linspace(0, 0.5, seq_len) # alpha 缓慢上升
|
||||
normal_sequence[:, 2] = np.abs(normal_sequence[:, 2]) + 1 # amt_ratio > 0
|
||||
|
||||
# 异常序列:最后几个时间步突然变化
|
||||
anomaly_sequence = np.random.randn(seq_len, n_features) * 0.3
|
||||
anomaly_sequence[-5:, 0] = np.linspace(1, 4, 5) # alpha 突然飙升
|
||||
anomaly_sequence[-5:, 1] = np.linspace(0.2, 1.5, 5) # alpha_delta 加速
|
||||
anomaly_sequence[-5:, 2] = np.linspace(2, 6, 5) # amt_ratio 放量
|
||||
anomaly_sequence[:, 2] = np.abs(anomaly_sequence[:, 2]) + 1
|
||||
|
||||
# 测试正常序列
|
||||
normal_features = {
|
||||
'alpha': float(normal_sequence[-1, 0]),
|
||||
'alpha_delta': float(normal_sequence[-1, 1]),
|
||||
'amt_ratio': float(normal_sequence[-1, 2]),
|
||||
'amt_delta': float(normal_sequence[-1, 3]),
|
||||
'rank_pct': 0.5,
|
||||
'limit_up_ratio': 0.02
|
||||
}
|
||||
|
||||
result = detector.detect(normal_features, normal_sequence)
|
||||
print(f"\n正常序列:")
|
||||
print(f" 异动: {'是' if result.is_anomaly else '否'}")
|
||||
print(f" 最终得分: {result.final_score:.1f}")
|
||||
print(f" 规则得分: {result.rule_score:.1f}")
|
||||
print(f" ML得分: {result.ml_score:.1f}")
|
||||
|
||||
# 测试异常序列
|
||||
anomaly_features = {
|
||||
'alpha': float(anomaly_sequence[-1, 0]),
|
||||
'alpha_delta': float(anomaly_sequence[-1, 1]),
|
||||
'amt_ratio': float(anomaly_sequence[-1, 2]),
|
||||
'amt_delta': float(anomaly_sequence[-1, 3]),
|
||||
'rank_pct': 0.95,
|
||||
'limit_up_ratio': 0.15
|
||||
}
|
||||
|
||||
result = detector.detect(anomaly_features, anomaly_sequence)
|
||||
print(f"\n异常序列:")
|
||||
print(f" 异动: {'是' if result.is_anomaly else '否'}")
|
||||
print(f" 最终得分: {result.final_score:.1f}")
|
||||
print(f" 规则得分: {result.rule_score:.1f}")
|
||||
print(f" ML得分: {result.ml_score:.1f}")
|
||||
if result.is_anomaly:
|
||||
print(f" 触发原因: {result.trigger_reason}")
|
||||
print(f" 异动类型: {result.anomaly_type}")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("测试完成!")
|
||||
716
ml/detector_v2.py
Normal file
716
ml/detector_v2.py
Normal file
@@ -0,0 +1,716 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
异动检测器 V2 - 基于时间片对齐 + 持续性确认
|
||||
|
||||
核心改进:
|
||||
1. Z-Score 特征:相对于同时间片历史的偏离
|
||||
2. 短序列 LSTM:10分钟序列,开盘即可用
|
||||
3. 持续性确认:5分钟窗口内60%时刻超标才确认为异动
|
||||
|
||||
检测流程:
|
||||
1. 计算当前时刻的 Z-Score(对比同时间片历史基线)
|
||||
2. 构建最近10分钟的 Z-Score 序列
|
||||
3. LSTM 计算重构误差(ML分数)
|
||||
4. 规则评分(基于 Z-Score 的规则)
|
||||
5. 滑动窗口确认:最近5分钟内是否有足够多的时刻超标
|
||||
6. 只有通过持续性确认的才输出为异动
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import pickle
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from collections import defaultdict, deque
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
from sqlalchemy import create_engine, text
|
||||
from elasticsearch import Elasticsearch
|
||||
from clickhouse_driver import Client
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from ml.model import TransformerAutoencoder
|
||||
|
||||
# ==================== 配置 ====================
|
||||
|
||||
MYSQL_ENGINE = create_engine(
|
||||
"mysql+pymysql://root:Zzl5588161!@192.168.1.5:3306/stock",
|
||||
echo=False
|
||||
)
|
||||
|
||||
ES_CLIENT = Elasticsearch(['http://127.0.0.1:9200'])
|
||||
ES_INDEX = 'concept_library_v3'
|
||||
|
||||
CLICKHOUSE_CONFIG = {
|
||||
'host': '127.0.0.1',
|
||||
'port': 9000,
|
||||
'user': 'default',
|
||||
'password': 'Zzl33818!',
|
||||
'database': 'stock'
|
||||
}
|
||||
|
||||
REFERENCE_INDEX = '000001.SH'
|
||||
|
||||
# 检测配置
|
||||
CONFIG = {
|
||||
# 序列配置
|
||||
'seq_len': 10, # LSTM 序列长度(分钟)
|
||||
|
||||
# 持续性确认配置(核心!)
|
||||
'confirm_window': 5, # 确认窗口(分钟)
|
||||
'confirm_ratio': 0.6, # 确认比例(60%时刻需要超标)
|
||||
|
||||
# Z-Score 阈值
|
||||
'alpha_zscore_threshold': 2.0, # Alpha Z-Score 阈值
|
||||
'amt_zscore_threshold': 2.5, # 成交额 Z-Score 阈值
|
||||
|
||||
# 融合权重
|
||||
'rule_weight': 0.5,
|
||||
'ml_weight': 0.5,
|
||||
|
||||
# 触发阈值
|
||||
'rule_trigger': 60,
|
||||
'ml_trigger': 70,
|
||||
'fusion_trigger': 50,
|
||||
|
||||
# 冷却期
|
||||
'cooldown_minutes': 10,
|
||||
'max_alerts_per_minute': 15,
|
||||
|
||||
# Z-Score 截断
|
||||
'zscore_clip': 5.0,
|
||||
}
|
||||
|
||||
# V2 特征列表
|
||||
FEATURES_V2 = [
|
||||
'alpha_zscore', 'amt_zscore', 'rank_zscore',
|
||||
'momentum_3m', 'momentum_5m', 'limit_up_ratio'
|
||||
]
|
||||
|
||||
|
||||
# ==================== 工具函数 ====================
|
||||
|
||||
def get_ch_client():
|
||||
return Client(**CLICKHOUSE_CONFIG)
|
||||
|
||||
|
||||
def code_to_ch_format(code: str) -> str:
|
||||
if not code or len(code) != 6 or not code.isdigit():
|
||||
return None
|
||||
if code.startswith('6'):
|
||||
return f"{code}.SH"
|
||||
elif code.startswith('0') or code.startswith('3'):
|
||||
return f"{code}.SZ"
|
||||
else:
|
||||
return f"{code}.BJ"
|
||||
|
||||
|
||||
def time_to_slot(ts) -> str:
|
||||
"""时间戳转时间片(HH:MM)"""
|
||||
if isinstance(ts, str):
|
||||
return ts
|
||||
return ts.strftime('%H:%M')
|
||||
|
||||
|
||||
# ==================== 基线加载 ====================
|
||||
|
||||
def load_baselines(baseline_dir: str = 'ml/data_v2/baselines') -> Dict[str, pd.DataFrame]:
|
||||
"""加载时间片基线"""
|
||||
baseline_file = os.path.join(baseline_dir, 'baselines.pkl')
|
||||
if os.path.exists(baseline_file):
|
||||
with open(baseline_file, 'rb') as f:
|
||||
return pickle.load(f)
|
||||
return {}
|
||||
|
||||
|
||||
# ==================== 规则评分(基于 Z-Score)====================
|
||||
|
||||
def score_rules_zscore(row: Dict) -> Tuple[float, List[str]]:
|
||||
"""
|
||||
基于 Z-Score 的规则评分
|
||||
|
||||
设计思路:Z-Score 已经标准化,直接用阈值判断
|
||||
"""
|
||||
score = 0.0
|
||||
triggered = []
|
||||
|
||||
alpha_zscore = row.get('alpha_zscore', 0)
|
||||
amt_zscore = row.get('amt_zscore', 0)
|
||||
rank_zscore = row.get('rank_zscore', 0)
|
||||
momentum_3m = row.get('momentum_3m', 0)
|
||||
momentum_5m = row.get('momentum_5m', 0)
|
||||
limit_up_ratio = row.get('limit_up_ratio', 0)
|
||||
|
||||
alpha_zscore_abs = abs(alpha_zscore)
|
||||
amt_zscore_abs = abs(amt_zscore)
|
||||
|
||||
# ========== Alpha Z-Score 规则 ==========
|
||||
if alpha_zscore_abs >= 4.0:
|
||||
score += 25
|
||||
triggered.append('alpha_zscore_extreme')
|
||||
elif alpha_zscore_abs >= 3.0:
|
||||
score += 18
|
||||
triggered.append('alpha_zscore_strong')
|
||||
elif alpha_zscore_abs >= 2.0:
|
||||
score += 10
|
||||
triggered.append('alpha_zscore_moderate')
|
||||
|
||||
# ========== 成交额 Z-Score 规则 ==========
|
||||
if amt_zscore >= 4.0:
|
||||
score += 20
|
||||
triggered.append('amt_zscore_extreme')
|
||||
elif amt_zscore >= 3.0:
|
||||
score += 12
|
||||
triggered.append('amt_zscore_strong')
|
||||
elif amt_zscore >= 2.0:
|
||||
score += 6
|
||||
triggered.append('amt_zscore_moderate')
|
||||
|
||||
# ========== 排名 Z-Score 规则 ==========
|
||||
if abs(rank_zscore) >= 3.0:
|
||||
score += 15
|
||||
triggered.append('rank_zscore_extreme')
|
||||
elif abs(rank_zscore) >= 2.0:
|
||||
score += 8
|
||||
triggered.append('rank_zscore_strong')
|
||||
|
||||
# ========== 动量规则 ==========
|
||||
if momentum_3m >= 1.0:
|
||||
score += 12
|
||||
triggered.append('momentum_3m_strong')
|
||||
elif momentum_3m >= 0.5:
|
||||
score += 6
|
||||
triggered.append('momentum_3m_moderate')
|
||||
|
||||
if momentum_5m >= 1.5:
|
||||
score += 10
|
||||
triggered.append('momentum_5m_strong')
|
||||
|
||||
# ========== 涨停比例规则 ==========
|
||||
if limit_up_ratio >= 0.3:
|
||||
score += 20
|
||||
triggered.append('limit_up_extreme')
|
||||
elif limit_up_ratio >= 0.15:
|
||||
score += 12
|
||||
triggered.append('limit_up_strong')
|
||||
elif limit_up_ratio >= 0.08:
|
||||
score += 5
|
||||
triggered.append('limit_up_moderate')
|
||||
|
||||
# ========== 组合规则 ==========
|
||||
# Alpha Z-Score + 成交额放大
|
||||
if alpha_zscore_abs >= 2.0 and amt_zscore >= 2.0:
|
||||
score += 15
|
||||
triggered.append('combo_alpha_amt')
|
||||
|
||||
# Alpha Z-Score + 涨停
|
||||
if alpha_zscore_abs >= 2.0 and limit_up_ratio >= 0.1:
|
||||
score += 12
|
||||
triggered.append('combo_alpha_limitup')
|
||||
|
||||
return min(score, 100), triggered
|
||||
|
||||
|
||||
# ==================== ML 评分器 ====================
|
||||
|
||||
class MLScorerV2:
|
||||
"""V2 ML 评分器"""
|
||||
|
||||
def __init__(self, model_dir: str = 'ml/checkpoints_v2'):
|
||||
self.model_dir = model_dir
|
||||
self.model = None
|
||||
self.thresholds = None
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
self._load_model()
|
||||
|
||||
def _load_model(self):
|
||||
"""加载模型和阈值"""
|
||||
model_path = os.path.join(self.model_dir, 'best_model.pt')
|
||||
threshold_path = os.path.join(self.model_dir, 'thresholds.json')
|
||||
config_path = os.path.join(self.model_dir, 'config.json')
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
print(f"警告: 模型文件不存在: {model_path}")
|
||||
return
|
||||
|
||||
# 加载配置
|
||||
with open(config_path, 'r') as f:
|
||||
config = json.load(f)
|
||||
|
||||
# 创建模型
|
||||
model_config = config.get('model', {})
|
||||
self.model = TransformerAutoencoder(**model_config)
|
||||
|
||||
# 加载权重
|
||||
checkpoint = torch.load(model_path, map_location=self.device)
|
||||
self.model.load_state_dict(checkpoint['model_state_dict'])
|
||||
self.model.to(self.device)
|
||||
self.model.eval()
|
||||
|
||||
# 加载阈值
|
||||
if os.path.exists(threshold_path):
|
||||
with open(threshold_path, 'r') as f:
|
||||
self.thresholds = json.load(f)
|
||||
|
||||
print(f"V2 模型加载完成: {model_path}")
|
||||
|
||||
@torch.no_grad()
|
||||
def score_batch(self, sequences: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
批量计算 ML 分数
|
||||
|
||||
返回 0-100 的分数,越高越异常
|
||||
"""
|
||||
if self.model is None:
|
||||
return np.zeros(len(sequences))
|
||||
|
||||
# 转换为 tensor
|
||||
x = torch.FloatTensor(sequences).to(self.device)
|
||||
|
||||
# 计算重构误差
|
||||
errors = self.model.compute_reconstruction_error(x, reduction='none')
|
||||
# 取最后一个时刻的误差
|
||||
last_errors = errors[:, -1].cpu().numpy()
|
||||
|
||||
# 转换为 0-100 分数
|
||||
if self.thresholds:
|
||||
p50 = self.thresholds.get('median', 0.1)
|
||||
p99 = self.thresholds.get('p99', 1.0)
|
||||
|
||||
# 线性映射:p50 -> 50分,p99 -> 99分
|
||||
scores = 50 + (last_errors - p50) / (p99 - p50) * 49
|
||||
scores = np.clip(scores, 0, 100)
|
||||
else:
|
||||
# 没有阈值时,简单归一化
|
||||
scores = last_errors * 100
|
||||
scores = np.clip(scores, 0, 100)
|
||||
|
||||
return scores
|
||||
|
||||
|
||||
# ==================== 实时数据管理器 ====================
|
||||
|
||||
class RealtimeDataManagerV2:
|
||||
"""
|
||||
V2 实时数据管理器
|
||||
|
||||
维护:
|
||||
1. 每个概念的历史 Z-Score 序列(用于 LSTM 输入)
|
||||
2. 每个概念的异动候选队列(用于持续性确认)
|
||||
"""
|
||||
|
||||
def __init__(self, concepts: List[dict], baselines: Dict[str, pd.DataFrame]):
|
||||
self.concepts = {c['concept_id']: c for c in concepts}
|
||||
self.baselines = baselines
|
||||
|
||||
# 概念到股票的映射
|
||||
self.concept_stocks = {c['concept_id']: set(c['stocks']) for c in concepts}
|
||||
|
||||
# 历史 Z-Score 序列(每个概念)
|
||||
# {concept_id: deque([(timestamp, features_dict), ...], maxlen=seq_len)}
|
||||
self.zscore_history = defaultdict(lambda: deque(maxlen=CONFIG['seq_len']))
|
||||
|
||||
# 异动候选队列(用于持续性确认)
|
||||
# {concept_id: deque([(timestamp, score), ...], maxlen=confirm_window)}
|
||||
self.anomaly_candidates = defaultdict(lambda: deque(maxlen=CONFIG['confirm_window']))
|
||||
|
||||
# 冷却期记录
|
||||
self.cooldown = {}
|
||||
|
||||
# 上一次更新的时间戳
|
||||
self.last_timestamp = None
|
||||
|
||||
def compute_zscore_features(
|
||||
self,
|
||||
concept_id: str,
|
||||
timestamp,
|
||||
alpha: float,
|
||||
total_amt: float,
|
||||
rank_pct: float,
|
||||
limit_up_ratio: float
|
||||
) -> Optional[Dict]:
|
||||
"""计算单个概念单个时刻的 Z-Score 特征"""
|
||||
if concept_id not in self.baselines:
|
||||
return None
|
||||
|
||||
baseline = self.baselines[concept_id]
|
||||
time_slot = time_to_slot(timestamp)
|
||||
|
||||
# 查找对应时间片的基线
|
||||
bl_row = baseline[baseline['time_slot'] == time_slot]
|
||||
if bl_row.empty:
|
||||
return None
|
||||
|
||||
bl = bl_row.iloc[0]
|
||||
|
||||
# 检查样本量
|
||||
if bl.get('sample_count', 0) < 10:
|
||||
return None
|
||||
|
||||
# 计算 Z-Score
|
||||
alpha_zscore = (alpha - bl['alpha_mean']) / bl['alpha_std']
|
||||
amt_zscore = (total_amt - bl['amt_mean']) / bl['amt_std']
|
||||
rank_zscore = (rank_pct - bl['rank_mean']) / bl['rank_std']
|
||||
|
||||
# 截断
|
||||
clip = CONFIG['zscore_clip']
|
||||
alpha_zscore = np.clip(alpha_zscore, -clip, clip)
|
||||
amt_zscore = np.clip(amt_zscore, -clip, clip)
|
||||
rank_zscore = np.clip(rank_zscore, -clip, clip)
|
||||
|
||||
# 计算动量(需要历史)
|
||||
history = self.zscore_history[concept_id]
|
||||
momentum_3m = 0
|
||||
momentum_5m = 0
|
||||
|
||||
if len(history) >= 3:
|
||||
recent_alphas = [h[1]['alpha'] for h in list(history)[-3:]]
|
||||
older_alphas = [h[1]['alpha'] for h in list(history)[-6:-3]] if len(history) >= 6 else [alpha]
|
||||
momentum_3m = np.mean(recent_alphas) - np.mean(older_alphas)
|
||||
|
||||
if len(history) >= 5:
|
||||
recent_alphas = [h[1]['alpha'] for h in list(history)[-5:]]
|
||||
older_alphas = [h[1]['alpha'] for h in list(history)[-10:-5]] if len(history) >= 10 else [alpha]
|
||||
momentum_5m = np.mean(recent_alphas) - np.mean(older_alphas)
|
||||
|
||||
return {
|
||||
'alpha': alpha,
|
||||
'alpha_zscore': alpha_zscore,
|
||||
'amt_zscore': amt_zscore,
|
||||
'rank_zscore': rank_zscore,
|
||||
'momentum_3m': momentum_3m,
|
||||
'momentum_5m': momentum_5m,
|
||||
'limit_up_ratio': limit_up_ratio,
|
||||
'total_amt': total_amt,
|
||||
'rank_pct': rank_pct,
|
||||
}
|
||||
|
||||
def update(self, concept_id: str, timestamp, features: Dict):
|
||||
"""更新概念的历史数据"""
|
||||
self.zscore_history[concept_id].append((timestamp, features))
|
||||
|
||||
def get_sequence(self, concept_id: str) -> Optional[np.ndarray]:
|
||||
"""获取用于 LSTM 的序列"""
|
||||
history = self.zscore_history[concept_id]
|
||||
|
||||
if len(history) < CONFIG['seq_len']:
|
||||
return None
|
||||
|
||||
# 提取特征
|
||||
feature_list = []
|
||||
for _, features in history:
|
||||
feature_list.append([
|
||||
features['alpha_zscore'],
|
||||
features['amt_zscore'],
|
||||
features['rank_zscore'],
|
||||
features['momentum_3m'],
|
||||
features['momentum_5m'],
|
||||
features['limit_up_ratio'],
|
||||
])
|
||||
|
||||
return np.array(feature_list)
|
||||
|
||||
def add_anomaly_candidate(self, concept_id: str, timestamp, score: float):
|
||||
"""添加异动候选"""
|
||||
self.anomaly_candidates[concept_id].append((timestamp, score))
|
||||
|
||||
def check_sustained_anomaly(self, concept_id: str, threshold: float) -> Tuple[bool, float]:
|
||||
"""
|
||||
检查是否为持续性异动
|
||||
|
||||
返回:(是否确认, 确认比例)
|
||||
"""
|
||||
candidates = self.anomaly_candidates[concept_id]
|
||||
|
||||
if len(candidates) < CONFIG['confirm_window']:
|
||||
return False, 0.0
|
||||
|
||||
# 统计超过阈值的时刻数量
|
||||
exceed_count = sum(1 for _, score in candidates if score >= threshold)
|
||||
ratio = exceed_count / len(candidates)
|
||||
|
||||
return ratio >= CONFIG['confirm_ratio'], ratio
|
||||
|
||||
def check_cooldown(self, concept_id: str, timestamp) -> bool:
|
||||
"""检查是否在冷却期"""
|
||||
if concept_id not in self.cooldown:
|
||||
return False
|
||||
|
||||
last_alert = self.cooldown[concept_id]
|
||||
try:
|
||||
diff = (timestamp - last_alert).total_seconds() / 60
|
||||
return diff < CONFIG['cooldown_minutes']
|
||||
except:
|
||||
return False
|
||||
|
||||
def set_cooldown(self, concept_id: str, timestamp):
|
||||
"""设置冷却期"""
|
||||
self.cooldown[concept_id] = timestamp
|
||||
|
||||
|
||||
# ==================== 异动检测器 V2 ====================
|
||||
|
||||
class AnomalyDetectorV2:
|
||||
"""
|
||||
V2 异动检测器
|
||||
|
||||
核心流程:
|
||||
1. 获取实时数据
|
||||
2. 计算 Z-Score 特征
|
||||
3. 规则评分 + ML 评分
|
||||
4. 持续性确认
|
||||
5. 输出异动
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_dir: str = 'ml/checkpoints_v2',
|
||||
baseline_dir: str = 'ml/data_v2/baselines'
|
||||
):
|
||||
# 加载概念
|
||||
self.concepts = self._load_concepts()
|
||||
|
||||
# 加载基线
|
||||
self.baselines = load_baselines(baseline_dir)
|
||||
print(f"加载了 {len(self.baselines)} 个概念的基线")
|
||||
|
||||
# 初始化 ML 评分器
|
||||
self.ml_scorer = MLScorerV2(model_dir)
|
||||
|
||||
# 初始化数据管理器
|
||||
self.data_manager = RealtimeDataManagerV2(self.concepts, self.baselines)
|
||||
|
||||
# 收集所有股票
|
||||
self.all_stocks = list(set(s for c in self.concepts for s in c['stocks']))
|
||||
|
||||
def _load_concepts(self) -> List[dict]:
|
||||
"""从 ES 加载概念"""
|
||||
concepts = []
|
||||
query = {"query": {"match_all": {}}, "size": 100, "_source": ["concept_id", "concept", "stocks"]}
|
||||
|
||||
resp = ES_CLIENT.search(index=ES_INDEX, body=query, scroll='2m')
|
||||
scroll_id = resp['_scroll_id']
|
||||
hits = resp['hits']['hits']
|
||||
|
||||
while len(hits) > 0:
|
||||
for hit in hits:
|
||||
source = hit['_source']
|
||||
stocks = []
|
||||
if 'stocks' in source and isinstance(source['stocks'], list):
|
||||
for stock in source['stocks']:
|
||||
if isinstance(stock, dict) and 'code' in stock and stock['code']:
|
||||
stocks.append(stock['code'])
|
||||
if stocks:
|
||||
concepts.append({
|
||||
'concept_id': source.get('concept_id'),
|
||||
'concept_name': source.get('concept'),
|
||||
'stocks': stocks
|
||||
})
|
||||
|
||||
resp = ES_CLIENT.scroll(scroll_id=scroll_id, scroll='2m')
|
||||
scroll_id = resp['_scroll_id']
|
||||
hits = resp['hits']['hits']
|
||||
|
||||
ES_CLIENT.clear_scroll(scroll_id=scroll_id)
|
||||
print(f"加载了 {len(concepts)} 个概念")
|
||||
return concepts
|
||||
|
||||
def detect(self, trade_date: str) -> List[Dict]:
|
||||
"""
|
||||
检测指定日期的异动
|
||||
|
||||
返回异动列表
|
||||
"""
|
||||
print(f"\n检测 {trade_date} 的异动...")
|
||||
|
||||
# 获取原始数据
|
||||
raw_features = self._compute_raw_features(trade_date)
|
||||
if raw_features.empty:
|
||||
print("无数据")
|
||||
return []
|
||||
|
||||
# 按时间排序
|
||||
timestamps = sorted(raw_features['timestamp'].unique())
|
||||
print(f"时间点数: {len(timestamps)}")
|
||||
|
||||
all_alerts = []
|
||||
|
||||
for ts in timestamps:
|
||||
ts_data = raw_features[raw_features['timestamp'] == ts]
|
||||
ts_alerts = self._process_timestamp(ts, ts_data, trade_date)
|
||||
all_alerts.extend(ts_alerts)
|
||||
|
||||
print(f"共检测到 {len(all_alerts)} 个异动")
|
||||
return all_alerts
|
||||
|
||||
def _compute_raw_features(self, trade_date: str) -> pd.DataFrame:
|
||||
"""计算原始特征(同 prepare_data_v2)"""
|
||||
# 这里简化处理,直接调用数据准备逻辑
|
||||
from prepare_data_v2 import compute_raw_concept_features
|
||||
return compute_raw_concept_features(trade_date, self.concepts, self.all_stocks)
|
||||
|
||||
def _process_timestamp(self, timestamp, ts_data: pd.DataFrame, trade_date: str) -> List[Dict]:
|
||||
"""处理单个时间戳"""
|
||||
alerts = []
|
||||
candidates = [] # (concept_id, features, rule_score, triggered_rules)
|
||||
|
||||
for _, row in ts_data.iterrows():
|
||||
concept_id = row['concept_id']
|
||||
|
||||
# 计算 Z-Score 特征
|
||||
features = self.data_manager.compute_zscore_features(
|
||||
concept_id, timestamp,
|
||||
row['alpha'], row['total_amt'], row['rank_pct'], row['limit_up_ratio']
|
||||
)
|
||||
|
||||
if features is None:
|
||||
continue
|
||||
|
||||
# 更新历史
|
||||
self.data_manager.update(concept_id, timestamp, features)
|
||||
|
||||
# 规则评分
|
||||
rule_score, triggered_rules = score_rules_zscore(features)
|
||||
|
||||
# 收集候选
|
||||
candidates.append((concept_id, features, rule_score, triggered_rules))
|
||||
|
||||
if not candidates:
|
||||
return []
|
||||
|
||||
# 批量 ML 评分
|
||||
sequences = []
|
||||
valid_candidates = []
|
||||
|
||||
for concept_id, features, rule_score, triggered_rules in candidates:
|
||||
seq = self.data_manager.get_sequence(concept_id)
|
||||
if seq is not None:
|
||||
sequences.append(seq)
|
||||
valid_candidates.append((concept_id, features, rule_score, triggered_rules))
|
||||
|
||||
if not sequences:
|
||||
return []
|
||||
|
||||
sequences = np.array(sequences)
|
||||
ml_scores = self.ml_scorer.score_batch(sequences)
|
||||
|
||||
# 融合评分 + 持续性确认
|
||||
for i, (concept_id, features, rule_score, triggered_rules) in enumerate(valid_candidates):
|
||||
ml_score = ml_scores[i]
|
||||
final_score = CONFIG['rule_weight'] * rule_score + CONFIG['ml_weight'] * ml_score
|
||||
|
||||
# 判断是否触发
|
||||
is_triggered = (
|
||||
rule_score >= CONFIG['rule_trigger'] or
|
||||
ml_score >= CONFIG['ml_trigger'] or
|
||||
final_score >= CONFIG['fusion_trigger']
|
||||
)
|
||||
|
||||
# 添加到候选队列
|
||||
self.data_manager.add_anomaly_candidate(concept_id, timestamp, final_score)
|
||||
|
||||
if not is_triggered:
|
||||
continue
|
||||
|
||||
# 检查冷却期
|
||||
if self.data_manager.check_cooldown(concept_id, timestamp):
|
||||
continue
|
||||
|
||||
# 持续性确认
|
||||
is_sustained, confirm_ratio = self.data_manager.check_sustained_anomaly(
|
||||
concept_id, CONFIG['fusion_trigger']
|
||||
)
|
||||
|
||||
if not is_sustained:
|
||||
continue
|
||||
|
||||
# 确认为异动!
|
||||
self.data_manager.set_cooldown(concept_id, timestamp)
|
||||
|
||||
# 确定异动类型
|
||||
alpha = features['alpha']
|
||||
if alpha >= 1.5:
|
||||
alert_type = 'surge_up'
|
||||
elif alpha <= -1.5:
|
||||
alert_type = 'surge_down'
|
||||
elif features['amt_zscore'] >= 3.0:
|
||||
alert_type = 'volume_spike'
|
||||
else:
|
||||
alert_type = 'surge'
|
||||
|
||||
# 确定触发原因
|
||||
if rule_score >= CONFIG['rule_trigger']:
|
||||
trigger_reason = f'规则({rule_score:.0f})+持续确认({confirm_ratio:.0%})'
|
||||
elif ml_score >= CONFIG['ml_trigger']:
|
||||
trigger_reason = f'ML({ml_score:.0f})+持续确认({confirm_ratio:.0%})'
|
||||
else:
|
||||
trigger_reason = f'融合({final_score:.0f})+持续确认({confirm_ratio:.0%})'
|
||||
|
||||
alerts.append({
|
||||
'concept_id': concept_id,
|
||||
'concept_name': self.data_manager.concepts.get(concept_id, {}).get('concept_name', concept_id),
|
||||
'alert_time': timestamp,
|
||||
'trade_date': trade_date,
|
||||
'alert_type': alert_type,
|
||||
'final_score': final_score,
|
||||
'rule_score': rule_score,
|
||||
'ml_score': ml_score,
|
||||
'trigger_reason': trigger_reason,
|
||||
'confirm_ratio': confirm_ratio,
|
||||
'alpha': alpha,
|
||||
'alpha_zscore': features['alpha_zscore'],
|
||||
'amt_zscore': features['amt_zscore'],
|
||||
'rank_zscore': features['rank_zscore'],
|
||||
'momentum_3m': features['momentum_3m'],
|
||||
'momentum_5m': features['momentum_5m'],
|
||||
'limit_up_ratio': features['limit_up_ratio'],
|
||||
'triggered_rules': triggered_rules,
|
||||
})
|
||||
|
||||
# 每分钟最多 N 个
|
||||
if len(alerts) > CONFIG['max_alerts_per_minute']:
|
||||
alerts = sorted(alerts, key=lambda x: x['final_score'], reverse=True)
|
||||
alerts = alerts[:CONFIG['max_alerts_per_minute']]
|
||||
|
||||
return alerts
|
||||
|
||||
|
||||
# ==================== 主函数 ====================
|
||||
|
||||
def main():
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description='V2 异动检测器')
|
||||
parser.add_argument('--date', type=str, default=None, help='检测日期(默认今天)')
|
||||
parser.add_argument('--model_dir', type=str, default='ml/checkpoints_v2')
|
||||
parser.add_argument('--baseline_dir', type=str, default='ml/data_v2/baselines')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
trade_date = args.date or datetime.now().strftime('%Y-%m-%d')
|
||||
|
||||
detector = AnomalyDetectorV2(
|
||||
model_dir=args.model_dir,
|
||||
baseline_dir=args.baseline_dir
|
||||
)
|
||||
|
||||
alerts = detector.detect(trade_date)
|
||||
|
||||
print(f"\n检测结果:")
|
||||
for alert in alerts[:20]:
|
||||
print(f" [{alert['alert_time'].strftime('%H:%M') if hasattr(alert['alert_time'], 'strftime') else alert['alert_time']}] "
|
||||
f"{alert['concept_name']} ({alert['alert_type']}) "
|
||||
f"分数={alert['final_score']:.0f} "
|
||||
f"确认率={alert['confirm_ratio']:.0%}")
|
||||
|
||||
if len(alerts) > 20:
|
||||
print(f" ... 共 {len(alerts)} 个异动")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
526
ml/enhanced_detector.py
Normal file
526
ml/enhanced_detector.py
Normal file
@@ -0,0 +1,526 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
增强版概念异动检测器
|
||||
|
||||
融合两种检测方法:
|
||||
1. Alpha-based Z-Score(规则方法,实时性好)
|
||||
2. Transformer Autoencoder(ML方法,更准确)
|
||||
|
||||
使用策略:
|
||||
- 当 ML 模型可用且历史数据足够时,优先使用 ML 方法
|
||||
- 否则回退到 Alpha-based 方法
|
||||
- 可以配置两种方法的融合权重
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple, Optional
|
||||
from dataclasses import dataclass, field
|
||||
from collections import deque
|
||||
import numpy as np
|
||||
|
||||
# 添加父目录到路径
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ==================== 配置 ====================
|
||||
|
||||
ENHANCED_CONFIG = {
|
||||
# 融合策略
|
||||
'fusion_mode': 'adaptive', # 'ml_only', 'alpha_only', 'adaptive', 'ensemble'
|
||||
|
||||
# ML 权重(在 ensemble 模式下)
|
||||
'ml_weight': 0.6,
|
||||
'alpha_weight': 0.4,
|
||||
|
||||
# ML 模型配置
|
||||
'ml_checkpoint_dir': 'ml/checkpoints',
|
||||
'ml_threshold_key': 'p95', # p90, p95, p99
|
||||
|
||||
# Alpha 配置(与 concept_alert_alpha.py 一致)
|
||||
'alpha_zscore_threshold': 2.0,
|
||||
'alpha_absolute_threshold': 1.5,
|
||||
'alpha_history_window': 60,
|
||||
'alpha_min_history': 5,
|
||||
|
||||
# 共享配置
|
||||
'cooldown_minutes': 8,
|
||||
'max_alerts_per_minute': 15,
|
||||
'min_alpha_abs': 0.5,
|
||||
}
|
||||
|
||||
# 特征配置(与训练一致)
|
||||
FEATURE_NAMES = [
|
||||
'alpha',
|
||||
'alpha_delta',
|
||||
'amt_ratio',
|
||||
'amt_delta',
|
||||
'rank_pct',
|
||||
'limit_up_ratio',
|
||||
]
|
||||
|
||||
|
||||
# ==================== 数据结构 ====================
|
||||
|
||||
@dataclass
|
||||
class AlphaStats:
|
||||
"""概念的Alpha统计信息"""
|
||||
history: deque = field(default_factory=lambda: deque(maxlen=ENHANCED_CONFIG['alpha_history_window']))
|
||||
mean: float = 0.0
|
||||
std: float = 1.0
|
||||
|
||||
def update(self, alpha: float):
|
||||
self.history.append(alpha)
|
||||
if len(self.history) >= 2:
|
||||
self.mean = np.mean(self.history)
|
||||
self.std = max(np.std(self.history), 0.1)
|
||||
|
||||
def get_zscore(self, alpha: float) -> float:
|
||||
if len(self.history) < ENHANCED_CONFIG['alpha_min_history']:
|
||||
return 0.0
|
||||
return (alpha - self.mean) / self.std
|
||||
|
||||
def is_ready(self) -> bool:
|
||||
return len(self.history) >= ENHANCED_CONFIG['alpha_min_history']
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConceptFeatures:
|
||||
"""概念的实时特征"""
|
||||
alpha: float = 0.0
|
||||
alpha_delta: float = 0.0
|
||||
amt_ratio: float = 1.0
|
||||
amt_delta: float = 0.0
|
||||
rank_pct: float = 0.5
|
||||
limit_up_ratio: float = 0.0
|
||||
|
||||
def to_dict(self) -> Dict[str, float]:
|
||||
return {
|
||||
'alpha': self.alpha,
|
||||
'alpha_delta': self.alpha_delta,
|
||||
'amt_ratio': self.amt_ratio,
|
||||
'amt_delta': self.amt_delta,
|
||||
'rank_pct': self.rank_pct,
|
||||
'limit_up_ratio': self.limit_up_ratio,
|
||||
}
|
||||
|
||||
|
||||
# ==================== 增强检测器 ====================
|
||||
|
||||
class EnhancedAnomalyDetector:
|
||||
"""
|
||||
增强版异动检测器
|
||||
|
||||
融合 Alpha-based 和 ML 两种方法
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Dict = None,
|
||||
ml_enabled: bool = True
|
||||
):
|
||||
self.config = config or ENHANCED_CONFIG
|
||||
self.ml_enabled = ml_enabled
|
||||
self.ml_detector = None
|
||||
|
||||
# Alpha 统计
|
||||
self.alpha_stats: Dict[str, AlphaStats] = {}
|
||||
|
||||
# 特征历史(用于计算 delta)
|
||||
self.feature_history: Dict[str, deque] = {}
|
||||
|
||||
# 冷却记录
|
||||
self.cooldown_cache: Dict[str, datetime] = {}
|
||||
|
||||
# 尝试加载 ML 模型
|
||||
if ml_enabled:
|
||||
self._load_ml_model()
|
||||
|
||||
logger.info(f"EnhancedAnomalyDetector 初始化完成")
|
||||
logger.info(f" 融合模式: {self.config['fusion_mode']}")
|
||||
logger.info(f" ML 可用: {self.ml_detector is not None}")
|
||||
|
||||
def _load_ml_model(self):
|
||||
"""加载 ML 模型"""
|
||||
try:
|
||||
from inference import ConceptAnomalyDetector
|
||||
checkpoint_dir = Path(__file__).parent / 'checkpoints'
|
||||
|
||||
if (checkpoint_dir / 'best_model.pt').exists():
|
||||
self.ml_detector = ConceptAnomalyDetector(
|
||||
checkpoint_dir=str(checkpoint_dir),
|
||||
threshold_key=self.config['ml_threshold_key']
|
||||
)
|
||||
logger.info("ML 模型加载成功")
|
||||
else:
|
||||
logger.warning(f"ML 模型不存在: {checkpoint_dir / 'best_model.pt'}")
|
||||
except Exception as e:
|
||||
logger.warning(f"ML 模型加载失败: {e}")
|
||||
self.ml_detector = None
|
||||
|
||||
def _get_alpha_stats(self, concept_id: str) -> AlphaStats:
|
||||
"""获取或创建 Alpha 统计"""
|
||||
if concept_id not in self.alpha_stats:
|
||||
self.alpha_stats[concept_id] = AlphaStats()
|
||||
return self.alpha_stats[concept_id]
|
||||
|
||||
def _get_feature_history(self, concept_id: str) -> deque:
|
||||
"""获取特征历史"""
|
||||
if concept_id not in self.feature_history:
|
||||
self.feature_history[concept_id] = deque(maxlen=10)
|
||||
return self.feature_history[concept_id]
|
||||
|
||||
def _check_cooldown(self, concept_id: str, current_time: datetime) -> bool:
|
||||
"""检查冷却"""
|
||||
if concept_id not in self.cooldown_cache:
|
||||
return False
|
||||
|
||||
last_alert = self.cooldown_cache[concept_id]
|
||||
cooldown_td = (current_time - last_alert).total_seconds() / 60
|
||||
|
||||
return cooldown_td < self.config['cooldown_minutes']
|
||||
|
||||
def _set_cooldown(self, concept_id: str, current_time: datetime):
|
||||
"""设置冷却"""
|
||||
self.cooldown_cache[concept_id] = current_time
|
||||
|
||||
def compute_features(
|
||||
self,
|
||||
concept_id: str,
|
||||
alpha: float,
|
||||
amt_ratio: float,
|
||||
rank_pct: float,
|
||||
limit_up_ratio: float
|
||||
) -> ConceptFeatures:
|
||||
"""
|
||||
计算概念的完整特征
|
||||
|
||||
Args:
|
||||
concept_id: 概念ID
|
||||
alpha: 当前超额收益
|
||||
amt_ratio: 成交额比率
|
||||
rank_pct: 排名百分位
|
||||
limit_up_ratio: 涨停股占比
|
||||
|
||||
Returns:
|
||||
完整特征
|
||||
"""
|
||||
history = self._get_feature_history(concept_id)
|
||||
|
||||
# 计算变化率
|
||||
alpha_delta = 0.0
|
||||
amt_delta = 0.0
|
||||
|
||||
if len(history) > 0:
|
||||
last_features = history[-1]
|
||||
alpha_delta = alpha - last_features.alpha
|
||||
if last_features.amt_ratio > 0:
|
||||
amt_delta = (amt_ratio - last_features.amt_ratio) / last_features.amt_ratio
|
||||
|
||||
features = ConceptFeatures(
|
||||
alpha=alpha,
|
||||
alpha_delta=alpha_delta,
|
||||
amt_ratio=amt_ratio,
|
||||
amt_delta=amt_delta,
|
||||
rank_pct=rank_pct,
|
||||
limit_up_ratio=limit_up_ratio,
|
||||
)
|
||||
|
||||
# 更新历史
|
||||
history.append(features)
|
||||
|
||||
return features
|
||||
|
||||
def detect_alpha_anomaly(
|
||||
self,
|
||||
concept_id: str,
|
||||
alpha: float
|
||||
) -> Tuple[bool, float, str]:
|
||||
"""
|
||||
Alpha-based 异动检测
|
||||
|
||||
Returns:
|
||||
is_anomaly: 是否异动
|
||||
score: 异动分数(Z-Score 绝对值)
|
||||
reason: 触发原因
|
||||
"""
|
||||
stats = self._get_alpha_stats(concept_id)
|
||||
|
||||
# 计算 Z-Score(在更新前)
|
||||
zscore = stats.get_zscore(alpha)
|
||||
|
||||
# 更新统计
|
||||
stats.update(alpha)
|
||||
|
||||
# 判断
|
||||
if stats.is_ready():
|
||||
if abs(zscore) >= self.config['alpha_zscore_threshold']:
|
||||
return True, abs(zscore), f"Z={zscore:.2f}"
|
||||
else:
|
||||
if abs(alpha) >= self.config['alpha_absolute_threshold']:
|
||||
fake_zscore = alpha / 0.5
|
||||
return True, abs(fake_zscore), f"Alpha={alpha:+.2f}%"
|
||||
|
||||
return False, abs(zscore) if zscore else 0.0, ""
|
||||
|
||||
def detect_ml_anomaly(
|
||||
self,
|
||||
concept_id: str,
|
||||
features: ConceptFeatures
|
||||
) -> Tuple[bool, float]:
|
||||
"""
|
||||
ML-based 异动检测
|
||||
|
||||
Returns:
|
||||
is_anomaly: 是否异动
|
||||
score: 异动分数(重构误差)
|
||||
"""
|
||||
if self.ml_detector is None:
|
||||
return False, 0.0
|
||||
|
||||
try:
|
||||
is_anomaly, score = self.ml_detector.detect(
|
||||
concept_id,
|
||||
features.to_dict()
|
||||
)
|
||||
return is_anomaly, score or 0.0
|
||||
except Exception as e:
|
||||
logger.warning(f"ML 检测失败: {e}")
|
||||
return False, 0.0
|
||||
|
||||
def detect(
|
||||
self,
|
||||
concept_id: str,
|
||||
concept_name: str,
|
||||
alpha: float,
|
||||
amt_ratio: float,
|
||||
rank_pct: float,
|
||||
limit_up_ratio: float,
|
||||
change_pct: float,
|
||||
index_change: float,
|
||||
current_time: datetime,
|
||||
**extra_data
|
||||
) -> Optional[Dict]:
|
||||
"""
|
||||
融合检测
|
||||
|
||||
Args:
|
||||
concept_id: 概念ID
|
||||
concept_name: 概念名称
|
||||
alpha: 超额收益
|
||||
amt_ratio: 成交额比率
|
||||
rank_pct: 排名百分位
|
||||
limit_up_ratio: 涨停股占比
|
||||
change_pct: 概念涨跌幅
|
||||
index_change: 大盘涨跌幅
|
||||
current_time: 当前时间
|
||||
**extra_data: 其他数据(limit_up_count, stock_count 等)
|
||||
|
||||
Returns:
|
||||
异动信息(如果触发),否则 None
|
||||
"""
|
||||
# Alpha 太小,不关注
|
||||
if abs(alpha) < self.config['min_alpha_abs']:
|
||||
return None
|
||||
|
||||
# 检查冷却
|
||||
if self._check_cooldown(concept_id, current_time):
|
||||
return None
|
||||
|
||||
# 计算特征
|
||||
features = self.compute_features(
|
||||
concept_id, alpha, amt_ratio, rank_pct, limit_up_ratio
|
||||
)
|
||||
|
||||
# 执行检测
|
||||
fusion_mode = self.config['fusion_mode']
|
||||
|
||||
alpha_anomaly, alpha_score, alpha_reason = self.detect_alpha_anomaly(concept_id, alpha)
|
||||
ml_anomaly, ml_score = False, 0.0
|
||||
|
||||
if fusion_mode in ('ml_only', 'adaptive', 'ensemble'):
|
||||
ml_anomaly, ml_score = self.detect_ml_anomaly(concept_id, features)
|
||||
|
||||
# 根据融合模式判断
|
||||
is_anomaly = False
|
||||
final_score = 0.0
|
||||
detection_method = ''
|
||||
|
||||
if fusion_mode == 'alpha_only':
|
||||
is_anomaly = alpha_anomaly
|
||||
final_score = alpha_score
|
||||
detection_method = 'alpha'
|
||||
|
||||
elif fusion_mode == 'ml_only':
|
||||
is_anomaly = ml_anomaly
|
||||
final_score = ml_score
|
||||
detection_method = 'ml'
|
||||
|
||||
elif fusion_mode == 'adaptive':
|
||||
# 优先 ML,回退 Alpha
|
||||
if self.ml_detector and ml_score > 0:
|
||||
is_anomaly = ml_anomaly
|
||||
final_score = ml_score
|
||||
detection_method = 'ml'
|
||||
else:
|
||||
is_anomaly = alpha_anomaly
|
||||
final_score = alpha_score
|
||||
detection_method = 'alpha'
|
||||
|
||||
elif fusion_mode == 'ensemble':
|
||||
# 加权融合
|
||||
# 归一化分数
|
||||
norm_alpha = min(alpha_score / 5.0, 1.0) # Z > 5 视为 1.0
|
||||
norm_ml = min(ml_score / (self.ml_detector.threshold if self.ml_detector else 1.0), 1.0)
|
||||
|
||||
final_score = (
|
||||
self.config['alpha_weight'] * norm_alpha +
|
||||
self.config['ml_weight'] * norm_ml
|
||||
)
|
||||
is_anomaly = final_score > 0.5 or alpha_anomaly or ml_anomaly
|
||||
detection_method = 'ensemble'
|
||||
|
||||
if not is_anomaly:
|
||||
return None
|
||||
|
||||
# 构建异动记录
|
||||
self._set_cooldown(concept_id, current_time)
|
||||
|
||||
alert_type = 'surge_up' if alpha > 0 else 'surge_down'
|
||||
|
||||
alert = {
|
||||
'concept_id': concept_id,
|
||||
'concept_name': concept_name,
|
||||
'alert_type': alert_type,
|
||||
'alert_time': current_time,
|
||||
'change_pct': change_pct,
|
||||
'alpha': alpha,
|
||||
'alpha_zscore': alpha_score,
|
||||
'index_change_pct': index_change,
|
||||
'detection_method': detection_method,
|
||||
'alpha_score': alpha_score,
|
||||
'ml_score': ml_score,
|
||||
'final_score': final_score,
|
||||
**extra_data
|
||||
}
|
||||
|
||||
return alert
|
||||
|
||||
def batch_detect(
|
||||
self,
|
||||
concepts_data: List[Dict],
|
||||
current_time: datetime
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
批量检测
|
||||
|
||||
Args:
|
||||
concepts_data: 概念数据列表
|
||||
current_time: 当前时间
|
||||
|
||||
Returns:
|
||||
异动列表(按分数排序,限制数量)
|
||||
"""
|
||||
alerts = []
|
||||
|
||||
for data in concepts_data:
|
||||
alert = self.detect(
|
||||
concept_id=data['concept_id'],
|
||||
concept_name=data['concept_name'],
|
||||
alpha=data.get('alpha', 0),
|
||||
amt_ratio=data.get('amt_ratio', 1.0),
|
||||
rank_pct=data.get('rank_pct', 0.5),
|
||||
limit_up_ratio=data.get('limit_up_ratio', 0),
|
||||
change_pct=data.get('change_pct', 0),
|
||||
index_change=data.get('index_change', 0),
|
||||
current_time=current_time,
|
||||
limit_up_count=data.get('limit_up_count', 0),
|
||||
limit_down_count=data.get('limit_down_count', 0),
|
||||
stock_count=data.get('stock_count', 0),
|
||||
concept_type=data.get('concept_type', 'leaf'),
|
||||
)
|
||||
|
||||
if alert:
|
||||
alerts.append(alert)
|
||||
|
||||
# 排序并限制数量
|
||||
alerts.sort(key=lambda x: x['final_score'], reverse=True)
|
||||
return alerts[:self.config['max_alerts_per_minute']]
|
||||
|
||||
def reset(self):
|
||||
"""重置所有状态(新交易日)"""
|
||||
self.alpha_stats.clear()
|
||||
self.feature_history.clear()
|
||||
self.cooldown_cache.clear()
|
||||
|
||||
if self.ml_detector:
|
||||
self.ml_detector.clear_history()
|
||||
|
||||
logger.info("检测器状态已重置")
|
||||
|
||||
|
||||
# ==================== 测试 ====================
|
||||
|
||||
if __name__ == "__main__":
|
||||
import random
|
||||
|
||||
print("测试 EnhancedAnomalyDetector...")
|
||||
|
||||
# 初始化
|
||||
detector = EnhancedAnomalyDetector(ml_enabled=False) # 不加载 ML(可能不存在)
|
||||
|
||||
# 模拟数据
|
||||
concepts = [
|
||||
{'concept_id': 'ai_001', 'concept_name': '人工智能'},
|
||||
{'concept_id': 'chip_002', 'concept_name': '芯片半导体'},
|
||||
{'concept_id': 'car_003', 'concept_name': '新能源汽车'},
|
||||
]
|
||||
|
||||
print("\n模拟实时检测...")
|
||||
current_time = datetime.now()
|
||||
|
||||
for minute in range(50):
|
||||
concepts_data = []
|
||||
|
||||
for c in concepts:
|
||||
# 生成随机数据
|
||||
alpha = random.gauss(0, 0.8)
|
||||
amt_ratio = max(0.3, random.gauss(1, 0.3))
|
||||
rank_pct = random.random()
|
||||
limit_up_ratio = random.random() * 0.1
|
||||
|
||||
# 模拟异动(第30分钟人工智能暴涨)
|
||||
if minute == 30 and c['concept_id'] == 'ai_001':
|
||||
alpha = 4.5
|
||||
amt_ratio = 2.5
|
||||
limit_up_ratio = 0.3
|
||||
|
||||
concepts_data.append({
|
||||
**c,
|
||||
'alpha': alpha,
|
||||
'amt_ratio': amt_ratio,
|
||||
'rank_pct': rank_pct,
|
||||
'limit_up_ratio': limit_up_ratio,
|
||||
'change_pct': alpha + 0.5,
|
||||
'index_change': 0.5,
|
||||
})
|
||||
|
||||
# 检测
|
||||
alerts = detector.batch_detect(concepts_data, current_time)
|
||||
|
||||
if alerts:
|
||||
for alert in alerts:
|
||||
print(f" t={minute:02d} 🔥 {alert['concept_name']} "
|
||||
f"Alpha={alert['alpha']:+.2f}% "
|
||||
f"Score={alert['final_score']:.2f} "
|
||||
f"Method={alert['detection_method']}")
|
||||
|
||||
current_time = current_time.replace(minute=current_time.minute + 1 if current_time.minute < 59 else 0)
|
||||
|
||||
print("\n测试完成!")
|
||||
455
ml/inference.py
Normal file
455
ml/inference.py
Normal file
@@ -0,0 +1,455 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
概念异动检测推理服务
|
||||
|
||||
在实时场景中使用训练好的 Transformer Autoencoder 进行异动检测
|
||||
|
||||
使用方法:
|
||||
from ml.inference import ConceptAnomalyDetector
|
||||
|
||||
detector = ConceptAnomalyDetector('ml/checkpoints')
|
||||
|
||||
# 检测异动
|
||||
features = {...} # 实时特征数据
|
||||
is_anomaly, score = detector.detect(features)
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple, Optional
|
||||
from collections import deque
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from model import TransformerAutoencoder
|
||||
|
||||
|
||||
class ConceptAnomalyDetector:
|
||||
"""
|
||||
概念异动检测器
|
||||
|
||||
使用训练好的 Transformer Autoencoder 进行实时异动检测
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
checkpoint_dir: str = 'ml/checkpoints',
|
||||
device: str = 'auto',
|
||||
threshold_key: str = 'p95'
|
||||
):
|
||||
"""
|
||||
初始化检测器
|
||||
|
||||
Args:
|
||||
checkpoint_dir: 模型检查点目录
|
||||
device: 设备 (auto/cuda/cpu)
|
||||
threshold_key: 使用的阈值键 (p90/p95/p99)
|
||||
"""
|
||||
self.checkpoint_dir = Path(checkpoint_dir)
|
||||
self.threshold_key = threshold_key
|
||||
|
||||
# 设备选择
|
||||
if device == 'auto':
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
else:
|
||||
self.device = torch.device(device)
|
||||
|
||||
# 加载配置
|
||||
self._load_config()
|
||||
|
||||
# 加载模型
|
||||
self._load_model()
|
||||
|
||||
# 加载阈值
|
||||
self._load_thresholds()
|
||||
|
||||
# 加载标准化统计量
|
||||
self._load_normalization_stats()
|
||||
|
||||
# 概念历史数据缓存
|
||||
# {concept_name: deque(maxlen=seq_len)}
|
||||
self.history_cache: Dict[str, deque] = {}
|
||||
|
||||
print(f"ConceptAnomalyDetector 初始化完成")
|
||||
print(f" 设备: {self.device}")
|
||||
print(f" 阈值: {self.threshold_key} = {self.threshold:.6f}")
|
||||
print(f" 序列长度: {self.seq_len}")
|
||||
|
||||
def _load_config(self):
|
||||
"""加载配置"""
|
||||
config_path = self.checkpoint_dir / 'config.json'
|
||||
if not config_path.exists():
|
||||
raise FileNotFoundError(f"配置文件不存在: {config_path}")
|
||||
|
||||
with open(config_path, 'r') as f:
|
||||
self.config = json.load(f)
|
||||
|
||||
self.features = self.config['features']
|
||||
self.seq_len = self.config['seq_len']
|
||||
self.model_config = self.config['model']
|
||||
|
||||
def _load_model(self):
|
||||
"""加载模型"""
|
||||
model_path = self.checkpoint_dir / 'best_model.pt'
|
||||
if not model_path.exists():
|
||||
raise FileNotFoundError(f"模型文件不存在: {model_path}")
|
||||
|
||||
# 创建模型
|
||||
self.model = TransformerAutoencoder(**self.model_config)
|
||||
|
||||
# 加载权重
|
||||
checkpoint = torch.load(model_path, map_location=self.device)
|
||||
self.model.load_state_dict(checkpoint['model_state_dict'])
|
||||
self.model.to(self.device)
|
||||
self.model.eval()
|
||||
|
||||
print(f"模型已加载: {model_path}")
|
||||
|
||||
def _load_thresholds(self):
|
||||
"""加载阈值"""
|
||||
thresholds_path = self.checkpoint_dir / 'thresholds.json'
|
||||
if not thresholds_path.exists():
|
||||
raise FileNotFoundError(f"阈值文件不存在: {thresholds_path}")
|
||||
|
||||
with open(thresholds_path, 'r') as f:
|
||||
self.thresholds = json.load(f)
|
||||
|
||||
if self.threshold_key not in self.thresholds:
|
||||
available_keys = list(self.thresholds.keys())
|
||||
raise KeyError(f"阈值键 '{self.threshold_key}' 不存在,可用: {available_keys}")
|
||||
|
||||
self.threshold = self.thresholds[self.threshold_key]
|
||||
|
||||
def _load_normalization_stats(self):
|
||||
"""加载标准化统计量"""
|
||||
stats_path = self.checkpoint_dir / 'normalization_stats.json'
|
||||
if not stats_path.exists():
|
||||
raise FileNotFoundError(f"标准化统计量文件不存在: {stats_path}")
|
||||
|
||||
with open(stats_path, 'r') as f:
|
||||
stats = json.load(f)
|
||||
|
||||
self.norm_mean = np.array(stats['mean'])
|
||||
self.norm_std = np.array(stats['std'])
|
||||
|
||||
def normalize(self, features: np.ndarray) -> np.ndarray:
|
||||
"""标准化特征"""
|
||||
return (features - self.norm_mean) / self.norm_std
|
||||
|
||||
def update_history(
|
||||
self,
|
||||
concept_name: str,
|
||||
features: Dict[str, float]
|
||||
):
|
||||
"""
|
||||
更新概念历史数据
|
||||
|
||||
Args:
|
||||
concept_name: 概念名称
|
||||
features: 当前时刻的特征字典
|
||||
"""
|
||||
# 初始化历史缓存
|
||||
if concept_name not in self.history_cache:
|
||||
self.history_cache[concept_name] = deque(maxlen=self.seq_len)
|
||||
|
||||
# 提取特征向量
|
||||
feature_vector = np.array([
|
||||
features.get(f, 0.0) for f in self.features
|
||||
])
|
||||
|
||||
# 处理异常值
|
||||
feature_vector = np.nan_to_num(feature_vector, nan=0.0, posinf=0.0, neginf=0.0)
|
||||
|
||||
# 添加到历史
|
||||
self.history_cache[concept_name].append(feature_vector)
|
||||
|
||||
def get_history_length(self, concept_name: str) -> int:
|
||||
"""获取概念的历史数据长度"""
|
||||
if concept_name not in self.history_cache:
|
||||
return 0
|
||||
return len(self.history_cache[concept_name])
|
||||
|
||||
@torch.no_grad()
|
||||
def detect(
|
||||
self,
|
||||
concept_name: str,
|
||||
features: Dict[str, float] = None,
|
||||
return_score: bool = True
|
||||
) -> Tuple[bool, Optional[float]]:
|
||||
"""
|
||||
检测概念是否异动
|
||||
|
||||
Args:
|
||||
concept_name: 概念名称
|
||||
features: 当前时刻的特征(如果提供,会先更新历史)
|
||||
return_score: 是否返回异动分数
|
||||
|
||||
Returns:
|
||||
is_anomaly: 是否异动
|
||||
score: 异动分数(如果 return_score=True)
|
||||
"""
|
||||
# 更新历史
|
||||
if features is not None:
|
||||
self.update_history(concept_name, features)
|
||||
|
||||
# 检查历史数据是否足够
|
||||
if concept_name not in self.history_cache:
|
||||
return False, None
|
||||
|
||||
history = self.history_cache[concept_name]
|
||||
if len(history) < self.seq_len:
|
||||
return False, None
|
||||
|
||||
# 构建输入序列
|
||||
sequence = np.array(list(history)) # (seq_len, n_features)
|
||||
|
||||
# 标准化
|
||||
sequence = self.normalize(sequence)
|
||||
|
||||
# 转为 tensor
|
||||
x = torch.FloatTensor(sequence).unsqueeze(0) # (1, seq_len, n_features)
|
||||
x = x.to(self.device)
|
||||
|
||||
# 计算重构误差
|
||||
error = self.model.compute_reconstruction_error(x, reduction='none')
|
||||
|
||||
# 取最后一个时刻的误差作为当前分数
|
||||
score = error[0, -1].item()
|
||||
|
||||
# 判断是否异动
|
||||
is_anomaly = score > self.threshold
|
||||
|
||||
if return_score:
|
||||
return is_anomaly, score
|
||||
else:
|
||||
return is_anomaly, None
|
||||
|
||||
@torch.no_grad()
|
||||
def batch_detect(
|
||||
self,
|
||||
concept_features: Dict[str, Dict[str, float]]
|
||||
) -> Dict[str, Tuple[bool, float]]:
|
||||
"""
|
||||
批量检测多个概念
|
||||
|
||||
Args:
|
||||
concept_features: {concept_name: {feature_name: value}}
|
||||
|
||||
Returns:
|
||||
results: {concept_name: (is_anomaly, score)}
|
||||
"""
|
||||
results = {}
|
||||
|
||||
for concept_name, features in concept_features.items():
|
||||
is_anomaly, score = self.detect(concept_name, features)
|
||||
results[concept_name] = (is_anomaly, score)
|
||||
|
||||
return results
|
||||
|
||||
def get_anomaly_type(
|
||||
self,
|
||||
concept_name: str,
|
||||
features: Dict[str, float]
|
||||
) -> str:
|
||||
"""
|
||||
判断异动类型
|
||||
|
||||
Args:
|
||||
concept_name: 概念名称
|
||||
features: 当前特征
|
||||
|
||||
Returns:
|
||||
anomaly_type: 'surge_up' / 'surge_down' / 'normal'
|
||||
"""
|
||||
is_anomaly, score = self.detect(concept_name, features)
|
||||
|
||||
if not is_anomaly:
|
||||
return 'normal'
|
||||
|
||||
# 根据 alpha 判断涨跌
|
||||
alpha = features.get('alpha', 0.0)
|
||||
|
||||
if alpha > 0:
|
||||
return 'surge_up'
|
||||
else:
|
||||
return 'surge_down'
|
||||
|
||||
def get_top_anomalies(
|
||||
self,
|
||||
concept_features: Dict[str, Dict[str, float]],
|
||||
top_k: int = 10
|
||||
) -> List[Tuple[str, float, str]]:
|
||||
"""
|
||||
获取异动分数最高的 top_k 个概念
|
||||
|
||||
Args:
|
||||
concept_features: {concept_name: {feature_name: value}}
|
||||
top_k: 返回数量
|
||||
|
||||
Returns:
|
||||
anomalies: [(concept_name, score, anomaly_type), ...]
|
||||
"""
|
||||
results = self.batch_detect(concept_features)
|
||||
|
||||
# 按分数排序
|
||||
sorted_results = sorted(
|
||||
[(name, is_anomaly, score) for name, (is_anomaly, score) in results.items() if score is not None],
|
||||
key=lambda x: x[2],
|
||||
reverse=True
|
||||
)
|
||||
|
||||
# 取 top_k
|
||||
top_anomalies = []
|
||||
for name, is_anomaly, score in sorted_results[:top_k]:
|
||||
if is_anomaly:
|
||||
alpha = concept_features[name].get('alpha', 0.0)
|
||||
anomaly_type = 'surge_up' if alpha > 0 else 'surge_down'
|
||||
top_anomalies.append((name, score, anomaly_type))
|
||||
|
||||
return top_anomalies
|
||||
|
||||
def clear_history(self, concept_name: str = None):
|
||||
"""
|
||||
清除历史缓存
|
||||
|
||||
Args:
|
||||
concept_name: 概念名称(如果为 None,清除所有)
|
||||
"""
|
||||
if concept_name is None:
|
||||
self.history_cache.clear()
|
||||
elif concept_name in self.history_cache:
|
||||
del self.history_cache[concept_name]
|
||||
|
||||
|
||||
# ==================== 集成到现有系统 ====================
|
||||
|
||||
class MLAnomalyService:
|
||||
"""
|
||||
ML 异动检测服务
|
||||
|
||||
用于替换或增强现有的 Alpha-based 检测
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
checkpoint_dir: str = 'ml/checkpoints',
|
||||
fallback_to_alpha: bool = True
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
checkpoint_dir: 模型检查点目录
|
||||
fallback_to_alpha: 当 ML 模型不可用时是否回退到 Alpha 方法
|
||||
"""
|
||||
self.fallback_to_alpha = fallback_to_alpha
|
||||
self.ml_detector = None
|
||||
|
||||
try:
|
||||
self.ml_detector = ConceptAnomalyDetector(checkpoint_dir)
|
||||
print("ML 异动检测服务初始化成功")
|
||||
except Exception as e:
|
||||
print(f"ML 模型加载失败: {e}")
|
||||
if not fallback_to_alpha:
|
||||
raise
|
||||
print("将回退到 Alpha-based 检测")
|
||||
|
||||
def is_ml_available(self) -> bool:
|
||||
"""检查 ML 模型是否可用"""
|
||||
return self.ml_detector is not None
|
||||
|
||||
def detect_anomaly(
|
||||
self,
|
||||
concept_name: str,
|
||||
features: Dict[str, float],
|
||||
alpha_threshold: float = 2.0
|
||||
) -> Tuple[bool, float, str]:
|
||||
"""
|
||||
检测异动
|
||||
|
||||
Args:
|
||||
concept_name: 概念名称
|
||||
features: 特征字典(需包含 alpha, amt_ratio 等)
|
||||
alpha_threshold: Alpha Z-Score 阈值(用于回退)
|
||||
|
||||
Returns:
|
||||
is_anomaly: 是否异动
|
||||
score: 异动分数
|
||||
method: 检测方法 ('ml' / 'alpha')
|
||||
"""
|
||||
# 优先使用 ML 检测
|
||||
if self.ml_detector is not None:
|
||||
history_len = self.ml_detector.get_history_length(concept_name)
|
||||
|
||||
# 历史数据足够时使用 ML
|
||||
if history_len >= self.ml_detector.seq_len - 1:
|
||||
is_anomaly, score = self.ml_detector.detect(concept_name, features)
|
||||
if score is not None:
|
||||
return is_anomaly, score, 'ml'
|
||||
else:
|
||||
# 更新历史但使用 Alpha 方法
|
||||
self.ml_detector.update_history(concept_name, features)
|
||||
|
||||
# 回退到 Alpha 方法
|
||||
if self.fallback_to_alpha:
|
||||
alpha = features.get('alpha', 0.0)
|
||||
alpha_zscore = features.get('alpha_zscore', 0.0)
|
||||
|
||||
is_anomaly = abs(alpha_zscore) > alpha_threshold
|
||||
score = abs(alpha_zscore)
|
||||
|
||||
return is_anomaly, score, 'alpha'
|
||||
|
||||
return False, 0.0, 'none'
|
||||
|
||||
|
||||
# ==================== 测试 ====================
|
||||
|
||||
if __name__ == "__main__":
|
||||
import random
|
||||
|
||||
print("测试 ConceptAnomalyDetector...")
|
||||
|
||||
# 检查模型是否存在
|
||||
checkpoint_dir = Path('ml/checkpoints')
|
||||
if not (checkpoint_dir / 'best_model.pt').exists():
|
||||
print("模型文件不存在,跳过测试")
|
||||
print("请先运行 train.py 训练模型")
|
||||
exit(0)
|
||||
|
||||
# 初始化检测器
|
||||
detector = ConceptAnomalyDetector('ml/checkpoints')
|
||||
|
||||
# 模拟数据
|
||||
print("\n模拟实时检测...")
|
||||
concept_name = "人工智能"
|
||||
|
||||
for i in range(40):
|
||||
# 生成随机特征
|
||||
features = {
|
||||
'alpha': random.gauss(0, 1),
|
||||
'alpha_delta': random.gauss(0, 0.5),
|
||||
'amt_ratio': random.gauss(1, 0.3),
|
||||
'amt_delta': random.gauss(0, 0.2),
|
||||
'rank_pct': random.random(),
|
||||
'limit_up_ratio': random.random() * 0.1,
|
||||
}
|
||||
|
||||
# 在第 35 分钟模拟异动
|
||||
if i == 35:
|
||||
features['alpha'] = 5.0
|
||||
features['alpha_delta'] = 2.0
|
||||
features['amt_ratio'] = 3.0
|
||||
|
||||
is_anomaly, score = detector.detect(concept_name, features)
|
||||
|
||||
history_len = detector.get_history_length(concept_name)
|
||||
|
||||
if score is not None:
|
||||
status = "🔥 异动!" if is_anomaly else "正常"
|
||||
print(f" t={i:02d} | 历史={history_len} | 分数={score:.4f} | {status}")
|
||||
else:
|
||||
print(f" t={i:02d} | 历史={history_len} | 数据不足")
|
||||
|
||||
print("\n测试完成!")
|
||||
393
ml/model.py
Normal file
393
ml/model.py
Normal file
@@ -0,0 +1,393 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
LSTM Autoencoder 模型定义
|
||||
|
||||
用于概念异动检测:
|
||||
- 学习"正常"市场模式
|
||||
- 重构误差大的时刻 = 异动
|
||||
|
||||
模型结构(简洁有效):
|
||||
┌─────────────────────────────────────┐
|
||||
│ 输入: (batch, seq_len, n_features) │
|
||||
│ 过去30分钟的特征序列 │
|
||||
├─────────────────────────────────────┤
|
||||
│ LSTM Encoder │
|
||||
│ - 双向 LSTM │
|
||||
│ - 输出最后隐藏状态 │
|
||||
├─────────────────────────────────────┤
|
||||
│ Bottleneck (压缩层) │
|
||||
│ 降维到 latent_dim(关键!) │
|
||||
├─────────────────────────────────────┤
|
||||
│ LSTM Decoder │
|
||||
│ - 单向 LSTM │
|
||||
│ - 重构序列 │
|
||||
├─────────────────────────────────────┤
|
||||
│ 输出: (batch, seq_len, n_features) │
|
||||
│ 重构的特征序列 │
|
||||
└─────────────────────────────────────┘
|
||||
|
||||
为什么用 LSTM 而不是 Transformer:
|
||||
1. 参数更少,不容易过拟合
|
||||
2. 对于 6 维特征足够用
|
||||
3. 训练更稳定
|
||||
4. 瓶颈约束更容易控制
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from typing import Optional, Tuple
|
||||
|
||||
|
||||
class LSTMAutoencoder(nn.Module):
|
||||
"""
|
||||
LSTM Autoencoder for Anomaly Detection
|
||||
|
||||
设计原则:
|
||||
- 足够简单,避免过拟合
|
||||
- 瓶颈层严格限制,迫使模型只学习主要模式
|
||||
- 异常难以通过狭窄瓶颈,重构误差大
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_features: int = 6,
|
||||
hidden_dim: int = 32, # LSTM 隐藏维度(小!)
|
||||
latent_dim: int = 4, # 瓶颈维度(非常小!关键参数)
|
||||
num_layers: int = 1, # LSTM 层数
|
||||
dropout: float = 0.2,
|
||||
bidirectional: bool = True, # 双向编码器
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.n_features = n_features
|
||||
self.hidden_dim = hidden_dim
|
||||
self.latent_dim = latent_dim
|
||||
self.num_layers = num_layers
|
||||
self.bidirectional = bidirectional
|
||||
self.num_directions = 2 if bidirectional else 1
|
||||
|
||||
# Encoder: 双向 LSTM
|
||||
self.encoder = nn.LSTM(
|
||||
input_size=n_features,
|
||||
hidden_size=hidden_dim,
|
||||
num_layers=num_layers,
|
||||
batch_first=True,
|
||||
dropout=dropout if num_layers > 1 else 0,
|
||||
bidirectional=bidirectional
|
||||
)
|
||||
|
||||
# Bottleneck: 压缩到极小的 latent space
|
||||
encoder_output_dim = hidden_dim * self.num_directions
|
||||
self.bottleneck_down = nn.Sequential(
|
||||
nn.Linear(encoder_output_dim, latent_dim),
|
||||
nn.Tanh(), # 限制范围,增加约束
|
||||
)
|
||||
|
||||
# 使用 LeakyReLU 替代 ReLU
|
||||
# 原因:Z-Score 数据范围是 [-5, +5],ReLU 会截断负值,丢失跌幅信息
|
||||
# LeakyReLU 保留负值信号(乘以 0.1)
|
||||
self.bottleneck_up = nn.Sequential(
|
||||
nn.Linear(latent_dim, hidden_dim),
|
||||
nn.LeakyReLU(negative_slope=0.1),
|
||||
)
|
||||
|
||||
# Decoder: 单向 LSTM
|
||||
self.decoder = nn.LSTM(
|
||||
input_size=hidden_dim,
|
||||
hidden_size=hidden_dim,
|
||||
num_layers=num_layers,
|
||||
batch_first=True,
|
||||
dropout=dropout if num_layers > 1 else 0,
|
||||
bidirectional=False # 解码器用单向
|
||||
)
|
||||
|
||||
# 输出层
|
||||
self.output_layer = nn.Linear(hidden_dim, n_features)
|
||||
|
||||
# Dropout
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
# 初始化
|
||||
self._init_weights()
|
||||
|
||||
def _init_weights(self):
|
||||
"""初始化权重"""
|
||||
for name, param in self.named_parameters():
|
||||
if 'weight_ih' in name:
|
||||
nn.init.xavier_uniform_(param)
|
||||
elif 'weight_hh' in name:
|
||||
nn.init.orthogonal_(param)
|
||||
elif 'bias' in name:
|
||||
nn.init.zeros_(param)
|
||||
|
||||
def encode(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
编码器
|
||||
|
||||
Args:
|
||||
x: (batch, seq_len, n_features)
|
||||
Returns:
|
||||
latent: (batch, seq_len, latent_dim) 每个时间步的压缩表示
|
||||
encoder_outputs: (batch, seq_len, hidden_dim * num_directions)
|
||||
"""
|
||||
# LSTM 编码
|
||||
encoder_outputs, (h_n, c_n) = self.encoder(x)
|
||||
# encoder_outputs: (batch, seq_len, hidden_dim * num_directions)
|
||||
|
||||
encoder_outputs = self.dropout(encoder_outputs)
|
||||
|
||||
# 压缩到 latent space(对每个时间步)
|
||||
latent = self.bottleneck_down(encoder_outputs)
|
||||
# latent: (batch, seq_len, latent_dim)
|
||||
|
||||
return latent, encoder_outputs
|
||||
|
||||
def decode(self, latent: torch.Tensor, seq_len: int) -> torch.Tensor:
|
||||
"""
|
||||
解码器
|
||||
|
||||
Args:
|
||||
latent: (batch, seq_len, latent_dim)
|
||||
seq_len: 序列长度
|
||||
Returns:
|
||||
output: (batch, seq_len, n_features)
|
||||
"""
|
||||
# 从 latent space 恢复
|
||||
decoder_input = self.bottleneck_up(latent)
|
||||
# decoder_input: (batch, seq_len, hidden_dim)
|
||||
|
||||
# LSTM 解码
|
||||
decoder_outputs, _ = self.decoder(decoder_input)
|
||||
# decoder_outputs: (batch, seq_len, hidden_dim)
|
||||
|
||||
decoder_outputs = self.dropout(decoder_outputs)
|
||||
|
||||
# 投影到原始特征空间
|
||||
output = self.output_layer(decoder_outputs)
|
||||
# output: (batch, seq_len, n_features)
|
||||
|
||||
return output
|
||||
|
||||
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
前向传播
|
||||
|
||||
Args:
|
||||
x: (batch, seq_len, n_features)
|
||||
Returns:
|
||||
output: (batch, seq_len, n_features) 重构结果
|
||||
latent: (batch, seq_len, latent_dim) 隐向量
|
||||
"""
|
||||
batch_size, seq_len, _ = x.shape
|
||||
|
||||
# 编码
|
||||
latent, _ = self.encode(x)
|
||||
|
||||
# 解码
|
||||
output = self.decode(latent, seq_len)
|
||||
|
||||
return output, latent
|
||||
|
||||
def compute_reconstruction_error(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
reduction: str = 'none'
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
计算重构误差
|
||||
|
||||
Args:
|
||||
x: (batch, seq_len, n_features)
|
||||
reduction: 'none' | 'mean' | 'sum'
|
||||
Returns:
|
||||
error: 重构误差
|
||||
"""
|
||||
output, _ = self.forward(x)
|
||||
|
||||
# MSE per feature per timestep
|
||||
error = F.mse_loss(output, x, reduction='none')
|
||||
|
||||
if reduction == 'none':
|
||||
# (batch, seq_len, n_features) -> (batch, seq_len)
|
||||
return error.mean(dim=-1)
|
||||
elif reduction == 'mean':
|
||||
return error.mean()
|
||||
elif reduction == 'sum':
|
||||
return error.sum()
|
||||
else:
|
||||
raise ValueError(f"Unknown reduction: {reduction}")
|
||||
|
||||
def detect_anomaly(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
threshold: float = None,
|
||||
return_scores: bool = True
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""
|
||||
检测异动
|
||||
|
||||
Args:
|
||||
x: (batch, seq_len, n_features)
|
||||
threshold: 异动阈值(如果为 None,只返回分数)
|
||||
return_scores: 是否返回异动分数
|
||||
Returns:
|
||||
is_anomaly: (batch, seq_len) bool tensor (if threshold is not None)
|
||||
scores: (batch, seq_len) 异动分数 (if return_scores)
|
||||
"""
|
||||
scores = self.compute_reconstruction_error(x, reduction='none')
|
||||
|
||||
is_anomaly = None
|
||||
if threshold is not None:
|
||||
is_anomaly = scores > threshold
|
||||
|
||||
if return_scores:
|
||||
return is_anomaly, scores
|
||||
else:
|
||||
return is_anomaly, None
|
||||
|
||||
|
||||
# 为了兼容性,创建别名
|
||||
TransformerAutoencoder = LSTMAutoencoder
|
||||
|
||||
|
||||
# ==================== 损失函数 ====================
|
||||
|
||||
class AnomalyDetectionLoss(nn.Module):
|
||||
"""
|
||||
异动检测损失函数
|
||||
|
||||
简单的 MSE 重构损失
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
feature_weights: torch.Tensor = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.feature_weights = feature_weights
|
||||
|
||||
def forward(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
latent: torch.Tensor = None
|
||||
) -> Tuple[torch.Tensor, dict]:
|
||||
"""
|
||||
Args:
|
||||
output: (batch, seq_len, n_features) 重构结果
|
||||
target: (batch, seq_len, n_features) 原始输入
|
||||
latent: (batch, seq_len, latent_dim) 隐向量(未使用)
|
||||
Returns:
|
||||
loss: 总损失
|
||||
loss_dict: 各项损失详情
|
||||
"""
|
||||
# 重构损失 (MSE)
|
||||
mse = F.mse_loss(output, target, reduction='none')
|
||||
|
||||
# 特征加权(可选)
|
||||
if self.feature_weights is not None:
|
||||
weights = self.feature_weights.to(mse.device)
|
||||
mse = mse * weights
|
||||
|
||||
reconstruction_loss = mse.mean()
|
||||
|
||||
loss_dict = {
|
||||
'total': reconstruction_loss.item(),
|
||||
'reconstruction': reconstruction_loss.item(),
|
||||
}
|
||||
|
||||
return reconstruction_loss, loss_dict
|
||||
|
||||
|
||||
# ==================== 工具函数 ====================
|
||||
|
||||
def count_parameters(model: nn.Module) -> int:
|
||||
"""统计模型参数量"""
|
||||
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
|
||||
|
||||
def create_model(config: dict = None) -> LSTMAutoencoder:
|
||||
"""
|
||||
创建模型
|
||||
|
||||
默认使用小型 LSTM 配置,适合异动检测
|
||||
"""
|
||||
default_config = {
|
||||
'n_features': 6,
|
||||
'hidden_dim': 32, # 小!
|
||||
'latent_dim': 4, # 非常小!关键
|
||||
'num_layers': 1,
|
||||
'dropout': 0.2,
|
||||
'bidirectional': True,
|
||||
}
|
||||
|
||||
if config:
|
||||
# 兼容旧的 Transformer 配置键名
|
||||
if 'd_model' in config:
|
||||
config['hidden_dim'] = config.pop('d_model') // 2
|
||||
if 'num_encoder_layers' in config:
|
||||
config['num_layers'] = config.pop('num_encoder_layers')
|
||||
if 'num_decoder_layers' in config:
|
||||
config.pop('num_decoder_layers')
|
||||
if 'nhead' in config:
|
||||
config.pop('nhead')
|
||||
if 'dim_feedforward' in config:
|
||||
config.pop('dim_feedforward')
|
||||
if 'max_seq_len' in config:
|
||||
config.pop('max_seq_len')
|
||||
if 'use_instance_norm' in config:
|
||||
config.pop('use_instance_norm')
|
||||
|
||||
default_config.update(config)
|
||||
|
||||
model = LSTMAutoencoder(**default_config)
|
||||
param_count = count_parameters(model)
|
||||
print(f"模型参数量: {param_count:,}")
|
||||
|
||||
if param_count > 100000:
|
||||
print(f"⚠️ 警告: 参数量较大({param_count:,}),可能过拟合")
|
||||
else:
|
||||
print(f"✓ 参数量适中(LSTM Autoencoder)")
|
||||
|
||||
return model
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 测试模型
|
||||
print("测试 LSTM Autoencoder...")
|
||||
|
||||
# 创建模型
|
||||
model = create_model()
|
||||
|
||||
# 测试输入
|
||||
batch_size = 32
|
||||
seq_len = 30
|
||||
n_features = 6
|
||||
|
||||
x = torch.randn(batch_size, seq_len, n_features)
|
||||
|
||||
# 前向传播
|
||||
output, latent = model(x)
|
||||
|
||||
print(f"输入形状: {x.shape}")
|
||||
print(f"输出形状: {output.shape}")
|
||||
print(f"隐向量形状: {latent.shape}")
|
||||
|
||||
# 计算重构误差
|
||||
error = model.compute_reconstruction_error(x)
|
||||
print(f"重构误差形状: {error.shape}")
|
||||
print(f"平均重构误差: {error.mean().item():.4f}")
|
||||
|
||||
# 测试异动检测
|
||||
is_anomaly, scores = model.detect_anomaly(x, threshold=0.5)
|
||||
print(f"异动检测结果形状: {is_anomaly.shape if is_anomaly is not None else 'None'}")
|
||||
print(f"异动分数形状: {scores.shape}")
|
||||
|
||||
# 测试损失函数
|
||||
criterion = AnomalyDetectionLoss()
|
||||
loss, loss_dict = criterion(output, x, latent)
|
||||
print(f"损失: {loss.item():.4f}")
|
||||
|
||||
print("\n测试通过!")
|
||||
537
ml/prepare_data.py
Normal file
537
ml/prepare_data.py
Normal file
@@ -0,0 +1,537 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
数据准备脚本 - 为 Transformer Autoencoder 准备训练数据
|
||||
|
||||
从 ClickHouse 提取历史分钟数据,计算以下特征:
|
||||
1. alpha - 超额收益(概念涨幅 - 大盘涨幅)
|
||||
2. alpha_delta - Alpha 变化率(5分钟)
|
||||
3. amt_ratio - 成交额相对均值(当前/过去20分钟均值)
|
||||
4. amt_delta - 成交额变化率
|
||||
5. rank_pct - Alpha 排名百分位
|
||||
6. limit_up_ratio - 涨停股占比
|
||||
|
||||
输出:按交易日存储的特征文件(parquet格式)
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from datetime import datetime, timedelta, date
|
||||
from sqlalchemy import create_engine, text
|
||||
from elasticsearch import Elasticsearch
|
||||
from clickhouse_driver import Client
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, List, Set, Tuple
|
||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||
from multiprocessing import Manager
|
||||
import multiprocessing
|
||||
import warnings
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
# ==================== 配置 ====================
|
||||
|
||||
MYSQL_ENGINE = create_engine(
|
||||
"mysql+pymysql://root:Zzl5588161!@192.168.1.5:3306/stock",
|
||||
echo=False
|
||||
)
|
||||
|
||||
ES_CLIENT = Elasticsearch(['http://127.0.0.1:9200'])
|
||||
ES_INDEX = 'concept_library_v3'
|
||||
|
||||
CLICKHOUSE_CONFIG = {
|
||||
'host': '127.0.0.1',
|
||||
'port': 9000,
|
||||
'user': 'default',
|
||||
'password': 'Zzl33818!',
|
||||
'database': 'stock'
|
||||
}
|
||||
|
||||
# 输出目录
|
||||
OUTPUT_DIR = os.path.join(os.path.dirname(__file__), 'data')
|
||||
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
||||
|
||||
# 特征计算参数
|
||||
FEATURE_CONFIG = {
|
||||
'alpha_delta_window': 5, # Alpha变化窗口(分钟)
|
||||
'amt_ma_window': 20, # 成交额均值窗口(分钟)
|
||||
'limit_up_threshold': 9.8, # 涨停阈值(%)
|
||||
'limit_down_threshold': -9.8, # 跌停阈值(%)
|
||||
}
|
||||
|
||||
REFERENCE_INDEX = '000001.SH'
|
||||
|
||||
# ==================== 日志 ====================
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ==================== 工具函数 ====================
|
||||
|
||||
def get_ch_client():
|
||||
return Client(**CLICKHOUSE_CONFIG)
|
||||
|
||||
|
||||
def generate_id(name: str) -> str:
|
||||
return hashlib.md5(name.encode('utf-8')).hexdigest()[:16]
|
||||
|
||||
|
||||
def code_to_ch_format(code: str) -> str:
|
||||
if not code or len(code) != 6 or not code.isdigit():
|
||||
return None
|
||||
if code.startswith('6'):
|
||||
return f"{code}.SH"
|
||||
elif code.startswith('0') or code.startswith('3'):
|
||||
return f"{code}.SZ"
|
||||
else:
|
||||
return f"{code}.BJ"
|
||||
|
||||
|
||||
# ==================== 获取概念列表 ====================
|
||||
|
||||
def get_all_concepts() -> List[dict]:
|
||||
"""从ES获取所有叶子概念"""
|
||||
concepts = []
|
||||
|
||||
query = {
|
||||
"query": {"match_all": {}},
|
||||
"size": 100,
|
||||
"_source": ["concept_id", "concept", "stocks"]
|
||||
}
|
||||
|
||||
resp = ES_CLIENT.search(index=ES_INDEX, body=query, scroll='2m')
|
||||
scroll_id = resp['_scroll_id']
|
||||
hits = resp['hits']['hits']
|
||||
|
||||
while len(hits) > 0:
|
||||
for hit in hits:
|
||||
source = hit['_source']
|
||||
stocks = []
|
||||
if 'stocks' in source and isinstance(source['stocks'], list):
|
||||
for stock in source['stocks']:
|
||||
if isinstance(stock, dict) and 'code' in stock and stock['code']:
|
||||
stocks.append(stock['code'])
|
||||
|
||||
if stocks:
|
||||
concepts.append({
|
||||
'concept_id': source.get('concept_id'),
|
||||
'concept_name': source.get('concept'),
|
||||
'stocks': stocks
|
||||
})
|
||||
|
||||
resp = ES_CLIENT.scroll(scroll_id=scroll_id, scroll='2m')
|
||||
scroll_id = resp['_scroll_id']
|
||||
hits = resp['hits']['hits']
|
||||
|
||||
ES_CLIENT.clear_scroll(scroll_id=scroll_id)
|
||||
print(f"获取到 {len(concepts)} 个概念")
|
||||
return concepts
|
||||
|
||||
|
||||
# ==================== 获取交易日列表 ====================
|
||||
|
||||
def get_trading_days(start_date: str, end_date: str) -> List[str]:
|
||||
"""获取交易日列表"""
|
||||
client = get_ch_client()
|
||||
|
||||
query = f"""
|
||||
SELECT DISTINCT toDate(timestamp) as trade_date
|
||||
FROM stock_minute
|
||||
WHERE toDate(timestamp) >= '{start_date}'
|
||||
AND toDate(timestamp) <= '{end_date}'
|
||||
ORDER BY trade_date
|
||||
"""
|
||||
|
||||
result = client.execute(query)
|
||||
days = [row[0].strftime('%Y-%m-%d') for row in result]
|
||||
print(f"找到 {len(days)} 个交易日: {days[0]} ~ {days[-1]}")
|
||||
return days
|
||||
|
||||
|
||||
# ==================== 获取单日数据 ====================
|
||||
|
||||
def get_daily_stock_data(trade_date: str, stock_codes: List[str]) -> pd.DataFrame:
|
||||
"""获取单日所有股票的分钟数据"""
|
||||
client = get_ch_client()
|
||||
|
||||
# 转换代码格式
|
||||
ch_codes = []
|
||||
code_map = {}
|
||||
for code in stock_codes:
|
||||
ch_code = code_to_ch_format(code)
|
||||
if ch_code:
|
||||
ch_codes.append(ch_code)
|
||||
code_map[ch_code] = code
|
||||
|
||||
if not ch_codes:
|
||||
return pd.DataFrame()
|
||||
|
||||
ch_codes_str = "','".join(ch_codes)
|
||||
|
||||
query = f"""
|
||||
SELECT
|
||||
code,
|
||||
timestamp,
|
||||
close,
|
||||
volume,
|
||||
amt
|
||||
FROM stock_minute
|
||||
WHERE toDate(timestamp) = '{trade_date}'
|
||||
AND code IN ('{ch_codes_str}')
|
||||
ORDER BY code, timestamp
|
||||
"""
|
||||
|
||||
result = client.execute(query)
|
||||
|
||||
if not result:
|
||||
return pd.DataFrame()
|
||||
|
||||
df = pd.DataFrame(result, columns=['ch_code', 'timestamp', 'close', 'volume', 'amt'])
|
||||
df['code'] = df['ch_code'].map(code_map)
|
||||
df = df.dropna(subset=['code'])
|
||||
|
||||
return df[['code', 'timestamp', 'close', 'volume', 'amt']]
|
||||
|
||||
|
||||
def get_daily_index_data(trade_date: str, index_code: str = REFERENCE_INDEX) -> pd.DataFrame:
|
||||
"""获取单日指数分钟数据"""
|
||||
client = get_ch_client()
|
||||
|
||||
query = f"""
|
||||
SELECT
|
||||
timestamp,
|
||||
close,
|
||||
volume,
|
||||
amt
|
||||
FROM index_minute
|
||||
WHERE toDate(timestamp) = '{trade_date}'
|
||||
AND code = '{index_code}'
|
||||
ORDER BY timestamp
|
||||
"""
|
||||
|
||||
result = client.execute(query)
|
||||
|
||||
if not result:
|
||||
return pd.DataFrame()
|
||||
|
||||
df = pd.DataFrame(result, columns=['timestamp', 'close', 'volume', 'amt'])
|
||||
return df
|
||||
|
||||
|
||||
def get_prev_close(stock_codes: List[str], trade_date: str) -> Dict[str, float]:
|
||||
"""获取昨收价(上一交易日的收盘价 F007N)"""
|
||||
valid_codes = [c for c in stock_codes if c and len(c) == 6 and c.isdigit()]
|
||||
if not valid_codes:
|
||||
return {}
|
||||
|
||||
codes_str = "','".join(valid_codes)
|
||||
|
||||
# 注意:F007N 是"最近成交价"即当日收盘价,F002N 是"昨日收盘价"
|
||||
# 我们需要查上一交易日的 F007N(那天的收盘价)作为今天的昨收
|
||||
query = f"""
|
||||
SELECT SECCODE, F007N
|
||||
FROM ea_trade
|
||||
WHERE SECCODE IN ('{codes_str}')
|
||||
AND TRADEDATE = (
|
||||
SELECT MAX(TRADEDATE) FROM ea_trade WHERE TRADEDATE < '{trade_date}'
|
||||
)
|
||||
AND F007N IS NOT NULL AND F007N > 0
|
||||
"""
|
||||
|
||||
try:
|
||||
with MYSQL_ENGINE.connect() as conn:
|
||||
result = conn.execute(text(query))
|
||||
return {row[0]: float(row[1]) for row in result if row[1]}
|
||||
except Exception as e:
|
||||
print(f"获取昨收价失败: {e}")
|
||||
return {}
|
||||
|
||||
|
||||
def get_index_prev_close(trade_date: str, index_code: str = REFERENCE_INDEX) -> float:
|
||||
"""获取指数昨收价"""
|
||||
code_no_suffix = index_code.split('.')[0]
|
||||
|
||||
try:
|
||||
with MYSQL_ENGINE.connect() as conn:
|
||||
result = conn.execute(text("""
|
||||
SELECT F006N FROM ea_exchangetrade
|
||||
WHERE INDEXCODE = :code AND TRADEDATE < :today
|
||||
ORDER BY TRADEDATE DESC LIMIT 1
|
||||
"""), {'code': code_no_suffix, 'today': trade_date}).fetchone()
|
||||
|
||||
if result and result[0]:
|
||||
return float(result[0])
|
||||
except Exception as e:
|
||||
print(f"获取指数昨收失败: {e}")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# ==================== 计算特征 ====================
|
||||
|
||||
def compute_daily_features(
|
||||
trade_date: str,
|
||||
concepts: List[dict],
|
||||
all_stocks: List[str]
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
计算单日所有概念的特征
|
||||
|
||||
返回 DataFrame:
|
||||
- index: (timestamp, concept_id)
|
||||
- columns: alpha, alpha_delta, amt_ratio, amt_delta, rank_pct, limit_up_ratio
|
||||
"""
|
||||
|
||||
# 1. 获取数据
|
||||
stock_df = get_daily_stock_data(trade_date, all_stocks)
|
||||
if stock_df.empty:
|
||||
return pd.DataFrame()
|
||||
|
||||
index_df = get_daily_index_data(trade_date)
|
||||
if index_df.empty:
|
||||
return pd.DataFrame()
|
||||
|
||||
# 2. 获取昨收价
|
||||
prev_close = get_prev_close(all_stocks, trade_date)
|
||||
index_prev_close = get_index_prev_close(trade_date)
|
||||
|
||||
if not prev_close or not index_prev_close:
|
||||
return pd.DataFrame()
|
||||
|
||||
# 3. 计算股票涨跌幅和成交额
|
||||
stock_df['prev_close'] = stock_df['code'].map(prev_close)
|
||||
stock_df = stock_df.dropna(subset=['prev_close'])
|
||||
stock_df['change_pct'] = (stock_df['close'] - stock_df['prev_close']) / stock_df['prev_close'] * 100
|
||||
|
||||
# 4. 计算指数涨跌幅
|
||||
index_df['change_pct'] = (index_df['close'] - index_prev_close) / index_prev_close * 100
|
||||
index_change_map = dict(zip(index_df['timestamp'], index_df['change_pct']))
|
||||
|
||||
# 5. 获取所有时间点
|
||||
timestamps = sorted(stock_df['timestamp'].unique())
|
||||
|
||||
# 6. 按时间点计算概念特征
|
||||
results = []
|
||||
|
||||
# 概念到股票的映射
|
||||
concept_stocks = {c['concept_id']: set(c['stocks']) for c in concepts}
|
||||
concept_names = {c['concept_id']: c['concept_name'] for c in concepts}
|
||||
|
||||
# 历史数据缓存(用于计算变化率)
|
||||
concept_history = {cid: {'alpha': [], 'amt': []} for cid in concept_stocks}
|
||||
|
||||
for ts in timestamps:
|
||||
ts_stock_data = stock_df[stock_df['timestamp'] == ts]
|
||||
index_change = index_change_map.get(ts, 0)
|
||||
|
||||
# 股票涨跌幅和成交额字典
|
||||
stock_change = dict(zip(ts_stock_data['code'], ts_stock_data['change_pct']))
|
||||
stock_amt = dict(zip(ts_stock_data['code'], ts_stock_data['amt']))
|
||||
|
||||
concept_features = []
|
||||
|
||||
for concept_id, stocks in concept_stocks.items():
|
||||
# 该概念的股票数据
|
||||
concept_changes = [stock_change[s] for s in stocks if s in stock_change]
|
||||
concept_amts = [stock_amt.get(s, 0) for s in stocks if s in stock_change]
|
||||
|
||||
if not concept_changes:
|
||||
continue
|
||||
|
||||
# 基础统计
|
||||
avg_change = np.mean(concept_changes)
|
||||
total_amt = sum(concept_amts)
|
||||
|
||||
# Alpha = 概念涨幅 - 指数涨幅
|
||||
alpha = avg_change - index_change
|
||||
|
||||
# 涨停/跌停股占比
|
||||
limit_up_count = sum(1 for c in concept_changes if c >= FEATURE_CONFIG['limit_up_threshold'])
|
||||
limit_down_count = sum(1 for c in concept_changes if c <= FEATURE_CONFIG['limit_down_threshold'])
|
||||
limit_up_ratio = limit_up_count / len(concept_changes)
|
||||
limit_down_ratio = limit_down_count / len(concept_changes)
|
||||
|
||||
# 更新历史
|
||||
history = concept_history[concept_id]
|
||||
history['alpha'].append(alpha)
|
||||
history['amt'].append(total_amt)
|
||||
|
||||
# 计算变化率
|
||||
alpha_delta = 0
|
||||
if len(history['alpha']) > FEATURE_CONFIG['alpha_delta_window']:
|
||||
alpha_delta = alpha - history['alpha'][-FEATURE_CONFIG['alpha_delta_window']-1]
|
||||
|
||||
# 成交额相对均值
|
||||
amt_ratio = 1.0
|
||||
amt_delta = 0
|
||||
if len(history['amt']) > FEATURE_CONFIG['amt_ma_window']:
|
||||
amt_ma = np.mean(history['amt'][-FEATURE_CONFIG['amt_ma_window']-1:-1])
|
||||
if amt_ma > 0:
|
||||
amt_ratio = total_amt / amt_ma
|
||||
amt_delta = total_amt - history['amt'][-2] if len(history['amt']) > 1 else 0
|
||||
|
||||
concept_features.append({
|
||||
'concept_id': concept_id,
|
||||
'alpha': alpha,
|
||||
'alpha_delta': alpha_delta,
|
||||
'amt_ratio': amt_ratio,
|
||||
'amt_delta': amt_delta,
|
||||
'limit_up_ratio': limit_up_ratio,
|
||||
'limit_down_ratio': limit_down_ratio,
|
||||
'total_amt': total_amt,
|
||||
'stock_count': len(concept_changes),
|
||||
})
|
||||
|
||||
if not concept_features:
|
||||
continue
|
||||
|
||||
# 计算排名百分位
|
||||
concept_df = pd.DataFrame(concept_features)
|
||||
concept_df['rank_pct'] = concept_df['alpha'].rank(pct=True)
|
||||
|
||||
# 添加时间戳
|
||||
concept_df['timestamp'] = ts
|
||||
results.append(concept_df)
|
||||
|
||||
if not results:
|
||||
return pd.DataFrame()
|
||||
|
||||
# 合并所有时间点
|
||||
final_df = pd.concat(results, ignore_index=True)
|
||||
|
||||
# 标准化成交额变化率
|
||||
if 'amt_delta' in final_df.columns:
|
||||
amt_delta_std = final_df['amt_delta'].std()
|
||||
if amt_delta_std > 0:
|
||||
final_df['amt_delta'] = final_df['amt_delta'] / amt_delta_std
|
||||
|
||||
return final_df
|
||||
|
||||
|
||||
# ==================== 主流程 ====================
|
||||
|
||||
def process_single_day(args) -> Tuple[str, bool]:
|
||||
"""
|
||||
处理单个交易日(多进程版本)
|
||||
|
||||
Args:
|
||||
args: (trade_date, concepts, all_stocks) 元组
|
||||
|
||||
Returns:
|
||||
(trade_date, success) 元组
|
||||
"""
|
||||
trade_date, concepts, all_stocks = args
|
||||
output_file = os.path.join(OUTPUT_DIR, f'features_{trade_date}.parquet')
|
||||
|
||||
# 检查是否已处理
|
||||
if os.path.exists(output_file):
|
||||
print(f"[{trade_date}] 已存在,跳过")
|
||||
return (trade_date, True)
|
||||
|
||||
print(f"[{trade_date}] 开始处理...")
|
||||
|
||||
try:
|
||||
df = compute_daily_features(trade_date, concepts, all_stocks)
|
||||
|
||||
if df.empty:
|
||||
print(f"[{trade_date}] 无数据")
|
||||
return (trade_date, False)
|
||||
|
||||
# 保存
|
||||
df.to_parquet(output_file, index=False)
|
||||
print(f"[{trade_date}] 保存完成")
|
||||
return (trade_date, True)
|
||||
|
||||
except Exception as e:
|
||||
print(f"[{trade_date}] 处理失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return (trade_date, False)
|
||||
|
||||
|
||||
def main():
|
||||
import argparse
|
||||
from tqdm import tqdm
|
||||
|
||||
parser = argparse.ArgumentParser(description='准备训练数据')
|
||||
parser.add_argument('--start', type=str, default='2022-01-01', help='开始日期')
|
||||
parser.add_argument('--end', type=str, default=None, help='结束日期(默认今天)')
|
||||
parser.add_argument('--workers', type=int, default=18, help='并行进程数(默认18)')
|
||||
parser.add_argument('--force', action='store_true', help='强制重新处理已存在的文件')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
end_date = args.end or datetime.now().strftime('%Y-%m-%d')
|
||||
|
||||
print("=" * 60)
|
||||
print("数据准备 - Transformer Autoencoder 训练数据")
|
||||
print("=" * 60)
|
||||
print(f"日期范围: {args.start} ~ {end_date}")
|
||||
print(f"并行进程数: {args.workers}")
|
||||
|
||||
# 1. 获取概念列表
|
||||
concepts = get_all_concepts()
|
||||
|
||||
# 收集所有股票
|
||||
all_stocks = list(set(s for c in concepts for s in c['stocks']))
|
||||
print(f"股票总数: {len(all_stocks)}")
|
||||
|
||||
# 2. 获取交易日列表
|
||||
trading_days = get_trading_days(args.start, end_date)
|
||||
|
||||
if not trading_days:
|
||||
print("无交易日数据")
|
||||
return
|
||||
|
||||
# 如果强制模式,删除已有文件
|
||||
if args.force:
|
||||
for trade_date in trading_days:
|
||||
output_file = os.path.join(OUTPUT_DIR, f'features_{trade_date}.parquet')
|
||||
if os.path.exists(output_file):
|
||||
os.remove(output_file)
|
||||
print(f"删除已有文件: {output_file}")
|
||||
|
||||
# 3. 准备任务参数
|
||||
tasks = [(trade_date, concepts, all_stocks) for trade_date in trading_days]
|
||||
|
||||
print(f"\n开始处理 {len(trading_days)} 个交易日({args.workers} 进程并行)...")
|
||||
|
||||
# 4. 多进程处理
|
||||
success_count = 0
|
||||
failed_dates = []
|
||||
|
||||
with ProcessPoolExecutor(max_workers=args.workers) as executor:
|
||||
# 提交所有任务
|
||||
futures = {executor.submit(process_single_day, task): task[0] for task in tasks}
|
||||
|
||||
# 使用 tqdm 显示进度
|
||||
with tqdm(total=len(futures), desc="处理进度", unit="天") as pbar:
|
||||
for future in as_completed(futures):
|
||||
trade_date = futures[future]
|
||||
try:
|
||||
result_date, success = future.result()
|
||||
if success:
|
||||
success_count += 1
|
||||
else:
|
||||
failed_dates.append(result_date)
|
||||
except Exception as e:
|
||||
print(f"\n[{trade_date}] 进程异常: {e}")
|
||||
failed_dates.append(trade_date)
|
||||
pbar.update(1)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print(f"处理完成: {success_count}/{len(trading_days)} 个交易日")
|
||||
if failed_dates:
|
||||
print(f"失败日期: {failed_dates[:10]}{'...' if len(failed_dates) > 10 else ''}")
|
||||
print(f"数据保存在: {OUTPUT_DIR}")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
715
ml/prepare_data_v2.py
Normal file
715
ml/prepare_data_v2.py
Normal file
@@ -0,0 +1,715 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
数据准备 V2 - 基于时间片对齐的特征计算(修复版)
|
||||
|
||||
核心改进:
|
||||
1. 时间片对齐:9:35 和历史的 9:35 比,而不是和前30分钟比
|
||||
2. Z-Score 特征:相对于同时间片历史分布的偏离程度
|
||||
3. 滚动窗口基线:每个日期使用它之前 N 天的数据作为基线(不是固定的最后 N 天!)
|
||||
4. 基于 Z-Score 的动量:消除一天内波动率异构性
|
||||
|
||||
修复:
|
||||
- 滚动窗口基线:避免未来数据泄露
|
||||
- Z-Score 动量:消除早盘/尾盘波动率差异
|
||||
- 进程级数据库单例:避免连接池爆炸
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from datetime import datetime, timedelta
|
||||
from sqlalchemy import create_engine, text
|
||||
from elasticsearch import Elasticsearch
|
||||
from clickhouse_driver import Client
|
||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||
from typing import Dict, List, Tuple, Optional
|
||||
from tqdm import tqdm
|
||||
from collections import defaultdict
|
||||
import warnings
|
||||
import pickle
|
||||
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
# ==================== 配置 ====================
|
||||
|
||||
MYSQL_URL = "mysql+pymysql://root:Zzl5588161!@192.168.1.5:3306/stock"
|
||||
ES_HOST = 'http://127.0.0.1:9200'
|
||||
ES_INDEX = 'concept_library_v3'
|
||||
|
||||
CLICKHOUSE_CONFIG = {
|
||||
'host': '127.0.0.1',
|
||||
'port': 9000,
|
||||
'user': 'default',
|
||||
'password': 'Zzl33818!',
|
||||
'database': 'stock'
|
||||
}
|
||||
|
||||
REFERENCE_INDEX = '000001.SH'
|
||||
|
||||
# 输出目录
|
||||
OUTPUT_DIR = os.path.join(os.path.dirname(__file__), 'data_v2')
|
||||
BASELINE_DIR = os.path.join(OUTPUT_DIR, 'baselines')
|
||||
RAW_CACHE_DIR = os.path.join(OUTPUT_DIR, 'raw_cache')
|
||||
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
||||
os.makedirs(BASELINE_DIR, exist_ok=True)
|
||||
os.makedirs(RAW_CACHE_DIR, exist_ok=True)
|
||||
|
||||
# 特征配置
|
||||
CONFIG = {
|
||||
'baseline_days': 20, # 滚动窗口大小
|
||||
'min_baseline_samples': 10, # 最少需要10个样本才算有效基线
|
||||
'limit_up_threshold': 9.8,
|
||||
'limit_down_threshold': -9.8,
|
||||
'zscore_clip': 5.0,
|
||||
}
|
||||
|
||||
# 特征列表
|
||||
FEATURES_V2 = [
|
||||
'alpha', 'alpha_zscore', 'amt_zscore', 'rank_zscore',
|
||||
'momentum_3m', 'momentum_5m', 'limit_up_ratio',
|
||||
]
|
||||
|
||||
# ==================== 进程级单例(避免连接池爆炸)====================
|
||||
|
||||
# 进程级全局变量
|
||||
_process_mysql_engine = None
|
||||
_process_es_client = None
|
||||
_process_ch_client = None
|
||||
|
||||
|
||||
def init_process_connections():
|
||||
"""进程初始化时调用,创建连接(单例)"""
|
||||
global _process_mysql_engine, _process_es_client, _process_ch_client
|
||||
_process_mysql_engine = create_engine(MYSQL_URL, echo=False, pool_pre_ping=True, pool_size=5)
|
||||
_process_es_client = Elasticsearch([ES_HOST])
|
||||
_process_ch_client = Client(**CLICKHOUSE_CONFIG)
|
||||
|
||||
|
||||
def get_mysql_engine():
|
||||
"""获取进程级 MySQL Engine(单例)"""
|
||||
global _process_mysql_engine
|
||||
if _process_mysql_engine is None:
|
||||
_process_mysql_engine = create_engine(MYSQL_URL, echo=False, pool_pre_ping=True, pool_size=5)
|
||||
return _process_mysql_engine
|
||||
|
||||
|
||||
def get_es_client():
|
||||
"""获取进程级 ES 客户端(单例)"""
|
||||
global _process_es_client
|
||||
if _process_es_client is None:
|
||||
_process_es_client = Elasticsearch([ES_HOST])
|
||||
return _process_es_client
|
||||
|
||||
|
||||
def get_ch_client():
|
||||
"""获取进程级 ClickHouse 客户端(单例)"""
|
||||
global _process_ch_client
|
||||
if _process_ch_client is None:
|
||||
_process_ch_client = Client(**CLICKHOUSE_CONFIG)
|
||||
return _process_ch_client
|
||||
|
||||
|
||||
# ==================== 工具函数 ====================
|
||||
|
||||
def code_to_ch_format(code: str) -> str:
|
||||
if not code or len(code) != 6 or not code.isdigit():
|
||||
return None
|
||||
if code.startswith('6'):
|
||||
return f"{code}.SH"
|
||||
elif code.startswith('0') or code.startswith('3'):
|
||||
return f"{code}.SZ"
|
||||
else:
|
||||
return f"{code}.BJ"
|
||||
|
||||
|
||||
def time_to_slot(ts) -> str:
|
||||
"""将时间戳转换为时间片(HH:MM格式)"""
|
||||
if isinstance(ts, str):
|
||||
return ts
|
||||
return ts.strftime('%H:%M')
|
||||
|
||||
|
||||
# ==================== 获取概念列表 ====================
|
||||
|
||||
def get_all_concepts() -> List[dict]:
|
||||
"""从ES获取所有叶子概念"""
|
||||
es_client = get_es_client()
|
||||
concepts = []
|
||||
|
||||
query = {
|
||||
"query": {"match_all": {}},
|
||||
"size": 100,
|
||||
"_source": ["concept_id", "concept", "stocks"]
|
||||
}
|
||||
|
||||
resp = es_client.search(index=ES_INDEX, body=query, scroll='2m')
|
||||
scroll_id = resp['_scroll_id']
|
||||
hits = resp['hits']['hits']
|
||||
|
||||
while len(hits) > 0:
|
||||
for hit in hits:
|
||||
source = hit['_source']
|
||||
stocks = []
|
||||
if 'stocks' in source and isinstance(source['stocks'], list):
|
||||
for stock in source['stocks']:
|
||||
if isinstance(stock, dict) and 'code' in stock and stock['code']:
|
||||
stocks.append(stock['code'])
|
||||
|
||||
if stocks:
|
||||
concepts.append({
|
||||
'concept_id': source.get('concept_id'),
|
||||
'concept_name': source.get('concept'),
|
||||
'stocks': stocks
|
||||
})
|
||||
|
||||
resp = es_client.scroll(scroll_id=scroll_id, scroll='2m')
|
||||
scroll_id = resp['_scroll_id']
|
||||
hits = resp['hits']['hits']
|
||||
|
||||
es_client.clear_scroll(scroll_id=scroll_id)
|
||||
print(f"获取到 {len(concepts)} 个概念")
|
||||
return concepts
|
||||
|
||||
|
||||
# ==================== 获取交易日列表 ====================
|
||||
|
||||
def get_trading_days(start_date: str, end_date: str) -> List[str]:
|
||||
"""获取交易日列表"""
|
||||
client = get_ch_client()
|
||||
|
||||
query = f"""
|
||||
SELECT DISTINCT toDate(timestamp) as trade_date
|
||||
FROM stock_minute
|
||||
WHERE toDate(timestamp) >= '{start_date}'
|
||||
AND toDate(timestamp) <= '{end_date}'
|
||||
ORDER BY trade_date
|
||||
"""
|
||||
|
||||
result = client.execute(query)
|
||||
days = [row[0].strftime('%Y-%m-%d') for row in result]
|
||||
if days:
|
||||
print(f"找到 {len(days)} 个交易日: {days[0]} ~ {days[-1]}")
|
||||
return days
|
||||
|
||||
|
||||
# ==================== 获取昨收价 ====================
|
||||
|
||||
def get_prev_close(stock_codes: List[str], trade_date: str) -> Dict[str, float]:
|
||||
"""获取昨收价(上一交易日的收盘价 F007N)"""
|
||||
valid_codes = [c for c in stock_codes if c and len(c) == 6 and c.isdigit()]
|
||||
if not valid_codes:
|
||||
return {}
|
||||
|
||||
codes_str = "','".join(valid_codes)
|
||||
query = f"""
|
||||
SELECT SECCODE, F007N
|
||||
FROM ea_trade
|
||||
WHERE SECCODE IN ('{codes_str}')
|
||||
AND TRADEDATE = (
|
||||
SELECT MAX(TRADEDATE) FROM ea_trade WHERE TRADEDATE < '{trade_date}'
|
||||
)
|
||||
AND F007N IS NOT NULL AND F007N > 0
|
||||
"""
|
||||
|
||||
try:
|
||||
engine = get_mysql_engine()
|
||||
with engine.connect() as conn:
|
||||
result = conn.execute(text(query))
|
||||
return {row[0]: float(row[1]) for row in result if row[1]}
|
||||
except Exception as e:
|
||||
print(f"获取昨收价失败: {e}")
|
||||
return {}
|
||||
|
||||
|
||||
def get_index_prev_close(trade_date: str, index_code: str = REFERENCE_INDEX) -> float:
|
||||
"""获取指数昨收价"""
|
||||
code_no_suffix = index_code.split('.')[0]
|
||||
|
||||
try:
|
||||
engine = get_mysql_engine()
|
||||
with engine.connect() as conn:
|
||||
result = conn.execute(text("""
|
||||
SELECT F006N FROM ea_exchangetrade
|
||||
WHERE INDEXCODE = :code AND TRADEDATE < :today
|
||||
ORDER BY TRADEDATE DESC LIMIT 1
|
||||
"""), {'code': code_no_suffix, 'today': trade_date}).fetchone()
|
||||
|
||||
if result and result[0]:
|
||||
return float(result[0])
|
||||
except Exception as e:
|
||||
print(f"获取指数昨收失败: {e}")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# ==================== 获取分钟数据 ====================
|
||||
|
||||
def get_daily_stock_data(trade_date: str, stock_codes: List[str]) -> pd.DataFrame:
|
||||
"""获取单日所有股票的分钟数据"""
|
||||
client = get_ch_client()
|
||||
|
||||
ch_codes = []
|
||||
code_map = {}
|
||||
for code in stock_codes:
|
||||
ch_code = code_to_ch_format(code)
|
||||
if ch_code:
|
||||
ch_codes.append(ch_code)
|
||||
code_map[ch_code] = code
|
||||
|
||||
if not ch_codes:
|
||||
return pd.DataFrame()
|
||||
|
||||
ch_codes_str = "','".join(ch_codes)
|
||||
|
||||
query = f"""
|
||||
SELECT code, timestamp, close, volume, amt
|
||||
FROM stock_minute
|
||||
WHERE toDate(timestamp) = '{trade_date}'
|
||||
AND code IN ('{ch_codes_str}')
|
||||
ORDER BY code, timestamp
|
||||
"""
|
||||
|
||||
result = client.execute(query)
|
||||
if not result:
|
||||
return pd.DataFrame()
|
||||
|
||||
df = pd.DataFrame(result, columns=['ch_code', 'timestamp', 'close', 'volume', 'amt'])
|
||||
df['code'] = df['ch_code'].map(code_map)
|
||||
df = df.dropna(subset=['code'])
|
||||
|
||||
return df[['code', 'timestamp', 'close', 'volume', 'amt']]
|
||||
|
||||
|
||||
def get_daily_index_data(trade_date: str, index_code: str = REFERENCE_INDEX) -> pd.DataFrame:
|
||||
"""获取单日指数分钟数据"""
|
||||
client = get_ch_client()
|
||||
|
||||
query = f"""
|
||||
SELECT timestamp, close, volume, amt
|
||||
FROM index_minute
|
||||
WHERE toDate(timestamp) = '{trade_date}'
|
||||
AND code = '{index_code}'
|
||||
ORDER BY timestamp
|
||||
"""
|
||||
|
||||
result = client.execute(query)
|
||||
if not result:
|
||||
return pd.DataFrame()
|
||||
|
||||
df = pd.DataFrame(result, columns=['timestamp', 'close', 'volume', 'amt'])
|
||||
return df
|
||||
|
||||
|
||||
# ==================== 计算原始概念特征(单日)====================
|
||||
|
||||
def compute_raw_concept_features(
|
||||
trade_date: str,
|
||||
concepts: List[dict],
|
||||
all_stocks: List[str]
|
||||
) -> pd.DataFrame:
|
||||
"""计算单日概念的原始特征(alpha, amt, rank_pct, limit_up_ratio)"""
|
||||
# 检查缓存
|
||||
cache_file = os.path.join(RAW_CACHE_DIR, f'raw_{trade_date}.parquet')
|
||||
if os.path.exists(cache_file):
|
||||
return pd.read_parquet(cache_file)
|
||||
|
||||
# 获取数据
|
||||
stock_df = get_daily_stock_data(trade_date, all_stocks)
|
||||
if stock_df.empty:
|
||||
return pd.DataFrame()
|
||||
|
||||
index_df = get_daily_index_data(trade_date)
|
||||
if index_df.empty:
|
||||
return pd.DataFrame()
|
||||
|
||||
# 获取昨收价
|
||||
prev_close = get_prev_close(all_stocks, trade_date)
|
||||
index_prev_close = get_index_prev_close(trade_date)
|
||||
|
||||
if not prev_close or not index_prev_close:
|
||||
return pd.DataFrame()
|
||||
|
||||
# 计算涨跌幅
|
||||
stock_df['prev_close'] = stock_df['code'].map(prev_close)
|
||||
stock_df = stock_df.dropna(subset=['prev_close'])
|
||||
stock_df['change_pct'] = (stock_df['close'] - stock_df['prev_close']) / stock_df['prev_close'] * 100
|
||||
|
||||
index_df['change_pct'] = (index_df['close'] - index_prev_close) / index_prev_close * 100
|
||||
index_change_map = dict(zip(index_df['timestamp'], index_df['change_pct']))
|
||||
|
||||
# 获取所有时间点
|
||||
timestamps = sorted(stock_df['timestamp'].unique())
|
||||
|
||||
# 概念到股票的映射
|
||||
concept_stocks = {c['concept_id']: set(c['stocks']) for c in concepts}
|
||||
|
||||
results = []
|
||||
|
||||
for ts in timestamps:
|
||||
ts_stock_data = stock_df[stock_df['timestamp'] == ts]
|
||||
index_change = index_change_map.get(ts, 0)
|
||||
|
||||
stock_change = dict(zip(ts_stock_data['code'], ts_stock_data['change_pct']))
|
||||
stock_amt = dict(zip(ts_stock_data['code'], ts_stock_data['amt']))
|
||||
|
||||
concept_features = []
|
||||
|
||||
for concept_id, stocks in concept_stocks.items():
|
||||
concept_changes = [stock_change[s] for s in stocks if s in stock_change]
|
||||
concept_amts = [stock_amt.get(s, 0) for s in stocks if s in stock_change]
|
||||
|
||||
if not concept_changes:
|
||||
continue
|
||||
|
||||
avg_change = np.mean(concept_changes)
|
||||
total_amt = sum(concept_amts)
|
||||
alpha = avg_change - index_change
|
||||
|
||||
limit_up_count = sum(1 for c in concept_changes if c >= CONFIG['limit_up_threshold'])
|
||||
limit_up_ratio = limit_up_count / len(concept_changes)
|
||||
|
||||
concept_features.append({
|
||||
'concept_id': concept_id,
|
||||
'alpha': alpha,
|
||||
'total_amt': total_amt,
|
||||
'limit_up_ratio': limit_up_ratio,
|
||||
'stock_count': len(concept_changes),
|
||||
})
|
||||
|
||||
if not concept_features:
|
||||
continue
|
||||
|
||||
concept_df = pd.DataFrame(concept_features)
|
||||
concept_df['rank_pct'] = concept_df['alpha'].rank(pct=True)
|
||||
concept_df['timestamp'] = ts
|
||||
concept_df['time_slot'] = time_to_slot(ts)
|
||||
concept_df['trade_date'] = trade_date
|
||||
|
||||
results.append(concept_df)
|
||||
|
||||
if not results:
|
||||
return pd.DataFrame()
|
||||
|
||||
result_df = pd.concat(results, ignore_index=True)
|
||||
|
||||
# 保存缓存
|
||||
result_df.to_parquet(cache_file, index=False)
|
||||
|
||||
return result_df
|
||||
|
||||
|
||||
# ==================== 滚动窗口基线计算 ====================
|
||||
|
||||
def compute_rolling_baseline(
|
||||
historical_data: pd.DataFrame,
|
||||
concept_id: str
|
||||
) -> Dict[str, Dict]:
|
||||
"""
|
||||
计算单个概念的滚动基线
|
||||
|
||||
返回: {time_slot: {alpha_mean, alpha_std, amt_mean, amt_std, rank_mean, rank_std, sample_count}}
|
||||
"""
|
||||
if historical_data.empty:
|
||||
return {}
|
||||
|
||||
concept_data = historical_data[historical_data['concept_id'] == concept_id]
|
||||
if concept_data.empty:
|
||||
return {}
|
||||
|
||||
baseline_dict = {}
|
||||
|
||||
for time_slot, group in concept_data.groupby('time_slot'):
|
||||
if len(group) < CONFIG['min_baseline_samples']:
|
||||
continue
|
||||
|
||||
alpha_std = group['alpha'].std()
|
||||
amt_std = group['total_amt'].std()
|
||||
rank_std = group['rank_pct'].std()
|
||||
|
||||
baseline_dict[time_slot] = {
|
||||
'alpha_mean': group['alpha'].mean(),
|
||||
'alpha_std': max(alpha_std if pd.notna(alpha_std) else 1.0, 0.1),
|
||||
'amt_mean': group['total_amt'].mean(),
|
||||
'amt_std': max(amt_std if pd.notna(amt_std) else group['total_amt'].mean() * 0.5, 1.0),
|
||||
'rank_mean': group['rank_pct'].mean(),
|
||||
'rank_std': max(rank_std if pd.notna(rank_std) else 0.2, 0.05),
|
||||
'sample_count': len(group),
|
||||
}
|
||||
|
||||
return baseline_dict
|
||||
|
||||
|
||||
# ==================== 计算单日 Z-Score 特征(带滚动基线)====================
|
||||
|
||||
def compute_zscore_features_rolling(
|
||||
trade_date: str,
|
||||
concepts: List[dict],
|
||||
all_stocks: List[str],
|
||||
historical_raw_data: pd.DataFrame # 该日期之前 N 天的原始数据
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
计算单日的 Z-Score 特征(使用滚动窗口基线)
|
||||
|
||||
关键改进:
|
||||
1. 基线只使用 trade_date 之前的数据(无未来泄露)
|
||||
2. 动量基于 Z-Score 计算(消除波动率异构性)
|
||||
"""
|
||||
# 计算当日原始特征
|
||||
raw_df = compute_raw_concept_features(trade_date, concepts, all_stocks)
|
||||
|
||||
if raw_df.empty:
|
||||
return pd.DataFrame()
|
||||
|
||||
zscore_records = []
|
||||
|
||||
for concept_id, group in raw_df.groupby('concept_id'):
|
||||
# 计算该概念的滚动基线(只用历史数据)
|
||||
baseline_dict = compute_rolling_baseline(historical_raw_data, concept_id)
|
||||
|
||||
if not baseline_dict:
|
||||
continue
|
||||
|
||||
# 按时间排序
|
||||
group = group.sort_values('timestamp').reset_index(drop=True)
|
||||
|
||||
# Z-Score 历史(用于计算基于 Z-Score 的动量)
|
||||
zscore_history = []
|
||||
|
||||
for idx, row in group.iterrows():
|
||||
time_slot = row['time_slot']
|
||||
|
||||
if time_slot not in baseline_dict:
|
||||
continue
|
||||
|
||||
bl = baseline_dict[time_slot]
|
||||
|
||||
# 计算 Z-Score
|
||||
alpha_zscore = (row['alpha'] - bl['alpha_mean']) / bl['alpha_std']
|
||||
amt_zscore = (row['total_amt'] - bl['amt_mean']) / bl['amt_std']
|
||||
rank_zscore = (row['rank_pct'] - bl['rank_mean']) / bl['rank_std']
|
||||
|
||||
# 截断极端值
|
||||
clip = CONFIG['zscore_clip']
|
||||
alpha_zscore = np.clip(alpha_zscore, -clip, clip)
|
||||
amt_zscore = np.clip(amt_zscore, -clip, clip)
|
||||
rank_zscore = np.clip(rank_zscore, -clip, clip)
|
||||
|
||||
# 记录 Z-Score 历史
|
||||
zscore_history.append(alpha_zscore)
|
||||
|
||||
# 基于 Z-Score 计算动量(消除波动率异构性)
|
||||
momentum_3m = 0.0
|
||||
momentum_5m = 0.0
|
||||
|
||||
if len(zscore_history) >= 3:
|
||||
recent_3 = zscore_history[-3:]
|
||||
older_3 = zscore_history[-6:-3] if len(zscore_history) >= 6 else [zscore_history[0]]
|
||||
momentum_3m = np.mean(recent_3) - np.mean(older_3)
|
||||
|
||||
if len(zscore_history) >= 5:
|
||||
recent_5 = zscore_history[-5:]
|
||||
older_5 = zscore_history[-10:-5] if len(zscore_history) >= 10 else [zscore_history[0]]
|
||||
momentum_5m = np.mean(recent_5) - np.mean(older_5)
|
||||
|
||||
zscore_records.append({
|
||||
'concept_id': concept_id,
|
||||
'timestamp': row['timestamp'],
|
||||
'time_slot': time_slot,
|
||||
'trade_date': trade_date,
|
||||
# 原始特征
|
||||
'alpha': row['alpha'],
|
||||
'total_amt': row['total_amt'],
|
||||
'limit_up_ratio': row['limit_up_ratio'],
|
||||
'stock_count': row['stock_count'],
|
||||
'rank_pct': row['rank_pct'],
|
||||
# Z-Score 特征
|
||||
'alpha_zscore': alpha_zscore,
|
||||
'amt_zscore': amt_zscore,
|
||||
'rank_zscore': rank_zscore,
|
||||
# 基于 Z-Score 的动量
|
||||
'momentum_3m': momentum_3m,
|
||||
'momentum_5m': momentum_5m,
|
||||
})
|
||||
|
||||
if not zscore_records:
|
||||
return pd.DataFrame()
|
||||
|
||||
return pd.DataFrame(zscore_records)
|
||||
|
||||
|
||||
# ==================== 多进程处理 ====================
|
||||
|
||||
def process_single_day_v2(args) -> Tuple[str, bool]:
|
||||
"""处理单个交易日(多进程版本)"""
|
||||
trade_date, day_index, concepts, all_stocks, all_trading_days = args
|
||||
output_file = os.path.join(OUTPUT_DIR, f'features_v2_{trade_date}.parquet')
|
||||
|
||||
if os.path.exists(output_file):
|
||||
return (trade_date, True)
|
||||
|
||||
try:
|
||||
# 计算滚动窗口范围(该日期之前的 N 天)
|
||||
baseline_days = CONFIG['baseline_days']
|
||||
|
||||
# 找出 trade_date 之前的交易日
|
||||
start_idx = max(0, day_index - baseline_days)
|
||||
end_idx = day_index # 不包含当天
|
||||
|
||||
if end_idx <= start_idx:
|
||||
# 没有足够的历史数据
|
||||
return (trade_date, False)
|
||||
|
||||
historical_days = all_trading_days[start_idx:end_idx]
|
||||
|
||||
# 加载历史原始数据
|
||||
historical_dfs = []
|
||||
for hist_date in historical_days:
|
||||
cache_file = os.path.join(RAW_CACHE_DIR, f'raw_{hist_date}.parquet')
|
||||
if os.path.exists(cache_file):
|
||||
historical_dfs.append(pd.read_parquet(cache_file))
|
||||
else:
|
||||
# 需要计算
|
||||
hist_df = compute_raw_concept_features(hist_date, concepts, all_stocks)
|
||||
if not hist_df.empty:
|
||||
historical_dfs.append(hist_df)
|
||||
|
||||
if not historical_dfs:
|
||||
return (trade_date, False)
|
||||
|
||||
historical_raw_data = pd.concat(historical_dfs, ignore_index=True)
|
||||
|
||||
# 计算当日 Z-Score 特征(使用滚动基线)
|
||||
df = compute_zscore_features_rolling(trade_date, concepts, all_stocks, historical_raw_data)
|
||||
|
||||
if df.empty:
|
||||
return (trade_date, False)
|
||||
|
||||
df.to_parquet(output_file, index=False)
|
||||
return (trade_date, True)
|
||||
|
||||
except Exception as e:
|
||||
print(f"[{trade_date}] 处理失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return (trade_date, False)
|
||||
|
||||
|
||||
# ==================== 主流程 ====================
|
||||
|
||||
def main():
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description='准备训练数据 V2(滚动窗口基线 + Z-Score 动量)')
|
||||
parser.add_argument('--start', type=str, default='2022-01-01', help='开始日期')
|
||||
parser.add_argument('--end', type=str, default=None, help='结束日期(默认今天)')
|
||||
parser.add_argument('--workers', type=int, default=18, help='并行进程数')
|
||||
parser.add_argument('--baseline-days', type=int, default=20, help='滚动基线窗口大小')
|
||||
parser.add_argument('--force', action='store_true', help='强制重新计算(忽略缓存)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
end_date = args.end or datetime.now().strftime('%Y-%m-%d')
|
||||
CONFIG['baseline_days'] = args.baseline_days
|
||||
|
||||
print("=" * 60)
|
||||
print("数据准备 V2 - 滚动窗口基线 + Z-Score 动量")
|
||||
print("=" * 60)
|
||||
print(f"日期范围: {args.start} ~ {end_date}")
|
||||
print(f"并行进程数: {args.workers}")
|
||||
print(f"滚动基线窗口: {args.baseline_days} 天")
|
||||
|
||||
# 初始化主进程连接
|
||||
init_process_connections()
|
||||
|
||||
# 1. 获取概念列表
|
||||
concepts = get_all_concepts()
|
||||
all_stocks = list(set(s for c in concepts for s in c['stocks']))
|
||||
print(f"股票总数: {len(all_stocks)}")
|
||||
|
||||
# 2. 获取交易日列表
|
||||
trading_days = get_trading_days(args.start, end_date)
|
||||
|
||||
if not trading_days:
|
||||
print("无交易日数据")
|
||||
return
|
||||
|
||||
# 3. 第一阶段:预计算所有原始特征(用于缓存)
|
||||
print(f"\n{'='*60}")
|
||||
print("第一阶段:预计算原始特征(用于滚动基线)")
|
||||
print(f"{'='*60}")
|
||||
|
||||
# 如果强制重新计算,删除缓存
|
||||
if args.force:
|
||||
import shutil
|
||||
if os.path.exists(RAW_CACHE_DIR):
|
||||
shutil.rmtree(RAW_CACHE_DIR)
|
||||
os.makedirs(RAW_CACHE_DIR, exist_ok=True)
|
||||
if os.path.exists(OUTPUT_DIR):
|
||||
for f in os.listdir(OUTPUT_DIR):
|
||||
if f.startswith('features_v2_'):
|
||||
os.remove(os.path.join(OUTPUT_DIR, f))
|
||||
|
||||
# 单线程预计算原始特征(因为需要顺序缓存)
|
||||
print(f"预计算 {len(trading_days)} 天的原始特征...")
|
||||
for trade_date in tqdm(trading_days, desc="预计算原始特征"):
|
||||
cache_file = os.path.join(RAW_CACHE_DIR, f'raw_{trade_date}.parquet')
|
||||
if not os.path.exists(cache_file):
|
||||
compute_raw_concept_features(trade_date, concepts, all_stocks)
|
||||
|
||||
# 4. 第二阶段:计算 Z-Score 特征(多进程)
|
||||
print(f"\n{'='*60}")
|
||||
print("第二阶段:计算 Z-Score 特征(滚动基线)")
|
||||
print(f"{'='*60}")
|
||||
|
||||
# 从第 baseline_days 天开始(前面的没有足够历史)
|
||||
start_idx = args.baseline_days
|
||||
processable_days = trading_days[start_idx:]
|
||||
|
||||
if not processable_days:
|
||||
print(f"错误:需要至少 {args.baseline_days + 1} 天的数据")
|
||||
return
|
||||
|
||||
print(f"可处理日期: {processable_days[0]} ~ {processable_days[-1]} ({len(processable_days)} 天)")
|
||||
print(f"跳过前 {start_idx} 天(基线预热期)")
|
||||
|
||||
# 构建任务
|
||||
tasks = []
|
||||
for i, trade_date in enumerate(trading_days):
|
||||
if i >= start_idx:
|
||||
tasks.append((trade_date, i, concepts, all_stocks, trading_days))
|
||||
|
||||
print(f"开始处理 {len(tasks)} 个交易日({args.workers} 进程并行)...")
|
||||
|
||||
success_count = 0
|
||||
failed_dates = []
|
||||
|
||||
# 使用进程池初始化器
|
||||
with ProcessPoolExecutor(max_workers=args.workers, initializer=init_process_connections) as executor:
|
||||
futures = {executor.submit(process_single_day_v2, task): task[0] for task in tasks}
|
||||
|
||||
with tqdm(total=len(futures), desc="处理进度", unit="天") as pbar:
|
||||
for future in as_completed(futures):
|
||||
trade_date = futures[future]
|
||||
try:
|
||||
result_date, success = future.result()
|
||||
if success:
|
||||
success_count += 1
|
||||
else:
|
||||
failed_dates.append(result_date)
|
||||
except Exception as e:
|
||||
print(f"\n[{trade_date}] 进程异常: {e}")
|
||||
failed_dates.append(trade_date)
|
||||
pbar.update(1)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print(f"处理完成: {success_count}/{len(tasks)} 个交易日")
|
||||
if failed_dates:
|
||||
print(f"失败日期: {failed_dates[:10]}{'...' if len(failed_dates) > 10 else ''}")
|
||||
print(f"数据保存在: {OUTPUT_DIR}")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
1520
ml/realtime_detector.py
Normal file
1520
ml/realtime_detector.py
Normal file
File diff suppressed because it is too large
Load Diff
729
ml/realtime_detector_v2.py
Normal file
729
ml/realtime_detector_v2.py
Normal file
@@ -0,0 +1,729 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
V2 实时异动检测器
|
||||
|
||||
使用方法:
|
||||
# 作为模块导入
|
||||
from ml.realtime_detector_v2 import RealtimeDetectorV2
|
||||
|
||||
detector = RealtimeDetectorV2()
|
||||
alerts = detector.detect_realtime() # 检测当前时刻
|
||||
|
||||
# 或命令行测试
|
||||
python ml/realtime_detector_v2.py --date 2025-12-09
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import pickle
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from collections import defaultdict, deque
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
from sqlalchemy import create_engine, text
|
||||
from elasticsearch import Elasticsearch
|
||||
from clickhouse_driver import Client
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from ml.model import TransformerAutoencoder
|
||||
|
||||
# ==================== 配置 ====================
|
||||
|
||||
MYSQL_URL = "mysql+pymysql://root:Zzl5588161!@192.168.1.5:3306/stock"
|
||||
ES_HOST = 'http://127.0.0.1:9200'
|
||||
ES_INDEX = 'concept_library_v3'
|
||||
|
||||
CLICKHOUSE_CONFIG = {
|
||||
'host': '127.0.0.1',
|
||||
'port': 9000,
|
||||
'user': 'default',
|
||||
'password': 'Zzl33818!',
|
||||
'database': 'stock'
|
||||
}
|
||||
|
||||
REFERENCE_INDEX = '000001.SH'
|
||||
BASELINE_FILE = 'ml/data_v2/baselines/realtime_baseline.pkl'
|
||||
MODEL_DIR = 'ml/checkpoints_v2'
|
||||
|
||||
# 检测配置
|
||||
CONFIG = {
|
||||
'seq_len': 10, # LSTM 序列长度
|
||||
'confirm_window': 5, # 持续确认窗口
|
||||
'confirm_ratio': 0.6, # 确认比例
|
||||
'rule_weight': 0.5,
|
||||
'ml_weight': 0.5,
|
||||
'rule_trigger': 60,
|
||||
'ml_trigger': 70,
|
||||
'fusion_trigger': 50,
|
||||
'cooldown_minutes': 10,
|
||||
'max_alerts_per_minute': 15,
|
||||
'zscore_clip': 5.0,
|
||||
'limit_up_threshold': 9.8,
|
||||
}
|
||||
|
||||
FEATURES = ['alpha_zscore', 'amt_zscore', 'rank_zscore', 'momentum_3m', 'momentum_5m', 'limit_up_ratio']
|
||||
|
||||
|
||||
# ==================== 数据库连接 ====================
|
||||
|
||||
_mysql_engine = None
|
||||
_es_client = None
|
||||
_ch_client = None
|
||||
|
||||
|
||||
def get_mysql_engine():
|
||||
global _mysql_engine
|
||||
if _mysql_engine is None:
|
||||
_mysql_engine = create_engine(MYSQL_URL, echo=False, pool_pre_ping=True)
|
||||
return _mysql_engine
|
||||
|
||||
|
||||
def get_es_client():
|
||||
global _es_client
|
||||
if _es_client is None:
|
||||
_es_client = Elasticsearch([ES_HOST])
|
||||
return _es_client
|
||||
|
||||
|
||||
def get_ch_client():
|
||||
global _ch_client
|
||||
if _ch_client is None:
|
||||
_ch_client = Client(**CLICKHOUSE_CONFIG)
|
||||
return _ch_client
|
||||
|
||||
|
||||
def code_to_ch_format(code: str) -> str:
|
||||
if not code or len(code) != 6 or not code.isdigit():
|
||||
return None
|
||||
if code.startswith('6'):
|
||||
return f"{code}.SH"
|
||||
elif code.startswith('0') or code.startswith('3'):
|
||||
return f"{code}.SZ"
|
||||
return f"{code}.BJ"
|
||||
|
||||
|
||||
def time_to_slot(ts) -> str:
|
||||
if isinstance(ts, str):
|
||||
return ts
|
||||
return ts.strftime('%H:%M')
|
||||
|
||||
|
||||
# ==================== 规则评分 ====================
|
||||
|
||||
def score_rules_zscore(features: Dict) -> Tuple[float, List[str]]:
|
||||
"""基于 Z-Score 的规则评分"""
|
||||
score = 0.0
|
||||
triggered = []
|
||||
|
||||
alpha_z = abs(features.get('alpha_zscore', 0))
|
||||
amt_z = features.get('amt_zscore', 0)
|
||||
rank_z = abs(features.get('rank_zscore', 0))
|
||||
mom_3m = features.get('momentum_3m', 0)
|
||||
mom_5m = features.get('momentum_5m', 0)
|
||||
limit_up = features.get('limit_up_ratio', 0)
|
||||
|
||||
# Alpha Z-Score
|
||||
if alpha_z >= 4.0:
|
||||
score += 25
|
||||
triggered.append('alpha_extreme')
|
||||
elif alpha_z >= 3.0:
|
||||
score += 18
|
||||
triggered.append('alpha_strong')
|
||||
elif alpha_z >= 2.0:
|
||||
score += 10
|
||||
triggered.append('alpha_moderate')
|
||||
|
||||
# 成交额 Z-Score
|
||||
if amt_z >= 4.0:
|
||||
score += 20
|
||||
triggered.append('amt_extreme')
|
||||
elif amt_z >= 3.0:
|
||||
score += 12
|
||||
triggered.append('amt_strong')
|
||||
elif amt_z >= 2.0:
|
||||
score += 6
|
||||
triggered.append('amt_moderate')
|
||||
|
||||
# 排名 Z-Score
|
||||
if rank_z >= 3.0:
|
||||
score += 15
|
||||
triggered.append('rank_extreme')
|
||||
elif rank_z >= 2.0:
|
||||
score += 8
|
||||
triggered.append('rank_strong')
|
||||
|
||||
# 动量(基于 Z-Score 的)
|
||||
if mom_3m >= 1.0:
|
||||
score += 12
|
||||
triggered.append('momentum_3m_strong')
|
||||
elif mom_3m >= 0.5:
|
||||
score += 6
|
||||
triggered.append('momentum_3m_moderate')
|
||||
|
||||
if mom_5m >= 1.5:
|
||||
score += 10
|
||||
triggered.append('momentum_5m_strong')
|
||||
|
||||
# 涨停比例
|
||||
if limit_up >= 0.3:
|
||||
score += 20
|
||||
triggered.append('limit_up_extreme')
|
||||
elif limit_up >= 0.15:
|
||||
score += 12
|
||||
triggered.append('limit_up_strong')
|
||||
elif limit_up >= 0.08:
|
||||
score += 5
|
||||
triggered.append('limit_up_moderate')
|
||||
|
||||
# 组合规则
|
||||
if alpha_z >= 2.0 and amt_z >= 2.0:
|
||||
score += 15
|
||||
triggered.append('combo_alpha_amt')
|
||||
|
||||
if alpha_z >= 2.0 and limit_up >= 0.1:
|
||||
score += 12
|
||||
triggered.append('combo_alpha_limitup')
|
||||
|
||||
return min(score, 100), triggered
|
||||
|
||||
|
||||
# ==================== 实时检测器 ====================
|
||||
|
||||
class RealtimeDetectorV2:
|
||||
"""V2 实时异动检测器"""
|
||||
|
||||
def __init__(self, model_dir: str = MODEL_DIR, baseline_file: str = BASELINE_FILE):
|
||||
print("初始化 V2 实时检测器...")
|
||||
|
||||
# 加载概念
|
||||
self.concepts = self._load_concepts()
|
||||
self.concept_stocks = {c['concept_id']: set(c['stocks']) for c in self.concepts}
|
||||
self.all_stocks = list(set(s for c in self.concepts for s in c['stocks']))
|
||||
|
||||
# 加载基线
|
||||
self.baselines = self._load_baselines(baseline_file)
|
||||
|
||||
# 加载模型
|
||||
self.model, self.thresholds, self.device = self._load_model(model_dir)
|
||||
|
||||
# 状态管理
|
||||
self.zscore_history = defaultdict(lambda: deque(maxlen=CONFIG['seq_len']))
|
||||
self.anomaly_candidates = defaultdict(lambda: deque(maxlen=CONFIG['confirm_window']))
|
||||
self.cooldown = {}
|
||||
|
||||
print(f"初始化完成: {len(self.concepts)} 概念, {len(self.baselines)} 基线")
|
||||
|
||||
def _load_concepts(self) -> List[dict]:
|
||||
"""从 ES 加载概念"""
|
||||
es = get_es_client()
|
||||
concepts = []
|
||||
|
||||
query = {"query": {"match_all": {}}, "size": 100, "_source": ["concept_id", "concept", "stocks"]}
|
||||
resp = es.search(index=ES_INDEX, body=query, scroll='2m')
|
||||
scroll_id = resp['_scroll_id']
|
||||
hits = resp['hits']['hits']
|
||||
|
||||
while hits:
|
||||
for hit in hits:
|
||||
src = hit['_source']
|
||||
stocks = [s['code'] for s in src.get('stocks', []) if isinstance(s, dict) and s.get('code')]
|
||||
if stocks:
|
||||
concepts.append({
|
||||
'concept_id': src.get('concept_id'),
|
||||
'concept_name': src.get('concept'),
|
||||
'stocks': stocks
|
||||
})
|
||||
resp = es.scroll(scroll_id=scroll_id, scroll='2m')
|
||||
scroll_id = resp['_scroll_id']
|
||||
hits = resp['hits']['hits']
|
||||
|
||||
es.clear_scroll(scroll_id=scroll_id)
|
||||
return concepts
|
||||
|
||||
def _load_baselines(self, baseline_file: str) -> Dict:
|
||||
"""加载基线"""
|
||||
if not os.path.exists(baseline_file):
|
||||
print(f"警告: 基线文件不存在: {baseline_file}")
|
||||
print("请先运行: python ml/update_baseline.py")
|
||||
return {}
|
||||
|
||||
with open(baseline_file, 'rb') as f:
|
||||
data = pickle.load(f)
|
||||
|
||||
print(f"基线日期范围: {data.get('date_range', 'unknown')}")
|
||||
print(f"更新时间: {data.get('update_time', 'unknown')}")
|
||||
|
||||
return data.get('baselines', {})
|
||||
|
||||
def _load_model(self, model_dir: str):
|
||||
"""加载模型"""
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
config_path = os.path.join(model_dir, 'config.json')
|
||||
model_path = os.path.join(model_dir, 'best_model.pt')
|
||||
threshold_path = os.path.join(model_dir, 'thresholds.json')
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
print(f"警告: 模型不存在: {model_path}")
|
||||
return None, {}, device
|
||||
|
||||
with open(config_path) as f:
|
||||
config = json.load(f)
|
||||
|
||||
model = TransformerAutoencoder(**config['model'])
|
||||
checkpoint = torch.load(model_path, map_location=device)
|
||||
model.load_state_dict(checkpoint['model_state_dict'])
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
thresholds = {}
|
||||
if os.path.exists(threshold_path):
|
||||
with open(threshold_path) as f:
|
||||
thresholds = json.load(f)
|
||||
|
||||
print(f"模型已加载: {model_path}")
|
||||
return model, thresholds, device
|
||||
|
||||
def _get_realtime_data(self, trade_date: str) -> pd.DataFrame:
|
||||
"""获取实时数据并计算原始特征"""
|
||||
ch = get_ch_client()
|
||||
|
||||
# 获取股票数据
|
||||
ch_codes = [code_to_ch_format(c) for c in self.all_stocks if code_to_ch_format(c)]
|
||||
ch_codes_str = "','".join(ch_codes)
|
||||
|
||||
stock_query = f"""
|
||||
SELECT code, timestamp, close, amt
|
||||
FROM stock_minute
|
||||
WHERE toDate(timestamp) = '{trade_date}'
|
||||
AND code IN ('{ch_codes_str}')
|
||||
ORDER BY timestamp
|
||||
"""
|
||||
stock_result = ch.execute(stock_query)
|
||||
if not stock_result:
|
||||
return pd.DataFrame()
|
||||
|
||||
stock_df = pd.DataFrame(stock_result, columns=['ch_code', 'timestamp', 'close', 'amt'])
|
||||
|
||||
# 映射回原始代码
|
||||
ch_to_code = {code_to_ch_format(c): c for c in self.all_stocks if code_to_ch_format(c)}
|
||||
stock_df['code'] = stock_df['ch_code'].map(ch_to_code)
|
||||
stock_df = stock_df.dropna(subset=['code'])
|
||||
|
||||
# 获取指数数据
|
||||
index_query = f"""
|
||||
SELECT timestamp, close
|
||||
FROM index_minute
|
||||
WHERE toDate(timestamp) = '{trade_date}'
|
||||
AND code = '{REFERENCE_INDEX}'
|
||||
ORDER BY timestamp
|
||||
"""
|
||||
index_result = ch.execute(index_query)
|
||||
if not index_result:
|
||||
return pd.DataFrame()
|
||||
|
||||
index_df = pd.DataFrame(index_result, columns=['timestamp', 'close'])
|
||||
|
||||
# 获取昨收价
|
||||
engine = get_mysql_engine()
|
||||
codes_str = "','".join([c for c in self.all_stocks if c and len(c) == 6])
|
||||
|
||||
with engine.connect() as conn:
|
||||
prev_result = conn.execute(text(f"""
|
||||
SELECT SECCODE, F007N FROM ea_trade
|
||||
WHERE SECCODE IN ('{codes_str}')
|
||||
AND TRADEDATE = (SELECT MAX(TRADEDATE) FROM ea_trade WHERE TRADEDATE < '{trade_date}')
|
||||
AND F007N > 0
|
||||
"""))
|
||||
prev_close = {row[0]: float(row[1]) for row in prev_result if row[1]}
|
||||
|
||||
idx_result = conn.execute(text("""
|
||||
SELECT F006N FROM ea_exchangetrade
|
||||
WHERE INDEXCODE = '000001' AND TRADEDATE < :today
|
||||
ORDER BY TRADEDATE DESC LIMIT 1
|
||||
"""), {'today': trade_date}).fetchone()
|
||||
index_prev_close = float(idx_result[0]) if idx_result else None
|
||||
|
||||
if not prev_close or not index_prev_close:
|
||||
return pd.DataFrame()
|
||||
|
||||
# 计算涨跌幅
|
||||
stock_df['prev_close'] = stock_df['code'].map(prev_close)
|
||||
stock_df = stock_df.dropna(subset=['prev_close'])
|
||||
stock_df['change_pct'] = (stock_df['close'] - stock_df['prev_close']) / stock_df['prev_close'] * 100
|
||||
|
||||
index_df['change_pct'] = (index_df['close'] - index_prev_close) / index_prev_close * 100
|
||||
index_map = dict(zip(index_df['timestamp'], index_df['change_pct']))
|
||||
|
||||
# 按时间聚合概念特征
|
||||
results = []
|
||||
for ts in sorted(stock_df['timestamp'].unique()):
|
||||
ts_data = stock_df[stock_df['timestamp'] == ts]
|
||||
idx_chg = index_map.get(ts, 0)
|
||||
|
||||
stock_chg = dict(zip(ts_data['code'], ts_data['change_pct']))
|
||||
stock_amt = dict(zip(ts_data['code'], ts_data['amt']))
|
||||
|
||||
for cid, stocks in self.concept_stocks.items():
|
||||
changes = [stock_chg[s] for s in stocks if s in stock_chg]
|
||||
amts = [stock_amt.get(s, 0) for s in stocks if s in stock_chg]
|
||||
|
||||
if not changes:
|
||||
continue
|
||||
|
||||
alpha = np.mean(changes) - idx_chg
|
||||
total_amt = sum(amts)
|
||||
limit_up_ratio = sum(1 for c in changes if c >= CONFIG['limit_up_threshold']) / len(changes)
|
||||
|
||||
results.append({
|
||||
'concept_id': cid,
|
||||
'timestamp': ts,
|
||||
'time_slot': time_to_slot(ts),
|
||||
'alpha': alpha,
|
||||
'total_amt': total_amt,
|
||||
'limit_up_ratio': limit_up_ratio,
|
||||
'stock_count': len(changes),
|
||||
})
|
||||
|
||||
if not results:
|
||||
return pd.DataFrame()
|
||||
|
||||
df = pd.DataFrame(results)
|
||||
|
||||
# 计算排名
|
||||
for ts in df['timestamp'].unique():
|
||||
mask = df['timestamp'] == ts
|
||||
df.loc[mask, 'rank_pct'] = df.loc[mask, 'alpha'].rank(pct=True)
|
||||
|
||||
return df
|
||||
|
||||
def _compute_zscore(self, concept_id: str, time_slot: str, alpha: float, total_amt: float, rank_pct: float) -> Optional[Dict]:
|
||||
"""计算 Z-Score"""
|
||||
if concept_id not in self.baselines:
|
||||
return None
|
||||
|
||||
baseline = self.baselines[concept_id]
|
||||
if time_slot not in baseline:
|
||||
return None
|
||||
|
||||
bl = baseline[time_slot]
|
||||
|
||||
alpha_z = np.clip((alpha - bl['alpha_mean']) / bl['alpha_std'], -5, 5)
|
||||
amt_z = np.clip((total_amt - bl['amt_mean']) / bl['amt_std'], -5, 5)
|
||||
rank_z = np.clip((rank_pct - bl['rank_mean']) / bl['rank_std'], -5, 5)
|
||||
|
||||
# 动量(基于 Z-Score 历史)
|
||||
history = list(self.zscore_history[concept_id])
|
||||
mom_3m = 0.0
|
||||
mom_5m = 0.0
|
||||
|
||||
if len(history) >= 3:
|
||||
recent = [h['alpha_zscore'] for h in history[-3:]]
|
||||
older = [h['alpha_zscore'] for h in history[-6:-3]] if len(history) >= 6 else [history[0]['alpha_zscore']]
|
||||
mom_3m = np.mean(recent) - np.mean(older)
|
||||
|
||||
if len(history) >= 5:
|
||||
recent = [h['alpha_zscore'] for h in history[-5:]]
|
||||
older = [h['alpha_zscore'] for h in history[-10:-5]] if len(history) >= 10 else [history[0]['alpha_zscore']]
|
||||
mom_5m = np.mean(recent) - np.mean(older)
|
||||
|
||||
return {
|
||||
'alpha_zscore': float(alpha_z),
|
||||
'amt_zscore': float(amt_z),
|
||||
'rank_zscore': float(rank_z),
|
||||
'momentum_3m': float(mom_3m),
|
||||
'momentum_5m': float(mom_5m),
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def _ml_score(self, sequences: np.ndarray) -> np.ndarray:
|
||||
"""批量 ML 评分"""
|
||||
if self.model is None or len(sequences) == 0:
|
||||
return np.zeros(len(sequences))
|
||||
|
||||
x = torch.FloatTensor(sequences).to(self.device)
|
||||
errors = self.model.compute_reconstruction_error(x, reduction='none')
|
||||
last_errors = errors[:, -1].cpu().numpy()
|
||||
|
||||
# 转换为 0-100 分数
|
||||
if self.thresholds:
|
||||
p50 = self.thresholds.get('median', 0.001)
|
||||
p99 = self.thresholds.get('p99', 0.05)
|
||||
scores = 50 + (last_errors - p50) / (p99 - p50 + 1e-6) * 49
|
||||
else:
|
||||
scores = last_errors * 1000
|
||||
|
||||
return np.clip(scores, 0, 100)
|
||||
|
||||
def detect(self, trade_date: str = None) -> List[Dict]:
|
||||
"""检测指定日期的异动"""
|
||||
trade_date = trade_date or datetime.now().strftime('%Y-%m-%d')
|
||||
print(f"\n检测 {trade_date} 的异动...")
|
||||
|
||||
# 重置状态
|
||||
self.zscore_history.clear()
|
||||
self.anomaly_candidates.clear()
|
||||
self.cooldown.clear()
|
||||
|
||||
# 获取数据
|
||||
raw_df = self._get_realtime_data(trade_date)
|
||||
if raw_df.empty:
|
||||
print("无数据")
|
||||
return []
|
||||
|
||||
timestamps = sorted(raw_df['timestamp'].unique())
|
||||
print(f"时间点数: {len(timestamps)}")
|
||||
|
||||
all_alerts = []
|
||||
|
||||
for ts in timestamps:
|
||||
ts_data = raw_df[raw_df['timestamp'] == ts]
|
||||
time_slot = time_to_slot(ts)
|
||||
|
||||
candidates = []
|
||||
|
||||
# 计算每个概念的 Z-Score
|
||||
for _, row in ts_data.iterrows():
|
||||
cid = row['concept_id']
|
||||
|
||||
zscore = self._compute_zscore(
|
||||
cid, time_slot,
|
||||
row['alpha'], row['total_amt'], row['rank_pct']
|
||||
)
|
||||
|
||||
if zscore is None:
|
||||
continue
|
||||
|
||||
# 完整特征
|
||||
features = {
|
||||
**zscore,
|
||||
'alpha': row['alpha'],
|
||||
'limit_up_ratio': row['limit_up_ratio'],
|
||||
'total_amt': row['total_amt'],
|
||||
}
|
||||
|
||||
# 更新历史
|
||||
self.zscore_history[cid].append(zscore)
|
||||
|
||||
# 规则评分
|
||||
rule_score, triggered = score_rules_zscore(features)
|
||||
|
||||
candidates.append((cid, features, rule_score, triggered))
|
||||
|
||||
if not candidates:
|
||||
continue
|
||||
|
||||
# 批量 ML 评分
|
||||
sequences = []
|
||||
valid_candidates = []
|
||||
|
||||
for cid, features, rule_score, triggered in candidates:
|
||||
history = list(self.zscore_history[cid])
|
||||
if len(history) >= CONFIG['seq_len']:
|
||||
seq = np.array([[h['alpha_zscore'], h['amt_zscore'], h['rank_zscore'],
|
||||
h['momentum_3m'], h['momentum_5m'], features['limit_up_ratio']]
|
||||
for h in history])
|
||||
sequences.append(seq)
|
||||
valid_candidates.append((cid, features, rule_score, triggered))
|
||||
|
||||
if not sequences:
|
||||
continue
|
||||
|
||||
ml_scores = self._ml_score(np.array(sequences))
|
||||
|
||||
# 融合 + 确认
|
||||
for i, (cid, features, rule_score, triggered) in enumerate(valid_candidates):
|
||||
ml_score = ml_scores[i]
|
||||
final_score = CONFIG['rule_weight'] * rule_score + CONFIG['ml_weight'] * ml_score
|
||||
|
||||
# 判断触发
|
||||
is_triggered = (
|
||||
rule_score >= CONFIG['rule_trigger'] or
|
||||
ml_score >= CONFIG['ml_trigger'] or
|
||||
final_score >= CONFIG['fusion_trigger']
|
||||
)
|
||||
|
||||
self.anomaly_candidates[cid].append((ts, final_score))
|
||||
|
||||
if not is_triggered:
|
||||
continue
|
||||
|
||||
# 冷却期
|
||||
if cid in self.cooldown:
|
||||
if (ts - self.cooldown[cid]).total_seconds() < CONFIG['cooldown_minutes'] * 60:
|
||||
continue
|
||||
|
||||
# 持续确认
|
||||
recent = list(self.anomaly_candidates[cid])
|
||||
if len(recent) < CONFIG['confirm_window']:
|
||||
continue
|
||||
|
||||
exceed = sum(1 for _, s in recent if s >= CONFIG['fusion_trigger'])
|
||||
ratio = exceed / len(recent)
|
||||
|
||||
if ratio < CONFIG['confirm_ratio']:
|
||||
continue
|
||||
|
||||
# 确认异动!
|
||||
self.cooldown[cid] = ts
|
||||
|
||||
alpha = features['alpha']
|
||||
alert_type = 'surge_up' if alpha >= 1.5 else 'surge_down' if alpha <= -1.5 else 'surge'
|
||||
|
||||
concept_name = next((c['concept_name'] for c in self.concepts if c['concept_id'] == cid), cid)
|
||||
|
||||
all_alerts.append({
|
||||
'concept_id': cid,
|
||||
'concept_name': concept_name,
|
||||
'alert_time': ts,
|
||||
'trade_date': trade_date,
|
||||
'alert_type': alert_type,
|
||||
'final_score': float(final_score),
|
||||
'rule_score': float(rule_score),
|
||||
'ml_score': float(ml_score),
|
||||
'confirm_ratio': float(ratio),
|
||||
'alpha': float(alpha),
|
||||
'alpha_zscore': float(features['alpha_zscore']),
|
||||
'amt_zscore': float(features['amt_zscore']),
|
||||
'rank_zscore': float(features['rank_zscore']),
|
||||
'momentum_3m': float(features['momentum_3m']),
|
||||
'momentum_5m': float(features['momentum_5m']),
|
||||
'limit_up_ratio': float(features['limit_up_ratio']),
|
||||
'triggered_rules': triggered,
|
||||
'trigger_reason': f"融合({final_score:.0f})+确认({ratio:.0%})",
|
||||
})
|
||||
|
||||
print(f"检测到 {len(all_alerts)} 个异动")
|
||||
return all_alerts
|
||||
|
||||
|
||||
# ==================== 数据库存储 ====================
|
||||
|
||||
def create_v2_table():
|
||||
"""创建 V2 异动表(如果不存在)"""
|
||||
engine = get_mysql_engine()
|
||||
with engine.begin() as conn:
|
||||
conn.execute(text("""
|
||||
CREATE TABLE IF NOT EXISTS concept_anomaly_v2 (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
concept_id VARCHAR(50) NOT NULL,
|
||||
alert_time DATETIME NOT NULL,
|
||||
trade_date DATE NOT NULL,
|
||||
alert_type VARCHAR(20) NOT NULL,
|
||||
final_score FLOAT,
|
||||
rule_score FLOAT,
|
||||
ml_score FLOAT,
|
||||
trigger_reason VARCHAR(200),
|
||||
confirm_ratio FLOAT,
|
||||
alpha FLOAT,
|
||||
alpha_zscore FLOAT,
|
||||
amt_zscore FLOAT,
|
||||
rank_zscore FLOAT,
|
||||
momentum_3m FLOAT,
|
||||
momentum_5m FLOAT,
|
||||
limit_up_ratio FLOAT,
|
||||
triggered_rules TEXT,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
UNIQUE KEY uk_concept_time (concept_id, alert_time),
|
||||
INDEX idx_trade_date (trade_date),
|
||||
INDEX idx_alert_type (alert_type)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
|
||||
"""))
|
||||
print("concept_anomaly_v2 表已就绪")
|
||||
|
||||
|
||||
def save_alerts_to_db(alerts: List[Dict]) -> int:
|
||||
"""保存异动到数据库"""
|
||||
if not alerts:
|
||||
return 0
|
||||
|
||||
engine = get_mysql_engine()
|
||||
saved = 0
|
||||
|
||||
with engine.begin() as conn:
|
||||
for alert in alerts:
|
||||
try:
|
||||
insert_sql = text("""
|
||||
INSERT IGNORE INTO concept_anomaly_v2
|
||||
(concept_id, alert_time, trade_date, alert_type,
|
||||
final_score, rule_score, ml_score, trigger_reason, confirm_ratio,
|
||||
alpha, alpha_zscore, amt_zscore, rank_zscore,
|
||||
momentum_3m, momentum_5m, limit_up_ratio, triggered_rules)
|
||||
VALUES
|
||||
(:concept_id, :alert_time, :trade_date, :alert_type,
|
||||
:final_score, :rule_score, :ml_score, :trigger_reason, :confirm_ratio,
|
||||
:alpha, :alpha_zscore, :amt_zscore, :rank_zscore,
|
||||
:momentum_3m, :momentum_5m, :limit_up_ratio, :triggered_rules)
|
||||
""")
|
||||
|
||||
result = conn.execute(insert_sql, {
|
||||
'concept_id': alert['concept_id'],
|
||||
'alert_time': alert['alert_time'],
|
||||
'trade_date': alert['trade_date'],
|
||||
'alert_type': alert['alert_type'],
|
||||
'final_score': alert['final_score'],
|
||||
'rule_score': alert['rule_score'],
|
||||
'ml_score': alert['ml_score'],
|
||||
'trigger_reason': alert['trigger_reason'],
|
||||
'confirm_ratio': alert['confirm_ratio'],
|
||||
'alpha': alert['alpha'],
|
||||
'alpha_zscore': alert['alpha_zscore'],
|
||||
'amt_zscore': alert['amt_zscore'],
|
||||
'rank_zscore': alert['rank_zscore'],
|
||||
'momentum_3m': alert['momentum_3m'],
|
||||
'momentum_5m': alert['momentum_5m'],
|
||||
'limit_up_ratio': alert['limit_up_ratio'],
|
||||
'triggered_rules': json.dumps(alert.get('triggered_rules', []), ensure_ascii=False),
|
||||
})
|
||||
|
||||
if result.rowcount > 0:
|
||||
saved += 1
|
||||
except Exception as e:
|
||||
print(f"保存失败: {alert['concept_id']} - {e}")
|
||||
|
||||
return saved
|
||||
|
||||
|
||||
def main():
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--date', type=str, default=None)
|
||||
parser.add_argument('--no-save', action='store_true', help='不保存到数据库,只打印')
|
||||
args = parser.parse_args()
|
||||
|
||||
# 确保表存在
|
||||
if not args.no_save:
|
||||
create_v2_table()
|
||||
|
||||
detector = RealtimeDetectorV2()
|
||||
alerts = detector.detect(args.date)
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f"检测结果 ({len(alerts)} 个异动)")
|
||||
print('='*60)
|
||||
|
||||
for a in alerts[:20]:
|
||||
print(f"[{a['alert_time'].strftime('%H:%M') if hasattr(a['alert_time'], 'strftime') else a['alert_time']}] "
|
||||
f"{a['concept_name']} | {a['alert_type']} | "
|
||||
f"分数={a['final_score']:.0f} 确认={a['confirm_ratio']:.0%} "
|
||||
f"α={a['alpha']:.2f}% αZ={a['alpha_zscore']:.1f}")
|
||||
|
||||
if len(alerts) > 20:
|
||||
print(f"... 共 {len(alerts)} 个")
|
||||
|
||||
# 保存到数据库
|
||||
if not args.no_save and alerts:
|
||||
saved = save_alerts_to_db(alerts)
|
||||
print(f"\n✅ 已保存 {saved}/{len(alerts)} 条到 concept_anomaly_v2 表")
|
||||
elif args.no_save:
|
||||
print(f"\n⚠️ --no-save 模式,未保存到数据库")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
25
ml/requirements.txt
Normal file
25
ml/requirements.txt
Normal file
@@ -0,0 +1,25 @@
|
||||
# 概念异动检测 ML 模块依赖
|
||||
# 安装: pip install -r ml/requirements.txt
|
||||
|
||||
# PyTorch (根据 CUDA 版本选择)
|
||||
# 5090 显卡需要 CUDA 12.x
|
||||
# pip install torch --index-url https://download.pytorch.org/whl/cu124
|
||||
torch>=2.0.0
|
||||
|
||||
# 数据处理
|
||||
numpy>=1.24.0
|
||||
pandas>=2.0.0
|
||||
pyarrow>=14.0.0
|
||||
|
||||
# 数据库
|
||||
clickhouse-driver>=0.2.6
|
||||
elasticsearch>=7.0.0,<8.0.0
|
||||
sqlalchemy>=2.0.0
|
||||
pymysql>=1.1.0
|
||||
|
||||
# 训练工具
|
||||
tqdm>=4.65.0
|
||||
|
||||
# 可选: 可视化
|
||||
# matplotlib>=3.7.0
|
||||
# tensorboard>=2.14.0
|
||||
99
ml/run_training.sh
Normal file
99
ml/run_training.sh
Normal file
@@ -0,0 +1,99 @@
|
||||
#!/bin/bash
|
||||
# 概念异动检测模型训练脚本 (Linux)
|
||||
#
|
||||
# 使用方法:
|
||||
# chmod +x run_training.sh
|
||||
# ./run_training.sh
|
||||
#
|
||||
# 或指定参数:
|
||||
# ./run_training.sh --start 2022-01-01 --epochs 100
|
||||
|
||||
set -e
|
||||
|
||||
echo "============================================================"
|
||||
echo "概念异动检测模型训练流程"
|
||||
echo "============================================================"
|
||||
echo ""
|
||||
|
||||
# 获取脚本所在目录
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
cd "$SCRIPT_DIR/.."
|
||||
|
||||
echo "[1/4] 检查环境..."
|
||||
python3 --version || { echo "Python3 未找到!"; exit 1; }
|
||||
|
||||
# 检查 GPU
|
||||
if python3 -c "import torch; print(f'CUDA: {torch.cuda.is_available()}')" 2>/dev/null; then
|
||||
echo "PyTorch GPU 检测完成"
|
||||
else
|
||||
echo "警告: PyTorch 未安装或无法检测 GPU"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "[2/4] 检查依赖..."
|
||||
pip3 install -q torch pandas numpy pyarrow tqdm clickhouse-driver elasticsearch sqlalchemy pymysql
|
||||
|
||||
echo ""
|
||||
echo "[3/4] 准备训练数据..."
|
||||
echo "从 ClickHouse 提取历史数据,这可能需要较长时间..."
|
||||
echo ""
|
||||
|
||||
# 解析参数
|
||||
START_DATE="2022-01-01"
|
||||
END_DATE=""
|
||||
EPOCHS=100
|
||||
BATCH_SIZE=256
|
||||
TRAIN_END="2025-06-30"
|
||||
VAL_END="2025-09-30"
|
||||
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case $1 in
|
||||
--start)
|
||||
START_DATE="$2"
|
||||
shift 2
|
||||
;;
|
||||
--end)
|
||||
END_DATE="$2"
|
||||
shift 2
|
||||
;;
|
||||
--epochs)
|
||||
EPOCHS="$2"
|
||||
shift 2
|
||||
;;
|
||||
--batch_size)
|
||||
BATCH_SIZE="$2"
|
||||
shift 2
|
||||
;;
|
||||
--train_end)
|
||||
TRAIN_END="$2"
|
||||
shift 2
|
||||
;;
|
||||
--val_end)
|
||||
VAL_END="$2"
|
||||
shift 2
|
||||
;;
|
||||
*)
|
||||
shift
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
# 数据准备
|
||||
if [ -n "$END_DATE" ]; then
|
||||
python3 ml/prepare_data.py --start "$START_DATE" --end "$END_DATE"
|
||||
else
|
||||
python3 ml/prepare_data.py --start "$START_DATE"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "[4/4] 训练模型..."
|
||||
echo "使用 GPU 加速训练..."
|
||||
echo ""
|
||||
|
||||
python3 ml/train.py --epochs "$EPOCHS" --batch_size "$BATCH_SIZE" --train_end "$TRAIN_END" --val_end "$VAL_END"
|
||||
|
||||
echo ""
|
||||
echo "============================================================"
|
||||
echo "训练完成!"
|
||||
echo "模型保存在: ml/checkpoints/"
|
||||
echo "============================================================"
|
||||
808
ml/train.py
Normal file
808
ml/train.py
Normal file
@@ -0,0 +1,808 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Transformer Autoencoder 训练脚本 (修复版)
|
||||
|
||||
修复问题:
|
||||
1. 按概念分组构建序列,避免跨概念切片
|
||||
2. 按时间(日期)切分数据集,避免数据泄露
|
||||
3. 使用 RobustScaler + Clipping 处理非平稳性
|
||||
4. 使用验证集计算阈值
|
||||
|
||||
训练流程:
|
||||
1. 加载预处理好的特征数据(parquet 文件)
|
||||
2. 按概念分组,在每个概念内部构建序列
|
||||
3. 按日期划分训练/验证/测试集
|
||||
4. 训练 Autoencoder(最小化重构误差)
|
||||
5. 保存模型和阈值
|
||||
|
||||
使用方法:
|
||||
python train.py --data_dir ml/data --epochs 100 --batch_size 256
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple, Dict
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from torch.optim import AdamW
|
||||
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
|
||||
from tqdm import tqdm
|
||||
|
||||
from model import TransformerAutoencoder, AnomalyDetectionLoss, count_parameters
|
||||
|
||||
# 性能优化:启用 cuDNN benchmark(对固定输入尺寸自动选择最快算法)
|
||||
torch.backends.cudnn.benchmark = True
|
||||
# 启用 TF32(RTX 30/40 系列特有,提速约 3 倍)
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
# 可视化(可选)
|
||||
try:
|
||||
import matplotlib
|
||||
matplotlib.use('Agg') # 无头模式,不需要显示器
|
||||
import matplotlib.pyplot as plt
|
||||
HAS_MATPLOTLIB = True
|
||||
except ImportError:
|
||||
HAS_MATPLOTLIB = False
|
||||
|
||||
|
||||
# ==================== 配置 ====================
|
||||
|
||||
TRAIN_CONFIG = {
|
||||
# 数据配置
|
||||
'seq_len': 30, # 输入序列长度(30分钟)
|
||||
'stride': 5, # 滑动窗口步长
|
||||
|
||||
# 时间切分(按日期)
|
||||
'train_end_date': '2024-06-30', # 训练集截止日期
|
||||
'val_end_date': '2024-09-30', # 验证集截止日期(之后为测试集)
|
||||
|
||||
# 特征配置
|
||||
'features': [
|
||||
'alpha', # 超额收益
|
||||
'alpha_delta', # Alpha 变化率
|
||||
'amt_ratio', # 成交额比率
|
||||
'amt_delta', # 成交额变化率
|
||||
'rank_pct', # Alpha 排名百分位
|
||||
'limit_up_ratio', # 涨停比例
|
||||
],
|
||||
|
||||
# 训练配置(针对 4x RTX 4090 优化)
|
||||
'batch_size': 4096, # 256 -> 4096(大幅增加,充分利用显存)
|
||||
'epochs': 100,
|
||||
'learning_rate': 3e-4, # 1e-4 -> 3e-4(大 batch 需要更大学习率)
|
||||
'weight_decay': 1e-5,
|
||||
'gradient_clip': 1.0,
|
||||
|
||||
# 早停配置
|
||||
'patience': 10,
|
||||
'min_delta': 1e-6,
|
||||
|
||||
# 模型配置(LSTM Autoencoder,简洁有效)
|
||||
'model': {
|
||||
'n_features': 6,
|
||||
'hidden_dim': 32, # LSTM 隐藏维度(小)
|
||||
'latent_dim': 4, # 瓶颈维度(非常小!关键)
|
||||
'num_layers': 1, # LSTM 层数
|
||||
'dropout': 0.2,
|
||||
'bidirectional': True, # 双向编码器
|
||||
},
|
||||
|
||||
# 标准化配置
|
||||
'use_instance_norm': True, # 模型内部使用 Instance Norm(推荐)
|
||||
'clip_value': 10.0, # 简单截断极端值
|
||||
|
||||
# 阈值配置
|
||||
'threshold_percentiles': [90, 95, 99],
|
||||
}
|
||||
|
||||
|
||||
# ==================== 数据加载(修复版)====================
|
||||
|
||||
def load_data_by_date(data_dir: str, features: List[str]) -> Dict[str, pd.DataFrame]:
|
||||
"""
|
||||
按日期加载数据,返回 {date: DataFrame} 字典
|
||||
|
||||
每个 DataFrame 包含该日所有概念的所有时间点数据
|
||||
"""
|
||||
data_path = Path(data_dir)
|
||||
parquet_files = sorted(data_path.glob("features_*.parquet"))
|
||||
|
||||
if not parquet_files:
|
||||
raise FileNotFoundError(f"未找到 parquet 文件: {data_dir}")
|
||||
|
||||
print(f"找到 {len(parquet_files)} 个数据文件")
|
||||
|
||||
date_data = {}
|
||||
|
||||
for pf in tqdm(parquet_files, desc="加载数据"):
|
||||
# 提取日期
|
||||
date = pf.stem.replace('features_', '')
|
||||
|
||||
df = pd.read_parquet(pf)
|
||||
|
||||
# 检查必要列
|
||||
required_cols = features + ['concept_id', 'timestamp']
|
||||
missing_cols = [c for c in required_cols if c not in df.columns]
|
||||
if missing_cols:
|
||||
print(f"警告: {date} 缺少列: {missing_cols}, 跳过")
|
||||
continue
|
||||
|
||||
date_data[date] = df
|
||||
|
||||
print(f"成功加载 {len(date_data)} 天的数据")
|
||||
return date_data
|
||||
|
||||
|
||||
def split_data_by_date(
|
||||
date_data: Dict[str, pd.DataFrame],
|
||||
train_end: str,
|
||||
val_end: str
|
||||
) -> Tuple[Dict[str, pd.DataFrame], Dict[str, pd.DataFrame], Dict[str, pd.DataFrame]]:
|
||||
"""
|
||||
按日期严格划分数据集
|
||||
|
||||
- 训练集: <= train_end
|
||||
- 验证集: train_end < date <= val_end
|
||||
- 测试集: > val_end
|
||||
"""
|
||||
train_data = {}
|
||||
val_data = {}
|
||||
test_data = {}
|
||||
|
||||
for date, df in date_data.items():
|
||||
if date <= train_end:
|
||||
train_data[date] = df
|
||||
elif date <= val_end:
|
||||
val_data[date] = df
|
||||
else:
|
||||
test_data[date] = df
|
||||
|
||||
print(f"数据集划分(按日期):")
|
||||
print(f" 训练集: {len(train_data)} 天 (<= {train_end})")
|
||||
print(f" 验证集: {len(val_data)} 天 ({train_end} ~ {val_end})")
|
||||
print(f" 测试集: {len(test_data)} 天 (> {val_end})")
|
||||
|
||||
return train_data, val_data, test_data
|
||||
|
||||
|
||||
def build_sequences_by_concept(
|
||||
date_data: Dict[str, pd.DataFrame],
|
||||
features: List[str],
|
||||
seq_len: int,
|
||||
stride: int
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
按概念分组构建序列(性能优化版)
|
||||
|
||||
使用 groupby 一次性分组,避免重复扫描大数组
|
||||
|
||||
1. 将所有日期的数据合并
|
||||
2. 使用 groupby 按 concept_id 分组
|
||||
3. 在每个概念内部,按时间排序并滑动窗口
|
||||
4. 合并所有序列
|
||||
"""
|
||||
# 合并所有日期的数据
|
||||
all_dfs = []
|
||||
for date, df in sorted(date_data.items()):
|
||||
df = df.copy()
|
||||
df['date'] = date
|
||||
all_dfs.append(df)
|
||||
|
||||
if not all_dfs:
|
||||
return np.array([])
|
||||
|
||||
combined = pd.concat(all_dfs, ignore_index=True)
|
||||
|
||||
# 预先排序(按概念、日期、时间),这样 groupby 会更快
|
||||
combined = combined.sort_values(['concept_id', 'date', 'timestamp'])
|
||||
|
||||
# 使用 groupby 一次性分组(性能关键!)
|
||||
all_sequences = []
|
||||
grouped = combined.groupby('concept_id', sort=False)
|
||||
n_concepts = len(grouped)
|
||||
|
||||
for concept_id, concept_df in tqdm(grouped, desc="构建序列", total=n_concepts, leave=False):
|
||||
# 已经排序过了,直接提取特征
|
||||
feature_data = concept_df[features].values
|
||||
|
||||
# 处理缺失值
|
||||
feature_data = np.nan_to_num(feature_data, nan=0.0, posinf=0.0, neginf=0.0)
|
||||
|
||||
# 在该概念内部滑动窗口
|
||||
n_points = len(feature_data)
|
||||
for start in range(0, n_points - seq_len + 1, stride):
|
||||
seq = feature_data[start:start + seq_len]
|
||||
all_sequences.append(seq)
|
||||
|
||||
if not all_sequences:
|
||||
return np.array([])
|
||||
|
||||
sequences = np.array(all_sequences)
|
||||
print(f" 构建序列: {len(sequences):,} 条 (来自 {n_concepts} 个概念)")
|
||||
|
||||
return sequences
|
||||
|
||||
|
||||
# ==================== 数据集 ====================
|
||||
|
||||
class SequenceDataset(Dataset):
|
||||
"""序列数据集(已经构建好的序列)"""
|
||||
|
||||
def __init__(self, sequences: np.ndarray):
|
||||
self.sequences = torch.FloatTensor(sequences)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.sequences)
|
||||
|
||||
def __getitem__(self, idx: int) -> torch.Tensor:
|
||||
return self.sequences[idx]
|
||||
|
||||
|
||||
# ==================== 训练器 ====================
|
||||
|
||||
class EarlyStopping:
|
||||
"""早停机制"""
|
||||
|
||||
def __init__(self, patience: int = 10, min_delta: float = 1e-6):
|
||||
self.patience = patience
|
||||
self.min_delta = min_delta
|
||||
self.counter = 0
|
||||
self.best_loss = float('inf')
|
||||
self.early_stop = False
|
||||
|
||||
def __call__(self, val_loss: float) -> bool:
|
||||
if val_loss < self.best_loss - self.min_delta:
|
||||
self.best_loss = val_loss
|
||||
self.counter = 0
|
||||
else:
|
||||
self.counter += 1
|
||||
if self.counter >= self.patience:
|
||||
self.early_stop = True
|
||||
|
||||
return self.early_stop
|
||||
|
||||
|
||||
class Trainer:
|
||||
"""模型训练器(支持 AMP 混合精度加速)"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
train_loader: DataLoader,
|
||||
val_loader: DataLoader,
|
||||
config: Dict,
|
||||
device: torch.device,
|
||||
save_dir: str = 'ml/checkpoints'
|
||||
):
|
||||
self.model = model.to(device)
|
||||
self.train_loader = train_loader
|
||||
self.val_loader = val_loader
|
||||
self.config = config
|
||||
self.device = device
|
||||
self.save_dir = Path(save_dir)
|
||||
self.save_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 优化器
|
||||
self.optimizer = AdamW(
|
||||
model.parameters(),
|
||||
lr=config['learning_rate'],
|
||||
weight_decay=config['weight_decay']
|
||||
)
|
||||
|
||||
# 学习率调度器
|
||||
self.scheduler = CosineAnnealingWarmRestarts(
|
||||
self.optimizer,
|
||||
T_0=10,
|
||||
T_mult=2,
|
||||
eta_min=1e-6
|
||||
)
|
||||
|
||||
# 损失函数(简化版,只用 MSE)
|
||||
self.criterion = AnomalyDetectionLoss()
|
||||
|
||||
# 早停
|
||||
self.early_stopping = EarlyStopping(
|
||||
patience=config['patience'],
|
||||
min_delta=config['min_delta']
|
||||
)
|
||||
|
||||
# AMP 混合精度训练(大幅提速 + 省显存)
|
||||
self.use_amp = torch.cuda.is_available()
|
||||
self.scaler = torch.cuda.amp.GradScaler() if self.use_amp else None
|
||||
if self.use_amp:
|
||||
print(" ✓ 启用 AMP 混合精度训练")
|
||||
|
||||
# 训练历史
|
||||
self.history = {
|
||||
'train_loss': [],
|
||||
'val_loss': [],
|
||||
'learning_rate': [],
|
||||
}
|
||||
|
||||
self.best_val_loss = float('inf')
|
||||
|
||||
def train_epoch(self) -> float:
|
||||
"""训练一个 epoch(使用 AMP 混合精度)"""
|
||||
self.model.train()
|
||||
total_loss = 0.0
|
||||
n_batches = 0
|
||||
|
||||
pbar = tqdm(self.train_loader, desc="Training", leave=False)
|
||||
for batch in pbar:
|
||||
batch = batch.to(self.device, non_blocking=True) # 异步传输
|
||||
|
||||
self.optimizer.zero_grad(set_to_none=True) # 更快的梯度清零
|
||||
|
||||
# AMP 混合精度前向传播
|
||||
if self.use_amp:
|
||||
with torch.cuda.amp.autocast():
|
||||
output, latent = self.model(batch)
|
||||
loss, loss_dict = self.criterion(output, batch, latent)
|
||||
|
||||
# AMP 反向传播
|
||||
self.scaler.scale(loss).backward()
|
||||
|
||||
# 梯度裁剪(需要 unscale)
|
||||
self.scaler.unscale_(self.optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(
|
||||
self.model.parameters(),
|
||||
self.config['gradient_clip']
|
||||
)
|
||||
|
||||
self.scaler.step(self.optimizer)
|
||||
self.scaler.update()
|
||||
else:
|
||||
# 非 AMP 模式
|
||||
output, latent = self.model(batch)
|
||||
loss, loss_dict = self.criterion(output, batch, latent)
|
||||
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(
|
||||
self.model.parameters(),
|
||||
self.config['gradient_clip']
|
||||
)
|
||||
self.optimizer.step()
|
||||
|
||||
total_loss += loss.item()
|
||||
n_batches += 1
|
||||
|
||||
pbar.set_postfix({'loss': f"{loss.item():.4f}"})
|
||||
|
||||
return total_loss / n_batches
|
||||
|
||||
@torch.no_grad()
|
||||
def validate(self) -> float:
|
||||
"""验证(使用 AMP)"""
|
||||
self.model.eval()
|
||||
total_loss = 0.0
|
||||
n_batches = 0
|
||||
|
||||
for batch in self.val_loader:
|
||||
batch = batch.to(self.device, non_blocking=True)
|
||||
|
||||
if self.use_amp:
|
||||
with torch.cuda.amp.autocast():
|
||||
output, latent = self.model(batch)
|
||||
loss, _ = self.criterion(output, batch, latent)
|
||||
else:
|
||||
output, latent = self.model(batch)
|
||||
loss, _ = self.criterion(output, batch, latent)
|
||||
|
||||
total_loss += loss.item()
|
||||
n_batches += 1
|
||||
|
||||
return total_loss / n_batches
|
||||
|
||||
def save_checkpoint(self, epoch: int, val_loss: float, is_best: bool = False):
|
||||
"""保存检查点"""
|
||||
# 处理 DataParallel 包装
|
||||
model_to_save = self.model.module if hasattr(self.model, 'module') else self.model
|
||||
|
||||
checkpoint = {
|
||||
'epoch': epoch,
|
||||
'model_state_dict': model_to_save.state_dict(),
|
||||
'optimizer_state_dict': self.optimizer.state_dict(),
|
||||
'scheduler_state_dict': self.scheduler.state_dict(),
|
||||
'val_loss': val_loss,
|
||||
'config': self.config,
|
||||
}
|
||||
|
||||
# 保存最新检查点
|
||||
torch.save(checkpoint, self.save_dir / 'last_checkpoint.pt')
|
||||
|
||||
# 保存最佳模型
|
||||
if is_best:
|
||||
torch.save(checkpoint, self.save_dir / 'best_model.pt')
|
||||
print(f" ✓ 保存最佳模型 (val_loss: {val_loss:.6f})")
|
||||
|
||||
def train(self, epochs: int):
|
||||
"""完整训练流程"""
|
||||
print(f"\n开始训练 ({epochs} epochs)...")
|
||||
print(f"设备: {self.device}")
|
||||
print(f"模型参数量: {count_parameters(self.model):,}")
|
||||
|
||||
for epoch in range(1, epochs + 1):
|
||||
print(f"\nEpoch {epoch}/{epochs}")
|
||||
|
||||
# 训练
|
||||
train_loss = self.train_epoch()
|
||||
|
||||
# 验证
|
||||
val_loss = self.validate()
|
||||
|
||||
# 更新学习率
|
||||
self.scheduler.step()
|
||||
current_lr = self.optimizer.param_groups[0]['lr']
|
||||
|
||||
# 记录历史
|
||||
self.history['train_loss'].append(train_loss)
|
||||
self.history['val_loss'].append(val_loss)
|
||||
self.history['learning_rate'].append(current_lr)
|
||||
|
||||
# 打印进度
|
||||
print(f" Train Loss: {train_loss:.6f}")
|
||||
print(f" Val Loss: {val_loss:.6f}")
|
||||
print(f" LR: {current_lr:.2e}")
|
||||
|
||||
# 保存检查点
|
||||
is_best = val_loss < self.best_val_loss
|
||||
if is_best:
|
||||
self.best_val_loss = val_loss
|
||||
self.save_checkpoint(epoch, val_loss, is_best)
|
||||
|
||||
# 早停检查
|
||||
if self.early_stopping(val_loss):
|
||||
print(f"\n早停触发!验证损失已 {self.early_stopping.patience} 个 epoch 未改善")
|
||||
break
|
||||
|
||||
print(f"\n训练完成!最佳验证损失: {self.best_val_loss:.6f}")
|
||||
|
||||
# 保存训练历史
|
||||
self.save_history()
|
||||
|
||||
return self.history
|
||||
|
||||
def save_history(self):
|
||||
"""保存训练历史"""
|
||||
history_path = self.save_dir / 'training_history.json'
|
||||
with open(history_path, 'w') as f:
|
||||
json.dump(self.history, f, indent=2)
|
||||
print(f"训练历史已保存: {history_path}")
|
||||
|
||||
# 绘制训练曲线
|
||||
self.plot_training_curves()
|
||||
|
||||
def plot_training_curves(self):
|
||||
"""绘制训练曲线"""
|
||||
if not HAS_MATPLOTLIB:
|
||||
print("matplotlib 未安装,跳过绘图")
|
||||
return
|
||||
|
||||
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
|
||||
|
||||
epochs = range(1, len(self.history['train_loss']) + 1)
|
||||
|
||||
# 1. Loss 曲线
|
||||
ax1 = axes[0]
|
||||
ax1.plot(epochs, self.history['train_loss'], 'b-', label='Train Loss', linewidth=2)
|
||||
ax1.plot(epochs, self.history['val_loss'], 'r-', label='Val Loss', linewidth=2)
|
||||
ax1.set_xlabel('Epoch', fontsize=12)
|
||||
ax1.set_ylabel('Loss', fontsize=12)
|
||||
ax1.set_title('Training & Validation Loss', fontsize=14)
|
||||
ax1.legend(fontsize=11)
|
||||
ax1.grid(True, alpha=0.3)
|
||||
|
||||
# 标记最佳点
|
||||
best_epoch = np.argmin(self.history['val_loss']) + 1
|
||||
best_val_loss = min(self.history['val_loss'])
|
||||
ax1.axvline(x=best_epoch, color='g', linestyle='--', alpha=0.7, label=f'Best Epoch: {best_epoch}')
|
||||
ax1.scatter([best_epoch], [best_val_loss], color='g', s=100, zorder=5)
|
||||
ax1.annotate(f'Best: {best_val_loss:.6f}', xy=(best_epoch, best_val_loss),
|
||||
xytext=(best_epoch + 2, best_val_loss + 0.0005),
|
||||
fontsize=10, color='green')
|
||||
|
||||
# 2. 学习率曲线
|
||||
ax2 = axes[1]
|
||||
ax2.plot(epochs, self.history['learning_rate'], 'g-', linewidth=2)
|
||||
ax2.set_xlabel('Epoch', fontsize=12)
|
||||
ax2.set_ylabel('Learning Rate', fontsize=12)
|
||||
ax2.set_title('Learning Rate Schedule', fontsize=14)
|
||||
ax2.set_yscale('log')
|
||||
ax2.grid(True, alpha=0.3)
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
# 保存图片
|
||||
plot_path = self.save_dir / 'training_curves.png'
|
||||
plt.savefig(plot_path, dpi=150, bbox_inches='tight')
|
||||
plt.close()
|
||||
print(f"训练曲线已保存: {plot_path}")
|
||||
|
||||
|
||||
# ==================== 阈值计算(使用验证集)====================
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_thresholds(
|
||||
model: nn.Module,
|
||||
data_loader: DataLoader,
|
||||
device: torch.device,
|
||||
percentiles: List[float] = [90, 95, 99]
|
||||
) -> Dict[str, float]:
|
||||
"""
|
||||
在验证集上计算重构误差的百分位数阈值
|
||||
|
||||
注:使用验证集而非测试集,避免数据泄露
|
||||
"""
|
||||
model.eval()
|
||||
all_errors = []
|
||||
|
||||
print("计算异动阈值(使用验证集)...")
|
||||
for batch in tqdm(data_loader, desc="Computing thresholds"):
|
||||
batch = batch.to(device)
|
||||
errors = model.compute_reconstruction_error(batch, reduction='none')
|
||||
|
||||
# 取每个序列的最后一个时刻误差(预测当前时刻)
|
||||
seq_errors = errors[:, -1] # (batch,)
|
||||
all_errors.append(seq_errors.cpu().numpy())
|
||||
|
||||
all_errors = np.concatenate(all_errors)
|
||||
|
||||
thresholds = {}
|
||||
for p in percentiles:
|
||||
threshold = np.percentile(all_errors, p)
|
||||
thresholds[f'p{p}'] = float(threshold)
|
||||
print(f" P{p}: {threshold:.6f}")
|
||||
|
||||
# 额外统计
|
||||
thresholds['mean'] = float(np.mean(all_errors))
|
||||
thresholds['std'] = float(np.std(all_errors))
|
||||
thresholds['median'] = float(np.median(all_errors))
|
||||
|
||||
print(f" Mean: {thresholds['mean']:.6f}")
|
||||
print(f" Median: {thresholds['median']:.6f}")
|
||||
print(f" Std: {thresholds['std']:.6f}")
|
||||
|
||||
return thresholds
|
||||
|
||||
|
||||
# ==================== 主函数 ====================
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='训练概念异动检测模型')
|
||||
parser.add_argument('--data_dir', type=str, default='ml/data',
|
||||
help='数据目录路径')
|
||||
parser.add_argument('--epochs', type=int, default=100,
|
||||
help='训练轮数')
|
||||
parser.add_argument('--batch_size', type=int, default=4096,
|
||||
help='批次大小(4x RTX 4090 推荐 4096~8192)')
|
||||
parser.add_argument('--lr', type=float, default=3e-4,
|
||||
help='学习率(大 batch 推荐 3e-4)')
|
||||
parser.add_argument('--device', type=str, default='auto',
|
||||
help='设备 (auto/cuda/cpu)')
|
||||
parser.add_argument('--save_dir', type=str, default='ml/checkpoints',
|
||||
help='模型保存目录')
|
||||
parser.add_argument('--train_end', type=str, default='2024-06-30',
|
||||
help='训练集截止日期')
|
||||
parser.add_argument('--val_end', type=str, default='2024-09-30',
|
||||
help='验证集截止日期')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 更新配置
|
||||
config = TRAIN_CONFIG.copy()
|
||||
config['batch_size'] = args.batch_size
|
||||
config['epochs'] = args.epochs
|
||||
config['learning_rate'] = args.lr
|
||||
config['train_end_date'] = args.train_end
|
||||
config['val_end_date'] = args.val_end
|
||||
|
||||
# 设备选择
|
||||
if args.device == 'auto':
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
else:
|
||||
device = torch.device(args.device)
|
||||
|
||||
print("=" * 60)
|
||||
print("概念异动检测模型训练(修复版)")
|
||||
print("=" * 60)
|
||||
print(f"配置:")
|
||||
print(f" 数据目录: {args.data_dir}")
|
||||
print(f" 设备: {device}")
|
||||
print(f" 批次大小: {config['batch_size']}")
|
||||
print(f" 学习率: {config['learning_rate']}")
|
||||
print(f" 训练轮数: {config['epochs']}")
|
||||
print(f" 训练集截止: {config['train_end_date']}")
|
||||
print(f" 验证集截止: {config['val_end_date']}")
|
||||
print("=" * 60)
|
||||
|
||||
# 1. 按日期加载数据
|
||||
print("\n[1/6] 加载数据...")
|
||||
date_data = load_data_by_date(args.data_dir, config['features'])
|
||||
|
||||
# 2. 按日期划分
|
||||
print("\n[2/6] 按日期划分数据集...")
|
||||
train_data, val_data, test_data = split_data_by_date(
|
||||
date_data,
|
||||
config['train_end_date'],
|
||||
config['val_end_date']
|
||||
)
|
||||
|
||||
# 3. 按概念构建序列
|
||||
print("\n[3/6] 按概念构建序列...")
|
||||
print("训练集:")
|
||||
train_sequences = build_sequences_by_concept(
|
||||
train_data, config['features'], config['seq_len'], config['stride']
|
||||
)
|
||||
print("验证集:")
|
||||
val_sequences = build_sequences_by_concept(
|
||||
val_data, config['features'], config['seq_len'], config['stride']
|
||||
)
|
||||
print("测试集:")
|
||||
test_sequences = build_sequences_by_concept(
|
||||
test_data, config['features'], config['seq_len'], config['stride']
|
||||
)
|
||||
|
||||
if len(train_sequences) == 0:
|
||||
print("错误: 训练集为空!请检查数据和日期范围")
|
||||
return
|
||||
|
||||
# 4. 数据预处理(简单截断极端值,标准化在模型内部通过 Instance Norm 完成)
|
||||
print("\n[4/6] 数据预处理...")
|
||||
print(" 注意: 使用 Instance Norm,每个序列在模型内部单独标准化")
|
||||
print(" 这样可以处理不同概念波动率差异(银行 vs 半导体)")
|
||||
|
||||
clip_value = config['clip_value']
|
||||
print(f" 截断极端值: ±{clip_value}")
|
||||
|
||||
# 简单截断极端值(防止异常数据影响训练)
|
||||
train_sequences = np.clip(train_sequences, -clip_value, clip_value)
|
||||
if len(val_sequences) > 0:
|
||||
val_sequences = np.clip(val_sequences, -clip_value, clip_value)
|
||||
if len(test_sequences) > 0:
|
||||
test_sequences = np.clip(test_sequences, -clip_value, clip_value)
|
||||
|
||||
# 保存配置
|
||||
save_dir = Path(args.save_dir)
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
preprocess_params = {
|
||||
'features': config['features'],
|
||||
'normalization': 'instance_norm', # 在模型内部完成
|
||||
'clip_value': clip_value,
|
||||
'note': '标准化在模型内部通过 InstanceNorm1d 完成,无需外部 Scaler'
|
||||
}
|
||||
|
||||
with open(save_dir / 'normalization_stats.json', 'w') as f:
|
||||
json.dump(preprocess_params, f, indent=2)
|
||||
print(f" 预处理参数已保存")
|
||||
|
||||
# 5. 创建数据集和加载器
|
||||
print("\n[5/6] 创建数据加载器...")
|
||||
train_dataset = SequenceDataset(train_sequences)
|
||||
val_dataset = SequenceDataset(val_sequences) if len(val_sequences) > 0 else None
|
||||
test_dataset = SequenceDataset(test_sequences) if len(test_sequences) > 0 else None
|
||||
|
||||
print(f" 训练序列: {len(train_dataset):,}")
|
||||
print(f" 验证序列: {len(val_dataset) if val_dataset else 0:,}")
|
||||
print(f" 测试序列: {len(test_dataset) if test_dataset else 0:,}")
|
||||
|
||||
# 多卡时增加 num_workers(Linux 上可以用更多)
|
||||
n_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 1
|
||||
num_workers = min(32, 8 * n_gpus) if sys.platform != 'win32' else 0
|
||||
print(f" DataLoader workers: {num_workers}")
|
||||
print(f" Batch size: {config['batch_size']}")
|
||||
|
||||
# 大 batch + 多 worker + prefetch 提速
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=config['batch_size'],
|
||||
shuffle=True,
|
||||
num_workers=num_workers,
|
||||
pin_memory=True,
|
||||
prefetch_factor=4 if num_workers > 0 else None, # 预取更多 batch
|
||||
persistent_workers=True if num_workers > 0 else False, # 保持 worker 存活
|
||||
drop_last=True # 丢弃不完整的最后一批,避免 batch 大小不一致
|
||||
)
|
||||
|
||||
val_loader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=config['batch_size'] * 2, # 验证时可以用更大 batch(无梯度)
|
||||
shuffle=False,
|
||||
num_workers=num_workers,
|
||||
pin_memory=True,
|
||||
prefetch_factor=4 if num_workers > 0 else None,
|
||||
persistent_workers=True if num_workers > 0 else False,
|
||||
) if val_dataset else None
|
||||
|
||||
test_loader = DataLoader(
|
||||
test_dataset,
|
||||
batch_size=config['batch_size'] * 2,
|
||||
shuffle=False,
|
||||
num_workers=num_workers,
|
||||
pin_memory=True,
|
||||
prefetch_factor=4 if num_workers > 0 else None,
|
||||
persistent_workers=True if num_workers > 0 else False,
|
||||
) if test_dataset else None
|
||||
|
||||
# 6. 训练
|
||||
print("\n[6/6] 训练模型...")
|
||||
model_config = config['model'].copy()
|
||||
model = TransformerAutoencoder(**model_config)
|
||||
|
||||
# 多卡并行
|
||||
if torch.cuda.device_count() > 1:
|
||||
print(f" 使用 {torch.cuda.device_count()} 张 GPU 并行训练")
|
||||
model = nn.DataParallel(model)
|
||||
|
||||
if val_loader is None:
|
||||
print("警告: 验证集为空,将使用训练集的一部分作为验证")
|
||||
# 简单处理:用训练集的后 10% 作为验证
|
||||
split_idx = int(len(train_dataset) * 0.9)
|
||||
train_subset = torch.utils.data.Subset(train_dataset, range(split_idx))
|
||||
val_subset = torch.utils.data.Subset(train_dataset, range(split_idx, len(train_dataset)))
|
||||
|
||||
train_loader = DataLoader(train_subset, batch_size=config['batch_size'], shuffle=True, num_workers=num_workers, pin_memory=True)
|
||||
val_loader = DataLoader(val_subset, batch_size=config['batch_size'], shuffle=False, num_workers=num_workers, pin_memory=True)
|
||||
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
train_loader=train_loader,
|
||||
val_loader=val_loader,
|
||||
config=config,
|
||||
device=device,
|
||||
save_dir=args.save_dir
|
||||
)
|
||||
|
||||
history = trainer.train(config['epochs'])
|
||||
|
||||
# 7. 计算阈值(使用验证集)
|
||||
print("\n[额外] 计算异动阈值...")
|
||||
|
||||
# 加载最佳模型
|
||||
best_checkpoint = torch.load(
|
||||
save_dir / 'best_model.pt',
|
||||
map_location=device
|
||||
)
|
||||
model.load_state_dict(best_checkpoint['model_state_dict'])
|
||||
model.to(device)
|
||||
|
||||
# 使用验证集计算阈值(避免数据泄露)
|
||||
thresholds = compute_thresholds(
|
||||
model,
|
||||
val_loader,
|
||||
device,
|
||||
config['threshold_percentiles']
|
||||
)
|
||||
|
||||
# 保存阈值
|
||||
with open(save_dir / 'thresholds.json', 'w') as f:
|
||||
json.dump(thresholds, f, indent=2)
|
||||
print(f"阈值已保存")
|
||||
|
||||
# 保存完整配置
|
||||
with open(save_dir / 'config.json', 'w') as f:
|
||||
json.dump(config, f, indent=2)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("训练完成!")
|
||||
print("=" * 60)
|
||||
print(f"模型保存位置: {args.save_dir}")
|
||||
print(f" - best_model.pt: 最佳模型权重")
|
||||
print(f" - thresholds.json: 异动阈值")
|
||||
print(f" - normalization_stats.json: 标准化参数")
|
||||
print(f" - config.json: 训练配置")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
622
ml/train_v2.py
Normal file
622
ml/train_v2.py
Normal file
@@ -0,0 +1,622 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
训练脚本 V2 - 基于 Z-Score 特征的 LSTM Autoencoder
|
||||
|
||||
改进点:
|
||||
1. 使用 Z-Score 特征(相对于同时间片历史的偏离)
|
||||
2. 短序列:10分钟(不需要30分钟预热)
|
||||
3. 开盘即可检测:9:30 直接有特征
|
||||
|
||||
模型输入:
|
||||
- 过去10分钟的 Z-Score 特征序列
|
||||
- 特征:alpha_zscore, amt_zscore, rank_zscore, momentum_3m, momentum_5m, limit_up_ratio
|
||||
|
||||
模型学习:
|
||||
- 学习 Z-Score 序列的"正常演化模式"
|
||||
- 异动 = Z-Score 序列的异常演化(重构误差大)
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple, Dict
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from torch.optim import AdamW
|
||||
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
|
||||
from tqdm import tqdm
|
||||
|
||||
from model import TransformerAutoencoder, AnomalyDetectionLoss, count_parameters
|
||||
|
||||
# 性能优化
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
try:
|
||||
import matplotlib
|
||||
matplotlib.use('Agg')
|
||||
import matplotlib.pyplot as plt
|
||||
HAS_MATPLOTLIB = True
|
||||
except ImportError:
|
||||
HAS_MATPLOTLIB = False
|
||||
|
||||
|
||||
# ==================== 配置 ====================
|
||||
|
||||
TRAIN_CONFIG = {
|
||||
# 数据配置(改进!)
|
||||
'seq_len': 10, # 10分钟序列(不是30分钟!)
|
||||
'stride': 2, # 步长2分钟
|
||||
|
||||
# 时间切分
|
||||
'train_end_date': '2024-06-30',
|
||||
'val_end_date': '2024-09-30',
|
||||
|
||||
# V2 特征(Z-Score 为主)
|
||||
'features': [
|
||||
'alpha_zscore', # Alpha 的 Z-Score
|
||||
'amt_zscore', # 成交额的 Z-Score
|
||||
'rank_zscore', # 排名的 Z-Score
|
||||
'momentum_3m', # 3分钟动量
|
||||
'momentum_5m', # 5分钟动量
|
||||
'limit_up_ratio', # 涨停占比
|
||||
],
|
||||
|
||||
# 训练配置
|
||||
'batch_size': 4096,
|
||||
'epochs': 100,
|
||||
'learning_rate': 3e-4,
|
||||
'weight_decay': 1e-5,
|
||||
'gradient_clip': 1.0,
|
||||
|
||||
# 早停配置
|
||||
'patience': 15,
|
||||
'min_delta': 1e-6,
|
||||
|
||||
# 模型配置(小型 LSTM)
|
||||
'model': {
|
||||
'n_features': 6,
|
||||
'hidden_dim': 32,
|
||||
'latent_dim': 4,
|
||||
'num_layers': 1,
|
||||
'dropout': 0.2,
|
||||
'bidirectional': True,
|
||||
},
|
||||
|
||||
# 标准化配置
|
||||
'clip_value': 5.0, # Z-Score 已经标准化,clip 5.0 足够
|
||||
|
||||
# 阈值配置
|
||||
'threshold_percentiles': [90, 95, 99],
|
||||
}
|
||||
|
||||
|
||||
# ==================== 数据加载 ====================
|
||||
|
||||
def load_data_by_date(data_dir: str, features: List[str]) -> Dict[str, pd.DataFrame]:
|
||||
"""按日期加载 V2 数据"""
|
||||
data_path = Path(data_dir)
|
||||
parquet_files = sorted(data_path.glob("features_v2_*.parquet"))
|
||||
|
||||
if not parquet_files:
|
||||
raise FileNotFoundError(f"未找到 V2 数据文件: {data_dir}")
|
||||
|
||||
print(f"找到 {len(parquet_files)} 个 V2 数据文件")
|
||||
|
||||
date_data = {}
|
||||
|
||||
for pf in tqdm(parquet_files, desc="加载数据"):
|
||||
date = pf.stem.replace('features_v2_', '')
|
||||
|
||||
df = pd.read_parquet(pf)
|
||||
|
||||
required_cols = features + ['concept_id', 'timestamp']
|
||||
missing_cols = [c for c in required_cols if c not in df.columns]
|
||||
if missing_cols:
|
||||
print(f"警告: {date} 缺少列: {missing_cols}, 跳过")
|
||||
continue
|
||||
|
||||
date_data[date] = df
|
||||
|
||||
print(f"成功加载 {len(date_data)} 天的数据")
|
||||
return date_data
|
||||
|
||||
|
||||
def split_data_by_date(
|
||||
date_data: Dict[str, pd.DataFrame],
|
||||
train_end: str,
|
||||
val_end: str
|
||||
) -> Tuple[Dict[str, pd.DataFrame], Dict[str, pd.DataFrame], Dict[str, pd.DataFrame]]:
|
||||
"""按日期划分数据集"""
|
||||
train_data = {}
|
||||
val_data = {}
|
||||
test_data = {}
|
||||
|
||||
for date, df in date_data.items():
|
||||
if date <= train_end:
|
||||
train_data[date] = df
|
||||
elif date <= val_end:
|
||||
val_data[date] = df
|
||||
else:
|
||||
test_data[date] = df
|
||||
|
||||
print(f"数据集划分:")
|
||||
print(f" 训练集: {len(train_data)} 天 (<= {train_end})")
|
||||
print(f" 验证集: {len(val_data)} 天 ({train_end} ~ {val_end})")
|
||||
print(f" 测试集: {len(test_data)} 天 (> {val_end})")
|
||||
|
||||
return train_data, val_data, test_data
|
||||
|
||||
|
||||
def build_sequences_by_concept(
|
||||
date_data: Dict[str, pd.DataFrame],
|
||||
features: List[str],
|
||||
seq_len: int,
|
||||
stride: int
|
||||
) -> np.ndarray:
|
||||
"""按概念分组构建序列"""
|
||||
all_dfs = []
|
||||
for date, df in sorted(date_data.items()):
|
||||
df = df.copy()
|
||||
df['date'] = date
|
||||
all_dfs.append(df)
|
||||
|
||||
if not all_dfs:
|
||||
return np.array([])
|
||||
|
||||
combined = pd.concat(all_dfs, ignore_index=True)
|
||||
combined = combined.sort_values(['concept_id', 'date', 'timestamp'])
|
||||
|
||||
all_sequences = []
|
||||
grouped = combined.groupby('concept_id', sort=False)
|
||||
n_concepts = len(grouped)
|
||||
|
||||
for concept_id, concept_df in tqdm(grouped, desc="构建序列", total=n_concepts, leave=False):
|
||||
feature_data = concept_df[features].values
|
||||
feature_data = np.nan_to_num(feature_data, nan=0.0, posinf=0.0, neginf=0.0)
|
||||
|
||||
n_points = len(feature_data)
|
||||
for start in range(0, n_points - seq_len + 1, stride):
|
||||
seq = feature_data[start:start + seq_len]
|
||||
all_sequences.append(seq)
|
||||
|
||||
if not all_sequences:
|
||||
return np.array([])
|
||||
|
||||
sequences = np.array(all_sequences)
|
||||
print(f" 构建序列: {len(sequences):,} 条 (来自 {n_concepts} 个概念)")
|
||||
|
||||
return sequences
|
||||
|
||||
|
||||
# ==================== 数据集 ====================
|
||||
|
||||
class SequenceDataset(Dataset):
|
||||
def __init__(self, sequences: np.ndarray):
|
||||
self.sequences = torch.FloatTensor(sequences)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.sequences)
|
||||
|
||||
def __getitem__(self, idx: int) -> torch.Tensor:
|
||||
return self.sequences[idx]
|
||||
|
||||
|
||||
# ==================== 训练器 ====================
|
||||
|
||||
class EarlyStopping:
|
||||
def __init__(self, patience: int = 10, min_delta: float = 1e-6):
|
||||
self.patience = patience
|
||||
self.min_delta = min_delta
|
||||
self.counter = 0
|
||||
self.best_loss = float('inf')
|
||||
self.early_stop = False
|
||||
|
||||
def __call__(self, val_loss: float) -> bool:
|
||||
if val_loss < self.best_loss - self.min_delta:
|
||||
self.best_loss = val_loss
|
||||
self.counter = 0
|
||||
else:
|
||||
self.counter += 1
|
||||
if self.counter >= self.patience:
|
||||
self.early_stop = True
|
||||
return self.early_stop
|
||||
|
||||
|
||||
class Trainer:
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
train_loader: DataLoader,
|
||||
val_loader: DataLoader,
|
||||
config: Dict,
|
||||
device: torch.device,
|
||||
save_dir: str = 'ml/checkpoints_v2'
|
||||
):
|
||||
self.model = model.to(device)
|
||||
self.train_loader = train_loader
|
||||
self.val_loader = val_loader
|
||||
self.config = config
|
||||
self.device = device
|
||||
self.save_dir = Path(save_dir)
|
||||
self.save_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.optimizer = AdamW(
|
||||
model.parameters(),
|
||||
lr=config['learning_rate'],
|
||||
weight_decay=config['weight_decay']
|
||||
)
|
||||
|
||||
self.scheduler = CosineAnnealingWarmRestarts(
|
||||
self.optimizer, T_0=10, T_mult=2, eta_min=1e-6
|
||||
)
|
||||
|
||||
self.criterion = AnomalyDetectionLoss()
|
||||
|
||||
self.early_stopping = EarlyStopping(
|
||||
patience=config['patience'],
|
||||
min_delta=config['min_delta']
|
||||
)
|
||||
|
||||
self.use_amp = torch.cuda.is_available()
|
||||
self.scaler = torch.cuda.amp.GradScaler() if self.use_amp else None
|
||||
if self.use_amp:
|
||||
print(" ✓ 启用 AMP 混合精度训练")
|
||||
|
||||
self.history = {'train_loss': [], 'val_loss': [], 'learning_rate': []}
|
||||
self.best_val_loss = float('inf')
|
||||
|
||||
def train_epoch(self) -> float:
|
||||
self.model.train()
|
||||
total_loss = 0.0
|
||||
n_batches = 0
|
||||
|
||||
pbar = tqdm(self.train_loader, desc="Training", leave=False)
|
||||
for batch in pbar:
|
||||
batch = batch.to(self.device, non_blocking=True)
|
||||
self.optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
if self.use_amp:
|
||||
with torch.cuda.amp.autocast():
|
||||
output, latent = self.model(batch)
|
||||
loss, _ = self.criterion(output, batch, latent)
|
||||
|
||||
self.scaler.scale(loss).backward()
|
||||
self.scaler.unscale_(self.optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config['gradient_clip'])
|
||||
self.scaler.step(self.optimizer)
|
||||
self.scaler.update()
|
||||
else:
|
||||
output, latent = self.model(batch)
|
||||
loss, _ = self.criterion(output, batch, latent)
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config['gradient_clip'])
|
||||
self.optimizer.step()
|
||||
|
||||
total_loss += loss.item()
|
||||
n_batches += 1
|
||||
pbar.set_postfix({'loss': f"{loss.item():.4f}"})
|
||||
|
||||
return total_loss / n_batches
|
||||
|
||||
@torch.no_grad()
|
||||
def validate(self) -> float:
|
||||
self.model.eval()
|
||||
total_loss = 0.0
|
||||
n_batches = 0
|
||||
|
||||
for batch in self.val_loader:
|
||||
batch = batch.to(self.device, non_blocking=True)
|
||||
|
||||
if self.use_amp:
|
||||
with torch.cuda.amp.autocast():
|
||||
output, latent = self.model(batch)
|
||||
loss, _ = self.criterion(output, batch, latent)
|
||||
else:
|
||||
output, latent = self.model(batch)
|
||||
loss, _ = self.criterion(output, batch, latent)
|
||||
|
||||
total_loss += loss.item()
|
||||
n_batches += 1
|
||||
|
||||
return total_loss / n_batches
|
||||
|
||||
def save_checkpoint(self, epoch: int, val_loss: float, is_best: bool = False):
|
||||
model_to_save = self.model.module if hasattr(self.model, 'module') else self.model
|
||||
|
||||
checkpoint = {
|
||||
'epoch': epoch,
|
||||
'model_state_dict': model_to_save.state_dict(),
|
||||
'optimizer_state_dict': self.optimizer.state_dict(),
|
||||
'scheduler_state_dict': self.scheduler.state_dict(),
|
||||
'val_loss': val_loss,
|
||||
'config': self.config,
|
||||
}
|
||||
|
||||
torch.save(checkpoint, self.save_dir / 'last_checkpoint.pt')
|
||||
|
||||
if is_best:
|
||||
torch.save(checkpoint, self.save_dir / 'best_model.pt')
|
||||
print(f" ✓ 保存最佳模型 (val_loss: {val_loss:.6f})")
|
||||
|
||||
def train(self, epochs: int):
|
||||
print(f"\n开始训练 ({epochs} epochs)...")
|
||||
print(f"设备: {self.device}")
|
||||
print(f"模型参数量: {count_parameters(self.model):,}")
|
||||
|
||||
for epoch in range(1, epochs + 1):
|
||||
print(f"\nEpoch {epoch}/{epochs}")
|
||||
|
||||
train_loss = self.train_epoch()
|
||||
val_loss = self.validate()
|
||||
|
||||
self.scheduler.step()
|
||||
current_lr = self.optimizer.param_groups[0]['lr']
|
||||
|
||||
self.history['train_loss'].append(train_loss)
|
||||
self.history['val_loss'].append(val_loss)
|
||||
self.history['learning_rate'].append(current_lr)
|
||||
|
||||
print(f" Train Loss: {train_loss:.6f}")
|
||||
print(f" Val Loss: {val_loss:.6f}")
|
||||
print(f" LR: {current_lr:.2e}")
|
||||
|
||||
is_best = val_loss < self.best_val_loss
|
||||
if is_best:
|
||||
self.best_val_loss = val_loss
|
||||
self.save_checkpoint(epoch, val_loss, is_best)
|
||||
|
||||
if self.early_stopping(val_loss):
|
||||
print(f"\n早停触发!")
|
||||
break
|
||||
|
||||
print(f"\n训练完成!最佳验证损失: {self.best_val_loss:.6f}")
|
||||
self.save_history()
|
||||
|
||||
return self.history
|
||||
|
||||
def save_history(self):
|
||||
history_path = self.save_dir / 'training_history.json'
|
||||
with open(history_path, 'w') as f:
|
||||
json.dump(self.history, f, indent=2)
|
||||
print(f"训练历史已保存: {history_path}")
|
||||
|
||||
if HAS_MATPLOTLIB:
|
||||
self.plot_training_curves()
|
||||
|
||||
def plot_training_curves(self):
|
||||
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
|
||||
epochs = range(1, len(self.history['train_loss']) + 1)
|
||||
|
||||
ax1 = axes[0]
|
||||
ax1.plot(epochs, self.history['train_loss'], 'b-', label='Train Loss', linewidth=2)
|
||||
ax1.plot(epochs, self.history['val_loss'], 'r-', label='Val Loss', linewidth=2)
|
||||
ax1.set_xlabel('Epoch')
|
||||
ax1.set_ylabel('Loss')
|
||||
ax1.set_title('Training & Validation Loss (V2)')
|
||||
ax1.legend()
|
||||
ax1.grid(True, alpha=0.3)
|
||||
|
||||
best_epoch = np.argmin(self.history['val_loss']) + 1
|
||||
best_val_loss = min(self.history['val_loss'])
|
||||
ax1.axvline(x=best_epoch, color='g', linestyle='--', alpha=0.7)
|
||||
ax1.scatter([best_epoch], [best_val_loss], color='g', s=100, zorder=5)
|
||||
|
||||
ax2 = axes[1]
|
||||
ax2.plot(epochs, self.history['learning_rate'], 'g-', linewidth=2)
|
||||
ax2.set_xlabel('Epoch')
|
||||
ax2.set_ylabel('Learning Rate')
|
||||
ax2.set_title('Learning Rate Schedule')
|
||||
ax2.set_yscale('log')
|
||||
ax2.grid(True, alpha=0.3)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(self.save_dir / 'training_curves.png', dpi=150, bbox_inches='tight')
|
||||
plt.close()
|
||||
print(f"训练曲线已保存")
|
||||
|
||||
|
||||
# ==================== 阈值计算 ====================
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_thresholds(
|
||||
model: nn.Module,
|
||||
data_loader: DataLoader,
|
||||
device: torch.device,
|
||||
percentiles: List[float] = [90, 95, 99]
|
||||
) -> Dict[str, float]:
|
||||
"""在验证集上计算阈值"""
|
||||
model.eval()
|
||||
all_errors = []
|
||||
|
||||
print("计算异动阈值...")
|
||||
for batch in tqdm(data_loader, desc="Computing thresholds"):
|
||||
batch = batch.to(device)
|
||||
errors = model.compute_reconstruction_error(batch, reduction='none')
|
||||
seq_errors = errors[:, -1] # 最后一个时刻
|
||||
all_errors.append(seq_errors.cpu().numpy())
|
||||
|
||||
all_errors = np.concatenate(all_errors)
|
||||
|
||||
thresholds = {}
|
||||
for p in percentiles:
|
||||
threshold = np.percentile(all_errors, p)
|
||||
thresholds[f'p{p}'] = float(threshold)
|
||||
print(f" P{p}: {threshold:.6f}")
|
||||
|
||||
thresholds['mean'] = float(np.mean(all_errors))
|
||||
thresholds['std'] = float(np.std(all_errors))
|
||||
thresholds['median'] = float(np.median(all_errors))
|
||||
|
||||
return thresholds
|
||||
|
||||
|
||||
# ==================== 主函数 ====================
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='训练 V2 模型')
|
||||
parser.add_argument('--data_dir', type=str, default='ml/data_v2', help='V2 数据目录')
|
||||
parser.add_argument('--epochs', type=int, default=100)
|
||||
parser.add_argument('--batch_size', type=int, default=4096)
|
||||
parser.add_argument('--lr', type=float, default=3e-4)
|
||||
parser.add_argument('--device', type=str, default='auto')
|
||||
parser.add_argument('--save_dir', type=str, default='ml/checkpoints_v2')
|
||||
parser.add_argument('--train_end', type=str, default='2024-06-30')
|
||||
parser.add_argument('--val_end', type=str, default='2024-09-30')
|
||||
parser.add_argument('--seq_len', type=int, default=10, help='序列长度(分钟)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
config = TRAIN_CONFIG.copy()
|
||||
config['batch_size'] = args.batch_size
|
||||
config['epochs'] = args.epochs
|
||||
config['learning_rate'] = args.lr
|
||||
config['train_end_date'] = args.train_end
|
||||
config['val_end_date'] = args.val_end
|
||||
config['seq_len'] = args.seq_len
|
||||
|
||||
if args.device == 'auto':
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
else:
|
||||
device = torch.device(args.device)
|
||||
|
||||
print("=" * 60)
|
||||
print("概念异动检测模型训练 V2(Z-Score 特征)")
|
||||
print("=" * 60)
|
||||
print(f"数据目录: {args.data_dir}")
|
||||
print(f"设备: {device}")
|
||||
print(f"序列长度: {config['seq_len']} 分钟")
|
||||
print(f"批次大小: {config['batch_size']}")
|
||||
print(f"特征: {config['features']}")
|
||||
print("=" * 60)
|
||||
|
||||
# 1. 加载数据
|
||||
print("\n[1/6] 加载 V2 数据...")
|
||||
date_data = load_data_by_date(args.data_dir, config['features'])
|
||||
|
||||
# 2. 划分数据集
|
||||
print("\n[2/6] 划分数据集...")
|
||||
train_data, val_data, test_data = split_data_by_date(
|
||||
date_data, config['train_end_date'], config['val_end_date']
|
||||
)
|
||||
|
||||
# 3. 构建序列
|
||||
print("\n[3/6] 构建序列...")
|
||||
print("训练集:")
|
||||
train_sequences = build_sequences_by_concept(
|
||||
train_data, config['features'], config['seq_len'], config['stride']
|
||||
)
|
||||
print("验证集:")
|
||||
val_sequences = build_sequences_by_concept(
|
||||
val_data, config['features'], config['seq_len'], config['stride']
|
||||
)
|
||||
|
||||
if len(train_sequences) == 0:
|
||||
print("错误: 训练集为空!")
|
||||
return
|
||||
|
||||
# 4. 预处理
|
||||
print("\n[4/6] 数据预处理...")
|
||||
clip_value = config['clip_value']
|
||||
print(f" Z-Score 特征已标准化,截断: ±{clip_value}")
|
||||
|
||||
train_sequences = np.clip(train_sequences, -clip_value, clip_value)
|
||||
if len(val_sequences) > 0:
|
||||
val_sequences = np.clip(val_sequences, -clip_value, clip_value)
|
||||
|
||||
# 保存配置
|
||||
save_dir = Path(args.save_dir)
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(save_dir / 'config.json', 'w') as f:
|
||||
json.dump(config, f, indent=2)
|
||||
|
||||
# 5. 创建数据加载器
|
||||
print("\n[5/6] 创建数据加载器...")
|
||||
train_dataset = SequenceDataset(train_sequences)
|
||||
val_dataset = SequenceDataset(val_sequences) if len(val_sequences) > 0 else None
|
||||
|
||||
print(f" 训练序列: {len(train_dataset):,}")
|
||||
print(f" 验证序列: {len(val_dataset) if val_dataset else 0:,}")
|
||||
|
||||
n_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 1
|
||||
num_workers = min(32, 8 * n_gpus) if sys.platform != 'win32' else 0
|
||||
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=config['batch_size'],
|
||||
shuffle=True,
|
||||
num_workers=num_workers,
|
||||
pin_memory=True,
|
||||
prefetch_factor=4 if num_workers > 0 else None,
|
||||
persistent_workers=True if num_workers > 0 else False,
|
||||
drop_last=True
|
||||
)
|
||||
|
||||
val_loader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=config['batch_size'] * 2,
|
||||
shuffle=False,
|
||||
num_workers=num_workers,
|
||||
pin_memory=True,
|
||||
) if val_dataset else None
|
||||
|
||||
# 6. 训练
|
||||
print("\n[6/6] 训练模型...")
|
||||
model = TransformerAutoencoder(**config['model'])
|
||||
|
||||
if torch.cuda.device_count() > 1:
|
||||
print(f" 使用 {torch.cuda.device_count()} 张 GPU 并行训练")
|
||||
model = nn.DataParallel(model)
|
||||
|
||||
if val_loader is None:
|
||||
print("警告: 验证集为空,使用训练集的 10% 作为验证")
|
||||
split_idx = int(len(train_dataset) * 0.9)
|
||||
train_subset = torch.utils.data.Subset(train_dataset, range(split_idx))
|
||||
val_subset = torch.utils.data.Subset(train_dataset, range(split_idx, len(train_dataset)))
|
||||
train_loader = DataLoader(train_subset, batch_size=config['batch_size'], shuffle=True, num_workers=num_workers, pin_memory=True)
|
||||
val_loader = DataLoader(val_subset, batch_size=config['batch_size'], shuffle=False, num_workers=num_workers, pin_memory=True)
|
||||
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
train_loader=train_loader,
|
||||
val_loader=val_loader,
|
||||
config=config,
|
||||
device=device,
|
||||
save_dir=args.save_dir
|
||||
)
|
||||
|
||||
trainer.train(config['epochs'])
|
||||
|
||||
# 计算阈值
|
||||
print("\n[额外] 计算异动阈值...")
|
||||
best_checkpoint = torch.load(save_dir / 'best_model.pt', map_location=device)
|
||||
|
||||
# 创建新的单 GPU 模型用于计算阈值(避免 DataParallel 问题)
|
||||
threshold_model = TransformerAutoencoder(**config['model'])
|
||||
threshold_model.load_state_dict(best_checkpoint['model_state_dict'])
|
||||
threshold_model.to(device)
|
||||
threshold_model.eval()
|
||||
|
||||
thresholds = compute_thresholds(threshold_model, val_loader, device, config['threshold_percentiles'])
|
||||
|
||||
with open(save_dir / 'thresholds.json', 'w') as f:
|
||||
json.dump(thresholds, f, indent=2)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("训练完成!")
|
||||
print(f"模型保存位置: {args.save_dir}")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
132
ml/update_baseline.py
Normal file
132
ml/update_baseline.py
Normal file
@@ -0,0 +1,132 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
每日盘后运行:更新滚动基线
|
||||
|
||||
使用方法:
|
||||
python ml/update_baseline.py
|
||||
|
||||
建议加入 crontab,每天 15:30 后运行:
|
||||
30 15 * * 1-5 cd /path/to/project && python ml/update_baseline.py
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import pickle
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from ml.prepare_data_v2 import (
|
||||
get_all_concepts, get_trading_days, compute_raw_concept_features,
|
||||
init_process_connections, CONFIG, RAW_CACHE_DIR, BASELINE_DIR
|
||||
)
|
||||
|
||||
|
||||
def update_rolling_baseline(baseline_days: int = 20):
|
||||
"""
|
||||
更新滚动基线(用于实盘检测)
|
||||
|
||||
基线 = 最近 N 个交易日每个时间片的统计量
|
||||
"""
|
||||
print("=" * 60)
|
||||
print("更新滚动基线(用于实盘)")
|
||||
print("=" * 60)
|
||||
|
||||
# 初始化连接
|
||||
init_process_connections()
|
||||
|
||||
# 获取概念列表
|
||||
concepts = get_all_concepts()
|
||||
all_stocks = list(set(s for c in concepts for s in c['stocks']))
|
||||
|
||||
# 获取最近的交易日
|
||||
today = datetime.now().strftime('%Y-%m-%d')
|
||||
start_date = (datetime.now() - timedelta(days=60)).strftime('%Y-%m-%d') # 多取一些
|
||||
|
||||
trading_days = get_trading_days(start_date, today)
|
||||
|
||||
if len(trading_days) < baseline_days:
|
||||
print(f"错误:交易日不足 {baseline_days} 天")
|
||||
return
|
||||
|
||||
# 只取最近 N 天
|
||||
recent_days = trading_days[-baseline_days:]
|
||||
print(f"使用 {len(recent_days)} 天数据: {recent_days[0]} ~ {recent_days[-1]}")
|
||||
|
||||
# 加载原始数据
|
||||
all_data = []
|
||||
for trade_date in tqdm(recent_days, desc="加载数据"):
|
||||
cache_file = os.path.join(RAW_CACHE_DIR, f'raw_{trade_date}.parquet')
|
||||
|
||||
if os.path.exists(cache_file):
|
||||
df = pd.read_parquet(cache_file)
|
||||
else:
|
||||
df = compute_raw_concept_features(trade_date, concepts, all_stocks)
|
||||
|
||||
if not df.empty:
|
||||
all_data.append(df)
|
||||
|
||||
if not all_data:
|
||||
print("错误:无数据")
|
||||
return
|
||||
|
||||
combined = pd.concat(all_data, ignore_index=True)
|
||||
print(f"总数据量: {len(combined):,} 条")
|
||||
|
||||
# 按概念计算基线
|
||||
baselines = {}
|
||||
|
||||
for concept_id, group in tqdm(combined.groupby('concept_id'), desc="计算基线"):
|
||||
baseline_dict = {}
|
||||
|
||||
for time_slot, slot_group in group.groupby('time_slot'):
|
||||
if len(slot_group) < CONFIG['min_baseline_samples']:
|
||||
continue
|
||||
|
||||
alpha_std = slot_group['alpha'].std()
|
||||
amt_std = slot_group['total_amt'].std()
|
||||
rank_std = slot_group['rank_pct'].std()
|
||||
|
||||
baseline_dict[time_slot] = {
|
||||
'alpha_mean': float(slot_group['alpha'].mean()),
|
||||
'alpha_std': float(max(alpha_std if pd.notna(alpha_std) else 1.0, 0.1)),
|
||||
'amt_mean': float(slot_group['total_amt'].mean()),
|
||||
'amt_std': float(max(amt_std if pd.notna(amt_std) else slot_group['total_amt'].mean() * 0.5, 1.0)),
|
||||
'rank_mean': float(slot_group['rank_pct'].mean()),
|
||||
'rank_std': float(max(rank_std if pd.notna(rank_std) else 0.2, 0.05)),
|
||||
'sample_count': len(slot_group),
|
||||
}
|
||||
|
||||
if baseline_dict:
|
||||
baselines[concept_id] = baseline_dict
|
||||
|
||||
print(f"计算了 {len(baselines)} 个概念的基线")
|
||||
|
||||
# 保存
|
||||
os.makedirs(BASELINE_DIR, exist_ok=True)
|
||||
baseline_file = os.path.join(BASELINE_DIR, 'realtime_baseline.pkl')
|
||||
|
||||
with open(baseline_file, 'wb') as f:
|
||||
pickle.dump({
|
||||
'baselines': baselines,
|
||||
'update_time': datetime.now().isoformat(),
|
||||
'date_range': [recent_days[0], recent_days[-1]],
|
||||
'baseline_days': baseline_days,
|
||||
}, f)
|
||||
|
||||
print(f"基线已保存: {baseline_file}")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--days', type=int, default=20, help='基线天数')
|
||||
args = parser.parse_args()
|
||||
|
||||
update_rolling_baseline(args.days)
|
||||
68
sql/concept_minute_alert.sql
Normal file
68
sql/concept_minute_alert.sql
Normal file
@@ -0,0 +1,68 @@
|
||||
-- 概念分钟级异动数据表
|
||||
-- 用于存储概念板块的实时异动信息,支持热点概览图表展示
|
||||
|
||||
CREATE TABLE IF NOT EXISTS concept_minute_alert (
|
||||
id BIGINT AUTO_INCREMENT PRIMARY KEY,
|
||||
concept_id VARCHAR(32) NOT NULL COMMENT '概念ID',
|
||||
concept_name VARCHAR(100) NOT NULL COMMENT '概念名称',
|
||||
alert_time DATETIME NOT NULL COMMENT '异动时间(精确到分钟)',
|
||||
alert_type VARCHAR(20) NOT NULL COMMENT '异动类型:surge(急涨)/limit_up(涨停增加)/rank_jump(排名跃升)',
|
||||
trade_date DATE NOT NULL COMMENT '交易日期',
|
||||
|
||||
-- 涨跌幅相关
|
||||
change_pct DECIMAL(10,4) COMMENT '当时涨跌幅(%)',
|
||||
prev_change_pct DECIMAL(10,4) COMMENT '之前涨跌幅(%)',
|
||||
change_delta DECIMAL(10,4) COMMENT '涨幅变化量(%)',
|
||||
|
||||
-- 涨停相关
|
||||
limit_up_count INT DEFAULT 0 COMMENT '当前涨停数量',
|
||||
prev_limit_up_count INT DEFAULT 0 COMMENT '之前涨停数量',
|
||||
limit_up_delta INT DEFAULT 0 COMMENT '涨停变化数量',
|
||||
|
||||
-- 排名相关
|
||||
rank_position INT COMMENT '当前涨幅排名',
|
||||
prev_rank_position INT COMMENT '之前涨幅排名',
|
||||
rank_delta INT COMMENT '排名变化(负数表示上升)',
|
||||
|
||||
-- 指数位置(用于图表Y轴定位)
|
||||
index_code VARCHAR(20) DEFAULT '000001.SH' COMMENT '参考指数代码',
|
||||
index_price DECIMAL(12,4) COMMENT '异动时的指数点位',
|
||||
index_change_pct DECIMAL(10,4) COMMENT '异动时的指数涨跌幅(%)',
|
||||
|
||||
-- 概念详情
|
||||
stock_count INT COMMENT '概念包含股票数',
|
||||
concept_type VARCHAR(20) DEFAULT 'leaf' COMMENT '概念类型:leaf/lv1/lv2/lv3',
|
||||
|
||||
-- 额外信息(JSON格式,存储涨停股票列表等)
|
||||
extra_info JSON COMMENT '额外信息',
|
||||
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
|
||||
-- 索引
|
||||
INDEX idx_trade_date (trade_date),
|
||||
INDEX idx_alert_time (alert_time),
|
||||
INDEX idx_concept_id (concept_id),
|
||||
INDEX idx_alert_type (alert_type),
|
||||
INDEX idx_trade_date_time (trade_date, alert_time)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='概念分钟级异动数据表';
|
||||
|
||||
|
||||
-- 创建指数分时快照表(用于异动时获取指数位置)
|
||||
CREATE TABLE IF NOT EXISTS index_minute_snapshot (
|
||||
id BIGINT AUTO_INCREMENT PRIMARY KEY,
|
||||
index_code VARCHAR(20) NOT NULL COMMENT '指数代码',
|
||||
trade_date DATE NOT NULL COMMENT '交易日期',
|
||||
snapshot_time DATETIME NOT NULL COMMENT '快照时间',
|
||||
|
||||
price DECIMAL(12,4) COMMENT '指数点位',
|
||||
open_price DECIMAL(12,4) COMMENT '开盘价',
|
||||
high_price DECIMAL(12,4) COMMENT '最高价',
|
||||
low_price DECIMAL(12,4) COMMENT '最低价',
|
||||
prev_close DECIMAL(12,4) COMMENT '昨收价',
|
||||
change_pct DECIMAL(10,4) COMMENT '涨跌幅(%)',
|
||||
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
|
||||
UNIQUE KEY uk_index_time (index_code, snapshot_time),
|
||||
INDEX idx_trade_date (trade_date)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='指数分时快照表';
|
||||
@@ -313,12 +313,29 @@ const StockChartAntdModal = ({
|
||||
axisPointer: { type: 'cross' },
|
||||
formatter: function(params) {
|
||||
const d = params[0]?.dataIndex ?? 0;
|
||||
const priceChangePercent = ((prices[d] - prevClose) / prevClose * 100);
|
||||
const avgChangePercent = ((avgPrices[d] - prevClose) / prevClose * 100);
|
||||
const price = prices[d];
|
||||
const avgPrice = avgPrices[d];
|
||||
const volume = volumes[d];
|
||||
|
||||
// 安全计算涨跌幅,处理 undefined/null/0 的情况
|
||||
const safeCalcPercent = (val, base) => {
|
||||
if (val == null || base == null || base === 0) return 0;
|
||||
return ((val - base) / base * 100);
|
||||
};
|
||||
|
||||
const priceChangePercent = safeCalcPercent(price, prevClose);
|
||||
const avgChangePercent = safeCalcPercent(avgPrice, prevClose);
|
||||
const priceColor = priceChangePercent >= 0 ? '#ef5350' : '#26a69a';
|
||||
const avgColor = avgChangePercent >= 0 ? '#ef5350' : '#26a69a';
|
||||
|
||||
return `时间:${times[d]}<br/>现价:<span style="color: ${priceColor}">¥${prices[d]?.toFixed(2)} (${priceChangePercent >= 0 ? '+' : ''}${priceChangePercent.toFixed(2)}%)</span><br/>均价:<span style="color: ${avgColor}">¥${avgPrices[d]?.toFixed(2)} (${avgChangePercent >= 0 ? '+' : ''}${avgChangePercent.toFixed(2)}%)</span><br/>昨收:¥${prevClose?.toFixed(2)}<br/>成交量:${Math.round(volumes[d]/100)}手`;
|
||||
// 安全格式化数字
|
||||
const safeFixed = (val, digits = 2) => (val != null && !isNaN(val)) ? val.toFixed(digits) : '-';
|
||||
const formatPercent = (val) => {
|
||||
if (val == null || isNaN(val)) return '-';
|
||||
return (val >= 0 ? '+' : '') + val.toFixed(2) + '%';
|
||||
};
|
||||
|
||||
return `时间:${times[d] || '-'}<br/>现价:<span style="color: ${priceColor}">¥${safeFixed(price)} (${formatPercent(priceChangePercent)})</span><br/>均价:<span style="color: ${avgColor}">¥${safeFixed(avgPrice)} (${formatPercent(avgChangePercent)})</span><br/>昨收:¥${safeFixed(prevClose)}<br/>成交量:${volume != null ? Math.round(volume/100) + '手' : '-'}`;
|
||||
}
|
||||
},
|
||||
grid: [
|
||||
@@ -337,6 +354,7 @@ const StockChartAntdModal = ({
|
||||
position: 'left',
|
||||
axisLabel: {
|
||||
formatter: function(value) {
|
||||
if (value == null || isNaN(value)) return '-';
|
||||
return (value >= 0 ? '+' : '') + value.toFixed(2) + '%';
|
||||
}
|
||||
},
|
||||
@@ -354,11 +372,12 @@ const StockChartAntdModal = ({
|
||||
position: 'right',
|
||||
axisLabel: {
|
||||
formatter: function(value) {
|
||||
if (value == null || isNaN(value)) return '-';
|
||||
return (value >= 0 ? '+' : '') + value.toFixed(2) + '%';
|
||||
}
|
||||
}
|
||||
},
|
||||
{ type: 'value', gridIndex: 1, scale: true, axisLabel: { formatter: v => Math.round(v/100) + '手' } }
|
||||
{ type: 'value', gridIndex: 1, scale: true, axisLabel: { formatter: v => (v != null && !isNaN(v)) ? Math.round(v/100) + '手' : '-' } }
|
||||
],
|
||||
dataZoom: [
|
||||
{ type: 'inside', xAxisIndex: [0, 1], start: 0, end: 100 },
|
||||
|
||||
@@ -217,27 +217,34 @@ const TimelineChartModal: React.FC<TimelineChartModalProps> = ({
|
||||
if (dataIndex === undefined) return '';
|
||||
|
||||
const item = data[dataIndex];
|
||||
const changeColor = item.change_percent >= 0 ? '#ef5350' : '#26a69a';
|
||||
const changeSign = item.change_percent >= 0 ? '+' : '';
|
||||
if (!item) return '';
|
||||
|
||||
// 安全格式化数字
|
||||
const safeFixed = (val: any, digits = 2) =>
|
||||
val != null && !isNaN(val) ? Number(val).toFixed(digits) : '-';
|
||||
|
||||
const changePercent = item.change_percent ?? 0;
|
||||
const changeColor = changePercent >= 0 ? '#ef5350' : '#26a69a';
|
||||
const changeSign = changePercent >= 0 ? '+' : '';
|
||||
|
||||
return `
|
||||
<div style="padding: 8px;">
|
||||
<div style="font-weight: bold; margin-bottom: 8px;">${item.time}</div>
|
||||
<div style="font-weight: bold; margin-bottom: 8px;">${item.time || '-'}</div>
|
||||
<div style="display: flex; justify-content: space-between; margin-bottom: 4px;">
|
||||
<span>价格:</span>
|
||||
<span style="color: ${changeColor}; font-weight: bold; margin-left: 20px;">${item.price.toFixed(2)}</span>
|
||||
<span style="color: ${changeColor}; font-weight: bold; margin-left: 20px;">${safeFixed(item.price)}</span>
|
||||
</div>
|
||||
<div style="display: flex; justify-content: space-between; margin-bottom: 4px;">
|
||||
<span>均价:</span>
|
||||
<span style="color: #ffa726; margin-left: 20px;">${item.avg_price.toFixed(2)}</span>
|
||||
<span style="color: #ffa726; margin-left: 20px;">${safeFixed(item.avg_price)}</span>
|
||||
</div>
|
||||
<div style="display: flex; justify-content: space-between; margin-bottom: 4px;">
|
||||
<span>涨跌幅:</span>
|
||||
<span style="color: ${changeColor}; margin-left: 20px;">${changeSign}${item.change_percent.toFixed(2)}%</span>
|
||||
<span style="color: ${changeColor}; margin-left: 20px;">${changeSign}${safeFixed(changePercent)}%</span>
|
||||
</div>
|
||||
<div style="display: flex; justify-content: space-between;">
|
||||
<span>成交量:</span>
|
||||
<span style="margin-left: 20px;">${(item.volume / 100).toFixed(0)}手</span>
|
||||
<span style="margin-left: 20px;">${item.volume != null ? (item.volume / 100).toFixed(0) : '-'}手</span>
|
||||
</div>
|
||||
</div>
|
||||
`;
|
||||
@@ -314,7 +321,7 @@ const TimelineChartModal: React.FC<TimelineChartModalProps> = ({
|
||||
axisLabel: {
|
||||
color: '#999',
|
||||
fontSize: isMobile ? 10 : 12,
|
||||
formatter: (value: number) => value.toFixed(2),
|
||||
formatter: (value: number) => (value != null && !isNaN(value)) ? value.toFixed(2) : '-',
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -333,6 +340,7 @@ const TimelineChartModal: React.FC<TimelineChartModalProps> = ({
|
||||
color: '#999',
|
||||
fontSize: isMobile ? 10 : 12,
|
||||
formatter: (value: number) => {
|
||||
if (value == null || isNaN(value)) return '-';
|
||||
if (value >= 10000) {
|
||||
return (value / 10000).toFixed(1) + '万';
|
||||
}
|
||||
|
||||
@@ -346,7 +346,173 @@ export const marketHandlers = [
|
||||
});
|
||||
}),
|
||||
|
||||
// 11. 市场统计数据(个股中心页面使用)
|
||||
// 11. 热点概览数据(大盘分时 + 概念异动)
|
||||
http.get('/api/market/hotspot-overview', async ({ request }) => {
|
||||
await delay(300);
|
||||
const url = new URL(request.url);
|
||||
const date = url.searchParams.get('date');
|
||||
|
||||
const tradeDate = date || new Date().toISOString().split('T')[0];
|
||||
|
||||
// 生成分时数据(240个点,9:30-11:30 + 13:00-15:00)
|
||||
const timeline = [];
|
||||
const basePrice = 3900 + Math.random() * 100; // 基准价格 3900-4000
|
||||
const prevClose = basePrice;
|
||||
let currentPrice = basePrice;
|
||||
let cumulativeVolume = 0;
|
||||
|
||||
// 上午时段 9:30-11:30 (120分钟)
|
||||
for (let i = 0; i < 120; i++) {
|
||||
const hour = 9 + Math.floor((i + 30) / 60);
|
||||
const minute = (i + 30) % 60;
|
||||
const time = `${hour.toString().padStart(2, '0')}:${minute.toString().padStart(2, '0')}`;
|
||||
|
||||
// 模拟价格波动
|
||||
const volatility = 0.002; // 0.2%波动
|
||||
const drift = (Math.random() - 0.5) * 0.001; // 微小趋势
|
||||
currentPrice = currentPrice * (1 + (Math.random() - 0.5) * volatility + drift);
|
||||
|
||||
const volume = Math.floor(Math.random() * 500000 + 100000); // 成交量
|
||||
cumulativeVolume += volume;
|
||||
|
||||
timeline.push({
|
||||
time,
|
||||
price: parseFloat(currentPrice.toFixed(2)),
|
||||
volume: cumulativeVolume,
|
||||
change_pct: parseFloat(((currentPrice - prevClose) / prevClose * 100).toFixed(2))
|
||||
});
|
||||
}
|
||||
|
||||
// 下午时段 13:00-15:00 (120分钟)
|
||||
for (let i = 0; i < 120; i++) {
|
||||
const hour = 13 + Math.floor(i / 60);
|
||||
const minute = i % 60;
|
||||
const time = `${hour.toString().padStart(2, '0')}:${minute.toString().padStart(2, '0')}`;
|
||||
|
||||
// 下午波动略小
|
||||
const volatility = 0.0015;
|
||||
const drift = (Math.random() - 0.5) * 0.0008;
|
||||
currentPrice = currentPrice * (1 + (Math.random() - 0.5) * volatility + drift);
|
||||
|
||||
const volume = Math.floor(Math.random() * 400000 + 80000);
|
||||
cumulativeVolume += volume;
|
||||
|
||||
timeline.push({
|
||||
time,
|
||||
price: parseFloat(currentPrice.toFixed(2)),
|
||||
volume: cumulativeVolume,
|
||||
change_pct: parseFloat(((currentPrice - prevClose) / prevClose * 100).toFixed(2))
|
||||
});
|
||||
}
|
||||
|
||||
// 生成概念异动数据
|
||||
const conceptNames = [
|
||||
'人工智能', 'AI眼镜', '机器人', '核电', '国企', '卫星导航',
|
||||
'福建自贸区', '两岸融合', 'CRO', '三季报增长', '百货零售',
|
||||
'人形机器人', '央企', '数据中心', 'CPO', '新能源', '电网设备',
|
||||
'氢能源', '算力租赁', '厦门国资', '乳业', '低空安防', '创新药',
|
||||
'商业航天', '控制权变更', '文化传媒', '海峡两岸'
|
||||
];
|
||||
|
||||
const alertTypes = ['surge_up', 'surge_down', 'volume_spike', 'limit_up', 'rank_jump'];
|
||||
|
||||
// 生成 15-25 个异动
|
||||
const alertCount = Math.floor(Math.random() * 10) + 15;
|
||||
const alerts = [];
|
||||
const usedTimes = new Set();
|
||||
|
||||
for (let i = 0; i < alertCount; i++) {
|
||||
// 随机选择一个时间点
|
||||
let timeIdx;
|
||||
let attempts = 0;
|
||||
do {
|
||||
timeIdx = Math.floor(Math.random() * timeline.length);
|
||||
attempts++;
|
||||
} while (usedTimes.has(timeIdx) && attempts < 50);
|
||||
|
||||
if (attempts >= 50) continue;
|
||||
|
||||
// 同一时间可以有多个异动
|
||||
const time = timeline[timeIdx].time;
|
||||
const conceptName = conceptNames[Math.floor(Math.random() * conceptNames.length)];
|
||||
const alertType = alertTypes[Math.floor(Math.random() * alertTypes.length)];
|
||||
|
||||
// 根据类型生成 alpha
|
||||
let alpha;
|
||||
if (alertType === 'surge_up') {
|
||||
alpha = parseFloat((Math.random() * 3 + 2).toFixed(2)); // +2% ~ +5%
|
||||
} else if (alertType === 'surge_down') {
|
||||
alpha = parseFloat((-Math.random() * 3 - 1.5).toFixed(2)); // -1.5% ~ -4.5%
|
||||
} else {
|
||||
alpha = parseFloat((Math.random() * 4 - 1).toFixed(2)); // -1% ~ +3%
|
||||
}
|
||||
|
||||
const finalScore = Math.floor(Math.random() * 40 + 45); // 45-85分
|
||||
const ruleScore = Math.floor(Math.random() * 30 + 40);
|
||||
const mlScore = Math.floor(Math.random() * 30 + 40);
|
||||
|
||||
alerts.push({
|
||||
concept_id: `CONCEPT_${1000 + i}`,
|
||||
concept_name: conceptName,
|
||||
time,
|
||||
alert_type: alertType,
|
||||
alpha,
|
||||
alpha_delta: parseFloat((Math.random() * 2 - 0.5).toFixed(2)),
|
||||
amt_ratio: parseFloat((Math.random() * 5 + 1).toFixed(2)),
|
||||
limit_up_count: alertType === 'limit_up' ? Math.floor(Math.random() * 5 + 1) : 0,
|
||||
limit_up_ratio: parseFloat((Math.random() * 0.3).toFixed(3)),
|
||||
final_score: finalScore,
|
||||
rule_score: ruleScore,
|
||||
ml_score: mlScore,
|
||||
trigger_reason: finalScore >= 65 ? '规则强信号' : (mlScore >= 70 ? 'ML强信号' : '融合触发'),
|
||||
importance_score: parseFloat((finalScore / 100).toFixed(2)),
|
||||
index_price: timeline[timeIdx].price
|
||||
});
|
||||
}
|
||||
|
||||
// 按时间排序
|
||||
alerts.sort((a, b) => a.time.localeCompare(b.time));
|
||||
|
||||
// 统计异动类型
|
||||
const alertSummary = alerts.reduce((acc, alert) => {
|
||||
acc[alert.alert_type] = (acc[alert.alert_type] || 0) + 1;
|
||||
return acc;
|
||||
}, {});
|
||||
|
||||
// 计算指数统计
|
||||
const prices = timeline.map(t => t.price);
|
||||
const latestPrice = prices[prices.length - 1];
|
||||
const highPrice = Math.max(...prices);
|
||||
const lowPrice = Math.min(...prices);
|
||||
const changePct = ((latestPrice - prevClose) / prevClose * 100);
|
||||
|
||||
console.log('[Mock Market] 获取热点概览数据:', {
|
||||
date: tradeDate,
|
||||
timelinePoints: timeline.length,
|
||||
alertCount: alerts.length
|
||||
});
|
||||
|
||||
return HttpResponse.json({
|
||||
success: true,
|
||||
data: {
|
||||
index: {
|
||||
code: '000001.SH',
|
||||
name: '上证指数',
|
||||
latest_price: latestPrice,
|
||||
prev_close: prevClose,
|
||||
high: highPrice,
|
||||
low: lowPrice,
|
||||
change_pct: parseFloat(changePct.toFixed(2)),
|
||||
timeline
|
||||
},
|
||||
alerts,
|
||||
alert_summary: alertSummary
|
||||
},
|
||||
trade_date: tradeDate
|
||||
});
|
||||
}),
|
||||
|
||||
// 12. 市场统计数据(个股中心页面使用)
|
||||
http.get('/api/market/statistics', async ({ request }) => {
|
||||
await delay(200);
|
||||
const url = new URL(request.url);
|
||||
|
||||
@@ -207,9 +207,12 @@ const CompactIndexCard = ({ indexCode, indexName }) => {
|
||||
const raw = chartData.rawData[idx];
|
||||
if (!raw) return '';
|
||||
|
||||
// 安全格式化数字
|
||||
const safeFixed = (val, digits = 2) => (val != null && !isNaN(val)) ? val.toFixed(digits) : '-';
|
||||
|
||||
// 计算涨跌
|
||||
const prevClose = raw.prev_close || (idx > 0 ? chartData.rawData[idx - 1]?.close : raw.open) || raw.open;
|
||||
const changeAmount = raw.close - prevClose;
|
||||
const changeAmount = (raw.close != null && prevClose != null) ? (raw.close - prevClose) : 0;
|
||||
const changePct = prevClose ? ((changeAmount / prevClose) * 100) : 0;
|
||||
const isUp = changeAmount >= 0;
|
||||
const color = isUp ? '#ef5350' : '#26a69a';
|
||||
@@ -218,22 +221,22 @@ const CompactIndexCard = ({ indexCode, indexName }) => {
|
||||
return `
|
||||
<div style="min-width: 180px;">
|
||||
<div style="font-weight: bold; color: #FFD700; margin-bottom: 10px; font-size: 13px; border-bottom: 1px solid rgba(255,215,0,0.2); padding-bottom: 8px;">
|
||||
📅 ${raw.time}
|
||||
📅 ${raw.time || '-'}
|
||||
</div>
|
||||
<div style="display: grid; grid-template-columns: auto 1fr; gap: 6px 16px; font-size: 12px;">
|
||||
<span style="color: #999;">开盘</span>
|
||||
<span style="text-align: right; font-family: monospace;">${raw.open.toFixed(2)}</span>
|
||||
<span style="text-align: right; font-family: monospace;">${safeFixed(raw.open)}</span>
|
||||
<span style="color: #999;">收盘</span>
|
||||
<span style="text-align: right; font-weight: bold; color: ${color}; font-family: monospace;">${raw.close.toFixed(2)}</span>
|
||||
<span style="text-align: right; font-weight: bold; color: ${color}; font-family: monospace;">${safeFixed(raw.close)}</span>
|
||||
<span style="color: #999;">最高</span>
|
||||
<span style="text-align: right; color: #ef5350; font-family: monospace;">${raw.high.toFixed(2)}</span>
|
||||
<span style="text-align: right; color: #ef5350; font-family: monospace;">${safeFixed(raw.high)}</span>
|
||||
<span style="color: #999;">最低</span>
|
||||
<span style="text-align: right; color: #26a69a; font-family: monospace;">${raw.low.toFixed(2)}</span>
|
||||
<span style="text-align: right; color: #26a69a; font-family: monospace;">${safeFixed(raw.low)}</span>
|
||||
</div>
|
||||
<div style="margin-top: 10px; padding-top: 8px; border-top: 1px solid rgba(255,255,255,0.1); display: flex; justify-content: space-between; align-items: center;">
|
||||
<span style="color: #999; font-size: 11px;">涨跌幅</span>
|
||||
<span style="color: ${color}; font-weight: bold; font-size: 14px; font-family: monospace;">
|
||||
${sign}${changeAmount.toFixed(2)} (${sign}${changePct.toFixed(2)}%)
|
||||
${sign}${safeFixed(changeAmount)} (${sign}${safeFixed(changePct)}%)
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
@@ -529,7 +532,7 @@ const FlowingConcepts = () => {
|
||||
color={colors.text}
|
||||
whiteSpace="nowrap"
|
||||
>
|
||||
{concept.change_pct > 0 ? '+' : ''}{concept.change_pct.toFixed(2)}%
|
||||
{concept.change_pct > 0 ? '+' : ''}{concept.change_pct?.toFixed(2) ?? '-'}%
|
||||
</Text>
|
||||
</HStack>
|
||||
</Box>
|
||||
|
||||
@@ -0,0 +1,280 @@
|
||||
/**
|
||||
* 迷你分时图组件
|
||||
* 用于灵活屏中显示证券的日内走势
|
||||
*/
|
||||
import React, { useEffect, useRef, useState, useMemo } from 'react';
|
||||
import { Box, Spinner, Center, Text } from '@chakra-ui/react';
|
||||
import * as echarts from 'echarts';
|
||||
import type { ECharts, EChartsOption } from 'echarts';
|
||||
|
||||
import type { MiniTimelineChartProps, TimelineDataPoint } from '../types';
|
||||
|
||||
/**
|
||||
* 生成交易时间刻度(用于 X 轴)
|
||||
* A股交易时间:9:30-11:30, 13:00-15:00
|
||||
*/
|
||||
const generateTimeTicks = (): string[] => {
|
||||
const ticks: string[] = [];
|
||||
// 上午
|
||||
for (let h = 9; h <= 11; h++) {
|
||||
for (let m = h === 9 ? 30 : 0; m < 60; m++) {
|
||||
if (h === 11 && m > 30) break;
|
||||
ticks.push(`${String(h).padStart(2, '0')}:${String(m).padStart(2, '0')}`);
|
||||
}
|
||||
}
|
||||
// 下午
|
||||
for (let h = 13; h <= 15; h++) {
|
||||
for (let m = 0; m < 60; m++) {
|
||||
if (h === 15 && m > 0) break;
|
||||
ticks.push(`${String(h).padStart(2, '0')}:${String(m).padStart(2, '0')}`);
|
||||
}
|
||||
}
|
||||
return ticks;
|
||||
};
|
||||
|
||||
const TIME_TICKS = generateTimeTicks();
|
||||
|
||||
/** API 返回的分钟数据结构 */
|
||||
interface MinuteKLineItem {
|
||||
time?: string;
|
||||
timestamp?: string;
|
||||
close?: number;
|
||||
price?: number;
|
||||
}
|
||||
|
||||
/** API 响应结构 */
|
||||
interface KLineApiResponse {
|
||||
success?: boolean;
|
||||
data?: MinuteKLineItem[];
|
||||
error?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* MiniTimelineChart 组件
|
||||
*/
|
||||
const MiniTimelineChart: React.FC<MiniTimelineChartProps> = ({
|
||||
code,
|
||||
isIndex = false,
|
||||
prevClose,
|
||||
currentPrice,
|
||||
height = 120,
|
||||
}) => {
|
||||
const chartRef = useRef<HTMLDivElement>(null);
|
||||
const chartInstance = useRef<ECharts | null>(null);
|
||||
const [timelineData, setTimelineData] = useState<TimelineDataPoint[]>([]);
|
||||
const [loading, setLoading] = useState(true);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
|
||||
// 获取分钟数据
|
||||
useEffect(() => {
|
||||
if (!code) return;
|
||||
|
||||
const fetchData = async (): Promise<void> => {
|
||||
setLoading(true);
|
||||
setError(null);
|
||||
|
||||
try {
|
||||
const apiPath = isIndex
|
||||
? `/api/index/${code}/kline?type=minute`
|
||||
: `/api/stock/${code}/kline?type=minute`;
|
||||
|
||||
const response = await fetch(apiPath);
|
||||
const result: KLineApiResponse = await response.json();
|
||||
|
||||
if (result.success !== false && result.data) {
|
||||
// 格式化数据
|
||||
const formatted: TimelineDataPoint[] = result.data.map(item => ({
|
||||
time: item.time || item.timestamp || '',
|
||||
price: item.close || item.price || 0,
|
||||
}));
|
||||
setTimelineData(formatted);
|
||||
} else {
|
||||
setError(result.error || '暂无数据');
|
||||
}
|
||||
} catch (e) {
|
||||
setError('加载失败');
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
fetchData();
|
||||
|
||||
// 交易时间内每分钟刷新
|
||||
const now = new Date();
|
||||
const hours = now.getHours();
|
||||
const minutes = now.getMinutes();
|
||||
const currentMinutes = hours * 60 + minutes;
|
||||
const isTrading =
|
||||
(currentMinutes >= 570 && currentMinutes <= 690) ||
|
||||
(currentMinutes >= 780 && currentMinutes <= 900);
|
||||
|
||||
let intervalId: NodeJS.Timeout | undefined;
|
||||
if (isTrading) {
|
||||
intervalId = setInterval(fetchData, 60000); // 1分钟刷新
|
||||
}
|
||||
|
||||
return () => {
|
||||
if (intervalId) clearInterval(intervalId);
|
||||
};
|
||||
}, [code, isIndex]);
|
||||
|
||||
// 合并实时价格到数据中
|
||||
const chartData = useMemo((): TimelineDataPoint[] => {
|
||||
if (!timelineData.length) return [];
|
||||
|
||||
const data = [...timelineData];
|
||||
|
||||
// 如果有实时价格,添加到最新点
|
||||
if (currentPrice && data.length > 0) {
|
||||
const now = new Date();
|
||||
const timeStr = `${String(now.getHours()).padStart(2, '0')}:${String(now.getMinutes()).padStart(2, '0')}`;
|
||||
const lastItem = data[data.length - 1];
|
||||
|
||||
// 如果实时价格的时间比最后一条数据新,添加新点
|
||||
if (lastItem.time !== timeStr) {
|
||||
data.push({ time: timeStr, price: currentPrice });
|
||||
} else {
|
||||
// 更新最后一条
|
||||
data[data.length - 1] = { ...lastItem, price: currentPrice };
|
||||
}
|
||||
}
|
||||
|
||||
return data;
|
||||
}, [timelineData, currentPrice]);
|
||||
|
||||
// 渲染图表
|
||||
useEffect(() => {
|
||||
if (!chartRef.current || loading || !chartData.length) return;
|
||||
|
||||
if (!chartInstance.current) {
|
||||
chartInstance.current = echarts.init(chartRef.current);
|
||||
}
|
||||
|
||||
const baseLine = prevClose || chartData[0]?.price || 0;
|
||||
|
||||
// 计算价格范围
|
||||
const prices = chartData.map(d => d.price).filter(p => p > 0);
|
||||
const minPrice = Math.min(...prices, baseLine);
|
||||
const maxPrice = Math.max(...prices, baseLine);
|
||||
const range = Math.max(maxPrice - baseLine, baseLine - minPrice) * 1.1;
|
||||
|
||||
// 准备数据
|
||||
const times = chartData.map(d => d.time);
|
||||
const values = chartData.map(d => d.price);
|
||||
|
||||
// 判断涨跌
|
||||
const lastPrice = values[values.length - 1] || baseLine;
|
||||
const isUp = lastPrice >= baseLine;
|
||||
|
||||
const option: EChartsOption = {
|
||||
grid: {
|
||||
top: 5,
|
||||
right: 5,
|
||||
bottom: 5,
|
||||
left: 5,
|
||||
containLabel: false,
|
||||
},
|
||||
xAxis: {
|
||||
type: 'category',
|
||||
data: times,
|
||||
show: false,
|
||||
boundaryGap: false,
|
||||
},
|
||||
yAxis: {
|
||||
type: 'value',
|
||||
min: baseLine - range,
|
||||
max: baseLine + range,
|
||||
show: false,
|
||||
},
|
||||
series: [
|
||||
{
|
||||
type: 'line',
|
||||
data: values,
|
||||
smooth: false,
|
||||
symbol: 'none',
|
||||
lineStyle: {
|
||||
width: 1.5,
|
||||
color: isUp ? '#ef4444' : '#22c55e',
|
||||
},
|
||||
areaStyle: {
|
||||
color: {
|
||||
type: 'linear',
|
||||
x: 0,
|
||||
y: 0,
|
||||
x2: 0,
|
||||
y2: 1,
|
||||
colorStops: [
|
||||
{ offset: 0, color: isUp ? 'rgba(239, 68, 68, 0.3)' : 'rgba(34, 197, 94, 0.3)' },
|
||||
{ offset: 1, color: isUp ? 'rgba(239, 68, 68, 0.05)' : 'rgba(34, 197, 94, 0.05)' },
|
||||
],
|
||||
},
|
||||
},
|
||||
markLine: {
|
||||
silent: true,
|
||||
symbol: 'none',
|
||||
data: [
|
||||
{
|
||||
yAxis: baseLine,
|
||||
lineStyle: {
|
||||
color: '#666',
|
||||
type: 'dashed',
|
||||
width: 1,
|
||||
},
|
||||
label: { show: false },
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
],
|
||||
animation: false,
|
||||
};
|
||||
|
||||
chartInstance.current.setOption(option);
|
||||
|
||||
return () => {
|
||||
// 不在这里销毁,只在组件卸载时销毁
|
||||
};
|
||||
}, [chartData, prevClose, loading]);
|
||||
|
||||
// 组件卸载时销毁图表
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
if (chartInstance.current) {
|
||||
chartInstance.current.dispose();
|
||||
chartInstance.current = null;
|
||||
}
|
||||
};
|
||||
}, []);
|
||||
|
||||
// 窗口 resize 处理
|
||||
useEffect(() => {
|
||||
const handleResize = (): void => {
|
||||
chartInstance.current?.resize();
|
||||
};
|
||||
window.addEventListener('resize', handleResize);
|
||||
return () => window.removeEventListener('resize', handleResize);
|
||||
}, []);
|
||||
|
||||
if (loading) {
|
||||
return (
|
||||
<Center h={height}>
|
||||
<Spinner size="sm" color="gray.400" />
|
||||
</Center>
|
||||
);
|
||||
}
|
||||
|
||||
if (error || !chartData.length) {
|
||||
return (
|
||||
<Center h={height}>
|
||||
<Text fontSize="xs" color="gray.400">
|
||||
{error || '暂无数据'}
|
||||
</Text>
|
||||
</Center>
|
||||
);
|
||||
}
|
||||
|
||||
return <Box ref={chartRef} h={`${height}px`} w="100%" />;
|
||||
};
|
||||
|
||||
export default MiniTimelineChart;
|
||||
@@ -0,0 +1,291 @@
|
||||
/**
|
||||
* 盘口行情面板组件
|
||||
* 支持显示 5 档或 10 档买卖盘数据
|
||||
*
|
||||
* 上交所: 5 档行情
|
||||
* 深交所: 10 档行情
|
||||
*/
|
||||
import React, { useState } from 'react';
|
||||
import {
|
||||
Box,
|
||||
VStack,
|
||||
HStack,
|
||||
Text,
|
||||
Button,
|
||||
ButtonGroup,
|
||||
useColorModeValue,
|
||||
Tooltip,
|
||||
Badge,
|
||||
} from '@chakra-ui/react';
|
||||
|
||||
import type { OrderBookPanelProps } from '../types';
|
||||
|
||||
/** 格式化价格返回值 */
|
||||
interface FormattedPrice {
|
||||
text: string;
|
||||
color: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* 格式化成交量
|
||||
*/
|
||||
const formatVolume = (volume: number): string => {
|
||||
if (!volume || volume === 0) return '-';
|
||||
if (volume >= 10000) {
|
||||
return `${(volume / 10000).toFixed(0)}万`;
|
||||
}
|
||||
if (volume >= 1000) {
|
||||
return `${(volume / 1000).toFixed(1)}k`;
|
||||
}
|
||||
return String(volume);
|
||||
};
|
||||
|
||||
/**
|
||||
* 格式化价格
|
||||
*/
|
||||
const formatPrice = (price: number, prevClose?: number): FormattedPrice => {
|
||||
if (!price || price === 0) {
|
||||
return { text: '-', color: 'gray.400' };
|
||||
}
|
||||
|
||||
const text = price.toFixed(2);
|
||||
|
||||
if (!prevClose || prevClose === 0) {
|
||||
return { text, color: 'gray.600' };
|
||||
}
|
||||
|
||||
if (price > prevClose) {
|
||||
return { text, color: 'red.500' };
|
||||
}
|
||||
if (price < prevClose) {
|
||||
return { text, color: 'green.500' };
|
||||
}
|
||||
return { text, color: 'gray.600' };
|
||||
};
|
||||
|
||||
/** OrderRow 组件 Props */
|
||||
interface OrderRowProps {
|
||||
label: string;
|
||||
price: number;
|
||||
volume: number;
|
||||
prevClose?: number;
|
||||
isBid: boolean;
|
||||
maxVolume: number;
|
||||
isLimitPrice: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* 单行盘口
|
||||
*/
|
||||
const OrderRow: React.FC<OrderRowProps> = ({
|
||||
label,
|
||||
price,
|
||||
volume,
|
||||
prevClose,
|
||||
isBid,
|
||||
maxVolume,
|
||||
isLimitPrice,
|
||||
}) => {
|
||||
const bgColor = useColorModeValue(
|
||||
isBid ? 'red.50' : 'green.50',
|
||||
isBid ? 'rgba(239, 68, 68, 0.1)' : 'rgba(34, 197, 94, 0.1)'
|
||||
);
|
||||
const barColor = useColorModeValue(
|
||||
isBid ? 'red.200' : 'green.200',
|
||||
isBid ? 'rgba(239, 68, 68, 0.3)' : 'rgba(34, 197, 94, 0.3)'
|
||||
);
|
||||
const limitColor = useColorModeValue('orange.500', 'orange.300');
|
||||
|
||||
const priceInfo = formatPrice(price, prevClose);
|
||||
const volumeText = formatVolume(volume);
|
||||
|
||||
// 计算成交量条宽度
|
||||
const barWidth = maxVolume > 0 ? Math.min((volume / maxVolume) * 100, 100) : 0;
|
||||
|
||||
return (
|
||||
<HStack
|
||||
spacing={2}
|
||||
py={0.5}
|
||||
px={1}
|
||||
position="relative"
|
||||
overflow="hidden"
|
||||
fontSize="xs"
|
||||
>
|
||||
{/* 成交量条 */}
|
||||
<Box
|
||||
position="absolute"
|
||||
right={0}
|
||||
top={0}
|
||||
bottom={0}
|
||||
width={`${barWidth}%`}
|
||||
bg={barColor}
|
||||
transition="width 0.2s"
|
||||
/>
|
||||
|
||||
{/* 内容 */}
|
||||
<Text color="gray.500" w="24px" flexShrink={0} zIndex={1}>
|
||||
{label}
|
||||
</Text>
|
||||
<HStack flex={1} justify="flex-end" zIndex={1}>
|
||||
<Text color={isLimitPrice ? limitColor : priceInfo.color} fontWeight="medium">
|
||||
{priceInfo.text}
|
||||
</Text>
|
||||
{isLimitPrice && (
|
||||
<Tooltip label={isBid ? '跌停价' : '涨停价'}>
|
||||
<Badge
|
||||
colorScheme={isBid ? 'green' : 'red'}
|
||||
fontSize="2xs"
|
||||
variant="subtle"
|
||||
>
|
||||
{isBid ? '跌' : '涨'}
|
||||
</Badge>
|
||||
</Tooltip>
|
||||
)}
|
||||
</HStack>
|
||||
<Text color="gray.600" w="40px" textAlign="right" zIndex={1}>
|
||||
{volumeText}
|
||||
</Text>
|
||||
</HStack>
|
||||
);
|
||||
};
|
||||
|
||||
/**
|
||||
* OrderBookPanel 组件
|
||||
*/
|
||||
const OrderBookPanel: React.FC<OrderBookPanelProps> = ({
|
||||
bidPrices = [],
|
||||
bidVolumes = [],
|
||||
askPrices = [],
|
||||
askVolumes = [],
|
||||
prevClose,
|
||||
upperLimit,
|
||||
lowerLimit,
|
||||
defaultLevels = 5,
|
||||
}) => {
|
||||
const borderColor = useColorModeValue('gray.200', 'gray.700');
|
||||
const buttonBg = useColorModeValue('gray.100', 'gray.700');
|
||||
const bgColor = useColorModeValue('white', '#1a1a1a');
|
||||
|
||||
// 可切换显示的档位数
|
||||
const maxAvailableLevels = Math.max(bidPrices.length, askPrices.length, 1);
|
||||
const [showLevels, setShowLevels] = useState(Math.min(defaultLevels, maxAvailableLevels));
|
||||
|
||||
// 计算最大成交量(用于条形图比例)
|
||||
const displayBidVolumes = bidVolumes.slice(0, showLevels);
|
||||
const displayAskVolumes = askVolumes.slice(0, showLevels);
|
||||
const allVolumes = [...displayBidVolumes, ...displayAskVolumes].filter(v => v > 0);
|
||||
const maxVolume = allVolumes.length > 0 ? Math.max(...allVolumes) : 0;
|
||||
|
||||
// 判断是否为涨跌停价
|
||||
const isUpperLimit = (price: number): boolean =>
|
||||
!!upperLimit && Math.abs(price - upperLimit) < 0.001;
|
||||
const isLowerLimit = (price: number): boolean =>
|
||||
!!lowerLimit && Math.abs(price - lowerLimit) < 0.001;
|
||||
|
||||
// 卖盘(从卖N到卖1,即价格从高到低)
|
||||
const askRows: React.ReactNode[] = [];
|
||||
for (let i = showLevels - 1; i >= 0; i--) {
|
||||
askRows.push(
|
||||
<OrderRow
|
||||
key={`ask${i + 1}`}
|
||||
label={`卖${i + 1}`}
|
||||
price={askPrices[i]}
|
||||
volume={askVolumes[i]}
|
||||
prevClose={prevClose}
|
||||
isBid={false}
|
||||
maxVolume={maxVolume}
|
||||
isLimitPrice={isUpperLimit(askPrices[i])}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
// 买盘(从买1到买N,即价格从高到低)
|
||||
const bidRows: React.ReactNode[] = [];
|
||||
for (let i = 0; i < showLevels; i++) {
|
||||
bidRows.push(
|
||||
<OrderRow
|
||||
key={`bid${i + 1}`}
|
||||
label={`买${i + 1}`}
|
||||
price={bidPrices[i]}
|
||||
volume={bidVolumes[i]}
|
||||
prevClose={prevClose}
|
||||
isBid={true}
|
||||
maxVolume={maxVolume}
|
||||
isLimitPrice={isLowerLimit(bidPrices[i])}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
// 没有数据时的提示
|
||||
const hasData = bidPrices.length > 0 || askPrices.length > 0;
|
||||
|
||||
if (!hasData) {
|
||||
return (
|
||||
<Box textAlign="center" py={2}>
|
||||
<Text fontSize="xs" color="gray.400">
|
||||
暂无盘口数据
|
||||
</Text>
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<VStack spacing={0} align="stretch">
|
||||
{/* 档位切换(只有当有超过5档数据时才显示) */}
|
||||
{maxAvailableLevels > 5 && (
|
||||
<HStack justify="flex-end" mb={1}>
|
||||
<ButtonGroup size="xs" isAttached variant="outline">
|
||||
<Button
|
||||
onClick={() => setShowLevels(5)}
|
||||
bg={showLevels === 5 ? buttonBg : 'transparent'}
|
||||
fontWeight={showLevels === 5 ? 'bold' : 'normal'}
|
||||
>
|
||||
5档
|
||||
</Button>
|
||||
<Button
|
||||
onClick={() => setShowLevels(10)}
|
||||
bg={showLevels === 10 ? buttonBg : 'transparent'}
|
||||
fontWeight={showLevels === 10 ? 'bold' : 'normal'}
|
||||
>
|
||||
10档
|
||||
</Button>
|
||||
</ButtonGroup>
|
||||
</HStack>
|
||||
)}
|
||||
|
||||
{/* 卖盘 */}
|
||||
{askRows}
|
||||
|
||||
{/* 分隔线 + 当前价信息 */}
|
||||
<Box h="1px" bg={borderColor} my={1} position="relative">
|
||||
{prevClose && (
|
||||
<Text
|
||||
position="absolute"
|
||||
right={0}
|
||||
top="50%"
|
||||
transform="translateY(-50%)"
|
||||
fontSize="2xs"
|
||||
color="gray.400"
|
||||
bg={bgColor}
|
||||
px={1}
|
||||
>
|
||||
昨收 {prevClose.toFixed(2)}
|
||||
</Text>
|
||||
)}
|
||||
</Box>
|
||||
|
||||
{/* 买盘 */}
|
||||
{bidRows}
|
||||
|
||||
{/* 涨跌停价信息 */}
|
||||
{(upperLimit || lowerLimit) && (
|
||||
<HStack justify="space-between" mt={1} fontSize="2xs" color="gray.400">
|
||||
{lowerLimit && <Text>跌停 {lowerLimit.toFixed(2)}</Text>}
|
||||
{upperLimit && <Text>涨停 {upperLimit.toFixed(2)}</Text>}
|
||||
</HStack>
|
||||
)}
|
||||
</VStack>
|
||||
);
|
||||
};
|
||||
|
||||
export default OrderBookPanel;
|
||||
@@ -0,0 +1,288 @@
|
||||
/**
|
||||
* 行情瓷砖组件
|
||||
* 单个证券的实时行情展示卡片,包含分时图和五档盘口
|
||||
*/
|
||||
import React, { useState } from 'react';
|
||||
import {
|
||||
Box,
|
||||
VStack,
|
||||
HStack,
|
||||
Text,
|
||||
IconButton,
|
||||
Tooltip,
|
||||
useColorModeValue,
|
||||
Collapse,
|
||||
Badge,
|
||||
} from '@chakra-ui/react';
|
||||
import { CloseIcon, ChevronDownIcon, ChevronUpIcon } from '@chakra-ui/icons';
|
||||
import { useNavigate } from 'react-router-dom';
|
||||
|
||||
import MiniTimelineChart from './MiniTimelineChart';
|
||||
import OrderBookPanel from './OrderBookPanel';
|
||||
import type { QuoteTileProps, QuoteData } from '../types';
|
||||
|
||||
/**
|
||||
* 格式化价格显示
|
||||
*/
|
||||
const formatPrice = (price?: number): string => {
|
||||
if (!price || isNaN(price)) return '-';
|
||||
return price.toFixed(2);
|
||||
};
|
||||
|
||||
/**
|
||||
* 格式化涨跌幅
|
||||
*/
|
||||
const formatChangePct = (pct?: number): string => {
|
||||
if (!pct || isNaN(pct)) return '0.00%';
|
||||
const sign = pct > 0 ? '+' : '';
|
||||
return `${sign}${pct.toFixed(2)}%`;
|
||||
};
|
||||
|
||||
/**
|
||||
* 格式化涨跌额
|
||||
*/
|
||||
const formatChange = (change?: number): string => {
|
||||
if (!change || isNaN(change)) return '-';
|
||||
const sign = change > 0 ? '+' : '';
|
||||
return `${sign}${change.toFixed(2)}`;
|
||||
};
|
||||
|
||||
/**
|
||||
* 格式化成交额
|
||||
*/
|
||||
const formatAmount = (amount?: number): string => {
|
||||
if (!amount || isNaN(amount)) return '-';
|
||||
if (amount >= 100000000) {
|
||||
return `${(amount / 100000000).toFixed(2)}亿`;
|
||||
}
|
||||
if (amount >= 10000) {
|
||||
return `${(amount / 10000).toFixed(0)}万`;
|
||||
}
|
||||
return amount.toFixed(0);
|
||||
};
|
||||
|
||||
/**
|
||||
* QuoteTile 组件
|
||||
*/
|
||||
const QuoteTile: React.FC<QuoteTileProps> = ({
|
||||
code,
|
||||
name,
|
||||
quote = {},
|
||||
isIndex = false,
|
||||
onRemove,
|
||||
}) => {
|
||||
const navigate = useNavigate();
|
||||
const [expanded, setExpanded] = useState(true);
|
||||
|
||||
// 类型断言,确保类型安全
|
||||
const quoteData = quote as Partial<QuoteData>;
|
||||
|
||||
// 颜色主题
|
||||
const cardBg = useColorModeValue('white', '#1a1a1a');
|
||||
const borderColor = useColorModeValue('gray.200', '#333');
|
||||
const hoverBorderColor = useColorModeValue('purple.300', '#666');
|
||||
const textColor = useColorModeValue('gray.800', 'white');
|
||||
const subTextColor = useColorModeValue('gray.500', 'gray.400');
|
||||
|
||||
// 涨跌色
|
||||
const { price, prevClose, change, changePct, amount } = quoteData;
|
||||
const priceColor = useColorModeValue(
|
||||
!prevClose || price === prevClose
|
||||
? 'gray.800'
|
||||
: price && price > prevClose
|
||||
? 'red.500'
|
||||
: 'green.500',
|
||||
!prevClose || price === prevClose
|
||||
? 'gray.200'
|
||||
: price && price > prevClose
|
||||
? 'red.400'
|
||||
: 'green.400'
|
||||
);
|
||||
|
||||
// 涨跌幅背景色
|
||||
const changeBgColor = useColorModeValue(
|
||||
!changePct || changePct === 0
|
||||
? 'gray.100'
|
||||
: changePct > 0
|
||||
? 'red.100'
|
||||
: 'green.100',
|
||||
!changePct || changePct === 0
|
||||
? 'gray.700'
|
||||
: changePct > 0
|
||||
? 'rgba(239, 68, 68, 0.2)'
|
||||
: 'rgba(34, 197, 94, 0.2)'
|
||||
);
|
||||
|
||||
// 跳转到详情页
|
||||
const handleNavigate = (): void => {
|
||||
if (isIndex) {
|
||||
// 指数暂无详情页
|
||||
return;
|
||||
}
|
||||
navigate(`/company?scode=${code}`);
|
||||
};
|
||||
|
||||
// 获取盘口数据(带类型安全)
|
||||
const bidPrices = 'bidPrices' in quoteData ? (quoteData.bidPrices as number[]) : [];
|
||||
const bidVolumes = 'bidVolumes' in quoteData ? (quoteData.bidVolumes as number[]) : [];
|
||||
const askPrices = 'askPrices' in quoteData ? (quoteData.askPrices as number[]) : [];
|
||||
const askVolumes = 'askVolumes' in quoteData ? (quoteData.askVolumes as number[]) : [];
|
||||
const upperLimit = 'upperLimit' in quoteData ? (quoteData.upperLimit as number | undefined) : undefined;
|
||||
const lowerLimit = 'lowerLimit' in quoteData ? (quoteData.lowerLimit as number | undefined) : undefined;
|
||||
const openPrice = 'open' in quoteData ? (quoteData.open as number | undefined) : undefined;
|
||||
|
||||
return (
|
||||
<Box
|
||||
bg={cardBg}
|
||||
borderWidth="1px"
|
||||
borderColor={borderColor}
|
||||
borderRadius="lg"
|
||||
overflow="hidden"
|
||||
transition="all 0.2s"
|
||||
_hover={{
|
||||
borderColor: hoverBorderColor,
|
||||
boxShadow: 'md',
|
||||
}}
|
||||
>
|
||||
{/* 头部 */}
|
||||
<HStack
|
||||
px={3}
|
||||
py={2}
|
||||
borderBottomWidth={expanded ? '1px' : '0'}
|
||||
borderColor={borderColor}
|
||||
cursor="pointer"
|
||||
onClick={() => setExpanded(!expanded)}
|
||||
>
|
||||
{/* 名称和代码 */}
|
||||
<VStack align="start" spacing={0} flex={1} minW={0}>
|
||||
<HStack spacing={2}>
|
||||
<Text
|
||||
fontWeight="bold"
|
||||
fontSize="sm"
|
||||
color={textColor}
|
||||
noOfLines={1}
|
||||
cursor="pointer"
|
||||
_hover={{ textDecoration: 'underline' }}
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
handleNavigate();
|
||||
}}
|
||||
>
|
||||
{name || code}
|
||||
</Text>
|
||||
{isIndex && (
|
||||
<Badge colorScheme="purple" fontSize="xs">
|
||||
指数
|
||||
</Badge>
|
||||
)}
|
||||
</HStack>
|
||||
<Text fontSize="xs" color={subTextColor}>
|
||||
{code}
|
||||
</Text>
|
||||
</VStack>
|
||||
|
||||
{/* 价格信息 */}
|
||||
<VStack align="end" spacing={0}>
|
||||
<Text fontWeight="bold" fontSize="lg" color={priceColor}>
|
||||
{formatPrice(price)}
|
||||
</Text>
|
||||
<HStack spacing={1}>
|
||||
<Box
|
||||
px={1.5}
|
||||
py={0.5}
|
||||
bg={changeBgColor}
|
||||
borderRadius="sm"
|
||||
fontSize="xs"
|
||||
fontWeight="medium"
|
||||
color={priceColor}
|
||||
>
|
||||
{formatChangePct(changePct)}
|
||||
</Box>
|
||||
<Text fontSize="xs" color={priceColor}>
|
||||
{formatChange(change)}
|
||||
</Text>
|
||||
</HStack>
|
||||
</VStack>
|
||||
|
||||
{/* 操作按钮 */}
|
||||
<HStack spacing={1} ml={2}>
|
||||
<IconButton
|
||||
icon={expanded ? <ChevronUpIcon /> : <ChevronDownIcon />}
|
||||
size="xs"
|
||||
variant="ghost"
|
||||
aria-label={expanded ? '收起' : '展开'}
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
setExpanded(!expanded);
|
||||
}}
|
||||
/>
|
||||
<Tooltip label="移除">
|
||||
<IconButton
|
||||
icon={<CloseIcon />}
|
||||
size="xs"
|
||||
variant="ghost"
|
||||
colorScheme="red"
|
||||
aria-label="移除"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
onRemove?.(code);
|
||||
}}
|
||||
/>
|
||||
</Tooltip>
|
||||
</HStack>
|
||||
</HStack>
|
||||
|
||||
{/* 可折叠内容 */}
|
||||
<Collapse in={expanded} animateOpacity>
|
||||
<Box p={3}>
|
||||
{/* 统计信息 */}
|
||||
<HStack spacing={4} mb={3} fontSize="xs" color={subTextColor}>
|
||||
<HStack>
|
||||
<Text>昨收:</Text>
|
||||
<Text color={textColor}>{formatPrice(prevClose)}</Text>
|
||||
</HStack>
|
||||
<HStack>
|
||||
<Text>今开:</Text>
|
||||
<Text color={textColor}>{formatPrice(openPrice)}</Text>
|
||||
</HStack>
|
||||
<HStack>
|
||||
<Text>成交额:</Text>
|
||||
<Text color={textColor}>{formatAmount(amount)}</Text>
|
||||
</HStack>
|
||||
</HStack>
|
||||
|
||||
{/* 分时图 */}
|
||||
<Box mb={3}>
|
||||
<MiniTimelineChart
|
||||
code={code}
|
||||
isIndex={isIndex}
|
||||
prevClose={prevClose}
|
||||
currentPrice={price}
|
||||
height={100}
|
||||
/>
|
||||
</Box>
|
||||
|
||||
{/* 盘口(指数没有盘口) */}
|
||||
{!isIndex && (
|
||||
<Box>
|
||||
<Text fontSize="xs" color={subTextColor} mb={1}>
|
||||
盘口 {bidPrices.length > 5 ? '(10档)' : '(5档)'}
|
||||
</Text>
|
||||
<OrderBookPanel
|
||||
bidPrices={bidPrices}
|
||||
bidVolumes={bidVolumes}
|
||||
askPrices={askPrices}
|
||||
askVolumes={askVolumes}
|
||||
prevClose={prevClose}
|
||||
upperLimit={upperLimit}
|
||||
lowerLimit={lowerLimit}
|
||||
/>
|
||||
</Box>
|
||||
)}
|
||||
</Box>
|
||||
</Collapse>
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
|
||||
export default QuoteTile;
|
||||
@@ -0,0 +1,7 @@
|
||||
/**
|
||||
* 组件导出文件
|
||||
*/
|
||||
|
||||
export { default as MiniTimelineChart } from './MiniTimelineChart';
|
||||
export { default as OrderBookPanel } from './OrderBookPanel';
|
||||
export { default as QuoteTile } from './QuoteTile';
|
||||
@@ -0,0 +1,46 @@
|
||||
/**
|
||||
* WebSocket 配置常量
|
||||
*/
|
||||
|
||||
import type { Exchange } from '../types';
|
||||
|
||||
/**
|
||||
* 获取 WebSocket 配置
|
||||
* - 生产环境 (HTTPS): 通过 Nginx 代理使用 wss://
|
||||
* - 开发环境 (HTTP): 直连 ws://
|
||||
*/
|
||||
const getWsConfig = (): Record<Exchange, string> => {
|
||||
// 服务端渲染或测试环境使用默认配置
|
||||
if (typeof window === 'undefined') {
|
||||
return {
|
||||
SSE: 'ws://49.232.185.254:8765',
|
||||
SZSE: 'ws://222.128.1.157:8765',
|
||||
};
|
||||
}
|
||||
|
||||
const isHttps = window.location.protocol === 'https:';
|
||||
const host = window.location.host;
|
||||
|
||||
if (isHttps) {
|
||||
// 生产环境:通过 Nginx 代理
|
||||
return {
|
||||
SSE: `wss://${host}/ws/sse`, // 上交所 - Nginx 代理
|
||||
SZSE: `wss://${host}/ws/szse`, // 深交所 - Nginx 代理
|
||||
};
|
||||
}
|
||||
|
||||
// 开发环境:直连
|
||||
return {
|
||||
SSE: 'ws://49.232.185.254:8765', // 上交所
|
||||
SZSE: 'ws://222.128.1.157:8765', // 深交所
|
||||
};
|
||||
};
|
||||
|
||||
/** WebSocket 服务地址 */
|
||||
export const WS_CONFIG: Record<Exchange, string> = getWsConfig();
|
||||
|
||||
/** 心跳间隔 (ms) */
|
||||
export const HEARTBEAT_INTERVAL = 30000;
|
||||
|
||||
/** 重连间隔 (ms) */
|
||||
export const RECONNECT_INTERVAL = 3000;
|
||||
@@ -0,0 +1,7 @@
|
||||
/**
|
||||
* Hooks 导出文件
|
||||
*/
|
||||
|
||||
export { useRealtimeQuote } from './useRealtimeQuote';
|
||||
export * from './constants';
|
||||
export * from './utils';
|
||||
@@ -0,0 +1,722 @@
|
||||
/**
|
||||
* 实时行情 Hook
|
||||
* 管理上交所和深交所 WebSocket 连接,获取实时行情数据
|
||||
*
|
||||
* 连接方式:
|
||||
* - 生产环境 (HTTPS): 通过 Nginx 代理使用 wss:// (如 wss://valuefrontier.cn/ws/sse)
|
||||
* - 开发环境 (HTTP): 直连 ws://
|
||||
*
|
||||
* 上交所 (SSE): 需主动订阅,提供五档行情
|
||||
* 深交所 (SZSE): 自动推送,提供十档行情
|
||||
*/
|
||||
|
||||
import { useState, useEffect, useRef, useCallback } from 'react';
|
||||
import { logger } from '@utils/logger';
|
||||
import { WS_CONFIG, HEARTBEAT_INTERVAL, RECONNECT_INTERVAL } from './constants';
|
||||
import { getExchange, normalizeCode, extractOrderBook, calcChangePct } from './utils';
|
||||
import type {
|
||||
Exchange,
|
||||
ConnectionStatus,
|
||||
QuotesMap,
|
||||
QuoteData,
|
||||
SSEMessage,
|
||||
SSEQuoteItem,
|
||||
SZSEMessage,
|
||||
SZSERealtimeMessage,
|
||||
SZSESnapshotMessage,
|
||||
SZSEStockData,
|
||||
SZSEIndexData,
|
||||
SZSEBondData,
|
||||
SZSEHKStockData,
|
||||
SZSEAfterhoursData,
|
||||
UseRealtimeQuoteReturn,
|
||||
} from '../types';
|
||||
|
||||
/**
|
||||
* 处理上交所消息
|
||||
* 注意:上交所返回的 code 不带后缀,但通过 msg.type 区分 'stock' 和 'index'
|
||||
* 存储时使用带后缀的完整代码作为 key(如 000001.SH)
|
||||
*/
|
||||
const handleSSEMessage = (
|
||||
msg: SSEMessage,
|
||||
subscribedCodes: Set<string>,
|
||||
prevQuotes: QuotesMap
|
||||
): QuotesMap | null => {
|
||||
if (msg.type !== 'stock' && msg.type !== 'index') {
|
||||
return null;
|
||||
}
|
||||
|
||||
const data = msg.data || {};
|
||||
const updated: QuotesMap = { ...prevQuotes };
|
||||
let hasUpdate = false;
|
||||
const isIndex = msg.type === 'index';
|
||||
|
||||
Object.entries(data).forEach(([code, quote]: [string, SSEQuoteItem]) => {
|
||||
// 生成带后缀的完整代码(上交所统一用 .SH)
|
||||
const fullCode = code.includes('.') ? code : `${code}.SH`;
|
||||
|
||||
if (subscribedCodes.has(code) || subscribedCodes.has(fullCode)) {
|
||||
hasUpdate = true;
|
||||
updated[fullCode] = {
|
||||
code: fullCode,
|
||||
name: quote.security_name,
|
||||
price: quote.last_price,
|
||||
prevClose: quote.prev_close,
|
||||
open: quote.open_price,
|
||||
high: quote.high_price,
|
||||
low: quote.low_price,
|
||||
volume: quote.volume,
|
||||
amount: quote.amount,
|
||||
change: quote.last_price - quote.prev_close,
|
||||
changePct: calcChangePct(quote.last_price, quote.prev_close),
|
||||
bidPrices: quote.bid_prices || [],
|
||||
bidVolumes: quote.bid_volumes || [],
|
||||
askPrices: quote.ask_prices || [],
|
||||
askVolumes: quote.ask_volumes || [],
|
||||
updateTime: quote.trade_time,
|
||||
exchange: 'SSE',
|
||||
} as QuoteData;
|
||||
}
|
||||
});
|
||||
|
||||
return hasUpdate ? updated : null;
|
||||
};
|
||||
|
||||
/**
|
||||
* 处理深交所实时消息
|
||||
* 注意:深交所返回的 security_id 可能带后缀也可能不带
|
||||
* 存储时统一使用带后缀的完整代码作为 key(如 000001.SZ)
|
||||
*/
|
||||
const handleSZSERealtimeMessage = (
|
||||
msg: SZSERealtimeMessage,
|
||||
subscribedCodes: Set<string>,
|
||||
prevQuotes: QuotesMap
|
||||
): QuotesMap | null => {
|
||||
const { category, data, timestamp } = msg;
|
||||
const rawCode = data.security_id;
|
||||
// 生成带后缀的完整代码(深交所统一用 .SZ)
|
||||
const fullCode = rawCode.includes('.') ? rawCode : `${rawCode}.SZ`;
|
||||
|
||||
if (!subscribedCodes.has(rawCode) && !subscribedCodes.has(fullCode)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const updated: QuotesMap = { ...prevQuotes };
|
||||
|
||||
switch (category) {
|
||||
case 'stock': {
|
||||
const stockData = data as SZSEStockData;
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const rawData = data as any; // 用于检查替代字段名
|
||||
|
||||
// 调试日志:检查深交所返回的盘口原始数据(临时使用 warn 级别方便调试)
|
||||
if (!stockData.bids || stockData.bids.length === 0) {
|
||||
logger.warn('FlexScreen', `SZSE股票数据无盘口 ${fullCode}`, {
|
||||
hasBids: !!stockData.bids,
|
||||
hasAsks: !!stockData.asks,
|
||||
bidsLength: stockData.bids?.length || 0,
|
||||
asksLength: stockData.asks?.length || 0,
|
||||
// 检查替代字段名
|
||||
hasBidPrices: !!rawData.bid_prices,
|
||||
hasAskPrices: !!rawData.ask_prices,
|
||||
dataKeys: Object.keys(stockData), // 查看服务端实际返回了哪些字段
|
||||
});
|
||||
}
|
||||
|
||||
// 优先使用 bids/asks 对象数组格式,如果不存在则尝试 bid_prices/ask_prices 分离数组格式
|
||||
let bidPrices: number[] = [];
|
||||
let bidVolumes: number[] = [];
|
||||
let askPrices: number[] = [];
|
||||
let askVolumes: number[] = [];
|
||||
|
||||
if (stockData.bids && stockData.bids.length > 0) {
|
||||
const extracted = extractOrderBook(stockData.bids);
|
||||
bidPrices = extracted.prices;
|
||||
bidVolumes = extracted.volumes;
|
||||
} else if (rawData.bid_prices && Array.isArray(rawData.bid_prices)) {
|
||||
// 替代格式:bid_prices 和 bid_volumes 分离
|
||||
bidPrices = rawData.bid_prices;
|
||||
bidVolumes = rawData.bid_volumes || [];
|
||||
}
|
||||
|
||||
if (stockData.asks && stockData.asks.length > 0) {
|
||||
const extracted = extractOrderBook(stockData.asks);
|
||||
askPrices = extracted.prices;
|
||||
askVolumes = extracted.volumes;
|
||||
} else if (rawData.ask_prices && Array.isArray(rawData.ask_prices)) {
|
||||
// 替代格式:ask_prices 和 ask_volumes 分离
|
||||
askPrices = rawData.ask_prices;
|
||||
askVolumes = rawData.ask_volumes || [];
|
||||
}
|
||||
|
||||
updated[fullCode] = {
|
||||
code: fullCode,
|
||||
name: prevQuotes[fullCode]?.name || '',
|
||||
price: stockData.last_px,
|
||||
prevClose: stockData.prev_close,
|
||||
open: stockData.open_px,
|
||||
high: stockData.high_px,
|
||||
low: stockData.low_px,
|
||||
volume: stockData.volume,
|
||||
amount: stockData.amount,
|
||||
numTrades: stockData.num_trades,
|
||||
upperLimit: stockData.upper_limit,
|
||||
lowerLimit: stockData.lower_limit,
|
||||
change: stockData.last_px - stockData.prev_close,
|
||||
changePct: calcChangePct(stockData.last_px, stockData.prev_close),
|
||||
bidPrices,
|
||||
bidVolumes,
|
||||
askPrices,
|
||||
askVolumes,
|
||||
tradingPhase: stockData.trading_phase,
|
||||
updateTime: timestamp,
|
||||
exchange: 'SZSE',
|
||||
} as QuoteData;
|
||||
break;
|
||||
}
|
||||
|
||||
case 'index': {
|
||||
const indexData = data as SZSEIndexData;
|
||||
updated[fullCode] = {
|
||||
code: fullCode,
|
||||
name: prevQuotes[fullCode]?.name || '',
|
||||
price: indexData.current_index,
|
||||
prevClose: indexData.prev_close,
|
||||
open: indexData.open_index,
|
||||
high: indexData.high_index,
|
||||
low: indexData.low_index,
|
||||
close: indexData.close_index,
|
||||
volume: indexData.volume,
|
||||
amount: indexData.amount,
|
||||
numTrades: indexData.num_trades,
|
||||
change: indexData.current_index - indexData.prev_close,
|
||||
changePct: calcChangePct(indexData.current_index, indexData.prev_close),
|
||||
bidPrices: [],
|
||||
bidVolumes: [],
|
||||
askPrices: [],
|
||||
askVolumes: [],
|
||||
tradingPhase: indexData.trading_phase,
|
||||
updateTime: timestamp,
|
||||
exchange: 'SZSE',
|
||||
} as QuoteData;
|
||||
break;
|
||||
}
|
||||
|
||||
case 'bond': {
|
||||
const bondData = data as SZSEBondData;
|
||||
updated[fullCode] = {
|
||||
code: fullCode,
|
||||
name: prevQuotes[fullCode]?.name || '',
|
||||
price: bondData.last_px,
|
||||
prevClose: bondData.prev_close,
|
||||
open: bondData.open_px,
|
||||
high: bondData.high_px,
|
||||
low: bondData.low_px,
|
||||
volume: bondData.volume,
|
||||
amount: bondData.amount,
|
||||
numTrades: bondData.num_trades,
|
||||
weightedAvgPx: bondData.weighted_avg_px,
|
||||
change: bondData.last_px - bondData.prev_close,
|
||||
changePct: calcChangePct(bondData.last_px, bondData.prev_close),
|
||||
bidPrices: [],
|
||||
bidVolumes: [],
|
||||
askPrices: [],
|
||||
askVolumes: [],
|
||||
tradingPhase: bondData.trading_phase,
|
||||
updateTime: timestamp,
|
||||
exchange: 'SZSE',
|
||||
isBond: true,
|
||||
} as QuoteData;
|
||||
break;
|
||||
}
|
||||
|
||||
case 'hk_stock': {
|
||||
const hkData = data as SZSEHKStockData;
|
||||
const { prices: bidPrices, volumes: bidVolumes } = extractOrderBook(hkData.bids);
|
||||
const { prices: askPrices, volumes: askVolumes } = extractOrderBook(hkData.asks);
|
||||
|
||||
updated[fullCode] = {
|
||||
code: fullCode,
|
||||
name: prevQuotes[fullCode]?.name || '',
|
||||
price: hkData.last_px,
|
||||
prevClose: hkData.prev_close,
|
||||
open: hkData.open_px,
|
||||
high: hkData.high_px,
|
||||
low: hkData.low_px,
|
||||
volume: hkData.volume,
|
||||
amount: hkData.amount,
|
||||
numTrades: hkData.num_trades,
|
||||
nominalPx: hkData.nominal_px,
|
||||
referencePx: hkData.reference_px,
|
||||
change: hkData.last_px - hkData.prev_close,
|
||||
changePct: calcChangePct(hkData.last_px, hkData.prev_close),
|
||||
bidPrices,
|
||||
bidVolumes,
|
||||
askPrices,
|
||||
askVolumes,
|
||||
tradingPhase: hkData.trading_phase,
|
||||
updateTime: timestamp,
|
||||
exchange: 'SZSE',
|
||||
isHK: true,
|
||||
} as QuoteData;
|
||||
break;
|
||||
}
|
||||
|
||||
case 'afterhours_block':
|
||||
case 'afterhours_trading': {
|
||||
const afterhoursData = data as SZSEAfterhoursData;
|
||||
const existing = prevQuotes[fullCode];
|
||||
if (existing) {
|
||||
updated[fullCode] = {
|
||||
...existing,
|
||||
afterhours: {
|
||||
bidPx: afterhoursData.bid_px,
|
||||
bidSize: afterhoursData.bid_size,
|
||||
offerPx: afterhoursData.offer_px,
|
||||
offerSize: afterhoursData.offer_size,
|
||||
volume: afterhoursData.volume,
|
||||
amount: afterhoursData.amount,
|
||||
numTrades: afterhoursData.num_trades || 0,
|
||||
},
|
||||
updateTime: timestamp,
|
||||
} as QuoteData;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
default:
|
||||
return null;
|
||||
}
|
||||
|
||||
return updated;
|
||||
};
|
||||
|
||||
/**
|
||||
* 处理深交所快照消息
|
||||
* 存储时统一使用带后缀的完整代码作为 key
|
||||
*/
|
||||
const handleSZSESnapshotMessage = (
|
||||
msg: SZSESnapshotMessage,
|
||||
subscribedCodes: Set<string>,
|
||||
prevQuotes: QuotesMap
|
||||
): QuotesMap | null => {
|
||||
const { stocks = [], indexes = [], bonds = [] } = msg.data || {};
|
||||
const updated: QuotesMap = { ...prevQuotes };
|
||||
let hasUpdate = false;
|
||||
|
||||
stocks.forEach((s: SZSEStockData) => {
|
||||
const rawCode = s.security_id;
|
||||
const fullCode = rawCode.includes('.') ? rawCode : `${rawCode}.SZ`;
|
||||
|
||||
if (subscribedCodes.has(rawCode) || subscribedCodes.has(fullCode)) {
|
||||
hasUpdate = true;
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const rawData = s as any; // 用于检查替代字段名
|
||||
|
||||
// 调试日志:检查快照消息中的盘口数据(无盘口时警告)
|
||||
if (!s.bids || s.bids.length === 0) {
|
||||
logger.warn('FlexScreen', `SZSE快照股票数据无盘口 ${fullCode}`, {
|
||||
hasBids: !!s.bids,
|
||||
hasAsks: !!s.asks,
|
||||
hasBidPrices: !!rawData.bid_prices,
|
||||
hasAskPrices: !!rawData.ask_prices,
|
||||
dataKeys: Object.keys(s),
|
||||
});
|
||||
}
|
||||
|
||||
// 优先使用 bids/asks 对象数组格式,如果不存在则尝试 bid_prices/ask_prices 分离数组格式
|
||||
let bidPrices: number[] = [];
|
||||
let bidVolumes: number[] = [];
|
||||
let askPrices: number[] = [];
|
||||
let askVolumes: number[] = [];
|
||||
|
||||
if (s.bids && s.bids.length > 0) {
|
||||
const extracted = extractOrderBook(s.bids);
|
||||
bidPrices = extracted.prices;
|
||||
bidVolumes = extracted.volumes;
|
||||
} else if (rawData.bid_prices && Array.isArray(rawData.bid_prices)) {
|
||||
bidPrices = rawData.bid_prices;
|
||||
bidVolumes = rawData.bid_volumes || [];
|
||||
}
|
||||
|
||||
if (s.asks && s.asks.length > 0) {
|
||||
const extracted = extractOrderBook(s.asks);
|
||||
askPrices = extracted.prices;
|
||||
askVolumes = extracted.volumes;
|
||||
} else if (rawData.ask_prices && Array.isArray(rawData.ask_prices)) {
|
||||
askPrices = rawData.ask_prices;
|
||||
askVolumes = rawData.ask_volumes || [];
|
||||
}
|
||||
|
||||
updated[fullCode] = {
|
||||
code: fullCode,
|
||||
name: '',
|
||||
price: s.last_px,
|
||||
prevClose: s.prev_close,
|
||||
open: s.open_px,
|
||||
high: s.high_px,
|
||||
low: s.low_px,
|
||||
volume: s.volume,
|
||||
amount: s.amount,
|
||||
numTrades: s.num_trades,
|
||||
upperLimit: s.upper_limit,
|
||||
lowerLimit: s.lower_limit,
|
||||
change: s.last_px - s.prev_close,
|
||||
changePct: calcChangePct(s.last_px, s.prev_close),
|
||||
bidPrices,
|
||||
bidVolumes,
|
||||
askPrices,
|
||||
askVolumes,
|
||||
exchange: 'SZSE',
|
||||
} as QuoteData;
|
||||
}
|
||||
});
|
||||
|
||||
indexes.forEach((i: SZSEIndexData) => {
|
||||
const rawCode = i.security_id;
|
||||
const fullCode = rawCode.includes('.') ? rawCode : `${rawCode}.SZ`;
|
||||
|
||||
if (subscribedCodes.has(rawCode) || subscribedCodes.has(fullCode)) {
|
||||
hasUpdate = true;
|
||||
updated[fullCode] = {
|
||||
code: fullCode,
|
||||
name: '',
|
||||
price: i.current_index,
|
||||
prevClose: i.prev_close,
|
||||
open: i.open_index,
|
||||
high: i.high_index,
|
||||
low: i.low_index,
|
||||
volume: i.volume,
|
||||
amount: i.amount,
|
||||
numTrades: i.num_trades,
|
||||
change: i.current_index - i.prev_close,
|
||||
changePct: calcChangePct(i.current_index, i.prev_close),
|
||||
bidPrices: [],
|
||||
bidVolumes: [],
|
||||
askPrices: [],
|
||||
askVolumes: [],
|
||||
exchange: 'SZSE',
|
||||
} as QuoteData;
|
||||
}
|
||||
});
|
||||
|
||||
bonds.forEach((b: SZSEBondData) => {
|
||||
const rawCode = b.security_id;
|
||||
const fullCode = rawCode.includes('.') ? rawCode : `${rawCode}.SZ`;
|
||||
|
||||
if (subscribedCodes.has(rawCode) || subscribedCodes.has(fullCode)) {
|
||||
hasUpdate = true;
|
||||
updated[fullCode] = {
|
||||
code: fullCode,
|
||||
name: '',
|
||||
price: b.last_px,
|
||||
prevClose: b.prev_close,
|
||||
open: b.open_px,
|
||||
high: b.high_px,
|
||||
low: b.low_px,
|
||||
volume: b.volume,
|
||||
amount: b.amount,
|
||||
change: b.last_px - b.prev_close,
|
||||
changePct: calcChangePct(b.last_px, b.prev_close),
|
||||
bidPrices: [],
|
||||
bidVolumes: [],
|
||||
askPrices: [],
|
||||
askVolumes: [],
|
||||
exchange: 'SZSE',
|
||||
isBond: true,
|
||||
} as QuoteData;
|
||||
}
|
||||
});
|
||||
|
||||
return hasUpdate ? updated : null;
|
||||
};
|
||||
|
||||
/**
|
||||
* 实时行情 Hook
|
||||
* @param codes - 订阅的证券代码列表
|
||||
*/
|
||||
export const useRealtimeQuote = (codes: string[] = []): UseRealtimeQuoteReturn => {
|
||||
const [quotes, setQuotes] = useState<QuotesMap>({});
|
||||
const [connected, setConnected] = useState<ConnectionStatus>({ SSE: false, SZSE: false });
|
||||
|
||||
const wsRefs = useRef<Record<Exchange, WebSocket | null>>({ SSE: null, SZSE: null });
|
||||
const heartbeatRefs = useRef<Record<Exchange, NodeJS.Timeout | null>>({ SSE: null, SZSE: null });
|
||||
const reconnectRefs = useRef<Record<Exchange, NodeJS.Timeout | null>>({ SSE: null, SZSE: null });
|
||||
const subscribedCodes = useRef<Record<Exchange, Set<string>>>({
|
||||
SSE: new Set(),
|
||||
SZSE: new Set(),
|
||||
});
|
||||
|
||||
const stopHeartbeat = useCallback((exchange: Exchange) => {
|
||||
if (heartbeatRefs.current[exchange]) {
|
||||
clearInterval(heartbeatRefs.current[exchange]!);
|
||||
heartbeatRefs.current[exchange] = null;
|
||||
}
|
||||
}, []);
|
||||
|
||||
const startHeartbeat = useCallback((exchange: Exchange) => {
|
||||
stopHeartbeat(exchange);
|
||||
heartbeatRefs.current[exchange] = setInterval(() => {
|
||||
const ws = wsRefs.current[exchange];
|
||||
if (ws && ws.readyState === WebSocket.OPEN) {
|
||||
const msg = exchange === 'SSE' ? { action: 'ping' } : { type: 'ping' };
|
||||
ws.send(JSON.stringify(msg));
|
||||
}
|
||||
}, HEARTBEAT_INTERVAL);
|
||||
}, [stopHeartbeat]);
|
||||
|
||||
const handleMessage = useCallback((exchange: Exchange, msg: SSEMessage | SZSEMessage) => {
|
||||
if (msg.type === 'pong') return;
|
||||
|
||||
if (exchange === 'SSE') {
|
||||
const result = handleSSEMessage(
|
||||
msg as SSEMessage,
|
||||
subscribedCodes.current.SSE,
|
||||
{} // Will be merged with current state
|
||||
);
|
||||
if (result) {
|
||||
setQuotes(prev => ({ ...prev, ...result }));
|
||||
}
|
||||
} else {
|
||||
if (msg.type === 'realtime') {
|
||||
setQuotes(prev => {
|
||||
const result = handleSZSERealtimeMessage(
|
||||
msg as SZSERealtimeMessage,
|
||||
subscribedCodes.current.SZSE,
|
||||
prev
|
||||
);
|
||||
return result || prev;
|
||||
});
|
||||
} else if (msg.type === 'snapshot') {
|
||||
setQuotes(prev => {
|
||||
const result = handleSZSESnapshotMessage(
|
||||
msg as SZSESnapshotMessage,
|
||||
subscribedCodes.current.SZSE,
|
||||
prev
|
||||
);
|
||||
return result || prev;
|
||||
});
|
||||
}
|
||||
}
|
||||
}, []);
|
||||
|
||||
const createConnection = useCallback((exchange: Exchange) => {
|
||||
// 防御性检查:确保 HTTPS 页面不会意外连接 ws://(Mixed Content 安全错误)
|
||||
// 正常情况下 WS_CONFIG 会自动根据协议返回正确的 URL,这里是备用保护
|
||||
const isHttps = typeof window !== 'undefined' && window.location.protocol === 'https:';
|
||||
const wsUrl = WS_CONFIG[exchange];
|
||||
const isInsecureWs = wsUrl.startsWith('ws://');
|
||||
|
||||
if (isHttps && isInsecureWs) {
|
||||
logger.warn(
|
||||
'FlexScreen',
|
||||
`${exchange} WebSocket 配置错误:HTTPS 页面尝试连接 ws:// 端点,请检查 Nginx 代理配置`
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
if (wsRefs.current[exchange]) {
|
||||
wsRefs.current[exchange]!.close();
|
||||
}
|
||||
|
||||
try {
|
||||
const ws = new WebSocket(wsUrl);
|
||||
wsRefs.current[exchange] = ws;
|
||||
|
||||
ws.onopen = () => {
|
||||
logger.info('FlexScreen', `${exchange} WebSocket 已连接`);
|
||||
setConnected(prev => ({ ...prev, [exchange]: true }));
|
||||
|
||||
if (exchange === 'SSE') {
|
||||
// subscribedCodes 存的是带后缀的完整代码,发送给 WS 需要去掉后缀
|
||||
const fullCodes = Array.from(subscribedCodes.current.SSE);
|
||||
const baseCodes = fullCodes.map(c => normalizeCode(c));
|
||||
if (baseCodes.length > 0) {
|
||||
ws.send(JSON.stringify({
|
||||
action: 'subscribe',
|
||||
channels: ['stock', 'index'],
|
||||
codes: baseCodes,
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
startHeartbeat(exchange);
|
||||
};
|
||||
|
||||
ws.onmessage = (event: MessageEvent) => {
|
||||
try {
|
||||
const msg = JSON.parse(event.data);
|
||||
handleMessage(exchange, msg);
|
||||
} catch (e) {
|
||||
logger.warn('FlexScreen', `${exchange} 消息解析失败`, e);
|
||||
}
|
||||
};
|
||||
|
||||
ws.onerror = (error: Event) => {
|
||||
logger.error('FlexScreen', `${exchange} WebSocket 错误`, error);
|
||||
};
|
||||
|
||||
ws.onclose = () => {
|
||||
logger.info('FlexScreen', `${exchange} WebSocket 断开`);
|
||||
setConnected(prev => ({ ...prev, [exchange]: false }));
|
||||
stopHeartbeat(exchange);
|
||||
|
||||
// 自动重连(仅在非 HTTPS + ws:// 场景下)
|
||||
if (!reconnectRefs.current[exchange] && subscribedCodes.current[exchange].size > 0) {
|
||||
reconnectRefs.current[exchange] = setTimeout(() => {
|
||||
reconnectRefs.current[exchange] = null;
|
||||
if (subscribedCodes.current[exchange].size > 0) {
|
||||
createConnection(exchange);
|
||||
}
|
||||
}, RECONNECT_INTERVAL);
|
||||
}
|
||||
};
|
||||
} catch (e) {
|
||||
logger.error('FlexScreen', `${exchange} WebSocket 连接失败`, e);
|
||||
setConnected(prev => ({ ...prev, [exchange]: false }));
|
||||
}
|
||||
}, [startHeartbeat, stopHeartbeat, handleMessage]);
|
||||
|
||||
const subscribe = useCallback((code: string) => {
|
||||
const exchange = getExchange(code);
|
||||
// 确保使用带后缀的完整代码
|
||||
const fullCode = code.includes('.') ? code : `${code}.${exchange === 'SSE' ? 'SH' : 'SZ'}`;
|
||||
const baseCode = normalizeCode(code);
|
||||
|
||||
subscribedCodes.current[exchange].add(fullCode);
|
||||
|
||||
const ws = wsRefs.current[exchange];
|
||||
if (exchange === 'SSE' && ws && ws.readyState === WebSocket.OPEN) {
|
||||
ws.send(JSON.stringify({
|
||||
action: 'subscribe',
|
||||
channels: ['stock', 'index'],
|
||||
codes: [baseCode], // 发送给 WS 用不带后缀的代码
|
||||
}));
|
||||
}
|
||||
|
||||
if (!ws || ws.readyState !== WebSocket.OPEN) {
|
||||
createConnection(exchange);
|
||||
}
|
||||
}, [createConnection]);
|
||||
|
||||
const unsubscribe = useCallback((code: string) => {
|
||||
const exchange = getExchange(code);
|
||||
// 确保使用带后缀的完整代码
|
||||
const fullCode = code.includes('.') ? code : `${code}.${exchange === 'SSE' ? 'SH' : 'SZ'}`;
|
||||
|
||||
subscribedCodes.current[exchange].delete(fullCode);
|
||||
|
||||
setQuotes(prev => {
|
||||
const updated = { ...prev };
|
||||
delete updated[fullCode]; // 删除时也用带后缀的 key
|
||||
return updated;
|
||||
});
|
||||
|
||||
if (subscribedCodes.current[exchange].size === 0) {
|
||||
const ws = wsRefs.current[exchange];
|
||||
if (ws) {
|
||||
ws.close();
|
||||
wsRefs.current[exchange] = null;
|
||||
}
|
||||
}
|
||||
}, []);
|
||||
|
||||
// 初始化和 codes 变化处理
|
||||
// 注意:codes 现在是带后缀的完整代码(如 000001.SH)
|
||||
useEffect(() => {
|
||||
if (!codes || codes.length === 0) return;
|
||||
|
||||
// 使用带后缀的完整代码作为内部 key
|
||||
const newSseCodes = new Set<string>();
|
||||
const newSzseCodes = new Set<string>();
|
||||
|
||||
codes.forEach(code => {
|
||||
const exchange = getExchange(code);
|
||||
// 确保代码带后缀
|
||||
const fullCode = code.includes('.') ? code : `${code}.${exchange === 'SSE' ? 'SH' : 'SZ'}`;
|
||||
if (exchange === 'SSE') {
|
||||
newSseCodes.add(fullCode);
|
||||
} else {
|
||||
newSzseCodes.add(fullCode);
|
||||
}
|
||||
});
|
||||
|
||||
// 更新上交所订阅
|
||||
const oldSseCodes = subscribedCodes.current.SSE;
|
||||
const sseToAdd = [...newSseCodes].filter(c => !oldSseCodes.has(c));
|
||||
// 发送给 WebSocket 的代码需要去掉后缀
|
||||
const sseToAddBase = sseToAdd.map(c => normalizeCode(c));
|
||||
|
||||
if (sseToAdd.length > 0 || newSseCodes.size !== oldSseCodes.size) {
|
||||
subscribedCodes.current.SSE = newSseCodes;
|
||||
const ws = wsRefs.current.SSE;
|
||||
|
||||
if (ws && ws.readyState === WebSocket.OPEN && sseToAddBase.length > 0) {
|
||||
ws.send(JSON.stringify({
|
||||
action: 'subscribe',
|
||||
channels: ['stock', 'index'],
|
||||
codes: sseToAddBase,
|
||||
}));
|
||||
}
|
||||
|
||||
if (sseToAdd.length > 0 && (!ws || ws.readyState !== WebSocket.OPEN)) {
|
||||
createConnection('SSE');
|
||||
}
|
||||
|
||||
if (newSseCodes.size === 0 && ws) {
|
||||
ws.close();
|
||||
wsRefs.current.SSE = null;
|
||||
}
|
||||
}
|
||||
|
||||
// 更新深交所订阅
|
||||
const oldSzseCodes = subscribedCodes.current.SZSE;
|
||||
const szseToAdd = [...newSzseCodes].filter(c => !oldSzseCodes.has(c));
|
||||
|
||||
if (szseToAdd.length > 0 || newSzseCodes.size !== oldSzseCodes.size) {
|
||||
subscribedCodes.current.SZSE = newSzseCodes;
|
||||
const ws = wsRefs.current.SZSE;
|
||||
|
||||
if (szseToAdd.length > 0 && (!ws || ws.readyState !== WebSocket.OPEN)) {
|
||||
createConnection('SZSE');
|
||||
}
|
||||
|
||||
if (newSzseCodes.size === 0 && ws) {
|
||||
ws.close();
|
||||
wsRefs.current.SZSE = null;
|
||||
}
|
||||
}
|
||||
|
||||
// 清理已取消订阅的 quotes(使用带后缀的完整代码)
|
||||
const allNewCodes = new Set([...newSseCodes, ...newSzseCodes]);
|
||||
setQuotes(prev => {
|
||||
const updated: QuotesMap = {};
|
||||
Object.keys(prev).forEach(code => {
|
||||
if (allNewCodes.has(code)) {
|
||||
updated[code] = prev[code];
|
||||
}
|
||||
});
|
||||
return updated;
|
||||
});
|
||||
}, [codes, createConnection]);
|
||||
|
||||
// 清理
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
(['SSE', 'SZSE'] as Exchange[]).forEach(exchange => {
|
||||
stopHeartbeat(exchange);
|
||||
if (reconnectRefs.current[exchange]) {
|
||||
clearTimeout(reconnectRefs.current[exchange]!);
|
||||
}
|
||||
const ws = wsRefs.current[exchange];
|
||||
if (ws) {
|
||||
ws.close();
|
||||
}
|
||||
});
|
||||
};
|
||||
}, [stopHeartbeat]);
|
||||
|
||||
return { quotes, connected, subscribe, unsubscribe };
|
||||
};
|
||||
|
||||
export default useRealtimeQuote;
|
||||
148
src/views/StockOverview/components/FlexScreen/hooks/utils.ts
Normal file
148
src/views/StockOverview/components/FlexScreen/hooks/utils.ts
Normal file
@@ -0,0 +1,148 @@
|
||||
/**
|
||||
* 实时行情相关工具函数
|
||||
*/
|
||||
|
||||
import type { Exchange, OrderBookLevel } from '../types';
|
||||
|
||||
/**
|
||||
* 判断证券代码属于哪个交易所
|
||||
* @param code - 证券代码(可带或不带后缀)
|
||||
* @param isIndex - 是否为指数(用于区分同代码的指数和股票,如 000001)
|
||||
* @returns 交易所标识
|
||||
*/
|
||||
export const getExchange = (code: string, isIndex?: boolean): Exchange => {
|
||||
// 如果已带后缀,直接判断
|
||||
if (code.includes('.')) {
|
||||
return code.endsWith('.SH') ? 'SSE' : 'SZSE';
|
||||
}
|
||||
|
||||
const baseCode = code;
|
||||
|
||||
// 6开头为上海股票
|
||||
if (baseCode.startsWith('6')) {
|
||||
return 'SSE';
|
||||
}
|
||||
|
||||
// 5开头是上海 ETF
|
||||
if (baseCode.startsWith('5')) {
|
||||
return 'SSE';
|
||||
}
|
||||
|
||||
// 399开头是深证指数
|
||||
if (baseCode.startsWith('399')) {
|
||||
return 'SZSE';
|
||||
}
|
||||
|
||||
// 000开头:如果是指数则为上交所(上证指数000001),否则为深交所(平安银行000001)
|
||||
if (baseCode.startsWith('000')) {
|
||||
return isIndex ? 'SSE' : 'SZSE';
|
||||
}
|
||||
|
||||
// 0、3开头是深圳股票
|
||||
if (baseCode.startsWith('0') || baseCode.startsWith('3')) {
|
||||
return 'SZSE';
|
||||
}
|
||||
|
||||
// 1开头是深圳 ETF/债券
|
||||
if (baseCode.startsWith('1')) {
|
||||
return 'SZSE';
|
||||
}
|
||||
|
||||
// 默认上海
|
||||
return 'SSE';
|
||||
};
|
||||
|
||||
/**
|
||||
* 获取证券代码的完整格式(带交易所后缀)
|
||||
* @param code - 原始代码
|
||||
* @param isIndex - 是否为指数
|
||||
* @returns 带后缀的代码
|
||||
*/
|
||||
export const getFullCode = (code: string, isIndex?: boolean): string => {
|
||||
if (code.includes('.')) {
|
||||
return code; // 已带后缀
|
||||
}
|
||||
const exchange = getExchange(code, isIndex);
|
||||
return `${code}.${exchange === 'SSE' ? 'SH' : 'SZ'}`;
|
||||
};
|
||||
|
||||
/**
|
||||
* 标准化证券代码为无后缀格式
|
||||
* @param code - 原始代码
|
||||
* @returns 无后缀代码
|
||||
*/
|
||||
export const normalizeCode = (code: string): string => {
|
||||
return code.split('.')[0];
|
||||
};
|
||||
|
||||
/**
|
||||
* 盘口数据可能的格式(根据不同的 WebSocket 服务端实现)
|
||||
*/
|
||||
type OrderBookInput =
|
||||
| OrderBookLevel[] // 格式1: [{price, volume}, ...]
|
||||
| Array<[number, number]> // 格式2: [[price, volume], ...]
|
||||
| { prices: number[]; volumes: number[] } // 格式3: {prices: [...], volumes: [...]}
|
||||
| undefined;
|
||||
|
||||
/**
|
||||
* 从深交所 bids/asks 数组提取价格和量数组
|
||||
* 支持多种可能的数据格式
|
||||
* @param orderBook - 盘口数据,支持多种格式
|
||||
* @returns { prices, volumes }
|
||||
*/
|
||||
export const extractOrderBook = (
|
||||
orderBook: OrderBookInput
|
||||
): { prices: number[]; volumes: number[] } => {
|
||||
if (!orderBook) {
|
||||
return { prices: [], volumes: [] };
|
||||
}
|
||||
|
||||
// 格式3: 已经是 {prices, volumes} 结构
|
||||
if (!Array.isArray(orderBook) && 'prices' in orderBook && 'volumes' in orderBook) {
|
||||
return {
|
||||
prices: orderBook.prices || [],
|
||||
volumes: orderBook.volumes || [],
|
||||
};
|
||||
}
|
||||
|
||||
// 必须是数组才能继续
|
||||
if (!Array.isArray(orderBook) || orderBook.length === 0) {
|
||||
return { prices: [], volumes: [] };
|
||||
}
|
||||
|
||||
const firstItem = orderBook[0];
|
||||
|
||||
// 格式2: [[price, volume], ...]
|
||||
if (Array.isArray(firstItem)) {
|
||||
const prices = orderBook.map((item: unknown) => {
|
||||
const arr = item as [number, number];
|
||||
return arr[0] || 0;
|
||||
});
|
||||
const volumes = orderBook.map((item: unknown) => {
|
||||
const arr = item as [number, number];
|
||||
return arr[1] || 0;
|
||||
});
|
||||
return { prices, volumes };
|
||||
}
|
||||
|
||||
// 格式1: [{price, volume}, ...] (标准格式)
|
||||
if (typeof firstItem === 'object' && firstItem !== null) {
|
||||
const typedBook = orderBook as OrderBookLevel[];
|
||||
const prices = typedBook.map(item => item.price || 0);
|
||||
const volumes = typedBook.map(item => item.volume || 0);
|
||||
return { prices, volumes };
|
||||
}
|
||||
|
||||
return { prices: [], volumes: [] };
|
||||
};
|
||||
|
||||
/**
|
||||
* 计算涨跌幅
|
||||
* @param price - 当前价
|
||||
* @param prevClose - 昨收价
|
||||
* @returns 涨跌幅百分比
|
||||
*/
|
||||
export const calcChangePct = (price: number, prevClose: number): number => {
|
||||
if (!prevClose || prevClose === 0) return 0;
|
||||
return ((price - prevClose) / prevClose) * 100;
|
||||
};
|
||||
507
src/views/StockOverview/components/FlexScreen/index.tsx
Normal file
507
src/views/StockOverview/components/FlexScreen/index.tsx
Normal file
@@ -0,0 +1,507 @@
|
||||
/**
|
||||
* 灵活屏组件
|
||||
* 用户可自定义添加关注的指数/个股,实时显示行情
|
||||
*
|
||||
* 功能:
|
||||
* 1. 添加/删除自选证券
|
||||
* 2. 显示实时行情(通过 WebSocket)
|
||||
* 3. 显示分时走势(结合 ClickHouse 历史数据)
|
||||
* 4. 显示五档盘口(上交所5档,深交所10档)
|
||||
* 5. 本地存储自选列表
|
||||
*/
|
||||
import React, { useState, useEffect, useCallback, useMemo } from 'react';
|
||||
import {
|
||||
Box,
|
||||
Card,
|
||||
CardBody,
|
||||
VStack,
|
||||
HStack,
|
||||
Heading,
|
||||
Text,
|
||||
Input,
|
||||
InputGroup,
|
||||
InputLeftElement,
|
||||
InputRightElement,
|
||||
IconButton,
|
||||
SimpleGrid,
|
||||
Flex,
|
||||
Spacer,
|
||||
Icon,
|
||||
useColorModeValue,
|
||||
useToast,
|
||||
Badge,
|
||||
Tooltip,
|
||||
Collapse,
|
||||
List,
|
||||
ListItem,
|
||||
Spinner,
|
||||
Center,
|
||||
Menu,
|
||||
MenuButton,
|
||||
MenuList,
|
||||
MenuItem,
|
||||
Tag,
|
||||
TagLabel,
|
||||
} from '@chakra-ui/react';
|
||||
import {
|
||||
SearchIcon,
|
||||
CloseIcon,
|
||||
AddIcon,
|
||||
ChevronDownIcon,
|
||||
ChevronUpIcon,
|
||||
SettingsIcon,
|
||||
} from '@chakra-ui/icons';
|
||||
import {
|
||||
FaDesktop,
|
||||
FaTrash,
|
||||
FaSync,
|
||||
FaWifi,
|
||||
FaExclamationCircle,
|
||||
} from 'react-icons/fa';
|
||||
|
||||
import { useRealtimeQuote } from './hooks';
|
||||
import { getFullCode } from './hooks/utils';
|
||||
import QuoteTile from './components/QuoteTile';
|
||||
import { logger } from '@utils/logger';
|
||||
import type { WatchlistItem, ConnectionStatus } from './types';
|
||||
|
||||
// 本地存储 key
|
||||
const STORAGE_KEY = 'flexscreen_watchlist';
|
||||
|
||||
// 默认自选列表
|
||||
const DEFAULT_WATCHLIST: WatchlistItem[] = [
|
||||
{ code: '000001', name: '上证指数', isIndex: true },
|
||||
{ code: '399001', name: '深证成指', isIndex: true },
|
||||
{ code: '399006', name: '创业板指', isIndex: true },
|
||||
];
|
||||
|
||||
// 热门推荐
|
||||
const HOT_RECOMMENDATIONS: WatchlistItem[] = [
|
||||
{ code: '000001', name: '上证指数', isIndex: true },
|
||||
{ code: '399001', name: '深证成指', isIndex: true },
|
||||
{ code: '399006', name: '创业板指', isIndex: true },
|
||||
{ code: '399300', name: '沪深300', isIndex: true },
|
||||
{ code: '600519', name: '贵州茅台', isIndex: false },
|
||||
{ code: '000858', name: '五粮液', isIndex: false },
|
||||
{ code: '300750', name: '宁德时代', isIndex: false },
|
||||
{ code: '002594', name: '比亚迪', isIndex: false },
|
||||
];
|
||||
|
||||
/** 搜索结果项 */
|
||||
interface SearchResultItem {
|
||||
stock_code: string;
|
||||
stock_name: string;
|
||||
isIndex?: boolean;
|
||||
code?: string;
|
||||
name?: string;
|
||||
}
|
||||
|
||||
/** 搜索 API 响应 */
|
||||
interface SearchApiResponse {
|
||||
success: boolean;
|
||||
data?: SearchResultItem[];
|
||||
}
|
||||
|
||||
/** 连接状态信息 */
|
||||
interface ConnectionStatusInfo {
|
||||
color: string;
|
||||
text: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* FlexScreen 组件
|
||||
*/
|
||||
const FlexScreen: React.FC = () => {
|
||||
const toast = useToast();
|
||||
|
||||
// 自选列表
|
||||
const [watchlist, setWatchlist] = useState<WatchlistItem[]>([]);
|
||||
// 搜索状态
|
||||
const [searchQuery, setSearchQuery] = useState('');
|
||||
const [searchResults, setSearchResults] = useState<SearchResultItem[]>([]);
|
||||
const [isSearching, setIsSearching] = useState(false);
|
||||
const [showResults, setShowResults] = useState(false);
|
||||
// 面板状态
|
||||
const [isCollapsed, setIsCollapsed] = useState(false);
|
||||
|
||||
// 颜色主题
|
||||
const cardBg = useColorModeValue('white', '#1a1a1a');
|
||||
const borderColor = useColorModeValue('gray.200', '#333');
|
||||
const textColor = useColorModeValue('gray.800', 'white');
|
||||
const subTextColor = useColorModeValue('gray.600', 'gray.400');
|
||||
const searchBg = useColorModeValue('gray.50', '#2a2a2a');
|
||||
const hoverBg = useColorModeValue('gray.100', '#333');
|
||||
|
||||
// 获取订阅的证券代码列表(带后缀,用于区分上证指数000001.SH和平安银行000001.SZ)
|
||||
const subscribedCodes = useMemo(() => {
|
||||
return watchlist.map(item => getFullCode(item.code, item.isIndex));
|
||||
}, [watchlist]);
|
||||
|
||||
// WebSocket 实时行情
|
||||
const { quotes, connected } = useRealtimeQuote(subscribedCodes);
|
||||
|
||||
// 从本地存储加载自选列表
|
||||
useEffect(() => {
|
||||
try {
|
||||
const saved = localStorage.getItem(STORAGE_KEY);
|
||||
if (saved) {
|
||||
const parsed = JSON.parse(saved) as WatchlistItem[];
|
||||
if (Array.isArray(parsed) && parsed.length > 0) {
|
||||
setWatchlist(parsed);
|
||||
return;
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
logger.warn('FlexScreen', '加载自选列表失败', e);
|
||||
}
|
||||
// 使用默认列表
|
||||
setWatchlist(DEFAULT_WATCHLIST);
|
||||
}, []);
|
||||
|
||||
// 保存自选列表到本地存储
|
||||
useEffect(() => {
|
||||
if (watchlist.length > 0) {
|
||||
try {
|
||||
localStorage.setItem(STORAGE_KEY, JSON.stringify(watchlist));
|
||||
} catch (e) {
|
||||
logger.warn('FlexScreen', '保存自选列表失败', e);
|
||||
}
|
||||
}
|
||||
}, [watchlist]);
|
||||
|
||||
// 搜索证券
|
||||
const searchSecurities = useCallback(async (query: string): Promise<void> => {
|
||||
if (!query.trim()) {
|
||||
setSearchResults([]);
|
||||
setShowResults(false);
|
||||
return;
|
||||
}
|
||||
|
||||
setIsSearching(true);
|
||||
try {
|
||||
const response = await fetch(`/api/stocks/search?q=${encodeURIComponent(query)}&limit=10`);
|
||||
const data: SearchApiResponse = await response.json();
|
||||
|
||||
if (data.success) {
|
||||
setSearchResults(data.data || []);
|
||||
setShowResults(true);
|
||||
} else {
|
||||
setSearchResults([]);
|
||||
}
|
||||
} catch (e) {
|
||||
logger.error('FlexScreen', '搜索失败', e);
|
||||
setSearchResults([]);
|
||||
} finally {
|
||||
setIsSearching(false);
|
||||
}
|
||||
}, []);
|
||||
|
||||
// 防抖搜索
|
||||
useEffect(() => {
|
||||
const timer = setTimeout(() => {
|
||||
searchSecurities(searchQuery);
|
||||
}, 300);
|
||||
return () => clearTimeout(timer);
|
||||
}, [searchQuery, searchSecurities]);
|
||||
|
||||
// 添加证券
|
||||
const addSecurity = useCallback(
|
||||
(security: SearchResultItem | WatchlistItem): void => {
|
||||
const code = 'stock_code' in security ? security.stock_code : security.code;
|
||||
const name = 'stock_name' in security ? security.stock_name : security.name;
|
||||
// 优先使用 API 返回的 isIndex 字段
|
||||
const isIndex = security.isIndex === true;
|
||||
|
||||
// 生成唯一标识(带后缀的完整代码)
|
||||
const fullCode = getFullCode(code, isIndex);
|
||||
|
||||
// 检查是否已存在(使用带后缀的代码比较,避免上证指数和平安银行冲突)
|
||||
if (watchlist.some(item => getFullCode(item.code, item.isIndex) === fullCode)) {
|
||||
toast({
|
||||
title: '已在自选列表中',
|
||||
status: 'info',
|
||||
duration: 2000,
|
||||
isClosable: true,
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
// 添加到列表
|
||||
setWatchlist(prev => [...prev, { code, name, isIndex }]);
|
||||
|
||||
toast({
|
||||
title: `已添加 ${name}${isIndex ? '(指数)' : ''}`,
|
||||
status: 'success',
|
||||
duration: 2000,
|
||||
isClosable: true,
|
||||
});
|
||||
|
||||
// 清空搜索
|
||||
setSearchQuery('');
|
||||
setShowResults(false);
|
||||
},
|
||||
[watchlist, toast]
|
||||
);
|
||||
|
||||
// 移除证券
|
||||
const removeSecurity = useCallback((code: string): void => {
|
||||
setWatchlist(prev => prev.filter(item => item.code !== code));
|
||||
}, []);
|
||||
|
||||
// 清空自选列表
|
||||
const clearWatchlist = useCallback((): void => {
|
||||
setWatchlist([]);
|
||||
localStorage.removeItem(STORAGE_KEY);
|
||||
toast({
|
||||
title: '已清空自选列表',
|
||||
status: 'info',
|
||||
duration: 2000,
|
||||
isClosable: true,
|
||||
});
|
||||
}, [toast]);
|
||||
|
||||
// 重置为默认列表
|
||||
const resetWatchlist = useCallback((): void => {
|
||||
setWatchlist(DEFAULT_WATCHLIST);
|
||||
toast({
|
||||
title: '已重置为默认列表',
|
||||
status: 'success',
|
||||
duration: 2000,
|
||||
isClosable: true,
|
||||
});
|
||||
}, [toast]);
|
||||
|
||||
// 连接状态指示
|
||||
const isAnyConnected = connected.SSE || connected.SZSE;
|
||||
const connectionStatus = useMemo((): ConnectionStatusInfo => {
|
||||
if (connected.SSE && connected.SZSE) {
|
||||
return { color: 'green', text: '上交所/深交所 已连接' };
|
||||
}
|
||||
if (connected.SSE) {
|
||||
return { color: 'yellow', text: '上交所 已连接' };
|
||||
}
|
||||
if (connected.SZSE) {
|
||||
return { color: 'yellow', text: '深交所 已连接' };
|
||||
}
|
||||
return { color: 'red', text: '未连接' };
|
||||
}, [connected]);
|
||||
|
||||
return (
|
||||
<Card bg={cardBg} borderWidth="1px" borderColor={borderColor}>
|
||||
<CardBody>
|
||||
{/* 头部 */}
|
||||
<Flex align="center" mb={4}>
|
||||
<HStack spacing={3}>
|
||||
<Icon as={FaDesktop} boxSize={6} color="purple.500" />
|
||||
<Heading size="md" color={textColor}>
|
||||
灵活屏
|
||||
</Heading>
|
||||
<Tooltip label={connectionStatus.text}>
|
||||
<Badge
|
||||
colorScheme={connectionStatus.color}
|
||||
variant="subtle"
|
||||
display="flex"
|
||||
alignItems="center"
|
||||
gap={1}
|
||||
>
|
||||
<Icon as={FaWifi} boxSize={3} />
|
||||
{isAnyConnected ? '实时' : '离线'}
|
||||
</Badge>
|
||||
</Tooltip>
|
||||
</HStack>
|
||||
<Spacer />
|
||||
<HStack spacing={2}>
|
||||
{/* 操作菜单 */}
|
||||
<Menu>
|
||||
<MenuButton
|
||||
as={IconButton}
|
||||
icon={<SettingsIcon />}
|
||||
size="sm"
|
||||
variant="ghost"
|
||||
aria-label="设置"
|
||||
/>
|
||||
<MenuList>
|
||||
<MenuItem icon={<FaSync />} onClick={resetWatchlist}>
|
||||
重置为默认
|
||||
</MenuItem>
|
||||
<MenuItem icon={<FaTrash />} onClick={clearWatchlist} color="red.500">
|
||||
清空列表
|
||||
</MenuItem>
|
||||
</MenuList>
|
||||
</Menu>
|
||||
{/* 折叠按钮 */}
|
||||
<IconButton
|
||||
icon={isCollapsed ? <ChevronDownIcon /> : <ChevronUpIcon />}
|
||||
size="sm"
|
||||
variant="ghost"
|
||||
onClick={() => setIsCollapsed(!isCollapsed)}
|
||||
aria-label={isCollapsed ? '展开' : '收起'}
|
||||
/>
|
||||
</HStack>
|
||||
</Flex>
|
||||
|
||||
{/* 可折叠内容 */}
|
||||
<Collapse in={!isCollapsed} animateOpacity>
|
||||
{/* 搜索框 */}
|
||||
<Box position="relative" mb={4}>
|
||||
<InputGroup size="md">
|
||||
<InputLeftElement pointerEvents="none">
|
||||
<SearchIcon color="gray.400" />
|
||||
</InputLeftElement>
|
||||
<Input
|
||||
placeholder="搜索股票/指数代码或名称..."
|
||||
value={searchQuery}
|
||||
onChange={e => setSearchQuery(e.target.value)}
|
||||
bg={searchBg}
|
||||
borderRadius="lg"
|
||||
_focus={{
|
||||
borderColor: 'purple.400',
|
||||
boxShadow: '0 0 0 1px var(--chakra-colors-purple-400)',
|
||||
}}
|
||||
/>
|
||||
{searchQuery && (
|
||||
<InputRightElement>
|
||||
<IconButton
|
||||
size="sm"
|
||||
icon={<CloseIcon />}
|
||||
variant="ghost"
|
||||
onClick={() => {
|
||||
setSearchQuery('');
|
||||
setShowResults(false);
|
||||
}}
|
||||
aria-label="清空"
|
||||
/>
|
||||
</InputRightElement>
|
||||
)}
|
||||
</InputGroup>
|
||||
|
||||
{/* 搜索结果下拉 */}
|
||||
<Collapse in={showResults} animateOpacity>
|
||||
<Box
|
||||
position="absolute"
|
||||
top="100%"
|
||||
left={0}
|
||||
right={0}
|
||||
mt={1}
|
||||
bg={cardBg}
|
||||
borderWidth="1px"
|
||||
borderColor={borderColor}
|
||||
borderRadius="lg"
|
||||
boxShadow="lg"
|
||||
maxH="300px"
|
||||
overflowY="auto"
|
||||
zIndex={10}
|
||||
>
|
||||
{isSearching ? (
|
||||
<Center p={4}>
|
||||
<Spinner size="sm" color="purple.500" />
|
||||
</Center>
|
||||
) : searchResults.length > 0 ? (
|
||||
<List spacing={0}>
|
||||
{searchResults.map((stock, index) => (
|
||||
<ListItem
|
||||
key={`${stock.stock_code}-${stock.isIndex ? 'index' : 'stock'}`}
|
||||
px={4}
|
||||
py={2}
|
||||
cursor="pointer"
|
||||
_hover={{ bg: hoverBg }}
|
||||
onClick={() => addSecurity(stock)}
|
||||
borderBottomWidth={index < searchResults.length - 1 ? '1px' : '0'}
|
||||
borderColor={borderColor}
|
||||
>
|
||||
<HStack justify="space-between">
|
||||
<VStack align="start" spacing={0}>
|
||||
<HStack spacing={2}>
|
||||
<Text fontWeight="medium" color={textColor}>
|
||||
{stock.stock_name}
|
||||
</Text>
|
||||
<Badge
|
||||
colorScheme={stock.isIndex ? 'purple' : 'blue'}
|
||||
fontSize="xs"
|
||||
variant="subtle"
|
||||
>
|
||||
{stock.isIndex ? '指数' : '股票'}
|
||||
</Badge>
|
||||
</HStack>
|
||||
<Text fontSize="xs" color={subTextColor}>
|
||||
{stock.stock_code}
|
||||
</Text>
|
||||
</VStack>
|
||||
<IconButton
|
||||
icon={<AddIcon />}
|
||||
size="xs"
|
||||
colorScheme="purple"
|
||||
variant="ghost"
|
||||
aria-label="添加"
|
||||
/>
|
||||
</HStack>
|
||||
</ListItem>
|
||||
))}
|
||||
</List>
|
||||
) : (
|
||||
<Center p={4}>
|
||||
<Text color={subTextColor} fontSize="sm">
|
||||
未找到相关证券
|
||||
</Text>
|
||||
</Center>
|
||||
)}
|
||||
</Box>
|
||||
</Collapse>
|
||||
</Box>
|
||||
|
||||
{/* 快捷添加 */}
|
||||
{watchlist.length === 0 && (
|
||||
<Box mb={4}>
|
||||
<Text fontSize="sm" color={subTextColor} mb={2}>
|
||||
热门推荐(点击添加)
|
||||
</Text>
|
||||
<Flex flexWrap="wrap" gap={2}>
|
||||
{HOT_RECOMMENDATIONS.map(item => (
|
||||
<Tag
|
||||
key={item.code}
|
||||
size="md"
|
||||
variant="subtle"
|
||||
colorScheme="purple"
|
||||
cursor="pointer"
|
||||
_hover={{ bg: 'purple.100' }}
|
||||
onClick={() => addSecurity(item)}
|
||||
>
|
||||
<TagLabel>{item.name}</TagLabel>
|
||||
</Tag>
|
||||
))}
|
||||
</Flex>
|
||||
</Box>
|
||||
)}
|
||||
|
||||
{/* 自选列表 */}
|
||||
{watchlist.length > 0 ? (
|
||||
<SimpleGrid columns={{ base: 1, md: 2, lg: 3 }} spacing={4}>
|
||||
{watchlist.map(item => {
|
||||
const fullCode = getFullCode(item.code, item.isIndex);
|
||||
return (
|
||||
<QuoteTile
|
||||
key={fullCode}
|
||||
code={item.code}
|
||||
name={item.name}
|
||||
quote={quotes[fullCode] || {}}
|
||||
isIndex={item.isIndex}
|
||||
onRemove={removeSecurity}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
</SimpleGrid>
|
||||
) : (
|
||||
<Center py={8}>
|
||||
<VStack spacing={3}>
|
||||
<Icon as={FaExclamationCircle} boxSize={10} color="gray.300" />
|
||||
<Text color={subTextColor}>自选列表为空,请搜索添加证券</Text>
|
||||
</VStack>
|
||||
</Center>
|
||||
)}
|
||||
</Collapse>
|
||||
</CardBody>
|
||||
</Card>
|
||||
);
|
||||
};
|
||||
|
||||
export default FlexScreen;
|
||||
322
src/views/StockOverview/components/FlexScreen/types.ts
Normal file
322
src/views/StockOverview/components/FlexScreen/types.ts
Normal file
@@ -0,0 +1,322 @@
|
||||
/**
|
||||
* 灵活屏组件类型定义
|
||||
*/
|
||||
|
||||
// ==================== WebSocket 相关类型 ====================
|
||||
|
||||
/** 交易所标识 */
|
||||
export type Exchange = 'SSE' | 'SZSE';
|
||||
|
||||
/** WebSocket 连接状态 */
|
||||
export interface ConnectionStatus {
|
||||
SSE: boolean;
|
||||
SZSE: boolean;
|
||||
}
|
||||
|
||||
/** 盘口档位数据 */
|
||||
export interface OrderBookLevel {
|
||||
price: number;
|
||||
volume: number;
|
||||
}
|
||||
|
||||
// ==================== 行情数据类型 ====================
|
||||
|
||||
/** 盘后交易数据 */
|
||||
export interface AfterhoursData {
|
||||
bidPx: number;
|
||||
bidSize: number;
|
||||
offerPx: number;
|
||||
offerSize: number;
|
||||
volume: number;
|
||||
amount: number;
|
||||
numTrades: number;
|
||||
}
|
||||
|
||||
/** 基础行情数据 */
|
||||
export interface BaseQuoteData {
|
||||
code: string;
|
||||
name: string;
|
||||
price: number;
|
||||
prevClose: number;
|
||||
open: number;
|
||||
high: number;
|
||||
low: number;
|
||||
volume: number;
|
||||
amount: number;
|
||||
change: number;
|
||||
changePct: number;
|
||||
updateTime?: string;
|
||||
exchange: Exchange;
|
||||
}
|
||||
|
||||
/** 股票行情数据 */
|
||||
export interface StockQuoteData extends BaseQuoteData {
|
||||
numTrades?: number;
|
||||
upperLimit?: number; // 涨停价
|
||||
lowerLimit?: number; // 跌停价
|
||||
bidPrices: number[];
|
||||
bidVolumes: number[];
|
||||
askPrices: number[];
|
||||
askVolumes: number[];
|
||||
tradingPhase?: string;
|
||||
afterhours?: AfterhoursData; // 盘后交易数据
|
||||
}
|
||||
|
||||
/** 指数行情数据 */
|
||||
export interface IndexQuoteData extends BaseQuoteData {
|
||||
close?: number;
|
||||
numTrades?: number;
|
||||
bidPrices: number[];
|
||||
bidVolumes: number[];
|
||||
askPrices: number[];
|
||||
askVolumes: number[];
|
||||
tradingPhase?: string;
|
||||
}
|
||||
|
||||
/** 债券行情数据 */
|
||||
export interface BondQuoteData extends BaseQuoteData {
|
||||
numTrades?: number;
|
||||
weightedAvgPx?: number;
|
||||
bidPrices: number[];
|
||||
bidVolumes: number[];
|
||||
askPrices: number[];
|
||||
askVolumes: number[];
|
||||
tradingPhase?: string;
|
||||
isBond: true;
|
||||
}
|
||||
|
||||
/** 港股行情数据 */
|
||||
export interface HKStockQuoteData extends BaseQuoteData {
|
||||
numTrades?: number;
|
||||
nominalPx?: number; // 按盘价
|
||||
referencePx?: number; // 参考价
|
||||
bidPrices: number[];
|
||||
bidVolumes: number[];
|
||||
askPrices: number[];
|
||||
askVolumes: number[];
|
||||
tradingPhase?: string;
|
||||
isHK: true;
|
||||
}
|
||||
|
||||
/** 统一行情数据类型 */
|
||||
export type QuoteData = StockQuoteData | IndexQuoteData | BondQuoteData | HKStockQuoteData;
|
||||
|
||||
/** 行情数据字典 */
|
||||
export interface QuotesMap {
|
||||
[code: string]: QuoteData;
|
||||
}
|
||||
|
||||
// ==================== 上交所 WebSocket 消息类型 ====================
|
||||
|
||||
/** 上交所行情数据 */
|
||||
export interface SSEQuoteItem {
|
||||
security_id: string;
|
||||
security_name: string;
|
||||
prev_close: number;
|
||||
open_price: number;
|
||||
high_price: number;
|
||||
low_price: number;
|
||||
last_price: number;
|
||||
close_price: number;
|
||||
volume: number;
|
||||
amount: number;
|
||||
bid_prices?: number[];
|
||||
bid_volumes?: number[];
|
||||
ask_prices?: number[];
|
||||
ask_volumes?: number[];
|
||||
trading_status?: string;
|
||||
trade_time?: string;
|
||||
update_time?: string;
|
||||
}
|
||||
|
||||
/** 上交所消息 */
|
||||
export interface SSEMessage {
|
||||
type: 'stock' | 'index' | 'etf' | 'bond' | 'option' | 'subscribed' | 'pong' | 'error';
|
||||
timestamp?: string;
|
||||
data?: Record<string, SSEQuoteItem>;
|
||||
channels?: string[];
|
||||
message?: string;
|
||||
}
|
||||
|
||||
// ==================== 深交所 WebSocket 消息类型 ====================
|
||||
|
||||
/** 深交所数据类别 */
|
||||
export type SZSECategory =
|
||||
| 'stock' // 300111 股票快照
|
||||
| 'bond' // 300211 债券快照
|
||||
| 'afterhours_block' // 300611 盘后定价大宗交易
|
||||
| 'afterhours_trading' // 303711 盘后定价交易
|
||||
| 'hk_stock' // 306311 港股快照
|
||||
| 'index' // 309011 指数快照
|
||||
| 'volume_stats' // 309111 成交量统计
|
||||
| 'fund_nav'; // 309211 基金净值
|
||||
|
||||
/** 深交所股票行情数据 */
|
||||
export interface SZSEStockData {
|
||||
security_id: string;
|
||||
orig_time?: number;
|
||||
channel_no?: number;
|
||||
trading_phase?: string;
|
||||
last_px: number;
|
||||
open_px: number;
|
||||
high_px: number;
|
||||
low_px: number;
|
||||
prev_close: number;
|
||||
volume: number;
|
||||
amount: number;
|
||||
num_trades?: number;
|
||||
upper_limit?: number;
|
||||
lower_limit?: number;
|
||||
bids?: OrderBookLevel[];
|
||||
asks?: OrderBookLevel[];
|
||||
}
|
||||
|
||||
/** 深交所指数行情数据 */
|
||||
export interface SZSEIndexData {
|
||||
security_id: string;
|
||||
orig_time?: number;
|
||||
channel_no?: number;
|
||||
trading_phase?: string;
|
||||
current_index: number;
|
||||
open_index: number;
|
||||
high_index: number;
|
||||
low_index: number;
|
||||
close_index?: number;
|
||||
prev_close: number;
|
||||
volume: number;
|
||||
amount: number;
|
||||
num_trades?: number;
|
||||
}
|
||||
|
||||
/** 深交所债券行情数据 */
|
||||
export interface SZSEBondData {
|
||||
security_id: string;
|
||||
orig_time?: number;
|
||||
channel_no?: number;
|
||||
trading_phase?: string;
|
||||
last_px: number;
|
||||
open_px: number;
|
||||
high_px: number;
|
||||
low_px: number;
|
||||
prev_close: number;
|
||||
weighted_avg_px?: number;
|
||||
volume: number;
|
||||
amount: number;
|
||||
num_trades?: number;
|
||||
auction_volume?: number;
|
||||
auction_amount?: number;
|
||||
}
|
||||
|
||||
/** 深交所港股行情数据 */
|
||||
export interface SZSEHKStockData {
|
||||
security_id: string;
|
||||
orig_time?: number;
|
||||
channel_no?: number;
|
||||
trading_phase?: string;
|
||||
last_px: number;
|
||||
open_px: number;
|
||||
high_px: number;
|
||||
low_px: number;
|
||||
prev_close: number;
|
||||
nominal_px?: number;
|
||||
reference_px?: number;
|
||||
volume: number;
|
||||
amount: number;
|
||||
num_trades?: number;
|
||||
vcm_start_time?: number;
|
||||
vcm_end_time?: number;
|
||||
bids?: OrderBookLevel[];
|
||||
asks?: OrderBookLevel[];
|
||||
}
|
||||
|
||||
/** 深交所盘后交易数据 */
|
||||
export interface SZSEAfterhoursData {
|
||||
security_id: string;
|
||||
orig_time?: number;
|
||||
channel_no?: number;
|
||||
trading_phase?: string;
|
||||
prev_close: number;
|
||||
bid_px: number;
|
||||
bid_size: number;
|
||||
offer_px: number;
|
||||
offer_size: number;
|
||||
volume: number;
|
||||
amount: number;
|
||||
num_trades?: number;
|
||||
}
|
||||
|
||||
/** 深交所实时消息 */
|
||||
export interface SZSERealtimeMessage {
|
||||
type: 'realtime';
|
||||
category: SZSECategory;
|
||||
msg_type?: number;
|
||||
timestamp: string;
|
||||
data: SZSEStockData | SZSEIndexData | SZSEBondData | SZSEHKStockData | SZSEAfterhoursData;
|
||||
}
|
||||
|
||||
/** 深交所快照消息 */
|
||||
export interface SZSESnapshotMessage {
|
||||
type: 'snapshot';
|
||||
timestamp: string;
|
||||
data: {
|
||||
stocks?: SZSEStockData[];
|
||||
indexes?: SZSEIndexData[];
|
||||
bonds?: SZSEBondData[];
|
||||
};
|
||||
}
|
||||
|
||||
/** 深交所消息类型 */
|
||||
export type SZSEMessage = SZSERealtimeMessage | SZSESnapshotMessage | { type: 'pong' };
|
||||
|
||||
// ==================== 组件 Props 类型 ====================
|
||||
|
||||
/** 自选证券项 */
|
||||
export interface WatchlistItem {
|
||||
code: string;
|
||||
name: string;
|
||||
isIndex: boolean;
|
||||
}
|
||||
|
||||
/** QuoteTile 组件 Props */
|
||||
export interface QuoteTileProps {
|
||||
code: string;
|
||||
name: string;
|
||||
quote: Partial<QuoteData>;
|
||||
isIndex?: boolean;
|
||||
onRemove?: (code: string) => void;
|
||||
}
|
||||
|
||||
/** OrderBookPanel 组件 Props */
|
||||
export interface OrderBookPanelProps {
|
||||
bidPrices?: number[];
|
||||
bidVolumes?: number[];
|
||||
askPrices?: number[];
|
||||
askVolumes?: number[];
|
||||
prevClose?: number;
|
||||
upperLimit?: number;
|
||||
lowerLimit?: number;
|
||||
defaultLevels?: number;
|
||||
}
|
||||
|
||||
/** MiniTimelineChart 组件 Props */
|
||||
export interface MiniTimelineChartProps {
|
||||
code: string;
|
||||
isIndex?: boolean;
|
||||
prevClose?: number;
|
||||
currentPrice?: number;
|
||||
height?: number;
|
||||
}
|
||||
|
||||
/** 分时数据点 */
|
||||
export interface TimelineDataPoint {
|
||||
time: string;
|
||||
price: number;
|
||||
}
|
||||
|
||||
/** useRealtimeQuote Hook 返回值 */
|
||||
export interface UseRealtimeQuoteReturn {
|
||||
quotes: QuotesMap;
|
||||
connected: ConnectionStatus;
|
||||
subscribe: (code: string) => void;
|
||||
unsubscribe: (code: string) => void;
|
||||
}
|
||||
@@ -0,0 +1,147 @@
|
||||
/**
|
||||
* 异动统计摘要组件
|
||||
* 展示指数统计和异动类型统计
|
||||
*/
|
||||
import React from 'react';
|
||||
import {
|
||||
Box,
|
||||
HStack,
|
||||
VStack,
|
||||
Text,
|
||||
Badge,
|
||||
Icon,
|
||||
Stat,
|
||||
StatLabel,
|
||||
StatNumber,
|
||||
StatHelpText,
|
||||
StatArrow,
|
||||
SimpleGrid,
|
||||
useColorModeValue,
|
||||
} from '@chakra-ui/react';
|
||||
import { FaBolt, FaArrowDown, FaRocket, FaChartLine, FaFire, FaVolumeUp } from 'react-icons/fa';
|
||||
|
||||
/**
|
||||
* 异动类型徽章
|
||||
*/
|
||||
const AlertTypeBadge = ({ type, count }) => {
|
||||
const config = {
|
||||
surge: { label: '急涨', color: 'red', icon: FaBolt },
|
||||
surge_up: { label: '暴涨', color: 'red', icon: FaBolt },
|
||||
surge_down: { label: '暴跌', color: 'green', icon: FaArrowDown },
|
||||
limit_up: { label: '涨停', color: 'orange', icon: FaRocket },
|
||||
rank_jump: { label: '排名跃升', color: 'blue', icon: FaChartLine },
|
||||
volume_spike: { label: '放量', color: 'purple', icon: FaVolumeUp },
|
||||
};
|
||||
|
||||
const cfg = config[type] || { label: type, color: 'gray', icon: FaFire };
|
||||
|
||||
return (
|
||||
<Badge colorScheme={cfg.color} variant="subtle" px={2} py={1} borderRadius="md">
|
||||
<HStack spacing={1}>
|
||||
<Icon as={cfg.icon} boxSize={3} />
|
||||
<Text>{cfg.label}</Text>
|
||||
<Text fontWeight="bold">{count}</Text>
|
||||
</HStack>
|
||||
</Badge>
|
||||
);
|
||||
};
|
||||
|
||||
/**
|
||||
* 指数统计卡片
|
||||
*/
|
||||
const IndexStatCard = ({ indexData }) => {
|
||||
const cardBg = useColorModeValue('white', '#1a1a1a');
|
||||
const borderColor = useColorModeValue('gray.200', '#333');
|
||||
const subTextColor = useColorModeValue('gray.600', 'gray.400');
|
||||
|
||||
if (!indexData) return null;
|
||||
|
||||
const changePct = indexData.change_pct || 0;
|
||||
const isUp = changePct >= 0;
|
||||
|
||||
return (
|
||||
<SimpleGrid columns={{ base: 2, md: 4 }} spacing={4}>
|
||||
<Stat size="sm">
|
||||
<StatLabel color={subTextColor}>{indexData.name || '上证指数'}</StatLabel>
|
||||
<StatNumber fontSize="xl" color={isUp ? 'red.500' : 'green.500'}>
|
||||
{indexData.latest_price?.toFixed(2) || '-'}
|
||||
</StatNumber>
|
||||
<StatHelpText mb={0}>
|
||||
<StatArrow type={isUp ? 'increase' : 'decrease'} />
|
||||
{changePct?.toFixed(2)}%
|
||||
</StatHelpText>
|
||||
</Stat>
|
||||
|
||||
<Stat size="sm">
|
||||
<StatLabel color={subTextColor}>最高</StatLabel>
|
||||
<StatNumber fontSize="xl" color="red.500">
|
||||
{indexData.high?.toFixed(2) || '-'}
|
||||
</StatNumber>
|
||||
</Stat>
|
||||
|
||||
<Stat size="sm">
|
||||
<StatLabel color={subTextColor}>最低</StatLabel>
|
||||
<StatNumber fontSize="xl" color="green.500">
|
||||
{indexData.low?.toFixed(2) || '-'}
|
||||
</StatNumber>
|
||||
</Stat>
|
||||
|
||||
<Stat size="sm">
|
||||
<StatLabel color={subTextColor}>振幅</StatLabel>
|
||||
<StatNumber fontSize="xl" color="purple.500">
|
||||
{indexData.high && indexData.low && indexData.prev_close
|
||||
? (((indexData.high - indexData.low) / indexData.prev_close) * 100).toFixed(2) + '%'
|
||||
: '-'}
|
||||
</StatNumber>
|
||||
</Stat>
|
||||
</SimpleGrid>
|
||||
);
|
||||
};
|
||||
|
||||
/**
|
||||
* 异动统计摘要
|
||||
* @param {Object} props
|
||||
* @param {Object} props.indexData - 指数数据
|
||||
* @param {Array} props.alerts - 异动数组
|
||||
* @param {Object} props.alertSummary - 异动类型统计
|
||||
*/
|
||||
const AlertSummary = ({ indexData, alerts = [], alertSummary = {} }) => {
|
||||
const cardBg = useColorModeValue('white', '#1a1a1a');
|
||||
const borderColor = useColorModeValue('gray.200', '#333');
|
||||
|
||||
// 如果没有 alertSummary,从 alerts 中统计
|
||||
const summary = alertSummary && Object.keys(alertSummary).length > 0
|
||||
? alertSummary
|
||||
: alerts.reduce((acc, alert) => {
|
||||
const type = alert.alert_type || 'unknown';
|
||||
acc[type] = (acc[type] || 0) + 1;
|
||||
return acc;
|
||||
}, {});
|
||||
|
||||
const totalAlerts = alerts.length;
|
||||
|
||||
return (
|
||||
<VStack spacing={4} align="stretch">
|
||||
{/* 指数统计 */}
|
||||
<IndexStatCard indexData={indexData} />
|
||||
|
||||
{/* 异动统计 */}
|
||||
{totalAlerts > 0 && (
|
||||
<HStack spacing={2} flexWrap="wrap">
|
||||
<Text fontSize="sm" color="gray.500" mr={2}>
|
||||
异动 {totalAlerts} 次:
|
||||
</Text>
|
||||
{(summary.surge_up > 0 || summary.surge > 0) && (
|
||||
<AlertTypeBadge type="surge_up" count={(summary.surge_up || 0) + (summary.surge || 0)} />
|
||||
)}
|
||||
{summary.surge_down > 0 && <AlertTypeBadge type="surge_down" count={summary.surge_down} />}
|
||||
{summary.limit_up > 0 && <AlertTypeBadge type="limit_up" count={summary.limit_up} />}
|
||||
{summary.volume_spike > 0 && <AlertTypeBadge type="volume_spike" count={summary.volume_spike} />}
|
||||
{summary.rank_jump > 0 && <AlertTypeBadge type="rank_jump" count={summary.rank_jump} />}
|
||||
</HStack>
|
||||
)}
|
||||
</VStack>
|
||||
);
|
||||
};
|
||||
|
||||
export default AlertSummary;
|
||||
@@ -0,0 +1,356 @@
|
||||
/**
|
||||
* 概念异动列表组件 - V2
|
||||
* 展示当日的概念异动记录,点击可展开显示相关股票
|
||||
*/
|
||||
import React, { useState, useCallback } from 'react';
|
||||
import {
|
||||
Box,
|
||||
VStack,
|
||||
HStack,
|
||||
Text,
|
||||
Badge,
|
||||
Icon,
|
||||
Tooltip,
|
||||
useColorModeValue,
|
||||
Flex,
|
||||
Collapse,
|
||||
Spinner,
|
||||
Progress,
|
||||
Table,
|
||||
Thead,
|
||||
Tbody,
|
||||
Tr,
|
||||
Th,
|
||||
Td,
|
||||
TableContainer,
|
||||
} from '@chakra-ui/react';
|
||||
import { FaArrowUp, FaArrowDown, FaFire, FaChevronDown, FaChevronRight } from 'react-icons/fa';
|
||||
import { useNavigate } from 'react-router-dom';
|
||||
import axios from 'axios';
|
||||
import { getAlertTypeLabel, formatScore, getScoreColor } from '../utils/chartHelpers';
|
||||
|
||||
/**
|
||||
* 紧凑型异动卡片
|
||||
*/
|
||||
const AlertCard = ({ alert, isExpanded, onToggle, stocks, loadingStocks }) => {
|
||||
const navigate = useNavigate();
|
||||
const bgColor = useColorModeValue('white', '#1a1a1a');
|
||||
const hoverBg = useColorModeValue('gray.50', '#252525');
|
||||
const borderColor = useColorModeValue('gray.200', '#333');
|
||||
const expandedBg = useColorModeValue('purple.50', '#1e1e2e');
|
||||
|
||||
const isUp = alert.alert_type !== 'surge_down';
|
||||
const typeColor = isUp ? 'red' : 'green';
|
||||
const isV2 = alert.is_v2;
|
||||
|
||||
// 点击股票跳转
|
||||
const handleStockClick = (e, stockCode) => {
|
||||
e.stopPropagation();
|
||||
navigate(`/company?scode=${stockCode}`);
|
||||
};
|
||||
|
||||
return (
|
||||
<Box
|
||||
bg={isExpanded ? expandedBg : bgColor}
|
||||
borderRadius="lg"
|
||||
borderWidth="1px"
|
||||
borderColor={isExpanded ? 'purple.400' : borderColor}
|
||||
overflow="hidden"
|
||||
transition="all 0.2s"
|
||||
_hover={{ borderColor: 'purple.300' }}
|
||||
>
|
||||
{/* 主卡片 - 点击展开 */}
|
||||
<Box
|
||||
p={3}
|
||||
cursor="pointer"
|
||||
onClick={onToggle}
|
||||
_hover={{ bg: hoverBg }}
|
||||
>
|
||||
<Flex justify="space-between" align="center">
|
||||
{/* 左侧:名称 + 类型 */}
|
||||
<HStack spacing={2} flex={1} minW={0}>
|
||||
<Icon
|
||||
as={isExpanded ? FaChevronDown : FaChevronRight}
|
||||
color="gray.400"
|
||||
boxSize={3}
|
||||
/>
|
||||
<Icon
|
||||
as={isUp ? FaArrowUp : FaArrowDown}
|
||||
color={`${typeColor}.500`}
|
||||
boxSize={3}
|
||||
/>
|
||||
<Text fontWeight="bold" fontSize="sm" noOfLines={1} flex={1}>
|
||||
{alert.concept_name}
|
||||
</Text>
|
||||
{isV2 && (
|
||||
<Badge colorScheme="purple" size="xs" variant="solid" fontSize="9px" px={1}>
|
||||
V2
|
||||
</Badge>
|
||||
)}
|
||||
</HStack>
|
||||
|
||||
{/* 右侧:分数 */}
|
||||
<Badge
|
||||
px={2}
|
||||
py={0.5}
|
||||
borderRadius="full"
|
||||
bg={getScoreColor(alert.final_score)}
|
||||
color="white"
|
||||
fontSize="xs"
|
||||
fontWeight="bold"
|
||||
ml={2}
|
||||
>
|
||||
{formatScore(alert.final_score)}分
|
||||
</Badge>
|
||||
</Flex>
|
||||
|
||||
{/* 第二行:时间 + 关键指标 */}
|
||||
<Flex mt={2} justify="space-between" align="center" fontSize="xs">
|
||||
<HStack spacing={2} color="gray.500">
|
||||
<Text>{alert.time}</Text>
|
||||
<Badge colorScheme={typeColor} size="sm" variant="subtle">
|
||||
{getAlertTypeLabel(alert.alert_type)}
|
||||
</Badge>
|
||||
{/* 确认率 */}
|
||||
{isV2 && alert.confirm_ratio != null && (
|
||||
<HStack spacing={1}>
|
||||
<Box w="30px" h="4px" bg="gray.200" borderRadius="full" overflow="hidden">
|
||||
<Box
|
||||
w={`${(alert.confirm_ratio || 0) * 100}%`}
|
||||
h="100%"
|
||||
bg={(alert.confirm_ratio || 0) >= 0.8 ? 'green.500' : 'orange.500'}
|
||||
/>
|
||||
</Box>
|
||||
<Text>{Math.round((alert.confirm_ratio || 0) * 100)}%</Text>
|
||||
</HStack>
|
||||
)}
|
||||
</HStack>
|
||||
|
||||
{/* Alpha + Z-Score 简化显示 */}
|
||||
<HStack spacing={3}>
|
||||
{alert.alpha != null && (
|
||||
<Text color={(alert.alpha || 0) >= 0 ? 'red.500' : 'green.500'} fontWeight="medium">
|
||||
α {(alert.alpha || 0) >= 0 ? '+' : ''}{(alert.alpha || 0).toFixed(2)}%
|
||||
</Text>
|
||||
)}
|
||||
{isV2 && alert.alpha_zscore != null && (
|
||||
<Tooltip label={`Alpha Z-Score: ${(alert.alpha_zscore || 0).toFixed(2)}σ`}>
|
||||
<HStack spacing={0.5}>
|
||||
<Box
|
||||
w="24px"
|
||||
h="4px"
|
||||
bg="gray.200"
|
||||
borderRadius="full"
|
||||
overflow="hidden"
|
||||
position="relative"
|
||||
>
|
||||
<Box
|
||||
position="absolute"
|
||||
left={(alert.alpha_zscore || 0) >= 0 ? '50%' : undefined}
|
||||
right={(alert.alpha_zscore || 0) < 0 ? '50%' : undefined}
|
||||
w={`${Math.min(Math.abs(alert.alpha_zscore || 0) / 5 * 50, 50)}%`}
|
||||
h="100%"
|
||||
bg={(alert.alpha_zscore || 0) >= 0 ? 'red.500' : 'green.500'}
|
||||
/>
|
||||
</Box>
|
||||
<Text color={(alert.alpha_zscore || 0) >= 0 ? 'red.400' : 'green.400'}>
|
||||
{(alert.alpha_zscore || 0) >= 0 ? '+' : ''}{(alert.alpha_zscore || 0).toFixed(1)}σ
|
||||
</Text>
|
||||
</HStack>
|
||||
</Tooltip>
|
||||
)}
|
||||
{(alert.limit_up_ratio || 0) > 0.05 && (
|
||||
<HStack spacing={0.5} color="orange.500">
|
||||
<Icon as={FaFire} boxSize={3} />
|
||||
<Text>{Math.round((alert.limit_up_ratio || 0) * 100)}%</Text>
|
||||
</HStack>
|
||||
)}
|
||||
</HStack>
|
||||
</Flex>
|
||||
</Box>
|
||||
|
||||
{/* 展开的股票列表 */}
|
||||
<Collapse in={isExpanded} animateOpacity>
|
||||
<Box
|
||||
borderTopWidth="1px"
|
||||
borderColor={borderColor}
|
||||
p={3}
|
||||
bg={useColorModeValue('gray.50', '#151520')}
|
||||
>
|
||||
{loadingStocks ? (
|
||||
<HStack justify="center" py={4}>
|
||||
<Spinner size="sm" color="purple.500" />
|
||||
<Text fontSize="sm" color="gray.500">加载相关股票...</Text>
|
||||
</HStack>
|
||||
) : stocks && stocks.length > 0 ? (
|
||||
<TableContainer maxH="200px" overflowY="auto">
|
||||
<Table size="sm" variant="simple">
|
||||
<Thead position="sticky" top={0} bg={useColorModeValue('gray.50', '#151520')} zIndex={1}>
|
||||
<Tr>
|
||||
<Th px={2} py={1} fontSize="xs" color="gray.500">股票</Th>
|
||||
<Th px={2} py={1} fontSize="xs" color="gray.500" isNumeric>涨跌</Th>
|
||||
<Th px={2} py={1} fontSize="xs" color="gray.500" maxW="120px">原因</Th>
|
||||
</Tr>
|
||||
</Thead>
|
||||
<Tbody>
|
||||
{stocks.slice(0, 10).map((stock, idx) => {
|
||||
const changePct = stock.change_pct;
|
||||
const hasChange = changePct != null && !isNaN(changePct);
|
||||
return (
|
||||
<Tr
|
||||
key={idx}
|
||||
cursor="pointer"
|
||||
_hover={{ bg: hoverBg }}
|
||||
onClick={(e) => handleStockClick(e, stock.code || stock.stock_code)}
|
||||
>
|
||||
<Td px={2} py={1.5}>
|
||||
<Text fontSize="xs" color="cyan.400" fontWeight="medium">
|
||||
{stock.name || stock.stock_name || '-'}
|
||||
</Text>
|
||||
</Td>
|
||||
<Td px={2} py={1.5} isNumeric>
|
||||
<Text
|
||||
fontSize="xs"
|
||||
fontWeight="bold"
|
||||
color={
|
||||
hasChange && changePct > 0 ? 'red.400' :
|
||||
hasChange && changePct < 0 ? 'green.400' : 'gray.400'
|
||||
}
|
||||
>
|
||||
{hasChange
|
||||
? `${changePct > 0 ? '+' : ''}${changePct.toFixed(2)}%`
|
||||
: '-'
|
||||
}
|
||||
</Text>
|
||||
</Td>
|
||||
<Td px={2} py={1.5} maxW="120px">
|
||||
<Text fontSize="xs" color="gray.500" noOfLines={1}>
|
||||
{stock.reason || '-'}
|
||||
</Text>
|
||||
</Td>
|
||||
</Tr>
|
||||
);
|
||||
})}
|
||||
</Tbody>
|
||||
</Table>
|
||||
{stocks.length > 10 && (
|
||||
<Text fontSize="xs" color="gray.500" textAlign="center" mt={2}>
|
||||
共 {stocks.length} 只相关股票,显示前 10 只
|
||||
</Text>
|
||||
)}
|
||||
</TableContainer>
|
||||
) : (
|
||||
<Text fontSize="sm" color="gray.500" textAlign="center" py={2}>
|
||||
暂无相关股票数据
|
||||
</Text>
|
||||
)}
|
||||
</Box>
|
||||
</Collapse>
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
|
||||
/**
|
||||
* 概念异动列表
|
||||
*/
|
||||
const ConceptAlertList = ({ alerts = [], onAlertClick, selectedAlert, maxHeight = '400px' }) => {
|
||||
const [expandedId, setExpandedId] = useState(null);
|
||||
const [conceptStocks, setConceptStocks] = useState({});
|
||||
const [loadingConcepts, setLoadingConcepts] = useState({});
|
||||
|
||||
const subTextColor = useColorModeValue('gray.500', 'gray.400');
|
||||
|
||||
// 获取概念相关股票
|
||||
const fetchConceptStocks = useCallback(async (conceptId) => {
|
||||
if (conceptStocks[conceptId] || loadingConcepts[conceptId]) {
|
||||
return;
|
||||
}
|
||||
|
||||
setLoadingConcepts(prev => ({ ...prev, [conceptId]: true }));
|
||||
|
||||
try {
|
||||
// 调用后端 API 获取概念股票
|
||||
const response = await axios.get(`/api/concept/${conceptId}/stocks`);
|
||||
if (response.data?.success && response.data?.data?.stocks) {
|
||||
setConceptStocks(prev => ({
|
||||
...prev,
|
||||
[conceptId]: response.data.data.stocks
|
||||
}));
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('获取概念股票失败:', error);
|
||||
// 如果 API 失败,尝试从 ES 直接获取
|
||||
try {
|
||||
const esResponse = await axios.get(`/api/es/concept/${conceptId}`);
|
||||
if (esResponse.data?.stocks) {
|
||||
setConceptStocks(prev => ({
|
||||
...prev,
|
||||
[conceptId]: esResponse.data.stocks
|
||||
}));
|
||||
}
|
||||
} catch (esError) {
|
||||
console.error('ES 获取也失败:', esError);
|
||||
setConceptStocks(prev => ({ ...prev, [conceptId]: [] }));
|
||||
}
|
||||
} finally {
|
||||
setLoadingConcepts(prev => ({ ...prev, [conceptId]: false }));
|
||||
}
|
||||
}, [conceptStocks, loadingConcepts]);
|
||||
|
||||
// 切换展开状态
|
||||
const handleToggle = useCallback((alert) => {
|
||||
const alertKey = `${alert.concept_id}-${alert.time}`;
|
||||
|
||||
if (expandedId === alertKey) {
|
||||
setExpandedId(null);
|
||||
} else {
|
||||
setExpandedId(alertKey);
|
||||
// 获取股票数据
|
||||
if (alert.concept_id) {
|
||||
fetchConceptStocks(alert.concept_id);
|
||||
}
|
||||
}
|
||||
|
||||
// 通知父组件
|
||||
onAlertClick?.(alert);
|
||||
}, [expandedId, fetchConceptStocks, onAlertClick]);
|
||||
|
||||
if (!alerts || alerts.length === 0) {
|
||||
return (
|
||||
<Box p={4} textAlign="center">
|
||||
<Text color={subTextColor} fontSize="sm">
|
||||
当日暂无概念异动
|
||||
</Text>
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
|
||||
// 按时间倒序排列
|
||||
const sortedAlerts = [...alerts].sort((a, b) => {
|
||||
const timeA = a.time || '00:00';
|
||||
const timeB = b.time || '00:00';
|
||||
return timeB.localeCompare(timeA);
|
||||
});
|
||||
|
||||
return (
|
||||
<Box maxH={maxHeight} overflowY="auto" pr={1}>
|
||||
<VStack spacing={2} align="stretch">
|
||||
{sortedAlerts.map((alert, idx) => {
|
||||
const alertKey = `${alert.concept_id}-${alert.time}`;
|
||||
return (
|
||||
<AlertCard
|
||||
key={alertKey || idx}
|
||||
alert={alert}
|
||||
isExpanded={expandedId === alertKey}
|
||||
onToggle={() => handleToggle(alert)}
|
||||
stocks={conceptStocks[alert.concept_id]}
|
||||
loadingStocks={loadingConcepts[alert.concept_id]}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
</VStack>
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
|
||||
export default ConceptAlertList;
|
||||
@@ -0,0 +1,264 @@
|
||||
/**
|
||||
* 指数分时图组件
|
||||
* 展示大盘分时走势,支持概念异动标注
|
||||
*/
|
||||
import React, { useRef, useEffect, useCallback, useMemo } from 'react';
|
||||
import { Box, useColorModeValue } from '@chakra-ui/react';
|
||||
import * as echarts from 'echarts';
|
||||
import { getAlertMarkPoints } from '../utils/chartHelpers';
|
||||
|
||||
/**
|
||||
* @param {Object} props
|
||||
* @param {Object} props.indexData - 指数数据 { timeline, prev_close, name, ... }
|
||||
* @param {Array} props.alerts - 异动数据数组
|
||||
* @param {Function} props.onAlertClick - 点击异动标注的回调
|
||||
* @param {string} props.height - 图表高度
|
||||
*/
|
||||
const IndexMinuteChart = ({ indexData, alerts = [], onAlertClick, height = '350px' }) => {
|
||||
const chartRef = useRef(null);
|
||||
const chartInstance = useRef(null);
|
||||
|
||||
const textColor = useColorModeValue('gray.800', 'white');
|
||||
const subTextColor = useColorModeValue('gray.600', 'gray.400');
|
||||
const gridLineColor = useColorModeValue('#eee', '#333');
|
||||
|
||||
// 计算图表配置
|
||||
const chartOption = useMemo(() => {
|
||||
if (!indexData || !indexData.timeline || indexData.timeline.length === 0) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const timeline = indexData.timeline || [];
|
||||
const times = timeline.map((d) => d.time);
|
||||
const prices = timeline.map((d) => d.price);
|
||||
const volumes = timeline.map((d) => d.volume);
|
||||
const changePcts = timeline.map((d) => d.change_pct);
|
||||
|
||||
// 计算Y轴范围
|
||||
const validPrices = prices.filter(Boolean);
|
||||
if (validPrices.length === 0) return null;
|
||||
|
||||
const priceMin = Math.min(...validPrices);
|
||||
const priceMax = Math.max(...validPrices);
|
||||
const priceRange = priceMax - priceMin;
|
||||
const yAxisMin = priceMin - priceRange * 0.1;
|
||||
const yAxisMax = priceMax + priceRange * 0.25; // 上方留更多空间给标注
|
||||
|
||||
// 准备异动标注
|
||||
const markPoints = getAlertMarkPoints(alerts, times, prices, priceMax);
|
||||
|
||||
// 渐变色 - 根据涨跌
|
||||
const latestChangePct = changePcts[changePcts.length - 1] || 0;
|
||||
const isUp = latestChangePct >= 0;
|
||||
const lineColor = isUp ? '#ff4d4d' : '#22c55e';
|
||||
const areaColorStops = isUp
|
||||
? [
|
||||
{ offset: 0, color: 'rgba(255, 77, 77, 0.4)' },
|
||||
{ offset: 1, color: 'rgba(255, 77, 77, 0.05)' },
|
||||
]
|
||||
: [
|
||||
{ offset: 0, color: 'rgba(34, 197, 94, 0.4)' },
|
||||
{ offset: 1, color: 'rgba(34, 197, 94, 0.05)' },
|
||||
];
|
||||
|
||||
return {
|
||||
backgroundColor: 'transparent',
|
||||
tooltip: {
|
||||
trigger: 'axis',
|
||||
axisPointer: {
|
||||
type: 'cross',
|
||||
crossStyle: { color: '#999' },
|
||||
},
|
||||
formatter: (params) => {
|
||||
if (!params || params.length === 0) return '';
|
||||
|
||||
const dataIndex = params[0].dataIndex;
|
||||
const time = times[dataIndex];
|
||||
const price = prices[dataIndex];
|
||||
const changePct = changePcts[dataIndex];
|
||||
const volume = volumes[dataIndex];
|
||||
|
||||
let html = `
|
||||
<div style="padding: 8px;">
|
||||
<div style="font-weight: bold; margin-bottom: 4px;">${time}</div>
|
||||
<div>指数: <span style="color: ${changePct >= 0 ? '#ff4d4d' : '#22c55e'}; font-weight: bold;">${price?.toFixed(2)}</span></div>
|
||||
<div>涨跌: <span style="color: ${changePct >= 0 ? '#ff4d4d' : '#22c55e'};">${changePct >= 0 ? '+' : ''}${changePct?.toFixed(2)}%</span></div>
|
||||
<div>成交量: ${(volume / 10000).toFixed(0)}万手</div>
|
||||
</div>
|
||||
`;
|
||||
|
||||
// 检查是否有异动
|
||||
const alertsAtTime = alerts.filter((a) => a.time === time);
|
||||
if (alertsAtTime.length > 0) {
|
||||
html += '<div style="border-top: 1px solid #eee; margin-top: 4px; padding-top: 4px;">';
|
||||
html += '<div style="font-weight: bold; color: #ff6b6b;">概念异动:</div>';
|
||||
alertsAtTime.forEach((alert) => {
|
||||
const typeLabel = {
|
||||
surge: '急涨',
|
||||
surge_up: '暴涨',
|
||||
surge_down: '暴跌',
|
||||
limit_up: '涨停增加',
|
||||
rank_jump: '排名跃升',
|
||||
volume_spike: '放量',
|
||||
}[alert.alert_type] || alert.alert_type;
|
||||
const typeColor = alert.alert_type === 'surge_down' ? '#2ed573' : '#ff6b6b';
|
||||
const alpha = alert.alpha ? ` (α${alert.alpha > 0 ? '+' : ''}${alert.alpha.toFixed(2)}%)` : '';
|
||||
html += `<div style="color: ${typeColor}">• ${alert.concept_name} (${typeLabel}${alpha})</div>`;
|
||||
});
|
||||
html += '</div>';
|
||||
}
|
||||
|
||||
return html;
|
||||
},
|
||||
},
|
||||
legend: { show: false },
|
||||
grid: [
|
||||
{ left: '8%', right: '3%', top: '8%', height: '58%' },
|
||||
{ left: '8%', right: '3%', top: '72%', height: '18%' },
|
||||
],
|
||||
xAxis: [
|
||||
{
|
||||
type: 'category',
|
||||
data: times,
|
||||
axisLine: { lineStyle: { color: gridLineColor } },
|
||||
axisLabel: {
|
||||
color: subTextColor,
|
||||
fontSize: 10,
|
||||
interval: Math.floor(times.length / 6),
|
||||
},
|
||||
axisTick: { show: false },
|
||||
splitLine: { show: false },
|
||||
},
|
||||
{
|
||||
type: 'category',
|
||||
gridIndex: 1,
|
||||
data: times,
|
||||
axisLine: { lineStyle: { color: gridLineColor } },
|
||||
axisLabel: { show: false },
|
||||
axisTick: { show: false },
|
||||
splitLine: { show: false },
|
||||
},
|
||||
],
|
||||
yAxis: [
|
||||
{
|
||||
type: 'value',
|
||||
min: yAxisMin,
|
||||
max: yAxisMax,
|
||||
axisLine: { show: false },
|
||||
axisLabel: {
|
||||
color: subTextColor,
|
||||
fontSize: 10,
|
||||
formatter: (val) => val.toFixed(0),
|
||||
},
|
||||
splitLine: { lineStyle: { color: gridLineColor, type: 'dashed' } },
|
||||
axisPointer: {
|
||||
label: {
|
||||
formatter: (params) => {
|
||||
if (!indexData.prev_close) return params.value.toFixed(2);
|
||||
const pct = ((params.value - indexData.prev_close) / indexData.prev_close) * 100;
|
||||
return `${params.value.toFixed(2)} (${pct >= 0 ? '+' : ''}${pct.toFixed(2)}%)`;
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
type: 'value',
|
||||
gridIndex: 1,
|
||||
axisLine: { show: false },
|
||||
axisLabel: { show: false },
|
||||
splitLine: { show: false },
|
||||
},
|
||||
],
|
||||
series: [
|
||||
// 分时线
|
||||
{
|
||||
name: indexData.name || '上证指数',
|
||||
type: 'line',
|
||||
data: prices,
|
||||
smooth: true,
|
||||
symbol: 'none',
|
||||
lineStyle: { color: lineColor, width: 1.5 },
|
||||
areaStyle: {
|
||||
color: new echarts.graphic.LinearGradient(0, 0, 0, 1, areaColorStops),
|
||||
},
|
||||
markPoint: {
|
||||
symbol: 'pin',
|
||||
symbolSize: 40,
|
||||
data: markPoints,
|
||||
animation: true,
|
||||
},
|
||||
},
|
||||
// 成交量
|
||||
{
|
||||
name: '成交量',
|
||||
type: 'bar',
|
||||
xAxisIndex: 1,
|
||||
yAxisIndex: 1,
|
||||
data: volumes.map((v, i) => ({
|
||||
value: v,
|
||||
itemStyle: {
|
||||
color: changePcts[i] >= 0 ? 'rgba(255, 77, 77, 0.6)' : 'rgba(34, 197, 94, 0.6)',
|
||||
},
|
||||
})),
|
||||
barWidth: '60%',
|
||||
},
|
||||
],
|
||||
};
|
||||
}, [indexData, alerts, subTextColor, gridLineColor]);
|
||||
|
||||
// 渲染图表
|
||||
const renderChart = useCallback(() => {
|
||||
if (!chartRef.current || !chartOption) return;
|
||||
|
||||
if (!chartInstance.current) {
|
||||
chartInstance.current = echarts.init(chartRef.current);
|
||||
}
|
||||
|
||||
chartInstance.current.setOption(chartOption, true);
|
||||
|
||||
// 点击事件
|
||||
if (onAlertClick) {
|
||||
chartInstance.current.off('click');
|
||||
chartInstance.current.on('click', 'series.line.markPoint', (params) => {
|
||||
if (params.data && params.data.alertData) {
|
||||
onAlertClick(params.data.alertData);
|
||||
}
|
||||
});
|
||||
}
|
||||
}, [chartOption, onAlertClick]);
|
||||
|
||||
// 数据变化时重新渲染
|
||||
useEffect(() => {
|
||||
renderChart();
|
||||
}, [renderChart]);
|
||||
|
||||
// 窗口大小变化时重新渲染
|
||||
useEffect(() => {
|
||||
const handleResize = () => {
|
||||
if (chartInstance.current) {
|
||||
chartInstance.current.resize();
|
||||
}
|
||||
};
|
||||
|
||||
window.addEventListener('resize', handleResize);
|
||||
return () => {
|
||||
window.removeEventListener('resize', handleResize);
|
||||
if (chartInstance.current) {
|
||||
chartInstance.current.dispose();
|
||||
chartInstance.current = null;
|
||||
}
|
||||
};
|
||||
}, []);
|
||||
|
||||
if (!chartOption) {
|
||||
return (
|
||||
<Box h={height} display="flex" alignItems="center" justifyContent="center" color={subTextColor}>
|
||||
暂无数据
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
|
||||
return <Box ref={chartRef} h={height} w="100%" />;
|
||||
};
|
||||
|
||||
export default IndexMinuteChart;
|
||||
@@ -0,0 +1,3 @@
|
||||
export { default as IndexMinuteChart } from './IndexMinuteChart';
|
||||
export { default as ConceptAlertList } from './ConceptAlertList';
|
||||
export { default as AlertSummary } from './AlertSummary';
|
||||
@@ -0,0 +1 @@
|
||||
export { useHotspotData } from './useHotspotData';
|
||||
@@ -0,0 +1,53 @@
|
||||
/**
|
||||
* 热点概览数据获取 Hook
|
||||
* 负责获取指数分时数据和概念异动数据
|
||||
*/
|
||||
import { useState, useEffect, useCallback } from 'react';
|
||||
import { logger } from '@utils/logger';
|
||||
|
||||
/**
|
||||
* @param {Date|null} selectedDate - 选中的交易日期
|
||||
* @returns {Object} 数据和状态
|
||||
*/
|
||||
export const useHotspotData = (selectedDate) => {
|
||||
const [loading, setLoading] = useState(true);
|
||||
const [error, setError] = useState(null);
|
||||
const [data, setData] = useState(null);
|
||||
|
||||
const fetchData = useCallback(async () => {
|
||||
setLoading(true);
|
||||
setError(null);
|
||||
|
||||
try {
|
||||
const dateParam = selectedDate
|
||||
? `?date=${selectedDate.toISOString().split('T')[0]}`
|
||||
: '';
|
||||
const response = await fetch(`/api/market/hotspot-overview${dateParam}`);
|
||||
const result = await response.json();
|
||||
|
||||
if (result.success) {
|
||||
setData(result.data);
|
||||
} else {
|
||||
setError(result.error || '获取数据失败');
|
||||
}
|
||||
} catch (err) {
|
||||
logger.error('useHotspotData', 'fetchData', err);
|
||||
setError('网络请求失败');
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
}, [selectedDate]);
|
||||
|
||||
useEffect(() => {
|
||||
fetchData();
|
||||
}, [fetchData]);
|
||||
|
||||
return {
|
||||
loading,
|
||||
error,
|
||||
data,
|
||||
refetch: fetchData,
|
||||
};
|
||||
};
|
||||
|
||||
export default useHotspotData;
|
||||
198
src/views/StockOverview/components/HotspotOverview/index.js
Normal file
198
src/views/StockOverview/components/HotspotOverview/index.js
Normal file
@@ -0,0 +1,198 @@
|
||||
/**
|
||||
* 热点概览组件
|
||||
* 展示大盘分时走势 + 概念异动标注
|
||||
*
|
||||
* 模块化结构:
|
||||
* - hooks/useHotspotData.js - 数据获取
|
||||
* - components/IndexMinuteChart.js - 分时图
|
||||
* - components/ConceptAlertList.js - 异动列表
|
||||
* - components/AlertSummary.js - 统计摘要
|
||||
* - utils/chartHelpers.js - 图表辅助函数
|
||||
*/
|
||||
import React, { useState, useCallback } from 'react';
|
||||
import {
|
||||
Box,
|
||||
Card,
|
||||
CardBody,
|
||||
Heading,
|
||||
Text,
|
||||
HStack,
|
||||
VStack,
|
||||
Spinner,
|
||||
Center,
|
||||
Icon,
|
||||
Flex,
|
||||
Spacer,
|
||||
Tooltip,
|
||||
useColorModeValue,
|
||||
Grid,
|
||||
GridItem,
|
||||
Divider,
|
||||
IconButton,
|
||||
Collapse,
|
||||
} from '@chakra-ui/react';
|
||||
import { FaFire, FaList, FaChartArea, FaChevronDown, FaChevronUp } from 'react-icons/fa';
|
||||
import { InfoIcon } from '@chakra-ui/icons';
|
||||
|
||||
import { useHotspotData } from './hooks';
|
||||
import { IndexMinuteChart, ConceptAlertList, AlertSummary } from './components';
|
||||
|
||||
/**
|
||||
* 热点概览主组件
|
||||
* @param {Object} props
|
||||
* @param {Date|null} props.selectedDate - 选中的交易日期
|
||||
*/
|
||||
const HotspotOverview = ({ selectedDate }) => {
|
||||
const [selectedAlert, setSelectedAlert] = useState(null);
|
||||
const [showAlertList, setShowAlertList] = useState(true);
|
||||
|
||||
// 获取数据
|
||||
const { loading, error, data } = useHotspotData(selectedDate);
|
||||
|
||||
// 颜色主题
|
||||
const cardBg = useColorModeValue('white', '#1a1a1a');
|
||||
const borderColor = useColorModeValue('gray.200', '#333333');
|
||||
const textColor = useColorModeValue('gray.800', 'white');
|
||||
const subTextColor = useColorModeValue('gray.600', 'gray.400');
|
||||
|
||||
// 点击异动标注
|
||||
const handleAlertClick = useCallback((alert) => {
|
||||
setSelectedAlert(alert);
|
||||
// 可以在这里添加滚动到对应位置的逻辑
|
||||
}, []);
|
||||
|
||||
// 渲染加载状态
|
||||
if (loading) {
|
||||
return (
|
||||
<Card bg={cardBg} borderWidth="1px" borderColor={borderColor}>
|
||||
<CardBody>
|
||||
<Center h="400px">
|
||||
<VStack spacing={4}>
|
||||
<Spinner size="xl" color="purple.500" thickness="4px" />
|
||||
<Text color={subTextColor}>加载热点概览数据...</Text>
|
||||
</VStack>
|
||||
</Center>
|
||||
</CardBody>
|
||||
</Card>
|
||||
);
|
||||
}
|
||||
|
||||
// 渲染错误状态
|
||||
if (error) {
|
||||
return (
|
||||
<Card bg={cardBg} borderWidth="1px" borderColor={borderColor}>
|
||||
<CardBody>
|
||||
<Center h="400px">
|
||||
<VStack spacing={4}>
|
||||
<Icon as={InfoIcon} boxSize={10} color="red.400" />
|
||||
<Text color="red.500">{error}</Text>
|
||||
</VStack>
|
||||
</Center>
|
||||
</CardBody>
|
||||
</Card>
|
||||
);
|
||||
}
|
||||
|
||||
// 无数据
|
||||
if (!data) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const { index, alerts, alert_summary } = data;
|
||||
|
||||
return (
|
||||
<Card bg={cardBg} borderWidth="1px" borderColor={borderColor}>
|
||||
<CardBody>
|
||||
{/* 头部 */}
|
||||
<Flex align="center" mb={4}>
|
||||
<HStack spacing={3}>
|
||||
<Icon as={FaFire} boxSize={6} color="orange.500" />
|
||||
<Heading size="md" color={textColor}>
|
||||
热点概览
|
||||
</Heading>
|
||||
</HStack>
|
||||
<Spacer />
|
||||
<HStack spacing={2}>
|
||||
<Tooltip label={showAlertList ? '收起异动列表' : '展开异动列表'}>
|
||||
<IconButton
|
||||
icon={showAlertList ? <FaChevronUp /> : <FaList />}
|
||||
size="sm"
|
||||
variant="ghost"
|
||||
onClick={() => setShowAlertList(!showAlertList)}
|
||||
aria-label="切换异动列表"
|
||||
/>
|
||||
</Tooltip>
|
||||
<Tooltip label="展示大盘走势与概念异动的关联">
|
||||
<Icon as={InfoIcon} color={subTextColor} />
|
||||
</Tooltip>
|
||||
</HStack>
|
||||
</Flex>
|
||||
|
||||
{/* 统计摘要 */}
|
||||
<Box mb={4}>
|
||||
<AlertSummary indexData={index} alerts={alerts} alertSummary={alert_summary} />
|
||||
</Box>
|
||||
|
||||
<Divider mb={4} />
|
||||
|
||||
{/* 主体内容:图表 + 异动列表 */}
|
||||
<Grid
|
||||
templateColumns={{ base: '1fr', lg: showAlertList ? '1fr 300px' : '1fr' }}
|
||||
gap={4}
|
||||
>
|
||||
{/* 分时图 */}
|
||||
<GridItem>
|
||||
<Box>
|
||||
<HStack spacing={2} mb={2}>
|
||||
<Icon as={FaChartArea} color="purple.500" boxSize={4} />
|
||||
<Text fontSize="sm" fontWeight="medium" color={textColor}>
|
||||
大盘分时走势
|
||||
</Text>
|
||||
</HStack>
|
||||
<IndexMinuteChart
|
||||
indexData={index}
|
||||
alerts={alerts}
|
||||
onAlertClick={handleAlertClick}
|
||||
height="350px"
|
||||
/>
|
||||
</Box>
|
||||
</GridItem>
|
||||
|
||||
{/* 异动列表(可收起) */}
|
||||
<Collapse in={showAlertList} animateOpacity>
|
||||
<GridItem>
|
||||
<Box>
|
||||
<HStack spacing={2} mb={2}>
|
||||
<Icon as={FaList} color="orange.500" boxSize={4} />
|
||||
<Text fontSize="sm" fontWeight="medium" color={textColor}>
|
||||
异动记录
|
||||
</Text>
|
||||
<Text fontSize="xs" color={subTextColor}>
|
||||
({alerts.length})
|
||||
</Text>
|
||||
</HStack>
|
||||
<ConceptAlertList
|
||||
alerts={alerts}
|
||||
onAlertClick={handleAlertClick}
|
||||
selectedAlert={selectedAlert}
|
||||
maxHeight="350px"
|
||||
/>
|
||||
</Box>
|
||||
</GridItem>
|
||||
</Collapse>
|
||||
</Grid>
|
||||
|
||||
{/* 无异动提示 */}
|
||||
{alerts.length === 0 && (
|
||||
<Center py={4}>
|
||||
<Text color={subTextColor} fontSize="sm">
|
||||
当日暂无概念异动数据
|
||||
</Text>
|
||||
</Center>
|
||||
)}
|
||||
</CardBody>
|
||||
</Card>
|
||||
);
|
||||
};
|
||||
|
||||
export default HotspotOverview;
|
||||
@@ -0,0 +1,160 @@
|
||||
/**
|
||||
* 图表辅助函数
|
||||
* 用于处理异动标注等图表相关逻辑
|
||||
*/
|
||||
|
||||
/**
|
||||
* 获取异动标注的配色和符号
|
||||
* @param {string} alertType - 异动类型
|
||||
* @param {number} importanceScore - 重要性得分
|
||||
* @returns {Object} { color, symbol, symbolSize }
|
||||
*/
|
||||
export const getAlertStyle = (alertType, importanceScore = 0.5) => {
|
||||
let color = '#ff6b6b';
|
||||
let symbol = 'pin';
|
||||
let symbolSize = 35;
|
||||
|
||||
switch (alertType) {
|
||||
case 'surge_up':
|
||||
case 'surge':
|
||||
color = '#ff4757';
|
||||
symbol = 'triangle';
|
||||
symbolSize = 30 + Math.min(importanceScore * 20, 15);
|
||||
break;
|
||||
case 'surge_down':
|
||||
color = '#2ed573';
|
||||
symbol = 'path://M0,0 L10,0 L5,10 Z'; // 向下三角形
|
||||
symbolSize = 30 + Math.min(importanceScore * 20, 15);
|
||||
break;
|
||||
case 'limit_up':
|
||||
color = '#ff6348';
|
||||
symbol = 'diamond';
|
||||
symbolSize = 28;
|
||||
break;
|
||||
case 'rank_jump':
|
||||
color = '#3742fa';
|
||||
symbol = 'circle';
|
||||
symbolSize = 25;
|
||||
break;
|
||||
case 'volume_spike':
|
||||
color = '#ffa502';
|
||||
symbol = 'rect';
|
||||
symbolSize = 25;
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
return { color, symbol, symbolSize };
|
||||
};
|
||||
|
||||
/**
|
||||
* 获取异动类型的显示标签
|
||||
* @param {string} alertType - 异动类型
|
||||
* @returns {string} 显示标签
|
||||
*/
|
||||
export const getAlertTypeLabel = (alertType) => {
|
||||
const labels = {
|
||||
surge: '急涨',
|
||||
surge_up: '暴涨',
|
||||
surge_down: '暴跌',
|
||||
limit_up: '涨停增加',
|
||||
rank_jump: '排名跃升',
|
||||
volume_spike: '放量',
|
||||
unknown: '异动',
|
||||
};
|
||||
return labels[alertType] || alertType;
|
||||
};
|
||||
|
||||
/**
|
||||
* 生成图表标注点数据
|
||||
* @param {Array} alerts - 异动数据数组
|
||||
* @param {Array} times - 时间数组
|
||||
* @param {Array} prices - 价格数组
|
||||
* @param {number} priceMax - 最高价格(用于无法匹配时间时的默认位置)
|
||||
* @param {number} maxCount - 最大显示数量
|
||||
* @returns {Array} ECharts markPoint data
|
||||
*/
|
||||
export const getAlertMarkPoints = (alerts, times, prices, priceMax, maxCount = 15) => {
|
||||
if (!alerts || alerts.length === 0) return [];
|
||||
|
||||
// 按重要性排序,限制显示数量
|
||||
const sortedAlerts = [...alerts]
|
||||
.sort((a, b) => (b.final_score || b.importance_score || 0) - (a.final_score || a.importance_score || 0))
|
||||
.slice(0, maxCount);
|
||||
|
||||
return sortedAlerts.map((alert) => {
|
||||
// 找到对应时间的价格
|
||||
const timeIndex = times.indexOf(alert.time);
|
||||
const price = timeIndex >= 0 ? prices[timeIndex] : (alert.index_price || priceMax);
|
||||
|
||||
const { color, symbol, symbolSize } = getAlertStyle(
|
||||
alert.alert_type,
|
||||
alert.final_score / 100 || alert.importance_score || 0.5
|
||||
);
|
||||
|
||||
// 格式化标签
|
||||
let label = alert.concept_name || '';
|
||||
if (label.length > 6) {
|
||||
label = label.substring(0, 5) + '...';
|
||||
}
|
||||
|
||||
// 添加涨停数量(如果有)
|
||||
if (alert.limit_up_count > 0) {
|
||||
label += `\n涨停: ${alert.limit_up_count}`;
|
||||
}
|
||||
|
||||
const isDown = alert.alert_type === 'surge_down';
|
||||
|
||||
return {
|
||||
name: alert.concept_name,
|
||||
coord: [alert.time, price],
|
||||
value: label,
|
||||
symbol,
|
||||
symbolSize,
|
||||
itemStyle: {
|
||||
color,
|
||||
borderColor: '#fff',
|
||||
borderWidth: 1,
|
||||
shadowBlur: 3,
|
||||
shadowColor: 'rgba(0,0,0,0.2)',
|
||||
},
|
||||
label: {
|
||||
show: true,
|
||||
position: isDown ? 'bottom' : 'top',
|
||||
formatter: '{b}',
|
||||
fontSize: 9,
|
||||
color: '#333',
|
||||
backgroundColor: isDown ? 'rgba(46, 213, 115, 0.9)' : 'rgba(255,255,255,0.9)',
|
||||
padding: [2, 4],
|
||||
borderRadius: 2,
|
||||
borderColor: color,
|
||||
borderWidth: 1,
|
||||
},
|
||||
alertData: alert, // 存储原始数据
|
||||
};
|
||||
});
|
||||
};
|
||||
|
||||
/**
|
||||
* 格式化分数显示
|
||||
* @param {number} score - 分数
|
||||
* @returns {string} 格式化后的分数
|
||||
*/
|
||||
export const formatScore = (score) => {
|
||||
if (score === null || score === undefined) return '-';
|
||||
return Math.round(score).toString();
|
||||
};
|
||||
|
||||
/**
|
||||
* 获取分数对应的颜色
|
||||
* @param {number} score - 分数 (0-100)
|
||||
* @returns {string} 颜色代码
|
||||
*/
|
||||
export const getScoreColor = (score) => {
|
||||
const s = score || 0;
|
||||
if (s >= 80) return '#ff4757';
|
||||
if (s >= 60) return '#ff6348';
|
||||
if (s >= 40) return '#ffa502';
|
||||
return '#747d8c';
|
||||
};
|
||||
@@ -0,0 +1 @@
|
||||
export * from './chartHelpers';
|
||||
@@ -50,9 +50,11 @@ import {
|
||||
SkeletonText,
|
||||
} from '@chakra-ui/react';
|
||||
import { SearchIcon, CloseIcon, ArrowForwardIcon, TrendingUpIcon, InfoIcon, ChevronRightIcon, CalendarIcon } from '@chakra-ui/icons';
|
||||
import { FaChartLine, FaFire, FaRocket, FaBrain, FaCalendarAlt, FaChevronRight, FaArrowUp, FaArrowDown, FaChartBar } from 'react-icons/fa';
|
||||
import { FaChartLine, FaFire, FaRocket, FaBrain, FaCalendarAlt, FaChevronRight, FaArrowUp, FaArrowDown, FaChartBar, FaTag, FaLayerGroup, FaBolt } from 'react-icons/fa';
|
||||
import ConceptStocksModal from '@components/ConceptStocksModal';
|
||||
import TradeDatePicker from '@components/TradeDatePicker';
|
||||
import HotspotOverview from './components/HotspotOverview';
|
||||
import FlexScreen from './components/FlexScreen';
|
||||
import { BsGraphUp, BsLightningFill } from 'react-icons/bs';
|
||||
import * as echarts from 'echarts';
|
||||
import { logger } from '../../utils/logger';
|
||||
@@ -840,6 +842,16 @@ const StockOverview = () => {
|
||||
)}
|
||||
</Box>
|
||||
|
||||
{/* 热点概览 - 大盘走势 + 概念异动 */}
|
||||
<Box mb={10}>
|
||||
<HotspotOverview selectedDate={selectedDate} />
|
||||
</Box>
|
||||
|
||||
{/* 灵活屏 - 实时行情监控 */}
|
||||
<Box mb={10}>
|
||||
<FlexScreen />
|
||||
</Box>
|
||||
|
||||
{/* 今日热门概念 */}
|
||||
<Box mb={10}>
|
||||
<Flex align="center" mb={6}>
|
||||
@@ -927,16 +939,65 @@ const StockOverview = () => {
|
||||
|
||||
<CardBody pt={12}>
|
||||
<VStack align="start" spacing={3}>
|
||||
{/* 概念名称 */}
|
||||
<Heading size="md" noOfLines={1} color={textColor}>
|
||||
{concept.concept_name}
|
||||
</Heading>
|
||||
|
||||
{/* 层级信息 */}
|
||||
{concept.hierarchy && (
|
||||
<HStack spacing={1} flexWrap="wrap">
|
||||
<Icon as={FaLayerGroup} boxSize={3} color="gray.400" />
|
||||
<Text fontSize="xs" color="gray.500">
|
||||
{[concept.hierarchy.lv1, concept.hierarchy.lv2, concept.hierarchy.lv3]
|
||||
.filter(Boolean)
|
||||
.join(' > ')}
|
||||
</Text>
|
||||
</HStack>
|
||||
)}
|
||||
|
||||
{/* 描述 */}
|
||||
<Text fontSize="sm" color={subTextColor} noOfLines={2}>
|
||||
{concept.description || '暂无描述'}
|
||||
</Text>
|
||||
|
||||
{/* 标签 */}
|
||||
{concept.tags && concept.tags.length > 0 && (
|
||||
<Flex flexWrap="wrap" gap={1}>
|
||||
{concept.tags.slice(0, 4).map((tag, idx) => (
|
||||
<Tag
|
||||
key={idx}
|
||||
size="sm"
|
||||
variant="outline"
|
||||
colorScheme="blue"
|
||||
borderRadius="full"
|
||||
>
|
||||
<Icon as={FaTag} boxSize={2} mr={1} />
|
||||
<TagLabel fontSize="xs">{tag}</TagLabel>
|
||||
</Tag>
|
||||
))}
|
||||
{concept.tags.length > 4 && (
|
||||
<Tag size="sm" variant="ghost" colorScheme="gray">
|
||||
<TagLabel fontSize="xs">+{concept.tags.length - 4}</TagLabel>
|
||||
</Tag>
|
||||
)}
|
||||
</Flex>
|
||||
)}
|
||||
|
||||
{/* 爆发日期 */}
|
||||
{concept.outbreak_dates && concept.outbreak_dates.length > 0 && (
|
||||
<HStack spacing={2} fontSize="xs" color="orange.500">
|
||||
<Icon as={FaBolt} />
|
||||
<Text>
|
||||
近期爆发: {concept.outbreak_dates.slice(0, 2).join(', ')}
|
||||
{concept.outbreak_dates.length > 2 && ` 等${concept.outbreak_dates.length}次`}
|
||||
</Text>
|
||||
</HStack>
|
||||
)}
|
||||
|
||||
<Divider />
|
||||
|
||||
{/* 相关股票 */}
|
||||
<Box
|
||||
w="100%"
|
||||
cursor="pointer"
|
||||
@@ -957,7 +1018,7 @@ const StockOverview = () => {
|
||||
overflow="hidden"
|
||||
maxH="24px"
|
||||
>
|
||||
{concept.stocks.map((stock, idx) => (
|
||||
{concept.stocks.slice(0, 5).map((stock, idx) => (
|
||||
<Tag
|
||||
key={idx}
|
||||
size="sm"
|
||||
@@ -965,9 +1026,14 @@ const StockOverview = () => {
|
||||
variant="subtle"
|
||||
flexShrink={0}
|
||||
>
|
||||
<TagLabel>{stock.stock_name}</TagLabel>
|
||||
<TagLabel>{stock.stock_name || stock.name}</TagLabel>
|
||||
</Tag>
|
||||
))}
|
||||
{concept.stocks.length > 5 && (
|
||||
<Tag size="sm" variant="ghost" colorScheme="gray" flexShrink={0}>
|
||||
<TagLabel>+{concept.stocks.length - 5}</TagLabel>
|
||||
</Tag>
|
||||
)}
|
||||
</Flex>
|
||||
)}
|
||||
</Box>
|
||||
|
||||
378
sse_html.html
Normal file
378
sse_html.html
Normal file
@@ -0,0 +1,378 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="zh-CN">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>VDE 实时行情 - WebSocket 测试</title>
|
||||
<style>
|
||||
* { box-sizing: border-box; }
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
|
||||
margin: 0;
|
||||
padding: 20px;
|
||||
background: #f5f5f5;
|
||||
}
|
||||
.container { max-width: 1400px; margin: 0 auto; }
|
||||
h1 { color: #333; margin-bottom: 20px; }
|
||||
|
||||
.status-bar {
|
||||
background: #fff;
|
||||
padding: 15px 20px;
|
||||
border-radius: 8px;
|
||||
margin-bottom: 20px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 20px;
|
||||
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
|
||||
}
|
||||
.status-indicator {
|
||||
width: 12px;
|
||||
height: 12px;
|
||||
border-radius: 50%;
|
||||
background: #ccc;
|
||||
}
|
||||
.status-indicator.connected { background: #4caf50; }
|
||||
.status-indicator.connecting { background: #ff9800; animation: pulse 1s infinite; }
|
||||
.status-indicator.disconnected { background: #f44336; }
|
||||
@keyframes pulse { 0%, 100% { opacity: 1; } 50% { opacity: 0.5; } }
|
||||
|
||||
.controls {
|
||||
display: flex;
|
||||
gap: 10px;
|
||||
margin-left: auto;
|
||||
}
|
||||
button {
|
||||
padding: 8px 16px;
|
||||
border: none;
|
||||
border-radius: 4px;
|
||||
cursor: pointer;
|
||||
font-size: 14px;
|
||||
}
|
||||
button.primary { background: #1976d2; color: white; }
|
||||
button.danger { background: #d32f2f; color: white; }
|
||||
button:hover { opacity: 0.9; }
|
||||
|
||||
.grid {
|
||||
display: grid;
|
||||
grid-template-columns: repeat(auto-fit, minmax(400px, 1fr));
|
||||
gap: 20px;
|
||||
}
|
||||
|
||||
.card {
|
||||
background: #fff;
|
||||
border-radius: 8px;
|
||||
padding: 20px;
|
||||
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
|
||||
}
|
||||
.card h2 {
|
||||
margin: 0 0 15px 0;
|
||||
font-size: 16px;
|
||||
color: #666;
|
||||
border-bottom: 1px solid #eee;
|
||||
padding-bottom: 10px;
|
||||
}
|
||||
|
||||
.quote-table {
|
||||
width: 100%;
|
||||
border-collapse: collapse;
|
||||
font-size: 13px;
|
||||
}
|
||||
.quote-table th, .quote-table td {
|
||||
padding: 8px;
|
||||
text-align: right;
|
||||
border-bottom: 1px solid #f0f0f0;
|
||||
}
|
||||
.quote-table th { text-align: left; color: #999; font-weight: normal; }
|
||||
.quote-table td:first-child { text-align: left; font-weight: 500; }
|
||||
.quote-table .name { color: #666; font-size: 12px; }
|
||||
|
||||
.price-up { color: #f44336; }
|
||||
.price-down { color: #4caf50; }
|
||||
.price-flat { color: #999; }
|
||||
|
||||
.update-flash { animation: flash 0.3s; }
|
||||
@keyframes flash { 0%, 100% { background: transparent; } 50% { background: #fff3e0; } }
|
||||
|
||||
#log {
|
||||
background: #263238;
|
||||
color: #aed581;
|
||||
padding: 15px;
|
||||
border-radius: 8px;
|
||||
font-family: monospace;
|
||||
font-size: 12px;
|
||||
max-height: 200px;
|
||||
overflow-y: auto;
|
||||
margin-top: 20px;
|
||||
}
|
||||
#log .error { color: #ef5350; }
|
||||
#log .info { color: #4fc3f7; }
|
||||
|
||||
.stats {
|
||||
display: flex;
|
||||
gap: 30px;
|
||||
font-size: 13px;
|
||||
color: #666;
|
||||
}
|
||||
.stats span { font-weight: 500; color: #333; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<h1>VDE 实时行情</h1>
|
||||
|
||||
<div class="status-bar">
|
||||
<div class="status-indicator" id="statusIndicator"></div>
|
||||
<span id="statusText">未连接</span>
|
||||
<div class="stats">
|
||||
<div>更新次数: <span id="updateCount">0</span></div>
|
||||
<div>最后更新: <span id="lastUpdate">-</span></div>
|
||||
</div>
|
||||
<div class="controls">
|
||||
<input type="text" id="wsUrl" value="ws://localhost:8765" style="padding: 8px; border: 1px solid #ddd; border-radius: 4px; width: 200px;">
|
||||
<button class="primary" id="connectBtn" onclick="connect()">连接</button>
|
||||
<button class="danger" id="disconnectBtn" onclick="disconnect()" style="display:none">断开</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="grid">
|
||||
<!-- 指数行情 -->
|
||||
<div class="card">
|
||||
<h2>📊 指数行情</h2>
|
||||
<table class="quote-table" id="indexTable">
|
||||
<thead>
|
||||
<tr>
|
||||
<th>代码/名称</th>
|
||||
<th>最新</th>
|
||||
<th>涨跌</th>
|
||||
<th>涨跌幅</th>
|
||||
<th>成交额(亿)</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody></tbody>
|
||||
</table>
|
||||
</div>
|
||||
|
||||
<!-- 股票行情 -->
|
||||
<div class="card">
|
||||
<h2>📈 股票行情 (前20)</h2>
|
||||
<table class="quote-table" id="stockTable">
|
||||
<thead>
|
||||
<tr>
|
||||
<th>代码/名称</th>
|
||||
<th>最新</th>
|
||||
<th>涨跌幅</th>
|
||||
<th>买一</th>
|
||||
<th>卖一</th>
|
||||
<th>成交额(万)</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody></tbody>
|
||||
</table>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 日志 -->
|
||||
<div id="log"></div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
let ws = null;
|
||||
let updateCount = 0;
|
||||
const indexData = {};
|
||||
const stockData = {};
|
||||
|
||||
function log(msg, type = '') {
|
||||
const logEl = document.getElementById('log');
|
||||
const time = new Date().toLocaleTimeString();
|
||||
logEl.innerHTML += `<div class="${type}">[${time}] ${msg}</div>`;
|
||||
logEl.scrollTop = logEl.scrollHeight;
|
||||
}
|
||||
|
||||
function setStatus(status) {
|
||||
const indicator = document.getElementById('statusIndicator');
|
||||
const text = document.getElementById('statusText');
|
||||
const connectBtn = document.getElementById('connectBtn');
|
||||
const disconnectBtn = document.getElementById('disconnectBtn');
|
||||
|
||||
indicator.className = 'status-indicator ' + status;
|
||||
|
||||
switch(status) {
|
||||
case 'connected':
|
||||
text.textContent = '已连接';
|
||||
connectBtn.style.display = 'none';
|
||||
disconnectBtn.style.display = 'inline-block';
|
||||
break;
|
||||
case 'connecting':
|
||||
text.textContent = '连接中...';
|
||||
break;
|
||||
case 'disconnected':
|
||||
text.textContent = '未连接';
|
||||
connectBtn.style.display = 'inline-block';
|
||||
disconnectBtn.style.display = 'none';
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
function connect() {
|
||||
const url = document.getElementById('wsUrl').value;
|
||||
log(`连接到 ${url}...`, 'info');
|
||||
setStatus('connecting');
|
||||
|
||||
try {
|
||||
ws = new WebSocket(url);
|
||||
|
||||
ws.onopen = () => {
|
||||
log('WebSocket 连接成功', 'info');
|
||||
setStatus('connected');
|
||||
|
||||
// 订阅所有频道
|
||||
ws.send(JSON.stringify({
|
||||
action: 'subscribe',
|
||||
channels: ['index', 'stock', 'etf']
|
||||
}));
|
||||
};
|
||||
|
||||
ws.onmessage = (event) => {
|
||||
try {
|
||||
const msg = JSON.parse(event.data);
|
||||
handleMessage(msg);
|
||||
} catch(e) {
|
||||
log('解析消息失败: ' + e, 'error');
|
||||
}
|
||||
};
|
||||
|
||||
ws.onclose = () => {
|
||||
log('连接已断开', 'error');
|
||||
setStatus('disconnected');
|
||||
};
|
||||
|
||||
ws.onerror = (err) => {
|
||||
log('连接错误', 'error');
|
||||
setStatus('disconnected');
|
||||
};
|
||||
} catch(e) {
|
||||
log('连接失败: ' + e, 'error');
|
||||
setStatus('disconnected');
|
||||
}
|
||||
}
|
||||
|
||||
function disconnect() {
|
||||
if (ws) {
|
||||
ws.close();
|
||||
ws = null;
|
||||
}
|
||||
}
|
||||
|
||||
function handleMessage(msg) {
|
||||
if (msg.type === 'subscribed') {
|
||||
log(`已订阅: ${msg.channels.join(', ')}`, 'info');
|
||||
return;
|
||||
}
|
||||
|
||||
updateCount++;
|
||||
document.getElementById('updateCount').textContent = updateCount;
|
||||
document.getElementById('lastUpdate').textContent = new Date().toLocaleTimeString();
|
||||
|
||||
if (msg.type === 'index') {
|
||||
Object.assign(indexData, msg.data);
|
||||
renderIndexTable();
|
||||
} else if (msg.type === 'stock' || msg.type === 'etf') {
|
||||
Object.assign(stockData, msg.data);
|
||||
renderStockTable();
|
||||
}
|
||||
}
|
||||
|
||||
function formatPrice(price, prevClose) {
|
||||
if (!price || price === 0) return '-';
|
||||
const cls = price > prevClose ? 'price-up' : (price < prevClose ? 'price-down' : 'price-flat');
|
||||
return `<span class="${cls}">${price.toFixed(2)}</span>`;
|
||||
}
|
||||
|
||||
function formatChange(change, changePct) {
|
||||
if (change === undefined) return '-';
|
||||
const cls = change > 0 ? 'price-up' : (change < 0 ? 'price-down' : 'price-flat');
|
||||
const sign = change > 0 ? '+' : '';
|
||||
return `<span class="${cls}">${sign}${change.toFixed(2)}</span>`;
|
||||
}
|
||||
|
||||
function formatChangePct(changePct) {
|
||||
if (changePct === undefined) return '-';
|
||||
const cls = changePct > 0 ? 'price-up' : (changePct < 0 ? 'price-down' : 'price-flat');
|
||||
const sign = changePct > 0 ? '+' : '';
|
||||
return `<span class="${cls}">${sign}${changePct.toFixed(2)}%</span>`;
|
||||
}
|
||||
|
||||
function renderIndexTable() {
|
||||
const tbody = document.querySelector('#indexTable tbody');
|
||||
const importantIndexes = ['000001', '000002', '000003', '000016', '000300'];
|
||||
|
||||
let html = '';
|
||||
// 先显示重要指数
|
||||
for (const code of importantIndexes) {
|
||||
if (indexData[code]) {
|
||||
html += renderIndexRow(indexData[code]);
|
||||
}
|
||||
}
|
||||
// 再显示其他
|
||||
for (const [code, data] of Object.entries(indexData)) {
|
||||
if (!importantIndexes.includes(code)) {
|
||||
html += renderIndexRow(data);
|
||||
}
|
||||
if (Object.keys(indexData).length > 10) break;
|
||||
}
|
||||
|
||||
tbody.innerHTML = html;
|
||||
}
|
||||
|
||||
function renderIndexRow(data) {
|
||||
const change = data.last_price - data.prev_close;
|
||||
const changePct = data.prev_close > 0 ? (change / data.prev_close * 100) : 0;
|
||||
const amountYi = (data.amount / 100000000).toFixed(2);
|
||||
|
||||
return `
|
||||
<tr class="update-flash">
|
||||
<td>
|
||||
${data.security_id}
|
||||
<div class="name">${data.security_name}</div>
|
||||
</td>
|
||||
<td>${formatPrice(data.last_price, data.prev_close)}</td>
|
||||
<td>${formatChange(change)}</td>
|
||||
<td>${formatChangePct(changePct)}</td>
|
||||
<td>${amountYi}</td>
|
||||
</tr>
|
||||
`;
|
||||
}
|
||||
|
||||
function renderStockTable() {
|
||||
const tbody = document.querySelector('#stockTable tbody');
|
||||
const entries = Object.entries(stockData).slice(0, 20);
|
||||
|
||||
let html = '';
|
||||
for (const [code, data] of entries) {
|
||||
const changePct = data.prev_close > 0 ?
|
||||
((data.last_price - data.prev_close) / data.prev_close * 100) : 0;
|
||||
const amountWan = (data.amount / 10000).toFixed(0);
|
||||
|
||||
html += `
|
||||
<tr class="update-flash">
|
||||
<td>
|
||||
${data.security_id}
|
||||
<div class="name">${data.security_name}</div>
|
||||
</td>
|
||||
<td>${formatPrice(data.last_price, data.prev_close)}</td>
|
||||
<td>${formatChangePct(changePct)}</td>
|
||||
<td>${data.bid_prices?.[0]?.toFixed(2) || '-'}</td>
|
||||
<td>${data.ask_prices?.[0]?.toFixed(2) || '-'}</td>
|
||||
<td>${amountWan}</td>
|
||||
</tr>
|
||||
`;
|
||||
}
|
||||
|
||||
tbody.innerHTML = html;
|
||||
}
|
||||
|
||||
// 页面加载后自动连接
|
||||
// window.onload = connect;
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
289
szse_html.html
Normal file
289
szse_html.html
Normal file
@@ -0,0 +1,289 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="zh-CN">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>深交所行情 WebSocket 测试</title>
|
||||
<style>
|
||||
* { box-sizing: border-box; margin: 0; padding: 0; }
|
||||
body { font-family: 'Microsoft YaHei', sans-serif; background: #1a1a2e; color: #eee; padding: 20px; }
|
||||
h1 { text-align: center; margin-bottom: 20px; color: #00d4ff; }
|
||||
.container { max-width: 1400px; margin: 0 auto; }
|
||||
.controls { background: #16213e; padding: 15px; border-radius: 8px; margin-bottom: 20px; display: flex; gap: 10px; align-items: center; flex-wrap: wrap; }
|
||||
input { padding: 8px 12px; border: 1px solid #0f3460; border-radius: 4px; background: #1a1a2e; color: #fff; width: 300px; }
|
||||
button { padding: 8px 20px; border: none; border-radius: 4px; cursor: pointer; font-weight: bold; }
|
||||
.btn-connect { background: #00d4ff; color: #000; }
|
||||
.btn-disconnect { background: #e94560; color: #fff; }
|
||||
.btn-ping { background: #ffc107; color: #000; }
|
||||
.btn-clear { background: #6c757d; color: #fff; }
|
||||
.status { padding: 5px 15px; border-radius: 20px; font-size: 14px; }
|
||||
.status.connected { background: #28a745; }
|
||||
.status.disconnected { background: #dc3545; }
|
||||
.stats { display: grid; grid-template-columns: repeat(auto-fit, minmax(150px, 1fr)); gap: 10px; margin-bottom: 20px; }
|
||||
.stat-card { background: #16213e; padding: 15px; border-radius: 8px; text-align: center; }
|
||||
.stat-card .value { font-size: 24px; font-weight: bold; color: #00d4ff; }
|
||||
.stat-card .label { font-size: 12px; color: #888; margin-top: 5px; }
|
||||
.panels { display: grid; grid-template-columns: 1fr 1fr; gap: 20px; }
|
||||
@media (max-width: 900px) { .panels { grid-template-columns: 1fr; } }
|
||||
.panel { background: #16213e; border-radius: 8px; overflow: hidden; }
|
||||
.panel-header { background: #0f3460; padding: 10px 15px; font-weight: bold; display: flex; justify-content: space-between; }
|
||||
.panel-body { height: 400px; overflow-y: auto; padding: 10px; font-family: 'Consolas', monospace; font-size: 12px; }
|
||||
.msg { padding: 5px; border-bottom: 1px solid #0f3460; word-break: break-all; }
|
||||
.msg.snapshot { color: #ffc107; }
|
||||
.msg.stock { color: #28a745; }
|
||||
.msg.index { color: #17a2b8; }
|
||||
.msg.bond { color: #6f42c1; }
|
||||
.msg.hk_stock { color: #fd7e14; }
|
||||
.msg.afterhours_block, .msg.afterhours_trading { color: #e83e8c; }
|
||||
.msg.volume_stats { color: #20c997; }
|
||||
.msg.fund_nav { color: #6610f2; }
|
||||
.msg.pong { color: #adb5bd; }
|
||||
.quote-table { width: 100%; border-collapse: collapse; font-size: 13px; }
|
||||
.quote-table th, .quote-table td { padding: 8px; text-align: right; border-bottom: 1px solid #0f3460; }
|
||||
.quote-table th { background: #0f3460; text-align: center; }
|
||||
.quote-table .code { text-align: left; font-weight: bold; }
|
||||
.quote-table .up { color: #f5222d; }
|
||||
.quote-table .down { color: #52c41a; }
|
||||
.quote-table .flat { color: #888; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<h1>深交所行情 WebSocket 测试</h1>
|
||||
|
||||
<div class="controls">
|
||||
<input type="text" id="wsUrl" value="ws://222.128.1.157:8765" placeholder="WebSocket URL">
|
||||
<button class="btn-connect" onclick="connect()">连接</button>
|
||||
<button class="btn-disconnect" onclick="disconnect()">断开</button>
|
||||
<button class="btn-ping" onclick="sendPing()">Ping</button>
|
||||
<button class="btn-clear" onclick="clearLogs()">清空日志</button>
|
||||
<span id="status" class="status disconnected">未连接</span>
|
||||
</div>
|
||||
|
||||
<div class="stats">
|
||||
<div class="stat-card">
|
||||
<div class="value" id="msgCount">0</div>
|
||||
<div class="label">消息总数</div>
|
||||
</div>
|
||||
<div class="stat-card">
|
||||
<div class="value" id="stockCount">0</div>
|
||||
<div class="label">股票 (300111)</div>
|
||||
</div>
|
||||
<div class="stat-card">
|
||||
<div class="value" id="indexCount">0</div>
|
||||
<div class="label">指数 (309011)</div>
|
||||
</div>
|
||||
<div class="stat-card">
|
||||
<div class="value" id="bondCount">0</div>
|
||||
<div class="label">债券 (300211)</div>
|
||||
</div>
|
||||
<div class="stat-card">
|
||||
<div class="value" id="hkCount">0</div>
|
||||
<div class="label">港股 (306311)</div>
|
||||
</div>
|
||||
<div class="stat-card">
|
||||
<div class="value" id="otherCount">0</div>
|
||||
<div class="label">其他类型</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="panels">
|
||||
<div class="panel">
|
||||
<div class="panel-header">
|
||||
<span>实时行情</span>
|
||||
<span id="quoteTime">--</span>
|
||||
</div>
|
||||
<div class="panel-body" style="padding: 0;">
|
||||
<table class="quote-table">
|
||||
<thead>
|
||||
<tr>
|
||||
<th>代码</th>
|
||||
<th>类型</th>
|
||||
<th>最新价</th>
|
||||
<th>涨跌幅</th>
|
||||
<th>成交量</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody id="quoteTable"></tbody>
|
||||
</table>
|
||||
</div>
|
||||
</div>
|
||||
<div class="panel">
|
||||
<div class="panel-header">
|
||||
<span>消息日志</span>
|
||||
<span id="logCount">0 条</span>
|
||||
</div>
|
||||
<div class="panel-body" id="logs"></div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
let ws = null;
|
||||
let msgCount = 0;
|
||||
let counts = { stock: 0, index: 0, bond: 0, hk_stock: 0, other: 0 };
|
||||
let quotes = {};
|
||||
const maxLogs = 200;
|
||||
|
||||
function connect() {
|
||||
const url = document.getElementById('wsUrl').value;
|
||||
if (ws && ws.readyState === WebSocket.OPEN) {
|
||||
addLog('已经连接', 'warning');
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
ws = new WebSocket(url);
|
||||
|
||||
ws.onopen = () => {
|
||||
setStatus(true);
|
||||
addLog(`已连接到 ${url}`, 'info');
|
||||
};
|
||||
|
||||
ws.onmessage = (e) => {
|
||||
msgCount++;
|
||||
document.getElementById('msgCount').textContent = msgCount;
|
||||
|
||||
try {
|
||||
const msg = JSON.parse(e.data);
|
||||
handleMessage(msg);
|
||||
} catch (err) {
|
||||
addLog(`解析错误: ${err.message}`, 'error');
|
||||
}
|
||||
};
|
||||
|
||||
ws.onclose = () => {
|
||||
setStatus(false);
|
||||
addLog('连接已断开', 'warning');
|
||||
};
|
||||
|
||||
ws.onerror = (err) => {
|
||||
addLog(`连接错误`, 'error');
|
||||
};
|
||||
} catch (err) {
|
||||
addLog(`连接失败: ${err.message}`, 'error');
|
||||
}
|
||||
}
|
||||
|
||||
function disconnect() {
|
||||
if (ws) {
|
||||
ws.close();
|
||||
ws = null;
|
||||
}
|
||||
}
|
||||
|
||||
function sendPing() {
|
||||
if (ws && ws.readyState === WebSocket.OPEN) {
|
||||
ws.send(JSON.stringify({ type: 'ping' }));
|
||||
addLog('发送 Ping', 'info');
|
||||
} else {
|
||||
addLog('未连接', 'error');
|
||||
}
|
||||
}
|
||||
|
||||
function handleMessage(msg) {
|
||||
if (msg.type === 'snapshot') {
|
||||
addLog(`收到快照: 股票${msg.data.stocks?.length || 0}只, 指数${msg.data.indexes?.length || 0}个, 债券${msg.data.bonds?.length || 0}只`, 'snapshot');
|
||||
// 初始化行情表
|
||||
msg.data.stocks?.slice(-20).forEach(s => updateQuote('stock', s));
|
||||
msg.data.indexes?.forEach(s => updateQuote('index', s));
|
||||
} else if (msg.type === 'realtime') {
|
||||
const cat = msg.category;
|
||||
if (cat === 'stock') { counts.stock++; updateQuote('stock', msg.data); }
|
||||
else if (cat === 'index') { counts.index++; updateQuote('index', msg.data); }
|
||||
else if (cat === 'bond') { counts.bond++; updateQuote('bond', msg.data); }
|
||||
else if (cat === 'hk_stock') { counts.hk_stock++; updateQuote('hk_stock', msg.data); }
|
||||
else { counts.other++; }
|
||||
|
||||
document.getElementById('stockCount').textContent = counts.stock;
|
||||
document.getElementById('indexCount').textContent = counts.index;
|
||||
document.getElementById('bondCount').textContent = counts.bond;
|
||||
document.getElementById('hkCount').textContent = counts.hk_stock;
|
||||
document.getElementById('otherCount').textContent = counts.other;
|
||||
|
||||
// 每50条记录一次日志
|
||||
if (msgCount % 50 === 0) {
|
||||
addLog(`${cat}: ${msg.data.security_id} = ${msg.data.last_px || msg.data.current_index || '--'}`, cat);
|
||||
}
|
||||
} else if (msg.type === 'pong') {
|
||||
addLog('收到 Pong', 'pong');
|
||||
}
|
||||
|
||||
document.getElementById('quoteTime').textContent = new Date().toLocaleTimeString();
|
||||
}
|
||||
|
||||
function updateQuote(type, data) {
|
||||
const id = data.security_id;
|
||||
quotes[id] = { type, ...data };
|
||||
renderQuotes();
|
||||
}
|
||||
|
||||
function renderQuotes() {
|
||||
const tbody = document.getElementById('quoteTable');
|
||||
const sorted = Object.values(quotes).sort((a, b) => {
|
||||
const order = { index: 0, stock: 1, hk_stock: 2, bond: 3 };
|
||||
return (order[a.type] || 9) - (order[b.type] || 9);
|
||||
}).slice(0, 30);
|
||||
|
||||
tbody.innerHTML = sorted.map(q => {
|
||||
const price = q.last_px || q.current_index || 0;
|
||||
const prev = q.prev_close || 0;
|
||||
const change = prev ? ((price - prev) / prev * 100).toFixed(2) : '0.00';
|
||||
const cls = change > 0 ? 'up' : change < 0 ? 'down' : 'flat';
|
||||
const typeMap = { stock: '股票', index: '指数', bond: '债券', hk_stock: '港股' };
|
||||
return `<tr>
|
||||
<td class="code">${q.security_id}</td>
|
||||
<td>${typeMap[q.type] || q.type}</td>
|
||||
<td class="${cls}">${price.toFixed(q.type === 'index' ? 2 : 2)}</td>
|
||||
<td class="${cls}">${change}%</td>
|
||||
<td>${formatVolume(q.volume)}</td>
|
||||
</tr>`;
|
||||
}).join('');
|
||||
}
|
||||
|
||||
function formatVolume(v) {
|
||||
if (!v) return '--';
|
||||
if (v >= 100000000) return (v / 100000000).toFixed(2) + '亿';
|
||||
if (v >= 10000) return (v / 10000).toFixed(2) + '万';
|
||||
return v.toString();
|
||||
}
|
||||
|
||||
function addLog(text, type = 'info') {
|
||||
const logs = document.getElementById('logs');
|
||||
const time = new Date().toLocaleTimeString();
|
||||
const div = document.createElement('div');
|
||||
div.className = `msg ${type}`;
|
||||
div.textContent = `[${time}] ${text}`;
|
||||
logs.insertBefore(div, logs.firstChild);
|
||||
|
||||
// 限制日志数量
|
||||
while (logs.children.length > maxLogs) {
|
||||
logs.removeChild(logs.lastChild);
|
||||
}
|
||||
document.getElementById('logCount').textContent = `${logs.children.length} 条`;
|
||||
}
|
||||
|
||||
function clearLogs() {
|
||||
document.getElementById('logs').innerHTML = '';
|
||||
document.getElementById('logCount').textContent = '0 条';
|
||||
msgCount = 0;
|
||||
counts = { stock: 0, index: 0, bond: 0, hk_stock: 0, other: 0 };
|
||||
document.getElementById('msgCount').textContent = '0';
|
||||
document.getElementById('stockCount').textContent = '0';
|
||||
document.getElementById('indexCount').textContent = '0';
|
||||
document.getElementById('bondCount').textContent = '0';
|
||||
document.getElementById('hkCount').textContent = '0';
|
||||
document.getElementById('otherCount').textContent = '0';
|
||||
}
|
||||
|
||||
function setStatus(connected) {
|
||||
const el = document.getElementById('status');
|
||||
el.textContent = connected ? '已连接' : '未连接';
|
||||
el.className = `status ${connected ? 'connected' : 'disconnected'}`;
|
||||
}
|
||||
|
||||
// 页面关闭时断开连接
|
||||
window.onbeforeunload = () => { if (ws) ws.close(); };
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
@@ -112,6 +112,42 @@ server {
|
||||
proxy_buffering off;
|
||||
}
|
||||
|
||||
# ============================================
|
||||
# 实时行情 WebSocket 代理(灵活屏功能)
|
||||
# ============================================
|
||||
|
||||
# 上交所实时行情 WebSocket
|
||||
location /ws/sse {
|
||||
proxy_pass http://49.232.185.254:8765;
|
||||
proxy_http_version 1.1;
|
||||
proxy_set_header Upgrade $http_upgrade;
|
||||
proxy_set_header Connection $connection_upgrade;
|
||||
proxy_set_header Host $host;
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
proxy_connect_timeout 7d;
|
||||
proxy_send_timeout 7d;
|
||||
proxy_read_timeout 7d;
|
||||
proxy_buffering off;
|
||||
}
|
||||
|
||||
# 深交所实时行情 WebSocket
|
||||
location /ws/szse {
|
||||
proxy_pass http://222.128.1.157:8765;
|
||||
proxy_http_version 1.1;
|
||||
proxy_set_header Upgrade $http_upgrade;
|
||||
proxy_set_header Connection $connection_upgrade;
|
||||
proxy_set_header Host $host;
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
proxy_connect_timeout 7d;
|
||||
proxy_send_timeout 7d;
|
||||
proxy_read_timeout 7d;
|
||||
proxy_buffering off;
|
||||
}
|
||||
|
||||
location /mcp/ {
|
||||
proxy_pass http://127.0.0.1:8900/;
|
||||
proxy_http_version 1.1;
|
||||
@@ -142,7 +178,6 @@ server {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# 概念板块API代理
|
||||
location /concept-api/ {
|
||||
proxy_pass http://222.128.1.157:16801/;
|
||||
@@ -158,6 +193,7 @@ server {
|
||||
proxy_send_timeout 60s;
|
||||
proxy_read_timeout 60s;
|
||||
}
|
||||
|
||||
# Elasticsearch API代理(价值论坛)
|
||||
location /es-api/ {
|
||||
proxy_pass http://222.128.1.157:19200/;
|
||||
@@ -223,36 +259,7 @@ server {
|
||||
proxy_send_timeout 86400s;
|
||||
proxy_read_timeout 86400s;
|
||||
}
|
||||
# AI Chat 应用 (Next.js) - MCP 集成
|
||||
# AI Chat 静态资源(图片、CSS、JS)
|
||||
location ~ ^/ai-chat/(images|_next/static|_next/image|favicon.ico) {
|
||||
proxy_pass http://127.0.0.1:3000;
|
||||
proxy_http_version 1.1;
|
||||
proxy_set_header Host $host;
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
|
||||
# 缓存设置
|
||||
expires 30d;
|
||||
add_header Cache-Control "public, immutable";
|
||||
}
|
||||
|
||||
# AI Chat 主应用
|
||||
location /ai-chat {
|
||||
proxy_pass http://127.0.0.1:3000;
|
||||
proxy_http_version 1.1;
|
||||
proxy_set_header Host $host;
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
proxy_set_header Cookie $http_cookie;
|
||||
proxy_pass_request_headers on;
|
||||
proxy_set_header Upgrade $http_upgrade;
|
||||
proxy_set_header Connection "upgrade";
|
||||
proxy_buffering off;
|
||||
proxy_cache off;
|
||||
}
|
||||
# iframe 内部资源代理(Bytedesk 聊天窗口的 CSS/JS)
|
||||
location /chat/ {
|
||||
proxy_pass http://43.143.189.195/chat/;
|
||||
@@ -326,6 +333,22 @@ server {
|
||||
add_header Cache-Control "public, max-age=86400";
|
||||
}
|
||||
|
||||
# Bytedesk 文件访问代理(仅 2025 年文件)
|
||||
location ^~ /file/2025/ {
|
||||
proxy_pass http://43.143.189.195/file/2025/;
|
||||
proxy_http_version 1.1;
|
||||
proxy_set_header Host $host;
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
|
||||
# 缓存配置
|
||||
proxy_cache_valid 200 1d;
|
||||
expires 1d;
|
||||
add_header Cache-Control "public, max-age=86400";
|
||||
add_header Access-Control-Allow-Origin *;
|
||||
}
|
||||
|
||||
|
||||
# Visitor API 代理(Bytedesk 初始化接口)
|
||||
location /visitor/ {
|
||||
|
||||
Reference in New Issue
Block a user