diff --git a/app.py b/app.py index 56bc3250..cdfdbe86 100755 --- a/app.py +++ b/app.py @@ -6412,6 +6412,10 @@ def get_stock_kline(stock_code): except ValueError: return jsonify({'error': 'Invalid event_time format'}), 400 + # 确保股票代码包含后缀(ClickHouse 中数据带后缀) + if '.' not in stock_code: + stock_code = f"{stock_code}.SH" if stock_code.startswith('6') else f"{stock_code}.SZ" + # 获取股票名称 with engine.connect() as conn: result = conn.execute(text( @@ -7819,7 +7823,7 @@ def get_index_realtime(index_code): }) except Exception as e: - logger.error(f"获取指数实时行情失败: {index_code}, 错误: {str(e)}") + app.logger.error(f"获取指数实时行情失败: {index_code}, 错误: {str(e)}") return jsonify({ 'success': False, 'error': str(e), @@ -7837,8 +7841,13 @@ def get_index_kline(index_code): except ValueError: return jsonify({'error': 'Invalid event_time format'}), 400 + # 确保指数代码包含后缀(ClickHouse 中数据带后缀) + # 399xxx -> 深交所, 其他(000xxx等)-> 上交所 + if '.' not in index_code: + index_code = f"{index_code}.SZ" if index_code.startswith('39') else f"{index_code}.SH" + # 指数名称(暂无索引表,先返回代码本身) - index_name = index_code + index_name = index_code.split('.')[0] if chart_type == 'minute': return get_index_minute_kline(index_code, event_datetime, index_name) @@ -12044,10 +12053,11 @@ def get_market_summary(seccode): @app.route('/api/stocks/search', methods=['GET']) def search_stocks(): - """搜索股票(支持股票代码、股票简称、拼音首字母)""" + """搜索股票和指数(支持代码、名称搜索)""" try: query = request.args.get('q', '').strip() limit = request.args.get('limit', 20, type=int) + search_type = request.args.get('type', 'all') # all, stock, index if not query: return jsonify({ @@ -12055,73 +12065,132 @@ def search_stocks(): 'error': '请输入搜索关键词' }), 400 + results = [] + with engine.connect() as conn: - test_sql = text(""" - SELECT SECCODE, SECNAME, F001V, F003V, F010V, F011V - FROM ea_stocklist - WHERE SECCODE = '300750' - OR F001V LIKE '%ndsd%' LIMIT 5 - """) - test_result = conn.execute(test_sql).fetchall() + # 搜索指数(优先显示指数,因为通常用户搜索代码时指数更常用) + if search_type in ('all', 'index'): + index_sql = text(""" + SELECT DISTINCT + INDEXCODE as stock_code, + SECNAME as stock_name, + INDEXNAME as full_name, + F018V as exchange + FROM ea_exchangeindex + WHERE ( + UPPER(INDEXCODE) LIKE UPPER(:query_pattern) + OR UPPER(SECNAME) LIKE UPPER(:query_pattern) + OR UPPER(INDEXNAME) LIKE UPPER(:query_pattern) + ) + ORDER BY CASE + WHEN UPPER(INDEXCODE) = UPPER(:exact_query) THEN 1 + WHEN UPPER(SECNAME) = UPPER(:exact_query) THEN 2 + WHEN UPPER(INDEXCODE) LIKE UPPER(:prefix_pattern) THEN 3 + WHEN UPPER(SECNAME) LIKE UPPER(:prefix_pattern) THEN 4 + ELSE 5 + END, + INDEXCODE + LIMIT :limit + """) - # 构建搜索SQL - 支持股票代码、股票简称、拼音简称搜索 - search_sql = text(""" - SELECT DISTINCT SECCODE as stock_code, - SECNAME as stock_name, - F001V as pinyin_abbr, - F003V as security_type, - F005V as exchange, - F011V as listing_status - FROM ea_stocklist - WHERE ( - UPPER(SECCODE) LIKE UPPER(:query_pattern) - OR UPPER(SECNAME) LIKE UPPER(:query_pattern) - OR UPPER(F001V) LIKE UPPER(:query_pattern) - ) - -- 基本过滤条件:只搜索正常的A股和B股 - AND (F011V = '正常上市' OR F010V = '013001') -- 正常上市状态 - AND F003V IN ('A股', 'B股') -- 只搜索A股和B股 - ORDER BY CASE - WHEN UPPER(SECCODE) = UPPER(:exact_query) THEN 1 - WHEN UPPER(SECNAME) = UPPER(:exact_query) THEN 2 - WHEN UPPER(F001V) = UPPER(:exact_query) THEN 3 - WHEN UPPER(SECCODE) LIKE UPPER(:prefix_pattern) THEN 4 - WHEN UPPER(SECNAME) LIKE UPPER(:prefix_pattern) THEN 5 - WHEN UPPER(F001V) LIKE UPPER(:prefix_pattern) THEN 6 - ELSE 7 - END, - SECCODE LIMIT :limit - """) + index_result = conn.execute(index_sql, { + 'query_pattern': f'%{query}%', + 'exact_query': query, + 'prefix_pattern': f'{query}%', + 'limit': limit + }).fetchall() - result = conn.execute(search_sql, { - 'query_pattern': f'%{query}%', - 'exact_query': query, - 'prefix_pattern': f'{query}%', - 'limit': limit - }).fetchall() + for row in index_result: + results.append({ + 'stock_code': row.stock_code, + 'stock_name': row.stock_name, + 'full_name': row.full_name, + 'exchange': row.exchange, + 'isIndex': True, + 'security_type': '指数' + }) - stocks = [] - for row in result: - # 获取当前价格 - current_price, _ = get_latest_price_from_clickhouse(row.stock_code) + # 搜索股票 + if search_type in ('all', 'stock'): + stock_sql = text(""" + SELECT DISTINCT SECCODE as stock_code, + SECNAME as stock_name, + F001V as pinyin_abbr, + F003V as security_type, + F005V as exchange, + F011V as listing_status + FROM ea_stocklist + WHERE ( + UPPER(SECCODE) LIKE UPPER(:query_pattern) + OR UPPER(SECNAME) LIKE UPPER(:query_pattern) + OR UPPER(F001V) LIKE UPPER(:query_pattern) + ) + AND (F011V = '正常上市' OR F010V = '013001') + AND F003V IN ('A股', 'B股') + ORDER BY CASE + WHEN UPPER(SECCODE) = UPPER(:exact_query) THEN 1 + WHEN UPPER(SECNAME) = UPPER(:exact_query) THEN 2 + WHEN UPPER(F001V) = UPPER(:exact_query) THEN 3 + WHEN UPPER(SECCODE) LIKE UPPER(:prefix_pattern) THEN 4 + WHEN UPPER(SECNAME) LIKE UPPER(:prefix_pattern) THEN 5 + WHEN UPPER(F001V) LIKE UPPER(:prefix_pattern) THEN 6 + ELSE 7 + END, + SECCODE + LIMIT :limit + """) - stocks.append({ - 'stock_code': row.stock_code, - 'stock_name': row.stock_name, - 'current_price': current_price or 0, # 添加当前价格 - 'pinyin_abbr': row.pinyin_abbr, - 'security_type': row.security_type, - 'exchange': row.exchange, - 'listing_status': row.listing_status - }) + stock_result = conn.execute(stock_sql, { + 'query_pattern': f'%{query}%', + 'exact_query': query, + 'prefix_pattern': f'{query}%', + 'limit': limit + }).fetchall() + + for row in stock_result: + results.append({ + 'stock_code': row.stock_code, + 'stock_name': row.stock_name, + 'pinyin_abbr': row.pinyin_abbr, + 'security_type': row.security_type, + 'exchange': row.exchange, + 'listing_status': row.listing_status, + 'isIndex': False + }) + + # 如果搜索全部,按相关性重新排序(精确匹配优先) + if search_type == 'all': + def sort_key(item): + code = item['stock_code'].upper() + name = item['stock_name'].upper() + q = query.upper() + # 精确匹配代码优先 + if code == q: + return (0, not item['isIndex'], code) # 指数优先 + # 精确匹配名称 + if name == q: + return (1, not item['isIndex'], code) + # 前缀匹配代码 + if code.startswith(q): + return (2, not item['isIndex'], code) + # 前缀匹配名称 + if name.startswith(q): + return (3, not item['isIndex'], code) + return (4, not item['isIndex'], code) + + results.sort(key=sort_key) + + # 限制总数 + results = results[:limit] return jsonify({ 'success': True, - 'data': stocks, - 'count': len(stocks) + 'data': results, + 'count': len(results) }) except Exception as e: + app.logger.error(f"搜索股票/指数错误: {e}") return jsonify({ 'success': False, 'error': str(e) @@ -12403,7 +12472,21 @@ def get_daily_top_concepts(): top_concepts = [] for concept in data.get('results', []): - # 保持与 /concept-api/search 相同的字段结构 + # 处理 stocks 字段:兼容 {name, code} 和 {stock_name, stock_code} 两种格式 + raw_stocks = concept.get('stocks', []) + formatted_stocks = [] + for stock in raw_stocks: + # 优先使用 stock_name,其次使用 name + stock_name = stock.get('stock_name') or stock.get('name', '') + stock_code = stock.get('stock_code') or stock.get('code', '') + formatted_stocks.append({ + 'stock_name': stock_name, + 'stock_code': stock_code, + 'name': stock_name, # 兼容旧格式 + 'code': stock_code # 兼容旧格式 + }) + + # 保持与 /concept-api/search 相同的字段结构,并添加新字段 top_concepts.append({ 'concept_id': concept.get('concept_id'), 'concept': concept.get('concept'), # 原始字段名 @@ -12414,8 +12497,10 @@ def get_daily_top_concepts(): 'match_type': concept.get('match_type'), 'price_info': concept.get('price_info', {}), # 完整的价格信息 'change_percent': concept.get('price_info', {}).get('avg_change_pct', 0), # 兼容旧字段 - 'happened_times': concept.get('happened_times', []), # 历史触发时间 - 'stocks': concept.get('stocks', []), # 返回完整股票列表 + 'tags': concept.get('tags', []), # 标签列表 + 'outbreak_dates': concept.get('outbreak_dates', []), # 爆发日期列表 + 'hierarchy': concept.get('hierarchy'), # 层级信息 {lv1, lv2, lv3} + 'stocks': formatted_stocks, # 返回格式化后的股票列表 'hot_score': concept.get('hot_score') }) @@ -12442,6 +12527,557 @@ def get_daily_top_concepts(): }), 500 +# ==================== 热点概览 API ==================== + +@app.route('/api/market/hotspot-overview', methods=['GET']) +def get_hotspot_overview(): + """ + 获取热点概览数据(用于个股中心的热点概览图表) + 返回:指数分时数据 + 概念异动标注 + + 数据来源: + - 指数分时:ClickHouse index_minute 表 + - 概念异动:MySQL concept_anomaly_hybrid 表(来自 realtime_detector.py) + """ + try: + trade_date = request.args.get('date') + index_code = request.args.get('index', '000001.SH') + + # 如果没有指定日期,使用最新交易日 + if not trade_date: + today = date.today() + if today in trading_days_set: + trade_date = today.strftime('%Y-%m-%d') + else: + target_date = get_trading_day_near_date(today) + trade_date = target_date.strftime('%Y-%m-%d') if target_date else today.strftime('%Y-%m-%d') + + # 1. 获取指数分时数据 + client = get_clickhouse_client() + target_date_obj = datetime.strptime(trade_date, '%Y-%m-%d').date() + + index_data = client.execute( + """ + SELECT timestamp, open, high, low, close, volume + FROM index_minute + WHERE code = %(code)s + AND toDate(timestamp) = %(date)s + ORDER BY timestamp + """, + { + 'code': index_code, + 'date': target_date_obj + } + ) + + # 获取昨收价 + code_no_suffix = index_code.split('.')[0] + prev_close = None + with engine.connect() as conn: + prev_result = conn.execute(text(""" + SELECT F006N FROM ea_exchangetrade + WHERE INDEXCODE = :code + AND TRADEDATE < :today + ORDER BY TRADEDATE DESC LIMIT 1 + """), { + 'code': code_no_suffix, + 'today': target_date_obj + }).fetchone() + if prev_result and prev_result[0]: + prev_close = float(prev_result[0]) + + # 格式化指数数据 + index_timeline = [] + for row in index_data: + ts, open_p, high_p, low_p, close_p, vol = row + change_pct = None + if prev_close and close_p: + change_pct = round((float(close_p) - prev_close) / prev_close * 100, 4) + + index_timeline.append({ + 'time': ts.strftime('%H:%M'), + 'timestamp': ts.isoformat(), + 'price': float(close_p) if close_p else None, + 'open': float(open_p) if open_p else None, + 'high': float(high_p) if high_p else None, + 'low': float(low_p) if low_p else None, + 'volume': int(vol) if vol else 0, + 'change_pct': change_pct + }) + + # 2. 获取概念异动数据(优先从 V2 表,fallback 到旧表) + alerts = [] + use_v2 = False + + with engine.connect() as conn: + # 尝试查询 V2 表(时间片对齐 + 持续确认版本) + try: + v2_result = conn.execute(text(""" + SELECT + concept_id, alert_time, trade_date, alert_type, + final_score, rule_score, ml_score, trigger_reason, confirm_ratio, + alpha, alpha_zscore, amt_zscore, rank_zscore, + momentum_3m, momentum_5m, limit_up_ratio, triggered_rules + FROM concept_anomaly_v2 + WHERE trade_date = :trade_date + ORDER BY alert_time + """), {'trade_date': trade_date}) + v2_rows = v2_result.fetchall() + if v2_rows: + use_v2 = True + for row in v2_rows: + triggered_rules = None + if row[16]: + try: + triggered_rules = json.loads(row[16]) if isinstance(row[16], str) else row[16] + except: + pass + + alerts.append({ + 'concept_id': row[0], + 'concept_name': row[0], # 后面会填充 + 'time': row[1].strftime('%H:%M') if row[1] else None, + 'timestamp': row[1].isoformat() if row[1] else None, + 'alert_type': row[3], + 'final_score': float(row[4]) if row[4] else None, + 'rule_score': float(row[5]) if row[5] else None, + 'ml_score': float(row[6]) if row[6] else None, + 'trigger_reason': row[7], + # V2 新增字段 + 'confirm_ratio': float(row[8]) if row[8] else None, + 'alpha': float(row[9]) if row[9] else None, + 'alpha_zscore': float(row[10]) if row[10] else None, + 'amt_zscore': float(row[11]) if row[11] else None, + 'rank_zscore': float(row[12]) if row[12] else None, + 'momentum_3m': float(row[13]) if row[13] else None, + 'momentum_5m': float(row[14]) if row[14] else None, + 'limit_up_ratio': float(row[15]) if row[15] else 0, + 'triggered_rules': triggered_rules, + # 兼容字段 + 'importance_score': float(row[4]) / 100 if row[4] else None, + 'is_v2': True, + }) + except Exception as v2_err: + app.logger.debug(f"V2 表查询失败,使用旧表: {v2_err}") + + # Fallback: 查询旧表 + if not use_v2: + try: + alert_result = conn.execute(text(""" + SELECT + a.concept_id, a.alert_time, a.trade_date, a.alert_type, + a.final_score, a.rule_score, a.ml_score, a.trigger_reason, + a.alpha, a.alpha_delta, a.amt_ratio, a.amt_delta, + a.rank_pct, a.limit_up_ratio, a.stock_count, a.total_amt, + a.triggered_rules + FROM concept_anomaly_hybrid a + WHERE a.trade_date = :trade_date + ORDER BY a.alert_time + """), {'trade_date': trade_date}) + + for row in alert_result: + triggered_rules = None + if row[16]: + try: + triggered_rules = json.loads(row[16]) if isinstance(row[16], str) else row[16] + except: + pass + + limit_up_ratio = float(row[13]) if row[13] else 0 + stock_count = int(row[14]) if row[14] else 0 + limit_up_count = int(limit_up_ratio * stock_count) if stock_count > 0 else 0 + + alerts.append({ + 'concept_id': row[0], + 'concept_name': row[0], + 'time': row[1].strftime('%H:%M') if row[1] else None, + 'timestamp': row[1].isoformat() if row[1] else None, + 'alert_type': row[3], + 'final_score': float(row[4]) if row[4] else None, + 'rule_score': float(row[5]) if row[5] else None, + 'ml_score': float(row[6]) if row[6] else None, + 'trigger_reason': row[7], + 'alpha': float(row[8]) if row[8] else None, + 'alpha_delta': float(row[9]) if row[9] else None, + 'amt_ratio': float(row[10]) if row[10] else None, + 'amt_delta': float(row[11]) if row[11] else None, + 'rank_pct': float(row[12]) if row[12] else None, + 'limit_up_ratio': limit_up_ratio, + 'limit_up_count': limit_up_count, + 'stock_count': stock_count, + 'total_amt': float(row[15]) if row[15] else None, + 'triggered_rules': triggered_rules, + 'importance_score': float(row[4]) / 100 if row[4] else None, + 'is_v2': False, + }) + except Exception as old_err: + app.logger.debug(f"旧表查询也失败: {old_err}") + + # 尝试批量获取概念名称 + if alerts: + concept_ids = list(set(a['concept_id'] for a in alerts)) + concept_names = {} # 初始化 concept_names 字典 + try: + from elasticsearch import Elasticsearch + es_client = Elasticsearch(["http://222.128.1.157:19200"]) + es_result = es_client.mget( + index='concept_library_v3', + body={'ids': concept_ids}, + _source=['concept'] + ) + for doc in es_result.get('docs', []): + if doc.get('found') and doc.get('_source'): + concept_names[doc['_id']] = doc['_source'].get('concept', doc['_id']) + # 更新 alerts 中的概念名称 + for alert in alerts: + if alert['concept_id'] in concept_names: + alert['concept_name'] = concept_names[alert['concept_id']] + except Exception as e: + app.logger.warning(f"获取概念名称失败: {e}") + + # 计算统计信息 + day_high = max([d['price'] for d in index_timeline if d['price']], default=None) + day_low = min([d['price'] for d in index_timeline if d['price']], default=None) + latest_price = index_timeline[-1]['price'] if index_timeline else None + latest_change_pct = index_timeline[-1]['change_pct'] if index_timeline else None + + return jsonify({ + 'success': True, + 'data': { + 'trade_date': trade_date, + 'index': { + 'code': index_code, + 'name': '上证指数' if index_code == '000001.SH' else index_code, + 'prev_close': prev_close, + 'latest_price': latest_price, + 'change_pct': latest_change_pct, + 'high': day_high, + 'low': day_low, + 'timeline': index_timeline + }, + 'alerts': alerts, + 'alert_count': len(alerts), + 'alert_summary': { + 'surge': len([a for a in alerts if a['alert_type'] == 'surge']), + 'surge_up': len([a for a in alerts if a['alert_type'] == 'surge_up']), + 'surge_down': len([a for a in alerts if a['alert_type'] == 'surge_down']), + 'limit_up': len([a for a in alerts if a['alert_type'] == 'limit_up']), + 'volume_spike': len([a for a in alerts if a['alert_type'] == 'volume_spike']), + 'rank_jump': len([a for a in alerts if a['alert_type'] == 'rank_jump']) + } + } + }) + + except Exception as e: + import traceback + error_trace = traceback.format_exc() + app.logger.error(f"获取热点概览数据失败: {error_trace}") + return jsonify({ + 'success': False, + 'error': str(e), + 'traceback': error_trace # 临时返回完整错误信息用于调试 + }), 500 + + +@app.route('/api/concept//stocks', methods=['GET']) +def get_concept_stocks(concept_id): + """ + 获取概念的相关股票列表(带实时涨跌幅) + + Args: + concept_id: 概念 ID(来自 ES concept_library_v3) + + Returns: + - stocks: 股票列表 [{code, name, reason, change_pct}, ...] + """ + try: + from elasticsearch import Elasticsearch + from clickhouse_driver import Client + + # 1. 从 ES 获取概念的股票列表 + es_client = Elasticsearch(["http://222.128.1.157:19200"]) + es_result = es_client.get(index='concept_library_v3', id=concept_id) + + if not es_result.get('found'): + return jsonify({ + 'success': False, + 'error': f'概念 {concept_id} 不存在' + }), 404 + + source = es_result.get('_source', {}) + concept_name = source.get('concept', concept_id) + raw_stocks = source.get('stocks', []) + + if not raw_stocks: + return jsonify({ + 'success': True, + 'data': { + 'concept_id': concept_id, + 'concept_name': concept_name, + 'stocks': [] + } + }) + + # 提取股票代码和原因 + stocks_info = [] + stock_codes = [] + for s in raw_stocks: + if isinstance(s, dict): + code = s.get('code', '') + if code and len(code) == 6: + stocks_info.append({ + 'code': code, + 'name': s.get('name', ''), + 'reason': s.get('reason', '') + }) + stock_codes.append(code) + + if not stock_codes: + return jsonify({ + 'success': True, + 'data': { + 'concept_id': concept_id, + 'concept_name': concept_name, + 'stocks': stocks_info + } + }) + + # 2. 获取最新交易日和前一交易日 + today = datetime.now().date() + trading_day = None + prev_trading_day = None + + with engine.connect() as conn: + # 获取最新交易日 + result = conn.execute(text(""" + SELECT EXCHANGE_DATE FROM trading_days + WHERE EXCHANGE_DATE <= :today + ORDER BY EXCHANGE_DATE DESC LIMIT 1 + """), {"today": today}).fetchone() + if result: + trading_day = result[0].date() if hasattr(result[0], 'date') else result[0] + + # 获取前一交易日 + if trading_day: + result = conn.execute(text(""" + SELECT EXCHANGE_DATE FROM trading_days + WHERE EXCHANGE_DATE < :date + ORDER BY EXCHANGE_DATE DESC LIMIT 1 + """), {"date": trading_day}).fetchone() + if result: + prev_trading_day = result[0].date() if hasattr(result[0], 'date') else result[0] + + # 3. 从 MySQL ea_trade 获取前一交易日收盘价(F007N) + prev_close_map = {} + if prev_trading_day and stock_codes: + with engine.connect() as conn: + placeholders = ','.join([f':code{i}' for i in range(len(stock_codes))]) + params = {f'code{i}': code for i, code in enumerate(stock_codes)} + params['trade_date'] = prev_trading_day + + result = conn.execute(text(f""" + SELECT SECCODE, F007N + FROM ea_trade + WHERE SECCODE IN ({placeholders}) + AND TRADEDATE = :trade_date + AND F007N > 0 + """), params).fetchall() + + prev_close_map = {row[0]: float(row[1]) for row in result if row[1]} + + # 4. 从 ClickHouse 获取最新价格 + current_price_map = {} + if stock_codes: + try: + ch_client = Client( + host='127.0.0.1', + port=9000, + user='default', + password='Zzl33818!', + database='stock' + ) + + # 转换为 ClickHouse 格式 + ch_codes = [] + code_mapping = {} + for code in stock_codes: + if code.startswith('6'): + ch_code = f"{code}.SH" + elif code.startswith('0') or code.startswith('3'): + ch_code = f"{code}.SZ" + else: + ch_code = f"{code}.BJ" + ch_codes.append(ch_code) + code_mapping[ch_code] = code + + ch_codes_str = "','".join(ch_codes) + + # 查询当天最新价格 + query = f""" + SELECT code, close + FROM stock_minute + WHERE code IN ('{ch_codes_str}') + AND toDate(timestamp) = today() + ORDER BY timestamp DESC + LIMIT 1 BY code + """ + result = ch_client.execute(query) + + for row in result: + ch_code, close_price = row + if ch_code in code_mapping and close_price: + original_code = code_mapping[ch_code] + current_price_map[original_code] = float(close_price) + + except Exception as ch_err: + app.logger.warning(f"ClickHouse 获取价格失败: {ch_err}") + + # 5. 计算涨跌幅并合并数据 + result_stocks = [] + for stock in stocks_info: + code = stock['code'] + prev_close = prev_close_map.get(code) + current_price = current_price_map.get(code) + + change_pct = None + if prev_close and current_price and prev_close > 0: + change_pct = round((current_price - prev_close) / prev_close * 100, 2) + + result_stocks.append({ + 'code': code, + 'name': stock['name'], + 'reason': stock['reason'], + 'change_pct': change_pct, + 'price': current_price, + 'prev_close': prev_close + }) + + # 按涨跌幅排序(涨停优先) + result_stocks.sort(key=lambda x: x.get('change_pct') if x.get('change_pct') is not None else -999, reverse=True) + + return jsonify({ + 'success': True, + 'data': { + 'concept_id': concept_id, + 'concept_name': concept_name, + 'stock_count': len(result_stocks), + 'trading_day': str(trading_day) if trading_day else None, + 'stocks': result_stocks + } + }) + + except Exception as e: + import traceback + app.logger.error(f"获取概念股票失败: {traceback.format_exc()}") + return jsonify({ + 'success': False, + 'error': str(e) + }), 500 + + +@app.route('/api/market/concept-alerts', methods=['GET']) +def get_concept_alerts(): + """ + 获取概念异动列表(支持分页和筛选) + """ + try: + trade_date = request.args.get('date') + alert_type = request.args.get('type') # surge/limit_up/rank_jump + concept_type = request.args.get('concept_type') # leaf/lv1/lv2/lv3 + limit = request.args.get('limit', 50, type=int) + offset = request.args.get('offset', 0, type=int) + + # 构建查询条件 + conditions = [] + params = {'limit': limit, 'offset': offset} + + if trade_date: + conditions.append("trade_date = :trade_date") + params['trade_date'] = trade_date + else: + conditions.append("trade_date = CURDATE()") + + if alert_type: + conditions.append("alert_type = :alert_type") + params['alert_type'] = alert_type + + if concept_type: + conditions.append("concept_type = :concept_type") + params['concept_type'] = concept_type + + where_clause = " AND ".join(conditions) if conditions else "1=1" + + with engine.connect() as conn: + # 获取总数 + count_sql = text(f"SELECT COUNT(*) FROM concept_minute_alert WHERE {where_clause}") + total = conn.execute(count_sql, params).scalar() + + # 获取数据 + query_sql = text(f""" + SELECT + id, concept_id, concept_name, alert_time, alert_type, trade_date, + change_pct, prev_change_pct, change_delta, + limit_up_count, prev_limit_up_count, limit_up_delta, + rank_position, prev_rank_position, rank_delta, + index_price, index_change_pct, + stock_count, concept_type, extra_info + FROM concept_minute_alert + WHERE {where_clause} + ORDER BY alert_time DESC + LIMIT :limit OFFSET :offset + """) + + result = conn.execute(query_sql, params) + + alerts = [] + for row in result: + extra_info = None + if row[19]: + try: + extra_info = json.loads(row[19]) if isinstance(row[19], str) else row[19] + except: + pass + + alerts.append({ + 'id': row[0], + 'concept_id': row[1], + 'concept_name': row[2], + 'alert_time': row[3].isoformat() if row[3] else None, + 'alert_type': row[4], + 'trade_date': row[5].isoformat() if row[5] else None, + 'change_pct': float(row[6]) if row[6] else None, + 'prev_change_pct': float(row[7]) if row[7] else None, + 'change_delta': float(row[8]) if row[8] else None, + 'limit_up_count': row[9], + 'prev_limit_up_count': row[10], + 'limit_up_delta': row[11], + 'rank_position': row[12], + 'prev_rank_position': row[13], + 'rank_delta': row[14], + 'index_price': float(row[15]) if row[15] else None, + 'index_change_pct': float(row[16]) if row[16] else None, + 'stock_count': row[17], + 'concept_type': row[18], + 'extra_info': extra_info + }) + + return jsonify({ + 'success': True, + 'data': alerts, + 'total': total, + 'limit': limit, + 'offset': offset + }) + + except Exception as e: + import traceback + app.logger.error(f"获取概念异动列表失败: {traceback.format_exc()}") + return jsonify({ + 'success': False, + 'error': str(e) + }), 500 + + @app.route('/api/market/rise-analysis/', methods=['GET']) def get_rise_analysis(seccode): """获取股票涨幅分析数据(从 Elasticsearch 获取)""" diff --git a/concept_hierarchy_v3.json b/concept_hierarchy_v3.json index a7eb2200..e5b74263 100644 --- a/concept_hierarchy_v3.json +++ b/concept_hierarchy_v3.json @@ -2,435 +2,602 @@ "hierarchy": [ { "lv1": "人工智能", - "lv1_id": "lv1_1", + "lv1_id": "ai", "children": [ { - "lv2": "AI基础设施", - "lv2_id": "lv2_1_1", + "lv2": "AI芯片与存储", + "lv2_id": "ai_chips_memory", + "concepts": [ + "AI算力芯片", + "AI芯片", + "GPU概念股", + "HBM", + "SRAM存储", + "RISC-V", + "TPU芯片", + "中昊芯英概念股", + "国产算力芯片", + "国产GPU", + "国产芯片参股公司", + "摩尔线程", + "摩尔线程IPO", + "沐曦集成", + "超威半导体AMD", + "英伟达概念", + "英伟达H20", + "芯片替代", + "谷歌", + "谷歌概念", + "阿里AI芯片", + "利基型存储DDR4", + "存储芯片产业", + "存储", + "存储芯片" + ] + }, + { + "lv2": "AI硬件与基础设施", + "lv2_id": "ai_hardware_infra", "children": [ { - "lv3": "AI算力硬件", - "lv3_id": "lv3_1_1_1", - "concepts": [ - "AI一体机", - "AI算力芯片", - "AI芯片", - "DeepSeek智算一体机", - "GPU概念股", - "TPU芯片", - "一体机核心标的弹性测算", - "中昊芯英概念股", - "国产算力芯片", - "国产GPU", - "服务器", - "昇腾推理一体机", - "摩尔线程", - "摩尔线程IPO", - "沐曦集成", - "阿里AI芯片" - ] - }, - { - "lv3": "AI关键组件", - "lv3_id": "lv3_1_1_2", + "lv3": "AI PCB与封装", + "lv3_id": "ai_pcb_packaging", "concepts": [ "AI PCB", "AI PCB英伟达M9", - "AI服务器钽电容", - "HBM", - "OCS光电路交换机", "PCB", "PCB设备及耗材", - "SRAM存储", - "光纤", - "光通信CPO", - "光通信", - "光芯片", - "光纤列阵单元FAU", - "博通交换机", - "改良型半加成工艺mSAP", - "服务器零部件", - "硅光技术", - "空芯光纤", - "薄膜铌酸锂", - "存储", - "存储芯片", - "存储芯片产业", - "铜互连", - "铜连接", - "钽电容", - "忆阻器", - "利基型存储DDR4" + "改良型半加成工艺mSAP" ] }, { - "lv3": "AI配套设施", - "lv3_id": "lv3_1_1_3", + "lv3": "AI一体机", + "lv3_id": "ai_integrated_machines", + "concepts": [ + "AI一体机", + "DeepSeek智算一体机", + "一体机核心标的弹性测算", + "昇腾推理一体机" + ] + }, + { + "lv3": "AI服务器组件", + "lv3_id": "ai_server_components", + "concepts": [ + "AI服务器钽电容", + "钽电容", + "功率半导体", + "电池备份单元", + "超级电容器" + ] + }, + { + "lv3": "AI数据中心与电力", + "lv3_id": "ai_data_center_power", "concepts": [ "AIDC供配电设备弹性", "数据中心", - "数据中心液冷", "数据中心电力设备", - "微泵液冷", - "微通道水冷板", - "液冷", - "液冷数据中心", - "液态金属散热", - "电池备份单元", - "电磁屏蔽", - "超级电容器", "柴油发电机", - "钻石散热", "英伟达电源方案" ] }, { - "lv3": "算力网络与服务", - "lv3_id": "lv3_1_1_4", + "lv3": "AI散热", + "lv3_id": "ai_cooling", "concepts": [ - "中国星际之门芜湖", - "上海算力", - "四川算力", - "大厂算力订单", - "字节算力", - "星际之门概念", - "杭州算力大会", - "核心城市智算算力", - "毫秒用算", - "算力" + "微泵液冷", + "微通道水冷板", + "数据中心液冷", + "液态金属散热", + "液冷", + "钻石散热", + "电磁屏蔽" + ] + }, + { + "lv3": "AI网络设备", + "lv3_id": "ai_network_equipment", + "concepts": [ + "博通交换机", + "铜连接", + "铜互连", + "铜互联" ] } ] }, { - "lv2": "AI模型与软件", - "lv2_id": "lv2_1_2", + "lv2": "AI光通信", + "lv2_id": "ai_optical_comm", + "concepts": [ + "光纤", + "光通信CPO", + "光通信", + "光芯片", + "光纤列阵单元FAU", + "OCS光电路交换机", + "硅光技术", + "空芯光纤", + "薄膜铌酸锂" + ] + }, + { + "lv2": "AI算力基础设施", + "lv2_id": "ai_computing_power_infra", + "concepts": [ + "中国星际之门芜湖", + "上海算力", + "四川算力", + "大厂算力订单", + "字节算力", + "核心城市智算算力", + "毫秒用算", + "星际之门概念", + "杭州算力大会", + "甲骨文概念股", + "算力" + ] + }, + { + "lv2": "AI模型与算法", + "lv2_id": "ai_models_algorithms", "concepts": [ "DeepSeek FP8", "DeepSeek", "DeepSeek、国产算力", - "KIMI", "MOE模型", "Minimax", - "Nano Banana", "SORA概念", "国产大模型", + "阶跃星辰" + ] + }, + { + "lv2": "AI应用与服务", + "lv2_id": "ai_applications_services", + "concepts": [ + "AI4S", + "AI应用AI语料", + "AI编程", + "Nano Banana", + "KIMI", + "低代码", + "内容审核概念", "文生视频", - "秘塔AI", - "阶跃星辰", + "智象未来", + "版权", + "秘塔AI" + ] + }, + { + "lv2": "AI智能体与陪伴", + "lv2_id": "ai_agents_companions", + "concepts": [ + "AI伴侣", + "AI成人陪伴", + "AI应用陪伴智能体", + "AI智能体", + "AI应用智能体", + "AI智能体AI应用", + "AI语音助手", + "AI陪伴", + "Manus", + "字节AI陪伴", + "开发智能体" + ] + }, + { + "lv2": "AI通用概念", + "lv2_id": "ai_general_concepts", + "concepts": [ + "AI-细分延伸更新", + "AI合集", + "物理AI", "马斯克Grok3大模型" ] }, { - "lv2": "AI应用", - "lv2_id": "lv2_1_3", - "children": [ - { - "lv3": "智能体与陪伴", - "lv3_id": "lv3_1_3_1", - "concepts": [ - "AI伴侣", - "AI成人陪伴", - "AI应用陪伴智能体", - "AI智能体", - "AI应用智能体", - "AI智能体AI应用", - "AI语音助手", - "AI陪伴", - "Manus", - "字节AI陪伴", - "开发智能体" - ] - }, - { - "lv3": "行业应用", - "lv3_id": "lv3_1_3_2", - "concepts": [ - "AI4S", - "AI应用AI语料", - "AI编程", - "低代码", - "内容审核概念", - "物理AI" - ] - } - ] - }, - { - "lv2": "AI生态系统", - "lv2_id": "lv2_1_4", - "children": [ - { - "lv3": "通用生态", - "lv3_id": "lv3_1_4_1", - "concepts": [ - "云计算各厂商云", - "微软Azure云平台", - "甲骨文概念股", - "英伟达代理", - "英伟达概念", - "英伟达H20", - "超威半导体AMD", - "谷歌", - "谷歌概念" - ] - }, - { - "lv3": "阿里生态", - "lv3_id": "lv3_1_4_2", - "concepts": [ - "通义千问阿里云", - "阿里云", - "阿里云通义千问", - "阿里AI千问、灵光", - "阿里“千问”项目", - "阿里AI来听" - ] - }, - { - "lv3": "腾讯生态", - "lv3_id": "lv3_1_4_3", - "concepts": [ - "腾讯元宝", - "腾讯大模型", - "腾讯混元大模型", - "腾讯云及大模型合作公司" - ] - }, - { - "lv3": "字节跳动生态", - "lv3_id": "lv3_1_4_4", - "concepts": [ - "努比亚手机", - "华为抖音支付", - "字节概念豆包AI手机", - "字节豆包概念股", - "抖音概念", - "豆包大模型" - ] - } - ] - }, - { - "lv2": "AI综合与趋势", - "lv2_id": "lv2_1_5", + "lv2": "云平台与大模型", + "lv2_id": "cloud_platforms_large_models", "concepts": [ - "AI-细分延伸更新", - "AI合集" + "云手机", + "云计算各厂商云", + "微软Azure云平台", + "腾讯元宝", + "腾讯大模型", + "腾讯混元大模型", + "腾讯云及大模型合作公司", + "通义千问阿里云", + "阿里云通义千问", + "阿里AI千问、灵光", + "阿里“千问”项目", + "阿里AI来听", + "阿里云" ] } ] }, { "lv1": "半导体", - "lv1_id": "lv1_2", + "lv1_id": "semiconductor", "children": [ { - "lv2": "半导体设备", - "lv2_id": "lv2_2_1", + "lv2": "半导体制造与材料", + "lv2_id": "semiconductor_mfg_materials", "concepts": [ "EDA", + "GAA晶体管", + "PSPI", "上海微电子", - "光刻机", + "先进陶瓷", + "光刻胶", "光刻机宇量昇", + "光刻机", + "半导体抛光液", + "半导体封测", + "半导体设计", + "半导体材料", + "半导体产业链", "半导体设备", + "半导体混合键合技术", + "国产光刻胶", + "国产半导体", "大湾区芯片展览会-新凯莱", "新凯来概念股", "新凯来示波器", - "电子束光刻机“羲之”" - ] - }, - { - "lv2": "半导体材料", - "lv2_id": "lv2_2_2", - "concepts": [ - "PSPI", - "先进陶瓷", - "光刻胶", - "半导体抛光液", - "半导体材料", - "国产光刻胶", - "石英砂", - "磷化铟" - ] - }, - { - "lv2": "芯片设计与制造", - "lv2_id": "lv2_2_3", - "concepts": [ - "ASIC", - "GAA晶体管", - "ISP视觉", - "RISC-V", - "功率半导体", - "半导体设计", - "第三代半导体", - "端侧AI芯片", - "碳化硅", - "英诺赛科概念股" - ] - }, - { - "lv2": "先进封装", - "lv2_id": "lv2_2_4", - "concepts": [ - "半导体封测", - "半导体混合键合技术", + "电子束光刻机“羲之”", "玻璃基板", + "石英砂", + "磷化铟", + "盛合晶微概念股", "盛合晶微", - "盛合晶微概念股" - ] - }, - { - "lv2": "重点企业与IPO", - "lv2_id": "lv2_2_5", - "concepts": [ - "地平线", - "地平线概念", + "第三代半导体", "紫光展锐IPO", + "碳化硅", + "科特估半导体", + "英诺赛科概念股", + "超硬材料", "长鑫存储", - "长鑫、长江产业链", - "高通概念" + "长鑫、长江产业链" ] }, { - "lv2": "综合与政策", - "lv2_id": "lv2_2_6", + "lv2": "半导体政策与地缘", + "lv2_id": "semiconductor_policy_geopolitics", "concepts": [ - "半导体产业链", - "国产半导体", - "国产芯片参股公司", - "科特估半导体", - "芯片替代", - "英特尔概念股" + "中美关系", + "出口管制" ] } ] }, { - "lv1": "机器人", - "lv1_id": "lv1_3", + "lv1": "机器人与智能制造", + "lv1_id": "robotics_smart_mfg", "children": [ { - "lv2": "人形机器人整机", - "lv2_id": "lv2_3_1", + "lv2": "人形机器人整机与生态", + "lv2_id": "humanoid_robot_systems_ecosystems", "concepts": [ "Optimus特斯拉机器人", "乐聚机器人", - "人形机器人", + "人形机器产业链", "人形机器人Figure", + "人形机器人", "云深处", "优必选", + "人形机器人万向节", + "人形机器人核心标的概览", + "人形机器人核心标的估值弹性测算", "优必选机器人", "华为人形机器人", "各厂商机器人", "奇瑞机器人潜在产业链", "天太机器人", - "天工机器人", "宇树人形机器人", "宇树机器人", + "小米智元机器人产业链机构机构版", + "天工机器人", "小鹏机器人", "小米机器人", - "小米智元机器人产业链机构版", "开普勒机器人", - "松延动力机器人", - "特斯拉人形机器人", + "工业机器人", "智元机器人", + "松延动力机器人", + "机器狗四足机器人", + "特斯拉人形机器人价值量", + "特斯拉人形机器人", + "特斯拉产业链", + "特斯拉人形机器人弹性测算", "荣耀华为人形机器人", + "美的库卡机器人", "赛力斯机器人" ] }, { - "lv2": "机器人核心零部件", - "lv2_id": "lv2_3_2", + "lv2": "机器人核心部件与材料", + "lv2_id": "robotics_core_components_materials", "concepts": [ + "MIM概念", "PCB轴向磁通电机", "人形机器人-滚柱丝杆丝杠", - "人形机器人万向节", + "人形机器人轻量化-PEEK材料", "人形机器人腱绳", - "机电", - "机器人皮肤仿生皮肤", + "冷锻产业链", + "机器人轻量化-PEEK", + "机器人轻量化—碳纤维", + "机器人零部件加工设备", + "机器人轻量化-镁铝合金", "摆线减速器", - "电子皮肤", - "轴向磁通电机" + "金属粉末注射成形MIM", + "超高分子量聚乙烯纤维", + "机器人氮化镓", + "轴向磁通电机", + "机电" ] }, { - "lv2": "机器人产业链", - "lv2_id": "lv2_3_3", - "children": [ - { - "lv3": "综合与价值链", - "lv3_id": "lv3_3_3_1", - "concepts": [ - "人形机器产业链", - "人形机器人核心标的估值弹性测算", - "人形机器人核心标的概览", - "特斯拉人形机器人价值量", - "特斯拉人形机器人弹性测算", - "特斯拉产业链" - ] - }, - { - "lv3": "轻量化材料", - "lv3_id": "lv3_3_3_2", - "concepts": [ - "人形机器人轻量化-PEEK", - "人形机器人轻量化-PEEK材料", - "机器人轻量化-镁铝合金", - "机器人轻量化—碳纤维", - "超高分子量聚乙烯纤维" - ] - }, - { - "lv3": "制造工艺与设备", - "lv3_id": "lv3_3_3_3", - "concepts": [ - "MIM概念", - "冷锻产业链", - "机器人零部件加工设备", - "金属粉末注射成形MIM" - ] - } + "lv2": "机器人感知与交互", + "lv2_id": "robotics_perception_interaction", + "concepts": [ + "机器人动作捕捉", + "机器人皮肤仿生皮肤", + "电子皮肤" ] }, { - "lv2": "机器人软件与AI", - "lv2_id": "lv2_3_4", + "lv2": "机器人AI与控制", + "lv2_id": "robotics_ai_control", "concepts": [ "机器人-神经网络", - "机器人动作捕捉" + "神经网络" ] }, { - "lv2": "其他类型机器人", - "lv2_id": "lv2_3_5", + "lv2": "信创与工业自动化", + "lv2_id": "xinchuang_industrial_automation", "concepts": [ - "外骨骼机器人", - "工业机器人", - "机器狗四足机器人", - "美的库卡机器人" + "信创概念", + "关键软件", + "国产信创概览", + "工业设备更新", + "工业软件", + "工业母机", + "灯塔工厂", + "自主可控", + "软件自主可控", + "设备更新" ] } ] }, { - "lv1": "消费电子", - "lv1_id": "lv1_4", + "lv1": "新能源与电力", + "lv1_id": "new_energy_power", "children": [ { - "lv2": "智能终端(端侧AI)", - "lv2_id": "lv2_4_1", + "lv2": "光伏产业", + "lv2_id": "photovoltaic_industry", "concepts": [ - "2025CES参展公司", - "AI PC", - "AIPC", - "AI手机" + "N型产业链", + "光伏", + "光伏行业兼并重组", + "光伏产业链", + "反内卷光伏", + "叠层钙钛矿", + "钙钛矿电池" ] }, { - "lv2": "XR与空间计算", - "lv2_id": "lv2_4_2", + "lv2": "先进电池技术", + "lv2_id": "advanced_battery_tech", + "concepts": [ + "固态电池-硅基负极", + "固态电池-硫化物", + "固态电池", + "固态电池设备", + "固态电池负极集流体材料-铜箔", + "固态电池产业链", + "复合集流体", + "富锂锰基材料", + "硅基负极材料", + "钠离子电池", + "隔膜", + "陶瓷隔膜骨架膜" + ] + }, + { + "lv2": "充电与换电", + "lv2_id": "charging_battery_swapping", + "concepts": [ + "充电桩", + "华为智能充电", + "华为智能充电网络", + "换电", + "换电重卡", + "比亚迪兆瓦闪充" + ] + }, + { + "lv2": "电力基础设施与AI能源", + "lv2_id": "power_infra_ai_energy", + "concepts": [ + "北美缺电AI电力", + "变压器出海", + "固体氧化物燃料电池-SOFC", + "固态变压器SST", + "燃料电池", + "燃气设备", + "电力产业链", + "燃气轮机HRSG", + "电力电网", + "电力设备", + "电力" + ] + }, + { + "lv2": "核能技术与应用", + "lv2_id": "nuclear_energy_applications", + "concepts": [ + "可控核聚变", + "微型核电", + "核电钍基熔盐堆", + "核聚变超导", + "核电产业链", + "超聚变" + ] + }, + { + "lv2": "核能通用概念", + "lv2_id": "nuclear_energy_general_concepts", + "concepts": [ + "高温概念" + ] + } + ] + }, + { + "lv1": "智能出行与交通", + "lv1_id": "smart_mobility_trans", + "children": [ + { + "lv2": "智能驾驶", + "lv2_id": "autonomous_driving", + "concepts": [ + "Robotaxi", + "小马智行", + "无人驾驶公交", + "文远知行", + "智能驾驶产业链", + "无人驾驶-线控转向", + "无人驾驶", + "特斯拉FSD", + "特斯拉RoboTaxi概念", + "矿山智驾", + "特斯拉无人驾驶出租车Robotaxi", + "自动驾驶" + ] + }, + { + "lv2": "无人物流", + "lv2_id": "unmanned_logistics", + "concepts": [ + "京东物流Robovan", + "无人物流", + "无人环卫车", + "无人物流车九识智能", + "菜鸟无人物流车" + ] + }, + { + "lv2": "智能驾驶基础设施", + "lv2_id": "autonomous_driving_infra", + "concepts": [ + "机器人充电", + "汽车无线充电" + ] + }, + { + "lv2": "智能出行服务", + "lv2_id": "smart_mobility_services", + "concepts": [ + "网约车" + ] + }, + { + "lv2": "智能驾驶与整车", + "lv2_id": "smart_driving_oems", + "concepts": [ + "mobileye替代概念", + "地平线概念", + "地平线", + "比亚迪产业链", + "比亚迪智驾", + "理想汽车" + ] + }, + { + "lv2": "低空经济", + "lv2_id": "low_altitude_economy", + "concepts": [ + "eVTOL材料", + "亿航智能订单量", + "低空经济&飞行汽车", + "低空经济", + "低空管控", + "低空物流", + "低空经济亿航智能", + "低空设计", + "低空经济产业链汇集", + "小鹏汇天", + "小鹏汇天供应商", + "空中成像", + "长安飞行汽车机器人概念", + "飞行汽车eVTOL" + ] + } + ] + }, + { + "lv1": "通信与航天", + "lv1_id": "comm_aerospace", + "children": [ + { + "lv2": "5G与未来通信", + "lv2_id": "5g_future_comm", + "concepts": [ + "5G毫米波", + "5G-A", + "5.5G", + "eSIM概念", + "通感一体" + ] + }, + { + "lv2": "卫星通信与应用", + "lv2_id": "satellite_comm_applications", + "concepts": [ + "6G概念", + "低轨卫星通信华为", + "北斗信使", + "北斗导航", + "商业航天卫星通信", + "卫星出海", + "卫星互联网", + "商业航天", + "太空行走", + "太空旅行", + "太空算力", + "手机直连卫星", + "星河动力", + "星网", + "蓝箭航天朱雀三号", + "长征十二号甲", + "SpaceX", + "卫星能源太阳翼" + ] + }, + { + "lv2": "航空航天", + "lv2_id": "aviation_space", + "concepts": [ + "九天无人机", + "凌空天行", + "大飞机", + "珠海航展", + "空客合作" + ] + } + ] + }, + { + "lv1": "消费电子与XR", + "lv1_id": "consumer_electronics_xr", + "children": [ + { + "lv2": "智能穿戴与XR", + "lv2_id": "smart_wearables_xr", "concepts": [ "AI手势识别", "AI眼镜", @@ -440,665 +607,56 @@ "META智能眼镜", "MR", "Rokid AR", - "消费电子-玄玑感知系统", + "小米眼镜", "智能穿戴", "智能眼镜", - "小米眼镜", - "阿里夸克AI眼镜", - "谷歌AI眼镜-合作XREAL" - ] - }, - { - "lv2": "华为产业链", - "lv2_id": "lv2_4_3", - "children": [ - { - "lv3": "终端产品", - "lv3_id": "lv3_4_3_1", - "concepts": [ - "华为P70", - "华为Mate80", - "华为Pura70", - "华为Mate70手表", - "华为MATE70" - ] - }, - { - "lv3": "技术与生态", - "lv3_id": "lv3_4_3_2", - "concepts": [ - "华为", - "华为5G", - "华为AI容器", - "华为云", - "华为昇腾", - "华为昇腾超节点", - "华为通信大模型", - "华为海思星闪", - "华为鸿蒙", - "华为鸿蒙甄选与支付", - "华字辈", - "昇腾异构计算架构-CANN", - "鸿蒙PC" - ] - }, - { - "lv3": "芯片与算力", - "lv3_id": "lv3_4_3_3", - "concepts": [ - "华为910C", - "华为AI存储", - "华为存储OceanStor", - "华为麒麟芯片", - "磁电存储" - ] - } - ] - }, - { - "lv2": "苹果产业链", - "lv2_id": "lv2_4_4", - "concepts": [ - "果链OPEN AI复用", + "消费电子-玄玑感知系统", "苹果MR产业链", + "谷歌AI眼镜-合作XREAL", + "阿里夸克AI眼镜", + "雷鸟创新光波导" + ] + }, + { + "lv2": "AI终端与芯片", + "lv2_id": "ai_devices_chips", + "concepts": [ + "2025CES参展公司", + "AI PC", + "AIPC", + "AI手机", + "ASIC", + "ISP视觉", + "果链OPEN AI复用", + "端侧AI芯片", "苹果供应商核心公司", "苹果机器人", - "苹果手机产业链" + "苹果手机产业链", + "高通概念" ] }, { "lv2": "新型显示技术", - "lv2_id": "lv2_4_5", + "lv2_id": "new_display_technologies", "concepts": [ "华为三折叠屏", "折叠屏", "显影液及硅基OLED", "苹果OLED潜在受益", "苹果折叠屏", - "面板" + "面板", + "TV面板LCD" ] } ] }, { - "lv1": "智能驾驶与汽车", - "lv1_id": "lv1_5", + "lv1": "文化传媒与娱乐", + "lv1_id": "culture_media_ent", "children": [ { - "lv2": "自动驾驶解决方案", - "lv2_id": "lv2_5_1", - "concepts": [ - "Robotaxi", - "京东物流Robovan", - "小马智行", - "文远知行", - "无人物流", - "无人物流车九识智能", - "无人驾驶", - "特斯拉FSD", - "特斯拉RoboTaxi概念", - "特斯拉无人驾驶出租车Robotaxi", - "自动驾驶", - "菜鸟无人物流车" - ] - }, - { - "lv2": "智能汽车产业链", - "lv2_id": "lv2_5_2", - "concepts": [ - "比亚迪产业链", - "比亚迪智驾", - "孙潇雅团队概念股", - "小米YU7供应链弹性测算", - "小米大模型", - "小米概念", - "小米算力AI互联", - "小米汽车产业链", - "小米汽车产业链弹性", - "小米汽车产业链弹性测算", - "小鹏产业链", - "理想汽车" - ] - }, - { - "lv2": "车路协同与特定场景", - "lv2_id": "lv2_5_3", - "concepts": [ - "无人驾驶-线控转向", - "无人驾驶公交", - "无人环卫车", - "矿山智驾", - "车路云-车路协同运营建设", - "车路云一体化", - "车路协同" - ] - }, - { - "lv2": "产业链综合", - "lv2_id": "lv2_5_4", - "concepts": [ - "mobileye替代概念", - "智能驾驶产业链", - "汽车安全" - ] - }, - { - "lv2": "出行服务", - "lv2_id": "lv2_5_5", - "concepts": [ - "网约车" - ] - }, - { - "lv2": "智能座舱", - "lv2_id": "lv2_5_6", - "concepts": [ - "空中成像" - ] - } - ] - }, - { - "lv1": "新能源与电力", - "lv1_id": "lv1_6", - "children": [ - { - "lv2": "新型电池技术", - "lv2_id": "lv2_6_1", - "children": [ - { - "lv3": "固态电池", - "lv3_id": "lv3_6_1_1", - "concepts": [ - "固态电池", - "固态电池-硅基负极", - "固态电池-硫化物", - "固态电池产业链", - "固态电池设备", - "固态电池负极集流体材料-铜箔", - "陶瓷隔膜骨架膜" - ] - }, - { - "lv3": "其他材料与技术", - "lv3_id": "lv3_6_1_2", - "concepts": [ - "复合集流体", - "富锂锰基材料", - "硅基负极材料", - "钠离子电池", - "隔膜" - ] - } - ] - }, - { - "lv2": "电力设备与电网", - "lv2_id": "lv2_6_2", - "concepts": [ - "北美缺电AI电力", - "变压器出海", - "固体氧化物燃料电池-SOFC", - "固态变压器SST", - "燃料电池", - "燃气设备", - "燃气轮机HRSG", - "电力", - "电力产业链", - "电力电网", - "电力设备" - ] - }, - { - "lv2": "清洁能源", - "lv2_id": "lv2_6_3", - "children": [ - { - "lv3": "光伏", - "lv3_id": "lv3_6_3_1", - "concepts": [ - "N型产业链", - "光伏", - "光伏行业兼并重组", - "光伏产业链", - "反内卷光伏", - "叠层钙钛矿", - "钙钛矿电池" - ] - }, - { - "lv3": "核能", - "lv3_id": "lv3_6_3_2", - "concepts": [ - "可控核聚变", - "微型核电", - "核电产业链", - "核电钍基熔盐堆", - "核聚变超导", - "超聚变", - "高温概念" - ] - } - ] - }, - { - "lv2": "充电桩与补能", - "lv2_id": "lv2_6_4", - "concepts": [ - "充电桩", - "华为智能充电", - "华为智能充电网络", - "换电", - "换电重卡", - "机器人充电", - "汽车无线充电", - "比亚迪兆瓦闪充" - ] - } - ] - }, - { - "lv1": "空天经济", - "lv1_id": "lv1_7", - "children": [ - { - "lv2": "低空经济", - "lv2_id": "lv2_7_1", - "concepts": [ - "eVTOL材料", - "九天无人机", - "亿航智能订单量", - "低空经济", - "低空经济&飞行汽车", - "低空经济产业链汇集", - "低空经济亿航智能", - "低空管控", - "低空物流", - "低空设计", - "小鹏汇天", - "小鹏汇天供应商", - "飞行汽车eVTOL", - "长安飞行汽车机器人概念" - ] - }, - { - "lv2": "商业航天", - "lv2_id": "lv2_7_2", - "concepts": [ - "凌空天行", - "北斗信使", - "北斗导航", - "商业航天", - "商业航天卫星通信", - "卫星出海", - "卫星互联网", - "太空行走", - "太空旅行", - "太空算力", - "手机直连卫星", - "星河动力", - "星网", - "蓝箭航天朱雀三号" - ] - }, - { - "lv2": "通信技术", - "lv2_id": "lv2_7_3", - "concepts": [ - "5.5G", - "5G-A", - "5G毫米波", - "6G概念", - "eSIM概念", - "通感一体" - ] - }, - { - "lv2": "民用航空", - "lv2_id": "lv2_7_4", - "concepts": [ - "大飞机", - "空客合作" - ] - }, - { - "lv2": "综合与主题", - "lv2_id": "lv2_7_5", - "concepts": [ - "珠海航展" - ] - } - ] - }, - { - "lv1": "国防军工", - "lv1_id": "lv1_8", - "children": [ - { - "lv2": "无人作战与信息化", - "lv2_id": "lv2_8_1", - "concepts": [ - "AI军工", - "AI无人机军工信息化", - "信息支援概念整理", - "军工信息化", - "军用无人机反无人机", - "无人机蜂群" - ] - }, - { - "lv2": "海军装备", - "lv2_id": "lv2_8_2", - "concepts": [ - "军工水面水下作战", - "国产航母", - "水下军工", - "海军", - "电磁弹射概念股", - "电磁发射设备", - "航母福建舰240430" - ] - }, - { - "lv2": "空军装备", - "lv2_id": "lv2_8_3", - "concepts": [ - "军机" - ] - }, - { - "lv2": "陆军装备", - "lv2_id": "lv2_8_4", - "concepts": [ - "远程火力" - ] - }, - { - "lv2": "军贸出海", - "lv2_id": "lv2_8_5", - "concepts": [ - "军贸", - "巴印军贸", - "巴黎航展" - ] - }, - { - "lv2": "综合与主题", - "lv2_id": "lv2_8_6", - "concepts": [ - "军工", - "军工-阅兵", - "国防军工" - ] - } - ] - }, - { - "lv1": "政策与主题", - "lv1_id": "lv1_9", - "children": [ - { - "lv2": "国企改革与市值管理", - "lv2_id": "lv2_9_1", - "concepts": [ - "中兵集团并购重组", - "中字头", - "中字头央企", - "中船合并", - "央企市值管理", - "央国企", - "央国企地产", - "央国企重组", - "安徽国资", - "地面兵装", - "整车央企重组", - "河南国资能源集团重组", - "市值管理16条-破净股", - "珠海国资", - "破净央国企", - "破净股合集", - "福建国资", - "湖北三资改革", - "国资高息股", - "高送转概念股", - "央国企AI一张图" - ] - }, - { - "lv2": "并购重组", - "lv2_id": "lv2_9_2", - "concepts": [ - "IPO终止相关企业重组预期", - "上海并购重组", - "券商合并预期", - "宝德计算机", - "并购重组", - "并购重组预期", - "消费医疗重组预期", - "湘财合并大智慧", - "科创板并购重组", - "秦淮数据", - "科技重组", - "超聚变借壳预期", - "证券", - "荣耀股改", - "重组-中科院系&海光系" - ] - }, - { - "lv2": "信创与自主可控", - "lv2_id": "lv2_9_3", - "concepts": [ - "信创概念", - "关键软件", - "国产信创概览", - "工业软件", - "自主可控", - "软件自主可控", - "信息安全", - "安全概念股", - "网络安全", - "通信设备", - "通信安全" - ] - }, - { - "lv2": "重大基建", - "lv2_id": "lv2_9_4", - "concepts": [ - "三峡水运新通道", - "新藏铁路", - "新疆概念", - "水利", - "水利工程", - "混凝土减水剂、砂石设备", - "节水产业240423", - "西部大开发", - "西部大开发240424", - "西南水电", - "西南水电站", - "西南水电站-机构测算", - "隧洞设备盾构机", - "雅下水电对电力设备增量测算-机构", - "雅下水电站", - "雅下水电站大件物流" - ] - }, - { - "lv2": "供给侧改革", - "lv2_id": "lv2_9_5", - "concepts": [ - "反内卷", - "反内卷造纸", - "反内卷食用盐", - "反内卷快递", - "反内卷合集", - "生猪", - "物流", - "牛肉", - "统一大市场", - "钢铁" - ] - }, - { - "lv2": "产业升级与制造", - "lv2_id": "lv2_9_6", - "concepts": [ - "工业设备更新", - "工业母机", - "设备更新", - "灯塔工厂" - ] - }, - { - "lv2": "国家战略", - "lv2_id": "lv2_9_7", - "concepts": [ - "2025年政府工作报告利好行业及个股", - "新型经济", - "新质生产力", - "深海数智化", - "深海经济", - "深地经济", - "首发经济" - ] - }, - { - "lv2": "区域发展与自贸区", - "lv2_id": "lv2_9_8", - "concepts": [ - "上海自贸区", - "免税离境退税", - "新型离岸贸易", - "海南", - "海南自贸区", - "海南自贸港", - "零售消费免税" - ] - }, - { - "lv2": "行业监管与规范", - "lv2_id": "lv2_9_9", - "concepts": [ - "充电宝", - "农药证件厂家", - "食品安全", - "食品安全全链条", - "预制菜" - ] - } - ] - }, - { - "lv1": "周期与材料", - "lv1_id": "lv1_10", - "children": [ - { - "lv2": "化工", - "lv2_id": "lv2_10_1", - "children": [ - { - "lv3": "行业趋势", - "lv3_id": "lv3_10_1_1", - "concepts": [ - "化工", - "化工概念", - "化工品涨价", - "涨价概念" - ] - }, - { - "lv3": "具体品种", - "lv3_id": "lv3_10_1_2", - "concepts": [ - "TMA偏苯三酸酐", - "乙烷", - "光引发剂", - "六氟磷酸锂", - "农药杀虫剂-氯虫苯甲酰胺", - "双季戊四醇", - "己内酰胺", - "有机硅", - "正丙醇", - "烧碱", - "涤纶长丝", - "电解液添加剂", - "甲苯二异氰酸酯-TDI", - "环氧丙烷", - "纯碱", - "磷化工", - "磷化工六氟磷酸锂", - "聚酯产业", - "维生素", - "苯酚丙酮", - "超硬材料", - "除草剂-烯草酮" - ] - }, - { - "lv3": "产业链", - "lv3_id": "lv3_10_1_3", - "concepts": [ - "有机硅产业链", - "电解液产业链" - ] - } - ] - }, - { - "lv2": "有色金属", - "lv2_id": "lv2_10_2", - "concepts": [ - "化工有色元素周期表", - "有色金属", - "电解铝", - "稀土", - "白银", - "铅酸电池", - "铜", - "铜产业", - "钨金属", - "钴", - "钴金属", - "钼金属", - "锡矿" - ] - } - ] - }, - { - "lv1": "大消费", - "lv1_id": "lv1_11", - "children": [ - { - "lv2": "文化传媒", - "lv2_id": "lv2_11_1", - "concepts": [ - "AI游戏", - "乙游", - "传媒出海", - "出版传媒", - "国产游戏黑神话", - "周杰伦概念股", - "幻兽帕鲁", - "影视", - "影视IP", - "影视传媒", - "影视院线", - "春节档重点影片(哪吒2)", - "智象未来", - "漫剧", - "游戏", - "游戏出海", - "疯狂动物城2", - "短剧", - "诡秘之主", - "腾讯短剧重点名单" - ] - }, - { - "lv2": "新消费", - "lv2_id": "lv2_11_2", + "lv2": "潮玩与IP衍生", + "lv2_id": "trendy_play_ip_derivatives", "concepts": [ "上市潮玩盲盒公司", "卡游文创玩具", @@ -1112,8 +670,263 @@ ] }, { - "lv2": "人口与社会", - "lv2_id": "lv2_11_3", + "lv2": "影视与IP内容", + "lv2_id": "film_tv_ip_content", + "concepts": [ + "出版传媒", + "周杰伦概念股", + "影视院线", + "影视IP", + "影视", + "影视传媒", + "春节档重点影片(哪吒2)", + "疯狂动物城2", + "诡秘之主" + ] + }, + { + "lv2": "体育赛事与经济", + "lv2_id": "sports_events_economy", + "concepts": [ + "体育", + "体育产业", + "冰雪经济", + "川超联赛", + "第十五届全运会", + "足球-苏超联赛、体彩", + "足球" + ] + }, + { + "lv2": "游戏与AI赋能", + "lv2_id": "gaming_ai_empowerment", + "concepts": [ + "AI游戏", + "乙游", + "国产游戏黑神话", + "幻兽帕鲁", + "游戏" + ] + }, + { + "lv2": "内容出海与新业态", + "lv2_id": "content_export_new_formats", + "concepts": [ + "传媒出海", + "漫剧", + "游戏出海", + "短剧", + "腾讯短剧重点名单" + ] + }, + { + "lv2": "社交媒体与跨境电商", + "lv2_id": "social_media_cross_border_ecomm", + "concepts": [ + "TikTok", + "小红书概念", + "小红书概念股", + "敦煌网跨境电商" + ] + } + ] + }, + { + "lv1": "宏观经济与政策", + "lv1_id": "macro_policy", + "children": [ + { + "lv2": "国家战略与新经济", + "lv2_id": "national_strategy_new_economy", + "concepts": [ + "2025年政府工作报告利好行业及个股", + "新型经济", + "新质生产力", + "深海数智化", + "深海经济", + "深地经济", + "首发经济" + ] + }, + { + "lv2": "国企改革与市值管理", + "lv2_id": "soe_reform_market_value", + "concepts": [ + "中兵集团并购重组", + "中字头", + "中字头央企", + "中船合并", + "国资高息股", + "央企市值管理", + "央国企", + "地面兵装", + "央国企地产", + "央国企重组", + "安徽国资", + "市值管理16条-破净股", + "整车央企重组", + "河南国资能源集团重组", + "珠海国资", + "湖北三资改革", + "破净央国企", + "破净股合集", + "福建国资" + ] + }, + { + "lv2": "并购重组", + "lv2_id": "mergers_acquisitions", + "concepts": [ + "IPO终止相关企业重组预期", + "上海并购重组", + "券商合并预期", + "宝德计算机", + "并购重组预期", + "并购重组", + "消费医疗重组预期", + "湘财合并大智慧", + "科创板并购重组", + "秦淮数据", + "科技重组", + "超聚变借壳预期", + "证券", + "荣耀股改", + "重组-中科院系&海光系" + ] + }, + { + "lv2": "供给侧改革与周期", + "lv2_id": "supply_side_reform_cycles", + "concepts": [ + "反内卷", + "反内卷造纸", + "反内卷食用盐", + "反内卷快递", + "反内卷合集", + "己内酰胺", + "涤纶长丝", + "生猪", + "牛肉", + "聚酯产业", + "钢铁", + "BOPET膜" + ] + }, + { + "lv2": "区域经济与自贸区", + "lv2_id": "regional_economy_free_trade_zones", + "concepts": [ + "上海自贸区", + "免税离境退税", + "新型离岸贸易", + "海南", + "海南自贸区", + "海南自贸港", + "零售消费免税" + ] + } + ] + }, + { + "lv1": "地缘政治与国际关系", + "lv1_id": "geopolitics_intl", + "children": [ + { + "lv2": "贸易政策与供应链", + "lv2_id": "trade_policy_supply_chain", + "concepts": [ + "中俄贸易", + "乙烷", + "中欧贸易", + "二轮车全地形车", + "关税豁免", + "关税减免出口链", + "后关税战受益", + "反制关税涨价预期", + "墨西哥汽车零部件", + "绒毛浆", + "芬太尼管制", + "转口贸易出口转内销", + "越南工厂" + ] + }, + { + "lv2": "国际冲突与重建", + "lv2_id": "intl_conflicts_reconstruction", + "concepts": [ + "乌克兰战后重建概念", + "乌克兰重建", + "俄乌重建", + "柬泰战争", + "黄岩岛概念股" + ] + }, + { + "lv2": "两岸关系与区域发展", + "lv2_id": "cross_strait_regional_dev", + "concepts": [ + "两岸融合", + "台资企业", + "海峡两岸福建", + "厦门“十五五规划”" + ] + } + ] + }, + { + "lv1": "数字经济与数据要素", + "lv1_id": "digital_economy_data", + "children": [ + { + "lv2": "数字金融与Web3", + "lv2_id": "digital_finance_web3", + "concepts": [ + "上海浦江数链", + "互联网金融", + "复星稳定币", + "数字货币", + "树图链概念", + "稳定币RWA概念股", + "稳定币-蚂蚁国际", + "稳定币一体机", + "蚂蚁金服", + "香港金融牌照" + ] + }, + { + "lv2": "数据要素与基础设施", + "lv2_id": "data_elements_infra", + "concepts": [ + "RWA上链— IoT设备数据采集", + "RDA概念股", + "数据可信", + "数据交易所", + "数据要素", + "跨境数据数据要素" + ] + }, + { + "lv2": "网络与数据安全", + "lv2_id": "network_data_security", + "concepts": [ + "信息安全", + "地理信息", + "安全概念股", + "汽车安全", + "网络安全", + "通信设备", + "通信安全" + ] + } + ] + }, + { + "lv1": "社会民生", + "lv1_id": "social_livelihood", + "children": [ + { + "lv2": "生育与教育", + "lv2_id": "fertility_education", "concepts": [ "三胎", "多胎", @@ -1124,148 +937,274 @@ ] }, { - "lv2": "体育产业", - "lv2_id": "lv2_11_4", + "lv2": "行业监管与安全", + "lv2_id": "industry_regulation_safety", "concepts": [ - "体育", - "体育产业", - "冰雪经济", - "川超联赛", - "第十五届全运会", - "足球", - "足球-苏超联赛、体彩" + "充电宝", + "农药证件厂家", + "食品安全", + "食品安全全链条", + "预制菜" ] } ] }, { - "lv1": "数字经济与金融科技", - "lv1_id": "lv1_12", + "lv1": "医药与生物科技", + "lv1_id": "pharma_biotech", "children": [ { - "lv2": "数据要素", - "lv2_id": "lv2_12_1", - "concepts": [ - "RDA概念股", - "RWA上链— IoT设备数据采集", - "版权", - "地理信息", - "数据可信", - "数据交易所", - "数据要素", - "政务云政务IT", - "跨境数据数据要素" - ] - }, - { - "lv2": "数字金融", - "lv2_id": "lv2_12_2", - "concepts": [ - "上海浦江数链", - "互联网金融", - "复星稳定币", - "数字货币", - "树图链概念", - "稳定币-蚂蚁国际", - "稳定币RWA概念股", - "稳定币一体机", - "蚂蚁金服", - "香港金融牌照" - ] - } - ] - }, - { - "lv1": "全球宏观与贸易", - "lv1_id": "lv1_13", - "children": [ - { - "lv2": "地缘政治与冲突", - "lv2_id": "lv2_13_1", - "concepts": [ - "以伊冲突-天然气", - "以伊冲突-油运仓储", - "以伊冲突-航运", - "以伊冲突-资源化工", - "乌克兰战后重建概念", - "乌克兰重建", - "俄乌重建", - "油气", - "海外港口", - "海事反制", - "石油", - "航运", - "远洋航运", - "柬泰战争", - "黄岩岛概念股" - ] - }, - { - "lv2": "贸易政策与关系", - "lv2_id": "lv2_13_2", - "concepts": [ - "中美关系", - "中俄贸易", - "中欧贸易", - "出口管制", - "反制关税涨价预期", - "后关税战受益", - "关税减免出口链", - "关税豁免", - "绒毛浆", - "芬太尼管制", - "转口贸易出口转内销" - ] - }, - { - "lv2": "供应链重构", - "lv2_id": "lv2_13_3", - "concepts": [ - "二轮车全地形车", - "墨西哥汽车零部件", - "越南工厂" - ] - } - ] - }, - { - "lv1": "医药健康", - "lv1_id": "lv1_14", - "children": [ - { - "lv2": "创新药", - "lv2_id": "lv2_14_1", + "lv2": "创新药与医疗器械", + "lv2_id": "innovative_drugs_medical_devices", "concepts": [ "AI制药", - "创新药", - "创新药双抗", "创新药相关", - "医药" + "创新药双抗", + "医疗器械", + "创新药", + "医药", + "医药外包CXO" + ] + } + ] + }, + { + "lv1": "基础材料与化工", + "lv1_id": "basic_materials_chem", + "children": [ + { + "lv2": "化工产品与周期", + "lv2_id": "chemical_products_cycles", + "concepts": [ + "TMA偏苯三酸酐", + "光引发剂", + "六氟磷酸锂", + "农药杀虫剂-氯虫苯甲酰胺", + "化工品涨价", + "化工概念", + "化工", + "双季戊四醇", + "正丙醇", + "有机硅", + "有机硅产业链", + "烧碱", + "涨价概念", + "电解液添加剂", + "电解液产业链", + "甲苯二异氰酸酯-TDI", + "环氧丙烷", + "电解铝", + "纯碱", + "磷化工六氟磷酸锂", + "磷化工", + "维生素", + "苯酚丙酮", + "铅酸电池", + "除草剂-烯草酮" ] }, { - "lv2": "细胞治疗", - "lv2_id": "lv2_14_2", + "lv2": "有色金属与稀有矿产", + "lv2_id": "non_ferrous_rare_minerals", "concepts": [ - "干细胞", - "干细胞概念股" + "化工有色元素周期表", + "有色金属", + "稀土", + "白银", + "铜产业", + "钨金属", + "钴", + "铜", + "钼金属", + "钴金属", + "锡矿" + ] + } + ] + }, + { + "lv1": "基础设施与工程", + "lv1_id": "infra_engineering", + "children": [ + { + "lv2": "重大基建与水利", + "lv2_id": "major_infra_water", + "concepts": [ + "三峡水运新通道", + "新藏铁路", + "新疆概念", + "水利", + "水利工程", + "混凝土减水剂、砂石设备", + "节水产业240423", + "西部大开发", + "西南水电站", + "西部大开发240424", + "西南水电站-机构测算", + "西南水电", + "隧洞设备盾构机", + "雅下水电对电力设备增量测算-机构", + "雅下水电站", + "雅下水电站大件物流" ] } ] }, { "lv1": "前沿科技", - "lv1_id": "lv1_15", + "lv1_id": "frontier_tech", "children": [ { - "lv2": "量子科技", - "lv2_id": "lv2_15_1", + "lv2": "量子技术与应用", + "lv2_id": "quantum_tech_applications", "concepts": [ "量子材料钛酸锶", "量子科技", "量子科技产业链", - "量子科技参股公司", - "量子计算" + "量子计算", + "量子科技参股公司" + ] + } + ] + }, + { + "lv1": "市场情绪与概念", + "lv1_id": "market_sentiment", + "children": [ + { + "lv2": "叙事与情绪驱动", + "lv2_id": "narrative_emotion_driven", + "concepts": [ + "“马”字辈", + "人造肉", + "长安的荔枝", + "韦神概念股" + ] + } + ] + }, + { + "lv1": "华为生态", + "lv1_id": "huawei_eco", + "children": [ + { + "lv2": "华为终端与芯片", + "lv2_id": "huawei_devices_chips", + "concepts": [ + "华为P70", + "华为Mate80", + "华为Pura70", + "华为Mate70手表", + "华为MATE70", + "华为麒麟芯片" + ] + }, + { + "lv2": "华为通信与软件", + "lv2_id": "huawei_comm_software", + "concepts": [ + "华为5G", + "华为通信大模型", + "华为海思星闪", + "华为鸿蒙", + "华为鸿蒙甄选与支付", + "鸿蒙PC" + ] + }, + { + "lv2": "华为AI与云计算", + "lv2_id": "huawei_ai_cloud", + "concepts": [ + "华为存储OceanStor", + "华为910C", + "华为云", + "华为AI存储", + "华为AI容器", + "华为昇腾", + "华为昇腾超节点", + "昇腾异构计算架构-CANN", + "磁电存储" + ] + }, + { + "lv2": "华为通用概念", + "lv2_id": "huawei_general_concepts", + "concepts": [ + "华为", + "华字辈" + ] + } + ] + }, + { + "lv1": "小米生态", + "lv1_id": "xiaomi_eco", + "children": [ + { + "lv2": "小米生态与智能汽车", + "lv2_id": "xiaomi_eco_smart_vehicles", + "concepts": [ + "小米算力AI互联", + "孙潇雅团队概念股", + "小米YU7供应链弹性测算", + "小米大模型", + "小鹏产业链", + "小米概念", + "小米汽车产业链弹性", + "小米汽车产业链", + "小米汽车产业链弹性测算" + ] + } + ] + }, + { + "lv1": "字节跳动生态", + "lv1_id": "bytedance_eco", + "children": [ + { + "lv2": "字节生态与AI终端", + "lv2_id": "bytedance_eco_ai_devices", + "concepts": [ + "努比亚手机", + "华为抖音支付", + "字节概念豆包AI手机", + "字节豆包概念股", + "抖音概念", + "豆包大模型" + ] + } + ] + }, + { + "lv1": "国防军工", + "lv1_id": "national_defense_military", + "children": [ + { + "lv2": "军工装备与信息化", + "lv2_id": "military_equipment_informatization", + "concepts": [ + "AI军工", + "AI无人机军工信息化", + "信息支援概念整理", + "军工-阅兵", + "军工信息化", + "军工", + "军工水面水下作战", + "军贸", + "军用无人机反无人机", + "军机", + "国防军工", + "国产航母", + "巴印军贸", + "巴黎航展", + "无人机蜂群", + "水下军工", + "海军", + "电磁弹射概念股", + "电磁发射设备", + "航母福建舰240430", + "远程火力", + "福建-军工" ] } ] diff --git a/ml/README.md b/ml/README.md new file mode 100644 index 00000000..7651277c --- /dev/null +++ b/ml/README.md @@ -0,0 +1,112 @@ +# 概念异动检测 ML 模块 + +基于 Transformer Autoencoder 的概念异动检测系统。 + +## 环境要求 + +- Python 3.8+ +- PyTorch 2.0+ (CUDA 12.x for 5090 GPU) +- ClickHouse, MySQL, Elasticsearch + +## 数据库配置 + +当前配置(`prepare_data.py`): +- MySQL: `192.168.1.5:3306` +- Elasticsearch: `127.0.0.1:9200` +- ClickHouse: `127.0.0.1:9000` + +## 快速开始 + +```bash +# 1. 安装依赖 +pip install -r ml/requirements.txt + +# 2. 安装 PyTorch (5090 需要 CUDA 12.4) +pip install torch --index-url https://download.pytorch.org/whl/cu124 + +# 3. 运行训练 +chmod +x ml/run_training.sh +./ml/run_training.sh +``` + +## 文件说明 + +| 文件 | 说明 | +|------|------| +| `model.py` | Transformer Autoencoder 模型定义 | +| `prepare_data.py` | 数据提取和特征计算 | +| `train.py` | 模型训练脚本 | +| `inference.py` | 推理服务 | +| `enhanced_detector.py` | 增强版检测器(融合 Alpha + ML) | + +## 训练参数 + +```bash +# 完整参数 +./ml/run_training.sh --start 2022-01-01 --end 2024-12-01 --epochs 100 --batch_size 256 + +# 只准备数据 +python ml/prepare_data.py --start 2022-01-01 + +# 只训练(数据已准备好) +python ml/train.py --epochs 100 --batch_size 256 --lr 1e-4 +``` + +## 模型架构 + +``` +输入: (batch, 30, 6) # 30分钟序列,6个特征 + ↓ +Positional Encoding + ↓ +Transformer Encoder (4层, 8头, d=128) + ↓ +Bottleneck (压缩到 32 维) + ↓ +Transformer Decoder (4层) + ↓ +输出: (batch, 30, 6) # 重构序列 + +异动判断: reconstruction_error > threshold +``` + +## 6维特征 + +1. `alpha` - 超额收益(概念涨幅 - 大盘涨幅) +2. `alpha_delta` - Alpha 5分钟变化 +3. `amt_ratio` - 成交额 / 20分钟均值 +4. `amt_delta` - 成交额变化率 +5. `rank_pct` - Alpha 排名百分位 +6. `limit_up_ratio` - 涨停股占比 + +## 训练产出 + +训练完成后,`ml/checkpoints/` 包含: +- `best_model.pt` - 最佳模型权重 +- `thresholds.json` - 异动阈值 (P90/P95/P99) +- `normalization_stats.json` - 数据标准化参数 +- `config.json` - 训练配置 + +## 使用示例 + +```python +from ml.inference import ConceptAnomalyDetector + +detector = ConceptAnomalyDetector('ml/checkpoints') + +# 实时检测 +is_anomaly, score = detector.detect( + concept_name="人工智能", + features={ + 'alpha': 2.5, + 'alpha_delta': 0.8, + 'amt_ratio': 1.5, + 'amt_delta': 0.3, + 'rank_pct': 0.95, + 'limit_up_ratio': 0.15, + } +) + +if is_anomaly: + print(f"检测到异动!分数: {score}") +``` diff --git a/ml/__init__.py b/ml/__init__.py new file mode 100644 index 00000000..348015d0 --- /dev/null +++ b/ml/__init__.py @@ -0,0 +1,10 @@ +# -*- coding: utf-8 -*- +""" +概念异动检测 ML 模块 + +提供基于 Transformer Autoencoder 的异动检测功能 +""" + +from .inference import ConceptAnomalyDetector, MLAnomalyService + +__all__ = ['ConceptAnomalyDetector', 'MLAnomalyService'] diff --git a/ml/__pycache__/realtime_detector.cpython-310.pyc b/ml/__pycache__/realtime_detector.cpython-310.pyc new file mode 100644 index 00000000..b926f280 Binary files /dev/null and b/ml/__pycache__/realtime_detector.cpython-310.pyc differ diff --git a/ml/backtest.py b/ml/backtest.py new file mode 100644 index 00000000..57d49e30 --- /dev/null +++ b/ml/backtest.py @@ -0,0 +1,481 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +历史异动回测脚本 + +使用训练好的模型,对历史数据进行异动检测,生成异动记录 + +使用方法: + # 回测指定日期范围 + python backtest.py --start 2024-01-01 --end 2024-12-01 + + # 回测单天 + python backtest.py --start 2024-11-01 --end 2024-11-01 + + # 只生成结果,不写入数据库 + python backtest.py --start 2024-01-01 --dry-run +""" + +import os +import sys +import argparse +import json +from datetime import datetime, timedelta +from pathlib import Path +from typing import Dict, List, Tuple, Optional +from collections import defaultdict + +import numpy as np +import pandas as pd +import torch +from tqdm import tqdm +from sqlalchemy import create_engine, text + +# 添加父目录到路径 +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from model import TransformerAutoencoder + + +# ==================== 配置 ==================== + +MYSQL_ENGINE = create_engine( + "mysql+pymysql://root:Zzl5588161!@192.168.1.5:3306/stock", + echo=False +) + +# 特征列表(与训练一致) +FEATURES = [ + 'alpha', + 'alpha_delta', + 'amt_ratio', + 'amt_delta', + 'rank_pct', + 'limit_up_ratio', +] + +# 回测配置 +BACKTEST_CONFIG = { + 'seq_len': 30, # 序列长度 + 'threshold_key': 'p95', # 使用的阈值 + 'min_alpha_abs': 0.5, # 最小 Alpha 绝对值(过滤微小波动) + 'cooldown_minutes': 8, # 同一概念冷却时间 + 'max_alerts_per_minute': 15, # 每分钟最多异动数 + 'clip_value': 10.0, # 极端值截断 +} + + +# ==================== 模型加载 ==================== + +class AnomalyDetector: + """异动检测器""" + + def __init__(self, checkpoint_dir: str = 'ml/checkpoints', device: str = 'auto'): + self.checkpoint_dir = Path(checkpoint_dir) + + # 设备 + if device == 'auto': + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + else: + self.device = torch.device(device) + + # 加载配置 + self._load_config() + + # 加载模型 + self._load_model() + + # 加载阈值 + self._load_thresholds() + + print(f"AnomalyDetector 初始化完成") + print(f" 设备: {self.device}") + print(f" 阈值 ({BACKTEST_CONFIG['threshold_key']}): {self.threshold:.6f}") + + def _load_config(self): + config_path = self.checkpoint_dir / 'config.json' + with open(config_path, 'r') as f: + self.config = json.load(f) + + def _load_model(self): + model_path = self.checkpoint_dir / 'best_model.pt' + checkpoint = torch.load(model_path, map_location=self.device) + + model_config = self.config['model'].copy() + model_config['use_instance_norm'] = self.config.get('use_instance_norm', True) + + self.model = TransformerAutoencoder(**model_config) + self.model.load_state_dict(checkpoint['model_state_dict']) + self.model.to(self.device) + self.model.eval() + + def _load_thresholds(self): + thresholds_path = self.checkpoint_dir / 'thresholds.json' + with open(thresholds_path, 'r') as f: + thresholds = json.load(f) + + self.threshold = thresholds[BACKTEST_CONFIG['threshold_key']] + + @torch.no_grad() + def compute_anomaly_scores(self, sequences: np.ndarray) -> np.ndarray: + """ + 计算异动分数 + + Args: + sequences: (n_sequences, seq_len, n_features) + Returns: + scores: (n_sequences,) 每个序列最后时刻的异动分数 + """ + # 截断极端值 + sequences = np.clip(sequences, -BACKTEST_CONFIG['clip_value'], BACKTEST_CONFIG['clip_value']) + + # 转为 tensor + x = torch.FloatTensor(sequences).to(self.device) + + # 计算重构误差 + errors = self.model.compute_reconstruction_error(x, reduction='none') + + # 取最后一个时刻的误差 + scores = errors[:, -1].cpu().numpy() + + return scores + + def is_anomaly(self, score: float) -> bool: + """判断是否异动""" + return score > self.threshold + + +# ==================== 数据加载 ==================== + +def load_daily_features(data_dir: str, date: str) -> Optional[pd.DataFrame]: + """加载单天的特征数据""" + file_path = Path(data_dir) / f"features_{date}.parquet" + + if not file_path.exists(): + return None + + df = pd.read_parquet(file_path) + return df + + +def get_available_dates(data_dir: str, start_date: str, end_date: str) -> List[str]: + """获取可用的日期列表""" + data_path = Path(data_dir) + all_files = sorted(data_path.glob("features_*.parquet")) + + dates = [] + for f in all_files: + date = f.stem.replace('features_', '') + if start_date <= date <= end_date: + dates.append(date) + + return dates + + +# ==================== 回测逻辑 ==================== + +def backtest_single_day( + detector: AnomalyDetector, + df: pd.DataFrame, + date: str, + seq_len: int = 30 +) -> List[Dict]: + """ + 回测单天数据 + + Args: + detector: 异动检测器 + df: 当天的特征数据 + date: 日期 + seq_len: 序列长度 + + Returns: + alerts: 异动列表 + """ + alerts = [] + + # 按概念分组 + grouped = df.groupby('concept_id', sort=False) + + # 冷却记录 {concept_id: last_alert_timestamp} + cooldown = {} + + # 获取所有时间点 + all_timestamps = sorted(df['timestamp'].unique()) + + if len(all_timestamps) < seq_len: + return alerts + + # 对每个时间点进行检测(从第 seq_len 个开始) + for t_idx in range(seq_len - 1, len(all_timestamps)): + current_time = all_timestamps[t_idx] + window_start_time = all_timestamps[t_idx - seq_len + 1] + + minute_alerts = [] + + # 收集该时刻所有概念的序列 + concept_sequences = [] + concept_infos = [] + + for concept_id, concept_df in grouped: + # 获取该概念在时间窗口内的数据 + mask = (concept_df['timestamp'] >= window_start_time) & (concept_df['timestamp'] <= current_time) + window_df = concept_df[mask].sort_values('timestamp') + + if len(window_df) < seq_len: + continue + + # 取最后 seq_len 个点 + window_df = window_df.tail(seq_len) + + # 提取特征 + features = window_df[FEATURES].values + + # 处理缺失值 + features = np.nan_to_num(features, nan=0.0, posinf=0.0, neginf=0.0) + + # 获取当前时刻的信息 + current_row = window_df.iloc[-1] + + concept_sequences.append(features) + concept_infos.append({ + 'concept_id': concept_id, + 'timestamp': current_time, + 'alpha': current_row.get('alpha', 0), + 'alpha_delta': current_row.get('alpha_delta', 0), + 'amt_ratio': current_row.get('amt_ratio', 1), + 'limit_up_ratio': current_row.get('limit_up_ratio', 0), + 'limit_down_ratio': current_row.get('limit_down_ratio', 0), + 'rank_pct': current_row.get('rank_pct', 0.5), + 'stock_count': current_row.get('stock_count', 0), + 'total_amt': current_row.get('total_amt', 0), + }) + + if not concept_sequences: + continue + + # 批量计算异动分数 + sequences_array = np.array(concept_sequences) + scores = detector.compute_anomaly_scores(sequences_array) + + # 检测异动 + for i, (info, score) in enumerate(zip(concept_infos, scores)): + concept_id = info['concept_id'] + alpha = info['alpha'] + + # 过滤小波动 + if abs(alpha) < BACKTEST_CONFIG['min_alpha_abs']: + continue + + # 检查冷却 + if concept_id in cooldown: + last_alert = cooldown[concept_id] + if isinstance(current_time, datetime): + time_diff = (current_time - last_alert).total_seconds() / 60 + else: + # timestamp 是字符串或其他格式 + time_diff = BACKTEST_CONFIG['cooldown_minutes'] + 1 # 跳过冷却检查 + + if time_diff < BACKTEST_CONFIG['cooldown_minutes']: + continue + + # 判断是否异动 + if not detector.is_anomaly(score): + continue + + # 记录异动 + alert_type = 'surge_up' if alpha > 0 else 'surge_down' + + alert = { + 'concept_id': concept_id, + 'alert_time': current_time, + 'trade_date': date, + 'alert_type': alert_type, + 'anomaly_score': float(score), + 'threshold': detector.threshold, + **info + } + + minute_alerts.append(alert) + cooldown[concept_id] = current_time + + # 按分数排序,限制数量 + minute_alerts.sort(key=lambda x: x['anomaly_score'], reverse=True) + alerts.extend(minute_alerts[:BACKTEST_CONFIG['max_alerts_per_minute']]) + + return alerts + + +# ==================== 数据库写入 ==================== + +def save_alerts_to_mysql(alerts: List[Dict], dry_run: bool = False) -> int: + """保存异动到 MySQL""" + if not alerts: + return 0 + + if dry_run: + print(f" [Dry Run] 将写入 {len(alerts)} 条异动") + return len(alerts) + + saved = 0 + with MYSQL_ENGINE.begin() as conn: + for alert in alerts: + try: + # 检查是否已存在 + check_sql = text(""" + SELECT id FROM concept_minute_alert + WHERE concept_id = :concept_id + AND alert_time = :alert_time + AND trade_date = :trade_date + """) + exists = conn.execute(check_sql, { + 'concept_id': alert['concept_id'], + 'alert_time': alert['alert_time'], + 'trade_date': alert['trade_date'], + }).fetchone() + + if exists: + continue + + # 插入新记录 + insert_sql = text(""" + INSERT INTO concept_minute_alert + (concept_id, concept_name, alert_time, alert_type, trade_date, + change_pct, zscore, importance_score, stock_count, extra_info) + VALUES + (:concept_id, :concept_name, :alert_time, :alert_type, :trade_date, + :change_pct, :zscore, :importance_score, :stock_count, :extra_info) + """) + + conn.execute(insert_sql, { + 'concept_id': alert['concept_id'], + 'concept_name': alert.get('concept_name', ''), + 'alert_time': alert['alert_time'], + 'alert_type': alert['alert_type'], + 'trade_date': alert['trade_date'], + 'change_pct': alert.get('alpha', 0), + 'zscore': alert['anomaly_score'], + 'importance_score': alert['anomaly_score'], + 'stock_count': alert.get('stock_count', 0), + 'extra_info': json.dumps({ + 'detection_method': 'ml_autoencoder', + 'threshold': alert['threshold'], + 'alpha': alert.get('alpha', 0), + 'amt_ratio': alert.get('amt_ratio', 1), + }, ensure_ascii=False) + }) + + saved += 1 + + except Exception as e: + print(f" 保存失败: {alert['concept_id']} - {e}") + + return saved + + +def export_alerts_to_csv(alerts: List[Dict], output_path: str): + """导出异动到 CSV""" + if not alerts: + return + + df = pd.DataFrame(alerts) + df.to_csv(output_path, index=False, encoding='utf-8-sig') + print(f"已导出到: {output_path}") + + +# ==================== 主函数 ==================== + +def main(): + parser = argparse.ArgumentParser(description='历史异动回测') + parser.add_argument('--data_dir', type=str, default='ml/data', + help='特征数据目录') + parser.add_argument('--checkpoint_dir', type=str, default='ml/checkpoints', + help='模型检查点目录') + parser.add_argument('--start', type=str, required=True, + help='开始日期 (YYYY-MM-DD)') + parser.add_argument('--end', type=str, required=True, + help='结束日期 (YYYY-MM-DD)') + parser.add_argument('--dry-run', action='store_true', + help='只计算,不写入数据库') + parser.add_argument('--export-csv', type=str, default=None, + help='导出 CSV 文件路径') + parser.add_argument('--device', type=str, default='auto', + help='设备 (auto/cuda/cpu)') + + args = parser.parse_args() + + print("=" * 60) + print("历史异动回测") + print("=" * 60) + print(f"日期范围: {args.start} ~ {args.end}") + print(f"数据目录: {args.data_dir}") + print(f"模型目录: {args.checkpoint_dir}") + print(f"Dry Run: {args.dry_run}") + print("=" * 60) + + # 初始化检测器 + detector = AnomalyDetector(args.checkpoint_dir, args.device) + + # 获取可用日期 + dates = get_available_dates(args.data_dir, args.start, args.end) + + if not dates: + print(f"未找到 {args.start} ~ {args.end} 范围内的数据") + return + + print(f"\n找到 {len(dates)} 天的数据") + + # 回测 + all_alerts = [] + total_saved = 0 + + for date in tqdm(dates, desc="回测进度"): + # 加载数据 + df = load_daily_features(args.data_dir, date) + + if df is None or df.empty: + continue + + # 回测单天 + alerts = backtest_single_day( + detector, df, date, + seq_len=BACKTEST_CONFIG['seq_len'] + ) + + if alerts: + all_alerts.extend(alerts) + + # 写入数据库 + saved = save_alerts_to_mysql(alerts, dry_run=args.dry_run) + total_saved += saved + + if not args.dry_run: + tqdm.write(f" {date}: 检测到 {len(alerts)} 个异动,保存 {saved} 条") + + # 导出 CSV + if args.export_csv and all_alerts: + export_alerts_to_csv(all_alerts, args.export_csv) + + # 汇总 + print("\n" + "=" * 60) + print("回测完成!") + print("=" * 60) + print(f"总计检测到: {len(all_alerts)} 个异动") + print(f"保存到数据库: {total_saved} 条") + + # 统计 + if all_alerts: + df_alerts = pd.DataFrame(all_alerts) + print(f"\n异动类型分布:") + print(df_alerts['alert_type'].value_counts()) + + print(f"\n异动分数统计:") + print(f" Mean: {df_alerts['anomaly_score'].mean():.4f}") + print(f" Max: {df_alerts['anomaly_score'].max():.4f}") + print(f" Min: {df_alerts['anomaly_score'].min():.4f}") + + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/ml/backtest_fast.py b/ml/backtest_fast.py new file mode 100644 index 00000000..e1c06254 --- /dev/null +++ b/ml/backtest_fast.py @@ -0,0 +1,859 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +快速融合异动回测脚本 + +优化策略: +1. 预先构建所有序列(向量化),避免循环内重复切片 +2. 批量 ML 推理(一次推理所有候选) +3. 使用 NumPy 向量化操作替代 Python 循环 + +性能对比: +- 原版:5分钟/天 +- 优化版:预计 10-30秒/天 +""" + +import os +import sys +import argparse +import json +from datetime import datetime +from pathlib import Path +from typing import Dict, List, Optional, Tuple +from collections import defaultdict + +import numpy as np +import pandas as pd +import torch +from tqdm import tqdm +from sqlalchemy import create_engine, text + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + + +# ==================== 配置 ==================== + +MYSQL_ENGINE = create_engine( + "mysql+pymysql://root:Zzl5588161!@192.168.1.5:3306/stock", + echo=False +) + +FEATURES = ['alpha', 'alpha_delta', 'amt_ratio', 'amt_delta', 'rank_pct', 'limit_up_ratio'] + +CONFIG = { + 'seq_len': 15, # 序列长度(支持跨日后可从 9:30 检测) + 'min_alpha_abs': 0.3, # 最小 alpha 过滤 + 'cooldown_minutes': 8, + 'max_alerts_per_minute': 20, + 'clip_value': 10.0, + # === 融合权重:均衡 === + 'rule_weight': 0.5, + 'ml_weight': 0.5, + # === 触发阈值 === + 'rule_trigger': 65, # 60 -> 65,略提高规则门槛 + 'ml_trigger': 70, # 75 -> 70,略降低 ML 门槛 + 'fusion_trigger': 45, +} + + +# ==================== 规则评分(向量化版)==================== + +def get_size_adjusted_thresholds(stock_count: np.ndarray) -> dict: + """ + 根据概念股票数量计算动态阈值 + + 设计思路: + - 小概念(<10 只):波动大是正常的,需要更高阈值 + - 中概念(10-50 只):标准阈值 + - 大概念(>50 只):能有明显波动说明是真异动,降低阈值 + + 返回各指标的调整系数(乘以基准阈值) + """ + n = len(stock_count) + + # 基于股票数量的调整系数 + # 小概念:系数 > 1(提高阈值,更难触发) + # 大概念:系数 < 1(降低阈值,更容易触发) + size_factor = np.ones(n) + + # 微型概念(<5 只):阈值 × 1.8 + tiny = stock_count < 5 + size_factor[tiny] = 1.8 + + # 小概念(5-10 只):阈值 × 1.4 + small = (stock_count >= 5) & (stock_count < 10) + size_factor[small] = 1.4 + + # 中小概念(10-20 只):阈值 × 1.2 + medium_small = (stock_count >= 10) & (stock_count < 20) + size_factor[medium_small] = 1.2 + + # 中概念(20-50 只):标准阈值 × 1.0 + medium = (stock_count >= 20) & (stock_count < 50) + size_factor[medium] = 1.0 + + # 大概念(50-100 只):阈值 × 0.85 + large = (stock_count >= 50) & (stock_count < 100) + size_factor[large] = 0.85 + + # 超大概念(>100 只):阈值 × 0.7 + xlarge = stock_count >= 100 + size_factor[xlarge] = 0.7 + + return size_factor + + +def score_rules_batch(df: pd.DataFrame) -> Tuple[np.ndarray, List[List[str]]]: + """ + 批量计算规则得分(向量化)- 考虑概念规模版 + + 设计原则: + - 规则作为辅助信号,不应单独主导决策 + - 根据概念股票数量动态调整阈值 + - 大概念异动更有价值,小概念需要更大波动才算异动 + + Args: + df: DataFrame,包含所有特征列(必须包含 stock_count) + Returns: + scores: (n,) 规则得分数组 + triggered_rules: 每行触发的规则列表 + """ + n = len(df) + scores = np.zeros(n) + triggered = [[] for _ in range(n)] + + alpha = df['alpha'].values + alpha_delta = df['alpha_delta'].values + amt_ratio = df['amt_ratio'].values + amt_delta = df['amt_delta'].values + rank_pct = df['rank_pct'].values + limit_up_ratio = df['limit_up_ratio'].values + stock_count = df['stock_count'].values if 'stock_count' in df.columns else np.full(n, 20) + + alpha_abs = np.abs(alpha) + alpha_delta_abs = np.abs(alpha_delta) + + # 获取基于规模的调整系数 + size_factor = get_size_adjusted_thresholds(stock_count) + + # ========== Alpha 规则(动态阈值)========== + # 基准阈值:极强 5%,强 4%,中等 3% + # 实际阈值 = 基准 × size_factor + + # 极强信号 + alpha_extreme_thresh = 5.0 * size_factor + mask = alpha_abs >= alpha_extreme_thresh + scores[mask] += 20 + for i in np.where(mask)[0]: triggered[i].append('alpha_extreme') + + # 强信号 + alpha_strong_thresh = 4.0 * size_factor + mask = (alpha_abs >= alpha_strong_thresh) & (alpha_abs < alpha_extreme_thresh) + scores[mask] += 15 + for i in np.where(mask)[0]: triggered[i].append('alpha_strong') + + # 中等信号 + alpha_medium_thresh = 3.0 * size_factor + mask = (alpha_abs >= alpha_medium_thresh) & (alpha_abs < alpha_strong_thresh) + scores[mask] += 10 + for i in np.where(mask)[0]: triggered[i].append('alpha_medium') + + # ========== Alpha 加速度规则(动态阈值)========== + delta_strong_thresh = 2.0 * size_factor + mask = alpha_delta_abs >= delta_strong_thresh + scores[mask] += 15 + for i in np.where(mask)[0]: triggered[i].append('alpha_delta_strong') + + delta_medium_thresh = 1.5 * size_factor + mask = (alpha_delta_abs >= delta_medium_thresh) & (alpha_delta_abs < delta_strong_thresh) + scores[mask] += 10 + for i in np.where(mask)[0]: triggered[i].append('alpha_delta_medium') + + # ========== 成交额规则(不受规模影响,放量就是放量)========== + mask = amt_ratio >= 10.0 + scores[mask] += 20 + for i in np.where(mask)[0]: triggered[i].append('volume_extreme') + + mask = (amt_ratio >= 6.0) & (amt_ratio < 10.0) + scores[mask] += 12 + for i in np.where(mask)[0]: triggered[i].append('volume_strong') + + # ========== 排名规则 ========== + mask = rank_pct >= 0.98 + scores[mask] += 15 + for i in np.where(mask)[0]: triggered[i].append('rank_top') + + mask = rank_pct <= 0.02 + scores[mask] += 15 + for i in np.where(mask)[0]: triggered[i].append('rank_bottom') + + # ========== 涨停规则(动态阈值)========== + # 大概念有涨停更有意义 + limit_high_thresh = 0.30 * size_factor + mask = limit_up_ratio >= limit_high_thresh + scores[mask] += 20 + for i in np.where(mask)[0]: triggered[i].append('limit_up_high') + + limit_medium_thresh = 0.20 * size_factor + mask = (limit_up_ratio >= limit_medium_thresh) & (limit_up_ratio < limit_high_thresh) + scores[mask] += 12 + for i in np.where(mask)[0]: triggered[i].append('limit_up_medium') + + # ========== 概念规模加分(大概念异动更有价值)========== + # 大概念(50+)额外加分 + large_concept = stock_count >= 50 + has_signal = scores > 0 # 至少触发了某个规则 + mask = large_concept & has_signal + scores[mask] += 10 + for i in np.where(mask)[0]: triggered[i].append('large_concept_bonus') + + # 超大概念(100+)再加分 + xlarge_concept = stock_count >= 100 + mask = xlarge_concept & has_signal + scores[mask] += 10 + for i in np.where(mask)[0]: triggered[i].append('xlarge_concept_bonus') + + # ========== 组合规则(动态阈值)========== + combo_alpha_thresh = 3.0 * size_factor + + # Alpha + 放量 + 排名(三重验证) + mask = (alpha_abs >= combo_alpha_thresh) & (amt_ratio >= 5.0) & ((rank_pct >= 0.95) | (rank_pct <= 0.05)) + scores[mask] += 20 + for i in np.where(mask)[0]: triggered[i].append('triple_signal') + + # Alpha + 涨停(强组合) + mask = (alpha_abs >= combo_alpha_thresh) & (limit_up_ratio >= 0.15 * size_factor) + scores[mask] += 15 + for i in np.where(mask)[0]: triggered[i].append('alpha_with_limit') + + # ========== 小概念惩罚(过滤噪音)========== + # 微型概念(<5 只)如果只有单一信号,减分 + tiny_concept = stock_count < 5 + single_rule = np.array([len(t) <= 1 for t in triggered]) + mask = tiny_concept & single_rule & (scores > 0) + scores[mask] *= 0.5 # 减半 + for i in np.where(mask)[0]: triggered[i].append('tiny_concept_penalty') + + scores = np.clip(scores, 0, 100) + return scores, triggered + + +# ==================== ML 评分器 ==================== + +class FastMLScorer: + """快速 ML 评分器""" + + def __init__(self, checkpoint_dir: str = 'ml/checkpoints', device: str = 'auto'): + self.checkpoint_dir = Path(checkpoint_dir) + + if device == 'auto': + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + elif device == 'cuda' and not torch.cuda.is_available(): + print("警告: CUDA 不可用,使用 CPU") + self.device = torch.device('cpu') + else: + self.device = torch.device(device) + + self.model = None + self.thresholds = None + self._load_model() + + def _load_model(self): + model_path = self.checkpoint_dir / 'best_model.pt' + thresholds_path = self.checkpoint_dir / 'thresholds.json' + config_path = self.checkpoint_dir / 'config.json' + + if not model_path.exists(): + print(f"警告: 模型不存在 {model_path}") + return + + try: + from model import LSTMAutoencoder + + config = {} + if config_path.exists(): + with open(config_path) as f: + config = json.load(f).get('model', {}) + + # 处理旧配置键名 + if 'd_model' in config: + config['hidden_dim'] = config.pop('d_model') // 2 + for key in ['num_encoder_layers', 'num_decoder_layers', 'nhead', 'dim_feedforward', 'max_seq_len', 'use_instance_norm']: + config.pop(key, None) + if 'num_layers' not in config: + config['num_layers'] = 1 + + checkpoint = torch.load(model_path, map_location='cpu') + self.model = LSTMAutoencoder(**config) + self.model.load_state_dict(checkpoint['model_state_dict']) + self.model.to(self.device) + self.model.eval() + + if thresholds_path.exists(): + with open(thresholds_path) as f: + self.thresholds = json.load(f) + + print(f"ML模型加载成功 (设备: {self.device})") + except Exception as e: + print(f"ML模型加载失败: {e}") + self.model = None + + def is_ready(self): + return self.model is not None + + @torch.no_grad() + def score_batch(self, sequences: np.ndarray) -> np.ndarray: + """ + 批量计算 ML 得分 + + Args: + sequences: (batch, seq_len, n_features) + Returns: + scores: (batch,) 0-100 分数 + """ + if not self.is_ready() or len(sequences) == 0: + return np.zeros(len(sequences)) + + x = torch.FloatTensor(sequences).to(self.device) + output, _ = self.model(x) + mse = ((output - x) ** 2).mean(dim=-1) + errors = mse[:, -1].cpu().numpy() + + p95 = self.thresholds.get('p95', 0.1) if self.thresholds else 0.1 + scores = np.clip(errors / p95 * 50, 0, 100) + return scores + + +# ==================== 快速回测 ==================== + +def build_sequences_fast( + df: pd.DataFrame, + seq_len: int = 30, + prev_df: pd.DataFrame = None +) -> Tuple[np.ndarray, pd.DataFrame]: + """ + 快速构建所有有效序列 + + 支持跨日序列:用前一天收盘数据 + 当天开盘数据拼接,实现 9:30 就能检测 + + Args: + df: 当天数据 + seq_len: 序列长度 + prev_df: 前一天数据(可选,用于构建开盘时的序列) + + 返回: + sequences: (n_valid, seq_len, n_features) 所有有效序列 + info_df: 对应的元信息 DataFrame + """ + # 确保按概念和时间排序 + df = df.sort_values(['concept_id', 'timestamp']).reset_index(drop=True) + + # 如果有前一天数据,按概念构建尾部缓存(取每个概念最后 seq_len-1 条) + prev_cache = {} + if prev_df is not None and len(prev_df) > 0: + prev_df = prev_df.sort_values(['concept_id', 'timestamp']) + for concept_id, gdf in prev_df.groupby('concept_id'): + tail_data = gdf.tail(seq_len - 1) + if len(tail_data) > 0: + feat_matrix = tail_data[FEATURES].values + feat_matrix = np.nan_to_num(feat_matrix, nan=0.0, posinf=0.0, neginf=0.0) + feat_matrix = np.clip(feat_matrix, -CONFIG['clip_value'], CONFIG['clip_value']) + prev_cache[concept_id] = feat_matrix + + # 按概念分组 + groups = df.groupby('concept_id') + + sequences = [] + infos = [] + + for concept_id, gdf in groups: + gdf = gdf.reset_index(drop=True) + + # 获取特征矩阵 + feat_matrix = gdf[FEATURES].values + feat_matrix = np.nan_to_num(feat_matrix, nan=0.0, posinf=0.0, neginf=0.0) + feat_matrix = np.clip(feat_matrix, -CONFIG['clip_value'], CONFIG['clip_value']) + + # 如果有前一天缓存,拼接到当天数据前面 + if concept_id in prev_cache: + prev_data = prev_cache[concept_id] + combined_matrix = np.vstack([prev_data, feat_matrix]) + # 计算偏移量:前一天数据的长度 + offset = len(prev_data) + else: + combined_matrix = feat_matrix + offset = 0 + + # 滑动窗口构建序列 + n_total = len(combined_matrix) + if n_total < seq_len: + continue + + for i in range(n_total - seq_len + 1): + seq = combined_matrix[i:i + seq_len] + + # 计算对应当天数据的索引 + # 序列最后一个点的位置 = i + seq_len - 1 + # 对应当天数据的索引 = (i + seq_len - 1) - offset + today_idx = i + seq_len - 1 - offset + + # 只要序列的最后一个点是当天的数据,就记录 + if today_idx < 0 or today_idx >= len(gdf): + continue + + sequences.append(seq) + + # 记录最后一个时间步的信息(当天的) + row = gdf.iloc[today_idx] + infos.append({ + 'concept_id': concept_id, + 'timestamp': row['timestamp'], + 'alpha': row['alpha'], + 'alpha_delta': row.get('alpha_delta', 0), + 'amt_ratio': row.get('amt_ratio', 1), + 'amt_delta': row.get('amt_delta', 0), + 'rank_pct': row.get('rank_pct', 0.5), + 'limit_up_ratio': row.get('limit_up_ratio', 0), + 'stock_count': row.get('stock_count', 0), + 'total_amt': row.get('total_amt', 0), + }) + + if not sequences: + return np.array([]), pd.DataFrame() + + return np.array(sequences), pd.DataFrame(infos) + + +def backtest_single_day_fast( + ml_scorer: FastMLScorer, + df: pd.DataFrame, + date: str, + config: Dict, + prev_df: pd.DataFrame = None +) -> List[Dict]: + """ + 快速回测单天(向量化版本) + + Args: + ml_scorer: ML 评分器 + df: 当天数据 + date: 日期 + config: 配置 + prev_df: 前一天数据(用于 9:30 开始检测) + """ + seq_len = config.get('seq_len', 30) + + # 1. 构建所有序列(支持跨日) + sequences, info_df = build_sequences_fast(df, seq_len, prev_df) + + if len(sequences) == 0: + return [] + + # 2. 过滤小波动 + alpha_abs = np.abs(info_df['alpha'].values) + valid_mask = alpha_abs >= config['min_alpha_abs'] + + sequences = sequences[valid_mask] + info_df = info_df[valid_mask].reset_index(drop=True) + + if len(sequences) == 0: + return [] + + # 3. 批量规则评分 + rule_scores, triggered_rules = score_rules_batch(info_df) + + # 4. 批量 ML 评分(分批处理避免显存溢出) + batch_size = 2048 + ml_scores = [] + for i in range(0, len(sequences), batch_size): + batch_seq = sequences[i:i+batch_size] + batch_scores = ml_scorer.score_batch(batch_seq) + ml_scores.append(batch_scores) + ml_scores = np.concatenate(ml_scores) if ml_scores else np.zeros(len(sequences)) + + # 5. 融合得分 + w1, w2 = config['rule_weight'], config['ml_weight'] + final_scores = w1 * rule_scores + w2 * ml_scores + + # 6. 判断异动 + is_anomaly = ( + (rule_scores >= config['rule_trigger']) | + (ml_scores >= config['ml_trigger']) | + (final_scores >= config['fusion_trigger']) + ) + + # 7. 应用冷却期(按概念+时间排序后处理) + info_df['rule_score'] = rule_scores + info_df['ml_score'] = ml_scores + info_df['final_score'] = final_scores + info_df['is_anomaly'] = is_anomaly + info_df['triggered_rules'] = triggered_rules + + # 只保留异动 + anomaly_df = info_df[info_df['is_anomaly']].copy() + + if len(anomaly_df) == 0: + return [] + + # 应用冷却期 + anomaly_df = anomaly_df.sort_values(['concept_id', 'timestamp']) + cooldown = {} + keep_mask = [] + + for _, row in anomaly_df.iterrows(): + cid = row['concept_id'] + ts = row['timestamp'] + + if cid in cooldown: + try: + diff = (ts - cooldown[cid]).total_seconds() / 60 + except: + diff = config['cooldown_minutes'] + 1 + + if diff < config['cooldown_minutes']: + keep_mask.append(False) + continue + + cooldown[cid] = ts + keep_mask.append(True) + + anomaly_df = anomaly_df[keep_mask] + + # 8. 按时间分组,每分钟最多 max_alerts_per_minute 个 + alerts = [] + for ts, group in anomaly_df.groupby('timestamp'): + group = group.nlargest(config['max_alerts_per_minute'], 'final_score') + + for _, row in group.iterrows(): + alpha = row['alpha'] + if alpha >= 1.5: + atype = 'surge_up' + elif alpha <= -1.5: + atype = 'surge_down' + elif row['amt_ratio'] >= 3.0: + atype = 'volume_spike' + else: + atype = 'unknown' + + rule_score = row['rule_score'] + ml_score = row['ml_score'] + final_score = row['final_score'] + + if rule_score >= config['rule_trigger']: + trigger = f'规则强信号({rule_score:.0f}分)' + elif ml_score >= config['ml_trigger']: + trigger = f'ML强信号({ml_score:.0f}分)' + else: + trigger = f'融合触发({final_score:.0f}分)' + + alerts.append({ + 'concept_id': row['concept_id'], + 'alert_time': row['timestamp'], + 'trade_date': date, + 'alert_type': atype, + 'final_score': final_score, + 'rule_score': rule_score, + 'ml_score': ml_score, + 'trigger_reason': trigger, + 'triggered_rules': row['triggered_rules'], + 'alpha': alpha, + 'alpha_delta': row['alpha_delta'], + 'amt_ratio': row['amt_ratio'], + 'amt_delta': row['amt_delta'], + 'rank_pct': row['rank_pct'], + 'limit_up_ratio': row['limit_up_ratio'], + 'stock_count': row['stock_count'], + 'total_amt': row['total_amt'], + }) + + return alerts + + +# ==================== 数据加载 ==================== + +def load_daily_features(data_dir: str, date: str) -> Optional[pd.DataFrame]: + file_path = Path(data_dir) / f"features_{date}.parquet" + if not file_path.exists(): + return None + return pd.read_parquet(file_path) + + +def get_available_dates(data_dir: str, start: str, end: str) -> List[str]: + data_path = Path(data_dir) + dates = [] + for f in sorted(data_path.glob("features_*.parquet")): + d = f.stem.replace('features_', '') + if start <= d <= end: + dates.append(d) + return dates + + +def get_prev_trading_day(data_dir: str, date: str) -> Optional[str]: + """获取给定日期之前最近的有数据的交易日""" + data_path = Path(data_dir) + all_dates = sorted([f.stem.replace('features_', '') for f in data_path.glob("features_*.parquet")]) + + for i, d in enumerate(all_dates): + if d == date and i > 0: + return all_dates[i - 1] + return None + + +def export_to_csv(alerts: List[Dict], path: str): + if alerts: + pd.DataFrame(alerts).to_csv(path, index=False, encoding='utf-8-sig') + print(f"已导出: {path}") + + +# ==================== 数据库写入 ==================== + +def init_db_table(): + """ + 初始化数据库表(如果不存在则创建) + + 表结构说明: + - concept_id: 概念ID + - alert_time: 异动时间(精确到分钟) + - trade_date: 交易日期 + - alert_type: 异动类型(surge_up/surge_down/volume_spike/unknown) + - final_score: 最终得分(0-100) + - rule_score: 规则得分(0-100) + - ml_score: ML得分(0-100) + - trigger_reason: 触发原因 + - alpha: 超额收益率 + - alpha_delta: alpha变化速度 + - amt_ratio: 成交额放大倍数 + - rank_pct: 排名百分位 + - stock_count: 概念股票数量 + - triggered_rules: 触发的规则列表(JSON) + """ + create_sql = text(""" + CREATE TABLE IF NOT EXISTS concept_anomaly_hybrid ( + id INT AUTO_INCREMENT PRIMARY KEY, + concept_id VARCHAR(64) NOT NULL, + alert_time DATETIME NOT NULL, + trade_date DATE NOT NULL, + alert_type VARCHAR(32) NOT NULL, + final_score FLOAT NOT NULL, + rule_score FLOAT NOT NULL, + ml_score FLOAT NOT NULL, + trigger_reason VARCHAR(64), + alpha FLOAT, + alpha_delta FLOAT, + amt_ratio FLOAT, + amt_delta FLOAT, + rank_pct FLOAT, + limit_up_ratio FLOAT, + stock_count INT, + total_amt FLOAT, + triggered_rules JSON, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + UNIQUE KEY uk_concept_time (concept_id, alert_time, trade_date), + INDEX idx_trade_date (trade_date), + INDEX idx_concept_id (concept_id), + INDEX idx_final_score (final_score), + INDEX idx_alert_type (alert_type) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='概念异动检测结果(融合版)' + """) + + with MYSQL_ENGINE.begin() as conn: + conn.execute(create_sql) + print("数据库表已就绪: concept_anomaly_hybrid") + + +def save_alerts_to_mysql(alerts: List[Dict], dry_run: bool = False) -> int: + """ + 保存异动到 MySQL + + Args: + alerts: 异动列表 + dry_run: 是否只模拟,不实际写入 + + Returns: + 实际保存的记录数 + """ + if not alerts: + return 0 + + if dry_run: + print(f" [Dry Run] 将写入 {len(alerts)} 条异动") + return len(alerts) + + saved = 0 + skipped = 0 + + with MYSQL_ENGINE.begin() as conn: + for alert in alerts: + try: + # 检查是否已存在(使用 INSERT IGNORE 更高效) + insert_sql = text(""" + INSERT IGNORE INTO concept_anomaly_hybrid + (concept_id, alert_time, trade_date, alert_type, + final_score, rule_score, ml_score, trigger_reason, + alpha, alpha_delta, amt_ratio, amt_delta, + rank_pct, limit_up_ratio, stock_count, total_amt, + triggered_rules) + VALUES + (:concept_id, :alert_time, :trade_date, :alert_type, + :final_score, :rule_score, :ml_score, :trigger_reason, + :alpha, :alpha_delta, :amt_ratio, :amt_delta, + :rank_pct, :limit_up_ratio, :stock_count, :total_amt, + :triggered_rules) + """) + + result = conn.execute(insert_sql, { + 'concept_id': alert['concept_id'], + 'alert_time': alert['alert_time'], + 'trade_date': alert['trade_date'], + 'alert_type': alert['alert_type'], + 'final_score': alert['final_score'], + 'rule_score': alert['rule_score'], + 'ml_score': alert['ml_score'], + 'trigger_reason': alert['trigger_reason'], + 'alpha': alert.get('alpha', 0), + 'alpha_delta': alert.get('alpha_delta', 0), + 'amt_ratio': alert.get('amt_ratio', 1), + 'amt_delta': alert.get('amt_delta', 0), + 'rank_pct': alert.get('rank_pct', 0.5), + 'limit_up_ratio': alert.get('limit_up_ratio', 0), + 'stock_count': alert.get('stock_count', 0), + 'total_amt': alert.get('total_amt', 0), + 'triggered_rules': json.dumps(alert.get('triggered_rules', []), ensure_ascii=False), + }) + + if result.rowcount > 0: + saved += 1 + else: + skipped += 1 + + except Exception as e: + print(f" 保存失败: {alert['concept_id']} @ {alert['alert_time']} - {e}") + + if skipped > 0: + print(f" 跳过 {skipped} 条重复记录") + + return saved + + +def clear_alerts_by_date(trade_date: str) -> int: + """清除指定日期的异动记录(用于重新回测)""" + with MYSQL_ENGINE.begin() as conn: + result = conn.execute( + text("DELETE FROM concept_anomaly_hybrid WHERE trade_date = :trade_date"), + {'trade_date': trade_date} + ) + return result.rowcount + + +def analyze_alerts(alerts: List[Dict]): + if not alerts: + print("无异动") + return + + df = pd.DataFrame(alerts) + print(f"\n总异动: {len(alerts)}") + print(f"\n类型分布:\n{df['alert_type'].value_counts()}") + print(f"\n得分统计:") + print(f" 最终: {df['final_score'].mean():.1f} (max: {df['final_score'].max():.1f})") + print(f" 规则: {df['rule_score'].mean():.1f} (max: {df['rule_score'].max():.1f})") + print(f" ML: {df['ml_score'].mean():.1f} (max: {df['ml_score'].max():.1f})") + + trigger_type = df['trigger_reason'].apply( + lambda x: '规则' if '规则' in x else ('ML' if 'ML' in x else '融合') + ) + print(f"\n触发来源:\n{trigger_type.value_counts()}") + + +# ==================== 主函数 ==================== + +def main(): + parser = argparse.ArgumentParser(description='快速融合异动回测') + parser.add_argument('--data_dir', default='ml/data') + parser.add_argument('--checkpoint_dir', default='ml/checkpoints') + parser.add_argument('--start', required=True) + parser.add_argument('--end', default=None) + parser.add_argument('--dry-run', action='store_true', help='模拟运行,不写入数据库') + parser.add_argument('--export-csv', default=None, help='导出 CSV 文件路径') + parser.add_argument('--save-db', action='store_true', help='保存结果到数据库') + parser.add_argument('--clear-first', action='store_true', help='写入前先清除该日期的旧数据') + parser.add_argument('--device', default='auto') + + args = parser.parse_args() + if args.end is None: + args.end = args.start + + print("=" * 60) + print("快速融合异动回测") + print("=" * 60) + print(f"日期: {args.start} ~ {args.end}") + print(f"设备: {args.device}") + print(f"保存数据库: {args.save_db}") + print("=" * 60) + + # 初始化数据库表(如果需要保存) + if args.save_db and not args.dry_run: + init_db_table() + + # 初始化 ML 评分器 + ml_scorer = FastMLScorer(args.checkpoint_dir, args.device) + + # 获取日期 + dates = get_available_dates(args.data_dir, args.start, args.end) + if not dates: + print("无数据") + return + + print(f"找到 {len(dates)} 天数据\n") + + # 回测(支持跨日序列) + all_alerts = [] + total_saved = 0 + prev_df = None # 缓存前一天数据 + + for i, date in enumerate(tqdm(dates, desc="回测")): + df = load_daily_features(args.data_dir, date) + if df is None or df.empty: + prev_df = None # 当天无数据,清空缓存 + continue + + # 第一天需要加载前一天数据(如果存在) + if i == 0 and prev_df is None: + prev_date = get_prev_trading_day(args.data_dir, date) + if prev_date: + prev_df = load_daily_features(args.data_dir, prev_date) + if prev_df is not None: + tqdm.write(f" 加载前一天数据: {prev_date}") + + alerts = backtest_single_day_fast(ml_scorer, df, date, CONFIG, prev_df) + all_alerts.extend(alerts) + + # 保存到数据库 + if args.save_db and alerts: + if args.clear_first and not args.dry_run: + cleared = clear_alerts_by_date(date) + if cleared > 0: + tqdm.write(f" 清除 {date} 旧数据: {cleared} 条") + + saved = save_alerts_to_mysql(alerts, dry_run=args.dry_run) + total_saved += saved + tqdm.write(f" {date}: {len(alerts)} 个异动, 保存 {saved} 条") + elif alerts: + tqdm.write(f" {date}: {len(alerts)} 个异动") + + # 当天数据成为下一天的 prev_df + prev_df = df + + # 导出 CSV + if args.export_csv: + export_to_csv(all_alerts, args.export_csv) + + # 分析 + analyze_alerts(all_alerts) + + print(f"\n总计: {len(all_alerts)} 个异动") + if args.save_db: + print(f"已保存到数据库: {total_saved} 条") + + +if __name__ == "__main__": + main() diff --git a/ml/backtest_hybrid.py b/ml/backtest_hybrid.py new file mode 100644 index 00000000..6913f204 --- /dev/null +++ b/ml/backtest_hybrid.py @@ -0,0 +1,481 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +融合异动回测脚本 + +使用 HybridAnomalyDetector 进行回测: +- 规则评分 + LSTM Autoencoder 融合判断 +- 输出更丰富的异动信息 + +使用方法: + python backtest_hybrid.py --start 2024-01-01 --end 2024-12-01 + python backtest_hybrid.py --start 2024-11-01 --dry-run +""" + +import os +import sys +import argparse +import json +from datetime import datetime +from pathlib import Path +from typing import Dict, List, Optional +from collections import defaultdict + +import numpy as np +import pandas as pd +from tqdm import tqdm +from sqlalchemy import create_engine, text + +# 添加父目录到路径 +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from detector import HybridAnomalyDetector, create_detector + + +# ==================== 配置 ==================== + +MYSQL_ENGINE = create_engine( + "mysql+pymysql://root:Zzl5588161!@192.168.1.5:3306/stock", + echo=False +) + +FEATURES = [ + 'alpha', + 'alpha_delta', + 'amt_ratio', + 'amt_delta', + 'rank_pct', + 'limit_up_ratio', +] + +BACKTEST_CONFIG = { + 'seq_len': 30, + 'min_alpha_abs': 0.3, # 降低阈值,让规则也能发挥作用 + 'cooldown_minutes': 8, + 'max_alerts_per_minute': 20, + 'clip_value': 10.0, +} + + +# ==================== 数据加载 ==================== + +def load_daily_features(data_dir: str, date: str) -> Optional[pd.DataFrame]: + """加载单天的特征数据""" + file_path = Path(data_dir) / f"features_{date}.parquet" + + if not file_path.exists(): + return None + + df = pd.read_parquet(file_path) + return df + + +def get_available_dates(data_dir: str, start_date: str, end_date: str) -> List[str]: + """获取可用的日期列表""" + data_path = Path(data_dir) + all_files = sorted(data_path.glob("features_*.parquet")) + + dates = [] + for f in all_files: + date = f.stem.replace('features_', '') + if start_date <= date <= end_date: + dates.append(date) + + return dates + + +# ==================== 融合回测 ==================== + +def backtest_single_day_hybrid( + detector: HybridAnomalyDetector, + df: pd.DataFrame, + date: str, + seq_len: int = 30 +) -> List[Dict]: + """ + 使用融合检测器回测单天数据(批量优化版) + """ + alerts = [] + + # 按概念分组,预先构建字典 + grouped_dict = {cid: cdf for cid, cdf in df.groupby('concept_id', sort=False)} + + # 冷却记录 + cooldown = {} + + # 获取所有时间点 + all_timestamps = sorted(df['timestamp'].unique()) + + if len(all_timestamps) < seq_len: + return alerts + + # 对每个时间点进行检测 + for t_idx in range(seq_len - 1, len(all_timestamps)): + current_time = all_timestamps[t_idx] + window_start_time = all_timestamps[t_idx - seq_len + 1] + + # 批量收集该时刻所有候选概念 + batch_sequences = [] + batch_features = [] + batch_infos = [] + + for concept_id, concept_df in grouped_dict.items(): + # 检查冷却(提前过滤) + if concept_id in cooldown: + last_alert = cooldown[concept_id] + if isinstance(current_time, datetime): + time_diff = (current_time - last_alert).total_seconds() / 60 + else: + time_diff = BACKTEST_CONFIG['cooldown_minutes'] + 1 + if time_diff < BACKTEST_CONFIG['cooldown_minutes']: + continue + + # 获取时间窗口内的数据 + mask = (concept_df['timestamp'] >= window_start_time) & (concept_df['timestamp'] <= current_time) + window_df = concept_df.loc[mask] + + if len(window_df) < seq_len: + continue + + window_df = window_df.sort_values('timestamp').tail(seq_len) + + # 当前时刻特征 + current_row = window_df.iloc[-1] + alpha = current_row.get('alpha', 0) + + # 过滤微小波动(提前过滤) + if abs(alpha) < BACKTEST_CONFIG['min_alpha_abs']: + continue + + # 提取特征序列 + sequence = window_df[FEATURES].values + sequence = np.nan_to_num(sequence, nan=0.0, posinf=0.0, neginf=0.0) + sequence = np.clip(sequence, -BACKTEST_CONFIG['clip_value'], BACKTEST_CONFIG['clip_value']) + + current_features = { + 'alpha': alpha, + 'alpha_delta': current_row.get('alpha_delta', 0), + 'amt_ratio': current_row.get('amt_ratio', 1), + 'amt_delta': current_row.get('amt_delta', 0), + 'rank_pct': current_row.get('rank_pct', 0.5), + 'limit_up_ratio': current_row.get('limit_up_ratio', 0), + } + + batch_sequences.append(sequence) + batch_features.append(current_features) + batch_infos.append({ + 'concept_id': concept_id, + 'stock_count': current_row.get('stock_count', 0), + 'total_amt': current_row.get('total_amt', 0), + }) + + if not batch_sequences: + continue + + # 批量 ML 推理 + sequences_array = np.array(batch_sequences) + ml_scores = detector.ml_scorer.score(sequences_array) if detector.ml_scorer.is_ready() else [0.0] * len(batch_sequences) + if isinstance(ml_scores, float): + ml_scores = [ml_scores] + + # 批量规则评分 + 融合 + minute_alerts = [] + for i, (features, info) in enumerate(zip(batch_features, batch_infos)): + concept_id = info['concept_id'] + + # 规则评分 + rule_score, rule_details = detector.rule_scorer.score(features) + + # ML 评分 + ml_score = ml_scores[i] if i < len(ml_scores) else 0.0 + + # 融合 + w1 = detector.config['rule_weight'] + w2 = detector.config['ml_weight'] + final_score = w1 * rule_score + w2 * ml_score + + # 判断是否异动 + is_anomaly = False + trigger_reason = '' + + if rule_score >= detector.config['rule_trigger']: + is_anomaly = True + trigger_reason = f'规则强信号({rule_score:.0f}分)' + elif ml_score >= detector.config['ml_trigger']: + is_anomaly = True + trigger_reason = f'ML强信号({ml_score:.0f}分)' + elif final_score >= detector.config['fusion_trigger']: + is_anomaly = True + trigger_reason = f'融合触发({final_score:.0f}分)' + + if not is_anomaly: + continue + + # 异动类型 + alpha = features.get('alpha', 0) + if alpha >= 1.5: + anomaly_type = 'surge_up' + elif alpha <= -1.5: + anomaly_type = 'surge_down' + elif features.get('amt_ratio', 1) >= 3.0: + anomaly_type = 'volume_spike' + else: + anomaly_type = 'unknown' + + alert = { + 'concept_id': concept_id, + 'alert_time': current_time, + 'trade_date': date, + 'alert_type': anomaly_type, + 'final_score': final_score, + 'rule_score': rule_score, + 'ml_score': ml_score, + 'trigger_reason': trigger_reason, + 'triggered_rules': list(rule_details.keys()), + **features, + **info, + } + + minute_alerts.append(alert) + cooldown[concept_id] = current_time + + # 按最终得分排序 + minute_alerts.sort(key=lambda x: x['final_score'], reverse=True) + alerts.extend(minute_alerts[:BACKTEST_CONFIG['max_alerts_per_minute']]) + + return alerts + + +# ==================== 数据库写入 ==================== + +def save_alerts_to_mysql(alerts: List[Dict], dry_run: bool = False) -> int: + """保存异动到 MySQL(增强版字段)""" + if not alerts: + return 0 + + if dry_run: + print(f" [Dry Run] 将写入 {len(alerts)} 条异动") + return len(alerts) + + saved = 0 + with MYSQL_ENGINE.begin() as conn: + for alert in alerts: + try: + # 检查是否已存在 + check_sql = text(""" + SELECT id FROM concept_minute_alert + WHERE concept_id = :concept_id + AND alert_time = :alert_time + AND trade_date = :trade_date + """) + exists = conn.execute(check_sql, { + 'concept_id': alert['concept_id'], + 'alert_time': alert['alert_time'], + 'trade_date': alert['trade_date'], + }).fetchone() + + if exists: + continue + + # 插入新记录 + insert_sql = text(""" + INSERT INTO concept_minute_alert + (concept_id, concept_name, alert_time, alert_type, trade_date, + change_pct, zscore, importance_score, stock_count, extra_info) + VALUES + (:concept_id, :concept_name, :alert_time, :alert_type, :trade_date, + :change_pct, :zscore, :importance_score, :stock_count, :extra_info) + """) + + extra_info = { + 'detection_method': 'hybrid', + 'final_score': alert['final_score'], + 'rule_score': alert['rule_score'], + 'ml_score': alert['ml_score'], + 'trigger_reason': alert['trigger_reason'], + 'triggered_rules': alert['triggered_rules'], + 'alpha': alert.get('alpha', 0), + 'alpha_delta': alert.get('alpha_delta', 0), + 'amt_ratio': alert.get('amt_ratio', 1), + } + + conn.execute(insert_sql, { + 'concept_id': alert['concept_id'], + 'concept_name': alert.get('concept_name', ''), + 'alert_time': alert['alert_time'], + 'alert_type': alert['alert_type'], + 'trade_date': alert['trade_date'], + 'change_pct': alert.get('alpha', 0), + 'zscore': alert['final_score'], # 用最终得分作为 zscore + 'importance_score': alert['final_score'], + 'stock_count': alert.get('stock_count', 0), + 'extra_info': json.dumps(extra_info, ensure_ascii=False) + }) + + saved += 1 + + except Exception as e: + print(f" 保存失败: {alert['concept_id']} - {e}") + + return saved + + +def export_alerts_to_csv(alerts: List[Dict], output_path: str): + """导出异动到 CSV""" + if not alerts: + return + + df = pd.DataFrame(alerts) + df.to_csv(output_path, index=False, encoding='utf-8-sig') + print(f"已导出到: {output_path}") + + +# ==================== 统计分析 ==================== + +def analyze_alerts(alerts: List[Dict]): + """分析异动结果""" + if not alerts: + print("无异动数据") + return + + df = pd.DataFrame(alerts) + + print("\n" + "=" * 60) + print("异动统计分析") + print("=" * 60) + + # 1. 基本统计 + print(f"\n总异动数: {len(alerts)}") + + # 2. 按类型统计 + print(f"\n异动类型分布:") + print(df['alert_type'].value_counts()) + + # 3. 得分统计 + print(f"\n得分统计:") + print(f" 最终得分 - Mean: {df['final_score'].mean():.1f}, Max: {df['final_score'].max():.1f}") + print(f" 规则得分 - Mean: {df['rule_score'].mean():.1f}, Max: {df['rule_score'].max():.1f}") + print(f" ML得分 - Mean: {df['ml_score'].mean():.1f}, Max: {df['ml_score'].max():.1f}") + + # 4. 触发来源分析 + print(f"\n触发来源分析:") + trigger_counts = df['trigger_reason'].apply( + lambda x: '规则' if '规则' in x else ('ML' if 'ML' in x else '融合') + ).value_counts() + print(trigger_counts) + + # 5. 规则触发频率 + all_rules = [] + for rules in df['triggered_rules']: + if isinstance(rules, list): + all_rules.extend(rules) + + if all_rules: + print(f"\n最常触发的规则 (Top 10):") + from collections import Counter + rule_counts = Counter(all_rules) + for rule, count in rule_counts.most_common(10): + print(f" {rule}: {count}") + + +# ==================== 主函数 ==================== + +def main(): + parser = argparse.ArgumentParser(description='融合异动回测') + parser.add_argument('--data_dir', type=str, default='ml/data', + help='特征数据目录') + parser.add_argument('--checkpoint_dir', type=str, default='ml/checkpoints', + help='模型检查点目录') + parser.add_argument('--start', type=str, required=True, + help='开始日期 (YYYY-MM-DD)') + parser.add_argument('--end', type=str, default=None, + help='结束日期 (YYYY-MM-DD),默认=start') + parser.add_argument('--dry-run', action='store_true', + help='只计算,不写入数据库') + parser.add_argument('--export-csv', type=str, default=None, + help='导出 CSV 文件路径') + parser.add_argument('--rule-weight', type=float, default=0.6, + help='规则权重 (0-1)') + parser.add_argument('--ml-weight', type=float, default=0.4, + help='ML权重 (0-1)') + parser.add_argument('--device', type=str, default='cuda', + help='设备 (cuda/cpu),默认 cuda') + + args = parser.parse_args() + + if args.end is None: + args.end = args.start + + print("=" * 60) + print("融合异动回测 (规则 + LSTM)") + print("=" * 60) + print(f"日期范围: {args.start} ~ {args.end}") + print(f"数据目录: {args.data_dir}") + print(f"模型目录: {args.checkpoint_dir}") + print(f"规则权重: {args.rule_weight}") + print(f"ML权重: {args.ml_weight}") + print(f"设备: {args.device}") + print(f"Dry Run: {args.dry_run}") + print("=" * 60) + + # 初始化融合检测器(使用 GPU) + config = { + 'rule_weight': args.rule_weight, + 'ml_weight': args.ml_weight, + } + + # 修改 detector.py 中 MLScorer 的设备 + from detector import HybridAnomalyDetector + detector = HybridAnomalyDetector(config, args.checkpoint_dir, device=args.device) + + # 获取可用日期 + dates = get_available_dates(args.data_dir, args.start, args.end) + + if not dates: + print(f"未找到 {args.start} ~ {args.end} 范围内的数据") + return + + print(f"\n找到 {len(dates)} 天的数据") + + # 回测 + all_alerts = [] + total_saved = 0 + + for date in tqdm(dates, desc="回测进度"): + df = load_daily_features(args.data_dir, date) + + if df is None or df.empty: + continue + + alerts = backtest_single_day_hybrid( + detector, df, date, + seq_len=BACKTEST_CONFIG['seq_len'] + ) + + if alerts: + all_alerts.extend(alerts) + + saved = save_alerts_to_mysql(alerts, dry_run=args.dry_run) + total_saved += saved + + if not args.dry_run: + tqdm.write(f" {date}: 检测到 {len(alerts)} 个异动,保存 {saved} 条") + + # 导出 CSV + if args.export_csv and all_alerts: + export_alerts_to_csv(all_alerts, args.export_csv) + + # 统计分析 + analyze_alerts(all_alerts) + + # 汇总 + print("\n" + "=" * 60) + print("回测完成!") + print("=" * 60) + print(f"总计检测到: {len(all_alerts)} 个异动") + print(f"保存到数据库: {total_saved} 条") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/ml/backtest_v2.py b/ml/backtest_v2.py new file mode 100644 index 00000000..84524fd9 --- /dev/null +++ b/ml/backtest_v2.py @@ -0,0 +1,294 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +V2 回测脚本 - 验证时间片对齐 + 持续性确认的效果 + +回测指标: +1. 准确率:异动后 N 分钟内 alpha 是否继续上涨/下跌 +2. 虚警率:多少异动是噪音 +3. 持续性:平均异动持续时长 +""" + +import os +import sys +import json +import argparse +from datetime import datetime +from pathlib import Path +from typing import Dict, List, Tuple +from collections import defaultdict + +import numpy as np +import pandas as pd +from tqdm import tqdm +from sqlalchemy import create_engine, text + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from ml.detector_v2 import AnomalyDetectorV2, CONFIG + + +# ==================== 配置 ==================== + +MYSQL_ENGINE = create_engine( + "mysql+pymysql://root:Zzl5588161!@192.168.1.5:3306/stock", + echo=False +) + + +# ==================== 回测评估 ==================== + +def evaluate_alerts( + alerts: List[Dict], + raw_data: pd.DataFrame, + lookahead_minutes: int = 10 +) -> Dict: + """ + 评估异动质量 + + 指标: + 1. 方向正确率:异动后 N 分钟 alpha 方向是否一致 + 2. 持续率:异动后 N 分钟内有多少时刻 alpha 保持同向 + 3. 峰值收益:异动后 N 分钟内的最大 alpha + """ + if not alerts: + return {'accuracy': 0, 'sustained_rate': 0, 'avg_peak': 0, 'total_alerts': 0} + + results = [] + + for alert in alerts: + concept_id = alert['concept_id'] + alert_time = alert['alert_time'] + alert_alpha = alert['alpha'] + is_up = alert_alpha > 0 + + # 获取该概念在异动后的数据 + concept_data = raw_data[ + (raw_data['concept_id'] == concept_id) & + (raw_data['timestamp'] > alert_time) + ].head(lookahead_minutes) + + if len(concept_data) < 3: + continue + + future_alphas = concept_data['alpha'].values + + # 方向正确:未来 alpha 平均值与当前同向 + avg_future_alpha = np.mean(future_alphas) + direction_correct = (is_up and avg_future_alpha > 0) or (not is_up and avg_future_alpha < 0) + + # 持续率:有多少时刻保持同向 + if is_up: + sustained_count = sum(1 for a in future_alphas if a > 0) + else: + sustained_count = sum(1 for a in future_alphas if a < 0) + sustained_rate = sustained_count / len(future_alphas) + + # 峰值收益 + if is_up: + peak = max(future_alphas) + else: + peak = min(future_alphas) + + results.append({ + 'direction_correct': direction_correct, + 'sustained_rate': sustained_rate, + 'peak': peak, + 'alert_alpha': alert_alpha, + }) + + if not results: + return {'accuracy': 0, 'sustained_rate': 0, 'avg_peak': 0, 'total_alerts': 0} + + return { + 'accuracy': np.mean([r['direction_correct'] for r in results]), + 'sustained_rate': np.mean([r['sustained_rate'] for r in results]), + 'avg_peak': np.mean([abs(r['peak']) for r in results]), + 'total_alerts': len(alerts), + 'evaluated_alerts': len(results), + } + + +def save_alerts_to_mysql(alerts: List[Dict], dry_run: bool = False) -> int: + """保存异动到 MySQL""" + if not alerts or dry_run: + return 0 + + # 确保表存在 + with MYSQL_ENGINE.begin() as conn: + conn.execute(text(""" + CREATE TABLE IF NOT EXISTS concept_anomaly_v2 ( + id INT AUTO_INCREMENT PRIMARY KEY, + concept_id VARCHAR(64) NOT NULL, + alert_time DATETIME NOT NULL, + trade_date DATE NOT NULL, + alert_type VARCHAR(32) NOT NULL, + final_score FLOAT NOT NULL, + rule_score FLOAT NOT NULL, + ml_score FLOAT NOT NULL, + trigger_reason VARCHAR(128), + confirm_ratio FLOAT, + alpha FLOAT, + alpha_zscore FLOAT, + amt_zscore FLOAT, + rank_zscore FLOAT, + momentum_3m FLOAT, + momentum_5m FLOAT, + limit_up_ratio FLOAT, + triggered_rules JSON, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + UNIQUE KEY uk_concept_time (concept_id, alert_time, trade_date), + INDEX idx_trade_date (trade_date), + INDEX idx_final_score (final_score) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='概念异动 V2(时间片对齐+持续确认)' + """)) + + # 插入数据 + saved = 0 + with MYSQL_ENGINE.begin() as conn: + for alert in alerts: + try: + conn.execute(text(""" + INSERT IGNORE INTO concept_anomaly_v2 + (concept_id, alert_time, trade_date, alert_type, + final_score, rule_score, ml_score, trigger_reason, confirm_ratio, + alpha, alpha_zscore, amt_zscore, rank_zscore, + momentum_3m, momentum_5m, limit_up_ratio, triggered_rules) + VALUES + (:concept_id, :alert_time, :trade_date, :alert_type, + :final_score, :rule_score, :ml_score, :trigger_reason, :confirm_ratio, + :alpha, :alpha_zscore, :amt_zscore, :rank_zscore, + :momentum_3m, :momentum_5m, :limit_up_ratio, :triggered_rules) + """), { + 'concept_id': alert['concept_id'], + 'alert_time': alert['alert_time'], + 'trade_date': alert['trade_date'], + 'alert_type': alert['alert_type'], + 'final_score': alert['final_score'], + 'rule_score': alert['rule_score'], + 'ml_score': alert['ml_score'], + 'trigger_reason': alert['trigger_reason'], + 'confirm_ratio': alert.get('confirm_ratio', 0), + 'alpha': alert['alpha'], + 'alpha_zscore': alert.get('alpha_zscore', 0), + 'amt_zscore': alert.get('amt_zscore', 0), + 'rank_zscore': alert.get('rank_zscore', 0), + 'momentum_3m': alert.get('momentum_3m', 0), + 'momentum_5m': alert.get('momentum_5m', 0), + 'limit_up_ratio': alert.get('limit_up_ratio', 0), + 'triggered_rules': json.dumps(alert.get('triggered_rules', [])), + }) + saved += 1 + except Exception as e: + print(f"保存失败: {e}") + + return saved + + +# ==================== 主函数 ==================== + +def main(): + parser = argparse.ArgumentParser(description='V2 回测') + parser.add_argument('--start', type=str, required=True, help='开始日期') + parser.add_argument('--end', type=str, default=None, help='结束日期') + parser.add_argument('--model_dir', type=str, default='ml/checkpoints_v2') + parser.add_argument('--baseline_dir', type=str, default='ml/data_v2/baselines') + parser.add_argument('--save', action='store_true', help='保存到数据库') + parser.add_argument('--lookahead', type=int, default=10, help='评估前瞻时间(分钟)') + + args = parser.parse_args() + + end_date = args.end or args.start + + print("=" * 60) + print("V2 回测 - 时间片对齐 + 持续性确认") + print("=" * 60) + print(f"日期范围: {args.start} ~ {end_date}") + print(f"模型目录: {args.model_dir}") + print(f"评估前瞻: {args.lookahead} 分钟") + + # 初始化检测器 + detector = AnomalyDetectorV2( + model_dir=args.model_dir, + baseline_dir=args.baseline_dir + ) + + # 获取交易日 + from prepare_data_v2 import get_trading_days + trading_days = get_trading_days(args.start, end_date) + + if not trading_days: + print("无交易日") + return + + print(f"交易日数: {len(trading_days)}") + + # 回测统计 + total_stats = { + 'total_alerts': 0, + 'accuracy_sum': 0, + 'sustained_sum': 0, + 'peak_sum': 0, + 'day_count': 0, + } + + all_alerts = [] + + for trade_date in tqdm(trading_days, desc="回测进度"): + # 检测异动 + alerts = detector.detect(trade_date) + + if not alerts: + continue + + all_alerts.extend(alerts) + + # 评估 + raw_data = detector._compute_raw_features(trade_date) + if raw_data.empty: + continue + + stats = evaluate_alerts(alerts, raw_data, args.lookahead) + + if stats['evaluated_alerts'] > 0: + total_stats['total_alerts'] += stats['total_alerts'] + total_stats['accuracy_sum'] += stats['accuracy'] * stats['evaluated_alerts'] + total_stats['sustained_sum'] += stats['sustained_rate'] * stats['evaluated_alerts'] + total_stats['peak_sum'] += stats['avg_peak'] * stats['evaluated_alerts'] + total_stats['day_count'] += 1 + + print(f"\n[{trade_date}] 异动: {stats['total_alerts']}, " + f"准确率: {stats['accuracy']:.1%}, " + f"持续率: {stats['sustained_rate']:.1%}, " + f"峰值: {stats['avg_peak']:.2f}%") + + # 汇总 + print("\n" + "=" * 60) + print("回测汇总") + print("=" * 60) + + if total_stats['total_alerts'] > 0: + avg_accuracy = total_stats['accuracy_sum'] / total_stats['total_alerts'] + avg_sustained = total_stats['sustained_sum'] / total_stats['total_alerts'] + avg_peak = total_stats['peak_sum'] / total_stats['total_alerts'] + + print(f"总异动数: {total_stats['total_alerts']}") + print(f"回测天数: {total_stats['day_count']}") + print(f"平均每天: {total_stats['total_alerts'] / max(1, total_stats['day_count']):.1f} 个") + print(f"方向准确率: {avg_accuracy:.1%}") + print(f"持续率: {avg_sustained:.1%}") + print(f"平均峰值: {avg_peak:.2f}%") + else: + print("无异动检测结果") + + # 保存 + if args.save and all_alerts: + print(f"\n保存 {len(all_alerts)} 条异动到数据库...") + saved = save_alerts_to_mysql(all_alerts) + print(f"保存完成: {saved} 条") + + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/ml/checkpoints_v2/config.json b/ml/checkpoints_v2/config.json new file mode 100644 index 00000000..007997f5 --- /dev/null +++ b/ml/checkpoints_v2/config.json @@ -0,0 +1,31 @@ +{ + "seq_len": 10, + "stride": 2, + "train_end_date": "2025-06-30", + "val_end_date": "2025-09-30", + "features": [ + "alpha_zscore", + "amt_zscore", + "rank_zscore", + "momentum_3m", + "momentum_5m", + "limit_up_ratio" + ], + "batch_size": 32768, + "epochs": 150, + "learning_rate": 0.0006, + "weight_decay": 1e-05, + "gradient_clip": 1.0, + "patience": 15, + "min_delta": 1e-06, + "model": { + "n_features": 6, + "hidden_dim": 32, + "latent_dim": 4, + "num_layers": 1, + "dropout": 0.2, + "bidirectional": true + }, + "clip_value": 5.0, + "threshold_percentiles": [90, 95, 99] +} diff --git a/ml/checkpoints_v2/thresholds.json b/ml/checkpoints_v2/thresholds.json new file mode 100644 index 00000000..47a6fe63 --- /dev/null +++ b/ml/checkpoints_v2/thresholds.json @@ -0,0 +1,8 @@ +{ + "p90": 0.15, + "p95": 0.25, + "p99": 0.50, + "mean": 0.08, + "std": 0.12, + "median": 0.06 +} diff --git a/ml/detector.py b/ml/detector.py new file mode 100644 index 00000000..c5184771 --- /dev/null +++ b/ml/detector.py @@ -0,0 +1,635 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +概念异动检测器 - 融合版 + +结合两种方法的优势: +1. 规则评分系统:可解释、稳定、覆盖已知模式 +2. LSTM Autoencoder:发现未知的异常模式 + +融合策略: +┌─────────────────────────────────────────────────────────┐ +│ 输入特征 │ +│ (alpha, alpha_delta, amt_ratio, amt_delta, rank_pct, │ +│ limit_up_ratio) │ +├─────────────────────────────────────────────────────────┤ +│ │ +│ ┌──────────────┐ ┌──────────────┐ │ +│ │ 规则评分系统 │ │ LSTM Autoencoder │ │ +│ │ (0-100分) │ │ (重构误差) │ │ +│ └──────┬───────┘ └──────┬───────┘ │ +│ │ │ │ +│ ▼ ▼ │ +│ rule_score (0-100) ml_score (标准化后 0-100) │ +│ │ +├─────────────────────────────────────────────────────────┤ +│ 融合策略 │ +│ │ +│ final_score = w1 * rule_score + w2 * ml_score │ +│ │ +│ 异动判定: │ +│ - rule_score >= 60 → 直接触发(规则强信号) │ +│ - ml_score >= 80 → 直接触发(ML强信号) │ +│ - final_score >= 50 → 融合触发 │ +│ │ +└─────────────────────────────────────────────────────────┘ + +优势: +- 规则系统保证已知模式的检出率 +- ML模型捕捉规则未覆盖的异常 +- 两者互相验证,减少误报 +""" + +import json +from pathlib import Path +from typing import Dict, List, Optional, Tuple +from dataclasses import dataclass + +import numpy as np +import torch + +# 尝试导入模型(可能不存在) +try: + from model import LSTMAutoencoder, create_model + HAS_MODEL = True +except ImportError: + HAS_MODEL = False + + +@dataclass +class AnomalyResult: + """异动检测结果""" + is_anomaly: bool + final_score: float # 最终得分 (0-100) + rule_score: float # 规则得分 (0-100) + ml_score: float # ML得分 (0-100) + trigger_reason: str # 触发原因 + rule_details: Dict # 规则明细 + anomaly_type: str # 异动类型: surge_up / surge_down / volume_spike / unknown + + +class RuleBasedScorer: + """ + 基于规则的评分系统 + + 设计原则: + - 每个规则独立打分 + - 分数可叠加 + - 阈值可配置 + """ + + # 默认规则配置 + DEFAULT_RULES = { + # Alpha 相关(超额收益) + 'alpha_strong': { + 'condition': lambda r: abs(r.get('alpha', 0)) >= 3.0, + 'score': 35, + 'description': 'Alpha强信号(|α|≥3%)' + }, + 'alpha_medium': { + 'condition': lambda r: 2.0 <= abs(r.get('alpha', 0)) < 3.0, + 'score': 25, + 'description': 'Alpha中等(2%≤|α|<3%)' + }, + 'alpha_weak': { + 'condition': lambda r: 1.5 <= abs(r.get('alpha', 0)) < 2.0, + 'score': 15, + 'description': 'Alpha轻微(1.5%≤|α|<2%)' + }, + + # Alpha 变化率(加速度) + 'alpha_delta_strong': { + 'condition': lambda r: abs(r.get('alpha_delta', 0)) >= 1.0, + 'score': 30, + 'description': 'Alpha加速强(|Δα|≥1%)' + }, + 'alpha_delta_medium': { + 'condition': lambda r: 0.5 <= abs(r.get('alpha_delta', 0)) < 1.0, + 'score': 20, + 'description': 'Alpha加速中(0.5%≤|Δα|<1%)' + }, + + # 成交额比率(放量) + 'volume_spike_strong': { + 'condition': lambda r: r.get('amt_ratio', 1) >= 5.0, + 'score': 30, + 'description': '极度放量(≥5倍)' + }, + 'volume_spike_medium': { + 'condition': lambda r: 3.0 <= r.get('amt_ratio', 1) < 5.0, + 'score': 20, + 'description': '显著放量(3-5倍)' + }, + 'volume_spike_weak': { + 'condition': lambda r: 2.0 <= r.get('amt_ratio', 1) < 3.0, + 'score': 10, + 'description': '轻微放量(2-3倍)' + }, + + # 成交额变化率 + 'amt_delta_strong': { + 'condition': lambda r: abs(r.get('amt_delta', 0)) >= 1.0, + 'score': 15, + 'description': '成交额急变(|Δamt|≥100%)' + }, + + # 排名跳变 + 'rank_top': { + 'condition': lambda r: r.get('rank_pct', 0.5) >= 0.95, + 'score': 25, + 'description': '排名前5%' + }, + 'rank_bottom': { + 'condition': lambda r: r.get('rank_pct', 0.5) <= 0.05, + 'score': 25, + 'description': '排名后5%' + }, + 'rank_high': { + 'condition': lambda r: 0.9 <= r.get('rank_pct', 0.5) < 0.95, + 'score': 15, + 'description': '排名前10%' + }, + + # 涨停比例 + 'limit_up_high': { + 'condition': lambda r: r.get('limit_up_ratio', 0) >= 0.2, + 'score': 25, + 'description': '涨停比例≥20%' + }, + 'limit_up_medium': { + 'condition': lambda r: 0.1 <= r.get('limit_up_ratio', 0) < 0.2, + 'score': 15, + 'description': '涨停比例10-20%' + }, + + # 组合条件(更可靠的信号) + 'alpha_with_volume': { + 'condition': lambda r: abs(r.get('alpha', 0)) >= 1.5 and r.get('amt_ratio', 1) >= 2.0, + 'score': 20, # 额外加分 + 'description': 'Alpha+放量组合' + }, + 'acceleration_with_rank': { + 'condition': lambda r: abs(r.get('alpha_delta', 0)) >= 0.5 and (r.get('rank_pct', 0.5) >= 0.9 or r.get('rank_pct', 0.5) <= 0.1), + 'score': 15, # 额外加分 + 'description': '加速+排名异常组合' + }, + } + + def __init__(self, rules: Dict = None): + """ + 初始化规则评分器 + + Args: + rules: 自定义规则,格式同 DEFAULT_RULES + """ + self.rules = rules or self.DEFAULT_RULES + + def score(self, features: Dict) -> Tuple[float, Dict]: + """ + 计算规则得分 + + Args: + features: 特征字典,包含 alpha, alpha_delta, amt_ratio 等 + Returns: + score: 总分 (0-100) + details: 触发的规则明细 + """ + total_score = 0 + triggered_rules = {} + + for rule_name, rule_config in self.rules.items(): + try: + if rule_config['condition'](features): + total_score += rule_config['score'] + triggered_rules[rule_name] = { + 'score': rule_config['score'], + 'description': rule_config['description'] + } + except Exception: + # 忽略规则计算错误 + pass + + # 限制在 0-100 + total_score = min(100, max(0, total_score)) + + return total_score, triggered_rules + + def get_anomaly_type(self, features: Dict) -> str: + """判断异动类型""" + alpha = features.get('alpha', 0) + amt_ratio = features.get('amt_ratio', 1) + + if alpha >= 1.5: + return 'surge_up' + elif alpha <= -1.5: + return 'surge_down' + elif amt_ratio >= 3.0: + return 'volume_spike' + else: + return 'unknown' + + +class MLScorer: + """ + 基于 LSTM Autoencoder 的评分器 + + 将重构误差转换为 0-100 的分数 + """ + + def __init__( + self, + checkpoint_dir: str = 'ml/checkpoints', + device: str = 'auto' + ): + self.checkpoint_dir = Path(checkpoint_dir) + + # 设备检测 + if device == 'auto': + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + elif device == 'cuda' and not torch.cuda.is_available(): + print("警告: CUDA 不可用,使用 CPU") + self.device = torch.device('cpu') + else: + self.device = torch.device(device) + + self.model = None + self.thresholds = None + self.config = None + + # 尝试加载模型 + self._load_model() + + def _load_model(self): + """加载模型和阈值""" + if not HAS_MODEL: + print("警告: 无法导入模型模块") + return + + model_path = self.checkpoint_dir / 'best_model.pt' + thresholds_path = self.checkpoint_dir / 'thresholds.json' + config_path = self.checkpoint_dir / 'config.json' + + if not model_path.exists(): + print(f"警告: 模型文件不存在 {model_path}") + return + + try: + # 加载配置 + if config_path.exists(): + with open(config_path, 'r') as f: + self.config = json.load(f) + + # 先用 CPU 加载模型(避免 CUDA 不可用问题),再移动到目标设备 + checkpoint = torch.load(model_path, map_location='cpu') + + model_config = self.config.get('model', {}) if self.config else {} + self.model = create_model(model_config) + self.model.load_state_dict(checkpoint['model_state_dict']) + self.model.to(self.device) + self.model.eval() + + # 加载阈值 + if thresholds_path.exists(): + with open(thresholds_path, 'r') as f: + self.thresholds = json.load(f) + + print(f"MLScorer 加载成功 (设备: {self.device})") + + except Exception as e: + print(f"警告: 模型加载失败 - {e}") + import traceback + traceback.print_exc() + self.model = None + + def is_ready(self) -> bool: + """检查模型是否就绪""" + return self.model is not None + + @torch.no_grad() + def score(self, sequence: np.ndarray) -> float: + """ + 计算 ML 得分 + + Args: + sequence: (seq_len, n_features) 或 (batch, seq_len, n_features) + Returns: + score: 0-100 的分数,越高越异常 + """ + if not self.is_ready(): + return 0.0 + + # 确保是 3D + if sequence.ndim == 2: + sequence = sequence[np.newaxis, ...] + + # 转为 tensor + x = torch.FloatTensor(sequence).to(self.device) + + # 计算重构误差 + output, _ = self.model(x) + mse = ((output - x) ** 2).mean(dim=-1) # (batch, seq_len) + + # 取最后时刻的误差 + error = mse[:, -1].cpu().numpy() + + # 转换为 0-100 分数 + # 使用 p95 阈值作为参考 + if self.thresholds: + p95 = self.thresholds.get('p95', 0.1) + p99 = self.thresholds.get('p99', 0.2) + else: + p95, p99 = 0.1, 0.2 + + # 线性映射:p95 -> 50分, p99 -> 80分 + # error=0 -> 0分, error>=p99*1.5 -> 100分 + score = np.clip(error / p95 * 50, 0, 100) + + return float(score[0]) if len(score) == 1 else score.tolist() + + +class HybridAnomalyDetector: + """ + 融合异动检测器 + + 结合规则系统和 ML 模型 + """ + + # 默认配置 + DEFAULT_CONFIG = { + # 权重配置 + 'rule_weight': 0.6, # 规则权重 + 'ml_weight': 0.4, # ML权重 + + # 触发阈值 + 'rule_trigger': 60, # 规则直接触发阈值 + 'ml_trigger': 80, # ML直接触发阈值 + 'fusion_trigger': 50, # 融合触发阈值 + + # 特征列表 + 'features': [ + 'alpha', 'alpha_delta', 'amt_ratio', + 'amt_delta', 'rank_pct', 'limit_up_ratio' + ], + + # 序列长度(ML模型需要) + 'seq_len': 30, + } + + def __init__( + self, + config: Dict = None, + checkpoint_dir: str = 'ml/checkpoints', + device: str = 'auto' + ): + self.config = {**self.DEFAULT_CONFIG, **(config or {})} + + # 初始化评分器 + self.rule_scorer = RuleBasedScorer() + self.ml_scorer = MLScorer(checkpoint_dir, device) + + print(f"HybridAnomalyDetector 初始化完成") + print(f" 规则权重: {self.config['rule_weight']}") + print(f" ML权重: {self.config['ml_weight']}") + print(f" ML模型: {'就绪' if self.ml_scorer.is_ready() else '未加载'}") + + def detect( + self, + features: Dict, + sequence: np.ndarray = None + ) -> AnomalyResult: + """ + 检测异动 + + Args: + features: 当前时刻的特征字典 + sequence: 历史序列 (seq_len, n_features),ML模型需要 + Returns: + AnomalyResult: 检测结果 + """ + # 1. 规则评分 + rule_score, rule_details = self.rule_scorer.score(features) + + # 2. ML评分 + ml_score = 0.0 + if sequence is not None and self.ml_scorer.is_ready(): + ml_score = self.ml_scorer.score(sequence) + + # 3. 融合得分 + w1 = self.config['rule_weight'] + w2 = self.config['ml_weight'] + + # 如果ML不可用,全部权重给规则 + if not self.ml_scorer.is_ready(): + w1, w2 = 1.0, 0.0 + + final_score = w1 * rule_score + w2 * ml_score + + # 4. 判断是否异动 + is_anomaly = False + trigger_reason = '' + + if rule_score >= self.config['rule_trigger']: + is_anomaly = True + trigger_reason = f'规则强信号({rule_score:.0f}分)' + elif ml_score >= self.config['ml_trigger']: + is_anomaly = True + trigger_reason = f'ML强信号({ml_score:.0f}分)' + elif final_score >= self.config['fusion_trigger']: + is_anomaly = True + trigger_reason = f'融合触发({final_score:.0f}分)' + + # 5. 判断异动类型 + anomaly_type = self.rule_scorer.get_anomaly_type(features) if is_anomaly else '' + + return AnomalyResult( + is_anomaly=is_anomaly, + final_score=final_score, + rule_score=rule_score, + ml_score=ml_score, + trigger_reason=trigger_reason, + rule_details=rule_details, + anomaly_type=anomaly_type + ) + + def detect_batch( + self, + features_list: List[Dict], + sequences: np.ndarray = None + ) -> List[AnomalyResult]: + """ + 批量检测 + + Args: + features_list: 特征字典列表 + sequences: (batch, seq_len, n_features) + Returns: + List[AnomalyResult] + """ + results = [] + + for i, features in enumerate(features_list): + seq = sequences[i] if sequences is not None else None + result = self.detect(features, seq) + results.append(result) + + return results + + +# ==================== 便捷函数 ==================== + +def create_detector( + checkpoint_dir: str = 'ml/checkpoints', + config: Dict = None +) -> HybridAnomalyDetector: + """创建融合检测器""" + return HybridAnomalyDetector(config, checkpoint_dir) + + +def quick_detect(features: Dict) -> bool: + """ + 快速检测(只用规则,不需要ML模型) + + 适用于: + - 实时检测 + - ML模型未训练完成时 + """ + scorer = RuleBasedScorer() + score, _ = scorer.score(features) + return score >= 50 + + +# ==================== 测试 ==================== + +if __name__ == "__main__": + print("=" * 60) + print("融合异动检测器测试") + print("=" * 60) + + # 创建检测器 + detector = create_detector() + + # 测试用例 + test_cases = [ + { + 'name': '正常情况', + 'features': { + 'alpha': 0.5, + 'alpha_delta': 0.1, + 'amt_ratio': 1.2, + 'amt_delta': 0.1, + 'rank_pct': 0.5, + 'limit_up_ratio': 0.02 + } + }, + { + 'name': 'Alpha异动', + 'features': { + 'alpha': 3.5, + 'alpha_delta': 0.8, + 'amt_ratio': 2.5, + 'amt_delta': 0.5, + 'rank_pct': 0.92, + 'limit_up_ratio': 0.05 + } + }, + { + 'name': '放量异动', + 'features': { + 'alpha': 1.2, + 'alpha_delta': 0.3, + 'amt_ratio': 6.0, + 'amt_delta': 1.5, + 'rank_pct': 0.85, + 'limit_up_ratio': 0.08 + } + }, + { + 'name': '涨停潮', + 'features': { + 'alpha': 2.5, + 'alpha_delta': 0.6, + 'amt_ratio': 3.5, + 'amt_delta': 0.8, + 'rank_pct': 0.98, + 'limit_up_ratio': 0.25 + } + }, + ] + + print("\n" + "-" * 60) + print("测试1: 只用规则(无序列数据)") + print("-" * 60) + + for case in test_cases: + result = detector.detect(case['features']) + + print(f"\n{case['name']}:") + print(f" 异动: {'是' if result.is_anomaly else '否'}") + print(f" 最终得分: {result.final_score:.1f}") + print(f" 规则得分: {result.rule_score:.1f}") + print(f" ML得分: {result.ml_score:.1f}") + if result.is_anomaly: + print(f" 触发原因: {result.trigger_reason}") + print(f" 异动类型: {result.anomaly_type}") + print(f" 触发规则: {list(result.rule_details.keys())}") + + # 测试2: 带序列数据的融合检测 + print("\n" + "-" * 60) + print("测试2: 融合检测(规则 + ML)") + print("-" * 60) + + # 生成模拟序列数据 + seq_len = 30 + n_features = 6 + + # 正常序列:小幅波动 + normal_sequence = np.random.randn(seq_len, n_features) * 0.3 + normal_sequence[:, 0] = np.linspace(0, 0.5, seq_len) # alpha 缓慢上升 + normal_sequence[:, 2] = np.abs(normal_sequence[:, 2]) + 1 # amt_ratio > 0 + + # 异常序列:最后几个时间步突然变化 + anomaly_sequence = np.random.randn(seq_len, n_features) * 0.3 + anomaly_sequence[-5:, 0] = np.linspace(1, 4, 5) # alpha 突然飙升 + anomaly_sequence[-5:, 1] = np.linspace(0.2, 1.5, 5) # alpha_delta 加速 + anomaly_sequence[-5:, 2] = np.linspace(2, 6, 5) # amt_ratio 放量 + anomaly_sequence[:, 2] = np.abs(anomaly_sequence[:, 2]) + 1 + + # 测试正常序列 + normal_features = { + 'alpha': float(normal_sequence[-1, 0]), + 'alpha_delta': float(normal_sequence[-1, 1]), + 'amt_ratio': float(normal_sequence[-1, 2]), + 'amt_delta': float(normal_sequence[-1, 3]), + 'rank_pct': 0.5, + 'limit_up_ratio': 0.02 + } + + result = detector.detect(normal_features, normal_sequence) + print(f"\n正常序列:") + print(f" 异动: {'是' if result.is_anomaly else '否'}") + print(f" 最终得分: {result.final_score:.1f}") + print(f" 规则得分: {result.rule_score:.1f}") + print(f" ML得分: {result.ml_score:.1f}") + + # 测试异常序列 + anomaly_features = { + 'alpha': float(anomaly_sequence[-1, 0]), + 'alpha_delta': float(anomaly_sequence[-1, 1]), + 'amt_ratio': float(anomaly_sequence[-1, 2]), + 'amt_delta': float(anomaly_sequence[-1, 3]), + 'rank_pct': 0.95, + 'limit_up_ratio': 0.15 + } + + result = detector.detect(anomaly_features, anomaly_sequence) + print(f"\n异常序列:") + print(f" 异动: {'是' if result.is_anomaly else '否'}") + print(f" 最终得分: {result.final_score:.1f}") + print(f" 规则得分: {result.rule_score:.1f}") + print(f" ML得分: {result.ml_score:.1f}") + if result.is_anomaly: + print(f" 触发原因: {result.trigger_reason}") + print(f" 异动类型: {result.anomaly_type}") + + print("\n" + "=" * 60) + print("测试完成!") diff --git a/ml/detector_v2.py b/ml/detector_v2.py new file mode 100644 index 00000000..e4e6f1ae --- /dev/null +++ b/ml/detector_v2.py @@ -0,0 +1,716 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +异动检测器 V2 - 基于时间片对齐 + 持续性确认 + +核心改进: +1. Z-Score 特征:相对于同时间片历史的偏离 +2. 短序列 LSTM:10分钟序列,开盘即可用 +3. 持续性确认:5分钟窗口内60%时刻超标才确认为异动 + +检测流程: +1. 计算当前时刻的 Z-Score(对比同时间片历史基线) +2. 构建最近10分钟的 Z-Score 序列 +3. LSTM 计算重构误差(ML分数) +4. 规则评分(基于 Z-Score 的规则) +5. 滑动窗口确认:最近5分钟内是否有足够多的时刻超标 +6. 只有通过持续性确认的才输出为异动 +""" + +import os +import sys +import json +import pickle +from datetime import datetime, timedelta +from typing import Dict, List, Optional, Tuple +from collections import defaultdict, deque + +import numpy as np +import pandas as pd +import torch +from sqlalchemy import create_engine, text +from elasticsearch import Elasticsearch +from clickhouse_driver import Client + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from ml.model import TransformerAutoencoder + +# ==================== 配置 ==================== + +MYSQL_ENGINE = create_engine( + "mysql+pymysql://root:Zzl5588161!@192.168.1.5:3306/stock", + echo=False +) + +ES_CLIENT = Elasticsearch(['http://127.0.0.1:9200']) +ES_INDEX = 'concept_library_v3' + +CLICKHOUSE_CONFIG = { + 'host': '127.0.0.1', + 'port': 9000, + 'user': 'default', + 'password': 'Zzl33818!', + 'database': 'stock' +} + +REFERENCE_INDEX = '000001.SH' + +# 检测配置 +CONFIG = { + # 序列配置 + 'seq_len': 10, # LSTM 序列长度(分钟) + + # 持续性确认配置(核心!) + 'confirm_window': 5, # 确认窗口(分钟) + 'confirm_ratio': 0.6, # 确认比例(60%时刻需要超标) + + # Z-Score 阈值 + 'alpha_zscore_threshold': 2.0, # Alpha Z-Score 阈值 + 'amt_zscore_threshold': 2.5, # 成交额 Z-Score 阈值 + + # 融合权重 + 'rule_weight': 0.5, + 'ml_weight': 0.5, + + # 触发阈值 + 'rule_trigger': 60, + 'ml_trigger': 70, + 'fusion_trigger': 50, + + # 冷却期 + 'cooldown_minutes': 10, + 'max_alerts_per_minute': 15, + + # Z-Score 截断 + 'zscore_clip': 5.0, +} + +# V2 特征列表 +FEATURES_V2 = [ + 'alpha_zscore', 'amt_zscore', 'rank_zscore', + 'momentum_3m', 'momentum_5m', 'limit_up_ratio' +] + + +# ==================== 工具函数 ==================== + +def get_ch_client(): + return Client(**CLICKHOUSE_CONFIG) + + +def code_to_ch_format(code: str) -> str: + if not code or len(code) != 6 or not code.isdigit(): + return None + if code.startswith('6'): + return f"{code}.SH" + elif code.startswith('0') or code.startswith('3'): + return f"{code}.SZ" + else: + return f"{code}.BJ" + + +def time_to_slot(ts) -> str: + """时间戳转时间片(HH:MM)""" + if isinstance(ts, str): + return ts + return ts.strftime('%H:%M') + + +# ==================== 基线加载 ==================== + +def load_baselines(baseline_dir: str = 'ml/data_v2/baselines') -> Dict[str, pd.DataFrame]: + """加载时间片基线""" + baseline_file = os.path.join(baseline_dir, 'baselines.pkl') + if os.path.exists(baseline_file): + with open(baseline_file, 'rb') as f: + return pickle.load(f) + return {} + + +# ==================== 规则评分(基于 Z-Score)==================== + +def score_rules_zscore(row: Dict) -> Tuple[float, List[str]]: + """ + 基于 Z-Score 的规则评分 + + 设计思路:Z-Score 已经标准化,直接用阈值判断 + """ + score = 0.0 + triggered = [] + + alpha_zscore = row.get('alpha_zscore', 0) + amt_zscore = row.get('amt_zscore', 0) + rank_zscore = row.get('rank_zscore', 0) + momentum_3m = row.get('momentum_3m', 0) + momentum_5m = row.get('momentum_5m', 0) + limit_up_ratio = row.get('limit_up_ratio', 0) + + alpha_zscore_abs = abs(alpha_zscore) + amt_zscore_abs = abs(amt_zscore) + + # ========== Alpha Z-Score 规则 ========== + if alpha_zscore_abs >= 4.0: + score += 25 + triggered.append('alpha_zscore_extreme') + elif alpha_zscore_abs >= 3.0: + score += 18 + triggered.append('alpha_zscore_strong') + elif alpha_zscore_abs >= 2.0: + score += 10 + triggered.append('alpha_zscore_moderate') + + # ========== 成交额 Z-Score 规则 ========== + if amt_zscore >= 4.0: + score += 20 + triggered.append('amt_zscore_extreme') + elif amt_zscore >= 3.0: + score += 12 + triggered.append('amt_zscore_strong') + elif amt_zscore >= 2.0: + score += 6 + triggered.append('amt_zscore_moderate') + + # ========== 排名 Z-Score 规则 ========== + if abs(rank_zscore) >= 3.0: + score += 15 + triggered.append('rank_zscore_extreme') + elif abs(rank_zscore) >= 2.0: + score += 8 + triggered.append('rank_zscore_strong') + + # ========== 动量规则 ========== + if momentum_3m >= 1.0: + score += 12 + triggered.append('momentum_3m_strong') + elif momentum_3m >= 0.5: + score += 6 + triggered.append('momentum_3m_moderate') + + if momentum_5m >= 1.5: + score += 10 + triggered.append('momentum_5m_strong') + + # ========== 涨停比例规则 ========== + if limit_up_ratio >= 0.3: + score += 20 + triggered.append('limit_up_extreme') + elif limit_up_ratio >= 0.15: + score += 12 + triggered.append('limit_up_strong') + elif limit_up_ratio >= 0.08: + score += 5 + triggered.append('limit_up_moderate') + + # ========== 组合规则 ========== + # Alpha Z-Score + 成交额放大 + if alpha_zscore_abs >= 2.0 and amt_zscore >= 2.0: + score += 15 + triggered.append('combo_alpha_amt') + + # Alpha Z-Score + 涨停 + if alpha_zscore_abs >= 2.0 and limit_up_ratio >= 0.1: + score += 12 + triggered.append('combo_alpha_limitup') + + return min(score, 100), triggered + + +# ==================== ML 评分器 ==================== + +class MLScorerV2: + """V2 ML 评分器""" + + def __init__(self, model_dir: str = 'ml/checkpoints_v2'): + self.model_dir = model_dir + self.model = None + self.thresholds = None + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self._load_model() + + def _load_model(self): + """加载模型和阈值""" + model_path = os.path.join(self.model_dir, 'best_model.pt') + threshold_path = os.path.join(self.model_dir, 'thresholds.json') + config_path = os.path.join(self.model_dir, 'config.json') + + if not os.path.exists(model_path): + print(f"警告: 模型文件不存在: {model_path}") + return + + # 加载配置 + with open(config_path, 'r') as f: + config = json.load(f) + + # 创建模型 + model_config = config.get('model', {}) + self.model = TransformerAutoencoder(**model_config) + + # 加载权重 + checkpoint = torch.load(model_path, map_location=self.device) + self.model.load_state_dict(checkpoint['model_state_dict']) + self.model.to(self.device) + self.model.eval() + + # 加载阈值 + if os.path.exists(threshold_path): + with open(threshold_path, 'r') as f: + self.thresholds = json.load(f) + + print(f"V2 模型加载完成: {model_path}") + + @torch.no_grad() + def score_batch(self, sequences: np.ndarray) -> np.ndarray: + """ + 批量计算 ML 分数 + + 返回 0-100 的分数,越高越异常 + """ + if self.model is None: + return np.zeros(len(sequences)) + + # 转换为 tensor + x = torch.FloatTensor(sequences).to(self.device) + + # 计算重构误差 + errors = self.model.compute_reconstruction_error(x, reduction='none') + # 取最后一个时刻的误差 + last_errors = errors[:, -1].cpu().numpy() + + # 转换为 0-100 分数 + if self.thresholds: + p50 = self.thresholds.get('median', 0.1) + p99 = self.thresholds.get('p99', 1.0) + + # 线性映射:p50 -> 50分,p99 -> 99分 + scores = 50 + (last_errors - p50) / (p99 - p50) * 49 + scores = np.clip(scores, 0, 100) + else: + # 没有阈值时,简单归一化 + scores = last_errors * 100 + scores = np.clip(scores, 0, 100) + + return scores + + +# ==================== 实时数据管理器 ==================== + +class RealtimeDataManagerV2: + """ + V2 实时数据管理器 + + 维护: + 1. 每个概念的历史 Z-Score 序列(用于 LSTM 输入) + 2. 每个概念的异动候选队列(用于持续性确认) + """ + + def __init__(self, concepts: List[dict], baselines: Dict[str, pd.DataFrame]): + self.concepts = {c['concept_id']: c for c in concepts} + self.baselines = baselines + + # 概念到股票的映射 + self.concept_stocks = {c['concept_id']: set(c['stocks']) for c in concepts} + + # 历史 Z-Score 序列(每个概念) + # {concept_id: deque([(timestamp, features_dict), ...], maxlen=seq_len)} + self.zscore_history = defaultdict(lambda: deque(maxlen=CONFIG['seq_len'])) + + # 异动候选队列(用于持续性确认) + # {concept_id: deque([(timestamp, score), ...], maxlen=confirm_window)} + self.anomaly_candidates = defaultdict(lambda: deque(maxlen=CONFIG['confirm_window'])) + + # 冷却期记录 + self.cooldown = {} + + # 上一次更新的时间戳 + self.last_timestamp = None + + def compute_zscore_features( + self, + concept_id: str, + timestamp, + alpha: float, + total_amt: float, + rank_pct: float, + limit_up_ratio: float + ) -> Optional[Dict]: + """计算单个概念单个时刻的 Z-Score 特征""" + if concept_id not in self.baselines: + return None + + baseline = self.baselines[concept_id] + time_slot = time_to_slot(timestamp) + + # 查找对应时间片的基线 + bl_row = baseline[baseline['time_slot'] == time_slot] + if bl_row.empty: + return None + + bl = bl_row.iloc[0] + + # 检查样本量 + if bl.get('sample_count', 0) < 10: + return None + + # 计算 Z-Score + alpha_zscore = (alpha - bl['alpha_mean']) / bl['alpha_std'] + amt_zscore = (total_amt - bl['amt_mean']) / bl['amt_std'] + rank_zscore = (rank_pct - bl['rank_mean']) / bl['rank_std'] + + # 截断 + clip = CONFIG['zscore_clip'] + alpha_zscore = np.clip(alpha_zscore, -clip, clip) + amt_zscore = np.clip(amt_zscore, -clip, clip) + rank_zscore = np.clip(rank_zscore, -clip, clip) + + # 计算动量(需要历史) + history = self.zscore_history[concept_id] + momentum_3m = 0 + momentum_5m = 0 + + if len(history) >= 3: + recent_alphas = [h[1]['alpha'] for h in list(history)[-3:]] + older_alphas = [h[1]['alpha'] for h in list(history)[-6:-3]] if len(history) >= 6 else [alpha] + momentum_3m = np.mean(recent_alphas) - np.mean(older_alphas) + + if len(history) >= 5: + recent_alphas = [h[1]['alpha'] for h in list(history)[-5:]] + older_alphas = [h[1]['alpha'] for h in list(history)[-10:-5]] if len(history) >= 10 else [alpha] + momentum_5m = np.mean(recent_alphas) - np.mean(older_alphas) + + return { + 'alpha': alpha, + 'alpha_zscore': alpha_zscore, + 'amt_zscore': amt_zscore, + 'rank_zscore': rank_zscore, + 'momentum_3m': momentum_3m, + 'momentum_5m': momentum_5m, + 'limit_up_ratio': limit_up_ratio, + 'total_amt': total_amt, + 'rank_pct': rank_pct, + } + + def update(self, concept_id: str, timestamp, features: Dict): + """更新概念的历史数据""" + self.zscore_history[concept_id].append((timestamp, features)) + + def get_sequence(self, concept_id: str) -> Optional[np.ndarray]: + """获取用于 LSTM 的序列""" + history = self.zscore_history[concept_id] + + if len(history) < CONFIG['seq_len']: + return None + + # 提取特征 + feature_list = [] + for _, features in history: + feature_list.append([ + features['alpha_zscore'], + features['amt_zscore'], + features['rank_zscore'], + features['momentum_3m'], + features['momentum_5m'], + features['limit_up_ratio'], + ]) + + return np.array(feature_list) + + def add_anomaly_candidate(self, concept_id: str, timestamp, score: float): + """添加异动候选""" + self.anomaly_candidates[concept_id].append((timestamp, score)) + + def check_sustained_anomaly(self, concept_id: str, threshold: float) -> Tuple[bool, float]: + """ + 检查是否为持续性异动 + + 返回:(是否确认, 确认比例) + """ + candidates = self.anomaly_candidates[concept_id] + + if len(candidates) < CONFIG['confirm_window']: + return False, 0.0 + + # 统计超过阈值的时刻数量 + exceed_count = sum(1 for _, score in candidates if score >= threshold) + ratio = exceed_count / len(candidates) + + return ratio >= CONFIG['confirm_ratio'], ratio + + def check_cooldown(self, concept_id: str, timestamp) -> bool: + """检查是否在冷却期""" + if concept_id not in self.cooldown: + return False + + last_alert = self.cooldown[concept_id] + try: + diff = (timestamp - last_alert).total_seconds() / 60 + return diff < CONFIG['cooldown_minutes'] + except: + return False + + def set_cooldown(self, concept_id: str, timestamp): + """设置冷却期""" + self.cooldown[concept_id] = timestamp + + +# ==================== 异动检测器 V2 ==================== + +class AnomalyDetectorV2: + """ + V2 异动检测器 + + 核心流程: + 1. 获取实时数据 + 2. 计算 Z-Score 特征 + 3. 规则评分 + ML 评分 + 4. 持续性确认 + 5. 输出异动 + """ + + def __init__( + self, + model_dir: str = 'ml/checkpoints_v2', + baseline_dir: str = 'ml/data_v2/baselines' + ): + # 加载概念 + self.concepts = self._load_concepts() + + # 加载基线 + self.baselines = load_baselines(baseline_dir) + print(f"加载了 {len(self.baselines)} 个概念的基线") + + # 初始化 ML 评分器 + self.ml_scorer = MLScorerV2(model_dir) + + # 初始化数据管理器 + self.data_manager = RealtimeDataManagerV2(self.concepts, self.baselines) + + # 收集所有股票 + self.all_stocks = list(set(s for c in self.concepts for s in c['stocks'])) + + def _load_concepts(self) -> List[dict]: + """从 ES 加载概念""" + concepts = [] + query = {"query": {"match_all": {}}, "size": 100, "_source": ["concept_id", "concept", "stocks"]} + + resp = ES_CLIENT.search(index=ES_INDEX, body=query, scroll='2m') + scroll_id = resp['_scroll_id'] + hits = resp['hits']['hits'] + + while len(hits) > 0: + for hit in hits: + source = hit['_source'] + stocks = [] + if 'stocks' in source and isinstance(source['stocks'], list): + for stock in source['stocks']: + if isinstance(stock, dict) and 'code' in stock and stock['code']: + stocks.append(stock['code']) + if stocks: + concepts.append({ + 'concept_id': source.get('concept_id'), + 'concept_name': source.get('concept'), + 'stocks': stocks + }) + + resp = ES_CLIENT.scroll(scroll_id=scroll_id, scroll='2m') + scroll_id = resp['_scroll_id'] + hits = resp['hits']['hits'] + + ES_CLIENT.clear_scroll(scroll_id=scroll_id) + print(f"加载了 {len(concepts)} 个概念") + return concepts + + def detect(self, trade_date: str) -> List[Dict]: + """ + 检测指定日期的异动 + + 返回异动列表 + """ + print(f"\n检测 {trade_date} 的异动...") + + # 获取原始数据 + raw_features = self._compute_raw_features(trade_date) + if raw_features.empty: + print("无数据") + return [] + + # 按时间排序 + timestamps = sorted(raw_features['timestamp'].unique()) + print(f"时间点数: {len(timestamps)}") + + all_alerts = [] + + for ts in timestamps: + ts_data = raw_features[raw_features['timestamp'] == ts] + ts_alerts = self._process_timestamp(ts, ts_data, trade_date) + all_alerts.extend(ts_alerts) + + print(f"共检测到 {len(all_alerts)} 个异动") + return all_alerts + + def _compute_raw_features(self, trade_date: str) -> pd.DataFrame: + """计算原始特征(同 prepare_data_v2)""" + # 这里简化处理,直接调用数据准备逻辑 + from prepare_data_v2 import compute_raw_concept_features + return compute_raw_concept_features(trade_date, self.concepts, self.all_stocks) + + def _process_timestamp(self, timestamp, ts_data: pd.DataFrame, trade_date: str) -> List[Dict]: + """处理单个时间戳""" + alerts = [] + candidates = [] # (concept_id, features, rule_score, triggered_rules) + + for _, row in ts_data.iterrows(): + concept_id = row['concept_id'] + + # 计算 Z-Score 特征 + features = self.data_manager.compute_zscore_features( + concept_id, timestamp, + row['alpha'], row['total_amt'], row['rank_pct'], row['limit_up_ratio'] + ) + + if features is None: + continue + + # 更新历史 + self.data_manager.update(concept_id, timestamp, features) + + # 规则评分 + rule_score, triggered_rules = score_rules_zscore(features) + + # 收集候选 + candidates.append((concept_id, features, rule_score, triggered_rules)) + + if not candidates: + return [] + + # 批量 ML 评分 + sequences = [] + valid_candidates = [] + + for concept_id, features, rule_score, triggered_rules in candidates: + seq = self.data_manager.get_sequence(concept_id) + if seq is not None: + sequences.append(seq) + valid_candidates.append((concept_id, features, rule_score, triggered_rules)) + + if not sequences: + return [] + + sequences = np.array(sequences) + ml_scores = self.ml_scorer.score_batch(sequences) + + # 融合评分 + 持续性确认 + for i, (concept_id, features, rule_score, triggered_rules) in enumerate(valid_candidates): + ml_score = ml_scores[i] + final_score = CONFIG['rule_weight'] * rule_score + CONFIG['ml_weight'] * ml_score + + # 判断是否触发 + is_triggered = ( + rule_score >= CONFIG['rule_trigger'] or + ml_score >= CONFIG['ml_trigger'] or + final_score >= CONFIG['fusion_trigger'] + ) + + # 添加到候选队列 + self.data_manager.add_anomaly_candidate(concept_id, timestamp, final_score) + + if not is_triggered: + continue + + # 检查冷却期 + if self.data_manager.check_cooldown(concept_id, timestamp): + continue + + # 持续性确认 + is_sustained, confirm_ratio = self.data_manager.check_sustained_anomaly( + concept_id, CONFIG['fusion_trigger'] + ) + + if not is_sustained: + continue + + # 确认为异动! + self.data_manager.set_cooldown(concept_id, timestamp) + + # 确定异动类型 + alpha = features['alpha'] + if alpha >= 1.5: + alert_type = 'surge_up' + elif alpha <= -1.5: + alert_type = 'surge_down' + elif features['amt_zscore'] >= 3.0: + alert_type = 'volume_spike' + else: + alert_type = 'surge' + + # 确定触发原因 + if rule_score >= CONFIG['rule_trigger']: + trigger_reason = f'规则({rule_score:.0f})+持续确认({confirm_ratio:.0%})' + elif ml_score >= CONFIG['ml_trigger']: + trigger_reason = f'ML({ml_score:.0f})+持续确认({confirm_ratio:.0%})' + else: + trigger_reason = f'融合({final_score:.0f})+持续确认({confirm_ratio:.0%})' + + alerts.append({ + 'concept_id': concept_id, + 'concept_name': self.data_manager.concepts.get(concept_id, {}).get('concept_name', concept_id), + 'alert_time': timestamp, + 'trade_date': trade_date, + 'alert_type': alert_type, + 'final_score': final_score, + 'rule_score': rule_score, + 'ml_score': ml_score, + 'trigger_reason': trigger_reason, + 'confirm_ratio': confirm_ratio, + 'alpha': alpha, + 'alpha_zscore': features['alpha_zscore'], + 'amt_zscore': features['amt_zscore'], + 'rank_zscore': features['rank_zscore'], + 'momentum_3m': features['momentum_3m'], + 'momentum_5m': features['momentum_5m'], + 'limit_up_ratio': features['limit_up_ratio'], + 'triggered_rules': triggered_rules, + }) + + # 每分钟最多 N 个 + if len(alerts) > CONFIG['max_alerts_per_minute']: + alerts = sorted(alerts, key=lambda x: x['final_score'], reverse=True) + alerts = alerts[:CONFIG['max_alerts_per_minute']] + + return alerts + + +# ==================== 主函数 ==================== + +def main(): + import argparse + + parser = argparse.ArgumentParser(description='V2 异动检测器') + parser.add_argument('--date', type=str, default=None, help='检测日期(默认今天)') + parser.add_argument('--model_dir', type=str, default='ml/checkpoints_v2') + parser.add_argument('--baseline_dir', type=str, default='ml/data_v2/baselines') + + args = parser.parse_args() + + trade_date = args.date or datetime.now().strftime('%Y-%m-%d') + + detector = AnomalyDetectorV2( + model_dir=args.model_dir, + baseline_dir=args.baseline_dir + ) + + alerts = detector.detect(trade_date) + + print(f"\n检测结果:") + for alert in alerts[:20]: + print(f" [{alert['alert_time'].strftime('%H:%M') if hasattr(alert['alert_time'], 'strftime') else alert['alert_time']}] " + f"{alert['concept_name']} ({alert['alert_type']}) " + f"分数={alert['final_score']:.0f} " + f"确认率={alert['confirm_ratio']:.0%}") + + if len(alerts) > 20: + print(f" ... 共 {len(alerts)} 个异动") + + +if __name__ == "__main__": + main() diff --git a/ml/enhanced_detector.py b/ml/enhanced_detector.py new file mode 100644 index 00000000..b64c63a8 --- /dev/null +++ b/ml/enhanced_detector.py @@ -0,0 +1,526 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +增强版概念异动检测器 + +融合两种检测方法: +1. Alpha-based Z-Score(规则方法,实时性好) +2. Transformer Autoencoder(ML方法,更准确) + +使用策略: +- 当 ML 模型可用且历史数据足够时,优先使用 ML 方法 +- 否则回退到 Alpha-based 方法 +- 可以配置两种方法的融合权重 +""" + +import os +import sys +import logging +from datetime import datetime +from pathlib import Path +from typing import Dict, List, Tuple, Optional +from dataclasses import dataclass, field +from collections import deque +import numpy as np + +# 添加父目录到路径 +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +logger = logging.getLogger(__name__) + + +# ==================== 配置 ==================== + +ENHANCED_CONFIG = { + # 融合策略 + 'fusion_mode': 'adaptive', # 'ml_only', 'alpha_only', 'adaptive', 'ensemble' + + # ML 权重(在 ensemble 模式下) + 'ml_weight': 0.6, + 'alpha_weight': 0.4, + + # ML 模型配置 + 'ml_checkpoint_dir': 'ml/checkpoints', + 'ml_threshold_key': 'p95', # p90, p95, p99 + + # Alpha 配置(与 concept_alert_alpha.py 一致) + 'alpha_zscore_threshold': 2.0, + 'alpha_absolute_threshold': 1.5, + 'alpha_history_window': 60, + 'alpha_min_history': 5, + + # 共享配置 + 'cooldown_minutes': 8, + 'max_alerts_per_minute': 15, + 'min_alpha_abs': 0.5, +} + +# 特征配置(与训练一致) +FEATURE_NAMES = [ + 'alpha', + 'alpha_delta', + 'amt_ratio', + 'amt_delta', + 'rank_pct', + 'limit_up_ratio', +] + + +# ==================== 数据结构 ==================== + +@dataclass +class AlphaStats: + """概念的Alpha统计信息""" + history: deque = field(default_factory=lambda: deque(maxlen=ENHANCED_CONFIG['alpha_history_window'])) + mean: float = 0.0 + std: float = 1.0 + + def update(self, alpha: float): + self.history.append(alpha) + if len(self.history) >= 2: + self.mean = np.mean(self.history) + self.std = max(np.std(self.history), 0.1) + + def get_zscore(self, alpha: float) -> float: + if len(self.history) < ENHANCED_CONFIG['alpha_min_history']: + return 0.0 + return (alpha - self.mean) / self.std + + def is_ready(self) -> bool: + return len(self.history) >= ENHANCED_CONFIG['alpha_min_history'] + + +@dataclass +class ConceptFeatures: + """概念的实时特征""" + alpha: float = 0.0 + alpha_delta: float = 0.0 + amt_ratio: float = 1.0 + amt_delta: float = 0.0 + rank_pct: float = 0.5 + limit_up_ratio: float = 0.0 + + def to_dict(self) -> Dict[str, float]: + return { + 'alpha': self.alpha, + 'alpha_delta': self.alpha_delta, + 'amt_ratio': self.amt_ratio, + 'amt_delta': self.amt_delta, + 'rank_pct': self.rank_pct, + 'limit_up_ratio': self.limit_up_ratio, + } + + +# ==================== 增强检测器 ==================== + +class EnhancedAnomalyDetector: + """ + 增强版异动检测器 + + 融合 Alpha-based 和 ML 两种方法 + """ + + def __init__( + self, + config: Dict = None, + ml_enabled: bool = True + ): + self.config = config or ENHANCED_CONFIG + self.ml_enabled = ml_enabled + self.ml_detector = None + + # Alpha 统计 + self.alpha_stats: Dict[str, AlphaStats] = {} + + # 特征历史(用于计算 delta) + self.feature_history: Dict[str, deque] = {} + + # 冷却记录 + self.cooldown_cache: Dict[str, datetime] = {} + + # 尝试加载 ML 模型 + if ml_enabled: + self._load_ml_model() + + logger.info(f"EnhancedAnomalyDetector 初始化完成") + logger.info(f" 融合模式: {self.config['fusion_mode']}") + logger.info(f" ML 可用: {self.ml_detector is not None}") + + def _load_ml_model(self): + """加载 ML 模型""" + try: + from inference import ConceptAnomalyDetector + checkpoint_dir = Path(__file__).parent / 'checkpoints' + + if (checkpoint_dir / 'best_model.pt').exists(): + self.ml_detector = ConceptAnomalyDetector( + checkpoint_dir=str(checkpoint_dir), + threshold_key=self.config['ml_threshold_key'] + ) + logger.info("ML 模型加载成功") + else: + logger.warning(f"ML 模型不存在: {checkpoint_dir / 'best_model.pt'}") + except Exception as e: + logger.warning(f"ML 模型加载失败: {e}") + self.ml_detector = None + + def _get_alpha_stats(self, concept_id: str) -> AlphaStats: + """获取或创建 Alpha 统计""" + if concept_id not in self.alpha_stats: + self.alpha_stats[concept_id] = AlphaStats() + return self.alpha_stats[concept_id] + + def _get_feature_history(self, concept_id: str) -> deque: + """获取特征历史""" + if concept_id not in self.feature_history: + self.feature_history[concept_id] = deque(maxlen=10) + return self.feature_history[concept_id] + + def _check_cooldown(self, concept_id: str, current_time: datetime) -> bool: + """检查冷却""" + if concept_id not in self.cooldown_cache: + return False + + last_alert = self.cooldown_cache[concept_id] + cooldown_td = (current_time - last_alert).total_seconds() / 60 + + return cooldown_td < self.config['cooldown_minutes'] + + def _set_cooldown(self, concept_id: str, current_time: datetime): + """设置冷却""" + self.cooldown_cache[concept_id] = current_time + + def compute_features( + self, + concept_id: str, + alpha: float, + amt_ratio: float, + rank_pct: float, + limit_up_ratio: float + ) -> ConceptFeatures: + """ + 计算概念的完整特征 + + Args: + concept_id: 概念ID + alpha: 当前超额收益 + amt_ratio: 成交额比率 + rank_pct: 排名百分位 + limit_up_ratio: 涨停股占比 + + Returns: + 完整特征 + """ + history = self._get_feature_history(concept_id) + + # 计算变化率 + alpha_delta = 0.0 + amt_delta = 0.0 + + if len(history) > 0: + last_features = history[-1] + alpha_delta = alpha - last_features.alpha + if last_features.amt_ratio > 0: + amt_delta = (amt_ratio - last_features.amt_ratio) / last_features.amt_ratio + + features = ConceptFeatures( + alpha=alpha, + alpha_delta=alpha_delta, + amt_ratio=amt_ratio, + amt_delta=amt_delta, + rank_pct=rank_pct, + limit_up_ratio=limit_up_ratio, + ) + + # 更新历史 + history.append(features) + + return features + + def detect_alpha_anomaly( + self, + concept_id: str, + alpha: float + ) -> Tuple[bool, float, str]: + """ + Alpha-based 异动检测 + + Returns: + is_anomaly: 是否异动 + score: 异动分数(Z-Score 绝对值) + reason: 触发原因 + """ + stats = self._get_alpha_stats(concept_id) + + # 计算 Z-Score(在更新前) + zscore = stats.get_zscore(alpha) + + # 更新统计 + stats.update(alpha) + + # 判断 + if stats.is_ready(): + if abs(zscore) >= self.config['alpha_zscore_threshold']: + return True, abs(zscore), f"Z={zscore:.2f}" + else: + if abs(alpha) >= self.config['alpha_absolute_threshold']: + fake_zscore = alpha / 0.5 + return True, abs(fake_zscore), f"Alpha={alpha:+.2f}%" + + return False, abs(zscore) if zscore else 0.0, "" + + def detect_ml_anomaly( + self, + concept_id: str, + features: ConceptFeatures + ) -> Tuple[bool, float]: + """ + ML-based 异动检测 + + Returns: + is_anomaly: 是否异动 + score: 异动分数(重构误差) + """ + if self.ml_detector is None: + return False, 0.0 + + try: + is_anomaly, score = self.ml_detector.detect( + concept_id, + features.to_dict() + ) + return is_anomaly, score or 0.0 + except Exception as e: + logger.warning(f"ML 检测失败: {e}") + return False, 0.0 + + def detect( + self, + concept_id: str, + concept_name: str, + alpha: float, + amt_ratio: float, + rank_pct: float, + limit_up_ratio: float, + change_pct: float, + index_change: float, + current_time: datetime, + **extra_data + ) -> Optional[Dict]: + """ + 融合检测 + + Args: + concept_id: 概念ID + concept_name: 概念名称 + alpha: 超额收益 + amt_ratio: 成交额比率 + rank_pct: 排名百分位 + limit_up_ratio: 涨停股占比 + change_pct: 概念涨跌幅 + index_change: 大盘涨跌幅 + current_time: 当前时间 + **extra_data: 其他数据(limit_up_count, stock_count 等) + + Returns: + 异动信息(如果触发),否则 None + """ + # Alpha 太小,不关注 + if abs(alpha) < self.config['min_alpha_abs']: + return None + + # 检查冷却 + if self._check_cooldown(concept_id, current_time): + return None + + # 计算特征 + features = self.compute_features( + concept_id, alpha, amt_ratio, rank_pct, limit_up_ratio + ) + + # 执行检测 + fusion_mode = self.config['fusion_mode'] + + alpha_anomaly, alpha_score, alpha_reason = self.detect_alpha_anomaly(concept_id, alpha) + ml_anomaly, ml_score = False, 0.0 + + if fusion_mode in ('ml_only', 'adaptive', 'ensemble'): + ml_anomaly, ml_score = self.detect_ml_anomaly(concept_id, features) + + # 根据融合模式判断 + is_anomaly = False + final_score = 0.0 + detection_method = '' + + if fusion_mode == 'alpha_only': + is_anomaly = alpha_anomaly + final_score = alpha_score + detection_method = 'alpha' + + elif fusion_mode == 'ml_only': + is_anomaly = ml_anomaly + final_score = ml_score + detection_method = 'ml' + + elif fusion_mode == 'adaptive': + # 优先 ML,回退 Alpha + if self.ml_detector and ml_score > 0: + is_anomaly = ml_anomaly + final_score = ml_score + detection_method = 'ml' + else: + is_anomaly = alpha_anomaly + final_score = alpha_score + detection_method = 'alpha' + + elif fusion_mode == 'ensemble': + # 加权融合 + # 归一化分数 + norm_alpha = min(alpha_score / 5.0, 1.0) # Z > 5 视为 1.0 + norm_ml = min(ml_score / (self.ml_detector.threshold if self.ml_detector else 1.0), 1.0) + + final_score = ( + self.config['alpha_weight'] * norm_alpha + + self.config['ml_weight'] * norm_ml + ) + is_anomaly = final_score > 0.5 or alpha_anomaly or ml_anomaly + detection_method = 'ensemble' + + if not is_anomaly: + return None + + # 构建异动记录 + self._set_cooldown(concept_id, current_time) + + alert_type = 'surge_up' if alpha > 0 else 'surge_down' + + alert = { + 'concept_id': concept_id, + 'concept_name': concept_name, + 'alert_type': alert_type, + 'alert_time': current_time, + 'change_pct': change_pct, + 'alpha': alpha, + 'alpha_zscore': alpha_score, + 'index_change_pct': index_change, + 'detection_method': detection_method, + 'alpha_score': alpha_score, + 'ml_score': ml_score, + 'final_score': final_score, + **extra_data + } + + return alert + + def batch_detect( + self, + concepts_data: List[Dict], + current_time: datetime + ) -> List[Dict]: + """ + 批量检测 + + Args: + concepts_data: 概念数据列表 + current_time: 当前时间 + + Returns: + 异动列表(按分数排序,限制数量) + """ + alerts = [] + + for data in concepts_data: + alert = self.detect( + concept_id=data['concept_id'], + concept_name=data['concept_name'], + alpha=data.get('alpha', 0), + amt_ratio=data.get('amt_ratio', 1.0), + rank_pct=data.get('rank_pct', 0.5), + limit_up_ratio=data.get('limit_up_ratio', 0), + change_pct=data.get('change_pct', 0), + index_change=data.get('index_change', 0), + current_time=current_time, + limit_up_count=data.get('limit_up_count', 0), + limit_down_count=data.get('limit_down_count', 0), + stock_count=data.get('stock_count', 0), + concept_type=data.get('concept_type', 'leaf'), + ) + + if alert: + alerts.append(alert) + + # 排序并限制数量 + alerts.sort(key=lambda x: x['final_score'], reverse=True) + return alerts[:self.config['max_alerts_per_minute']] + + def reset(self): + """重置所有状态(新交易日)""" + self.alpha_stats.clear() + self.feature_history.clear() + self.cooldown_cache.clear() + + if self.ml_detector: + self.ml_detector.clear_history() + + logger.info("检测器状态已重置") + + +# ==================== 测试 ==================== + +if __name__ == "__main__": + import random + + print("测试 EnhancedAnomalyDetector...") + + # 初始化 + detector = EnhancedAnomalyDetector(ml_enabled=False) # 不加载 ML(可能不存在) + + # 模拟数据 + concepts = [ + {'concept_id': 'ai_001', 'concept_name': '人工智能'}, + {'concept_id': 'chip_002', 'concept_name': '芯片半导体'}, + {'concept_id': 'car_003', 'concept_name': '新能源汽车'}, + ] + + print("\n模拟实时检测...") + current_time = datetime.now() + + for minute in range(50): + concepts_data = [] + + for c in concepts: + # 生成随机数据 + alpha = random.gauss(0, 0.8) + amt_ratio = max(0.3, random.gauss(1, 0.3)) + rank_pct = random.random() + limit_up_ratio = random.random() * 0.1 + + # 模拟异动(第30分钟人工智能暴涨) + if minute == 30 and c['concept_id'] == 'ai_001': + alpha = 4.5 + amt_ratio = 2.5 + limit_up_ratio = 0.3 + + concepts_data.append({ + **c, + 'alpha': alpha, + 'amt_ratio': amt_ratio, + 'rank_pct': rank_pct, + 'limit_up_ratio': limit_up_ratio, + 'change_pct': alpha + 0.5, + 'index_change': 0.5, + }) + + # 检测 + alerts = detector.batch_detect(concepts_data, current_time) + + if alerts: + for alert in alerts: + print(f" t={minute:02d} 🔥 {alert['concept_name']} " + f"Alpha={alert['alpha']:+.2f}% " + f"Score={alert['final_score']:.2f} " + f"Method={alert['detection_method']}") + + current_time = current_time.replace(minute=current_time.minute + 1 if current_time.minute < 59 else 0) + + print("\n测试完成!") diff --git a/ml/inference.py b/ml/inference.py new file mode 100644 index 00000000..4e704f4c --- /dev/null +++ b/ml/inference.py @@ -0,0 +1,455 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +概念异动检测推理服务 + +在实时场景中使用训练好的 Transformer Autoencoder 进行异动检测 + +使用方法: + from ml.inference import ConceptAnomalyDetector + + detector = ConceptAnomalyDetector('ml/checkpoints') + + # 检测异动 + features = {...} # 实时特征数据 + is_anomaly, score = detector.detect(features) +""" + +import os +import json +from pathlib import Path +from typing import Dict, List, Tuple, Optional +from collections import deque + +import numpy as np +import torch + +from model import TransformerAutoencoder + + +class ConceptAnomalyDetector: + """ + 概念异动检测器 + + 使用训练好的 Transformer Autoencoder 进行实时异动检测 + """ + + def __init__( + self, + checkpoint_dir: str = 'ml/checkpoints', + device: str = 'auto', + threshold_key: str = 'p95' + ): + """ + 初始化检测器 + + Args: + checkpoint_dir: 模型检查点目录 + device: 设备 (auto/cuda/cpu) + threshold_key: 使用的阈值键 (p90/p95/p99) + """ + self.checkpoint_dir = Path(checkpoint_dir) + self.threshold_key = threshold_key + + # 设备选择 + if device == 'auto': + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + else: + self.device = torch.device(device) + + # 加载配置 + self._load_config() + + # 加载模型 + self._load_model() + + # 加载阈值 + self._load_thresholds() + + # 加载标准化统计量 + self._load_normalization_stats() + + # 概念历史数据缓存 + # {concept_name: deque(maxlen=seq_len)} + self.history_cache: Dict[str, deque] = {} + + print(f"ConceptAnomalyDetector 初始化完成") + print(f" 设备: {self.device}") + print(f" 阈值: {self.threshold_key} = {self.threshold:.6f}") + print(f" 序列长度: {self.seq_len}") + + def _load_config(self): + """加载配置""" + config_path = self.checkpoint_dir / 'config.json' + if not config_path.exists(): + raise FileNotFoundError(f"配置文件不存在: {config_path}") + + with open(config_path, 'r') as f: + self.config = json.load(f) + + self.features = self.config['features'] + self.seq_len = self.config['seq_len'] + self.model_config = self.config['model'] + + def _load_model(self): + """加载模型""" + model_path = self.checkpoint_dir / 'best_model.pt' + if not model_path.exists(): + raise FileNotFoundError(f"模型文件不存在: {model_path}") + + # 创建模型 + self.model = TransformerAutoencoder(**self.model_config) + + # 加载权重 + checkpoint = torch.load(model_path, map_location=self.device) + self.model.load_state_dict(checkpoint['model_state_dict']) + self.model.to(self.device) + self.model.eval() + + print(f"模型已加载: {model_path}") + + def _load_thresholds(self): + """加载阈值""" + thresholds_path = self.checkpoint_dir / 'thresholds.json' + if not thresholds_path.exists(): + raise FileNotFoundError(f"阈值文件不存在: {thresholds_path}") + + with open(thresholds_path, 'r') as f: + self.thresholds = json.load(f) + + if self.threshold_key not in self.thresholds: + available_keys = list(self.thresholds.keys()) + raise KeyError(f"阈值键 '{self.threshold_key}' 不存在,可用: {available_keys}") + + self.threshold = self.thresholds[self.threshold_key] + + def _load_normalization_stats(self): + """加载标准化统计量""" + stats_path = self.checkpoint_dir / 'normalization_stats.json' + if not stats_path.exists(): + raise FileNotFoundError(f"标准化统计量文件不存在: {stats_path}") + + with open(stats_path, 'r') as f: + stats = json.load(f) + + self.norm_mean = np.array(stats['mean']) + self.norm_std = np.array(stats['std']) + + def normalize(self, features: np.ndarray) -> np.ndarray: + """标准化特征""" + return (features - self.norm_mean) / self.norm_std + + def update_history( + self, + concept_name: str, + features: Dict[str, float] + ): + """ + 更新概念历史数据 + + Args: + concept_name: 概念名称 + features: 当前时刻的特征字典 + """ + # 初始化历史缓存 + if concept_name not in self.history_cache: + self.history_cache[concept_name] = deque(maxlen=self.seq_len) + + # 提取特征向量 + feature_vector = np.array([ + features.get(f, 0.0) for f in self.features + ]) + + # 处理异常值 + feature_vector = np.nan_to_num(feature_vector, nan=0.0, posinf=0.0, neginf=0.0) + + # 添加到历史 + self.history_cache[concept_name].append(feature_vector) + + def get_history_length(self, concept_name: str) -> int: + """获取概念的历史数据长度""" + if concept_name not in self.history_cache: + return 0 + return len(self.history_cache[concept_name]) + + @torch.no_grad() + def detect( + self, + concept_name: str, + features: Dict[str, float] = None, + return_score: bool = True + ) -> Tuple[bool, Optional[float]]: + """ + 检测概念是否异动 + + Args: + concept_name: 概念名称 + features: 当前时刻的特征(如果提供,会先更新历史) + return_score: 是否返回异动分数 + + Returns: + is_anomaly: 是否异动 + score: 异动分数(如果 return_score=True) + """ + # 更新历史 + if features is not None: + self.update_history(concept_name, features) + + # 检查历史数据是否足够 + if concept_name not in self.history_cache: + return False, None + + history = self.history_cache[concept_name] + if len(history) < self.seq_len: + return False, None + + # 构建输入序列 + sequence = np.array(list(history)) # (seq_len, n_features) + + # 标准化 + sequence = self.normalize(sequence) + + # 转为 tensor + x = torch.FloatTensor(sequence).unsqueeze(0) # (1, seq_len, n_features) + x = x.to(self.device) + + # 计算重构误差 + error = self.model.compute_reconstruction_error(x, reduction='none') + + # 取最后一个时刻的误差作为当前分数 + score = error[0, -1].item() + + # 判断是否异动 + is_anomaly = score > self.threshold + + if return_score: + return is_anomaly, score + else: + return is_anomaly, None + + @torch.no_grad() + def batch_detect( + self, + concept_features: Dict[str, Dict[str, float]] + ) -> Dict[str, Tuple[bool, float]]: + """ + 批量检测多个概念 + + Args: + concept_features: {concept_name: {feature_name: value}} + + Returns: + results: {concept_name: (is_anomaly, score)} + """ + results = {} + + for concept_name, features in concept_features.items(): + is_anomaly, score = self.detect(concept_name, features) + results[concept_name] = (is_anomaly, score) + + return results + + def get_anomaly_type( + self, + concept_name: str, + features: Dict[str, float] + ) -> str: + """ + 判断异动类型 + + Args: + concept_name: 概念名称 + features: 当前特征 + + Returns: + anomaly_type: 'surge_up' / 'surge_down' / 'normal' + """ + is_anomaly, score = self.detect(concept_name, features) + + if not is_anomaly: + return 'normal' + + # 根据 alpha 判断涨跌 + alpha = features.get('alpha', 0.0) + + if alpha > 0: + return 'surge_up' + else: + return 'surge_down' + + def get_top_anomalies( + self, + concept_features: Dict[str, Dict[str, float]], + top_k: int = 10 + ) -> List[Tuple[str, float, str]]: + """ + 获取异动分数最高的 top_k 个概念 + + Args: + concept_features: {concept_name: {feature_name: value}} + top_k: 返回数量 + + Returns: + anomalies: [(concept_name, score, anomaly_type), ...] + """ + results = self.batch_detect(concept_features) + + # 按分数排序 + sorted_results = sorted( + [(name, is_anomaly, score) for name, (is_anomaly, score) in results.items() if score is not None], + key=lambda x: x[2], + reverse=True + ) + + # 取 top_k + top_anomalies = [] + for name, is_anomaly, score in sorted_results[:top_k]: + if is_anomaly: + alpha = concept_features[name].get('alpha', 0.0) + anomaly_type = 'surge_up' if alpha > 0 else 'surge_down' + top_anomalies.append((name, score, anomaly_type)) + + return top_anomalies + + def clear_history(self, concept_name: str = None): + """ + 清除历史缓存 + + Args: + concept_name: 概念名称(如果为 None,清除所有) + """ + if concept_name is None: + self.history_cache.clear() + elif concept_name in self.history_cache: + del self.history_cache[concept_name] + + +# ==================== 集成到现有系统 ==================== + +class MLAnomalyService: + """ + ML 异动检测服务 + + 用于替换或增强现有的 Alpha-based 检测 + """ + + def __init__( + self, + checkpoint_dir: str = 'ml/checkpoints', + fallback_to_alpha: bool = True + ): + """ + Args: + checkpoint_dir: 模型检查点目录 + fallback_to_alpha: 当 ML 模型不可用时是否回退到 Alpha 方法 + """ + self.fallback_to_alpha = fallback_to_alpha + self.ml_detector = None + + try: + self.ml_detector = ConceptAnomalyDetector(checkpoint_dir) + print("ML 异动检测服务初始化成功") + except Exception as e: + print(f"ML 模型加载失败: {e}") + if not fallback_to_alpha: + raise + print("将回退到 Alpha-based 检测") + + def is_ml_available(self) -> bool: + """检查 ML 模型是否可用""" + return self.ml_detector is not None + + def detect_anomaly( + self, + concept_name: str, + features: Dict[str, float], + alpha_threshold: float = 2.0 + ) -> Tuple[bool, float, str]: + """ + 检测异动 + + Args: + concept_name: 概念名称 + features: 特征字典(需包含 alpha, amt_ratio 等) + alpha_threshold: Alpha Z-Score 阈值(用于回退) + + Returns: + is_anomaly: 是否异动 + score: 异动分数 + method: 检测方法 ('ml' / 'alpha') + """ + # 优先使用 ML 检测 + if self.ml_detector is not None: + history_len = self.ml_detector.get_history_length(concept_name) + + # 历史数据足够时使用 ML + if history_len >= self.ml_detector.seq_len - 1: + is_anomaly, score = self.ml_detector.detect(concept_name, features) + if score is not None: + return is_anomaly, score, 'ml' + else: + # 更新历史但使用 Alpha 方法 + self.ml_detector.update_history(concept_name, features) + + # 回退到 Alpha 方法 + if self.fallback_to_alpha: + alpha = features.get('alpha', 0.0) + alpha_zscore = features.get('alpha_zscore', 0.0) + + is_anomaly = abs(alpha_zscore) > alpha_threshold + score = abs(alpha_zscore) + + return is_anomaly, score, 'alpha' + + return False, 0.0, 'none' + + +# ==================== 测试 ==================== + +if __name__ == "__main__": + import random + + print("测试 ConceptAnomalyDetector...") + + # 检查模型是否存在 + checkpoint_dir = Path('ml/checkpoints') + if not (checkpoint_dir / 'best_model.pt').exists(): + print("模型文件不存在,跳过测试") + print("请先运行 train.py 训练模型") + exit(0) + + # 初始化检测器 + detector = ConceptAnomalyDetector('ml/checkpoints') + + # 模拟数据 + print("\n模拟实时检测...") + concept_name = "人工智能" + + for i in range(40): + # 生成随机特征 + features = { + 'alpha': random.gauss(0, 1), + 'alpha_delta': random.gauss(0, 0.5), + 'amt_ratio': random.gauss(1, 0.3), + 'amt_delta': random.gauss(0, 0.2), + 'rank_pct': random.random(), + 'limit_up_ratio': random.random() * 0.1, + } + + # 在第 35 分钟模拟异动 + if i == 35: + features['alpha'] = 5.0 + features['alpha_delta'] = 2.0 + features['amt_ratio'] = 3.0 + + is_anomaly, score = detector.detect(concept_name, features) + + history_len = detector.get_history_length(concept_name) + + if score is not None: + status = "🔥 异动!" if is_anomaly else "正常" + print(f" t={i:02d} | 历史={history_len} | 分数={score:.4f} | {status}") + else: + print(f" t={i:02d} | 历史={history_len} | 数据不足") + + print("\n测试完成!") diff --git a/ml/model.py b/ml/model.py new file mode 100644 index 00000000..90c0d61b --- /dev/null +++ b/ml/model.py @@ -0,0 +1,393 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +LSTM Autoencoder 模型定义 + +用于概念异动检测: +- 学习"正常"市场模式 +- 重构误差大的时刻 = 异动 + +模型结构(简洁有效): +┌─────────────────────────────────────┐ +│ 输入: (batch, seq_len, n_features) │ +│ 过去30分钟的特征序列 │ +├─────────────────────────────────────┤ +│ LSTM Encoder │ +│ - 双向 LSTM │ +│ - 输出最后隐藏状态 │ +├─────────────────────────────────────┤ +│ Bottleneck (压缩层) │ +│ 降维到 latent_dim(关键!) │ +├─────────────────────────────────────┤ +│ LSTM Decoder │ +│ - 单向 LSTM │ +│ - 重构序列 │ +├─────────────────────────────────────┤ +│ 输出: (batch, seq_len, n_features) │ +│ 重构的特征序列 │ +└─────────────────────────────────────┘ + +为什么用 LSTM 而不是 Transformer: +1. 参数更少,不容易过拟合 +2. 对于 6 维特征足够用 +3. 训练更稳定 +4. 瓶颈约束更容易控制 +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Optional, Tuple + + +class LSTMAutoencoder(nn.Module): + """ + LSTM Autoencoder for Anomaly Detection + + 设计原则: + - 足够简单,避免过拟合 + - 瓶颈层严格限制,迫使模型只学习主要模式 + - 异常难以通过狭窄瓶颈,重构误差大 + """ + + def __init__( + self, + n_features: int = 6, + hidden_dim: int = 32, # LSTM 隐藏维度(小!) + latent_dim: int = 4, # 瓶颈维度(非常小!关键参数) + num_layers: int = 1, # LSTM 层数 + dropout: float = 0.2, + bidirectional: bool = True, # 双向编码器 + ): + super().__init__() + + self.n_features = n_features + self.hidden_dim = hidden_dim + self.latent_dim = latent_dim + self.num_layers = num_layers + self.bidirectional = bidirectional + self.num_directions = 2 if bidirectional else 1 + + # Encoder: 双向 LSTM + self.encoder = nn.LSTM( + input_size=n_features, + hidden_size=hidden_dim, + num_layers=num_layers, + batch_first=True, + dropout=dropout if num_layers > 1 else 0, + bidirectional=bidirectional + ) + + # Bottleneck: 压缩到极小的 latent space + encoder_output_dim = hidden_dim * self.num_directions + self.bottleneck_down = nn.Sequential( + nn.Linear(encoder_output_dim, latent_dim), + nn.Tanh(), # 限制范围,增加约束 + ) + + # 使用 LeakyReLU 替代 ReLU + # 原因:Z-Score 数据范围是 [-5, +5],ReLU 会截断负值,丢失跌幅信息 + # LeakyReLU 保留负值信号(乘以 0.1) + self.bottleneck_up = nn.Sequential( + nn.Linear(latent_dim, hidden_dim), + nn.LeakyReLU(negative_slope=0.1), + ) + + # Decoder: 单向 LSTM + self.decoder = nn.LSTM( + input_size=hidden_dim, + hidden_size=hidden_dim, + num_layers=num_layers, + batch_first=True, + dropout=dropout if num_layers > 1 else 0, + bidirectional=False # 解码器用单向 + ) + + # 输出层 + self.output_layer = nn.Linear(hidden_dim, n_features) + + # Dropout + self.dropout = nn.Dropout(dropout) + + # 初始化 + self._init_weights() + + def _init_weights(self): + """初始化权重""" + for name, param in self.named_parameters(): + if 'weight_ih' in name: + nn.init.xavier_uniform_(param) + elif 'weight_hh' in name: + nn.init.orthogonal_(param) + elif 'bias' in name: + nn.init.zeros_(param) + + def encode(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + 编码器 + + Args: + x: (batch, seq_len, n_features) + Returns: + latent: (batch, seq_len, latent_dim) 每个时间步的压缩表示 + encoder_outputs: (batch, seq_len, hidden_dim * num_directions) + """ + # LSTM 编码 + encoder_outputs, (h_n, c_n) = self.encoder(x) + # encoder_outputs: (batch, seq_len, hidden_dim * num_directions) + + encoder_outputs = self.dropout(encoder_outputs) + + # 压缩到 latent space(对每个时间步) + latent = self.bottleneck_down(encoder_outputs) + # latent: (batch, seq_len, latent_dim) + + return latent, encoder_outputs + + def decode(self, latent: torch.Tensor, seq_len: int) -> torch.Tensor: + """ + 解码器 + + Args: + latent: (batch, seq_len, latent_dim) + seq_len: 序列长度 + Returns: + output: (batch, seq_len, n_features) + """ + # 从 latent space 恢复 + decoder_input = self.bottleneck_up(latent) + # decoder_input: (batch, seq_len, hidden_dim) + + # LSTM 解码 + decoder_outputs, _ = self.decoder(decoder_input) + # decoder_outputs: (batch, seq_len, hidden_dim) + + decoder_outputs = self.dropout(decoder_outputs) + + # 投影到原始特征空间 + output = self.output_layer(decoder_outputs) + # output: (batch, seq_len, n_features) + + return output + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + 前向传播 + + Args: + x: (batch, seq_len, n_features) + Returns: + output: (batch, seq_len, n_features) 重构结果 + latent: (batch, seq_len, latent_dim) 隐向量 + """ + batch_size, seq_len, _ = x.shape + + # 编码 + latent, _ = self.encode(x) + + # 解码 + output = self.decode(latent, seq_len) + + return output, latent + + def compute_reconstruction_error( + self, + x: torch.Tensor, + reduction: str = 'none' + ) -> torch.Tensor: + """ + 计算重构误差 + + Args: + x: (batch, seq_len, n_features) + reduction: 'none' | 'mean' | 'sum' + Returns: + error: 重构误差 + """ + output, _ = self.forward(x) + + # MSE per feature per timestep + error = F.mse_loss(output, x, reduction='none') + + if reduction == 'none': + # (batch, seq_len, n_features) -> (batch, seq_len) + return error.mean(dim=-1) + elif reduction == 'mean': + return error.mean() + elif reduction == 'sum': + return error.sum() + else: + raise ValueError(f"Unknown reduction: {reduction}") + + def detect_anomaly( + self, + x: torch.Tensor, + threshold: float = None, + return_scores: bool = True + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + 检测异动 + + Args: + x: (batch, seq_len, n_features) + threshold: 异动阈值(如果为 None,只返回分数) + return_scores: 是否返回异动分数 + Returns: + is_anomaly: (batch, seq_len) bool tensor (if threshold is not None) + scores: (batch, seq_len) 异动分数 (if return_scores) + """ + scores = self.compute_reconstruction_error(x, reduction='none') + + is_anomaly = None + if threshold is not None: + is_anomaly = scores > threshold + + if return_scores: + return is_anomaly, scores + else: + return is_anomaly, None + + +# 为了兼容性,创建别名 +TransformerAutoencoder = LSTMAutoencoder + + +# ==================== 损失函数 ==================== + +class AnomalyDetectionLoss(nn.Module): + """ + 异动检测损失函数 + + 简单的 MSE 重构损失 + """ + + def __init__( + self, + feature_weights: torch.Tensor = None, + ): + super().__init__() + self.feature_weights = feature_weights + + def forward( + self, + output: torch.Tensor, + target: torch.Tensor, + latent: torch.Tensor = None + ) -> Tuple[torch.Tensor, dict]: + """ + Args: + output: (batch, seq_len, n_features) 重构结果 + target: (batch, seq_len, n_features) 原始输入 + latent: (batch, seq_len, latent_dim) 隐向量(未使用) + Returns: + loss: 总损失 + loss_dict: 各项损失详情 + """ + # 重构损失 (MSE) + mse = F.mse_loss(output, target, reduction='none') + + # 特征加权(可选) + if self.feature_weights is not None: + weights = self.feature_weights.to(mse.device) + mse = mse * weights + + reconstruction_loss = mse.mean() + + loss_dict = { + 'total': reconstruction_loss.item(), + 'reconstruction': reconstruction_loss.item(), + } + + return reconstruction_loss, loss_dict + + +# ==================== 工具函数 ==================== + +def count_parameters(model: nn.Module) -> int: + """统计模型参数量""" + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + +def create_model(config: dict = None) -> LSTMAutoencoder: + """ + 创建模型 + + 默认使用小型 LSTM 配置,适合异动检测 + """ + default_config = { + 'n_features': 6, + 'hidden_dim': 32, # 小! + 'latent_dim': 4, # 非常小!关键 + 'num_layers': 1, + 'dropout': 0.2, + 'bidirectional': True, + } + + if config: + # 兼容旧的 Transformer 配置键名 + if 'd_model' in config: + config['hidden_dim'] = config.pop('d_model') // 2 + if 'num_encoder_layers' in config: + config['num_layers'] = config.pop('num_encoder_layers') + if 'num_decoder_layers' in config: + config.pop('num_decoder_layers') + if 'nhead' in config: + config.pop('nhead') + if 'dim_feedforward' in config: + config.pop('dim_feedforward') + if 'max_seq_len' in config: + config.pop('max_seq_len') + if 'use_instance_norm' in config: + config.pop('use_instance_norm') + + default_config.update(config) + + model = LSTMAutoencoder(**default_config) + param_count = count_parameters(model) + print(f"模型参数量: {param_count:,}") + + if param_count > 100000: + print(f"⚠️ 警告: 参数量较大({param_count:,}),可能过拟合") + else: + print(f"✓ 参数量适中(LSTM Autoencoder)") + + return model + + +if __name__ == "__main__": + # 测试模型 + print("测试 LSTM Autoencoder...") + + # 创建模型 + model = create_model() + + # 测试输入 + batch_size = 32 + seq_len = 30 + n_features = 6 + + x = torch.randn(batch_size, seq_len, n_features) + + # 前向传播 + output, latent = model(x) + + print(f"输入形状: {x.shape}") + print(f"输出形状: {output.shape}") + print(f"隐向量形状: {latent.shape}") + + # 计算重构误差 + error = model.compute_reconstruction_error(x) + print(f"重构误差形状: {error.shape}") + print(f"平均重构误差: {error.mean().item():.4f}") + + # 测试异动检测 + is_anomaly, scores = model.detect_anomaly(x, threshold=0.5) + print(f"异动检测结果形状: {is_anomaly.shape if is_anomaly is not None else 'None'}") + print(f"异动分数形状: {scores.shape}") + + # 测试损失函数 + criterion = AnomalyDetectionLoss() + loss, loss_dict = criterion(output, x, latent) + print(f"损失: {loss.item():.4f}") + + print("\n测试通过!") diff --git a/ml/prepare_data.py b/ml/prepare_data.py new file mode 100644 index 00000000..cb905d42 --- /dev/null +++ b/ml/prepare_data.py @@ -0,0 +1,537 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +数据准备脚本 - 为 Transformer Autoencoder 准备训练数据 + +从 ClickHouse 提取历史分钟数据,计算以下特征: +1. alpha - 超额收益(概念涨幅 - 大盘涨幅) +2. alpha_delta - Alpha 变化率(5分钟) +3. amt_ratio - 成交额相对均值(当前/过去20分钟均值) +4. amt_delta - 成交额变化率 +5. rank_pct - Alpha 排名百分位 +6. limit_up_ratio - 涨停股占比 + +输出:按交易日存储的特征文件(parquet格式) +""" + +import os +import sys +import numpy as np +import pandas as pd +from datetime import datetime, timedelta, date +from sqlalchemy import create_engine, text +from elasticsearch import Elasticsearch +from clickhouse_driver import Client +import hashlib +import json +import logging +from typing import Dict, List, Set, Tuple +from concurrent.futures import ProcessPoolExecutor, as_completed +from multiprocessing import Manager +import multiprocessing +import warnings +warnings.filterwarnings('ignore') + +# ==================== 配置 ==================== + +MYSQL_ENGINE = create_engine( + "mysql+pymysql://root:Zzl5588161!@192.168.1.5:3306/stock", + echo=False +) + +ES_CLIENT = Elasticsearch(['http://127.0.0.1:9200']) +ES_INDEX = 'concept_library_v3' + +CLICKHOUSE_CONFIG = { + 'host': '127.0.0.1', + 'port': 9000, + 'user': 'default', + 'password': 'Zzl33818!', + 'database': 'stock' +} + +# 输出目录 +OUTPUT_DIR = os.path.join(os.path.dirname(__file__), 'data') +os.makedirs(OUTPUT_DIR, exist_ok=True) + +# 特征计算参数 +FEATURE_CONFIG = { + 'alpha_delta_window': 5, # Alpha变化窗口(分钟) + 'amt_ma_window': 20, # 成交额均值窗口(分钟) + 'limit_up_threshold': 9.8, # 涨停阈值(%) + 'limit_down_threshold': -9.8, # 跌停阈值(%) +} + +REFERENCE_INDEX = '000001.SH' + +# ==================== 日志 ==================== + +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + +# ==================== 工具函数 ==================== + +def get_ch_client(): + return Client(**CLICKHOUSE_CONFIG) + + +def generate_id(name: str) -> str: + return hashlib.md5(name.encode('utf-8')).hexdigest()[:16] + + +def code_to_ch_format(code: str) -> str: + if not code or len(code) != 6 or not code.isdigit(): + return None + if code.startswith('6'): + return f"{code}.SH" + elif code.startswith('0') or code.startswith('3'): + return f"{code}.SZ" + else: + return f"{code}.BJ" + + +# ==================== 获取概念列表 ==================== + +def get_all_concepts() -> List[dict]: + """从ES获取所有叶子概念""" + concepts = [] + + query = { + "query": {"match_all": {}}, + "size": 100, + "_source": ["concept_id", "concept", "stocks"] + } + + resp = ES_CLIENT.search(index=ES_INDEX, body=query, scroll='2m') + scroll_id = resp['_scroll_id'] + hits = resp['hits']['hits'] + + while len(hits) > 0: + for hit in hits: + source = hit['_source'] + stocks = [] + if 'stocks' in source and isinstance(source['stocks'], list): + for stock in source['stocks']: + if isinstance(stock, dict) and 'code' in stock and stock['code']: + stocks.append(stock['code']) + + if stocks: + concepts.append({ + 'concept_id': source.get('concept_id'), + 'concept_name': source.get('concept'), + 'stocks': stocks + }) + + resp = ES_CLIENT.scroll(scroll_id=scroll_id, scroll='2m') + scroll_id = resp['_scroll_id'] + hits = resp['hits']['hits'] + + ES_CLIENT.clear_scroll(scroll_id=scroll_id) + print(f"获取到 {len(concepts)} 个概念") + return concepts + + +# ==================== 获取交易日列表 ==================== + +def get_trading_days(start_date: str, end_date: str) -> List[str]: + """获取交易日列表""" + client = get_ch_client() + + query = f""" + SELECT DISTINCT toDate(timestamp) as trade_date + FROM stock_minute + WHERE toDate(timestamp) >= '{start_date}' + AND toDate(timestamp) <= '{end_date}' + ORDER BY trade_date + """ + + result = client.execute(query) + days = [row[0].strftime('%Y-%m-%d') for row in result] + print(f"找到 {len(days)} 个交易日: {days[0]} ~ {days[-1]}") + return days + + +# ==================== 获取单日数据 ==================== + +def get_daily_stock_data(trade_date: str, stock_codes: List[str]) -> pd.DataFrame: + """获取单日所有股票的分钟数据""" + client = get_ch_client() + + # 转换代码格式 + ch_codes = [] + code_map = {} + for code in stock_codes: + ch_code = code_to_ch_format(code) + if ch_code: + ch_codes.append(ch_code) + code_map[ch_code] = code + + if not ch_codes: + return pd.DataFrame() + + ch_codes_str = "','".join(ch_codes) + + query = f""" + SELECT + code, + timestamp, + close, + volume, + amt + FROM stock_minute + WHERE toDate(timestamp) = '{trade_date}' + AND code IN ('{ch_codes_str}') + ORDER BY code, timestamp + """ + + result = client.execute(query) + + if not result: + return pd.DataFrame() + + df = pd.DataFrame(result, columns=['ch_code', 'timestamp', 'close', 'volume', 'amt']) + df['code'] = df['ch_code'].map(code_map) + df = df.dropna(subset=['code']) + + return df[['code', 'timestamp', 'close', 'volume', 'amt']] + + +def get_daily_index_data(trade_date: str, index_code: str = REFERENCE_INDEX) -> pd.DataFrame: + """获取单日指数分钟数据""" + client = get_ch_client() + + query = f""" + SELECT + timestamp, + close, + volume, + amt + FROM index_minute + WHERE toDate(timestamp) = '{trade_date}' + AND code = '{index_code}' + ORDER BY timestamp + """ + + result = client.execute(query) + + if not result: + return pd.DataFrame() + + df = pd.DataFrame(result, columns=['timestamp', 'close', 'volume', 'amt']) + return df + + +def get_prev_close(stock_codes: List[str], trade_date: str) -> Dict[str, float]: + """获取昨收价(上一交易日的收盘价 F007N)""" + valid_codes = [c for c in stock_codes if c and len(c) == 6 and c.isdigit()] + if not valid_codes: + return {} + + codes_str = "','".join(valid_codes) + + # 注意:F007N 是"最近成交价"即当日收盘价,F002N 是"昨日收盘价" + # 我们需要查上一交易日的 F007N(那天的收盘价)作为今天的昨收 + query = f""" + SELECT SECCODE, F007N + FROM ea_trade + WHERE SECCODE IN ('{codes_str}') + AND TRADEDATE = ( + SELECT MAX(TRADEDATE) FROM ea_trade WHERE TRADEDATE < '{trade_date}' + ) + AND F007N IS NOT NULL AND F007N > 0 + """ + + try: + with MYSQL_ENGINE.connect() as conn: + result = conn.execute(text(query)) + return {row[0]: float(row[1]) for row in result if row[1]} + except Exception as e: + print(f"获取昨收价失败: {e}") + return {} + + +def get_index_prev_close(trade_date: str, index_code: str = REFERENCE_INDEX) -> float: + """获取指数昨收价""" + code_no_suffix = index_code.split('.')[0] + + try: + with MYSQL_ENGINE.connect() as conn: + result = conn.execute(text(""" + SELECT F006N FROM ea_exchangetrade + WHERE INDEXCODE = :code AND TRADEDATE < :today + ORDER BY TRADEDATE DESC LIMIT 1 + """), {'code': code_no_suffix, 'today': trade_date}).fetchone() + + if result and result[0]: + return float(result[0]) + except Exception as e: + print(f"获取指数昨收失败: {e}") + + return None + + +# ==================== 计算特征 ==================== + +def compute_daily_features( + trade_date: str, + concepts: List[dict], + all_stocks: List[str] +) -> pd.DataFrame: + """ + 计算单日所有概念的特征 + + 返回 DataFrame: + - index: (timestamp, concept_id) + - columns: alpha, alpha_delta, amt_ratio, amt_delta, rank_pct, limit_up_ratio + """ + + # 1. 获取数据 + stock_df = get_daily_stock_data(trade_date, all_stocks) + if stock_df.empty: + return pd.DataFrame() + + index_df = get_daily_index_data(trade_date) + if index_df.empty: + return pd.DataFrame() + + # 2. 获取昨收价 + prev_close = get_prev_close(all_stocks, trade_date) + index_prev_close = get_index_prev_close(trade_date) + + if not prev_close or not index_prev_close: + return pd.DataFrame() + + # 3. 计算股票涨跌幅和成交额 + stock_df['prev_close'] = stock_df['code'].map(prev_close) + stock_df = stock_df.dropna(subset=['prev_close']) + stock_df['change_pct'] = (stock_df['close'] - stock_df['prev_close']) / stock_df['prev_close'] * 100 + + # 4. 计算指数涨跌幅 + index_df['change_pct'] = (index_df['close'] - index_prev_close) / index_prev_close * 100 + index_change_map = dict(zip(index_df['timestamp'], index_df['change_pct'])) + + # 5. 获取所有时间点 + timestamps = sorted(stock_df['timestamp'].unique()) + + # 6. 按时间点计算概念特征 + results = [] + + # 概念到股票的映射 + concept_stocks = {c['concept_id']: set(c['stocks']) for c in concepts} + concept_names = {c['concept_id']: c['concept_name'] for c in concepts} + + # 历史数据缓存(用于计算变化率) + concept_history = {cid: {'alpha': [], 'amt': []} for cid in concept_stocks} + + for ts in timestamps: + ts_stock_data = stock_df[stock_df['timestamp'] == ts] + index_change = index_change_map.get(ts, 0) + + # 股票涨跌幅和成交额字典 + stock_change = dict(zip(ts_stock_data['code'], ts_stock_data['change_pct'])) + stock_amt = dict(zip(ts_stock_data['code'], ts_stock_data['amt'])) + + concept_features = [] + + for concept_id, stocks in concept_stocks.items(): + # 该概念的股票数据 + concept_changes = [stock_change[s] for s in stocks if s in stock_change] + concept_amts = [stock_amt.get(s, 0) for s in stocks if s in stock_change] + + if not concept_changes: + continue + + # 基础统计 + avg_change = np.mean(concept_changes) + total_amt = sum(concept_amts) + + # Alpha = 概念涨幅 - 指数涨幅 + alpha = avg_change - index_change + + # 涨停/跌停股占比 + limit_up_count = sum(1 for c in concept_changes if c >= FEATURE_CONFIG['limit_up_threshold']) + limit_down_count = sum(1 for c in concept_changes if c <= FEATURE_CONFIG['limit_down_threshold']) + limit_up_ratio = limit_up_count / len(concept_changes) + limit_down_ratio = limit_down_count / len(concept_changes) + + # 更新历史 + history = concept_history[concept_id] + history['alpha'].append(alpha) + history['amt'].append(total_amt) + + # 计算变化率 + alpha_delta = 0 + if len(history['alpha']) > FEATURE_CONFIG['alpha_delta_window']: + alpha_delta = alpha - history['alpha'][-FEATURE_CONFIG['alpha_delta_window']-1] + + # 成交额相对均值 + amt_ratio = 1.0 + amt_delta = 0 + if len(history['amt']) > FEATURE_CONFIG['amt_ma_window']: + amt_ma = np.mean(history['amt'][-FEATURE_CONFIG['amt_ma_window']-1:-1]) + if amt_ma > 0: + amt_ratio = total_amt / amt_ma + amt_delta = total_amt - history['amt'][-2] if len(history['amt']) > 1 else 0 + + concept_features.append({ + 'concept_id': concept_id, + 'alpha': alpha, + 'alpha_delta': alpha_delta, + 'amt_ratio': amt_ratio, + 'amt_delta': amt_delta, + 'limit_up_ratio': limit_up_ratio, + 'limit_down_ratio': limit_down_ratio, + 'total_amt': total_amt, + 'stock_count': len(concept_changes), + }) + + if not concept_features: + continue + + # 计算排名百分位 + concept_df = pd.DataFrame(concept_features) + concept_df['rank_pct'] = concept_df['alpha'].rank(pct=True) + + # 添加时间戳 + concept_df['timestamp'] = ts + results.append(concept_df) + + if not results: + return pd.DataFrame() + + # 合并所有时间点 + final_df = pd.concat(results, ignore_index=True) + + # 标准化成交额变化率 + if 'amt_delta' in final_df.columns: + amt_delta_std = final_df['amt_delta'].std() + if amt_delta_std > 0: + final_df['amt_delta'] = final_df['amt_delta'] / amt_delta_std + + return final_df + + +# ==================== 主流程 ==================== + +def process_single_day(args) -> Tuple[str, bool]: + """ + 处理单个交易日(多进程版本) + + Args: + args: (trade_date, concepts, all_stocks) 元组 + + Returns: + (trade_date, success) 元组 + """ + trade_date, concepts, all_stocks = args + output_file = os.path.join(OUTPUT_DIR, f'features_{trade_date}.parquet') + + # 检查是否已处理 + if os.path.exists(output_file): + print(f"[{trade_date}] 已存在,跳过") + return (trade_date, True) + + print(f"[{trade_date}] 开始处理...") + + try: + df = compute_daily_features(trade_date, concepts, all_stocks) + + if df.empty: + print(f"[{trade_date}] 无数据") + return (trade_date, False) + + # 保存 + df.to_parquet(output_file, index=False) + print(f"[{trade_date}] 保存完成") + return (trade_date, True) + + except Exception as e: + print(f"[{trade_date}] 处理失败: {e}") + import traceback + traceback.print_exc() + return (trade_date, False) + + +def main(): + import argparse + from tqdm import tqdm + + parser = argparse.ArgumentParser(description='准备训练数据') + parser.add_argument('--start', type=str, default='2022-01-01', help='开始日期') + parser.add_argument('--end', type=str, default=None, help='结束日期(默认今天)') + parser.add_argument('--workers', type=int, default=18, help='并行进程数(默认18)') + parser.add_argument('--force', action='store_true', help='强制重新处理已存在的文件') + + args = parser.parse_args() + + end_date = args.end or datetime.now().strftime('%Y-%m-%d') + + print("=" * 60) + print("数据准备 - Transformer Autoencoder 训练数据") + print("=" * 60) + print(f"日期范围: {args.start} ~ {end_date}") + print(f"并行进程数: {args.workers}") + + # 1. 获取概念列表 + concepts = get_all_concepts() + + # 收集所有股票 + all_stocks = list(set(s for c in concepts for s in c['stocks'])) + print(f"股票总数: {len(all_stocks)}") + + # 2. 获取交易日列表 + trading_days = get_trading_days(args.start, end_date) + + if not trading_days: + print("无交易日数据") + return + + # 如果强制模式,删除已有文件 + if args.force: + for trade_date in trading_days: + output_file = os.path.join(OUTPUT_DIR, f'features_{trade_date}.parquet') + if os.path.exists(output_file): + os.remove(output_file) + print(f"删除已有文件: {output_file}") + + # 3. 准备任务参数 + tasks = [(trade_date, concepts, all_stocks) for trade_date in trading_days] + + print(f"\n开始处理 {len(trading_days)} 个交易日({args.workers} 进程并行)...") + + # 4. 多进程处理 + success_count = 0 + failed_dates = [] + + with ProcessPoolExecutor(max_workers=args.workers) as executor: + # 提交所有任务 + futures = {executor.submit(process_single_day, task): task[0] for task in tasks} + + # 使用 tqdm 显示进度 + with tqdm(total=len(futures), desc="处理进度", unit="天") as pbar: + for future in as_completed(futures): + trade_date = futures[future] + try: + result_date, success = future.result() + if success: + success_count += 1 + else: + failed_dates.append(result_date) + except Exception as e: + print(f"\n[{trade_date}] 进程异常: {e}") + failed_dates.append(trade_date) + pbar.update(1) + + print("\n" + "=" * 60) + print(f"处理完成: {success_count}/{len(trading_days)} 个交易日") + if failed_dates: + print(f"失败日期: {failed_dates[:10]}{'...' if len(failed_dates) > 10 else ''}") + print(f"数据保存在: {OUTPUT_DIR}") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/ml/prepare_data_v2.py b/ml/prepare_data_v2.py new file mode 100644 index 00000000..5a3ad02c --- /dev/null +++ b/ml/prepare_data_v2.py @@ -0,0 +1,715 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +数据准备 V2 - 基于时间片对齐的特征计算(修复版) + +核心改进: +1. 时间片对齐:9:35 和历史的 9:35 比,而不是和前30分钟比 +2. Z-Score 特征:相对于同时间片历史分布的偏离程度 +3. 滚动窗口基线:每个日期使用它之前 N 天的数据作为基线(不是固定的最后 N 天!) +4. 基于 Z-Score 的动量:消除一天内波动率异构性 + +修复: +- 滚动窗口基线:避免未来数据泄露 +- Z-Score 动量:消除早盘/尾盘波动率差异 +- 进程级数据库单例:避免连接池爆炸 +""" + +import os +import sys +import numpy as np +import pandas as pd +from datetime import datetime, timedelta +from sqlalchemy import create_engine, text +from elasticsearch import Elasticsearch +from clickhouse_driver import Client +from concurrent.futures import ProcessPoolExecutor, as_completed +from typing import Dict, List, Tuple, Optional +from tqdm import tqdm +from collections import defaultdict +import warnings +import pickle + +warnings.filterwarnings('ignore') + +# ==================== 配置 ==================== + +MYSQL_URL = "mysql+pymysql://root:Zzl5588161!@192.168.1.5:3306/stock" +ES_HOST = 'http://127.0.0.1:9200' +ES_INDEX = 'concept_library_v3' + +CLICKHOUSE_CONFIG = { + 'host': '127.0.0.1', + 'port': 9000, + 'user': 'default', + 'password': 'Zzl33818!', + 'database': 'stock' +} + +REFERENCE_INDEX = '000001.SH' + +# 输出目录 +OUTPUT_DIR = os.path.join(os.path.dirname(__file__), 'data_v2') +BASELINE_DIR = os.path.join(OUTPUT_DIR, 'baselines') +RAW_CACHE_DIR = os.path.join(OUTPUT_DIR, 'raw_cache') +os.makedirs(OUTPUT_DIR, exist_ok=True) +os.makedirs(BASELINE_DIR, exist_ok=True) +os.makedirs(RAW_CACHE_DIR, exist_ok=True) + +# 特征配置 +CONFIG = { + 'baseline_days': 20, # 滚动窗口大小 + 'min_baseline_samples': 10, # 最少需要10个样本才算有效基线 + 'limit_up_threshold': 9.8, + 'limit_down_threshold': -9.8, + 'zscore_clip': 5.0, +} + +# 特征列表 +FEATURES_V2 = [ + 'alpha', 'alpha_zscore', 'amt_zscore', 'rank_zscore', + 'momentum_3m', 'momentum_5m', 'limit_up_ratio', +] + +# ==================== 进程级单例(避免连接池爆炸)==================== + +# 进程级全局变量 +_process_mysql_engine = None +_process_es_client = None +_process_ch_client = None + + +def init_process_connections(): + """进程初始化时调用,创建连接(单例)""" + global _process_mysql_engine, _process_es_client, _process_ch_client + _process_mysql_engine = create_engine(MYSQL_URL, echo=False, pool_pre_ping=True, pool_size=5) + _process_es_client = Elasticsearch([ES_HOST]) + _process_ch_client = Client(**CLICKHOUSE_CONFIG) + + +def get_mysql_engine(): + """获取进程级 MySQL Engine(单例)""" + global _process_mysql_engine + if _process_mysql_engine is None: + _process_mysql_engine = create_engine(MYSQL_URL, echo=False, pool_pre_ping=True, pool_size=5) + return _process_mysql_engine + + +def get_es_client(): + """获取进程级 ES 客户端(单例)""" + global _process_es_client + if _process_es_client is None: + _process_es_client = Elasticsearch([ES_HOST]) + return _process_es_client + + +def get_ch_client(): + """获取进程级 ClickHouse 客户端(单例)""" + global _process_ch_client + if _process_ch_client is None: + _process_ch_client = Client(**CLICKHOUSE_CONFIG) + return _process_ch_client + + +# ==================== 工具函数 ==================== + +def code_to_ch_format(code: str) -> str: + if not code or len(code) != 6 or not code.isdigit(): + return None + if code.startswith('6'): + return f"{code}.SH" + elif code.startswith('0') or code.startswith('3'): + return f"{code}.SZ" + else: + return f"{code}.BJ" + + +def time_to_slot(ts) -> str: + """将时间戳转换为时间片(HH:MM格式)""" + if isinstance(ts, str): + return ts + return ts.strftime('%H:%M') + + +# ==================== 获取概念列表 ==================== + +def get_all_concepts() -> List[dict]: + """从ES获取所有叶子概念""" + es_client = get_es_client() + concepts = [] + + query = { + "query": {"match_all": {}}, + "size": 100, + "_source": ["concept_id", "concept", "stocks"] + } + + resp = es_client.search(index=ES_INDEX, body=query, scroll='2m') + scroll_id = resp['_scroll_id'] + hits = resp['hits']['hits'] + + while len(hits) > 0: + for hit in hits: + source = hit['_source'] + stocks = [] + if 'stocks' in source and isinstance(source['stocks'], list): + for stock in source['stocks']: + if isinstance(stock, dict) and 'code' in stock and stock['code']: + stocks.append(stock['code']) + + if stocks: + concepts.append({ + 'concept_id': source.get('concept_id'), + 'concept_name': source.get('concept'), + 'stocks': stocks + }) + + resp = es_client.scroll(scroll_id=scroll_id, scroll='2m') + scroll_id = resp['_scroll_id'] + hits = resp['hits']['hits'] + + es_client.clear_scroll(scroll_id=scroll_id) + print(f"获取到 {len(concepts)} 个概念") + return concepts + + +# ==================== 获取交易日列表 ==================== + +def get_trading_days(start_date: str, end_date: str) -> List[str]: + """获取交易日列表""" + client = get_ch_client() + + query = f""" + SELECT DISTINCT toDate(timestamp) as trade_date + FROM stock_minute + WHERE toDate(timestamp) >= '{start_date}' + AND toDate(timestamp) <= '{end_date}' + ORDER BY trade_date + """ + + result = client.execute(query) + days = [row[0].strftime('%Y-%m-%d') for row in result] + if days: + print(f"找到 {len(days)} 个交易日: {days[0]} ~ {days[-1]}") + return days + + +# ==================== 获取昨收价 ==================== + +def get_prev_close(stock_codes: List[str], trade_date: str) -> Dict[str, float]: + """获取昨收价(上一交易日的收盘价 F007N)""" + valid_codes = [c for c in stock_codes if c and len(c) == 6 and c.isdigit()] + if not valid_codes: + return {} + + codes_str = "','".join(valid_codes) + query = f""" + SELECT SECCODE, F007N + FROM ea_trade + WHERE SECCODE IN ('{codes_str}') + AND TRADEDATE = ( + SELECT MAX(TRADEDATE) FROM ea_trade WHERE TRADEDATE < '{trade_date}' + ) + AND F007N IS NOT NULL AND F007N > 0 + """ + + try: + engine = get_mysql_engine() + with engine.connect() as conn: + result = conn.execute(text(query)) + return {row[0]: float(row[1]) for row in result if row[1]} + except Exception as e: + print(f"获取昨收价失败: {e}") + return {} + + +def get_index_prev_close(trade_date: str, index_code: str = REFERENCE_INDEX) -> float: + """获取指数昨收价""" + code_no_suffix = index_code.split('.')[0] + + try: + engine = get_mysql_engine() + with engine.connect() as conn: + result = conn.execute(text(""" + SELECT F006N FROM ea_exchangetrade + WHERE INDEXCODE = :code AND TRADEDATE < :today + ORDER BY TRADEDATE DESC LIMIT 1 + """), {'code': code_no_suffix, 'today': trade_date}).fetchone() + + if result and result[0]: + return float(result[0]) + except Exception as e: + print(f"获取指数昨收失败: {e}") + + return None + + +# ==================== 获取分钟数据 ==================== + +def get_daily_stock_data(trade_date: str, stock_codes: List[str]) -> pd.DataFrame: + """获取单日所有股票的分钟数据""" + client = get_ch_client() + + ch_codes = [] + code_map = {} + for code in stock_codes: + ch_code = code_to_ch_format(code) + if ch_code: + ch_codes.append(ch_code) + code_map[ch_code] = code + + if not ch_codes: + return pd.DataFrame() + + ch_codes_str = "','".join(ch_codes) + + query = f""" + SELECT code, timestamp, close, volume, amt + FROM stock_minute + WHERE toDate(timestamp) = '{trade_date}' + AND code IN ('{ch_codes_str}') + ORDER BY code, timestamp + """ + + result = client.execute(query) + if not result: + return pd.DataFrame() + + df = pd.DataFrame(result, columns=['ch_code', 'timestamp', 'close', 'volume', 'amt']) + df['code'] = df['ch_code'].map(code_map) + df = df.dropna(subset=['code']) + + return df[['code', 'timestamp', 'close', 'volume', 'amt']] + + +def get_daily_index_data(trade_date: str, index_code: str = REFERENCE_INDEX) -> pd.DataFrame: + """获取单日指数分钟数据""" + client = get_ch_client() + + query = f""" + SELECT timestamp, close, volume, amt + FROM index_minute + WHERE toDate(timestamp) = '{trade_date}' + AND code = '{index_code}' + ORDER BY timestamp + """ + + result = client.execute(query) + if not result: + return pd.DataFrame() + + df = pd.DataFrame(result, columns=['timestamp', 'close', 'volume', 'amt']) + return df + + +# ==================== 计算原始概念特征(单日)==================== + +def compute_raw_concept_features( + trade_date: str, + concepts: List[dict], + all_stocks: List[str] +) -> pd.DataFrame: + """计算单日概念的原始特征(alpha, amt, rank_pct, limit_up_ratio)""" + # 检查缓存 + cache_file = os.path.join(RAW_CACHE_DIR, f'raw_{trade_date}.parquet') + if os.path.exists(cache_file): + return pd.read_parquet(cache_file) + + # 获取数据 + stock_df = get_daily_stock_data(trade_date, all_stocks) + if stock_df.empty: + return pd.DataFrame() + + index_df = get_daily_index_data(trade_date) + if index_df.empty: + return pd.DataFrame() + + # 获取昨收价 + prev_close = get_prev_close(all_stocks, trade_date) + index_prev_close = get_index_prev_close(trade_date) + + if not prev_close or not index_prev_close: + return pd.DataFrame() + + # 计算涨跌幅 + stock_df['prev_close'] = stock_df['code'].map(prev_close) + stock_df = stock_df.dropna(subset=['prev_close']) + stock_df['change_pct'] = (stock_df['close'] - stock_df['prev_close']) / stock_df['prev_close'] * 100 + + index_df['change_pct'] = (index_df['close'] - index_prev_close) / index_prev_close * 100 + index_change_map = dict(zip(index_df['timestamp'], index_df['change_pct'])) + + # 获取所有时间点 + timestamps = sorted(stock_df['timestamp'].unique()) + + # 概念到股票的映射 + concept_stocks = {c['concept_id']: set(c['stocks']) for c in concepts} + + results = [] + + for ts in timestamps: + ts_stock_data = stock_df[stock_df['timestamp'] == ts] + index_change = index_change_map.get(ts, 0) + + stock_change = dict(zip(ts_stock_data['code'], ts_stock_data['change_pct'])) + stock_amt = dict(zip(ts_stock_data['code'], ts_stock_data['amt'])) + + concept_features = [] + + for concept_id, stocks in concept_stocks.items(): + concept_changes = [stock_change[s] for s in stocks if s in stock_change] + concept_amts = [stock_amt.get(s, 0) for s in stocks if s in stock_change] + + if not concept_changes: + continue + + avg_change = np.mean(concept_changes) + total_amt = sum(concept_amts) + alpha = avg_change - index_change + + limit_up_count = sum(1 for c in concept_changes if c >= CONFIG['limit_up_threshold']) + limit_up_ratio = limit_up_count / len(concept_changes) + + concept_features.append({ + 'concept_id': concept_id, + 'alpha': alpha, + 'total_amt': total_amt, + 'limit_up_ratio': limit_up_ratio, + 'stock_count': len(concept_changes), + }) + + if not concept_features: + continue + + concept_df = pd.DataFrame(concept_features) + concept_df['rank_pct'] = concept_df['alpha'].rank(pct=True) + concept_df['timestamp'] = ts + concept_df['time_slot'] = time_to_slot(ts) + concept_df['trade_date'] = trade_date + + results.append(concept_df) + + if not results: + return pd.DataFrame() + + result_df = pd.concat(results, ignore_index=True) + + # 保存缓存 + result_df.to_parquet(cache_file, index=False) + + return result_df + + +# ==================== 滚动窗口基线计算 ==================== + +def compute_rolling_baseline( + historical_data: pd.DataFrame, + concept_id: str +) -> Dict[str, Dict]: + """ + 计算单个概念的滚动基线 + + 返回: {time_slot: {alpha_mean, alpha_std, amt_mean, amt_std, rank_mean, rank_std, sample_count}} + """ + if historical_data.empty: + return {} + + concept_data = historical_data[historical_data['concept_id'] == concept_id] + if concept_data.empty: + return {} + + baseline_dict = {} + + for time_slot, group in concept_data.groupby('time_slot'): + if len(group) < CONFIG['min_baseline_samples']: + continue + + alpha_std = group['alpha'].std() + amt_std = group['total_amt'].std() + rank_std = group['rank_pct'].std() + + baseline_dict[time_slot] = { + 'alpha_mean': group['alpha'].mean(), + 'alpha_std': max(alpha_std if pd.notna(alpha_std) else 1.0, 0.1), + 'amt_mean': group['total_amt'].mean(), + 'amt_std': max(amt_std if pd.notna(amt_std) else group['total_amt'].mean() * 0.5, 1.0), + 'rank_mean': group['rank_pct'].mean(), + 'rank_std': max(rank_std if pd.notna(rank_std) else 0.2, 0.05), + 'sample_count': len(group), + } + + return baseline_dict + + +# ==================== 计算单日 Z-Score 特征(带滚动基线)==================== + +def compute_zscore_features_rolling( + trade_date: str, + concepts: List[dict], + all_stocks: List[str], + historical_raw_data: pd.DataFrame # 该日期之前 N 天的原始数据 +) -> pd.DataFrame: + """ + 计算单日的 Z-Score 特征(使用滚动窗口基线) + + 关键改进: + 1. 基线只使用 trade_date 之前的数据(无未来泄露) + 2. 动量基于 Z-Score 计算(消除波动率异构性) + """ + # 计算当日原始特征 + raw_df = compute_raw_concept_features(trade_date, concepts, all_stocks) + + if raw_df.empty: + return pd.DataFrame() + + zscore_records = [] + + for concept_id, group in raw_df.groupby('concept_id'): + # 计算该概念的滚动基线(只用历史数据) + baseline_dict = compute_rolling_baseline(historical_raw_data, concept_id) + + if not baseline_dict: + continue + + # 按时间排序 + group = group.sort_values('timestamp').reset_index(drop=True) + + # Z-Score 历史(用于计算基于 Z-Score 的动量) + zscore_history = [] + + for idx, row in group.iterrows(): + time_slot = row['time_slot'] + + if time_slot not in baseline_dict: + continue + + bl = baseline_dict[time_slot] + + # 计算 Z-Score + alpha_zscore = (row['alpha'] - bl['alpha_mean']) / bl['alpha_std'] + amt_zscore = (row['total_amt'] - bl['amt_mean']) / bl['amt_std'] + rank_zscore = (row['rank_pct'] - bl['rank_mean']) / bl['rank_std'] + + # 截断极端值 + clip = CONFIG['zscore_clip'] + alpha_zscore = np.clip(alpha_zscore, -clip, clip) + amt_zscore = np.clip(amt_zscore, -clip, clip) + rank_zscore = np.clip(rank_zscore, -clip, clip) + + # 记录 Z-Score 历史 + zscore_history.append(alpha_zscore) + + # 基于 Z-Score 计算动量(消除波动率异构性) + momentum_3m = 0.0 + momentum_5m = 0.0 + + if len(zscore_history) >= 3: + recent_3 = zscore_history[-3:] + older_3 = zscore_history[-6:-3] if len(zscore_history) >= 6 else [zscore_history[0]] + momentum_3m = np.mean(recent_3) - np.mean(older_3) + + if len(zscore_history) >= 5: + recent_5 = zscore_history[-5:] + older_5 = zscore_history[-10:-5] if len(zscore_history) >= 10 else [zscore_history[0]] + momentum_5m = np.mean(recent_5) - np.mean(older_5) + + zscore_records.append({ + 'concept_id': concept_id, + 'timestamp': row['timestamp'], + 'time_slot': time_slot, + 'trade_date': trade_date, + # 原始特征 + 'alpha': row['alpha'], + 'total_amt': row['total_amt'], + 'limit_up_ratio': row['limit_up_ratio'], + 'stock_count': row['stock_count'], + 'rank_pct': row['rank_pct'], + # Z-Score 特征 + 'alpha_zscore': alpha_zscore, + 'amt_zscore': amt_zscore, + 'rank_zscore': rank_zscore, + # 基于 Z-Score 的动量 + 'momentum_3m': momentum_3m, + 'momentum_5m': momentum_5m, + }) + + if not zscore_records: + return pd.DataFrame() + + return pd.DataFrame(zscore_records) + + +# ==================== 多进程处理 ==================== + +def process_single_day_v2(args) -> Tuple[str, bool]: + """处理单个交易日(多进程版本)""" + trade_date, day_index, concepts, all_stocks, all_trading_days = args + output_file = os.path.join(OUTPUT_DIR, f'features_v2_{trade_date}.parquet') + + if os.path.exists(output_file): + return (trade_date, True) + + try: + # 计算滚动窗口范围(该日期之前的 N 天) + baseline_days = CONFIG['baseline_days'] + + # 找出 trade_date 之前的交易日 + start_idx = max(0, day_index - baseline_days) + end_idx = day_index # 不包含当天 + + if end_idx <= start_idx: + # 没有足够的历史数据 + return (trade_date, False) + + historical_days = all_trading_days[start_idx:end_idx] + + # 加载历史原始数据 + historical_dfs = [] + for hist_date in historical_days: + cache_file = os.path.join(RAW_CACHE_DIR, f'raw_{hist_date}.parquet') + if os.path.exists(cache_file): + historical_dfs.append(pd.read_parquet(cache_file)) + else: + # 需要计算 + hist_df = compute_raw_concept_features(hist_date, concepts, all_stocks) + if not hist_df.empty: + historical_dfs.append(hist_df) + + if not historical_dfs: + return (trade_date, False) + + historical_raw_data = pd.concat(historical_dfs, ignore_index=True) + + # 计算当日 Z-Score 特征(使用滚动基线) + df = compute_zscore_features_rolling(trade_date, concepts, all_stocks, historical_raw_data) + + if df.empty: + return (trade_date, False) + + df.to_parquet(output_file, index=False) + return (trade_date, True) + + except Exception as e: + print(f"[{trade_date}] 处理失败: {e}") + import traceback + traceback.print_exc() + return (trade_date, False) + + +# ==================== 主流程 ==================== + +def main(): + import argparse + + parser = argparse.ArgumentParser(description='准备训练数据 V2(滚动窗口基线 + Z-Score 动量)') + parser.add_argument('--start', type=str, default='2022-01-01', help='开始日期') + parser.add_argument('--end', type=str, default=None, help='结束日期(默认今天)') + parser.add_argument('--workers', type=int, default=18, help='并行进程数') + parser.add_argument('--baseline-days', type=int, default=20, help='滚动基线窗口大小') + parser.add_argument('--force', action='store_true', help='强制重新计算(忽略缓存)') + + args = parser.parse_args() + + end_date = args.end or datetime.now().strftime('%Y-%m-%d') + CONFIG['baseline_days'] = args.baseline_days + + print("=" * 60) + print("数据准备 V2 - 滚动窗口基线 + Z-Score 动量") + print("=" * 60) + print(f"日期范围: {args.start} ~ {end_date}") + print(f"并行进程数: {args.workers}") + print(f"滚动基线窗口: {args.baseline_days} 天") + + # 初始化主进程连接 + init_process_connections() + + # 1. 获取概念列表 + concepts = get_all_concepts() + all_stocks = list(set(s for c in concepts for s in c['stocks'])) + print(f"股票总数: {len(all_stocks)}") + + # 2. 获取交易日列表 + trading_days = get_trading_days(args.start, end_date) + + if not trading_days: + print("无交易日数据") + return + + # 3. 第一阶段:预计算所有原始特征(用于缓存) + print(f"\n{'='*60}") + print("第一阶段:预计算原始特征(用于滚动基线)") + print(f"{'='*60}") + + # 如果强制重新计算,删除缓存 + if args.force: + import shutil + if os.path.exists(RAW_CACHE_DIR): + shutil.rmtree(RAW_CACHE_DIR) + os.makedirs(RAW_CACHE_DIR, exist_ok=True) + if os.path.exists(OUTPUT_DIR): + for f in os.listdir(OUTPUT_DIR): + if f.startswith('features_v2_'): + os.remove(os.path.join(OUTPUT_DIR, f)) + + # 单线程预计算原始特征(因为需要顺序缓存) + print(f"预计算 {len(trading_days)} 天的原始特征...") + for trade_date in tqdm(trading_days, desc="预计算原始特征"): + cache_file = os.path.join(RAW_CACHE_DIR, f'raw_{trade_date}.parquet') + if not os.path.exists(cache_file): + compute_raw_concept_features(trade_date, concepts, all_stocks) + + # 4. 第二阶段:计算 Z-Score 特征(多进程) + print(f"\n{'='*60}") + print("第二阶段:计算 Z-Score 特征(滚动基线)") + print(f"{'='*60}") + + # 从第 baseline_days 天开始(前面的没有足够历史) + start_idx = args.baseline_days + processable_days = trading_days[start_idx:] + + if not processable_days: + print(f"错误:需要至少 {args.baseline_days + 1} 天的数据") + return + + print(f"可处理日期: {processable_days[0]} ~ {processable_days[-1]} ({len(processable_days)} 天)") + print(f"跳过前 {start_idx} 天(基线预热期)") + + # 构建任务 + tasks = [] + for i, trade_date in enumerate(trading_days): + if i >= start_idx: + tasks.append((trade_date, i, concepts, all_stocks, trading_days)) + + print(f"开始处理 {len(tasks)} 个交易日({args.workers} 进程并行)...") + + success_count = 0 + failed_dates = [] + + # 使用进程池初始化器 + with ProcessPoolExecutor(max_workers=args.workers, initializer=init_process_connections) as executor: + futures = {executor.submit(process_single_day_v2, task): task[0] for task in tasks} + + with tqdm(total=len(futures), desc="处理进度", unit="天") as pbar: + for future in as_completed(futures): + trade_date = futures[future] + try: + result_date, success = future.result() + if success: + success_count += 1 + else: + failed_dates.append(result_date) + except Exception as e: + print(f"\n[{trade_date}] 进程异常: {e}") + failed_dates.append(trade_date) + pbar.update(1) + + print("\n" + "=" * 60) + print(f"处理完成: {success_count}/{len(tasks)} 个交易日") + if failed_dates: + print(f"失败日期: {failed_dates[:10]}{'...' if len(failed_dates) > 10 else ''}") + print(f"数据保存在: {OUTPUT_DIR}") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/ml/realtime_detector.py b/ml/realtime_detector.py new file mode 100644 index 00000000..3af09828 --- /dev/null +++ b/ml/realtime_detector.py @@ -0,0 +1,1520 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +实时概念异动检测服务(实盘可用) + +盘中每分钟运行一次,检测概念异动并写入数据库 + +数据流程: +1. 启动时从 ES 获取概念列表,从 MySQL 获取昨收价 +2. 自动预热:从 ClickHouse 加载当天已有的历史分钟数据 +3. 每分钟增量获取最新分钟数据 +4. 在内存中实时计算概念特征(无前瞻偏差) +5. 使用规则+ML融合评分检测异动 +6. 异动写入 MySQL + +特征计算说明(无 Looking Forward): +- alpha: 当前时间点的概念超额收益 +- alpha_delta: 使用过去 5 分钟的 alpha 变化 +- amt_ratio: 使用过去 20 分钟的成交额均值 +- rank_pct: 当前时间点所有概念的 alpha 排名 +- limit_up_ratio: 当前时间点的涨停股占比 + +使用方法: + # 实盘模式(推荐)- 自动预热,不依赖 prepare_data.py + python realtime_detector.py + + # 单次检测 + python realtime_detector.py --once + + # 回补历史异动到数据库(需要 prepare_data.py 生成 parquet) + python realtime_detector.py --backfill-only + + # 实盘模式 + 启动时回补历史 + python realtime_detector.py --backfill + +最小数据量要求: +- ML 评分需要 seq_len=15 分钟的序列 +- amt_ratio 需要 amt_ma_window=20 分钟的历史 +- 即:开盘后约 35 分钟才能正常工作 +""" + +import os +import sys +import time +import json +import argparse +import schedule +from datetime import datetime, timedelta +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Set +from collections import defaultdict + +import numpy as np +import pandas as pd +import torch +from sqlalchemy import create_engine, text +from elasticsearch import Elasticsearch +from clickhouse_driver import Client as CHClient + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + + +# ==================== 配置 ==================== + +MYSQL_ENGINE = create_engine( + "mysql+pymysql://root:Zzl5588161!@192.168.1.5:3306/stock", + echo=False, + pool_pre_ping=True, + pool_recycle=3600, +) + +ES_CLIENT = Elasticsearch(['http://127.0.0.1:9200']) +ES_INDEX = 'concept_library_v3' + +CLICKHOUSE_CONFIG = { + 'host': '127.0.0.1', + 'port': 9000, + 'user': 'default', + 'password': 'Zzl33818!', + 'database': 'stock' +} + +REFERENCE_INDEX = '000001.SH' + +FEATURES = ['alpha', 'alpha_delta', 'amt_ratio', 'amt_delta', 'rank_pct', 'limit_up_ratio'] + +# 特征计算参数 +FEATURE_CONFIG = { + 'alpha_delta_window': 5, + 'amt_ma_window': 20, + 'limit_up_threshold': 9.8, + 'limit_down_threshold': -9.8, +} + +# 检测配置(与 backtest_fast.py 保持一致) +CONFIG = { + 'seq_len': 15, + 'min_alpha_abs': 0.3, + 'cooldown_minutes': 8, + 'max_alerts_per_minute': 20, + 'clip_value': 10.0, + # === 融合权重:与 backtest_fast.py 一致 === + 'rule_weight': 0.5, + 'ml_weight': 0.5, + # === 触发阈值:与 backtest_fast.py 一致 === + 'rule_trigger': 65, + 'ml_trigger': 70, + 'fusion_trigger': 45, +} + +TRADING_PERIODS = [ + ('09:30', '11:30'), + ('13:00', '15:00'), +] + + +# ==================== 工具函数 ==================== + +def get_ch_client(): + return CHClient(**CLICKHOUSE_CONFIG) + + +def code_to_ch_format(code: str) -> str: + if not code or len(code) != 6 or not code.isdigit(): + return None + if code.startswith('6'): + return f"{code}.SH" + elif code.startswith('0') or code.startswith('3'): + return f"{code}.SZ" + else: + return f"{code}.BJ" + + +def is_trading_time() -> bool: + now = datetime.now() + if now.weekday() >= 5: + return False + current_time = now.strftime('%H:%M') + for start, end in TRADING_PERIODS: + if start <= current_time <= end: + return True + return False + + +def get_current_trade_date() -> str: + now = datetime.now() + if now.hour < 9: + now = now - timedelta(days=1) + return now.strftime('%Y-%m-%d') + + +# ==================== 数据获取 ==================== + +def get_all_concepts() -> List[dict]: + """从 ES 获取所有概念""" + concepts = [] + query = { + "query": {"match_all": {}}, + "size": 100, + "_source": ["concept_id", "concept", "stocks"] + } + + resp = ES_CLIENT.search(index=ES_INDEX, body=query, scroll='2m') + scroll_id = resp['_scroll_id'] + hits = resp['hits']['hits'] + + while len(hits) > 0: + for hit in hits: + source = hit['_source'] + stocks = [] + if 'stocks' in source and isinstance(source['stocks'], list): + for stock in source['stocks']: + if isinstance(stock, dict) and 'code' in stock and stock['code']: + stocks.append(stock['code']) + + if stocks: + concepts.append({ + 'concept_id': source.get('concept_id'), + 'concept_name': source.get('concept'), + 'stocks': stocks + }) + + resp = ES_CLIENT.scroll(scroll_id=scroll_id, scroll='2m') + scroll_id = resp['_scroll_id'] + hits = resp['hits']['hits'] + + ES_CLIENT.clear_scroll(scroll_id=scroll_id) + print(f"获取到 {len(concepts)} 个概念") + return concepts + + +def get_prev_close(stock_codes: List[str], trade_date: str) -> Dict[str, float]: + """获取昨收价(上一交易日的收盘价 F007N)""" + valid_codes = [c for c in stock_codes if c and len(c) == 6 and c.isdigit()] + if not valid_codes: + return {} + + codes_str = "','".join(valid_codes) + # 注意:F007N 是"最近成交价"即当日收盘价,F002N 是"昨日收盘价" + # 我们需要查上一交易日的 F007N(那天的收盘价)作为今天的昨收 + query = f""" + SELECT SECCODE, F007N + FROM ea_trade + WHERE SECCODE IN ('{codes_str}') + AND TRADEDATE = ( + SELECT MAX(TRADEDATE) FROM ea_trade WHERE TRADEDATE < '{trade_date}' + ) + AND F007N IS NOT NULL AND F007N > 0 + """ + + try: + with MYSQL_ENGINE.connect() as conn: + result = conn.execute(text(query)) + return {row[0]: float(row[1]) for row in result if row[1]} + except Exception as e: + print(f"获取昨收价失败: {e}") + return {} + + +def get_index_prev_close(trade_date: str) -> float: + """获取指数昨收价""" + code_no_suffix = REFERENCE_INDEX.split('.')[0] + try: + with MYSQL_ENGINE.connect() as conn: + result = conn.execute(text(""" + SELECT F006N FROM ea_exchangetrade + WHERE INDEXCODE = :code AND TRADEDATE < :today + ORDER BY TRADEDATE DESC LIMIT 1 + """), {'code': code_no_suffix, 'today': trade_date}).fetchone() + if result and result[0]: + return float(result[0]) + except Exception as e: + print(f"获取指数昨收失败: {e}") + return None + + +def get_stock_minute_data(trade_date: str, stock_codes: List[str], since_time: datetime = None) -> pd.DataFrame: + """ + 从 ClickHouse 获取股票分钟数据 + + Args: + trade_date: 交易日期 + stock_codes: 股票代码列表 + since_time: 只获取该时间之后的数据(增量获取) + """ + client = get_ch_client() + + ch_codes = [] + code_map = {} + for code in stock_codes: + ch_code = code_to_ch_format(code) + if ch_code: + ch_codes.append(ch_code) + code_map[ch_code] = code + + if not ch_codes: + return pd.DataFrame() + + ch_codes_str = "','".join(ch_codes) + + time_filter = "" + if since_time: + time_filter = f"AND timestamp > '{since_time.strftime('%Y-%m-%d %H:%M:%S')}'" + + query = f""" + SELECT code, timestamp, close, volume, amt + FROM stock_minute + WHERE toDate(timestamp) = '{trade_date}' + AND code IN ('{ch_codes_str}') + {time_filter} + ORDER BY code, timestamp + """ + + result = client.execute(query) + if not result: + return pd.DataFrame() + + df = pd.DataFrame(result, columns=['ch_code', 'timestamp', 'close', 'volume', 'amt']) + df['code'] = df['ch_code'].map(code_map) + df = df.dropna(subset=['code']) + return df[['code', 'timestamp', 'close', 'volume', 'amt']] + + +def get_index_minute_data(trade_date: str, since_time: datetime = None) -> pd.DataFrame: + """从 ClickHouse 获取指数分钟数据""" + client = get_ch_client() + + time_filter = "" + if since_time: + time_filter = f"AND timestamp > '{since_time.strftime('%Y-%m-%d %H:%M:%S')}'" + + query = f""" + SELECT timestamp, close, volume, amt + FROM index_minute + WHERE toDate(timestamp) = '{trade_date}' + AND code = '{REFERENCE_INDEX}' + {time_filter} + ORDER BY timestamp + """ + + result = client.execute(query) + if not result: + return pd.DataFrame() + + return pd.DataFrame(result, columns=['timestamp', 'close', 'volume', 'amt']) + + +# ==================== 规则评分 ==================== + +def get_size_adjusted_thresholds(stock_count: np.ndarray) -> np.ndarray: + """根据概念股票数量计算动态阈值""" + n = len(stock_count) + size_factor = np.ones(n) + + size_factor[stock_count < 5] = 1.8 + size_factor[(stock_count >= 5) & (stock_count < 10)] = 1.4 + size_factor[(stock_count >= 10) & (stock_count < 20)] = 1.2 + size_factor[(stock_count >= 20) & (stock_count < 50)] = 1.0 + size_factor[(stock_count >= 50) & (stock_count < 100)] = 0.85 + size_factor[stock_count >= 100] = 0.7 + + return size_factor + + +def score_rules_batch(df: pd.DataFrame) -> Tuple[np.ndarray, List[List[str]]]: + """批量计算规则得分""" + n = len(df) + scores = np.zeros(n) + triggered = [[] for _ in range(n)] + + alpha = df['alpha'].values + alpha_delta = df['alpha_delta'].values + amt_ratio = df['amt_ratio'].values + rank_pct = df['rank_pct'].values + limit_up_ratio = df['limit_up_ratio'].values + stock_count = df['stock_count'].values if 'stock_count' in df.columns else np.full(n, 20) + + alpha_abs = np.abs(alpha) + alpha_delta_abs = np.abs(alpha_delta) + size_factor = get_size_adjusted_thresholds(stock_count) + + # Alpha 规则 + alpha_extreme_thresh = 5.0 * size_factor + mask = alpha_abs >= alpha_extreme_thresh + scores[mask] += 20 + for i in np.where(mask)[0]: triggered[i].append('alpha_extreme') + + alpha_strong_thresh = 4.0 * size_factor + mask = (alpha_abs >= alpha_strong_thresh) & (alpha_abs < alpha_extreme_thresh) + scores[mask] += 15 + for i in np.where(mask)[0]: triggered[i].append('alpha_strong') + + alpha_medium_thresh = 3.0 * size_factor + mask = (alpha_abs >= alpha_medium_thresh) & (alpha_abs < alpha_strong_thresh) + scores[mask] += 10 + for i in np.where(mask)[0]: triggered[i].append('alpha_medium') + + # Alpha 加速度 + delta_strong_thresh = 2.0 * size_factor + mask = alpha_delta_abs >= delta_strong_thresh + scores[mask] += 15 + for i in np.where(mask)[0]: triggered[i].append('alpha_delta_strong') + + delta_medium_thresh = 1.5 * size_factor + mask = (alpha_delta_abs >= delta_medium_thresh) & (alpha_delta_abs < delta_strong_thresh) + scores[mask] += 10 + for i in np.where(mask)[0]: triggered[i].append('alpha_delta_medium') + + # 成交额 + mask = amt_ratio >= 10.0 + scores[mask] += 20 + for i in np.where(mask)[0]: triggered[i].append('volume_extreme') + + mask = (amt_ratio >= 6.0) & (amt_ratio < 10.0) + scores[mask] += 12 + for i in np.where(mask)[0]: triggered[i].append('volume_strong') + + # 排名 + mask = rank_pct >= 0.98 + scores[mask] += 15 + for i in np.where(mask)[0]: triggered[i].append('rank_top') + + mask = rank_pct <= 0.02 + scores[mask] += 15 + for i in np.where(mask)[0]: triggered[i].append('rank_bottom') + + # 涨停 + limit_high_thresh = 0.30 * size_factor + mask = limit_up_ratio >= limit_high_thresh + scores[mask] += 20 + for i in np.where(mask)[0]: triggered[i].append('limit_up_high') + + limit_medium_thresh = 0.20 * size_factor + mask = (limit_up_ratio >= limit_medium_thresh) & (limit_up_ratio < limit_high_thresh) + scores[mask] += 12 + for i in np.where(mask)[0]: triggered[i].append('limit_up_medium') + + # 概念规模加分 + large_concept = stock_count >= 50 + has_signal = scores > 0 + mask = large_concept & has_signal + scores[mask] += 10 + for i in np.where(mask)[0]: triggered[i].append('large_concept_bonus') + + xlarge_concept = stock_count >= 100 + mask = xlarge_concept & has_signal + scores[mask] += 10 + for i in np.where(mask)[0]: triggered[i].append('xlarge_concept_bonus') + + # 组合规则 + combo_alpha_thresh = 3.0 * size_factor + mask = (alpha_abs >= combo_alpha_thresh) & (amt_ratio >= 5.0) & ((rank_pct >= 0.95) | (rank_pct <= 0.05)) + scores[mask] += 20 + for i in np.where(mask)[0]: triggered[i].append('triple_signal') + + mask = (alpha_abs >= combo_alpha_thresh) & (limit_up_ratio >= 0.15 * size_factor) + scores[mask] += 15 + for i in np.where(mask)[0]: triggered[i].append('alpha_with_limit') + + # 小概念惩罚 + tiny_concept = stock_count < 5 + single_rule = np.array([len(t) <= 1 for t in triggered]) + mask = tiny_concept & single_rule & (scores > 0) + scores[mask] *= 0.5 + for i in np.where(mask)[0]: triggered[i].append('tiny_concept_penalty') + + scores = np.clip(scores, 0, 100) + return scores, triggered + + +def rule_score_with_details(features: Dict, stock_count: int = 50) -> Tuple[float, Dict[str, float]]: + """ + 单条记录的规则评分(带详情) + + Args: + features: 特征字典,包含 alpha, alpha_delta, amt_ratio, rank_pct, limit_up_ratio + stock_count: 概念股票数量 + + Returns: + (score, details): 总分和各规则触发详情 + """ + score = 0.0 + details = {} + + alpha = features.get('alpha', 0) + alpha_delta = features.get('alpha_delta', 0) + amt_ratio = features.get('amt_ratio', 1) + rank_pct = features.get('rank_pct', 0.5) + limit_up_ratio = features.get('limit_up_ratio', 0) + + alpha_abs = abs(alpha) + alpha_delta_abs = abs(alpha_delta) + size_factor = get_size_adjusted_thresholds(np.array([stock_count]))[0] + + # Alpha 规则 + alpha_extreme_thresh = 5.0 * size_factor + alpha_strong_thresh = 4.0 * size_factor + alpha_medium_thresh = 3.0 * size_factor + + if alpha_abs >= alpha_extreme_thresh: + score += 20 + details['alpha_extreme'] = 20 + elif alpha_abs >= alpha_strong_thresh: + score += 15 + details['alpha_strong'] = 15 + elif alpha_abs >= alpha_medium_thresh: + score += 10 + details['alpha_medium'] = 10 + + # Alpha 加速度 + delta_strong_thresh = 2.0 * size_factor + delta_medium_thresh = 1.5 * size_factor + + if alpha_delta_abs >= delta_strong_thresh: + score += 15 + details['alpha_delta_strong'] = 15 + elif alpha_delta_abs >= delta_medium_thresh: + score += 10 + details['alpha_delta_medium'] = 10 + + # 成交额 + if amt_ratio >= 10.0: + score += 20 + details['volume_extreme'] = 20 + elif amt_ratio >= 6.0: + score += 12 + details['volume_strong'] = 12 + + # 排名 + if rank_pct >= 0.98: + score += 15 + details['rank_top'] = 15 + elif rank_pct <= 0.02: + score += 15 + details['rank_bottom'] = 15 + + # 涨停 + limit_high_thresh = 0.30 * size_factor + limit_medium_thresh = 0.20 * size_factor + + if limit_up_ratio >= limit_high_thresh: + score += 20 + details['limit_up_high'] = 20 + elif limit_up_ratio >= limit_medium_thresh: + score += 12 + details['limit_up_medium'] = 12 + + # 概念规模加分 + if score > 0: + if stock_count >= 50: + score += 10 + details['large_concept_bonus'] = 10 + if stock_count >= 100: + score += 10 + details['xlarge_concept_bonus'] = 10 + + # 组合规则 + combo_alpha_thresh = 3.0 * size_factor + + if alpha_abs >= combo_alpha_thresh and amt_ratio >= 5.0 and (rank_pct >= 0.95 or rank_pct <= 0.05): + score += 20 + details['triple_signal'] = 20 + + if alpha_abs >= combo_alpha_thresh and limit_up_ratio >= 0.15 * size_factor: + score += 15 + details['alpha_with_limit'] = 15 + + # 小概念惩罚 + if stock_count < 5 and len(details) <= 1 and score > 0: + penalty = score * 0.5 + score *= 0.5 + details['tiny_concept_penalty'] = -penalty + + score = min(max(score, 0), 100) + return score, details + + +# ==================== ML 评分器 ==================== + +class MLScorer: + def __init__(self, checkpoint_dir: str = 'ml/checkpoints', device: str = 'auto'): + self.checkpoint_dir = Path(checkpoint_dir) + if device == 'auto': + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + else: + self.device = torch.device(device) + + self.model = None + self.thresholds = None + self._load_model() + + def _load_model(self): + model_path = self.checkpoint_dir / 'best_model.pt' + thresholds_path = self.checkpoint_dir / 'thresholds.json' + config_path = self.checkpoint_dir / 'config.json' + + if not model_path.exists(): + print(f"警告: 模型不存在 {model_path}") + return + + try: + from model import LSTMAutoencoder + + config = {} + if config_path.exists(): + with open(config_path) as f: + config = json.load(f).get('model', {}) + + if 'd_model' in config: + config['hidden_dim'] = config.pop('d_model') // 2 + for key in ['num_encoder_layers', 'num_decoder_layers', 'nhead', 'dim_feedforward', 'max_seq_len', 'use_instance_norm']: + config.pop(key, None) + if 'num_layers' not in config: + config['num_layers'] = 1 + + checkpoint = torch.load(model_path, map_location='cpu') + self.model = LSTMAutoencoder(**config) + self.model.load_state_dict(checkpoint['model_state_dict']) + self.model.to(self.device) + self.model.eval() + + if thresholds_path.exists(): + with open(thresholds_path) as f: + self.thresholds = json.load(f) + + print(f"ML模型加载成功 (设备: {self.device})") + except Exception as e: + print(f"ML模型加载失败: {e}") + + def is_ready(self): + return self.model is not None + + @torch.no_grad() + def score_batch(self, sequences: np.ndarray, debug: bool = False) -> np.ndarray: + if not self.is_ready() or len(sequences) == 0: + return np.zeros(len(sequences)) + + x = torch.FloatTensor(sequences).to(self.device) + output, _ = self.model(x) + mse = ((output - x) ** 2).mean(dim=-1) + errors = mse[:, -1].cpu().numpy() + + p95 = self.thresholds.get('p95', 0.1) if self.thresholds else 0.1 + scores = np.clip(errors / p95 * 50, 0, 100) + + if debug and len(errors) > 0: + print(f"[ML调试] p95={p95:.4f}, errors: min={errors.min():.4f}, max={errors.max():.4f}, mean={errors.mean():.4f}") + print(f"[ML调试] scores: min={scores.min():.0f}, max={scores.max():.0f}, mean={scores.mean():.0f}, =100占比={100*(scores>=100).mean():.1f}%") + + return scores + + +# ==================== 内存数据管理器 ==================== + +class RealtimeDataManager: + """ + 内存数据管理器 + + - 缓存股票分钟数据和指数数据 + - 增量获取新数据 + - 实时计算概念特征 + """ + + def __init__(self, concepts: List[dict], prev_close: Dict[str, float], index_prev_close: float): + self.concepts = concepts + self.prev_close = prev_close + self.index_prev_close = index_prev_close + + # 概念到股票的映射 + self.concept_stocks = {c['concept_id']: set(c['stocks']) for c in concepts} + self.all_stocks = list(set(s for c in concepts for s in c['stocks'])) + + # 内存缓存:股票分钟数据 + self.stock_data = pd.DataFrame() # code, timestamp, close, volume, amt, change_pct + self.index_data = pd.DataFrame() # timestamp, close, change_pct + + # 最后更新时间 + self.last_update_time = None + + # 概念历史(用于计算变化率) + self.concept_history = defaultdict(lambda: {'alpha': [], 'amt': []}) + + # 概念特征时间序列(用于 ML) + self.concept_features_history = defaultdict(list) # concept_id -> list of feature dicts + + def update(self, trade_date: str) -> int: + """ + 增量更新数据 + + Returns: + 新增的时间点数量 + """ + # 获取增量数据 + new_stock_df = get_stock_minute_data(trade_date, self.all_stocks, self.last_update_time) + new_index_df = get_index_minute_data(trade_date, self.last_update_time) + + if new_stock_df.empty and new_index_df.empty: + return 0 + + # 计算涨跌幅 + if not new_stock_df.empty: + new_stock_df['prev_close'] = new_stock_df['code'].map(self.prev_close) + new_stock_df = new_stock_df.dropna(subset=['prev_close']) + new_stock_df['change_pct'] = (new_stock_df['close'] - new_stock_df['prev_close']) / new_stock_df['prev_close'] * 100 + + # 合并到缓存 + self.stock_data = pd.concat([self.stock_data, new_stock_df], ignore_index=True) + self.stock_data = self.stock_data.drop_duplicates(subset=['code', 'timestamp'], keep='last') + + if not new_index_df.empty: + new_index_df['change_pct'] = (new_index_df['close'] - self.index_prev_close) / self.index_prev_close * 100 + + # 调试:打印指数涨跌幅范围 + if len(self.index_data) == 0: # 第一次 + print(f"[调试] 指数 close 范围: {new_index_df['close'].min():.2f} ~ {new_index_df['close'].max():.2f}") + print(f"[调试] 指数 change_pct 范围: {new_index_df['change_pct'].min():.2f}% ~ {new_index_df['change_pct'].max():.2f}%") + + self.index_data = pd.concat([self.index_data, new_index_df], ignore_index=True) + self.index_data = self.index_data.drop_duplicates(subset=['timestamp'], keep='last') + + # 更新最后时间 + if not new_stock_df.empty: + self.last_update_time = new_stock_df['timestamp'].max() + elif not new_index_df.empty: + self.last_update_time = new_index_df['timestamp'].max() + + # 获取新时间点 + new_timestamps = sorted(new_stock_df['timestamp'].unique()) if not new_stock_df.empty else [] + + # 计算新时间点的概念特征 + for ts in new_timestamps: + self._compute_features_for_timestamp(ts) + + return len(new_timestamps) + + def _compute_features_for_timestamp(self, ts): + """计算单个时间点的概念特征""" + ts_stock_data = self.stock_data[self.stock_data['timestamp'] == ts] + index_row = self.index_data[self.index_data['timestamp'] == ts] + + if ts_stock_data.empty or index_row.empty: + return + + index_change = index_row['change_pct'].values[0] + stock_change = dict(zip(ts_stock_data['code'], ts_stock_data['change_pct'])) + stock_amt = dict(zip(ts_stock_data['code'], ts_stock_data['amt'])) + + for concept_id, stocks in self.concept_stocks.items(): + concept_changes = [stock_change[s] for s in stocks if s in stock_change] + concept_amts = [stock_amt.get(s, 0) for s in stocks if s in stock_change] + + if not concept_changes: + continue + + avg_change = np.mean(concept_changes) + total_amt = sum(concept_amts) + alpha = avg_change - index_change + + # 涨停比例 + limit_up_count = sum(1 for c in concept_changes if c >= FEATURE_CONFIG['limit_up_threshold']) + limit_up_ratio = limit_up_count / len(concept_changes) + + # 更新历史 + history = self.concept_history[concept_id] + history['alpha'].append(alpha) + history['amt'].append(total_amt) + + # 计算变化率 + alpha_delta = 0 + if len(history['alpha']) > FEATURE_CONFIG['alpha_delta_window']: + alpha_delta = alpha - history['alpha'][-FEATURE_CONFIG['alpha_delta_window'] - 1] + + amt_ratio = 1.0 + amt_delta = 0 + if len(history['amt']) > FEATURE_CONFIG['amt_ma_window']: + amt_ma = np.mean(history['amt'][-FEATURE_CONFIG['amt_ma_window'] - 1:-1]) + if amt_ma > 0: + amt_ratio = total_amt / amt_ma + amt_delta = total_amt - history['amt'][-2] if len(history['amt']) > 1 else 0 + + features = { + 'timestamp': ts, + 'concept_id': concept_id, + 'alpha': alpha, + 'alpha_delta': alpha_delta, + 'amt_ratio': amt_ratio, + 'amt_delta': amt_delta, + 'limit_up_ratio': limit_up_ratio, + 'stock_count': len(concept_changes), + 'total_amt': total_amt, + } + + self.concept_features_history[concept_id].append(features) + + def get_latest_features(self) -> pd.DataFrame: + """获取最新时间点的所有概念特征""" + if not self.concept_features_history: + return pd.DataFrame() + + latest_features = [] + for concept_id, history in self.concept_features_history.items(): + if history: + latest_features.append(history[-1]) + + if not latest_features: + return pd.DataFrame() + + df = pd.DataFrame(latest_features) + + # 计算排名百分位 + if len(df) > 1: + df['rank_pct'] = df['alpha'].rank(pct=True) + else: + df['rank_pct'] = 0.5 + + return df + + def get_sequences_for_concepts(self, seq_len: int) -> Tuple[np.ndarray, pd.DataFrame]: + """获取所有概念的特征序列(用于 ML 评分)""" + sequences = [] + infos = [] + + for concept_id, history in self.concept_features_history.items(): + if len(history) < seq_len: + continue + + # 取最近 seq_len 个时间点 + recent = history[-seq_len:] + + # 构建序列 + seq = np.array([[ + f['alpha'], + f['alpha_delta'], + f['amt_ratio'], + f['amt_delta'], + f.get('rank_pct', 0.5), + f['limit_up_ratio'] + ] for f in recent]) + + seq = np.nan_to_num(seq, nan=0.0, posinf=0.0, neginf=0.0) + seq = np.clip(seq, -CONFIG['clip_value'], CONFIG['clip_value']) + + sequences.append(seq) + infos.append(recent[-1]) # 最新特征 + + if not sequences: + return np.array([]), pd.DataFrame() + + # 补充 rank_pct + info_df = pd.DataFrame(infos) + if 'rank_pct' not in info_df.columns and len(info_df) > 1: + info_df['rank_pct'] = info_df['alpha'].rank(pct=True) + + return np.array(sequences), info_df + + def get_all_timestamps(self) -> List: + """获取所有时间点""" + if self.stock_data.empty: + return [] + return sorted(self.stock_data['timestamp'].unique()) + + def get_concept_features_df(self) -> pd.DataFrame: + """获取概念特征的 DataFrame 形式(用于批量回测)""" + if not self.concept_features_history: + return pd.DataFrame() + + rows = [] + for concept_id, history in self.concept_features_history.items(): + for f in history: + row = { + 'concept_id': concept_id, + 'timestamp': f['timestamp'], + 'alpha': f['alpha'], + 'alpha_delta': f['alpha_delta'], + 'amt_ratio': f['amt_ratio'], + 'amt_delta': f.get('amt_delta', 0), + 'limit_up_ratio': f['limit_up_ratio'], + 'stock_count': f.get('stock_count', 0), + 'total_amt': f.get('total_amt', 0), + } + rows.append(row) + + if not rows: + return pd.DataFrame() + + df = pd.DataFrame(rows) + + # 按时间点计算 rank_pct(每个时间点内部排名) + df['rank_pct'] = df.groupby('timestamp')['alpha'].rank(pct=True) + + return df + + +# ==================== 冷却期管理 ==================== + +class CooldownManager: + def __init__(self, cooldown_minutes: int = 8): + self.cooldown_minutes = cooldown_minutes + self.last_alert_time = {} + + def is_in_cooldown(self, concept_id: str, current_time: datetime) -> bool: + if concept_id not in self.last_alert_time: + return False + last_time = self.last_alert_time[concept_id] + diff = (current_time - last_time).total_seconds() / 60 + return diff < self.cooldown_minutes + + def record_alert(self, concept_id: str, alert_time: datetime): + self.last_alert_time[concept_id] = alert_time + + def cleanup_old(self, current_time: datetime): + cutoff = current_time - timedelta(minutes=self.cooldown_minutes * 2) + self.last_alert_time = {cid: t for cid, t in self.last_alert_time.items() if t > cutoff} + + +# ==================== 异动检测 ==================== + +def detect_anomalies( + ml_scorer: MLScorer, + data_mgr: RealtimeDataManager, + cooldown_mgr: CooldownManager, + trade_date: str, + config: Dict +) -> List[Dict]: + """检测当前时刻的异动""" + + # 获取最新特征 + latest_df = data_mgr.get_latest_features() + if latest_df.empty: + return [] + + # 获取 ML 序列 + sequences, info_df = data_mgr.get_sequences_for_concepts(config['seq_len']) + + if len(sequences) == 0: + return [] + + # 获取当前时间 + current_time = pd.to_datetime(info_df['timestamp'].iloc[0]) + + # 清理过期冷却 + cooldown_mgr.cleanup_old(current_time) + + # 过滤冷却中的概念 + valid_mask = [] + for _, row in info_df.iterrows(): + in_cooldown = cooldown_mgr.is_in_cooldown(row['concept_id'], current_time) + valid_mask.append(not in_cooldown) + + valid_mask = np.array(valid_mask) + sequences = sequences[valid_mask] + info_df = info_df[valid_mask].reset_index(drop=True) + + if len(sequences) == 0: + return [] + + # 过滤小波动 + alpha_mask = np.abs(info_df['alpha'].values) >= config['min_alpha_abs'] + sequences = sequences[alpha_mask] + info_df = info_df[alpha_mask].reset_index(drop=True) + + if len(sequences) == 0: + return [] + + # 规则评分 + rule_scores, triggered_rules = score_rules_batch(info_df) + + # ML 评分 + ml_scores = ml_scorer.score_batch(sequences) + + # 融合得分 + w1, w2 = config['rule_weight'], config['ml_weight'] + final_scores = w1 * rule_scores + w2 * ml_scores + + # 判断异动 + alerts = [] + for i, row in info_df.iterrows(): + rule_score = rule_scores[i] + ml_score = ml_scores[i] + final_score = final_scores[i] + + is_anomaly = ( + rule_score >= config['rule_trigger'] or + ml_score >= config['ml_trigger'] or + final_score >= config['fusion_trigger'] + ) + + if not is_anomaly: + continue + + # 触发原因 + if rule_score >= config['rule_trigger']: + trigger = f'规则强信号({rule_score:.0f}分)' + elif ml_score >= config['ml_trigger']: + trigger = f'ML强信号({ml_score:.0f}分)' + else: + trigger = f'融合触发({final_score:.0f}分)' + + # 异动类型 + alpha = row['alpha'] + if alpha >= 1.5: + alert_type = 'surge_up' + elif alpha <= -1.5: + alert_type = 'surge_down' + elif row['amt_ratio'] >= 3.0: + alert_type = 'volume_spike' + else: + alert_type = 'unknown' + + alert = { + 'concept_id': row['concept_id'], + 'alert_time': row['timestamp'], + 'trade_date': trade_date, + 'alert_type': alert_type, + 'final_score': final_score, + 'rule_score': rule_score, + 'ml_score': ml_score, + 'trigger_reason': trigger, + 'triggered_rules': triggered_rules[i], + 'alpha': row['alpha'], + 'alpha_delta': row['alpha_delta'], + 'amt_ratio': row['amt_ratio'], + 'amt_delta': row.get('amt_delta', 0), + 'rank_pct': row.get('rank_pct', 0.5), + 'limit_up_ratio': row['limit_up_ratio'], + 'stock_count': row['stock_count'], + 'total_amt': row['total_amt'], + } + + alerts.append(alert) + cooldown_mgr.record_alert(row['concept_id'], current_time) + + # 按得分排序 + alerts.sort(key=lambda x: x['final_score'], reverse=True) + return alerts[:config['max_alerts_per_minute']] + + +# ==================== 数据库写入 ==================== + +def save_alerts_to_mysql(alerts: List[Dict]) -> int: + if not alerts: + return 0 + + saved = 0 + with MYSQL_ENGINE.begin() as conn: + for alert in alerts: + try: + insert_sql = text(""" + INSERT IGNORE INTO concept_anomaly_hybrid + (concept_id, alert_time, trade_date, alert_type, + final_score, rule_score, ml_score, trigger_reason, + alpha, alpha_delta, amt_ratio, amt_delta, + rank_pct, limit_up_ratio, stock_count, total_amt, + triggered_rules) + VALUES + (:concept_id, :alert_time, :trade_date, :alert_type, + :final_score, :rule_score, :ml_score, :trigger_reason, + :alpha, :alpha_delta, :amt_ratio, :amt_delta, + :rank_pct, :limit_up_ratio, :stock_count, :total_amt, + :triggered_rules) + """) + + result = conn.execute(insert_sql, { + 'concept_id': alert['concept_id'], + 'alert_time': alert['alert_time'], + 'trade_date': alert['trade_date'], + 'alert_type': alert['alert_type'], + 'final_score': alert['final_score'], + 'rule_score': alert['rule_score'], + 'ml_score': alert['ml_score'], + 'trigger_reason': alert['trigger_reason'], + 'alpha': alert.get('alpha', 0), + 'alpha_delta': alert.get('alpha_delta', 0), + 'amt_ratio': alert.get('amt_ratio', 1), + 'amt_delta': alert.get('amt_delta', 0), + 'rank_pct': alert.get('rank_pct', 0.5), + 'limit_up_ratio': alert.get('limit_up_ratio', 0), + 'stock_count': alert.get('stock_count', 0), + 'total_amt': alert.get('total_amt', 0), + 'triggered_rules': json.dumps(alert.get('triggered_rules', []), ensure_ascii=False), + }) + + if result.rowcount > 0: + saved += 1 + except Exception as e: + print(f"保存失败: {alert['concept_id']} - {e}") + + return saved + + +# ==================== 主服务 ==================== + +class RealtimeDetectorService: + def __init__(self, checkpoint_dir: str = 'ml/checkpoints', device: str = 'auto'): + self.checkpoint_dir = checkpoint_dir + self.device = device + + # 初始化 ML 评分器 + self.ml_scorer = MLScorer(checkpoint_dir, device) + + # 这些在 init_for_trade_date 中初始化 + self.data_mgr = None + self.cooldown_mgr = None + self.trade_date = None + + def init_for_trade_date(self, trade_date: str, preload_history: bool = True): + """ + 为指定交易日初始化 + + Args: + trade_date: 交易日期 + preload_history: 是否预加载当天已有的历史数据(实盘必须为 True) + """ + if self.trade_date == trade_date and self.data_mgr is not None: + return + + print(f"[初始化] 交易日: {trade_date}") + + # 获取概念列表 + print(f"[初始化] 获取概念列表...") + concepts = get_all_concepts() + + # 获取所有股票 + all_stocks = list(set(s for c in concepts for s in c['stocks'])) + print(f"[初始化] 共 {len(all_stocks)} 只股票") + + # 获取昨收价 + print(f"[初始化] 获取昨收价...") + prev_close = get_prev_close(all_stocks, trade_date) + index_prev_close = get_index_prev_close(trade_date) + print(f"[初始化] 获取到 {len(prev_close)} 只股票的昨收价") + print(f"[初始化] 指数昨收价: {index_prev_close}") + + # 创建数据管理器 + self.data_mgr = RealtimeDataManager(concepts, prev_close, index_prev_close) + self.cooldown_mgr = CooldownManager(CONFIG['cooldown_minutes']) + self.trade_date = trade_date + + # 预加载当天已有的历史数据(实盘关键) + if preload_history: + self._preload_today_history(trade_date) + + def _preload_today_history(self, trade_date: str): + """ + 预加载当天已有的历史数据到内存 + + 这是实盘运行的关键: + - 在盘中任意时刻启动服务时,需要先加载当天已有的数据 + - 这样才能正确计算 alpha_delta(需要过去 5 分钟)和 amt_ratio(需要过去 20 分钟) + - 以及构建 ML 所需的序列(需要 seq_len=15 分钟) + + 整个过程不依赖 prepare_data.py,直接从 ClickHouse 读取原始数据计算 + """ + print(f"[预热] 加载当天历史数据...") + + # 直接调用 update,但不设置 last_update_time,会获取当天所有数据 + # data_mgr.last_update_time 初始为 None,会获取全部数据 + n_updates = self.data_mgr.update(trade_date) + + if n_updates > 0: + print(f"[预热] 加载完成,共 {n_updates} 个时间点") + + # 检查是否满足 ML 所需的最小数据量 + min_required = CONFIG['seq_len'] + FEATURE_CONFIG['amt_ma_window'] + if n_updates < min_required: + print(f"[预热] 警告:数据量 {n_updates} < 最小需求 {min_required},部分特征可能不准确") + else: + print(f"[预热] 数据充足,可以正常检测") + else: + print(f"[预热] 当天暂无历史数据(可能是开盘前)") + + def backfill_today(self): + """ + 补齐当天历史数据并检测异动(回补模式) + + 使用与 backtest_fast.py 完全相同的逻辑: + 1. 先用 prepare_data.py 生成当天的 parquet 文件 + 2. 读取 parquet 文件进行回测 + + 注意:这个方法用于回补历史异动记录,不是实盘必须的 + 实盘模式下,init_for_trade_date 会自动预热历史数据 + """ + trade_date = get_current_trade_date() + print(f"[补齐] 交易日: {trade_date}") + + # 1. 生成当天的 parquet 文件 + parquet_path = Path('ml/data') / f'features_{trade_date}.parquet' + + if not parquet_path.exists(): + print(f"[补齐] 生成当天特征数据...") + self._generate_today_parquet(trade_date) + + if not parquet_path.exists(): + print(f"[补齐] 无法生成特征数据,跳过") + return + + # 2. 读取 parquet 文件 + df = pd.read_parquet(parquet_path) + print(f"[补齐] 读取到 {len(df)} 条特征数据") + + if df.empty: + print("[补齐] 无数据") + return + + # 打印特征分布(调试) + print(f"[调试] alpha 分布: min={df['alpha'].min():.2f}, max={df['alpha'].max():.2f}, mean={df['alpha'].mean():.2f}") + print(f"[调试] |alpha| >= 0.3 的数量: {(df['alpha'].abs() >= 0.3).sum()}") + + # 3. 使用 backtest_fast.py 相同的回测逻辑 + alerts = self._backtest_from_parquet(df, trade_date) + + # 4. 保存结果 + if alerts: + saved = save_alerts_to_mysql(alerts) + print(f"[补齐] 完成!共 {len(alerts)} 个异动, 保存 {saved} 条") + + # 统计触发来源 + trigger_stats = {'规则': 0, 'ML': 0, '融合': 0} + for a in alerts: + reason = a['trigger_reason'] + if '规则' in reason: + trigger_stats['规则'] += 1 + elif 'ML' in reason: + trigger_stats['ML'] += 1 + else: + trigger_stats['融合'] += 1 + print(f"[补齐] 触发来源: {trigger_stats}") + else: + print("[补齐] 无异动") + + def _generate_today_parquet(self, trade_date: str): + """ + 生成当天的 parquet 文件(调用 prepare_data.py 的逻辑) + """ + import subprocess + cmd = ['python', 'ml/prepare_data.py', '--start', trade_date, '--end', trade_date] + print(f"[补齐] 执行: {' '.join(cmd)}") + try: + result = subprocess.run(cmd, capture_output=True, text=True, timeout=300) + if result.returncode != 0: + print(f"[补齐] prepare_data.py 执行失败: {result.stderr}") + except Exception as e: + print(f"[补齐] prepare_data.py 执行异常: {e}") + + def _backtest_from_parquet(self, df: pd.DataFrame, trade_date: str) -> List[Dict]: + """ + 从 parquet 数据回测(与 backtest_fast.py 完全一致的逻辑) + """ + seq_len = CONFIG['seq_len'] + now = datetime.now() + + # 确保按概念和时间排序 + df = df.sort_values(['concept_id', 'timestamp']).reset_index(drop=True) + + # 获取所有时间点 + all_timestamps = sorted(df['timestamp'].unique()) + + # 只处理当前时间之前的 + past_timestamps = [] + for ts in all_timestamps: + try: + ts_dt = pd.to_datetime(ts) + if ts_dt.tzinfo is not None: + ts_dt = ts_dt.tz_localize(None) + if ts_dt < now: + past_timestamps.append(ts) + except Exception: + continue + + print(f"[补齐] 处理 {len(past_timestamps)} 个历史时间点...") + + if len(past_timestamps) < seq_len: + print(f"[补齐] 时间点不足 {seq_len},跳过") + return [] + + # 构建序列(与 backtest_fast.py 的 build_sequences_fast 一致) + sequences = [] + infos = [] + + groups = df.groupby('concept_id') + + for concept_id, gdf in groups: + gdf = gdf.reset_index(drop=True) + feat_matrix = gdf[FEATURES].values + feat_matrix = np.nan_to_num(feat_matrix, nan=0.0, posinf=0.0, neginf=0.0) + feat_matrix = np.clip(feat_matrix, -CONFIG['clip_value'], CONFIG['clip_value']) + + n_total = len(feat_matrix) + if n_total < seq_len: + continue + + for i in range(n_total - seq_len + 1): + seq = feat_matrix[i:i + seq_len] + row = gdf.iloc[i + seq_len - 1] + + # 只保留当前时间之前的 + ts = row['timestamp'] + try: + ts_dt = pd.to_datetime(ts) + if ts_dt.tzinfo is not None: + ts_dt = ts_dt.tz_localize(None) + if ts_dt >= now: + continue + except Exception: + continue + + sequences.append(seq) + infos.append({ + 'concept_id': concept_id, + 'timestamp': row['timestamp'], + 'alpha': row['alpha'], + 'alpha_delta': row.get('alpha_delta', 0), + 'amt_ratio': row.get('amt_ratio', 1), + 'amt_delta': row.get('amt_delta', 0), + 'rank_pct': row.get('rank_pct', 0.5), + 'limit_up_ratio': row.get('limit_up_ratio', 0), + 'stock_count': row.get('stock_count', 0), + 'total_amt': row.get('total_amt', 0), + }) + + if not sequences: + return [] + + sequences = np.array(sequences) + info_df = pd.DataFrame(infos) + + print(f"[补齐] 构建了 {len(sequences)} 个序列") + + # 过滤小波动 + alpha_abs = np.abs(info_df['alpha'].values) + valid_mask = alpha_abs >= CONFIG['min_alpha_abs'] + sequences = sequences[valid_mask] + info_df = info_df[valid_mask].reset_index(drop=True) + + if len(sequences) == 0: + return [] + + print(f"[补齐] 过滤后 {len(sequences)} 个序列") + + # 批量规则评分 + rule_scores, triggered_rules = score_rules_batch(info_df) + + # 批量 ML 评分 + batch_size = 2048 + ml_scores = [] + for i in range(0, len(sequences), batch_size): + batch_seq = sequences[i:i+batch_size] + batch_scores = self.ml_scorer.score_batch(batch_seq) + ml_scores.append(batch_scores) + ml_scores = np.concatenate(ml_scores) if ml_scores else np.zeros(len(sequences)) + + # 融合得分 + w1, w2 = CONFIG['rule_weight'], CONFIG['ml_weight'] + final_scores = w1 * rule_scores + w2 * ml_scores + + # 判断异动 + is_anomaly = ( + (rule_scores >= CONFIG['rule_trigger']) | + (ml_scores >= CONFIG['ml_trigger']) | + (final_scores >= CONFIG['fusion_trigger']) + ) + + # 添加分数到 info_df + info_df['rule_score'] = rule_scores + info_df['ml_score'] = ml_scores + info_df['final_score'] = final_scores + info_df['is_anomaly'] = is_anomaly + info_df['triggered_rules'] = triggered_rules + + # 只保留异动 + anomaly_df = info_df[info_df['is_anomaly']].copy() + + if len(anomaly_df) == 0: + return [] + + print(f"[补齐] 发现 {len(anomaly_df)} 个候选异动") + + # 应用冷却期 + anomaly_df = anomaly_df.sort_values(['concept_id', 'timestamp']) + cooldown = {} + keep_mask = [] + + for _, row in anomaly_df.iterrows(): + cid = row['concept_id'] + ts = row['timestamp'] + + if cid in cooldown: + try: + diff = (pd.to_datetime(ts) - pd.to_datetime(cooldown[cid])).total_seconds() / 60 + except: + diff = CONFIG['cooldown_minutes'] + 1 + + if diff < CONFIG['cooldown_minutes']: + keep_mask.append(False) + continue + + cooldown[cid] = ts + keep_mask.append(True) + + anomaly_df = anomaly_df[keep_mask] + + print(f"[补齐] 冷却后 {len(anomaly_df)} 个异动") + + # 按时间分组,每分钟最多 max_alerts_per_minute 个 + alerts = [] + for ts, group in anomaly_df.groupby('timestamp'): + group = group.nlargest(CONFIG['max_alerts_per_minute'], 'final_score') + + for _, row in group.iterrows(): + alpha = row['alpha'] + if alpha >= 1.5: + atype = 'surge_up' + elif alpha <= -1.5: + atype = 'surge_down' + elif row['amt_ratio'] >= 3.0: + atype = 'volume_spike' + else: + atype = 'unknown' + + rule_score = row['rule_score'] + ml_score = row['ml_score'] + final_score = row['final_score'] + + if rule_score >= CONFIG['rule_trigger']: + trigger = f'规则强信号({rule_score:.0f}分)' + elif ml_score >= CONFIG['ml_trigger']: + trigger = f'ML强信号({ml_score:.0f}分)' + else: + trigger = f'融合触发({final_score:.0f}分)' + + alerts.append({ + 'concept_id': row['concept_id'], + 'alert_time': row['timestamp'], + 'trade_date': trade_date, + 'alert_type': atype, + 'final_score': final_score, + 'rule_score': rule_score, + 'ml_score': ml_score, + 'trigger_reason': trigger, + 'triggered_rules': row['triggered_rules'], + 'alpha': alpha, + 'alpha_delta': row['alpha_delta'], + 'amt_ratio': row['amt_ratio'], + 'amt_delta': row['amt_delta'], + 'rank_pct': row['rank_pct'], + 'limit_up_ratio': row['limit_up_ratio'], + 'stock_count': row['stock_count'], + 'total_amt': row['total_amt'], + }) + + return alerts + + def run_once(self): + """执行一次检测""" + now = datetime.now() + trade_date = get_current_trade_date() + + if not is_trading_time(): + print(f"[{now.strftime('%H:%M:%S')}] 非交易时间,跳过") + return + + # 初始化 + self.init_for_trade_date(trade_date) + + print(f"[{now.strftime('%H:%M:%S')}] 获取新数据...") + + # 增量更新 + n_updates = self.data_mgr.update(trade_date) + print(f" 新增 {n_updates} 个时间点") + + if n_updates == 0: + print(f" 无新数据") + return + + # 检测 + alerts = detect_anomalies( + self.ml_scorer, + self.data_mgr, + self.cooldown_mgr, + trade_date, + CONFIG + ) + + if alerts: + saved = save_alerts_to_mysql(alerts) + print(f" 检测到 {len(alerts)} 个异动, 保存 {saved} 条") + + for alert in alerts[:5]: + print(f" - {alert['concept_id']}: {alert['alert_type']} " + f"(final={alert['final_score']:.0f}, rule={alert['rule_score']:.0f}, ml={alert['ml_score']:.0f})") + else: + print(f" 无异动") + + def run_loop(self, backfill: bool = False): + """ + 持续运行(实盘模式) + + Args: + backfill: 是否回补历史异动到数据库(使用 prepare_data.py 方式) + 默认 False,因为实盘模式下 init_for_trade_date 会自动预热数据 + """ + print("=" * 60) + print("实时概念异动检测服务(实盘模式)") + print("=" * 60) + print(f"模型目录: {self.checkpoint_dir}") + print(f"交易时段: {TRADING_PERIODS}") + print(f"ML 序列长度: {CONFIG['seq_len']} 分钟") + print(f"成交额均值窗口: {FEATURE_CONFIG['amt_ma_window']} 分钟") + print("=" * 60) + + # 立即初始化并预热(即使不在交易时间也预热,方便测试) + trade_date = get_current_trade_date() + print(f"\n[启动] 初始化交易日 {trade_date}...") + self.init_for_trade_date(trade_date, preload_history=True) + + # 可选:回补历史异动记录到数据库 + if backfill and is_trading_time(): + print("\n[启动] 回补历史异动...") + self.backfill_today() + + # 每分钟第 10 秒执行 + schedule.every().minute.at(":10").do(self.run_once) + + print("\n服务已启动,等待下一分钟...") + + while True: + schedule.run_pending() + time.sleep(1) + + +# ==================== 主函数 ==================== + +def main(): + parser = argparse.ArgumentParser(description='实时概念异动检测') + parser.add_argument('--checkpoint_dir', default='ml/checkpoints', help='模型目录') + parser.add_argument('--device', default='auto', help='设备 (auto/cpu/cuda)') + parser.add_argument('--once', action='store_true', help='只运行一次检测') + parser.add_argument('--backfill', action='store_true', help='启动时回补历史异动到数据库') + parser.add_argument('--backfill-only', action='store_true', help='只回补历史(不持续运行)') + + args = parser.parse_args() + + service = RealtimeDetectorService( + checkpoint_dir=args.checkpoint_dir, + device=args.device + ) + + if args.once: + # 单次检测模式 + service.run_once() + elif args.backfill_only: + # 仅回补历史模式(需要 prepare_data.py) + service.backfill_today() + else: + # 实盘持续运行模式(自动预热,不依赖 prepare_data.py) + service.run_loop(backfill=args.backfill) + + +if __name__ == "__main__": + main() diff --git a/ml/realtime_detector_v2.py b/ml/realtime_detector_v2.py new file mode 100644 index 00000000..e5ccc942 --- /dev/null +++ b/ml/realtime_detector_v2.py @@ -0,0 +1,729 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +V2 实时异动检测器 + +使用方法: + # 作为模块导入 + from ml.realtime_detector_v2 import RealtimeDetectorV2 + + detector = RealtimeDetectorV2() + alerts = detector.detect_realtime() # 检测当前时刻 + + # 或命令行测试 + python ml/realtime_detector_v2.py --date 2025-12-09 +""" + +import os +import sys +import json +import pickle +from datetime import datetime, timedelta +from typing import Dict, List, Optional, Tuple +from collections import defaultdict, deque + +import numpy as np +import pandas as pd +import torch +from sqlalchemy import create_engine, text +from elasticsearch import Elasticsearch +from clickhouse_driver import Client + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from ml.model import TransformerAutoencoder + +# ==================== 配置 ==================== + +MYSQL_URL = "mysql+pymysql://root:Zzl5588161!@192.168.1.5:3306/stock" +ES_HOST = 'http://127.0.0.1:9200' +ES_INDEX = 'concept_library_v3' + +CLICKHOUSE_CONFIG = { + 'host': '127.0.0.1', + 'port': 9000, + 'user': 'default', + 'password': 'Zzl33818!', + 'database': 'stock' +} + +REFERENCE_INDEX = '000001.SH' +BASELINE_FILE = 'ml/data_v2/baselines/realtime_baseline.pkl' +MODEL_DIR = 'ml/checkpoints_v2' + +# 检测配置 +CONFIG = { + 'seq_len': 10, # LSTM 序列长度 + 'confirm_window': 5, # 持续确认窗口 + 'confirm_ratio': 0.6, # 确认比例 + 'rule_weight': 0.5, + 'ml_weight': 0.5, + 'rule_trigger': 60, + 'ml_trigger': 70, + 'fusion_trigger': 50, + 'cooldown_minutes': 10, + 'max_alerts_per_minute': 15, + 'zscore_clip': 5.0, + 'limit_up_threshold': 9.8, +} + +FEATURES = ['alpha_zscore', 'amt_zscore', 'rank_zscore', 'momentum_3m', 'momentum_5m', 'limit_up_ratio'] + + +# ==================== 数据库连接 ==================== + +_mysql_engine = None +_es_client = None +_ch_client = None + + +def get_mysql_engine(): + global _mysql_engine + if _mysql_engine is None: + _mysql_engine = create_engine(MYSQL_URL, echo=False, pool_pre_ping=True) + return _mysql_engine + + +def get_es_client(): + global _es_client + if _es_client is None: + _es_client = Elasticsearch([ES_HOST]) + return _es_client + + +def get_ch_client(): + global _ch_client + if _ch_client is None: + _ch_client = Client(**CLICKHOUSE_CONFIG) + return _ch_client + + +def code_to_ch_format(code: str) -> str: + if not code or len(code) != 6 or not code.isdigit(): + return None + if code.startswith('6'): + return f"{code}.SH" + elif code.startswith('0') or code.startswith('3'): + return f"{code}.SZ" + return f"{code}.BJ" + + +def time_to_slot(ts) -> str: + if isinstance(ts, str): + return ts + return ts.strftime('%H:%M') + + +# ==================== 规则评分 ==================== + +def score_rules_zscore(features: Dict) -> Tuple[float, List[str]]: + """基于 Z-Score 的规则评分""" + score = 0.0 + triggered = [] + + alpha_z = abs(features.get('alpha_zscore', 0)) + amt_z = features.get('amt_zscore', 0) + rank_z = abs(features.get('rank_zscore', 0)) + mom_3m = features.get('momentum_3m', 0) + mom_5m = features.get('momentum_5m', 0) + limit_up = features.get('limit_up_ratio', 0) + + # Alpha Z-Score + if alpha_z >= 4.0: + score += 25 + triggered.append('alpha_extreme') + elif alpha_z >= 3.0: + score += 18 + triggered.append('alpha_strong') + elif alpha_z >= 2.0: + score += 10 + triggered.append('alpha_moderate') + + # 成交额 Z-Score + if amt_z >= 4.0: + score += 20 + triggered.append('amt_extreme') + elif amt_z >= 3.0: + score += 12 + triggered.append('amt_strong') + elif amt_z >= 2.0: + score += 6 + triggered.append('amt_moderate') + + # 排名 Z-Score + if rank_z >= 3.0: + score += 15 + triggered.append('rank_extreme') + elif rank_z >= 2.0: + score += 8 + triggered.append('rank_strong') + + # 动量(基于 Z-Score 的) + if mom_3m >= 1.0: + score += 12 + triggered.append('momentum_3m_strong') + elif mom_3m >= 0.5: + score += 6 + triggered.append('momentum_3m_moderate') + + if mom_5m >= 1.5: + score += 10 + triggered.append('momentum_5m_strong') + + # 涨停比例 + if limit_up >= 0.3: + score += 20 + triggered.append('limit_up_extreme') + elif limit_up >= 0.15: + score += 12 + triggered.append('limit_up_strong') + elif limit_up >= 0.08: + score += 5 + triggered.append('limit_up_moderate') + + # 组合规则 + if alpha_z >= 2.0 and amt_z >= 2.0: + score += 15 + triggered.append('combo_alpha_amt') + + if alpha_z >= 2.0 and limit_up >= 0.1: + score += 12 + triggered.append('combo_alpha_limitup') + + return min(score, 100), triggered + + +# ==================== 实时检测器 ==================== + +class RealtimeDetectorV2: + """V2 实时异动检测器""" + + def __init__(self, model_dir: str = MODEL_DIR, baseline_file: str = BASELINE_FILE): + print("初始化 V2 实时检测器...") + + # 加载概念 + self.concepts = self._load_concepts() + self.concept_stocks = {c['concept_id']: set(c['stocks']) for c in self.concepts} + self.all_stocks = list(set(s for c in self.concepts for s in c['stocks'])) + + # 加载基线 + self.baselines = self._load_baselines(baseline_file) + + # 加载模型 + self.model, self.thresholds, self.device = self._load_model(model_dir) + + # 状态管理 + self.zscore_history = defaultdict(lambda: deque(maxlen=CONFIG['seq_len'])) + self.anomaly_candidates = defaultdict(lambda: deque(maxlen=CONFIG['confirm_window'])) + self.cooldown = {} + + print(f"初始化完成: {len(self.concepts)} 概念, {len(self.baselines)} 基线") + + def _load_concepts(self) -> List[dict]: + """从 ES 加载概念""" + es = get_es_client() + concepts = [] + + query = {"query": {"match_all": {}}, "size": 100, "_source": ["concept_id", "concept", "stocks"]} + resp = es.search(index=ES_INDEX, body=query, scroll='2m') + scroll_id = resp['_scroll_id'] + hits = resp['hits']['hits'] + + while hits: + for hit in hits: + src = hit['_source'] + stocks = [s['code'] for s in src.get('stocks', []) if isinstance(s, dict) and s.get('code')] + if stocks: + concepts.append({ + 'concept_id': src.get('concept_id'), + 'concept_name': src.get('concept'), + 'stocks': stocks + }) + resp = es.scroll(scroll_id=scroll_id, scroll='2m') + scroll_id = resp['_scroll_id'] + hits = resp['hits']['hits'] + + es.clear_scroll(scroll_id=scroll_id) + return concepts + + def _load_baselines(self, baseline_file: str) -> Dict: + """加载基线""" + if not os.path.exists(baseline_file): + print(f"警告: 基线文件不存在: {baseline_file}") + print("请先运行: python ml/update_baseline.py") + return {} + + with open(baseline_file, 'rb') as f: + data = pickle.load(f) + + print(f"基线日期范围: {data.get('date_range', 'unknown')}") + print(f"更新时间: {data.get('update_time', 'unknown')}") + + return data.get('baselines', {}) + + def _load_model(self, model_dir: str): + """加载模型""" + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + config_path = os.path.join(model_dir, 'config.json') + model_path = os.path.join(model_dir, 'best_model.pt') + threshold_path = os.path.join(model_dir, 'thresholds.json') + + if not os.path.exists(model_path): + print(f"警告: 模型不存在: {model_path}") + return None, {}, device + + with open(config_path) as f: + config = json.load(f) + + model = TransformerAutoencoder(**config['model']) + checkpoint = torch.load(model_path, map_location=device) + model.load_state_dict(checkpoint['model_state_dict']) + model.to(device) + model.eval() + + thresholds = {} + if os.path.exists(threshold_path): + with open(threshold_path) as f: + thresholds = json.load(f) + + print(f"模型已加载: {model_path}") + return model, thresholds, device + + def _get_realtime_data(self, trade_date: str) -> pd.DataFrame: + """获取实时数据并计算原始特征""" + ch = get_ch_client() + + # 获取股票数据 + ch_codes = [code_to_ch_format(c) for c in self.all_stocks if code_to_ch_format(c)] + ch_codes_str = "','".join(ch_codes) + + stock_query = f""" + SELECT code, timestamp, close, amt + FROM stock_minute + WHERE toDate(timestamp) = '{trade_date}' + AND code IN ('{ch_codes_str}') + ORDER BY timestamp + """ + stock_result = ch.execute(stock_query) + if not stock_result: + return pd.DataFrame() + + stock_df = pd.DataFrame(stock_result, columns=['ch_code', 'timestamp', 'close', 'amt']) + + # 映射回原始代码 + ch_to_code = {code_to_ch_format(c): c for c in self.all_stocks if code_to_ch_format(c)} + stock_df['code'] = stock_df['ch_code'].map(ch_to_code) + stock_df = stock_df.dropna(subset=['code']) + + # 获取指数数据 + index_query = f""" + SELECT timestamp, close + FROM index_minute + WHERE toDate(timestamp) = '{trade_date}' + AND code = '{REFERENCE_INDEX}' + ORDER BY timestamp + """ + index_result = ch.execute(index_query) + if not index_result: + return pd.DataFrame() + + index_df = pd.DataFrame(index_result, columns=['timestamp', 'close']) + + # 获取昨收价 + engine = get_mysql_engine() + codes_str = "','".join([c for c in self.all_stocks if c and len(c) == 6]) + + with engine.connect() as conn: + prev_result = conn.execute(text(f""" + SELECT SECCODE, F007N FROM ea_trade + WHERE SECCODE IN ('{codes_str}') + AND TRADEDATE = (SELECT MAX(TRADEDATE) FROM ea_trade WHERE TRADEDATE < '{trade_date}') + AND F007N > 0 + """)) + prev_close = {row[0]: float(row[1]) for row in prev_result if row[1]} + + idx_result = conn.execute(text(""" + SELECT F006N FROM ea_exchangetrade + WHERE INDEXCODE = '000001' AND TRADEDATE < :today + ORDER BY TRADEDATE DESC LIMIT 1 + """), {'today': trade_date}).fetchone() + index_prev_close = float(idx_result[0]) if idx_result else None + + if not prev_close or not index_prev_close: + return pd.DataFrame() + + # 计算涨跌幅 + stock_df['prev_close'] = stock_df['code'].map(prev_close) + stock_df = stock_df.dropna(subset=['prev_close']) + stock_df['change_pct'] = (stock_df['close'] - stock_df['prev_close']) / stock_df['prev_close'] * 100 + + index_df['change_pct'] = (index_df['close'] - index_prev_close) / index_prev_close * 100 + index_map = dict(zip(index_df['timestamp'], index_df['change_pct'])) + + # 按时间聚合概念特征 + results = [] + for ts in sorted(stock_df['timestamp'].unique()): + ts_data = stock_df[stock_df['timestamp'] == ts] + idx_chg = index_map.get(ts, 0) + + stock_chg = dict(zip(ts_data['code'], ts_data['change_pct'])) + stock_amt = dict(zip(ts_data['code'], ts_data['amt'])) + + for cid, stocks in self.concept_stocks.items(): + changes = [stock_chg[s] for s in stocks if s in stock_chg] + amts = [stock_amt.get(s, 0) for s in stocks if s in stock_chg] + + if not changes: + continue + + alpha = np.mean(changes) - idx_chg + total_amt = sum(amts) + limit_up_ratio = sum(1 for c in changes if c >= CONFIG['limit_up_threshold']) / len(changes) + + results.append({ + 'concept_id': cid, + 'timestamp': ts, + 'time_slot': time_to_slot(ts), + 'alpha': alpha, + 'total_amt': total_amt, + 'limit_up_ratio': limit_up_ratio, + 'stock_count': len(changes), + }) + + if not results: + return pd.DataFrame() + + df = pd.DataFrame(results) + + # 计算排名 + for ts in df['timestamp'].unique(): + mask = df['timestamp'] == ts + df.loc[mask, 'rank_pct'] = df.loc[mask, 'alpha'].rank(pct=True) + + return df + + def _compute_zscore(self, concept_id: str, time_slot: str, alpha: float, total_amt: float, rank_pct: float) -> Optional[Dict]: + """计算 Z-Score""" + if concept_id not in self.baselines: + return None + + baseline = self.baselines[concept_id] + if time_slot not in baseline: + return None + + bl = baseline[time_slot] + + alpha_z = np.clip((alpha - bl['alpha_mean']) / bl['alpha_std'], -5, 5) + amt_z = np.clip((total_amt - bl['amt_mean']) / bl['amt_std'], -5, 5) + rank_z = np.clip((rank_pct - bl['rank_mean']) / bl['rank_std'], -5, 5) + + # 动量(基于 Z-Score 历史) + history = list(self.zscore_history[concept_id]) + mom_3m = 0.0 + mom_5m = 0.0 + + if len(history) >= 3: + recent = [h['alpha_zscore'] for h in history[-3:]] + older = [h['alpha_zscore'] for h in history[-6:-3]] if len(history) >= 6 else [history[0]['alpha_zscore']] + mom_3m = np.mean(recent) - np.mean(older) + + if len(history) >= 5: + recent = [h['alpha_zscore'] for h in history[-5:]] + older = [h['alpha_zscore'] for h in history[-10:-5]] if len(history) >= 10 else [history[0]['alpha_zscore']] + mom_5m = np.mean(recent) - np.mean(older) + + return { + 'alpha_zscore': float(alpha_z), + 'amt_zscore': float(amt_z), + 'rank_zscore': float(rank_z), + 'momentum_3m': float(mom_3m), + 'momentum_5m': float(mom_5m), + } + + @torch.no_grad() + def _ml_score(self, sequences: np.ndarray) -> np.ndarray: + """批量 ML 评分""" + if self.model is None or len(sequences) == 0: + return np.zeros(len(sequences)) + + x = torch.FloatTensor(sequences).to(self.device) + errors = self.model.compute_reconstruction_error(x, reduction='none') + last_errors = errors[:, -1].cpu().numpy() + + # 转换为 0-100 分数 + if self.thresholds: + p50 = self.thresholds.get('median', 0.001) + p99 = self.thresholds.get('p99', 0.05) + scores = 50 + (last_errors - p50) / (p99 - p50 + 1e-6) * 49 + else: + scores = last_errors * 1000 + + return np.clip(scores, 0, 100) + + def detect(self, trade_date: str = None) -> List[Dict]: + """检测指定日期的异动""" + trade_date = trade_date or datetime.now().strftime('%Y-%m-%d') + print(f"\n检测 {trade_date} 的异动...") + + # 重置状态 + self.zscore_history.clear() + self.anomaly_candidates.clear() + self.cooldown.clear() + + # 获取数据 + raw_df = self._get_realtime_data(trade_date) + if raw_df.empty: + print("无数据") + return [] + + timestamps = sorted(raw_df['timestamp'].unique()) + print(f"时间点数: {len(timestamps)}") + + all_alerts = [] + + for ts in timestamps: + ts_data = raw_df[raw_df['timestamp'] == ts] + time_slot = time_to_slot(ts) + + candidates = [] + + # 计算每个概念的 Z-Score + for _, row in ts_data.iterrows(): + cid = row['concept_id'] + + zscore = self._compute_zscore( + cid, time_slot, + row['alpha'], row['total_amt'], row['rank_pct'] + ) + + if zscore is None: + continue + + # 完整特征 + features = { + **zscore, + 'alpha': row['alpha'], + 'limit_up_ratio': row['limit_up_ratio'], + 'total_amt': row['total_amt'], + } + + # 更新历史 + self.zscore_history[cid].append(zscore) + + # 规则评分 + rule_score, triggered = score_rules_zscore(features) + + candidates.append((cid, features, rule_score, triggered)) + + if not candidates: + continue + + # 批量 ML 评分 + sequences = [] + valid_candidates = [] + + for cid, features, rule_score, triggered in candidates: + history = list(self.zscore_history[cid]) + if len(history) >= CONFIG['seq_len']: + seq = np.array([[h['alpha_zscore'], h['amt_zscore'], h['rank_zscore'], + h['momentum_3m'], h['momentum_5m'], features['limit_up_ratio']] + for h in history]) + sequences.append(seq) + valid_candidates.append((cid, features, rule_score, triggered)) + + if not sequences: + continue + + ml_scores = self._ml_score(np.array(sequences)) + + # 融合 + 确认 + for i, (cid, features, rule_score, triggered) in enumerate(valid_candidates): + ml_score = ml_scores[i] + final_score = CONFIG['rule_weight'] * rule_score + CONFIG['ml_weight'] * ml_score + + # 判断触发 + is_triggered = ( + rule_score >= CONFIG['rule_trigger'] or + ml_score >= CONFIG['ml_trigger'] or + final_score >= CONFIG['fusion_trigger'] + ) + + self.anomaly_candidates[cid].append((ts, final_score)) + + if not is_triggered: + continue + + # 冷却期 + if cid in self.cooldown: + if (ts - self.cooldown[cid]).total_seconds() < CONFIG['cooldown_minutes'] * 60: + continue + + # 持续确认 + recent = list(self.anomaly_candidates[cid]) + if len(recent) < CONFIG['confirm_window']: + continue + + exceed = sum(1 for _, s in recent if s >= CONFIG['fusion_trigger']) + ratio = exceed / len(recent) + + if ratio < CONFIG['confirm_ratio']: + continue + + # 确认异动! + self.cooldown[cid] = ts + + alpha = features['alpha'] + alert_type = 'surge_up' if alpha >= 1.5 else 'surge_down' if alpha <= -1.5 else 'surge' + + concept_name = next((c['concept_name'] for c in self.concepts if c['concept_id'] == cid), cid) + + all_alerts.append({ + 'concept_id': cid, + 'concept_name': concept_name, + 'alert_time': ts, + 'trade_date': trade_date, + 'alert_type': alert_type, + 'final_score': float(final_score), + 'rule_score': float(rule_score), + 'ml_score': float(ml_score), + 'confirm_ratio': float(ratio), + 'alpha': float(alpha), + 'alpha_zscore': float(features['alpha_zscore']), + 'amt_zscore': float(features['amt_zscore']), + 'rank_zscore': float(features['rank_zscore']), + 'momentum_3m': float(features['momentum_3m']), + 'momentum_5m': float(features['momentum_5m']), + 'limit_up_ratio': float(features['limit_up_ratio']), + 'triggered_rules': triggered, + 'trigger_reason': f"融合({final_score:.0f})+确认({ratio:.0%})", + }) + + print(f"检测到 {len(all_alerts)} 个异动") + return all_alerts + + +# ==================== 数据库存储 ==================== + +def create_v2_table(): + """创建 V2 异动表(如果不存在)""" + engine = get_mysql_engine() + with engine.begin() as conn: + conn.execute(text(""" + CREATE TABLE IF NOT EXISTS concept_anomaly_v2 ( + id INT AUTO_INCREMENT PRIMARY KEY, + concept_id VARCHAR(50) NOT NULL, + alert_time DATETIME NOT NULL, + trade_date DATE NOT NULL, + alert_type VARCHAR(20) NOT NULL, + final_score FLOAT, + rule_score FLOAT, + ml_score FLOAT, + trigger_reason VARCHAR(200), + confirm_ratio FLOAT, + alpha FLOAT, + alpha_zscore FLOAT, + amt_zscore FLOAT, + rank_zscore FLOAT, + momentum_3m FLOAT, + momentum_5m FLOAT, + limit_up_ratio FLOAT, + triggered_rules TEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + UNIQUE KEY uk_concept_time (concept_id, alert_time), + INDEX idx_trade_date (trade_date), + INDEX idx_alert_type (alert_type) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 + """)) + print("concept_anomaly_v2 表已就绪") + + +def save_alerts_to_db(alerts: List[Dict]) -> int: + """保存异动到数据库""" + if not alerts: + return 0 + + engine = get_mysql_engine() + saved = 0 + + with engine.begin() as conn: + for alert in alerts: + try: + insert_sql = text(""" + INSERT IGNORE INTO concept_anomaly_v2 + (concept_id, alert_time, trade_date, alert_type, + final_score, rule_score, ml_score, trigger_reason, confirm_ratio, + alpha, alpha_zscore, amt_zscore, rank_zscore, + momentum_3m, momentum_5m, limit_up_ratio, triggered_rules) + VALUES + (:concept_id, :alert_time, :trade_date, :alert_type, + :final_score, :rule_score, :ml_score, :trigger_reason, :confirm_ratio, + :alpha, :alpha_zscore, :amt_zscore, :rank_zscore, + :momentum_3m, :momentum_5m, :limit_up_ratio, :triggered_rules) + """) + + result = conn.execute(insert_sql, { + 'concept_id': alert['concept_id'], + 'alert_time': alert['alert_time'], + 'trade_date': alert['trade_date'], + 'alert_type': alert['alert_type'], + 'final_score': alert['final_score'], + 'rule_score': alert['rule_score'], + 'ml_score': alert['ml_score'], + 'trigger_reason': alert['trigger_reason'], + 'confirm_ratio': alert['confirm_ratio'], + 'alpha': alert['alpha'], + 'alpha_zscore': alert['alpha_zscore'], + 'amt_zscore': alert['amt_zscore'], + 'rank_zscore': alert['rank_zscore'], + 'momentum_3m': alert['momentum_3m'], + 'momentum_5m': alert['momentum_5m'], + 'limit_up_ratio': alert['limit_up_ratio'], + 'triggered_rules': json.dumps(alert.get('triggered_rules', []), ensure_ascii=False), + }) + + if result.rowcount > 0: + saved += 1 + except Exception as e: + print(f"保存失败: {alert['concept_id']} - {e}") + + return saved + + +def main(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('--date', type=str, default=None) + parser.add_argument('--no-save', action='store_true', help='不保存到数据库,只打印') + args = parser.parse_args() + + # 确保表存在 + if not args.no_save: + create_v2_table() + + detector = RealtimeDetectorV2() + alerts = detector.detect(args.date) + + print(f"\n{'='*60}") + print(f"检测结果 ({len(alerts)} 个异动)") + print('='*60) + + for a in alerts[:20]: + print(f"[{a['alert_time'].strftime('%H:%M') if hasattr(a['alert_time'], 'strftime') else a['alert_time']}] " + f"{a['concept_name']} | {a['alert_type']} | " + f"分数={a['final_score']:.0f} 确认={a['confirm_ratio']:.0%} " + f"α={a['alpha']:.2f}% αZ={a['alpha_zscore']:.1f}") + + if len(alerts) > 20: + print(f"... 共 {len(alerts)} 个") + + # 保存到数据库 + if not args.no_save and alerts: + saved = save_alerts_to_db(alerts) + print(f"\n✅ 已保存 {saved}/{len(alerts)} 条到 concept_anomaly_v2 表") + elif args.no_save: + print(f"\n⚠️ --no-save 模式,未保存到数据库") + + +if __name__ == "__main__": + main() diff --git a/ml/requirements.txt b/ml/requirements.txt new file mode 100644 index 00000000..7b052bb8 --- /dev/null +++ b/ml/requirements.txt @@ -0,0 +1,25 @@ +# 概念异动检测 ML 模块依赖 +# 安装: pip install -r ml/requirements.txt + +# PyTorch (根据 CUDA 版本选择) +# 5090 显卡需要 CUDA 12.x +# pip install torch --index-url https://download.pytorch.org/whl/cu124 +torch>=2.0.0 + +# 数据处理 +numpy>=1.24.0 +pandas>=2.0.0 +pyarrow>=14.0.0 + +# 数据库 +clickhouse-driver>=0.2.6 +elasticsearch>=7.0.0,<8.0.0 +sqlalchemy>=2.0.0 +pymysql>=1.1.0 + +# 训练工具 +tqdm>=4.65.0 + +# 可选: 可视化 +# matplotlib>=3.7.0 +# tensorboard>=2.14.0 diff --git a/ml/run_training.sh b/ml/run_training.sh new file mode 100644 index 00000000..7f7f2dcd --- /dev/null +++ b/ml/run_training.sh @@ -0,0 +1,99 @@ +#!/bin/bash +# 概念异动检测模型训练脚本 (Linux) +# +# 使用方法: +# chmod +x run_training.sh +# ./run_training.sh +# +# 或指定参数: +# ./run_training.sh --start 2022-01-01 --epochs 100 + +set -e + +echo "============================================================" +echo "概念异动检测模型训练流程" +echo "============================================================" +echo "" + +# 获取脚本所在目录 +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +cd "$SCRIPT_DIR/.." + +echo "[1/4] 检查环境..." +python3 --version || { echo "Python3 未找到!"; exit 1; } + +# 检查 GPU +if python3 -c "import torch; print(f'CUDA: {torch.cuda.is_available()}')" 2>/dev/null; then + echo "PyTorch GPU 检测完成" +else + echo "警告: PyTorch 未安装或无法检测 GPU" +fi + +echo "" +echo "[2/4] 检查依赖..." +pip3 install -q torch pandas numpy pyarrow tqdm clickhouse-driver elasticsearch sqlalchemy pymysql + +echo "" +echo "[3/4] 准备训练数据..." +echo "从 ClickHouse 提取历史数据,这可能需要较长时间..." +echo "" + +# 解析参数 +START_DATE="2022-01-01" +END_DATE="" +EPOCHS=100 +BATCH_SIZE=256 +TRAIN_END="2025-06-30" +VAL_END="2025-09-30" + +while [[ $# -gt 0 ]]; do + case $1 in + --start) + START_DATE="$2" + shift 2 + ;; + --end) + END_DATE="$2" + shift 2 + ;; + --epochs) + EPOCHS="$2" + shift 2 + ;; + --batch_size) + BATCH_SIZE="$2" + shift 2 + ;; + --train_end) + TRAIN_END="$2" + shift 2 + ;; + --val_end) + VAL_END="$2" + shift 2 + ;; + *) + shift + ;; + esac +done + +# 数据准备 +if [ -n "$END_DATE" ]; then + python3 ml/prepare_data.py --start "$START_DATE" --end "$END_DATE" +else + python3 ml/prepare_data.py --start "$START_DATE" +fi + +echo "" +echo "[4/4] 训练模型..." +echo "使用 GPU 加速训练..." +echo "" + +python3 ml/train.py --epochs "$EPOCHS" --batch_size "$BATCH_SIZE" --train_end "$TRAIN_END" --val_end "$VAL_END" + +echo "" +echo "============================================================" +echo "训练完成!" +echo "模型保存在: ml/checkpoints/" +echo "============================================================" diff --git a/ml/train.py b/ml/train.py new file mode 100644 index 00000000..b93120ac --- /dev/null +++ b/ml/train.py @@ -0,0 +1,808 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Transformer Autoencoder 训练脚本 (修复版) + +修复问题: +1. 按概念分组构建序列,避免跨概念切片 +2. 按时间(日期)切分数据集,避免数据泄露 +3. 使用 RobustScaler + Clipping 处理非平稳性 +4. 使用验证集计算阈值 + +训练流程: +1. 加载预处理好的特征数据(parquet 文件) +2. 按概念分组,在每个概念内部构建序列 +3. 按日期划分训练/验证/测试集 +4. 训练 Autoencoder(最小化重构误差) +5. 保存模型和阈值 + +使用方法: + python train.py --data_dir ml/data --epochs 100 --batch_size 256 +""" + +import os +import sys +import argparse +import json +from datetime import datetime +from pathlib import Path +from typing import List, Tuple, Dict + +import numpy as np +import pandas as pd +import torch +import torch.nn as nn +from torch.utils.data import Dataset, DataLoader +from torch.optim import AdamW +from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts +from tqdm import tqdm + +from model import TransformerAutoencoder, AnomalyDetectionLoss, count_parameters + +# 性能优化:启用 cuDNN benchmark(对固定输入尺寸自动选择最快算法) +torch.backends.cudnn.benchmark = True +# 启用 TF32(RTX 30/40 系列特有,提速约 3 倍) +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True + +# 可视化(可选) +try: + import matplotlib + matplotlib.use('Agg') # 无头模式,不需要显示器 + import matplotlib.pyplot as plt + HAS_MATPLOTLIB = True +except ImportError: + HAS_MATPLOTLIB = False + + +# ==================== 配置 ==================== + +TRAIN_CONFIG = { + # 数据配置 + 'seq_len': 30, # 输入序列长度(30分钟) + 'stride': 5, # 滑动窗口步长 + + # 时间切分(按日期) + 'train_end_date': '2024-06-30', # 训练集截止日期 + 'val_end_date': '2024-09-30', # 验证集截止日期(之后为测试集) + + # 特征配置 + 'features': [ + 'alpha', # 超额收益 + 'alpha_delta', # Alpha 变化率 + 'amt_ratio', # 成交额比率 + 'amt_delta', # 成交额变化率 + 'rank_pct', # Alpha 排名百分位 + 'limit_up_ratio', # 涨停比例 + ], + + # 训练配置(针对 4x RTX 4090 优化) + 'batch_size': 4096, # 256 -> 4096(大幅增加,充分利用显存) + 'epochs': 100, + 'learning_rate': 3e-4, # 1e-4 -> 3e-4(大 batch 需要更大学习率) + 'weight_decay': 1e-5, + 'gradient_clip': 1.0, + + # 早停配置 + 'patience': 10, + 'min_delta': 1e-6, + + # 模型配置(LSTM Autoencoder,简洁有效) + 'model': { + 'n_features': 6, + 'hidden_dim': 32, # LSTM 隐藏维度(小) + 'latent_dim': 4, # 瓶颈维度(非常小!关键) + 'num_layers': 1, # LSTM 层数 + 'dropout': 0.2, + 'bidirectional': True, # 双向编码器 + }, + + # 标准化配置 + 'use_instance_norm': True, # 模型内部使用 Instance Norm(推荐) + 'clip_value': 10.0, # 简单截断极端值 + + # 阈值配置 + 'threshold_percentiles': [90, 95, 99], +} + + +# ==================== 数据加载(修复版)==================== + +def load_data_by_date(data_dir: str, features: List[str]) -> Dict[str, pd.DataFrame]: + """ + 按日期加载数据,返回 {date: DataFrame} 字典 + + 每个 DataFrame 包含该日所有概念的所有时间点数据 + """ + data_path = Path(data_dir) + parquet_files = sorted(data_path.glob("features_*.parquet")) + + if not parquet_files: + raise FileNotFoundError(f"未找到 parquet 文件: {data_dir}") + + print(f"找到 {len(parquet_files)} 个数据文件") + + date_data = {} + + for pf in tqdm(parquet_files, desc="加载数据"): + # 提取日期 + date = pf.stem.replace('features_', '') + + df = pd.read_parquet(pf) + + # 检查必要列 + required_cols = features + ['concept_id', 'timestamp'] + missing_cols = [c for c in required_cols if c not in df.columns] + if missing_cols: + print(f"警告: {date} 缺少列: {missing_cols}, 跳过") + continue + + date_data[date] = df + + print(f"成功加载 {len(date_data)} 天的数据") + return date_data + + +def split_data_by_date( + date_data: Dict[str, pd.DataFrame], + train_end: str, + val_end: str +) -> Tuple[Dict[str, pd.DataFrame], Dict[str, pd.DataFrame], Dict[str, pd.DataFrame]]: + """ + 按日期严格划分数据集 + + - 训练集: <= train_end + - 验证集: train_end < date <= val_end + - 测试集: > val_end + """ + train_data = {} + val_data = {} + test_data = {} + + for date, df in date_data.items(): + if date <= train_end: + train_data[date] = df + elif date <= val_end: + val_data[date] = df + else: + test_data[date] = df + + print(f"数据集划分(按日期):") + print(f" 训练集: {len(train_data)} 天 (<= {train_end})") + print(f" 验证集: {len(val_data)} 天 ({train_end} ~ {val_end})") + print(f" 测试集: {len(test_data)} 天 (> {val_end})") + + return train_data, val_data, test_data + + +def build_sequences_by_concept( + date_data: Dict[str, pd.DataFrame], + features: List[str], + seq_len: int, + stride: int +) -> np.ndarray: + """ + 按概念分组构建序列(性能优化版) + + 使用 groupby 一次性分组,避免重复扫描大数组 + + 1. 将所有日期的数据合并 + 2. 使用 groupby 按 concept_id 分组 + 3. 在每个概念内部,按时间排序并滑动窗口 + 4. 合并所有序列 + """ + # 合并所有日期的数据 + all_dfs = [] + for date, df in sorted(date_data.items()): + df = df.copy() + df['date'] = date + all_dfs.append(df) + + if not all_dfs: + return np.array([]) + + combined = pd.concat(all_dfs, ignore_index=True) + + # 预先排序(按概念、日期、时间),这样 groupby 会更快 + combined = combined.sort_values(['concept_id', 'date', 'timestamp']) + + # 使用 groupby 一次性分组(性能关键!) + all_sequences = [] + grouped = combined.groupby('concept_id', sort=False) + n_concepts = len(grouped) + + for concept_id, concept_df in tqdm(grouped, desc="构建序列", total=n_concepts, leave=False): + # 已经排序过了,直接提取特征 + feature_data = concept_df[features].values + + # 处理缺失值 + feature_data = np.nan_to_num(feature_data, nan=0.0, posinf=0.0, neginf=0.0) + + # 在该概念内部滑动窗口 + n_points = len(feature_data) + for start in range(0, n_points - seq_len + 1, stride): + seq = feature_data[start:start + seq_len] + all_sequences.append(seq) + + if not all_sequences: + return np.array([]) + + sequences = np.array(all_sequences) + print(f" 构建序列: {len(sequences):,} 条 (来自 {n_concepts} 个概念)") + + return sequences + + +# ==================== 数据集 ==================== + +class SequenceDataset(Dataset): + """序列数据集(已经构建好的序列)""" + + def __init__(self, sequences: np.ndarray): + self.sequences = torch.FloatTensor(sequences) + + def __len__(self) -> int: + return len(self.sequences) + + def __getitem__(self, idx: int) -> torch.Tensor: + return self.sequences[idx] + + +# ==================== 训练器 ==================== + +class EarlyStopping: + """早停机制""" + + def __init__(self, patience: int = 10, min_delta: float = 1e-6): + self.patience = patience + self.min_delta = min_delta + self.counter = 0 + self.best_loss = float('inf') + self.early_stop = False + + def __call__(self, val_loss: float) -> bool: + if val_loss < self.best_loss - self.min_delta: + self.best_loss = val_loss + self.counter = 0 + else: + self.counter += 1 + if self.counter >= self.patience: + self.early_stop = True + + return self.early_stop + + +class Trainer: + """模型训练器(支持 AMP 混合精度加速)""" + + def __init__( + self, + model: nn.Module, + train_loader: DataLoader, + val_loader: DataLoader, + config: Dict, + device: torch.device, + save_dir: str = 'ml/checkpoints' + ): + self.model = model.to(device) + self.train_loader = train_loader + self.val_loader = val_loader + self.config = config + self.device = device + self.save_dir = Path(save_dir) + self.save_dir.mkdir(parents=True, exist_ok=True) + + # 优化器 + self.optimizer = AdamW( + model.parameters(), + lr=config['learning_rate'], + weight_decay=config['weight_decay'] + ) + + # 学习率调度器 + self.scheduler = CosineAnnealingWarmRestarts( + self.optimizer, + T_0=10, + T_mult=2, + eta_min=1e-6 + ) + + # 损失函数(简化版,只用 MSE) + self.criterion = AnomalyDetectionLoss() + + # 早停 + self.early_stopping = EarlyStopping( + patience=config['patience'], + min_delta=config['min_delta'] + ) + + # AMP 混合精度训练(大幅提速 + 省显存) + self.use_amp = torch.cuda.is_available() + self.scaler = torch.cuda.amp.GradScaler() if self.use_amp else None + if self.use_amp: + print(" ✓ 启用 AMP 混合精度训练") + + # 训练历史 + self.history = { + 'train_loss': [], + 'val_loss': [], + 'learning_rate': [], + } + + self.best_val_loss = float('inf') + + def train_epoch(self) -> float: + """训练一个 epoch(使用 AMP 混合精度)""" + self.model.train() + total_loss = 0.0 + n_batches = 0 + + pbar = tqdm(self.train_loader, desc="Training", leave=False) + for batch in pbar: + batch = batch.to(self.device, non_blocking=True) # 异步传输 + + self.optimizer.zero_grad(set_to_none=True) # 更快的梯度清零 + + # AMP 混合精度前向传播 + if self.use_amp: + with torch.cuda.amp.autocast(): + output, latent = self.model(batch) + loss, loss_dict = self.criterion(output, batch, latent) + + # AMP 反向传播 + self.scaler.scale(loss).backward() + + # 梯度裁剪(需要 unscale) + self.scaler.unscale_(self.optimizer) + torch.nn.utils.clip_grad_norm_( + self.model.parameters(), + self.config['gradient_clip'] + ) + + self.scaler.step(self.optimizer) + self.scaler.update() + else: + # 非 AMP 模式 + output, latent = self.model(batch) + loss, loss_dict = self.criterion(output, batch, latent) + + loss.backward() + torch.nn.utils.clip_grad_norm_( + self.model.parameters(), + self.config['gradient_clip'] + ) + self.optimizer.step() + + total_loss += loss.item() + n_batches += 1 + + pbar.set_postfix({'loss': f"{loss.item():.4f}"}) + + return total_loss / n_batches + + @torch.no_grad() + def validate(self) -> float: + """验证(使用 AMP)""" + self.model.eval() + total_loss = 0.0 + n_batches = 0 + + for batch in self.val_loader: + batch = batch.to(self.device, non_blocking=True) + + if self.use_amp: + with torch.cuda.amp.autocast(): + output, latent = self.model(batch) + loss, _ = self.criterion(output, batch, latent) + else: + output, latent = self.model(batch) + loss, _ = self.criterion(output, batch, latent) + + total_loss += loss.item() + n_batches += 1 + + return total_loss / n_batches + + def save_checkpoint(self, epoch: int, val_loss: float, is_best: bool = False): + """保存检查点""" + # 处理 DataParallel 包装 + model_to_save = self.model.module if hasattr(self.model, 'module') else self.model + + checkpoint = { + 'epoch': epoch, + 'model_state_dict': model_to_save.state_dict(), + 'optimizer_state_dict': self.optimizer.state_dict(), + 'scheduler_state_dict': self.scheduler.state_dict(), + 'val_loss': val_loss, + 'config': self.config, + } + + # 保存最新检查点 + torch.save(checkpoint, self.save_dir / 'last_checkpoint.pt') + + # 保存最佳模型 + if is_best: + torch.save(checkpoint, self.save_dir / 'best_model.pt') + print(f" ✓ 保存最佳模型 (val_loss: {val_loss:.6f})") + + def train(self, epochs: int): + """完整训练流程""" + print(f"\n开始训练 ({epochs} epochs)...") + print(f"设备: {self.device}") + print(f"模型参数量: {count_parameters(self.model):,}") + + for epoch in range(1, epochs + 1): + print(f"\nEpoch {epoch}/{epochs}") + + # 训练 + train_loss = self.train_epoch() + + # 验证 + val_loss = self.validate() + + # 更新学习率 + self.scheduler.step() + current_lr = self.optimizer.param_groups[0]['lr'] + + # 记录历史 + self.history['train_loss'].append(train_loss) + self.history['val_loss'].append(val_loss) + self.history['learning_rate'].append(current_lr) + + # 打印进度 + print(f" Train Loss: {train_loss:.6f}") + print(f" Val Loss: {val_loss:.6f}") + print(f" LR: {current_lr:.2e}") + + # 保存检查点 + is_best = val_loss < self.best_val_loss + if is_best: + self.best_val_loss = val_loss + self.save_checkpoint(epoch, val_loss, is_best) + + # 早停检查 + if self.early_stopping(val_loss): + print(f"\n早停触发!验证损失已 {self.early_stopping.patience} 个 epoch 未改善") + break + + print(f"\n训练完成!最佳验证损失: {self.best_val_loss:.6f}") + + # 保存训练历史 + self.save_history() + + return self.history + + def save_history(self): + """保存训练历史""" + history_path = self.save_dir / 'training_history.json' + with open(history_path, 'w') as f: + json.dump(self.history, f, indent=2) + print(f"训练历史已保存: {history_path}") + + # 绘制训练曲线 + self.plot_training_curves() + + def plot_training_curves(self): + """绘制训练曲线""" + if not HAS_MATPLOTLIB: + print("matplotlib 未安装,跳过绘图") + return + + fig, axes = plt.subplots(1, 2, figsize=(14, 5)) + + epochs = range(1, len(self.history['train_loss']) + 1) + + # 1. Loss 曲线 + ax1 = axes[0] + ax1.plot(epochs, self.history['train_loss'], 'b-', label='Train Loss', linewidth=2) + ax1.plot(epochs, self.history['val_loss'], 'r-', label='Val Loss', linewidth=2) + ax1.set_xlabel('Epoch', fontsize=12) + ax1.set_ylabel('Loss', fontsize=12) + ax1.set_title('Training & Validation Loss', fontsize=14) + ax1.legend(fontsize=11) + ax1.grid(True, alpha=0.3) + + # 标记最佳点 + best_epoch = np.argmin(self.history['val_loss']) + 1 + best_val_loss = min(self.history['val_loss']) + ax1.axvline(x=best_epoch, color='g', linestyle='--', alpha=0.7, label=f'Best Epoch: {best_epoch}') + ax1.scatter([best_epoch], [best_val_loss], color='g', s=100, zorder=5) + ax1.annotate(f'Best: {best_val_loss:.6f}', xy=(best_epoch, best_val_loss), + xytext=(best_epoch + 2, best_val_loss + 0.0005), + fontsize=10, color='green') + + # 2. 学习率曲线 + ax2 = axes[1] + ax2.plot(epochs, self.history['learning_rate'], 'g-', linewidth=2) + ax2.set_xlabel('Epoch', fontsize=12) + ax2.set_ylabel('Learning Rate', fontsize=12) + ax2.set_title('Learning Rate Schedule', fontsize=14) + ax2.set_yscale('log') + ax2.grid(True, alpha=0.3) + + plt.tight_layout() + + # 保存图片 + plot_path = self.save_dir / 'training_curves.png' + plt.savefig(plot_path, dpi=150, bbox_inches='tight') + plt.close() + print(f"训练曲线已保存: {plot_path}") + + +# ==================== 阈值计算(使用验证集)==================== + +@torch.no_grad() +def compute_thresholds( + model: nn.Module, + data_loader: DataLoader, + device: torch.device, + percentiles: List[float] = [90, 95, 99] +) -> Dict[str, float]: + """ + 在验证集上计算重构误差的百分位数阈值 + + 注:使用验证集而非测试集,避免数据泄露 + """ + model.eval() + all_errors = [] + + print("计算异动阈值(使用验证集)...") + for batch in tqdm(data_loader, desc="Computing thresholds"): + batch = batch.to(device) + errors = model.compute_reconstruction_error(batch, reduction='none') + + # 取每个序列的最后一个时刻误差(预测当前时刻) + seq_errors = errors[:, -1] # (batch,) + all_errors.append(seq_errors.cpu().numpy()) + + all_errors = np.concatenate(all_errors) + + thresholds = {} + for p in percentiles: + threshold = np.percentile(all_errors, p) + thresholds[f'p{p}'] = float(threshold) + print(f" P{p}: {threshold:.6f}") + + # 额外统计 + thresholds['mean'] = float(np.mean(all_errors)) + thresholds['std'] = float(np.std(all_errors)) + thresholds['median'] = float(np.median(all_errors)) + + print(f" Mean: {thresholds['mean']:.6f}") + print(f" Median: {thresholds['median']:.6f}") + print(f" Std: {thresholds['std']:.6f}") + + return thresholds + + +# ==================== 主函数 ==================== + +def main(): + parser = argparse.ArgumentParser(description='训练概念异动检测模型') + parser.add_argument('--data_dir', type=str, default='ml/data', + help='数据目录路径') + parser.add_argument('--epochs', type=int, default=100, + help='训练轮数') + parser.add_argument('--batch_size', type=int, default=4096, + help='批次大小(4x RTX 4090 推荐 4096~8192)') + parser.add_argument('--lr', type=float, default=3e-4, + help='学习率(大 batch 推荐 3e-4)') + parser.add_argument('--device', type=str, default='auto', + help='设备 (auto/cuda/cpu)') + parser.add_argument('--save_dir', type=str, default='ml/checkpoints', + help='模型保存目录') + parser.add_argument('--train_end', type=str, default='2024-06-30', + help='训练集截止日期') + parser.add_argument('--val_end', type=str, default='2024-09-30', + help='验证集截止日期') + + args = parser.parse_args() + + # 更新配置 + config = TRAIN_CONFIG.copy() + config['batch_size'] = args.batch_size + config['epochs'] = args.epochs + config['learning_rate'] = args.lr + config['train_end_date'] = args.train_end + config['val_end_date'] = args.val_end + + # 设备选择 + if args.device == 'auto': + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + else: + device = torch.device(args.device) + + print("=" * 60) + print("概念异动检测模型训练(修复版)") + print("=" * 60) + print(f"配置:") + print(f" 数据目录: {args.data_dir}") + print(f" 设备: {device}") + print(f" 批次大小: {config['batch_size']}") + print(f" 学习率: {config['learning_rate']}") + print(f" 训练轮数: {config['epochs']}") + print(f" 训练集截止: {config['train_end_date']}") + print(f" 验证集截止: {config['val_end_date']}") + print("=" * 60) + + # 1. 按日期加载数据 + print("\n[1/6] 加载数据...") + date_data = load_data_by_date(args.data_dir, config['features']) + + # 2. 按日期划分 + print("\n[2/6] 按日期划分数据集...") + train_data, val_data, test_data = split_data_by_date( + date_data, + config['train_end_date'], + config['val_end_date'] + ) + + # 3. 按概念构建序列 + print("\n[3/6] 按概念构建序列...") + print("训练集:") + train_sequences = build_sequences_by_concept( + train_data, config['features'], config['seq_len'], config['stride'] + ) + print("验证集:") + val_sequences = build_sequences_by_concept( + val_data, config['features'], config['seq_len'], config['stride'] + ) + print("测试集:") + test_sequences = build_sequences_by_concept( + test_data, config['features'], config['seq_len'], config['stride'] + ) + + if len(train_sequences) == 0: + print("错误: 训练集为空!请检查数据和日期范围") + return + + # 4. 数据预处理(简单截断极端值,标准化在模型内部通过 Instance Norm 完成) + print("\n[4/6] 数据预处理...") + print(" 注意: 使用 Instance Norm,每个序列在模型内部单独标准化") + print(" 这样可以处理不同概念波动率差异(银行 vs 半导体)") + + clip_value = config['clip_value'] + print(f" 截断极端值: ±{clip_value}") + + # 简单截断极端值(防止异常数据影响训练) + train_sequences = np.clip(train_sequences, -clip_value, clip_value) + if len(val_sequences) > 0: + val_sequences = np.clip(val_sequences, -clip_value, clip_value) + if len(test_sequences) > 0: + test_sequences = np.clip(test_sequences, -clip_value, clip_value) + + # 保存配置 + save_dir = Path(args.save_dir) + save_dir.mkdir(parents=True, exist_ok=True) + + preprocess_params = { + 'features': config['features'], + 'normalization': 'instance_norm', # 在模型内部完成 + 'clip_value': clip_value, + 'note': '标准化在模型内部通过 InstanceNorm1d 完成,无需外部 Scaler' + } + + with open(save_dir / 'normalization_stats.json', 'w') as f: + json.dump(preprocess_params, f, indent=2) + print(f" 预处理参数已保存") + + # 5. 创建数据集和加载器 + print("\n[5/6] 创建数据加载器...") + train_dataset = SequenceDataset(train_sequences) + val_dataset = SequenceDataset(val_sequences) if len(val_sequences) > 0 else None + test_dataset = SequenceDataset(test_sequences) if len(test_sequences) > 0 else None + + print(f" 训练序列: {len(train_dataset):,}") + print(f" 验证序列: {len(val_dataset) if val_dataset else 0:,}") + print(f" 测试序列: {len(test_dataset) if test_dataset else 0:,}") + + # 多卡时增加 num_workers(Linux 上可以用更多) + n_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 1 + num_workers = min(32, 8 * n_gpus) if sys.platform != 'win32' else 0 + print(f" DataLoader workers: {num_workers}") + print(f" Batch size: {config['batch_size']}") + + # 大 batch + 多 worker + prefetch 提速 + train_loader = DataLoader( + train_dataset, + batch_size=config['batch_size'], + shuffle=True, + num_workers=num_workers, + pin_memory=True, + prefetch_factor=4 if num_workers > 0 else None, # 预取更多 batch + persistent_workers=True if num_workers > 0 else False, # 保持 worker 存活 + drop_last=True # 丢弃不完整的最后一批,避免 batch 大小不一致 + ) + + val_loader = DataLoader( + val_dataset, + batch_size=config['batch_size'] * 2, # 验证时可以用更大 batch(无梯度) + shuffle=False, + num_workers=num_workers, + pin_memory=True, + prefetch_factor=4 if num_workers > 0 else None, + persistent_workers=True if num_workers > 0 else False, + ) if val_dataset else None + + test_loader = DataLoader( + test_dataset, + batch_size=config['batch_size'] * 2, + shuffle=False, + num_workers=num_workers, + pin_memory=True, + prefetch_factor=4 if num_workers > 0 else None, + persistent_workers=True if num_workers > 0 else False, + ) if test_dataset else None + + # 6. 训练 + print("\n[6/6] 训练模型...") + model_config = config['model'].copy() + model = TransformerAutoencoder(**model_config) + + # 多卡并行 + if torch.cuda.device_count() > 1: + print(f" 使用 {torch.cuda.device_count()} 张 GPU 并行训练") + model = nn.DataParallel(model) + + if val_loader is None: + print("警告: 验证集为空,将使用训练集的一部分作为验证") + # 简单处理:用训练集的后 10% 作为验证 + split_idx = int(len(train_dataset) * 0.9) + train_subset = torch.utils.data.Subset(train_dataset, range(split_idx)) + val_subset = torch.utils.data.Subset(train_dataset, range(split_idx, len(train_dataset))) + + train_loader = DataLoader(train_subset, batch_size=config['batch_size'], shuffle=True, num_workers=num_workers, pin_memory=True) + val_loader = DataLoader(val_subset, batch_size=config['batch_size'], shuffle=False, num_workers=num_workers, pin_memory=True) + + trainer = Trainer( + model=model, + train_loader=train_loader, + val_loader=val_loader, + config=config, + device=device, + save_dir=args.save_dir + ) + + history = trainer.train(config['epochs']) + + # 7. 计算阈值(使用验证集) + print("\n[额外] 计算异动阈值...") + + # 加载最佳模型 + best_checkpoint = torch.load( + save_dir / 'best_model.pt', + map_location=device + ) + model.load_state_dict(best_checkpoint['model_state_dict']) + model.to(device) + + # 使用验证集计算阈值(避免数据泄露) + thresholds = compute_thresholds( + model, + val_loader, + device, + config['threshold_percentiles'] + ) + + # 保存阈值 + with open(save_dir / 'thresholds.json', 'w') as f: + json.dump(thresholds, f, indent=2) + print(f"阈值已保存") + + # 保存完整配置 + with open(save_dir / 'config.json', 'w') as f: + json.dump(config, f, indent=2) + + print("\n" + "=" * 60) + print("训练完成!") + print("=" * 60) + print(f"模型保存位置: {args.save_dir}") + print(f" - best_model.pt: 最佳模型权重") + print(f" - thresholds.json: 异动阈值") + print(f" - normalization_stats.json: 标准化参数") + print(f" - config.json: 训练配置") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/ml/train_v2.py b/ml/train_v2.py new file mode 100644 index 00000000..47e6cdb7 --- /dev/null +++ b/ml/train_v2.py @@ -0,0 +1,622 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +训练脚本 V2 - 基于 Z-Score 特征的 LSTM Autoencoder + +改进点: +1. 使用 Z-Score 特征(相对于同时间片历史的偏离) +2. 短序列:10分钟(不需要30分钟预热) +3. 开盘即可检测:9:30 直接有特征 + +模型输入: +- 过去10分钟的 Z-Score 特征序列 +- 特征:alpha_zscore, amt_zscore, rank_zscore, momentum_3m, momentum_5m, limit_up_ratio + +模型学习: +- 学习 Z-Score 序列的"正常演化模式" +- 异动 = Z-Score 序列的异常演化(重构误差大) +""" + +import os +import sys +import argparse +import json +from datetime import datetime +from pathlib import Path +from typing import List, Tuple, Dict + +import numpy as np +import pandas as pd +import torch +import torch.nn as nn +from torch.utils.data import Dataset, DataLoader +from torch.optim import AdamW +from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts +from tqdm import tqdm + +from model import TransformerAutoencoder, AnomalyDetectionLoss, count_parameters + +# 性能优化 +torch.backends.cudnn.benchmark = True +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True + +try: + import matplotlib + matplotlib.use('Agg') + import matplotlib.pyplot as plt + HAS_MATPLOTLIB = True +except ImportError: + HAS_MATPLOTLIB = False + + +# ==================== 配置 ==================== + +TRAIN_CONFIG = { + # 数据配置(改进!) + 'seq_len': 10, # 10分钟序列(不是30分钟!) + 'stride': 2, # 步长2分钟 + + # 时间切分 + 'train_end_date': '2024-06-30', + 'val_end_date': '2024-09-30', + + # V2 特征(Z-Score 为主) + 'features': [ + 'alpha_zscore', # Alpha 的 Z-Score + 'amt_zscore', # 成交额的 Z-Score + 'rank_zscore', # 排名的 Z-Score + 'momentum_3m', # 3分钟动量 + 'momentum_5m', # 5分钟动量 + 'limit_up_ratio', # 涨停占比 + ], + + # 训练配置 + 'batch_size': 4096, + 'epochs': 100, + 'learning_rate': 3e-4, + 'weight_decay': 1e-5, + 'gradient_clip': 1.0, + + # 早停配置 + 'patience': 15, + 'min_delta': 1e-6, + + # 模型配置(小型 LSTM) + 'model': { + 'n_features': 6, + 'hidden_dim': 32, + 'latent_dim': 4, + 'num_layers': 1, + 'dropout': 0.2, + 'bidirectional': True, + }, + + # 标准化配置 + 'clip_value': 5.0, # Z-Score 已经标准化,clip 5.0 足够 + + # 阈值配置 + 'threshold_percentiles': [90, 95, 99], +} + + +# ==================== 数据加载 ==================== + +def load_data_by_date(data_dir: str, features: List[str]) -> Dict[str, pd.DataFrame]: + """按日期加载 V2 数据""" + data_path = Path(data_dir) + parquet_files = sorted(data_path.glob("features_v2_*.parquet")) + + if not parquet_files: + raise FileNotFoundError(f"未找到 V2 数据文件: {data_dir}") + + print(f"找到 {len(parquet_files)} 个 V2 数据文件") + + date_data = {} + + for pf in tqdm(parquet_files, desc="加载数据"): + date = pf.stem.replace('features_v2_', '') + + df = pd.read_parquet(pf) + + required_cols = features + ['concept_id', 'timestamp'] + missing_cols = [c for c in required_cols if c not in df.columns] + if missing_cols: + print(f"警告: {date} 缺少列: {missing_cols}, 跳过") + continue + + date_data[date] = df + + print(f"成功加载 {len(date_data)} 天的数据") + return date_data + + +def split_data_by_date( + date_data: Dict[str, pd.DataFrame], + train_end: str, + val_end: str +) -> Tuple[Dict[str, pd.DataFrame], Dict[str, pd.DataFrame], Dict[str, pd.DataFrame]]: + """按日期划分数据集""" + train_data = {} + val_data = {} + test_data = {} + + for date, df in date_data.items(): + if date <= train_end: + train_data[date] = df + elif date <= val_end: + val_data[date] = df + else: + test_data[date] = df + + print(f"数据集划分:") + print(f" 训练集: {len(train_data)} 天 (<= {train_end})") + print(f" 验证集: {len(val_data)} 天 ({train_end} ~ {val_end})") + print(f" 测试集: {len(test_data)} 天 (> {val_end})") + + return train_data, val_data, test_data + + +def build_sequences_by_concept( + date_data: Dict[str, pd.DataFrame], + features: List[str], + seq_len: int, + stride: int +) -> np.ndarray: + """按概念分组构建序列""" + all_dfs = [] + for date, df in sorted(date_data.items()): + df = df.copy() + df['date'] = date + all_dfs.append(df) + + if not all_dfs: + return np.array([]) + + combined = pd.concat(all_dfs, ignore_index=True) + combined = combined.sort_values(['concept_id', 'date', 'timestamp']) + + all_sequences = [] + grouped = combined.groupby('concept_id', sort=False) + n_concepts = len(grouped) + + for concept_id, concept_df in tqdm(grouped, desc="构建序列", total=n_concepts, leave=False): + feature_data = concept_df[features].values + feature_data = np.nan_to_num(feature_data, nan=0.0, posinf=0.0, neginf=0.0) + + n_points = len(feature_data) + for start in range(0, n_points - seq_len + 1, stride): + seq = feature_data[start:start + seq_len] + all_sequences.append(seq) + + if not all_sequences: + return np.array([]) + + sequences = np.array(all_sequences) + print(f" 构建序列: {len(sequences):,} 条 (来自 {n_concepts} 个概念)") + + return sequences + + +# ==================== 数据集 ==================== + +class SequenceDataset(Dataset): + def __init__(self, sequences: np.ndarray): + self.sequences = torch.FloatTensor(sequences) + + def __len__(self) -> int: + return len(self.sequences) + + def __getitem__(self, idx: int) -> torch.Tensor: + return self.sequences[idx] + + +# ==================== 训练器 ==================== + +class EarlyStopping: + def __init__(self, patience: int = 10, min_delta: float = 1e-6): + self.patience = patience + self.min_delta = min_delta + self.counter = 0 + self.best_loss = float('inf') + self.early_stop = False + + def __call__(self, val_loss: float) -> bool: + if val_loss < self.best_loss - self.min_delta: + self.best_loss = val_loss + self.counter = 0 + else: + self.counter += 1 + if self.counter >= self.patience: + self.early_stop = True + return self.early_stop + + +class Trainer: + def __init__( + self, + model: nn.Module, + train_loader: DataLoader, + val_loader: DataLoader, + config: Dict, + device: torch.device, + save_dir: str = 'ml/checkpoints_v2' + ): + self.model = model.to(device) + self.train_loader = train_loader + self.val_loader = val_loader + self.config = config + self.device = device + self.save_dir = Path(save_dir) + self.save_dir.mkdir(parents=True, exist_ok=True) + + self.optimizer = AdamW( + model.parameters(), + lr=config['learning_rate'], + weight_decay=config['weight_decay'] + ) + + self.scheduler = CosineAnnealingWarmRestarts( + self.optimizer, T_0=10, T_mult=2, eta_min=1e-6 + ) + + self.criterion = AnomalyDetectionLoss() + + self.early_stopping = EarlyStopping( + patience=config['patience'], + min_delta=config['min_delta'] + ) + + self.use_amp = torch.cuda.is_available() + self.scaler = torch.cuda.amp.GradScaler() if self.use_amp else None + if self.use_amp: + print(" ✓ 启用 AMP 混合精度训练") + + self.history = {'train_loss': [], 'val_loss': [], 'learning_rate': []} + self.best_val_loss = float('inf') + + def train_epoch(self) -> float: + self.model.train() + total_loss = 0.0 + n_batches = 0 + + pbar = tqdm(self.train_loader, desc="Training", leave=False) + for batch in pbar: + batch = batch.to(self.device, non_blocking=True) + self.optimizer.zero_grad(set_to_none=True) + + if self.use_amp: + with torch.cuda.amp.autocast(): + output, latent = self.model(batch) + loss, _ = self.criterion(output, batch, latent) + + self.scaler.scale(loss).backward() + self.scaler.unscale_(self.optimizer) + torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config['gradient_clip']) + self.scaler.step(self.optimizer) + self.scaler.update() + else: + output, latent = self.model(batch) + loss, _ = self.criterion(output, batch, latent) + loss.backward() + torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config['gradient_clip']) + self.optimizer.step() + + total_loss += loss.item() + n_batches += 1 + pbar.set_postfix({'loss': f"{loss.item():.4f}"}) + + return total_loss / n_batches + + @torch.no_grad() + def validate(self) -> float: + self.model.eval() + total_loss = 0.0 + n_batches = 0 + + for batch in self.val_loader: + batch = batch.to(self.device, non_blocking=True) + + if self.use_amp: + with torch.cuda.amp.autocast(): + output, latent = self.model(batch) + loss, _ = self.criterion(output, batch, latent) + else: + output, latent = self.model(batch) + loss, _ = self.criterion(output, batch, latent) + + total_loss += loss.item() + n_batches += 1 + + return total_loss / n_batches + + def save_checkpoint(self, epoch: int, val_loss: float, is_best: bool = False): + model_to_save = self.model.module if hasattr(self.model, 'module') else self.model + + checkpoint = { + 'epoch': epoch, + 'model_state_dict': model_to_save.state_dict(), + 'optimizer_state_dict': self.optimizer.state_dict(), + 'scheduler_state_dict': self.scheduler.state_dict(), + 'val_loss': val_loss, + 'config': self.config, + } + + torch.save(checkpoint, self.save_dir / 'last_checkpoint.pt') + + if is_best: + torch.save(checkpoint, self.save_dir / 'best_model.pt') + print(f" ✓ 保存最佳模型 (val_loss: {val_loss:.6f})") + + def train(self, epochs: int): + print(f"\n开始训练 ({epochs} epochs)...") + print(f"设备: {self.device}") + print(f"模型参数量: {count_parameters(self.model):,}") + + for epoch in range(1, epochs + 1): + print(f"\nEpoch {epoch}/{epochs}") + + train_loss = self.train_epoch() + val_loss = self.validate() + + self.scheduler.step() + current_lr = self.optimizer.param_groups[0]['lr'] + + self.history['train_loss'].append(train_loss) + self.history['val_loss'].append(val_loss) + self.history['learning_rate'].append(current_lr) + + print(f" Train Loss: {train_loss:.6f}") + print(f" Val Loss: {val_loss:.6f}") + print(f" LR: {current_lr:.2e}") + + is_best = val_loss < self.best_val_loss + if is_best: + self.best_val_loss = val_loss + self.save_checkpoint(epoch, val_loss, is_best) + + if self.early_stopping(val_loss): + print(f"\n早停触发!") + break + + print(f"\n训练完成!最佳验证损失: {self.best_val_loss:.6f}") + self.save_history() + + return self.history + + def save_history(self): + history_path = self.save_dir / 'training_history.json' + with open(history_path, 'w') as f: + json.dump(self.history, f, indent=2) + print(f"训练历史已保存: {history_path}") + + if HAS_MATPLOTLIB: + self.plot_training_curves() + + def plot_training_curves(self): + fig, axes = plt.subplots(1, 2, figsize=(14, 5)) + epochs = range(1, len(self.history['train_loss']) + 1) + + ax1 = axes[0] + ax1.plot(epochs, self.history['train_loss'], 'b-', label='Train Loss', linewidth=2) + ax1.plot(epochs, self.history['val_loss'], 'r-', label='Val Loss', linewidth=2) + ax1.set_xlabel('Epoch') + ax1.set_ylabel('Loss') + ax1.set_title('Training & Validation Loss (V2)') + ax1.legend() + ax1.grid(True, alpha=0.3) + + best_epoch = np.argmin(self.history['val_loss']) + 1 + best_val_loss = min(self.history['val_loss']) + ax1.axvline(x=best_epoch, color='g', linestyle='--', alpha=0.7) + ax1.scatter([best_epoch], [best_val_loss], color='g', s=100, zorder=5) + + ax2 = axes[1] + ax2.plot(epochs, self.history['learning_rate'], 'g-', linewidth=2) + ax2.set_xlabel('Epoch') + ax2.set_ylabel('Learning Rate') + ax2.set_title('Learning Rate Schedule') + ax2.set_yscale('log') + ax2.grid(True, alpha=0.3) + + plt.tight_layout() + plt.savefig(self.save_dir / 'training_curves.png', dpi=150, bbox_inches='tight') + plt.close() + print(f"训练曲线已保存") + + +# ==================== 阈值计算 ==================== + +@torch.no_grad() +def compute_thresholds( + model: nn.Module, + data_loader: DataLoader, + device: torch.device, + percentiles: List[float] = [90, 95, 99] +) -> Dict[str, float]: + """在验证集上计算阈值""" + model.eval() + all_errors = [] + + print("计算异动阈值...") + for batch in tqdm(data_loader, desc="Computing thresholds"): + batch = batch.to(device) + errors = model.compute_reconstruction_error(batch, reduction='none') + seq_errors = errors[:, -1] # 最后一个时刻 + all_errors.append(seq_errors.cpu().numpy()) + + all_errors = np.concatenate(all_errors) + + thresholds = {} + for p in percentiles: + threshold = np.percentile(all_errors, p) + thresholds[f'p{p}'] = float(threshold) + print(f" P{p}: {threshold:.6f}") + + thresholds['mean'] = float(np.mean(all_errors)) + thresholds['std'] = float(np.std(all_errors)) + thresholds['median'] = float(np.median(all_errors)) + + return thresholds + + +# ==================== 主函数 ==================== + +def main(): + parser = argparse.ArgumentParser(description='训练 V2 模型') + parser.add_argument('--data_dir', type=str, default='ml/data_v2', help='V2 数据目录') + parser.add_argument('--epochs', type=int, default=100) + parser.add_argument('--batch_size', type=int, default=4096) + parser.add_argument('--lr', type=float, default=3e-4) + parser.add_argument('--device', type=str, default='auto') + parser.add_argument('--save_dir', type=str, default='ml/checkpoints_v2') + parser.add_argument('--train_end', type=str, default='2024-06-30') + parser.add_argument('--val_end', type=str, default='2024-09-30') + parser.add_argument('--seq_len', type=int, default=10, help='序列长度(分钟)') + + args = parser.parse_args() + + config = TRAIN_CONFIG.copy() + config['batch_size'] = args.batch_size + config['epochs'] = args.epochs + config['learning_rate'] = args.lr + config['train_end_date'] = args.train_end + config['val_end_date'] = args.val_end + config['seq_len'] = args.seq_len + + if args.device == 'auto': + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + else: + device = torch.device(args.device) + + print("=" * 60) + print("概念异动检测模型训练 V2(Z-Score 特征)") + print("=" * 60) + print(f"数据目录: {args.data_dir}") + print(f"设备: {device}") + print(f"序列长度: {config['seq_len']} 分钟") + print(f"批次大小: {config['batch_size']}") + print(f"特征: {config['features']}") + print("=" * 60) + + # 1. 加载数据 + print("\n[1/6] 加载 V2 数据...") + date_data = load_data_by_date(args.data_dir, config['features']) + + # 2. 划分数据集 + print("\n[2/6] 划分数据集...") + train_data, val_data, test_data = split_data_by_date( + date_data, config['train_end_date'], config['val_end_date'] + ) + + # 3. 构建序列 + print("\n[3/6] 构建序列...") + print("训练集:") + train_sequences = build_sequences_by_concept( + train_data, config['features'], config['seq_len'], config['stride'] + ) + print("验证集:") + val_sequences = build_sequences_by_concept( + val_data, config['features'], config['seq_len'], config['stride'] + ) + + if len(train_sequences) == 0: + print("错误: 训练集为空!") + return + + # 4. 预处理 + print("\n[4/6] 数据预处理...") + clip_value = config['clip_value'] + print(f" Z-Score 特征已标准化,截断: ±{clip_value}") + + train_sequences = np.clip(train_sequences, -clip_value, clip_value) + if len(val_sequences) > 0: + val_sequences = np.clip(val_sequences, -clip_value, clip_value) + + # 保存配置 + save_dir = Path(args.save_dir) + save_dir.mkdir(parents=True, exist_ok=True) + + with open(save_dir / 'config.json', 'w') as f: + json.dump(config, f, indent=2) + + # 5. 创建数据加载器 + print("\n[5/6] 创建数据加载器...") + train_dataset = SequenceDataset(train_sequences) + val_dataset = SequenceDataset(val_sequences) if len(val_sequences) > 0 else None + + print(f" 训练序列: {len(train_dataset):,}") + print(f" 验证序列: {len(val_dataset) if val_dataset else 0:,}") + + n_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 1 + num_workers = min(32, 8 * n_gpus) if sys.platform != 'win32' else 0 + + train_loader = DataLoader( + train_dataset, + batch_size=config['batch_size'], + shuffle=True, + num_workers=num_workers, + pin_memory=True, + prefetch_factor=4 if num_workers > 0 else None, + persistent_workers=True if num_workers > 0 else False, + drop_last=True + ) + + val_loader = DataLoader( + val_dataset, + batch_size=config['batch_size'] * 2, + shuffle=False, + num_workers=num_workers, + pin_memory=True, + ) if val_dataset else None + + # 6. 训练 + print("\n[6/6] 训练模型...") + model = TransformerAutoencoder(**config['model']) + + if torch.cuda.device_count() > 1: + print(f" 使用 {torch.cuda.device_count()} 张 GPU 并行训练") + model = nn.DataParallel(model) + + if val_loader is None: + print("警告: 验证集为空,使用训练集的 10% 作为验证") + split_idx = int(len(train_dataset) * 0.9) + train_subset = torch.utils.data.Subset(train_dataset, range(split_idx)) + val_subset = torch.utils.data.Subset(train_dataset, range(split_idx, len(train_dataset))) + train_loader = DataLoader(train_subset, batch_size=config['batch_size'], shuffle=True, num_workers=num_workers, pin_memory=True) + val_loader = DataLoader(val_subset, batch_size=config['batch_size'], shuffle=False, num_workers=num_workers, pin_memory=True) + + trainer = Trainer( + model=model, + train_loader=train_loader, + val_loader=val_loader, + config=config, + device=device, + save_dir=args.save_dir + ) + + trainer.train(config['epochs']) + + # 计算阈值 + print("\n[额外] 计算异动阈值...") + best_checkpoint = torch.load(save_dir / 'best_model.pt', map_location=device) + + # 创建新的单 GPU 模型用于计算阈值(避免 DataParallel 问题) + threshold_model = TransformerAutoencoder(**config['model']) + threshold_model.load_state_dict(best_checkpoint['model_state_dict']) + threshold_model.to(device) + threshold_model.eval() + + thresholds = compute_thresholds(threshold_model, val_loader, device, config['threshold_percentiles']) + + with open(save_dir / 'thresholds.json', 'w') as f: + json.dump(thresholds, f, indent=2) + + print("\n" + "=" * 60) + print("训练完成!") + print(f"模型保存位置: {args.save_dir}") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/ml/update_baseline.py b/ml/update_baseline.py new file mode 100644 index 00000000..7ee7e3cc --- /dev/null +++ b/ml/update_baseline.py @@ -0,0 +1,132 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +每日盘后运行:更新滚动基线 + +使用方法: + python ml/update_baseline.py + +建议加入 crontab,每天 15:30 后运行: + 30 15 * * 1-5 cd /path/to/project && python ml/update_baseline.py +""" + +import os +import sys +import pickle +import pandas as pd +import numpy as np +from datetime import datetime, timedelta +from pathlib import Path +from tqdm import tqdm + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from ml.prepare_data_v2 import ( + get_all_concepts, get_trading_days, compute_raw_concept_features, + init_process_connections, CONFIG, RAW_CACHE_DIR, BASELINE_DIR +) + + +def update_rolling_baseline(baseline_days: int = 20): + """ + 更新滚动基线(用于实盘检测) + + 基线 = 最近 N 个交易日每个时间片的统计量 + """ + print("=" * 60) + print("更新滚动基线(用于实盘)") + print("=" * 60) + + # 初始化连接 + init_process_connections() + + # 获取概念列表 + concepts = get_all_concepts() + all_stocks = list(set(s for c in concepts for s in c['stocks'])) + + # 获取最近的交易日 + today = datetime.now().strftime('%Y-%m-%d') + start_date = (datetime.now() - timedelta(days=60)).strftime('%Y-%m-%d') # 多取一些 + + trading_days = get_trading_days(start_date, today) + + if len(trading_days) < baseline_days: + print(f"错误:交易日不足 {baseline_days} 天") + return + + # 只取最近 N 天 + recent_days = trading_days[-baseline_days:] + print(f"使用 {len(recent_days)} 天数据: {recent_days[0]} ~ {recent_days[-1]}") + + # 加载原始数据 + all_data = [] + for trade_date in tqdm(recent_days, desc="加载数据"): + cache_file = os.path.join(RAW_CACHE_DIR, f'raw_{trade_date}.parquet') + + if os.path.exists(cache_file): + df = pd.read_parquet(cache_file) + else: + df = compute_raw_concept_features(trade_date, concepts, all_stocks) + + if not df.empty: + all_data.append(df) + + if not all_data: + print("错误:无数据") + return + + combined = pd.concat(all_data, ignore_index=True) + print(f"总数据量: {len(combined):,} 条") + + # 按概念计算基线 + baselines = {} + + for concept_id, group in tqdm(combined.groupby('concept_id'), desc="计算基线"): + baseline_dict = {} + + for time_slot, slot_group in group.groupby('time_slot'): + if len(slot_group) < CONFIG['min_baseline_samples']: + continue + + alpha_std = slot_group['alpha'].std() + amt_std = slot_group['total_amt'].std() + rank_std = slot_group['rank_pct'].std() + + baseline_dict[time_slot] = { + 'alpha_mean': float(slot_group['alpha'].mean()), + 'alpha_std': float(max(alpha_std if pd.notna(alpha_std) else 1.0, 0.1)), + 'amt_mean': float(slot_group['total_amt'].mean()), + 'amt_std': float(max(amt_std if pd.notna(amt_std) else slot_group['total_amt'].mean() * 0.5, 1.0)), + 'rank_mean': float(slot_group['rank_pct'].mean()), + 'rank_std': float(max(rank_std if pd.notna(rank_std) else 0.2, 0.05)), + 'sample_count': len(slot_group), + } + + if baseline_dict: + baselines[concept_id] = baseline_dict + + print(f"计算了 {len(baselines)} 个概念的基线") + + # 保存 + os.makedirs(BASELINE_DIR, exist_ok=True) + baseline_file = os.path.join(BASELINE_DIR, 'realtime_baseline.pkl') + + with open(baseline_file, 'wb') as f: + pickle.dump({ + 'baselines': baselines, + 'update_time': datetime.now().isoformat(), + 'date_range': [recent_days[0], recent_days[-1]], + 'baseline_days': baseline_days, + }, f) + + print(f"基线已保存: {baseline_file}") + print("=" * 60) + + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('--days', type=int, default=20, help='基线天数') + args = parser.parse_args() + + update_rolling_baseline(args.days) diff --git a/sql/concept_minute_alert.sql b/sql/concept_minute_alert.sql new file mode 100644 index 00000000..d54cc7ff --- /dev/null +++ b/sql/concept_minute_alert.sql @@ -0,0 +1,68 @@ +-- 概念分钟级异动数据表 +-- 用于存储概念板块的实时异动信息,支持热点概览图表展示 + +CREATE TABLE IF NOT EXISTS concept_minute_alert ( + id BIGINT AUTO_INCREMENT PRIMARY KEY, + concept_id VARCHAR(32) NOT NULL COMMENT '概念ID', + concept_name VARCHAR(100) NOT NULL COMMENT '概念名称', + alert_time DATETIME NOT NULL COMMENT '异动时间(精确到分钟)', + alert_type VARCHAR(20) NOT NULL COMMENT '异动类型:surge(急涨)/limit_up(涨停增加)/rank_jump(排名跃升)', + trade_date DATE NOT NULL COMMENT '交易日期', + + -- 涨跌幅相关 + change_pct DECIMAL(10,4) COMMENT '当时涨跌幅(%)', + prev_change_pct DECIMAL(10,4) COMMENT '之前涨跌幅(%)', + change_delta DECIMAL(10,4) COMMENT '涨幅变化量(%)', + + -- 涨停相关 + limit_up_count INT DEFAULT 0 COMMENT '当前涨停数量', + prev_limit_up_count INT DEFAULT 0 COMMENT '之前涨停数量', + limit_up_delta INT DEFAULT 0 COMMENT '涨停变化数量', + + -- 排名相关 + rank_position INT COMMENT '当前涨幅排名', + prev_rank_position INT COMMENT '之前涨幅排名', + rank_delta INT COMMENT '排名变化(负数表示上升)', + + -- 指数位置(用于图表Y轴定位) + index_code VARCHAR(20) DEFAULT '000001.SH' COMMENT '参考指数代码', + index_price DECIMAL(12,4) COMMENT '异动时的指数点位', + index_change_pct DECIMAL(10,4) COMMENT '异动时的指数涨跌幅(%)', + + -- 概念详情 + stock_count INT COMMENT '概念包含股票数', + concept_type VARCHAR(20) DEFAULT 'leaf' COMMENT '概念类型:leaf/lv1/lv2/lv3', + + -- 额外信息(JSON格式,存储涨停股票列表等) + extra_info JSON COMMENT '额外信息', + + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + + -- 索引 + INDEX idx_trade_date (trade_date), + INDEX idx_alert_time (alert_time), + INDEX idx_concept_id (concept_id), + INDEX idx_alert_type (alert_type), + INDEX idx_trade_date_time (trade_date, alert_time) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='概念分钟级异动数据表'; + + +-- 创建指数分时快照表(用于异动时获取指数位置) +CREATE TABLE IF NOT EXISTS index_minute_snapshot ( + id BIGINT AUTO_INCREMENT PRIMARY KEY, + index_code VARCHAR(20) NOT NULL COMMENT '指数代码', + trade_date DATE NOT NULL COMMENT '交易日期', + snapshot_time DATETIME NOT NULL COMMENT '快照时间', + + price DECIMAL(12,4) COMMENT '指数点位', + open_price DECIMAL(12,4) COMMENT '开盘价', + high_price DECIMAL(12,4) COMMENT '最高价', + low_price DECIMAL(12,4) COMMENT '最低价', + prev_close DECIMAL(12,4) COMMENT '昨收价', + change_pct DECIMAL(10,4) COMMENT '涨跌幅(%)', + + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + + UNIQUE KEY uk_index_time (index_code, snapshot_time), + INDEX idx_trade_date (trade_date) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='指数分时快照表'; diff --git a/src/components/StockChart/StockChartAntdModal.js b/src/components/StockChart/StockChartAntdModal.js index 8cddda26..bb1a0257 100644 --- a/src/components/StockChart/StockChartAntdModal.js +++ b/src/components/StockChart/StockChartAntdModal.js @@ -313,12 +313,29 @@ const StockChartAntdModal = ({ axisPointer: { type: 'cross' }, formatter: function(params) { const d = params[0]?.dataIndex ?? 0; - const priceChangePercent = ((prices[d] - prevClose) / prevClose * 100); - const avgChangePercent = ((avgPrices[d] - prevClose) / prevClose * 100); + const price = prices[d]; + const avgPrice = avgPrices[d]; + const volume = volumes[d]; + + // 安全计算涨跌幅,处理 undefined/null/0 的情况 + const safeCalcPercent = (val, base) => { + if (val == null || base == null || base === 0) return 0; + return ((val - base) / base * 100); + }; + + const priceChangePercent = safeCalcPercent(price, prevClose); + const avgChangePercent = safeCalcPercent(avgPrice, prevClose); const priceColor = priceChangePercent >= 0 ? '#ef5350' : '#26a69a'; const avgColor = avgChangePercent >= 0 ? '#ef5350' : '#26a69a'; - return `时间:${times[d]}
现价:¥${prices[d]?.toFixed(2)} (${priceChangePercent >= 0 ? '+' : ''}${priceChangePercent.toFixed(2)}%)
均价:¥${avgPrices[d]?.toFixed(2)} (${avgChangePercent >= 0 ? '+' : ''}${avgChangePercent.toFixed(2)}%)
昨收:¥${prevClose?.toFixed(2)}
成交量:${Math.round(volumes[d]/100)}手`; + // 安全格式化数字 + const safeFixed = (val, digits = 2) => (val != null && !isNaN(val)) ? val.toFixed(digits) : '-'; + const formatPercent = (val) => { + if (val == null || isNaN(val)) return '-'; + return (val >= 0 ? '+' : '') + val.toFixed(2) + '%'; + }; + + return `时间:${times[d] || '-'}
现价:¥${safeFixed(price)} (${formatPercent(priceChangePercent)})
均价:¥${safeFixed(avgPrice)} (${formatPercent(avgChangePercent)})
昨收:¥${safeFixed(prevClose)}
成交量:${volume != null ? Math.round(volume/100) + '手' : '-'}`; } }, grid: [ @@ -337,6 +354,7 @@ const StockChartAntdModal = ({ position: 'left', axisLabel: { formatter: function(value) { + if (value == null || isNaN(value)) return '-'; return (value >= 0 ? '+' : '') + value.toFixed(2) + '%'; } }, @@ -354,11 +372,12 @@ const StockChartAntdModal = ({ position: 'right', axisLabel: { formatter: function(value) { + if (value == null || isNaN(value)) return '-'; return (value >= 0 ? '+' : '') + value.toFixed(2) + '%'; } } }, - { type: 'value', gridIndex: 1, scale: true, axisLabel: { formatter: v => Math.round(v/100) + '手' } } + { type: 'value', gridIndex: 1, scale: true, axisLabel: { formatter: v => (v != null && !isNaN(v)) ? Math.round(v/100) + '手' : '-' } } ], dataZoom: [ { type: 'inside', xAxisIndex: [0, 1], start: 0, end: 100 }, diff --git a/src/components/StockChart/TimelineChartModal.tsx b/src/components/StockChart/TimelineChartModal.tsx index 7981b098..eb9eb809 100644 --- a/src/components/StockChart/TimelineChartModal.tsx +++ b/src/components/StockChart/TimelineChartModal.tsx @@ -217,27 +217,34 @@ const TimelineChartModal: React.FC = ({ if (dataIndex === undefined) return ''; const item = data[dataIndex]; - const changeColor = item.change_percent >= 0 ? '#ef5350' : '#26a69a'; - const changeSign = item.change_percent >= 0 ? '+' : ''; + if (!item) return ''; + + // 安全格式化数字 + const safeFixed = (val: any, digits = 2) => + val != null && !isNaN(val) ? Number(val).toFixed(digits) : '-'; + + const changePercent = item.change_percent ?? 0; + const changeColor = changePercent >= 0 ? '#ef5350' : '#26a69a'; + const changeSign = changePercent >= 0 ? '+' : ''; return `
-
${item.time}
+
${item.time || '-'}
价格: - ${item.price.toFixed(2)} + ${safeFixed(item.price)}
均价: - ${item.avg_price.toFixed(2)} + ${safeFixed(item.avg_price)}
涨跌幅: - ${changeSign}${item.change_percent.toFixed(2)}% + ${changeSign}${safeFixed(changePercent)}%
成交量: - ${(item.volume / 100).toFixed(0)}手 + ${item.volume != null ? (item.volume / 100).toFixed(0) : '-'}手
`; @@ -314,7 +321,7 @@ const TimelineChartModal: React.FC = ({ axisLabel: { color: '#999', fontSize: isMobile ? 10 : 12, - formatter: (value: number) => value.toFixed(2), + formatter: (value: number) => (value != null && !isNaN(value)) ? value.toFixed(2) : '-', }, }, { @@ -333,6 +340,7 @@ const TimelineChartModal: React.FC = ({ color: '#999', fontSize: isMobile ? 10 : 12, formatter: (value: number) => { + if (value == null || isNaN(value)) return '-'; if (value >= 10000) { return (value / 10000).toFixed(1) + '万'; } diff --git a/src/mocks/handlers/market.js b/src/mocks/handlers/market.js index 19857e45..4e15a72d 100644 --- a/src/mocks/handlers/market.js +++ b/src/mocks/handlers/market.js @@ -346,7 +346,173 @@ export const marketHandlers = [ }); }), - // 11. 市场统计数据(个股中心页面使用) + // 11. 热点概览数据(大盘分时 + 概念异动) + http.get('/api/market/hotspot-overview', async ({ request }) => { + await delay(300); + const url = new URL(request.url); + const date = url.searchParams.get('date'); + + const tradeDate = date || new Date().toISOString().split('T')[0]; + + // 生成分时数据(240个点,9:30-11:30 + 13:00-15:00) + const timeline = []; + const basePrice = 3900 + Math.random() * 100; // 基准价格 3900-4000 + const prevClose = basePrice; + let currentPrice = basePrice; + let cumulativeVolume = 0; + + // 上午时段 9:30-11:30 (120分钟) + for (let i = 0; i < 120; i++) { + const hour = 9 + Math.floor((i + 30) / 60); + const minute = (i + 30) % 60; + const time = `${hour.toString().padStart(2, '0')}:${minute.toString().padStart(2, '0')}`; + + // 模拟价格波动 + const volatility = 0.002; // 0.2%波动 + const drift = (Math.random() - 0.5) * 0.001; // 微小趋势 + currentPrice = currentPrice * (1 + (Math.random() - 0.5) * volatility + drift); + + const volume = Math.floor(Math.random() * 500000 + 100000); // 成交量 + cumulativeVolume += volume; + + timeline.push({ + time, + price: parseFloat(currentPrice.toFixed(2)), + volume: cumulativeVolume, + change_pct: parseFloat(((currentPrice - prevClose) / prevClose * 100).toFixed(2)) + }); + } + + // 下午时段 13:00-15:00 (120分钟) + for (let i = 0; i < 120; i++) { + const hour = 13 + Math.floor(i / 60); + const minute = i % 60; + const time = `${hour.toString().padStart(2, '0')}:${minute.toString().padStart(2, '0')}`; + + // 下午波动略小 + const volatility = 0.0015; + const drift = (Math.random() - 0.5) * 0.0008; + currentPrice = currentPrice * (1 + (Math.random() - 0.5) * volatility + drift); + + const volume = Math.floor(Math.random() * 400000 + 80000); + cumulativeVolume += volume; + + timeline.push({ + time, + price: parseFloat(currentPrice.toFixed(2)), + volume: cumulativeVolume, + change_pct: parseFloat(((currentPrice - prevClose) / prevClose * 100).toFixed(2)) + }); + } + + // 生成概念异动数据 + const conceptNames = [ + '人工智能', 'AI眼镜', '机器人', '核电', '国企', '卫星导航', + '福建自贸区', '两岸融合', 'CRO', '三季报增长', '百货零售', + '人形机器人', '央企', '数据中心', 'CPO', '新能源', '电网设备', + '氢能源', '算力租赁', '厦门国资', '乳业', '低空安防', '创新药', + '商业航天', '控制权变更', '文化传媒', '海峡两岸' + ]; + + const alertTypes = ['surge_up', 'surge_down', 'volume_spike', 'limit_up', 'rank_jump']; + + // 生成 15-25 个异动 + const alertCount = Math.floor(Math.random() * 10) + 15; + const alerts = []; + const usedTimes = new Set(); + + for (let i = 0; i < alertCount; i++) { + // 随机选择一个时间点 + let timeIdx; + let attempts = 0; + do { + timeIdx = Math.floor(Math.random() * timeline.length); + attempts++; + } while (usedTimes.has(timeIdx) && attempts < 50); + + if (attempts >= 50) continue; + + // 同一时间可以有多个异动 + const time = timeline[timeIdx].time; + const conceptName = conceptNames[Math.floor(Math.random() * conceptNames.length)]; + const alertType = alertTypes[Math.floor(Math.random() * alertTypes.length)]; + + // 根据类型生成 alpha + let alpha; + if (alertType === 'surge_up') { + alpha = parseFloat((Math.random() * 3 + 2).toFixed(2)); // +2% ~ +5% + } else if (alertType === 'surge_down') { + alpha = parseFloat((-Math.random() * 3 - 1.5).toFixed(2)); // -1.5% ~ -4.5% + } else { + alpha = parseFloat((Math.random() * 4 - 1).toFixed(2)); // -1% ~ +3% + } + + const finalScore = Math.floor(Math.random() * 40 + 45); // 45-85分 + const ruleScore = Math.floor(Math.random() * 30 + 40); + const mlScore = Math.floor(Math.random() * 30 + 40); + + alerts.push({ + concept_id: `CONCEPT_${1000 + i}`, + concept_name: conceptName, + time, + alert_type: alertType, + alpha, + alpha_delta: parseFloat((Math.random() * 2 - 0.5).toFixed(2)), + amt_ratio: parseFloat((Math.random() * 5 + 1).toFixed(2)), + limit_up_count: alertType === 'limit_up' ? Math.floor(Math.random() * 5 + 1) : 0, + limit_up_ratio: parseFloat((Math.random() * 0.3).toFixed(3)), + final_score: finalScore, + rule_score: ruleScore, + ml_score: mlScore, + trigger_reason: finalScore >= 65 ? '规则强信号' : (mlScore >= 70 ? 'ML强信号' : '融合触发'), + importance_score: parseFloat((finalScore / 100).toFixed(2)), + index_price: timeline[timeIdx].price + }); + } + + // 按时间排序 + alerts.sort((a, b) => a.time.localeCompare(b.time)); + + // 统计异动类型 + const alertSummary = alerts.reduce((acc, alert) => { + acc[alert.alert_type] = (acc[alert.alert_type] || 0) + 1; + return acc; + }, {}); + + // 计算指数统计 + const prices = timeline.map(t => t.price); + const latestPrice = prices[prices.length - 1]; + const highPrice = Math.max(...prices); + const lowPrice = Math.min(...prices); + const changePct = ((latestPrice - prevClose) / prevClose * 100); + + console.log('[Mock Market] 获取热点概览数据:', { + date: tradeDate, + timelinePoints: timeline.length, + alertCount: alerts.length + }); + + return HttpResponse.json({ + success: true, + data: { + index: { + code: '000001.SH', + name: '上证指数', + latest_price: latestPrice, + prev_close: prevClose, + high: highPrice, + low: lowPrice, + change_pct: parseFloat(changePct.toFixed(2)), + timeline + }, + alerts, + alert_summary: alertSummary + }, + trade_date: tradeDate + }); + }), + + // 12. 市场统计数据(个股中心页面使用) http.get('/api/market/statistics', async ({ request }) => { await delay(200); const url = new URL(request.url); diff --git a/src/views/Community/components/HeroPanel.js b/src/views/Community/components/HeroPanel.js index 40a69a2a..3c957fe0 100644 --- a/src/views/Community/components/HeroPanel.js +++ b/src/views/Community/components/HeroPanel.js @@ -207,9 +207,12 @@ const CompactIndexCard = ({ indexCode, indexName }) => { const raw = chartData.rawData[idx]; if (!raw) return ''; + // 安全格式化数字 + const safeFixed = (val, digits = 2) => (val != null && !isNaN(val)) ? val.toFixed(digits) : '-'; + // 计算涨跌 const prevClose = raw.prev_close || (idx > 0 ? chartData.rawData[idx - 1]?.close : raw.open) || raw.open; - const changeAmount = raw.close - prevClose; + const changeAmount = (raw.close != null && prevClose != null) ? (raw.close - prevClose) : 0; const changePct = prevClose ? ((changeAmount / prevClose) * 100) : 0; const isUp = changeAmount >= 0; const color = isUp ? '#ef5350' : '#26a69a'; @@ -218,22 +221,22 @@ const CompactIndexCard = ({ indexCode, indexName }) => { return `
- 📅 ${raw.time} + 📅 ${raw.time || '-'}
开盘 - ${raw.open.toFixed(2)} + ${safeFixed(raw.open)} 收盘 - ${raw.close.toFixed(2)} + ${safeFixed(raw.close)} 最高 - ${raw.high.toFixed(2)} + ${safeFixed(raw.high)} 最低 - ${raw.low.toFixed(2)} + ${safeFixed(raw.low)}
涨跌幅 - ${sign}${changeAmount.toFixed(2)} (${sign}${changePct.toFixed(2)}%) + ${sign}${safeFixed(changeAmount)} (${sign}${safeFixed(changePct)}%)
@@ -529,7 +532,7 @@ const FlowingConcepts = () => { color={colors.text} whiteSpace="nowrap" > - {concept.change_pct > 0 ? '+' : ''}{concept.change_pct.toFixed(2)}% + {concept.change_pct > 0 ? '+' : ''}{concept.change_pct?.toFixed(2) ?? '-'}% diff --git a/src/views/StockOverview/components/FlexScreen/components/MiniTimelineChart.tsx b/src/views/StockOverview/components/FlexScreen/components/MiniTimelineChart.tsx new file mode 100644 index 00000000..0730b24f --- /dev/null +++ b/src/views/StockOverview/components/FlexScreen/components/MiniTimelineChart.tsx @@ -0,0 +1,280 @@ +/** + * 迷你分时图组件 + * 用于灵活屏中显示证券的日内走势 + */ +import React, { useEffect, useRef, useState, useMemo } from 'react'; +import { Box, Spinner, Center, Text } from '@chakra-ui/react'; +import * as echarts from 'echarts'; +import type { ECharts, EChartsOption } from 'echarts'; + +import type { MiniTimelineChartProps, TimelineDataPoint } from '../types'; + +/** + * 生成交易时间刻度(用于 X 轴) + * A股交易时间:9:30-11:30, 13:00-15:00 + */ +const generateTimeTicks = (): string[] => { + const ticks: string[] = []; + // 上午 + for (let h = 9; h <= 11; h++) { + for (let m = h === 9 ? 30 : 0; m < 60; m++) { + if (h === 11 && m > 30) break; + ticks.push(`${String(h).padStart(2, '0')}:${String(m).padStart(2, '0')}`); + } + } + // 下午 + for (let h = 13; h <= 15; h++) { + for (let m = 0; m < 60; m++) { + if (h === 15 && m > 0) break; + ticks.push(`${String(h).padStart(2, '0')}:${String(m).padStart(2, '0')}`); + } + } + return ticks; +}; + +const TIME_TICKS = generateTimeTicks(); + +/** API 返回的分钟数据结构 */ +interface MinuteKLineItem { + time?: string; + timestamp?: string; + close?: number; + price?: number; +} + +/** API 响应结构 */ +interface KLineApiResponse { + success?: boolean; + data?: MinuteKLineItem[]; + error?: string; +} + +/** + * MiniTimelineChart 组件 + */ +const MiniTimelineChart: React.FC = ({ + code, + isIndex = false, + prevClose, + currentPrice, + height = 120, +}) => { + const chartRef = useRef(null); + const chartInstance = useRef(null); + const [timelineData, setTimelineData] = useState([]); + const [loading, setLoading] = useState(true); + const [error, setError] = useState(null); + + // 获取分钟数据 + useEffect(() => { + if (!code) return; + + const fetchData = async (): Promise => { + setLoading(true); + setError(null); + + try { + const apiPath = isIndex + ? `/api/index/${code}/kline?type=minute` + : `/api/stock/${code}/kline?type=minute`; + + const response = await fetch(apiPath); + const result: KLineApiResponse = await response.json(); + + if (result.success !== false && result.data) { + // 格式化数据 + const formatted: TimelineDataPoint[] = result.data.map(item => ({ + time: item.time || item.timestamp || '', + price: item.close || item.price || 0, + })); + setTimelineData(formatted); + } else { + setError(result.error || '暂无数据'); + } + } catch (e) { + setError('加载失败'); + } finally { + setLoading(false); + } + }; + + fetchData(); + + // 交易时间内每分钟刷新 + const now = new Date(); + const hours = now.getHours(); + const minutes = now.getMinutes(); + const currentMinutes = hours * 60 + minutes; + const isTrading = + (currentMinutes >= 570 && currentMinutes <= 690) || + (currentMinutes >= 780 && currentMinutes <= 900); + + let intervalId: NodeJS.Timeout | undefined; + if (isTrading) { + intervalId = setInterval(fetchData, 60000); // 1分钟刷新 + } + + return () => { + if (intervalId) clearInterval(intervalId); + }; + }, [code, isIndex]); + + // 合并实时价格到数据中 + const chartData = useMemo((): TimelineDataPoint[] => { + if (!timelineData.length) return []; + + const data = [...timelineData]; + + // 如果有实时价格,添加到最新点 + if (currentPrice && data.length > 0) { + const now = new Date(); + const timeStr = `${String(now.getHours()).padStart(2, '0')}:${String(now.getMinutes()).padStart(2, '0')}`; + const lastItem = data[data.length - 1]; + + // 如果实时价格的时间比最后一条数据新,添加新点 + if (lastItem.time !== timeStr) { + data.push({ time: timeStr, price: currentPrice }); + } else { + // 更新最后一条 + data[data.length - 1] = { ...lastItem, price: currentPrice }; + } + } + + return data; + }, [timelineData, currentPrice]); + + // 渲染图表 + useEffect(() => { + if (!chartRef.current || loading || !chartData.length) return; + + if (!chartInstance.current) { + chartInstance.current = echarts.init(chartRef.current); + } + + const baseLine = prevClose || chartData[0]?.price || 0; + + // 计算价格范围 + const prices = chartData.map(d => d.price).filter(p => p > 0); + const minPrice = Math.min(...prices, baseLine); + const maxPrice = Math.max(...prices, baseLine); + const range = Math.max(maxPrice - baseLine, baseLine - minPrice) * 1.1; + + // 准备数据 + const times = chartData.map(d => d.time); + const values = chartData.map(d => d.price); + + // 判断涨跌 + const lastPrice = values[values.length - 1] || baseLine; + const isUp = lastPrice >= baseLine; + + const option: EChartsOption = { + grid: { + top: 5, + right: 5, + bottom: 5, + left: 5, + containLabel: false, + }, + xAxis: { + type: 'category', + data: times, + show: false, + boundaryGap: false, + }, + yAxis: { + type: 'value', + min: baseLine - range, + max: baseLine + range, + show: false, + }, + series: [ + { + type: 'line', + data: values, + smooth: false, + symbol: 'none', + lineStyle: { + width: 1.5, + color: isUp ? '#ef4444' : '#22c55e', + }, + areaStyle: { + color: { + type: 'linear', + x: 0, + y: 0, + x2: 0, + y2: 1, + colorStops: [ + { offset: 0, color: isUp ? 'rgba(239, 68, 68, 0.3)' : 'rgba(34, 197, 94, 0.3)' }, + { offset: 1, color: isUp ? 'rgba(239, 68, 68, 0.05)' : 'rgba(34, 197, 94, 0.05)' }, + ], + }, + }, + markLine: { + silent: true, + symbol: 'none', + data: [ + { + yAxis: baseLine, + lineStyle: { + color: '#666', + type: 'dashed', + width: 1, + }, + label: { show: false }, + }, + ], + }, + }, + ], + animation: false, + }; + + chartInstance.current.setOption(option); + + return () => { + // 不在这里销毁,只在组件卸载时销毁 + }; + }, [chartData, prevClose, loading]); + + // 组件卸载时销毁图表 + useEffect(() => { + return () => { + if (chartInstance.current) { + chartInstance.current.dispose(); + chartInstance.current = null; + } + }; + }, []); + + // 窗口 resize 处理 + useEffect(() => { + const handleResize = (): void => { + chartInstance.current?.resize(); + }; + window.addEventListener('resize', handleResize); + return () => window.removeEventListener('resize', handleResize); + }, []); + + if (loading) { + return ( +
+ +
+ ); + } + + if (error || !chartData.length) { + return ( +
+ + {error || '暂无数据'} + +
+ ); + } + + return ; +}; + +export default MiniTimelineChart; diff --git a/src/views/StockOverview/components/FlexScreen/components/OrderBookPanel.tsx b/src/views/StockOverview/components/FlexScreen/components/OrderBookPanel.tsx new file mode 100644 index 00000000..4da68ca0 --- /dev/null +++ b/src/views/StockOverview/components/FlexScreen/components/OrderBookPanel.tsx @@ -0,0 +1,291 @@ +/** + * 盘口行情面板组件 + * 支持显示 5 档或 10 档买卖盘数据 + * + * 上交所: 5 档行情 + * 深交所: 10 档行情 + */ +import React, { useState } from 'react'; +import { + Box, + VStack, + HStack, + Text, + Button, + ButtonGroup, + useColorModeValue, + Tooltip, + Badge, +} from '@chakra-ui/react'; + +import type { OrderBookPanelProps } from '../types'; + +/** 格式化价格返回值 */ +interface FormattedPrice { + text: string; + color: string; +} + +/** + * 格式化成交量 + */ +const formatVolume = (volume: number): string => { + if (!volume || volume === 0) return '-'; + if (volume >= 10000) { + return `${(volume / 10000).toFixed(0)}万`; + } + if (volume >= 1000) { + return `${(volume / 1000).toFixed(1)}k`; + } + return String(volume); +}; + +/** + * 格式化价格 + */ +const formatPrice = (price: number, prevClose?: number): FormattedPrice => { + if (!price || price === 0) { + return { text: '-', color: 'gray.400' }; + } + + const text = price.toFixed(2); + + if (!prevClose || prevClose === 0) { + return { text, color: 'gray.600' }; + } + + if (price > prevClose) { + return { text, color: 'red.500' }; + } + if (price < prevClose) { + return { text, color: 'green.500' }; + } + return { text, color: 'gray.600' }; +}; + +/** OrderRow 组件 Props */ +interface OrderRowProps { + label: string; + price: number; + volume: number; + prevClose?: number; + isBid: boolean; + maxVolume: number; + isLimitPrice: boolean; +} + +/** + * 单行盘口 + */ +const OrderRow: React.FC = ({ + label, + price, + volume, + prevClose, + isBid, + maxVolume, + isLimitPrice, +}) => { + const bgColor = useColorModeValue( + isBid ? 'red.50' : 'green.50', + isBid ? 'rgba(239, 68, 68, 0.1)' : 'rgba(34, 197, 94, 0.1)' + ); + const barColor = useColorModeValue( + isBid ? 'red.200' : 'green.200', + isBid ? 'rgba(239, 68, 68, 0.3)' : 'rgba(34, 197, 94, 0.3)' + ); + const limitColor = useColorModeValue('orange.500', 'orange.300'); + + const priceInfo = formatPrice(price, prevClose); + const volumeText = formatVolume(volume); + + // 计算成交量条宽度 + const barWidth = maxVolume > 0 ? Math.min((volume / maxVolume) * 100, 100) : 0; + + return ( + + {/* 成交量条 */} + + + {/* 内容 */} + + {label} + + + + {priceInfo.text} + + {isLimitPrice && ( + + + {isBid ? '跌' : '涨'} + + + )} + + + {volumeText} + + + ); +}; + +/** + * OrderBookPanel 组件 + */ +const OrderBookPanel: React.FC = ({ + bidPrices = [], + bidVolumes = [], + askPrices = [], + askVolumes = [], + prevClose, + upperLimit, + lowerLimit, + defaultLevels = 5, +}) => { + const borderColor = useColorModeValue('gray.200', 'gray.700'); + const buttonBg = useColorModeValue('gray.100', 'gray.700'); + const bgColor = useColorModeValue('white', '#1a1a1a'); + + // 可切换显示的档位数 + const maxAvailableLevels = Math.max(bidPrices.length, askPrices.length, 1); + const [showLevels, setShowLevels] = useState(Math.min(defaultLevels, maxAvailableLevels)); + + // 计算最大成交量(用于条形图比例) + const displayBidVolumes = bidVolumes.slice(0, showLevels); + const displayAskVolumes = askVolumes.slice(0, showLevels); + const allVolumes = [...displayBidVolumes, ...displayAskVolumes].filter(v => v > 0); + const maxVolume = allVolumes.length > 0 ? Math.max(...allVolumes) : 0; + + // 判断是否为涨跌停价 + const isUpperLimit = (price: number): boolean => + !!upperLimit && Math.abs(price - upperLimit) < 0.001; + const isLowerLimit = (price: number): boolean => + !!lowerLimit && Math.abs(price - lowerLimit) < 0.001; + + // 卖盘(从卖N到卖1,即价格从高到低) + const askRows: React.ReactNode[] = []; + for (let i = showLevels - 1; i >= 0; i--) { + askRows.push( + + ); + } + + // 买盘(从买1到买N,即价格从高到低) + const bidRows: React.ReactNode[] = []; + for (let i = 0; i < showLevels; i++) { + bidRows.push( + + ); + } + + // 没有数据时的提示 + const hasData = bidPrices.length > 0 || askPrices.length > 0; + + if (!hasData) { + return ( + + + 暂无盘口数据 + + + ); + } + + return ( + + {/* 档位切换(只有当有超过5档数据时才显示) */} + {maxAvailableLevels > 5 && ( + + + + + + + )} + + {/* 卖盘 */} + {askRows} + + {/* 分隔线 + 当前价信息 */} + + {prevClose && ( + + 昨收 {prevClose.toFixed(2)} + + )} + + + {/* 买盘 */} + {bidRows} + + {/* 涨跌停价信息 */} + {(upperLimit || lowerLimit) && ( + + {lowerLimit && 跌停 {lowerLimit.toFixed(2)}} + {upperLimit && 涨停 {upperLimit.toFixed(2)}} + + )} + + ); +}; + +export default OrderBookPanel; diff --git a/src/views/StockOverview/components/FlexScreen/components/QuoteTile.tsx b/src/views/StockOverview/components/FlexScreen/components/QuoteTile.tsx new file mode 100644 index 00000000..d5a51be5 --- /dev/null +++ b/src/views/StockOverview/components/FlexScreen/components/QuoteTile.tsx @@ -0,0 +1,288 @@ +/** + * 行情瓷砖组件 + * 单个证券的实时行情展示卡片,包含分时图和五档盘口 + */ +import React, { useState } from 'react'; +import { + Box, + VStack, + HStack, + Text, + IconButton, + Tooltip, + useColorModeValue, + Collapse, + Badge, +} from '@chakra-ui/react'; +import { CloseIcon, ChevronDownIcon, ChevronUpIcon } from '@chakra-ui/icons'; +import { useNavigate } from 'react-router-dom'; + +import MiniTimelineChart from './MiniTimelineChart'; +import OrderBookPanel from './OrderBookPanel'; +import type { QuoteTileProps, QuoteData } from '../types'; + +/** + * 格式化价格显示 + */ +const formatPrice = (price?: number): string => { + if (!price || isNaN(price)) return '-'; + return price.toFixed(2); +}; + +/** + * 格式化涨跌幅 + */ +const formatChangePct = (pct?: number): string => { + if (!pct || isNaN(pct)) return '0.00%'; + const sign = pct > 0 ? '+' : ''; + return `${sign}${pct.toFixed(2)}%`; +}; + +/** + * 格式化涨跌额 + */ +const formatChange = (change?: number): string => { + if (!change || isNaN(change)) return '-'; + const sign = change > 0 ? '+' : ''; + return `${sign}${change.toFixed(2)}`; +}; + +/** + * 格式化成交额 + */ +const formatAmount = (amount?: number): string => { + if (!amount || isNaN(amount)) return '-'; + if (amount >= 100000000) { + return `${(amount / 100000000).toFixed(2)}亿`; + } + if (amount >= 10000) { + return `${(amount / 10000).toFixed(0)}万`; + } + return amount.toFixed(0); +}; + +/** + * QuoteTile 组件 + */ +const QuoteTile: React.FC = ({ + code, + name, + quote = {}, + isIndex = false, + onRemove, +}) => { + const navigate = useNavigate(); + const [expanded, setExpanded] = useState(true); + + // 类型断言,确保类型安全 + const quoteData = quote as Partial; + + // 颜色主题 + const cardBg = useColorModeValue('white', '#1a1a1a'); + const borderColor = useColorModeValue('gray.200', '#333'); + const hoverBorderColor = useColorModeValue('purple.300', '#666'); + const textColor = useColorModeValue('gray.800', 'white'); + const subTextColor = useColorModeValue('gray.500', 'gray.400'); + + // 涨跌色 + const { price, prevClose, change, changePct, amount } = quoteData; + const priceColor = useColorModeValue( + !prevClose || price === prevClose + ? 'gray.800' + : price && price > prevClose + ? 'red.500' + : 'green.500', + !prevClose || price === prevClose + ? 'gray.200' + : price && price > prevClose + ? 'red.400' + : 'green.400' + ); + + // 涨跌幅背景色 + const changeBgColor = useColorModeValue( + !changePct || changePct === 0 + ? 'gray.100' + : changePct > 0 + ? 'red.100' + : 'green.100', + !changePct || changePct === 0 + ? 'gray.700' + : changePct > 0 + ? 'rgba(239, 68, 68, 0.2)' + : 'rgba(34, 197, 94, 0.2)' + ); + + // 跳转到详情页 + const handleNavigate = (): void => { + if (isIndex) { + // 指数暂无详情页 + return; + } + navigate(`/company?scode=${code}`); + }; + + // 获取盘口数据(带类型安全) + const bidPrices = 'bidPrices' in quoteData ? (quoteData.bidPrices as number[]) : []; + const bidVolumes = 'bidVolumes' in quoteData ? (quoteData.bidVolumes as number[]) : []; + const askPrices = 'askPrices' in quoteData ? (quoteData.askPrices as number[]) : []; + const askVolumes = 'askVolumes' in quoteData ? (quoteData.askVolumes as number[]) : []; + const upperLimit = 'upperLimit' in quoteData ? (quoteData.upperLimit as number | undefined) : undefined; + const lowerLimit = 'lowerLimit' in quoteData ? (quoteData.lowerLimit as number | undefined) : undefined; + const openPrice = 'open' in quoteData ? (quoteData.open as number | undefined) : undefined; + + return ( + + {/* 头部 */} + setExpanded(!expanded)} + > + {/* 名称和代码 */} + + + { + e.stopPropagation(); + handleNavigate(); + }} + > + {name || code} + + {isIndex && ( + + 指数 + + )} + + + {code} + + + + {/* 价格信息 */} + + + {formatPrice(price)} + + + + {formatChangePct(changePct)} + + + {formatChange(change)} + + + + + {/* 操作按钮 */} + + : } + size="xs" + variant="ghost" + aria-label={expanded ? '收起' : '展开'} + onClick={(e) => { + e.stopPropagation(); + setExpanded(!expanded); + }} + /> + + } + size="xs" + variant="ghost" + colorScheme="red" + aria-label="移除" + onClick={(e) => { + e.stopPropagation(); + onRemove?.(code); + }} + /> + + + + + {/* 可折叠内容 */} + + + {/* 统计信息 */} + + + 昨收: + {formatPrice(prevClose)} + + + 今开: + {formatPrice(openPrice)} + + + 成交额: + {formatAmount(amount)} + + + + {/* 分时图 */} + + + + + {/* 盘口(指数没有盘口) */} + {!isIndex && ( + + + 盘口 {bidPrices.length > 5 ? '(10档)' : '(5档)'} + + + + )} + + + + ); +}; + +export default QuoteTile; diff --git a/src/views/StockOverview/components/FlexScreen/components/index.ts b/src/views/StockOverview/components/FlexScreen/components/index.ts new file mode 100644 index 00000000..4777c0f1 --- /dev/null +++ b/src/views/StockOverview/components/FlexScreen/components/index.ts @@ -0,0 +1,7 @@ +/** + * 组件导出文件 + */ + +export { default as MiniTimelineChart } from './MiniTimelineChart'; +export { default as OrderBookPanel } from './OrderBookPanel'; +export { default as QuoteTile } from './QuoteTile'; diff --git a/src/views/StockOverview/components/FlexScreen/hooks/constants.ts b/src/views/StockOverview/components/FlexScreen/hooks/constants.ts new file mode 100644 index 00000000..f9d41b30 --- /dev/null +++ b/src/views/StockOverview/components/FlexScreen/hooks/constants.ts @@ -0,0 +1,46 @@ +/** + * WebSocket 配置常量 + */ + +import type { Exchange } from '../types'; + +/** + * 获取 WebSocket 配置 + * - 生产环境 (HTTPS): 通过 Nginx 代理使用 wss:// + * - 开发环境 (HTTP): 直连 ws:// + */ +const getWsConfig = (): Record => { + // 服务端渲染或测试环境使用默认配置 + if (typeof window === 'undefined') { + return { + SSE: 'ws://49.232.185.254:8765', + SZSE: 'ws://222.128.1.157:8765', + }; + } + + const isHttps = window.location.protocol === 'https:'; + const host = window.location.host; + + if (isHttps) { + // 生产环境:通过 Nginx 代理 + return { + SSE: `wss://${host}/ws/sse`, // 上交所 - Nginx 代理 + SZSE: `wss://${host}/ws/szse`, // 深交所 - Nginx 代理 + }; + } + + // 开发环境:直连 + return { + SSE: 'ws://49.232.185.254:8765', // 上交所 + SZSE: 'ws://222.128.1.157:8765', // 深交所 + }; +}; + +/** WebSocket 服务地址 */ +export const WS_CONFIG: Record = getWsConfig(); + +/** 心跳间隔 (ms) */ +export const HEARTBEAT_INTERVAL = 30000; + +/** 重连间隔 (ms) */ +export const RECONNECT_INTERVAL = 3000; diff --git a/src/views/StockOverview/components/FlexScreen/hooks/index.ts b/src/views/StockOverview/components/FlexScreen/hooks/index.ts new file mode 100644 index 00000000..ffe3399e --- /dev/null +++ b/src/views/StockOverview/components/FlexScreen/hooks/index.ts @@ -0,0 +1,7 @@ +/** + * Hooks 导出文件 + */ + +export { useRealtimeQuote } from './useRealtimeQuote'; +export * from './constants'; +export * from './utils'; diff --git a/src/views/StockOverview/components/FlexScreen/hooks/useRealtimeQuote.ts b/src/views/StockOverview/components/FlexScreen/hooks/useRealtimeQuote.ts new file mode 100644 index 00000000..09687696 --- /dev/null +++ b/src/views/StockOverview/components/FlexScreen/hooks/useRealtimeQuote.ts @@ -0,0 +1,722 @@ +/** + * 实时行情 Hook + * 管理上交所和深交所 WebSocket 连接,获取实时行情数据 + * + * 连接方式: + * - 生产环境 (HTTPS): 通过 Nginx 代理使用 wss:// (如 wss://valuefrontier.cn/ws/sse) + * - 开发环境 (HTTP): 直连 ws:// + * + * 上交所 (SSE): 需主动订阅,提供五档行情 + * 深交所 (SZSE): 自动推送,提供十档行情 + */ + +import { useState, useEffect, useRef, useCallback } from 'react'; +import { logger } from '@utils/logger'; +import { WS_CONFIG, HEARTBEAT_INTERVAL, RECONNECT_INTERVAL } from './constants'; +import { getExchange, normalizeCode, extractOrderBook, calcChangePct } from './utils'; +import type { + Exchange, + ConnectionStatus, + QuotesMap, + QuoteData, + SSEMessage, + SSEQuoteItem, + SZSEMessage, + SZSERealtimeMessage, + SZSESnapshotMessage, + SZSEStockData, + SZSEIndexData, + SZSEBondData, + SZSEHKStockData, + SZSEAfterhoursData, + UseRealtimeQuoteReturn, +} from '../types'; + +/** + * 处理上交所消息 + * 注意:上交所返回的 code 不带后缀,但通过 msg.type 区分 'stock' 和 'index' + * 存储时使用带后缀的完整代码作为 key(如 000001.SH) + */ +const handleSSEMessage = ( + msg: SSEMessage, + subscribedCodes: Set, + prevQuotes: QuotesMap +): QuotesMap | null => { + if (msg.type !== 'stock' && msg.type !== 'index') { + return null; + } + + const data = msg.data || {}; + const updated: QuotesMap = { ...prevQuotes }; + let hasUpdate = false; + const isIndex = msg.type === 'index'; + + Object.entries(data).forEach(([code, quote]: [string, SSEQuoteItem]) => { + // 生成带后缀的完整代码(上交所统一用 .SH) + const fullCode = code.includes('.') ? code : `${code}.SH`; + + if (subscribedCodes.has(code) || subscribedCodes.has(fullCode)) { + hasUpdate = true; + updated[fullCode] = { + code: fullCode, + name: quote.security_name, + price: quote.last_price, + prevClose: quote.prev_close, + open: quote.open_price, + high: quote.high_price, + low: quote.low_price, + volume: quote.volume, + amount: quote.amount, + change: quote.last_price - quote.prev_close, + changePct: calcChangePct(quote.last_price, quote.prev_close), + bidPrices: quote.bid_prices || [], + bidVolumes: quote.bid_volumes || [], + askPrices: quote.ask_prices || [], + askVolumes: quote.ask_volumes || [], + updateTime: quote.trade_time, + exchange: 'SSE', + } as QuoteData; + } + }); + + return hasUpdate ? updated : null; +}; + +/** + * 处理深交所实时消息 + * 注意:深交所返回的 security_id 可能带后缀也可能不带 + * 存储时统一使用带后缀的完整代码作为 key(如 000001.SZ) + */ +const handleSZSERealtimeMessage = ( + msg: SZSERealtimeMessage, + subscribedCodes: Set, + prevQuotes: QuotesMap +): QuotesMap | null => { + const { category, data, timestamp } = msg; + const rawCode = data.security_id; + // 生成带后缀的完整代码(深交所统一用 .SZ) + const fullCode = rawCode.includes('.') ? rawCode : `${rawCode}.SZ`; + + if (!subscribedCodes.has(rawCode) && !subscribedCodes.has(fullCode)) { + return null; + } + + const updated: QuotesMap = { ...prevQuotes }; + + switch (category) { + case 'stock': { + const stockData = data as SZSEStockData; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const rawData = data as any; // 用于检查替代字段名 + + // 调试日志:检查深交所返回的盘口原始数据(临时使用 warn 级别方便调试) + if (!stockData.bids || stockData.bids.length === 0) { + logger.warn('FlexScreen', `SZSE股票数据无盘口 ${fullCode}`, { + hasBids: !!stockData.bids, + hasAsks: !!stockData.asks, + bidsLength: stockData.bids?.length || 0, + asksLength: stockData.asks?.length || 0, + // 检查替代字段名 + hasBidPrices: !!rawData.bid_prices, + hasAskPrices: !!rawData.ask_prices, + dataKeys: Object.keys(stockData), // 查看服务端实际返回了哪些字段 + }); + } + + // 优先使用 bids/asks 对象数组格式,如果不存在则尝试 bid_prices/ask_prices 分离数组格式 + let bidPrices: number[] = []; + let bidVolumes: number[] = []; + let askPrices: number[] = []; + let askVolumes: number[] = []; + + if (stockData.bids && stockData.bids.length > 0) { + const extracted = extractOrderBook(stockData.bids); + bidPrices = extracted.prices; + bidVolumes = extracted.volumes; + } else if (rawData.bid_prices && Array.isArray(rawData.bid_prices)) { + // 替代格式:bid_prices 和 bid_volumes 分离 + bidPrices = rawData.bid_prices; + bidVolumes = rawData.bid_volumes || []; + } + + if (stockData.asks && stockData.asks.length > 0) { + const extracted = extractOrderBook(stockData.asks); + askPrices = extracted.prices; + askVolumes = extracted.volumes; + } else if (rawData.ask_prices && Array.isArray(rawData.ask_prices)) { + // 替代格式:ask_prices 和 ask_volumes 分离 + askPrices = rawData.ask_prices; + askVolumes = rawData.ask_volumes || []; + } + + updated[fullCode] = { + code: fullCode, + name: prevQuotes[fullCode]?.name || '', + price: stockData.last_px, + prevClose: stockData.prev_close, + open: stockData.open_px, + high: stockData.high_px, + low: stockData.low_px, + volume: stockData.volume, + amount: stockData.amount, + numTrades: stockData.num_trades, + upperLimit: stockData.upper_limit, + lowerLimit: stockData.lower_limit, + change: stockData.last_px - stockData.prev_close, + changePct: calcChangePct(stockData.last_px, stockData.prev_close), + bidPrices, + bidVolumes, + askPrices, + askVolumes, + tradingPhase: stockData.trading_phase, + updateTime: timestamp, + exchange: 'SZSE', + } as QuoteData; + break; + } + + case 'index': { + const indexData = data as SZSEIndexData; + updated[fullCode] = { + code: fullCode, + name: prevQuotes[fullCode]?.name || '', + price: indexData.current_index, + prevClose: indexData.prev_close, + open: indexData.open_index, + high: indexData.high_index, + low: indexData.low_index, + close: indexData.close_index, + volume: indexData.volume, + amount: indexData.amount, + numTrades: indexData.num_trades, + change: indexData.current_index - indexData.prev_close, + changePct: calcChangePct(indexData.current_index, indexData.prev_close), + bidPrices: [], + bidVolumes: [], + askPrices: [], + askVolumes: [], + tradingPhase: indexData.trading_phase, + updateTime: timestamp, + exchange: 'SZSE', + } as QuoteData; + break; + } + + case 'bond': { + const bondData = data as SZSEBondData; + updated[fullCode] = { + code: fullCode, + name: prevQuotes[fullCode]?.name || '', + price: bondData.last_px, + prevClose: bondData.prev_close, + open: bondData.open_px, + high: bondData.high_px, + low: bondData.low_px, + volume: bondData.volume, + amount: bondData.amount, + numTrades: bondData.num_trades, + weightedAvgPx: bondData.weighted_avg_px, + change: bondData.last_px - bondData.prev_close, + changePct: calcChangePct(bondData.last_px, bondData.prev_close), + bidPrices: [], + bidVolumes: [], + askPrices: [], + askVolumes: [], + tradingPhase: bondData.trading_phase, + updateTime: timestamp, + exchange: 'SZSE', + isBond: true, + } as QuoteData; + break; + } + + case 'hk_stock': { + const hkData = data as SZSEHKStockData; + const { prices: bidPrices, volumes: bidVolumes } = extractOrderBook(hkData.bids); + const { prices: askPrices, volumes: askVolumes } = extractOrderBook(hkData.asks); + + updated[fullCode] = { + code: fullCode, + name: prevQuotes[fullCode]?.name || '', + price: hkData.last_px, + prevClose: hkData.prev_close, + open: hkData.open_px, + high: hkData.high_px, + low: hkData.low_px, + volume: hkData.volume, + amount: hkData.amount, + numTrades: hkData.num_trades, + nominalPx: hkData.nominal_px, + referencePx: hkData.reference_px, + change: hkData.last_px - hkData.prev_close, + changePct: calcChangePct(hkData.last_px, hkData.prev_close), + bidPrices, + bidVolumes, + askPrices, + askVolumes, + tradingPhase: hkData.trading_phase, + updateTime: timestamp, + exchange: 'SZSE', + isHK: true, + } as QuoteData; + break; + } + + case 'afterhours_block': + case 'afterhours_trading': { + const afterhoursData = data as SZSEAfterhoursData; + const existing = prevQuotes[fullCode]; + if (existing) { + updated[fullCode] = { + ...existing, + afterhours: { + bidPx: afterhoursData.bid_px, + bidSize: afterhoursData.bid_size, + offerPx: afterhoursData.offer_px, + offerSize: afterhoursData.offer_size, + volume: afterhoursData.volume, + amount: afterhoursData.amount, + numTrades: afterhoursData.num_trades || 0, + }, + updateTime: timestamp, + } as QuoteData; + } + break; + } + + default: + return null; + } + + return updated; +}; + +/** + * 处理深交所快照消息 + * 存储时统一使用带后缀的完整代码作为 key + */ +const handleSZSESnapshotMessage = ( + msg: SZSESnapshotMessage, + subscribedCodes: Set, + prevQuotes: QuotesMap +): QuotesMap | null => { + const { stocks = [], indexes = [], bonds = [] } = msg.data || {}; + const updated: QuotesMap = { ...prevQuotes }; + let hasUpdate = false; + + stocks.forEach((s: SZSEStockData) => { + const rawCode = s.security_id; + const fullCode = rawCode.includes('.') ? rawCode : `${rawCode}.SZ`; + + if (subscribedCodes.has(rawCode) || subscribedCodes.has(fullCode)) { + hasUpdate = true; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const rawData = s as any; // 用于检查替代字段名 + + // 调试日志:检查快照消息中的盘口数据(无盘口时警告) + if (!s.bids || s.bids.length === 0) { + logger.warn('FlexScreen', `SZSE快照股票数据无盘口 ${fullCode}`, { + hasBids: !!s.bids, + hasAsks: !!s.asks, + hasBidPrices: !!rawData.bid_prices, + hasAskPrices: !!rawData.ask_prices, + dataKeys: Object.keys(s), + }); + } + + // 优先使用 bids/asks 对象数组格式,如果不存在则尝试 bid_prices/ask_prices 分离数组格式 + let bidPrices: number[] = []; + let bidVolumes: number[] = []; + let askPrices: number[] = []; + let askVolumes: number[] = []; + + if (s.bids && s.bids.length > 0) { + const extracted = extractOrderBook(s.bids); + bidPrices = extracted.prices; + bidVolumes = extracted.volumes; + } else if (rawData.bid_prices && Array.isArray(rawData.bid_prices)) { + bidPrices = rawData.bid_prices; + bidVolumes = rawData.bid_volumes || []; + } + + if (s.asks && s.asks.length > 0) { + const extracted = extractOrderBook(s.asks); + askPrices = extracted.prices; + askVolumes = extracted.volumes; + } else if (rawData.ask_prices && Array.isArray(rawData.ask_prices)) { + askPrices = rawData.ask_prices; + askVolumes = rawData.ask_volumes || []; + } + + updated[fullCode] = { + code: fullCode, + name: '', + price: s.last_px, + prevClose: s.prev_close, + open: s.open_px, + high: s.high_px, + low: s.low_px, + volume: s.volume, + amount: s.amount, + numTrades: s.num_trades, + upperLimit: s.upper_limit, + lowerLimit: s.lower_limit, + change: s.last_px - s.prev_close, + changePct: calcChangePct(s.last_px, s.prev_close), + bidPrices, + bidVolumes, + askPrices, + askVolumes, + exchange: 'SZSE', + } as QuoteData; + } + }); + + indexes.forEach((i: SZSEIndexData) => { + const rawCode = i.security_id; + const fullCode = rawCode.includes('.') ? rawCode : `${rawCode}.SZ`; + + if (subscribedCodes.has(rawCode) || subscribedCodes.has(fullCode)) { + hasUpdate = true; + updated[fullCode] = { + code: fullCode, + name: '', + price: i.current_index, + prevClose: i.prev_close, + open: i.open_index, + high: i.high_index, + low: i.low_index, + volume: i.volume, + amount: i.amount, + numTrades: i.num_trades, + change: i.current_index - i.prev_close, + changePct: calcChangePct(i.current_index, i.prev_close), + bidPrices: [], + bidVolumes: [], + askPrices: [], + askVolumes: [], + exchange: 'SZSE', + } as QuoteData; + } + }); + + bonds.forEach((b: SZSEBondData) => { + const rawCode = b.security_id; + const fullCode = rawCode.includes('.') ? rawCode : `${rawCode}.SZ`; + + if (subscribedCodes.has(rawCode) || subscribedCodes.has(fullCode)) { + hasUpdate = true; + updated[fullCode] = { + code: fullCode, + name: '', + price: b.last_px, + prevClose: b.prev_close, + open: b.open_px, + high: b.high_px, + low: b.low_px, + volume: b.volume, + amount: b.amount, + change: b.last_px - b.prev_close, + changePct: calcChangePct(b.last_px, b.prev_close), + bidPrices: [], + bidVolumes: [], + askPrices: [], + askVolumes: [], + exchange: 'SZSE', + isBond: true, + } as QuoteData; + } + }); + + return hasUpdate ? updated : null; +}; + +/** + * 实时行情 Hook + * @param codes - 订阅的证券代码列表 + */ +export const useRealtimeQuote = (codes: string[] = []): UseRealtimeQuoteReturn => { + const [quotes, setQuotes] = useState({}); + const [connected, setConnected] = useState({ SSE: false, SZSE: false }); + + const wsRefs = useRef>({ SSE: null, SZSE: null }); + const heartbeatRefs = useRef>({ SSE: null, SZSE: null }); + const reconnectRefs = useRef>({ SSE: null, SZSE: null }); + const subscribedCodes = useRef>>({ + SSE: new Set(), + SZSE: new Set(), + }); + + const stopHeartbeat = useCallback((exchange: Exchange) => { + if (heartbeatRefs.current[exchange]) { + clearInterval(heartbeatRefs.current[exchange]!); + heartbeatRefs.current[exchange] = null; + } + }, []); + + const startHeartbeat = useCallback((exchange: Exchange) => { + stopHeartbeat(exchange); + heartbeatRefs.current[exchange] = setInterval(() => { + const ws = wsRefs.current[exchange]; + if (ws && ws.readyState === WebSocket.OPEN) { + const msg = exchange === 'SSE' ? { action: 'ping' } : { type: 'ping' }; + ws.send(JSON.stringify(msg)); + } + }, HEARTBEAT_INTERVAL); + }, [stopHeartbeat]); + + const handleMessage = useCallback((exchange: Exchange, msg: SSEMessage | SZSEMessage) => { + if (msg.type === 'pong') return; + + if (exchange === 'SSE') { + const result = handleSSEMessage( + msg as SSEMessage, + subscribedCodes.current.SSE, + {} // Will be merged with current state + ); + if (result) { + setQuotes(prev => ({ ...prev, ...result })); + } + } else { + if (msg.type === 'realtime') { + setQuotes(prev => { + const result = handleSZSERealtimeMessage( + msg as SZSERealtimeMessage, + subscribedCodes.current.SZSE, + prev + ); + return result || prev; + }); + } else if (msg.type === 'snapshot') { + setQuotes(prev => { + const result = handleSZSESnapshotMessage( + msg as SZSESnapshotMessage, + subscribedCodes.current.SZSE, + prev + ); + return result || prev; + }); + } + } + }, []); + + const createConnection = useCallback((exchange: Exchange) => { + // 防御性检查:确保 HTTPS 页面不会意外连接 ws://(Mixed Content 安全错误) + // 正常情况下 WS_CONFIG 会自动根据协议返回正确的 URL,这里是备用保护 + const isHttps = typeof window !== 'undefined' && window.location.protocol === 'https:'; + const wsUrl = WS_CONFIG[exchange]; + const isInsecureWs = wsUrl.startsWith('ws://'); + + if (isHttps && isInsecureWs) { + logger.warn( + 'FlexScreen', + `${exchange} WebSocket 配置错误:HTTPS 页面尝试连接 ws:// 端点,请检查 Nginx 代理配置` + ); + return; + } + + if (wsRefs.current[exchange]) { + wsRefs.current[exchange]!.close(); + } + + try { + const ws = new WebSocket(wsUrl); + wsRefs.current[exchange] = ws; + + ws.onopen = () => { + logger.info('FlexScreen', `${exchange} WebSocket 已连接`); + setConnected(prev => ({ ...prev, [exchange]: true })); + + if (exchange === 'SSE') { + // subscribedCodes 存的是带后缀的完整代码,发送给 WS 需要去掉后缀 + const fullCodes = Array.from(subscribedCodes.current.SSE); + const baseCodes = fullCodes.map(c => normalizeCode(c)); + if (baseCodes.length > 0) { + ws.send(JSON.stringify({ + action: 'subscribe', + channels: ['stock', 'index'], + codes: baseCodes, + })); + } + } + + startHeartbeat(exchange); + }; + + ws.onmessage = (event: MessageEvent) => { + try { + const msg = JSON.parse(event.data); + handleMessage(exchange, msg); + } catch (e) { + logger.warn('FlexScreen', `${exchange} 消息解析失败`, e); + } + }; + + ws.onerror = (error: Event) => { + logger.error('FlexScreen', `${exchange} WebSocket 错误`, error); + }; + + ws.onclose = () => { + logger.info('FlexScreen', `${exchange} WebSocket 断开`); + setConnected(prev => ({ ...prev, [exchange]: false })); + stopHeartbeat(exchange); + + // 自动重连(仅在非 HTTPS + ws:// 场景下) + if (!reconnectRefs.current[exchange] && subscribedCodes.current[exchange].size > 0) { + reconnectRefs.current[exchange] = setTimeout(() => { + reconnectRefs.current[exchange] = null; + if (subscribedCodes.current[exchange].size > 0) { + createConnection(exchange); + } + }, RECONNECT_INTERVAL); + } + }; + } catch (e) { + logger.error('FlexScreen', `${exchange} WebSocket 连接失败`, e); + setConnected(prev => ({ ...prev, [exchange]: false })); + } + }, [startHeartbeat, stopHeartbeat, handleMessage]); + + const subscribe = useCallback((code: string) => { + const exchange = getExchange(code); + // 确保使用带后缀的完整代码 + const fullCode = code.includes('.') ? code : `${code}.${exchange === 'SSE' ? 'SH' : 'SZ'}`; + const baseCode = normalizeCode(code); + + subscribedCodes.current[exchange].add(fullCode); + + const ws = wsRefs.current[exchange]; + if (exchange === 'SSE' && ws && ws.readyState === WebSocket.OPEN) { + ws.send(JSON.stringify({ + action: 'subscribe', + channels: ['stock', 'index'], + codes: [baseCode], // 发送给 WS 用不带后缀的代码 + })); + } + + if (!ws || ws.readyState !== WebSocket.OPEN) { + createConnection(exchange); + } + }, [createConnection]); + + const unsubscribe = useCallback((code: string) => { + const exchange = getExchange(code); + // 确保使用带后缀的完整代码 + const fullCode = code.includes('.') ? code : `${code}.${exchange === 'SSE' ? 'SH' : 'SZ'}`; + + subscribedCodes.current[exchange].delete(fullCode); + + setQuotes(prev => { + const updated = { ...prev }; + delete updated[fullCode]; // 删除时也用带后缀的 key + return updated; + }); + + if (subscribedCodes.current[exchange].size === 0) { + const ws = wsRefs.current[exchange]; + if (ws) { + ws.close(); + wsRefs.current[exchange] = null; + } + } + }, []); + + // 初始化和 codes 变化处理 + // 注意:codes 现在是带后缀的完整代码(如 000001.SH) + useEffect(() => { + if (!codes || codes.length === 0) return; + + // 使用带后缀的完整代码作为内部 key + const newSseCodes = new Set(); + const newSzseCodes = new Set(); + + codes.forEach(code => { + const exchange = getExchange(code); + // 确保代码带后缀 + const fullCode = code.includes('.') ? code : `${code}.${exchange === 'SSE' ? 'SH' : 'SZ'}`; + if (exchange === 'SSE') { + newSseCodes.add(fullCode); + } else { + newSzseCodes.add(fullCode); + } + }); + + // 更新上交所订阅 + const oldSseCodes = subscribedCodes.current.SSE; + const sseToAdd = [...newSseCodes].filter(c => !oldSseCodes.has(c)); + // 发送给 WebSocket 的代码需要去掉后缀 + const sseToAddBase = sseToAdd.map(c => normalizeCode(c)); + + if (sseToAdd.length > 0 || newSseCodes.size !== oldSseCodes.size) { + subscribedCodes.current.SSE = newSseCodes; + const ws = wsRefs.current.SSE; + + if (ws && ws.readyState === WebSocket.OPEN && sseToAddBase.length > 0) { + ws.send(JSON.stringify({ + action: 'subscribe', + channels: ['stock', 'index'], + codes: sseToAddBase, + })); + } + + if (sseToAdd.length > 0 && (!ws || ws.readyState !== WebSocket.OPEN)) { + createConnection('SSE'); + } + + if (newSseCodes.size === 0 && ws) { + ws.close(); + wsRefs.current.SSE = null; + } + } + + // 更新深交所订阅 + const oldSzseCodes = subscribedCodes.current.SZSE; + const szseToAdd = [...newSzseCodes].filter(c => !oldSzseCodes.has(c)); + + if (szseToAdd.length > 0 || newSzseCodes.size !== oldSzseCodes.size) { + subscribedCodes.current.SZSE = newSzseCodes; + const ws = wsRefs.current.SZSE; + + if (szseToAdd.length > 0 && (!ws || ws.readyState !== WebSocket.OPEN)) { + createConnection('SZSE'); + } + + if (newSzseCodes.size === 0 && ws) { + ws.close(); + wsRefs.current.SZSE = null; + } + } + + // 清理已取消订阅的 quotes(使用带后缀的完整代码) + const allNewCodes = new Set([...newSseCodes, ...newSzseCodes]); + setQuotes(prev => { + const updated: QuotesMap = {}; + Object.keys(prev).forEach(code => { + if (allNewCodes.has(code)) { + updated[code] = prev[code]; + } + }); + return updated; + }); + }, [codes, createConnection]); + + // 清理 + useEffect(() => { + return () => { + (['SSE', 'SZSE'] as Exchange[]).forEach(exchange => { + stopHeartbeat(exchange); + if (reconnectRefs.current[exchange]) { + clearTimeout(reconnectRefs.current[exchange]!); + } + const ws = wsRefs.current[exchange]; + if (ws) { + ws.close(); + } + }); + }; + }, [stopHeartbeat]); + + return { quotes, connected, subscribe, unsubscribe }; +}; + +export default useRealtimeQuote; diff --git a/src/views/StockOverview/components/FlexScreen/hooks/utils.ts b/src/views/StockOverview/components/FlexScreen/hooks/utils.ts new file mode 100644 index 00000000..dcc54890 --- /dev/null +++ b/src/views/StockOverview/components/FlexScreen/hooks/utils.ts @@ -0,0 +1,148 @@ +/** + * 实时行情相关工具函数 + */ + +import type { Exchange, OrderBookLevel } from '../types'; + +/** + * 判断证券代码属于哪个交易所 + * @param code - 证券代码(可带或不带后缀) + * @param isIndex - 是否为指数(用于区分同代码的指数和股票,如 000001) + * @returns 交易所标识 + */ +export const getExchange = (code: string, isIndex?: boolean): Exchange => { + // 如果已带后缀,直接判断 + if (code.includes('.')) { + return code.endsWith('.SH') ? 'SSE' : 'SZSE'; + } + + const baseCode = code; + + // 6开头为上海股票 + if (baseCode.startsWith('6')) { + return 'SSE'; + } + + // 5开头是上海 ETF + if (baseCode.startsWith('5')) { + return 'SSE'; + } + + // 399开头是深证指数 + if (baseCode.startsWith('399')) { + return 'SZSE'; + } + + // 000开头:如果是指数则为上交所(上证指数000001),否则为深交所(平安银行000001) + if (baseCode.startsWith('000')) { + return isIndex ? 'SSE' : 'SZSE'; + } + + // 0、3开头是深圳股票 + if (baseCode.startsWith('0') || baseCode.startsWith('3')) { + return 'SZSE'; + } + + // 1开头是深圳 ETF/债券 + if (baseCode.startsWith('1')) { + return 'SZSE'; + } + + // 默认上海 + return 'SSE'; +}; + +/** + * 获取证券代码的完整格式(带交易所后缀) + * @param code - 原始代码 + * @param isIndex - 是否为指数 + * @returns 带后缀的代码 + */ +export const getFullCode = (code: string, isIndex?: boolean): string => { + if (code.includes('.')) { + return code; // 已带后缀 + } + const exchange = getExchange(code, isIndex); + return `${code}.${exchange === 'SSE' ? 'SH' : 'SZ'}`; +}; + +/** + * 标准化证券代码为无后缀格式 + * @param code - 原始代码 + * @returns 无后缀代码 + */ +export const normalizeCode = (code: string): string => { + return code.split('.')[0]; +}; + +/** + * 盘口数据可能的格式(根据不同的 WebSocket 服务端实现) + */ +type OrderBookInput = + | OrderBookLevel[] // 格式1: [{price, volume}, ...] + | Array<[number, number]> // 格式2: [[price, volume], ...] + | { prices: number[]; volumes: number[] } // 格式3: {prices: [...], volumes: [...]} + | undefined; + +/** + * 从深交所 bids/asks 数组提取价格和量数组 + * 支持多种可能的数据格式 + * @param orderBook - 盘口数据,支持多种格式 + * @returns { prices, volumes } + */ +export const extractOrderBook = ( + orderBook: OrderBookInput +): { prices: number[]; volumes: number[] } => { + if (!orderBook) { + return { prices: [], volumes: [] }; + } + + // 格式3: 已经是 {prices, volumes} 结构 + if (!Array.isArray(orderBook) && 'prices' in orderBook && 'volumes' in orderBook) { + return { + prices: orderBook.prices || [], + volumes: orderBook.volumes || [], + }; + } + + // 必须是数组才能继续 + if (!Array.isArray(orderBook) || orderBook.length === 0) { + return { prices: [], volumes: [] }; + } + + const firstItem = orderBook[0]; + + // 格式2: [[price, volume], ...] + if (Array.isArray(firstItem)) { + const prices = orderBook.map((item: unknown) => { + const arr = item as [number, number]; + return arr[0] || 0; + }); + const volumes = orderBook.map((item: unknown) => { + const arr = item as [number, number]; + return arr[1] || 0; + }); + return { prices, volumes }; + } + + // 格式1: [{price, volume}, ...] (标准格式) + if (typeof firstItem === 'object' && firstItem !== null) { + const typedBook = orderBook as OrderBookLevel[]; + const prices = typedBook.map(item => item.price || 0); + const volumes = typedBook.map(item => item.volume || 0); + return { prices, volumes }; + } + + return { prices: [], volumes: [] }; +}; + +/** + * 计算涨跌幅 + * @param price - 当前价 + * @param prevClose - 昨收价 + * @returns 涨跌幅百分比 + */ +export const calcChangePct = (price: number, prevClose: number): number => { + if (!prevClose || prevClose === 0) return 0; + return ((price - prevClose) / prevClose) * 100; +}; diff --git a/src/views/StockOverview/components/FlexScreen/index.tsx b/src/views/StockOverview/components/FlexScreen/index.tsx new file mode 100644 index 00000000..dfda551e --- /dev/null +++ b/src/views/StockOverview/components/FlexScreen/index.tsx @@ -0,0 +1,507 @@ +/** + * 灵活屏组件 + * 用户可自定义添加关注的指数/个股,实时显示行情 + * + * 功能: + * 1. 添加/删除自选证券 + * 2. 显示实时行情(通过 WebSocket) + * 3. 显示分时走势(结合 ClickHouse 历史数据) + * 4. 显示五档盘口(上交所5档,深交所10档) + * 5. 本地存储自选列表 + */ +import React, { useState, useEffect, useCallback, useMemo } from 'react'; +import { + Box, + Card, + CardBody, + VStack, + HStack, + Heading, + Text, + Input, + InputGroup, + InputLeftElement, + InputRightElement, + IconButton, + SimpleGrid, + Flex, + Spacer, + Icon, + useColorModeValue, + useToast, + Badge, + Tooltip, + Collapse, + List, + ListItem, + Spinner, + Center, + Menu, + MenuButton, + MenuList, + MenuItem, + Tag, + TagLabel, +} from '@chakra-ui/react'; +import { + SearchIcon, + CloseIcon, + AddIcon, + ChevronDownIcon, + ChevronUpIcon, + SettingsIcon, +} from '@chakra-ui/icons'; +import { + FaDesktop, + FaTrash, + FaSync, + FaWifi, + FaExclamationCircle, +} from 'react-icons/fa'; + +import { useRealtimeQuote } from './hooks'; +import { getFullCode } from './hooks/utils'; +import QuoteTile from './components/QuoteTile'; +import { logger } from '@utils/logger'; +import type { WatchlistItem, ConnectionStatus } from './types'; + +// 本地存储 key +const STORAGE_KEY = 'flexscreen_watchlist'; + +// 默认自选列表 +const DEFAULT_WATCHLIST: WatchlistItem[] = [ + { code: '000001', name: '上证指数', isIndex: true }, + { code: '399001', name: '深证成指', isIndex: true }, + { code: '399006', name: '创业板指', isIndex: true }, +]; + +// 热门推荐 +const HOT_RECOMMENDATIONS: WatchlistItem[] = [ + { code: '000001', name: '上证指数', isIndex: true }, + { code: '399001', name: '深证成指', isIndex: true }, + { code: '399006', name: '创业板指', isIndex: true }, + { code: '399300', name: '沪深300', isIndex: true }, + { code: '600519', name: '贵州茅台', isIndex: false }, + { code: '000858', name: '五粮液', isIndex: false }, + { code: '300750', name: '宁德时代', isIndex: false }, + { code: '002594', name: '比亚迪', isIndex: false }, +]; + +/** 搜索结果项 */ +interface SearchResultItem { + stock_code: string; + stock_name: string; + isIndex?: boolean; + code?: string; + name?: string; +} + +/** 搜索 API 响应 */ +interface SearchApiResponse { + success: boolean; + data?: SearchResultItem[]; +} + +/** 连接状态信息 */ +interface ConnectionStatusInfo { + color: string; + text: string; +} + +/** + * FlexScreen 组件 + */ +const FlexScreen: React.FC = () => { + const toast = useToast(); + + // 自选列表 + const [watchlist, setWatchlist] = useState([]); + // 搜索状态 + const [searchQuery, setSearchQuery] = useState(''); + const [searchResults, setSearchResults] = useState([]); + const [isSearching, setIsSearching] = useState(false); + const [showResults, setShowResults] = useState(false); + // 面板状态 + const [isCollapsed, setIsCollapsed] = useState(false); + + // 颜色主题 + const cardBg = useColorModeValue('white', '#1a1a1a'); + const borderColor = useColorModeValue('gray.200', '#333'); + const textColor = useColorModeValue('gray.800', 'white'); + const subTextColor = useColorModeValue('gray.600', 'gray.400'); + const searchBg = useColorModeValue('gray.50', '#2a2a2a'); + const hoverBg = useColorModeValue('gray.100', '#333'); + + // 获取订阅的证券代码列表(带后缀,用于区分上证指数000001.SH和平安银行000001.SZ) + const subscribedCodes = useMemo(() => { + return watchlist.map(item => getFullCode(item.code, item.isIndex)); + }, [watchlist]); + + // WebSocket 实时行情 + const { quotes, connected } = useRealtimeQuote(subscribedCodes); + + // 从本地存储加载自选列表 + useEffect(() => { + try { + const saved = localStorage.getItem(STORAGE_KEY); + if (saved) { + const parsed = JSON.parse(saved) as WatchlistItem[]; + if (Array.isArray(parsed) && parsed.length > 0) { + setWatchlist(parsed); + return; + } + } + } catch (e) { + logger.warn('FlexScreen', '加载自选列表失败', e); + } + // 使用默认列表 + setWatchlist(DEFAULT_WATCHLIST); + }, []); + + // 保存自选列表到本地存储 + useEffect(() => { + if (watchlist.length > 0) { + try { + localStorage.setItem(STORAGE_KEY, JSON.stringify(watchlist)); + } catch (e) { + logger.warn('FlexScreen', '保存自选列表失败', e); + } + } + }, [watchlist]); + + // 搜索证券 + const searchSecurities = useCallback(async (query: string): Promise => { + if (!query.trim()) { + setSearchResults([]); + setShowResults(false); + return; + } + + setIsSearching(true); + try { + const response = await fetch(`/api/stocks/search?q=${encodeURIComponent(query)}&limit=10`); + const data: SearchApiResponse = await response.json(); + + if (data.success) { + setSearchResults(data.data || []); + setShowResults(true); + } else { + setSearchResults([]); + } + } catch (e) { + logger.error('FlexScreen', '搜索失败', e); + setSearchResults([]); + } finally { + setIsSearching(false); + } + }, []); + + // 防抖搜索 + useEffect(() => { + const timer = setTimeout(() => { + searchSecurities(searchQuery); + }, 300); + return () => clearTimeout(timer); + }, [searchQuery, searchSecurities]); + + // 添加证券 + const addSecurity = useCallback( + (security: SearchResultItem | WatchlistItem): void => { + const code = 'stock_code' in security ? security.stock_code : security.code; + const name = 'stock_name' in security ? security.stock_name : security.name; + // 优先使用 API 返回的 isIndex 字段 + const isIndex = security.isIndex === true; + + // 生成唯一标识(带后缀的完整代码) + const fullCode = getFullCode(code, isIndex); + + // 检查是否已存在(使用带后缀的代码比较,避免上证指数和平安银行冲突) + if (watchlist.some(item => getFullCode(item.code, item.isIndex) === fullCode)) { + toast({ + title: '已在自选列表中', + status: 'info', + duration: 2000, + isClosable: true, + }); + return; + } + + // 添加到列表 + setWatchlist(prev => [...prev, { code, name, isIndex }]); + + toast({ + title: `已添加 ${name}${isIndex ? '(指数)' : ''}`, + status: 'success', + duration: 2000, + isClosable: true, + }); + + // 清空搜索 + setSearchQuery(''); + setShowResults(false); + }, + [watchlist, toast] + ); + + // 移除证券 + const removeSecurity = useCallback((code: string): void => { + setWatchlist(prev => prev.filter(item => item.code !== code)); + }, []); + + // 清空自选列表 + const clearWatchlist = useCallback((): void => { + setWatchlist([]); + localStorage.removeItem(STORAGE_KEY); + toast({ + title: '已清空自选列表', + status: 'info', + duration: 2000, + isClosable: true, + }); + }, [toast]); + + // 重置为默认列表 + const resetWatchlist = useCallback((): void => { + setWatchlist(DEFAULT_WATCHLIST); + toast({ + title: '已重置为默认列表', + status: 'success', + duration: 2000, + isClosable: true, + }); + }, [toast]); + + // 连接状态指示 + const isAnyConnected = connected.SSE || connected.SZSE; + const connectionStatus = useMemo((): ConnectionStatusInfo => { + if (connected.SSE && connected.SZSE) { + return { color: 'green', text: '上交所/深交所 已连接' }; + } + if (connected.SSE) { + return { color: 'yellow', text: '上交所 已连接' }; + } + if (connected.SZSE) { + return { color: 'yellow', text: '深交所 已连接' }; + } + return { color: 'red', text: '未连接' }; + }, [connected]); + + return ( + + + {/* 头部 */} + + + + + 灵活屏 + + + + + {isAnyConnected ? '实时' : '离线'} + + + + + + {/* 操作菜单 */} + + } + size="sm" + variant="ghost" + aria-label="设置" + /> + + } onClick={resetWatchlist}> + 重置为默认 + + } onClick={clearWatchlist} color="red.500"> + 清空列表 + + + + {/* 折叠按钮 */} + : } + size="sm" + variant="ghost" + onClick={() => setIsCollapsed(!isCollapsed)} + aria-label={isCollapsed ? '展开' : '收起'} + /> + + + + {/* 可折叠内容 */} + + {/* 搜索框 */} + + + + + + setSearchQuery(e.target.value)} + bg={searchBg} + borderRadius="lg" + _focus={{ + borderColor: 'purple.400', + boxShadow: '0 0 0 1px var(--chakra-colors-purple-400)', + }} + /> + {searchQuery && ( + + } + variant="ghost" + onClick={() => { + setSearchQuery(''); + setShowResults(false); + }} + aria-label="清空" + /> + + )} + + + {/* 搜索结果下拉 */} + + + {isSearching ? ( +
+ +
+ ) : searchResults.length > 0 ? ( + + {searchResults.map((stock, index) => ( + addSecurity(stock)} + borderBottomWidth={index < searchResults.length - 1 ? '1px' : '0'} + borderColor={borderColor} + > + + + + + {stock.stock_name} + + + {stock.isIndex ? '指数' : '股票'} + + + + {stock.stock_code} + + + } + size="xs" + colorScheme="purple" + variant="ghost" + aria-label="添加" + /> + + + ))} + + ) : ( +
+ + 未找到相关证券 + +
+ )} +
+
+
+ + {/* 快捷添加 */} + {watchlist.length === 0 && ( + + + 热门推荐(点击添加) + + + {HOT_RECOMMENDATIONS.map(item => ( + addSecurity(item)} + > + {item.name} + + ))} + + + )} + + {/* 自选列表 */} + {watchlist.length > 0 ? ( + + {watchlist.map(item => { + const fullCode = getFullCode(item.code, item.isIndex); + return ( + + ); + })} + + ) : ( +
+ + + 自选列表为空,请搜索添加证券 + +
+ )} +
+
+
+ ); +}; + +export default FlexScreen; diff --git a/src/views/StockOverview/components/FlexScreen/types.ts b/src/views/StockOverview/components/FlexScreen/types.ts new file mode 100644 index 00000000..796290f3 --- /dev/null +++ b/src/views/StockOverview/components/FlexScreen/types.ts @@ -0,0 +1,322 @@ +/** + * 灵活屏组件类型定义 + */ + +// ==================== WebSocket 相关类型 ==================== + +/** 交易所标识 */ +export type Exchange = 'SSE' | 'SZSE'; + +/** WebSocket 连接状态 */ +export interface ConnectionStatus { + SSE: boolean; + SZSE: boolean; +} + +/** 盘口档位数据 */ +export interface OrderBookLevel { + price: number; + volume: number; +} + +// ==================== 行情数据类型 ==================== + +/** 盘后交易数据 */ +export interface AfterhoursData { + bidPx: number; + bidSize: number; + offerPx: number; + offerSize: number; + volume: number; + amount: number; + numTrades: number; +} + +/** 基础行情数据 */ +export interface BaseQuoteData { + code: string; + name: string; + price: number; + prevClose: number; + open: number; + high: number; + low: number; + volume: number; + amount: number; + change: number; + changePct: number; + updateTime?: string; + exchange: Exchange; +} + +/** 股票行情数据 */ +export interface StockQuoteData extends BaseQuoteData { + numTrades?: number; + upperLimit?: number; // 涨停价 + lowerLimit?: number; // 跌停价 + bidPrices: number[]; + bidVolumes: number[]; + askPrices: number[]; + askVolumes: number[]; + tradingPhase?: string; + afterhours?: AfterhoursData; // 盘后交易数据 +} + +/** 指数行情数据 */ +export interface IndexQuoteData extends BaseQuoteData { + close?: number; + numTrades?: number; + bidPrices: number[]; + bidVolumes: number[]; + askPrices: number[]; + askVolumes: number[]; + tradingPhase?: string; +} + +/** 债券行情数据 */ +export interface BondQuoteData extends BaseQuoteData { + numTrades?: number; + weightedAvgPx?: number; + bidPrices: number[]; + bidVolumes: number[]; + askPrices: number[]; + askVolumes: number[]; + tradingPhase?: string; + isBond: true; +} + +/** 港股行情数据 */ +export interface HKStockQuoteData extends BaseQuoteData { + numTrades?: number; + nominalPx?: number; // 按盘价 + referencePx?: number; // 参考价 + bidPrices: number[]; + bidVolumes: number[]; + askPrices: number[]; + askVolumes: number[]; + tradingPhase?: string; + isHK: true; +} + +/** 统一行情数据类型 */ +export type QuoteData = StockQuoteData | IndexQuoteData | BondQuoteData | HKStockQuoteData; + +/** 行情数据字典 */ +export interface QuotesMap { + [code: string]: QuoteData; +} + +// ==================== 上交所 WebSocket 消息类型 ==================== + +/** 上交所行情数据 */ +export interface SSEQuoteItem { + security_id: string; + security_name: string; + prev_close: number; + open_price: number; + high_price: number; + low_price: number; + last_price: number; + close_price: number; + volume: number; + amount: number; + bid_prices?: number[]; + bid_volumes?: number[]; + ask_prices?: number[]; + ask_volumes?: number[]; + trading_status?: string; + trade_time?: string; + update_time?: string; +} + +/** 上交所消息 */ +export interface SSEMessage { + type: 'stock' | 'index' | 'etf' | 'bond' | 'option' | 'subscribed' | 'pong' | 'error'; + timestamp?: string; + data?: Record; + channels?: string[]; + message?: string; +} + +// ==================== 深交所 WebSocket 消息类型 ==================== + +/** 深交所数据类别 */ +export type SZSECategory = + | 'stock' // 300111 股票快照 + | 'bond' // 300211 债券快照 + | 'afterhours_block' // 300611 盘后定价大宗交易 + | 'afterhours_trading' // 303711 盘后定价交易 + | 'hk_stock' // 306311 港股快照 + | 'index' // 309011 指数快照 + | 'volume_stats' // 309111 成交量统计 + | 'fund_nav'; // 309211 基金净值 + +/** 深交所股票行情数据 */ +export interface SZSEStockData { + security_id: string; + orig_time?: number; + channel_no?: number; + trading_phase?: string; + last_px: number; + open_px: number; + high_px: number; + low_px: number; + prev_close: number; + volume: number; + amount: number; + num_trades?: number; + upper_limit?: number; + lower_limit?: number; + bids?: OrderBookLevel[]; + asks?: OrderBookLevel[]; +} + +/** 深交所指数行情数据 */ +export interface SZSEIndexData { + security_id: string; + orig_time?: number; + channel_no?: number; + trading_phase?: string; + current_index: number; + open_index: number; + high_index: number; + low_index: number; + close_index?: number; + prev_close: number; + volume: number; + amount: number; + num_trades?: number; +} + +/** 深交所债券行情数据 */ +export interface SZSEBondData { + security_id: string; + orig_time?: number; + channel_no?: number; + trading_phase?: string; + last_px: number; + open_px: number; + high_px: number; + low_px: number; + prev_close: number; + weighted_avg_px?: number; + volume: number; + amount: number; + num_trades?: number; + auction_volume?: number; + auction_amount?: number; +} + +/** 深交所港股行情数据 */ +export interface SZSEHKStockData { + security_id: string; + orig_time?: number; + channel_no?: number; + trading_phase?: string; + last_px: number; + open_px: number; + high_px: number; + low_px: number; + prev_close: number; + nominal_px?: number; + reference_px?: number; + volume: number; + amount: number; + num_trades?: number; + vcm_start_time?: number; + vcm_end_time?: number; + bids?: OrderBookLevel[]; + asks?: OrderBookLevel[]; +} + +/** 深交所盘后交易数据 */ +export interface SZSEAfterhoursData { + security_id: string; + orig_time?: number; + channel_no?: number; + trading_phase?: string; + prev_close: number; + bid_px: number; + bid_size: number; + offer_px: number; + offer_size: number; + volume: number; + amount: number; + num_trades?: number; +} + +/** 深交所实时消息 */ +export interface SZSERealtimeMessage { + type: 'realtime'; + category: SZSECategory; + msg_type?: number; + timestamp: string; + data: SZSEStockData | SZSEIndexData | SZSEBondData | SZSEHKStockData | SZSEAfterhoursData; +} + +/** 深交所快照消息 */ +export interface SZSESnapshotMessage { + type: 'snapshot'; + timestamp: string; + data: { + stocks?: SZSEStockData[]; + indexes?: SZSEIndexData[]; + bonds?: SZSEBondData[]; + }; +} + +/** 深交所消息类型 */ +export type SZSEMessage = SZSERealtimeMessage | SZSESnapshotMessage | { type: 'pong' }; + +// ==================== 组件 Props 类型 ==================== + +/** 自选证券项 */ +export interface WatchlistItem { + code: string; + name: string; + isIndex: boolean; +} + +/** QuoteTile 组件 Props */ +export interface QuoteTileProps { + code: string; + name: string; + quote: Partial; + isIndex?: boolean; + onRemove?: (code: string) => void; +} + +/** OrderBookPanel 组件 Props */ +export interface OrderBookPanelProps { + bidPrices?: number[]; + bidVolumes?: number[]; + askPrices?: number[]; + askVolumes?: number[]; + prevClose?: number; + upperLimit?: number; + lowerLimit?: number; + defaultLevels?: number; +} + +/** MiniTimelineChart 组件 Props */ +export interface MiniTimelineChartProps { + code: string; + isIndex?: boolean; + prevClose?: number; + currentPrice?: number; + height?: number; +} + +/** 分时数据点 */ +export interface TimelineDataPoint { + time: string; + price: number; +} + +/** useRealtimeQuote Hook 返回值 */ +export interface UseRealtimeQuoteReturn { + quotes: QuotesMap; + connected: ConnectionStatus; + subscribe: (code: string) => void; + unsubscribe: (code: string) => void; +} diff --git a/src/views/StockOverview/components/HotspotOverview/components/AlertSummary.js b/src/views/StockOverview/components/HotspotOverview/components/AlertSummary.js new file mode 100644 index 00000000..37a871c6 --- /dev/null +++ b/src/views/StockOverview/components/HotspotOverview/components/AlertSummary.js @@ -0,0 +1,147 @@ +/** + * 异动统计摘要组件 + * 展示指数统计和异动类型统计 + */ +import React from 'react'; +import { + Box, + HStack, + VStack, + Text, + Badge, + Icon, + Stat, + StatLabel, + StatNumber, + StatHelpText, + StatArrow, + SimpleGrid, + useColorModeValue, +} from '@chakra-ui/react'; +import { FaBolt, FaArrowDown, FaRocket, FaChartLine, FaFire, FaVolumeUp } from 'react-icons/fa'; + +/** + * 异动类型徽章 + */ +const AlertTypeBadge = ({ type, count }) => { + const config = { + surge: { label: '急涨', color: 'red', icon: FaBolt }, + surge_up: { label: '暴涨', color: 'red', icon: FaBolt }, + surge_down: { label: '暴跌', color: 'green', icon: FaArrowDown }, + limit_up: { label: '涨停', color: 'orange', icon: FaRocket }, + rank_jump: { label: '排名跃升', color: 'blue', icon: FaChartLine }, + volume_spike: { label: '放量', color: 'purple', icon: FaVolumeUp }, + }; + + const cfg = config[type] || { label: type, color: 'gray', icon: FaFire }; + + return ( + + + + {cfg.label} + {count} + + + ); +}; + +/** + * 指数统计卡片 + */ +const IndexStatCard = ({ indexData }) => { + const cardBg = useColorModeValue('white', '#1a1a1a'); + const borderColor = useColorModeValue('gray.200', '#333'); + const subTextColor = useColorModeValue('gray.600', 'gray.400'); + + if (!indexData) return null; + + const changePct = indexData.change_pct || 0; + const isUp = changePct >= 0; + + return ( + + + {indexData.name || '上证指数'} + + {indexData.latest_price?.toFixed(2) || '-'} + + + + {changePct?.toFixed(2)}% + + + + + 最高 + + {indexData.high?.toFixed(2) || '-'} + + + + + 最低 + + {indexData.low?.toFixed(2) || '-'} + + + + + 振幅 + + {indexData.high && indexData.low && indexData.prev_close + ? (((indexData.high - indexData.low) / indexData.prev_close) * 100).toFixed(2) + '%' + : '-'} + + + + ); +}; + +/** + * 异动统计摘要 + * @param {Object} props + * @param {Object} props.indexData - 指数数据 + * @param {Array} props.alerts - 异动数组 + * @param {Object} props.alertSummary - 异动类型统计 + */ +const AlertSummary = ({ indexData, alerts = [], alertSummary = {} }) => { + const cardBg = useColorModeValue('white', '#1a1a1a'); + const borderColor = useColorModeValue('gray.200', '#333'); + + // 如果没有 alertSummary,从 alerts 中统计 + const summary = alertSummary && Object.keys(alertSummary).length > 0 + ? alertSummary + : alerts.reduce((acc, alert) => { + const type = alert.alert_type || 'unknown'; + acc[type] = (acc[type] || 0) + 1; + return acc; + }, {}); + + const totalAlerts = alerts.length; + + return ( + + {/* 指数统计 */} + + + {/* 异动统计 */} + {totalAlerts > 0 && ( + + + 异动 {totalAlerts} 次: + + {(summary.surge_up > 0 || summary.surge > 0) && ( + + )} + {summary.surge_down > 0 && } + {summary.limit_up > 0 && } + {summary.volume_spike > 0 && } + {summary.rank_jump > 0 && } + + )} + + ); +}; + +export default AlertSummary; diff --git a/src/views/StockOverview/components/HotspotOverview/components/ConceptAlertList.js b/src/views/StockOverview/components/HotspotOverview/components/ConceptAlertList.js new file mode 100644 index 00000000..c1613aba --- /dev/null +++ b/src/views/StockOverview/components/HotspotOverview/components/ConceptAlertList.js @@ -0,0 +1,356 @@ +/** + * 概念异动列表组件 - V2 + * 展示当日的概念异动记录,点击可展开显示相关股票 + */ +import React, { useState, useCallback } from 'react'; +import { + Box, + VStack, + HStack, + Text, + Badge, + Icon, + Tooltip, + useColorModeValue, + Flex, + Collapse, + Spinner, + Progress, + Table, + Thead, + Tbody, + Tr, + Th, + Td, + TableContainer, +} from '@chakra-ui/react'; +import { FaArrowUp, FaArrowDown, FaFire, FaChevronDown, FaChevronRight } from 'react-icons/fa'; +import { useNavigate } from 'react-router-dom'; +import axios from 'axios'; +import { getAlertTypeLabel, formatScore, getScoreColor } from '../utils/chartHelpers'; + +/** + * 紧凑型异动卡片 + */ +const AlertCard = ({ alert, isExpanded, onToggle, stocks, loadingStocks }) => { + const navigate = useNavigate(); + const bgColor = useColorModeValue('white', '#1a1a1a'); + const hoverBg = useColorModeValue('gray.50', '#252525'); + const borderColor = useColorModeValue('gray.200', '#333'); + const expandedBg = useColorModeValue('purple.50', '#1e1e2e'); + + const isUp = alert.alert_type !== 'surge_down'; + const typeColor = isUp ? 'red' : 'green'; + const isV2 = alert.is_v2; + + // 点击股票跳转 + const handleStockClick = (e, stockCode) => { + e.stopPropagation(); + navigate(`/company?scode=${stockCode}`); + }; + + return ( + + {/* 主卡片 - 点击展开 */} + + + {/* 左侧:名称 + 类型 */} + + + + + {alert.concept_name} + + {isV2 && ( + + V2 + + )} + + + {/* 右侧:分数 */} + + {formatScore(alert.final_score)}分 + + + + {/* 第二行:时间 + 关键指标 */} + + + {alert.time} + + {getAlertTypeLabel(alert.alert_type)} + + {/* 确认率 */} + {isV2 && alert.confirm_ratio != null && ( + + + = 0.8 ? 'green.500' : 'orange.500'} + /> + + {Math.round((alert.confirm_ratio || 0) * 100)}% + + )} + + + {/* Alpha + Z-Score 简化显示 */} + + {alert.alpha != null && ( + = 0 ? 'red.500' : 'green.500'} fontWeight="medium"> + α {(alert.alpha || 0) >= 0 ? '+' : ''}{(alert.alpha || 0).toFixed(2)}% + + )} + {isV2 && alert.alpha_zscore != null && ( + + + + = 0 ? '50%' : undefined} + right={(alert.alpha_zscore || 0) < 0 ? '50%' : undefined} + w={`${Math.min(Math.abs(alert.alpha_zscore || 0) / 5 * 50, 50)}%`} + h="100%" + bg={(alert.alpha_zscore || 0) >= 0 ? 'red.500' : 'green.500'} + /> + + = 0 ? 'red.400' : 'green.400'}> + {(alert.alpha_zscore || 0) >= 0 ? '+' : ''}{(alert.alpha_zscore || 0).toFixed(1)}σ + + + + )} + {(alert.limit_up_ratio || 0) > 0.05 && ( + + + {Math.round((alert.limit_up_ratio || 0) * 100)}% + + )} + + + + + {/* 展开的股票列表 */} + + + {loadingStocks ? ( + + + 加载相关股票... + + ) : stocks && stocks.length > 0 ? ( + + + + + + + + + + + {stocks.slice(0, 10).map((stock, idx) => { + const changePct = stock.change_pct; + const hasChange = changePct != null && !isNaN(changePct); + return ( + handleStockClick(e, stock.code || stock.stock_code)} + > + + + + + ); + })} + +
股票涨跌原因
+ + {stock.name || stock.stock_name || '-'} + + + 0 ? 'red.400' : + hasChange && changePct < 0 ? 'green.400' : 'gray.400' + } + > + {hasChange + ? `${changePct > 0 ? '+' : ''}${changePct.toFixed(2)}%` + : '-' + } + + + + {stock.reason || '-'} + +
+ {stocks.length > 10 && ( + + 共 {stocks.length} 只相关股票,显示前 10 只 + + )} +
+ ) : ( + + 暂无相关股票数据 + + )} +
+
+
+ ); +}; + +/** + * 概念异动列表 + */ +const ConceptAlertList = ({ alerts = [], onAlertClick, selectedAlert, maxHeight = '400px' }) => { + const [expandedId, setExpandedId] = useState(null); + const [conceptStocks, setConceptStocks] = useState({}); + const [loadingConcepts, setLoadingConcepts] = useState({}); + + const subTextColor = useColorModeValue('gray.500', 'gray.400'); + + // 获取概念相关股票 + const fetchConceptStocks = useCallback(async (conceptId) => { + if (conceptStocks[conceptId] || loadingConcepts[conceptId]) { + return; + } + + setLoadingConcepts(prev => ({ ...prev, [conceptId]: true })); + + try { + // 调用后端 API 获取概念股票 + const response = await axios.get(`/api/concept/${conceptId}/stocks`); + if (response.data?.success && response.data?.data?.stocks) { + setConceptStocks(prev => ({ + ...prev, + [conceptId]: response.data.data.stocks + })); + } + } catch (error) { + console.error('获取概念股票失败:', error); + // 如果 API 失败,尝试从 ES 直接获取 + try { + const esResponse = await axios.get(`/api/es/concept/${conceptId}`); + if (esResponse.data?.stocks) { + setConceptStocks(prev => ({ + ...prev, + [conceptId]: esResponse.data.stocks + })); + } + } catch (esError) { + console.error('ES 获取也失败:', esError); + setConceptStocks(prev => ({ ...prev, [conceptId]: [] })); + } + } finally { + setLoadingConcepts(prev => ({ ...prev, [conceptId]: false })); + } + }, [conceptStocks, loadingConcepts]); + + // 切换展开状态 + const handleToggle = useCallback((alert) => { + const alertKey = `${alert.concept_id}-${alert.time}`; + + if (expandedId === alertKey) { + setExpandedId(null); + } else { + setExpandedId(alertKey); + // 获取股票数据 + if (alert.concept_id) { + fetchConceptStocks(alert.concept_id); + } + } + + // 通知父组件 + onAlertClick?.(alert); + }, [expandedId, fetchConceptStocks, onAlertClick]); + + if (!alerts || alerts.length === 0) { + return ( + + + 当日暂无概念异动 + + + ); + } + + // 按时间倒序排列 + const sortedAlerts = [...alerts].sort((a, b) => { + const timeA = a.time || '00:00'; + const timeB = b.time || '00:00'; + return timeB.localeCompare(timeA); + }); + + return ( + + + {sortedAlerts.map((alert, idx) => { + const alertKey = `${alert.concept_id}-${alert.time}`; + return ( + handleToggle(alert)} + stocks={conceptStocks[alert.concept_id]} + loadingStocks={loadingConcepts[alert.concept_id]} + /> + ); + })} + + + ); +}; + +export default ConceptAlertList; diff --git a/src/views/StockOverview/components/HotspotOverview/components/IndexMinuteChart.js b/src/views/StockOverview/components/HotspotOverview/components/IndexMinuteChart.js new file mode 100644 index 00000000..a7c360a5 --- /dev/null +++ b/src/views/StockOverview/components/HotspotOverview/components/IndexMinuteChart.js @@ -0,0 +1,264 @@ +/** + * 指数分时图组件 + * 展示大盘分时走势,支持概念异动标注 + */ +import React, { useRef, useEffect, useCallback, useMemo } from 'react'; +import { Box, useColorModeValue } from '@chakra-ui/react'; +import * as echarts from 'echarts'; +import { getAlertMarkPoints } from '../utils/chartHelpers'; + +/** + * @param {Object} props + * @param {Object} props.indexData - 指数数据 { timeline, prev_close, name, ... } + * @param {Array} props.alerts - 异动数据数组 + * @param {Function} props.onAlertClick - 点击异动标注的回调 + * @param {string} props.height - 图表高度 + */ +const IndexMinuteChart = ({ indexData, alerts = [], onAlertClick, height = '350px' }) => { + const chartRef = useRef(null); + const chartInstance = useRef(null); + + const textColor = useColorModeValue('gray.800', 'white'); + const subTextColor = useColorModeValue('gray.600', 'gray.400'); + const gridLineColor = useColorModeValue('#eee', '#333'); + + // 计算图表配置 + const chartOption = useMemo(() => { + if (!indexData || !indexData.timeline || indexData.timeline.length === 0) { + return null; + } + + const timeline = indexData.timeline || []; + const times = timeline.map((d) => d.time); + const prices = timeline.map((d) => d.price); + const volumes = timeline.map((d) => d.volume); + const changePcts = timeline.map((d) => d.change_pct); + + // 计算Y轴范围 + const validPrices = prices.filter(Boolean); + if (validPrices.length === 0) return null; + + const priceMin = Math.min(...validPrices); + const priceMax = Math.max(...validPrices); + const priceRange = priceMax - priceMin; + const yAxisMin = priceMin - priceRange * 0.1; + const yAxisMax = priceMax + priceRange * 0.25; // 上方留更多空间给标注 + + // 准备异动标注 + const markPoints = getAlertMarkPoints(alerts, times, prices, priceMax); + + // 渐变色 - 根据涨跌 + const latestChangePct = changePcts[changePcts.length - 1] || 0; + const isUp = latestChangePct >= 0; + const lineColor = isUp ? '#ff4d4d' : '#22c55e'; + const areaColorStops = isUp + ? [ + { offset: 0, color: 'rgba(255, 77, 77, 0.4)' }, + { offset: 1, color: 'rgba(255, 77, 77, 0.05)' }, + ] + : [ + { offset: 0, color: 'rgba(34, 197, 94, 0.4)' }, + { offset: 1, color: 'rgba(34, 197, 94, 0.05)' }, + ]; + + return { + backgroundColor: 'transparent', + tooltip: { + trigger: 'axis', + axisPointer: { + type: 'cross', + crossStyle: { color: '#999' }, + }, + formatter: (params) => { + if (!params || params.length === 0) return ''; + + const dataIndex = params[0].dataIndex; + const time = times[dataIndex]; + const price = prices[dataIndex]; + const changePct = changePcts[dataIndex]; + const volume = volumes[dataIndex]; + + let html = ` +
+
${time}
+
指数: ${price?.toFixed(2)}
+
涨跌: ${changePct >= 0 ? '+' : ''}${changePct?.toFixed(2)}%
+
成交量: ${(volume / 10000).toFixed(0)}万手
+
+ `; + + // 检查是否有异动 + const alertsAtTime = alerts.filter((a) => a.time === time); + if (alertsAtTime.length > 0) { + html += '
'; + html += '
概念异动:
'; + alertsAtTime.forEach((alert) => { + const typeLabel = { + surge: '急涨', + surge_up: '暴涨', + surge_down: '暴跌', + limit_up: '涨停增加', + rank_jump: '排名跃升', + volume_spike: '放量', + }[alert.alert_type] || alert.alert_type; + const typeColor = alert.alert_type === 'surge_down' ? '#2ed573' : '#ff6b6b'; + const alpha = alert.alpha ? ` (α${alert.alpha > 0 ? '+' : ''}${alert.alpha.toFixed(2)}%)` : ''; + html += `
• ${alert.concept_name} (${typeLabel}${alpha})
`; + }); + html += '
'; + } + + return html; + }, + }, + legend: { show: false }, + grid: [ + { left: '8%', right: '3%', top: '8%', height: '58%' }, + { left: '8%', right: '3%', top: '72%', height: '18%' }, + ], + xAxis: [ + { + type: 'category', + data: times, + axisLine: { lineStyle: { color: gridLineColor } }, + axisLabel: { + color: subTextColor, + fontSize: 10, + interval: Math.floor(times.length / 6), + }, + axisTick: { show: false }, + splitLine: { show: false }, + }, + { + type: 'category', + gridIndex: 1, + data: times, + axisLine: { lineStyle: { color: gridLineColor } }, + axisLabel: { show: false }, + axisTick: { show: false }, + splitLine: { show: false }, + }, + ], + yAxis: [ + { + type: 'value', + min: yAxisMin, + max: yAxisMax, + axisLine: { show: false }, + axisLabel: { + color: subTextColor, + fontSize: 10, + formatter: (val) => val.toFixed(0), + }, + splitLine: { lineStyle: { color: gridLineColor, type: 'dashed' } }, + axisPointer: { + label: { + formatter: (params) => { + if (!indexData.prev_close) return params.value.toFixed(2); + const pct = ((params.value - indexData.prev_close) / indexData.prev_close) * 100; + return `${params.value.toFixed(2)} (${pct >= 0 ? '+' : ''}${pct.toFixed(2)}%)`; + }, + }, + }, + }, + { + type: 'value', + gridIndex: 1, + axisLine: { show: false }, + axisLabel: { show: false }, + splitLine: { show: false }, + }, + ], + series: [ + // 分时线 + { + name: indexData.name || '上证指数', + type: 'line', + data: prices, + smooth: true, + symbol: 'none', + lineStyle: { color: lineColor, width: 1.5 }, + areaStyle: { + color: new echarts.graphic.LinearGradient(0, 0, 0, 1, areaColorStops), + }, + markPoint: { + symbol: 'pin', + symbolSize: 40, + data: markPoints, + animation: true, + }, + }, + // 成交量 + { + name: '成交量', + type: 'bar', + xAxisIndex: 1, + yAxisIndex: 1, + data: volumes.map((v, i) => ({ + value: v, + itemStyle: { + color: changePcts[i] >= 0 ? 'rgba(255, 77, 77, 0.6)' : 'rgba(34, 197, 94, 0.6)', + }, + })), + barWidth: '60%', + }, + ], + }; + }, [indexData, alerts, subTextColor, gridLineColor]); + + // 渲染图表 + const renderChart = useCallback(() => { + if (!chartRef.current || !chartOption) return; + + if (!chartInstance.current) { + chartInstance.current = echarts.init(chartRef.current); + } + + chartInstance.current.setOption(chartOption, true); + + // 点击事件 + if (onAlertClick) { + chartInstance.current.off('click'); + chartInstance.current.on('click', 'series.line.markPoint', (params) => { + if (params.data && params.data.alertData) { + onAlertClick(params.data.alertData); + } + }); + } + }, [chartOption, onAlertClick]); + + // 数据变化时重新渲染 + useEffect(() => { + renderChart(); + }, [renderChart]); + + // 窗口大小变化时重新渲染 + useEffect(() => { + const handleResize = () => { + if (chartInstance.current) { + chartInstance.current.resize(); + } + }; + + window.addEventListener('resize', handleResize); + return () => { + window.removeEventListener('resize', handleResize); + if (chartInstance.current) { + chartInstance.current.dispose(); + chartInstance.current = null; + } + }; + }, []); + + if (!chartOption) { + return ( + + 暂无数据 + + ); + } + + return ; +}; + +export default IndexMinuteChart; diff --git a/src/views/StockOverview/components/HotspotOverview/components/index.js b/src/views/StockOverview/components/HotspotOverview/components/index.js new file mode 100644 index 00000000..2401b9bd --- /dev/null +++ b/src/views/StockOverview/components/HotspotOverview/components/index.js @@ -0,0 +1,3 @@ +export { default as IndexMinuteChart } from './IndexMinuteChart'; +export { default as ConceptAlertList } from './ConceptAlertList'; +export { default as AlertSummary } from './AlertSummary'; diff --git a/src/views/StockOverview/components/HotspotOverview/hooks/index.js b/src/views/StockOverview/components/HotspotOverview/hooks/index.js new file mode 100644 index 00000000..43a31148 --- /dev/null +++ b/src/views/StockOverview/components/HotspotOverview/hooks/index.js @@ -0,0 +1 @@ +export { useHotspotData } from './useHotspotData'; diff --git a/src/views/StockOverview/components/HotspotOverview/hooks/useHotspotData.js b/src/views/StockOverview/components/HotspotOverview/hooks/useHotspotData.js new file mode 100644 index 00000000..2fbca02a --- /dev/null +++ b/src/views/StockOverview/components/HotspotOverview/hooks/useHotspotData.js @@ -0,0 +1,53 @@ +/** + * 热点概览数据获取 Hook + * 负责获取指数分时数据和概念异动数据 + */ +import { useState, useEffect, useCallback } from 'react'; +import { logger } from '@utils/logger'; + +/** + * @param {Date|null} selectedDate - 选中的交易日期 + * @returns {Object} 数据和状态 + */ +export const useHotspotData = (selectedDate) => { + const [loading, setLoading] = useState(true); + const [error, setError] = useState(null); + const [data, setData] = useState(null); + + const fetchData = useCallback(async () => { + setLoading(true); + setError(null); + + try { + const dateParam = selectedDate + ? `?date=${selectedDate.toISOString().split('T')[0]}` + : ''; + const response = await fetch(`/api/market/hotspot-overview${dateParam}`); + const result = await response.json(); + + if (result.success) { + setData(result.data); + } else { + setError(result.error || '获取数据失败'); + } + } catch (err) { + logger.error('useHotspotData', 'fetchData', err); + setError('网络请求失败'); + } finally { + setLoading(false); + } + }, [selectedDate]); + + useEffect(() => { + fetchData(); + }, [fetchData]); + + return { + loading, + error, + data, + refetch: fetchData, + }; +}; + +export default useHotspotData; diff --git a/src/views/StockOverview/components/HotspotOverview/index.js b/src/views/StockOverview/components/HotspotOverview/index.js new file mode 100644 index 00000000..ccef8cbb --- /dev/null +++ b/src/views/StockOverview/components/HotspotOverview/index.js @@ -0,0 +1,198 @@ +/** + * 热点概览组件 + * 展示大盘分时走势 + 概念异动标注 + * + * 模块化结构: + * - hooks/useHotspotData.js - 数据获取 + * - components/IndexMinuteChart.js - 分时图 + * - components/ConceptAlertList.js - 异动列表 + * - components/AlertSummary.js - 统计摘要 + * - utils/chartHelpers.js - 图表辅助函数 + */ +import React, { useState, useCallback } from 'react'; +import { + Box, + Card, + CardBody, + Heading, + Text, + HStack, + VStack, + Spinner, + Center, + Icon, + Flex, + Spacer, + Tooltip, + useColorModeValue, + Grid, + GridItem, + Divider, + IconButton, + Collapse, +} from '@chakra-ui/react'; +import { FaFire, FaList, FaChartArea, FaChevronDown, FaChevronUp } from 'react-icons/fa'; +import { InfoIcon } from '@chakra-ui/icons'; + +import { useHotspotData } from './hooks'; +import { IndexMinuteChart, ConceptAlertList, AlertSummary } from './components'; + +/** + * 热点概览主组件 + * @param {Object} props + * @param {Date|null} props.selectedDate - 选中的交易日期 + */ +const HotspotOverview = ({ selectedDate }) => { + const [selectedAlert, setSelectedAlert] = useState(null); + const [showAlertList, setShowAlertList] = useState(true); + + // 获取数据 + const { loading, error, data } = useHotspotData(selectedDate); + + // 颜色主题 + const cardBg = useColorModeValue('white', '#1a1a1a'); + const borderColor = useColorModeValue('gray.200', '#333333'); + const textColor = useColorModeValue('gray.800', 'white'); + const subTextColor = useColorModeValue('gray.600', 'gray.400'); + + // 点击异动标注 + const handleAlertClick = useCallback((alert) => { + setSelectedAlert(alert); + // 可以在这里添加滚动到对应位置的逻辑 + }, []); + + // 渲染加载状态 + if (loading) { + return ( + + +
+ + + 加载热点概览数据... + +
+
+
+ ); + } + + // 渲染错误状态 + if (error) { + return ( + + +
+ + + {error} + +
+
+
+ ); + } + + // 无数据 + if (!data) { + return null; + } + + const { index, alerts, alert_summary } = data; + + return ( + + + {/* 头部 */} + + + + + 热点概览 + + + + + + : } + size="sm" + variant="ghost" + onClick={() => setShowAlertList(!showAlertList)} + aria-label="切换异动列表" + /> + + + + + + + + {/* 统计摘要 */} + + + + + + + {/* 主体内容:图表 + 异动列表 */} + + {/* 分时图 */} + + + + + + 大盘分时走势 + + + + + + + {/* 异动列表(可收起) */} + + + + + + + 异动记录 + + + ({alerts.length}) + + + + + + + + + {/* 无异动提示 */} + {alerts.length === 0 && ( +
+ + 当日暂无概念异动数据 + +
+ )} +
+
+ ); +}; + +export default HotspotOverview; diff --git a/src/views/StockOverview/components/HotspotOverview/utils/chartHelpers.js b/src/views/StockOverview/components/HotspotOverview/utils/chartHelpers.js new file mode 100644 index 00000000..54ff9eba --- /dev/null +++ b/src/views/StockOverview/components/HotspotOverview/utils/chartHelpers.js @@ -0,0 +1,160 @@ +/** + * 图表辅助函数 + * 用于处理异动标注等图表相关逻辑 + */ + +/** + * 获取异动标注的配色和符号 + * @param {string} alertType - 异动类型 + * @param {number} importanceScore - 重要性得分 + * @returns {Object} { color, symbol, symbolSize } + */ +export const getAlertStyle = (alertType, importanceScore = 0.5) => { + let color = '#ff6b6b'; + let symbol = 'pin'; + let symbolSize = 35; + + switch (alertType) { + case 'surge_up': + case 'surge': + color = '#ff4757'; + symbol = 'triangle'; + symbolSize = 30 + Math.min(importanceScore * 20, 15); + break; + case 'surge_down': + color = '#2ed573'; + symbol = 'path://M0,0 L10,0 L5,10 Z'; // 向下三角形 + symbolSize = 30 + Math.min(importanceScore * 20, 15); + break; + case 'limit_up': + color = '#ff6348'; + symbol = 'diamond'; + symbolSize = 28; + break; + case 'rank_jump': + color = '#3742fa'; + symbol = 'circle'; + symbolSize = 25; + break; + case 'volume_spike': + color = '#ffa502'; + symbol = 'rect'; + symbolSize = 25; + break; + default: + break; + } + + return { color, symbol, symbolSize }; +}; + +/** + * 获取异动类型的显示标签 + * @param {string} alertType - 异动类型 + * @returns {string} 显示标签 + */ +export const getAlertTypeLabel = (alertType) => { + const labels = { + surge: '急涨', + surge_up: '暴涨', + surge_down: '暴跌', + limit_up: '涨停增加', + rank_jump: '排名跃升', + volume_spike: '放量', + unknown: '异动', + }; + return labels[alertType] || alertType; +}; + +/** + * 生成图表标注点数据 + * @param {Array} alerts - 异动数据数组 + * @param {Array} times - 时间数组 + * @param {Array} prices - 价格数组 + * @param {number} priceMax - 最高价格(用于无法匹配时间时的默认位置) + * @param {number} maxCount - 最大显示数量 + * @returns {Array} ECharts markPoint data + */ +export const getAlertMarkPoints = (alerts, times, prices, priceMax, maxCount = 15) => { + if (!alerts || alerts.length === 0) return []; + + // 按重要性排序,限制显示数量 + const sortedAlerts = [...alerts] + .sort((a, b) => (b.final_score || b.importance_score || 0) - (a.final_score || a.importance_score || 0)) + .slice(0, maxCount); + + return sortedAlerts.map((alert) => { + // 找到对应时间的价格 + const timeIndex = times.indexOf(alert.time); + const price = timeIndex >= 0 ? prices[timeIndex] : (alert.index_price || priceMax); + + const { color, symbol, symbolSize } = getAlertStyle( + alert.alert_type, + alert.final_score / 100 || alert.importance_score || 0.5 + ); + + // 格式化标签 + let label = alert.concept_name || ''; + if (label.length > 6) { + label = label.substring(0, 5) + '...'; + } + + // 添加涨停数量(如果有) + if (alert.limit_up_count > 0) { + label += `\n涨停: ${alert.limit_up_count}`; + } + + const isDown = alert.alert_type === 'surge_down'; + + return { + name: alert.concept_name, + coord: [alert.time, price], + value: label, + symbol, + symbolSize, + itemStyle: { + color, + borderColor: '#fff', + borderWidth: 1, + shadowBlur: 3, + shadowColor: 'rgba(0,0,0,0.2)', + }, + label: { + show: true, + position: isDown ? 'bottom' : 'top', + formatter: '{b}', + fontSize: 9, + color: '#333', + backgroundColor: isDown ? 'rgba(46, 213, 115, 0.9)' : 'rgba(255,255,255,0.9)', + padding: [2, 4], + borderRadius: 2, + borderColor: color, + borderWidth: 1, + }, + alertData: alert, // 存储原始数据 + }; + }); +}; + +/** + * 格式化分数显示 + * @param {number} score - 分数 + * @returns {string} 格式化后的分数 + */ +export const formatScore = (score) => { + if (score === null || score === undefined) return '-'; + return Math.round(score).toString(); +}; + +/** + * 获取分数对应的颜色 + * @param {number} score - 分数 (0-100) + * @returns {string} 颜色代码 + */ +export const getScoreColor = (score) => { + const s = score || 0; + if (s >= 80) return '#ff4757'; + if (s >= 60) return '#ff6348'; + if (s >= 40) return '#ffa502'; + return '#747d8c'; +}; diff --git a/src/views/StockOverview/components/HotspotOverview/utils/index.js b/src/views/StockOverview/components/HotspotOverview/utils/index.js new file mode 100644 index 00000000..133f3c2d --- /dev/null +++ b/src/views/StockOverview/components/HotspotOverview/utils/index.js @@ -0,0 +1 @@ +export * from './chartHelpers'; diff --git a/src/views/StockOverview/index.js b/src/views/StockOverview/index.js index 6a6417e7..dacae01b 100644 --- a/src/views/StockOverview/index.js +++ b/src/views/StockOverview/index.js @@ -50,9 +50,11 @@ import { SkeletonText, } from '@chakra-ui/react'; import { SearchIcon, CloseIcon, ArrowForwardIcon, TrendingUpIcon, InfoIcon, ChevronRightIcon, CalendarIcon } from '@chakra-ui/icons'; -import { FaChartLine, FaFire, FaRocket, FaBrain, FaCalendarAlt, FaChevronRight, FaArrowUp, FaArrowDown, FaChartBar } from 'react-icons/fa'; +import { FaChartLine, FaFire, FaRocket, FaBrain, FaCalendarAlt, FaChevronRight, FaArrowUp, FaArrowDown, FaChartBar, FaTag, FaLayerGroup, FaBolt } from 'react-icons/fa'; import ConceptStocksModal from '@components/ConceptStocksModal'; import TradeDatePicker from '@components/TradeDatePicker'; +import HotspotOverview from './components/HotspotOverview'; +import FlexScreen from './components/FlexScreen'; import { BsGraphUp, BsLightningFill } from 'react-icons/bs'; import * as echarts from 'echarts'; import { logger } from '../../utils/logger'; @@ -840,6 +842,16 @@ const StockOverview = () => { )}
+ {/* 热点概览 - 大盘走势 + 概念异动 */} + + + + + {/* 灵活屏 - 实时行情监控 */} + + + + {/* 今日热门概念 */} @@ -927,16 +939,65 @@ const StockOverview = () => { + {/* 概念名称 */} {concept.concept_name} + {/* 层级信息 */} + {concept.hierarchy && ( + + + + {[concept.hierarchy.lv1, concept.hierarchy.lv2, concept.hierarchy.lv3] + .filter(Boolean) + .join(' > ')} + + + )} + + {/* 描述 */} {concept.description || '暂无描述'} + {/* 标签 */} + {concept.tags && concept.tags.length > 0 && ( + + {concept.tags.slice(0, 4).map((tag, idx) => ( + + + {tag} + + ))} + {concept.tags.length > 4 && ( + + +{concept.tags.length - 4} + + )} + + )} + + {/* 爆发日期 */} + {concept.outbreak_dates && concept.outbreak_dates.length > 0 && ( + + + + 近期爆发: {concept.outbreak_dates.slice(0, 2).join(', ')} + {concept.outbreak_dates.length > 2 && ` 等${concept.outbreak_dates.length}次`} + + + )} + + {/* 相关股票 */} { overflow="hidden" maxH="24px" > - {concept.stocks.map((stock, idx) => ( + {concept.stocks.slice(0, 5).map((stock, idx) => ( { variant="subtle" flexShrink={0} > - {stock.stock_name} + {stock.stock_name || stock.name} ))} + {concept.stocks.length > 5 && ( + + +{concept.stocks.length - 5} + + )} )} diff --git a/sse_html.html b/sse_html.html new file mode 100644 index 00000000..8723bda7 --- /dev/null +++ b/sse_html.html @@ -0,0 +1,378 @@ + + + + + + VDE 实时行情 - WebSocket 测试 + + + +
+

VDE 实时行情

+ +
+
+ 未连接 +
+
更新次数: 0
+
最后更新: -
+
+
+ + + +
+
+ +
+ +
+

📊 指数行情

+ + + + + + + + + + + +
代码/名称最新涨跌涨跌幅成交额(亿)
+
+ + +
+

📈 股票行情 (前20)

+ + + + + + + + + + + + +
代码/名称最新涨跌幅买一卖一成交额(万)
+
+
+ + +
+
+ + + + diff --git a/szse_html.html b/szse_html.html new file mode 100644 index 00000000..baf8e8f0 --- /dev/null +++ b/szse_html.html @@ -0,0 +1,289 @@ + + + + + + 深交所行情 WebSocket 测试 + + + +
+

深交所行情 WebSocket 测试

+ +
+ + + + + + 未连接 +
+ +
+
+
0
+
消息总数
+
+
+
0
+
股票 (300111)
+
+
+
0
+
指数 (309011)
+
+
+
0
+
债券 (300211)
+
+
+
0
+
港股 (306311)
+
+
+
0
+
其他类型
+
+
+ +
+
+
+ 实时行情 + -- +
+
+ + + + + + + + + + + +
代码类型最新价涨跌幅成交量
+
+
+
+
+ 消息日志 + 0 条 +
+
+
+
+
+ + + + diff --git a/valuefrontier.conf b/valuefrontier.conf index 9f4e2992..76fac284 100644 --- a/valuefrontier.conf +++ b/valuefrontier.conf @@ -112,6 +112,42 @@ server { proxy_buffering off; } + # ============================================ + # 实时行情 WebSocket 代理(灵活屏功能) + # ============================================ + + # 上交所实时行情 WebSocket + location /ws/sse { + proxy_pass http://49.232.185.254:8765; + proxy_http_version 1.1; + proxy_set_header Upgrade $http_upgrade; + proxy_set_header Connection $connection_upgrade; + proxy_set_header Host $host; + proxy_set_header X-Real-IP $remote_addr; + proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; + proxy_set_header X-Forwarded-Proto $scheme; + proxy_connect_timeout 7d; + proxy_send_timeout 7d; + proxy_read_timeout 7d; + proxy_buffering off; + } + + # 深交所实时行情 WebSocket + location /ws/szse { + proxy_pass http://222.128.1.157:8765; + proxy_http_version 1.1; + proxy_set_header Upgrade $http_upgrade; + proxy_set_header Connection $connection_upgrade; + proxy_set_header Host $host; + proxy_set_header X-Real-IP $remote_addr; + proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; + proxy_set_header X-Forwarded-Proto $scheme; + proxy_connect_timeout 7d; + proxy_send_timeout 7d; + proxy_read_timeout 7d; + proxy_buffering off; + } + location /mcp/ { proxy_pass http://127.0.0.1:8900/; proxy_http_version 1.1; @@ -142,7 +178,6 @@ server { } } - # 概念板块API代理 location /concept-api/ { proxy_pass http://222.128.1.157:16801/; @@ -158,6 +193,7 @@ server { proxy_send_timeout 60s; proxy_read_timeout 60s; } + # Elasticsearch API代理(价值论坛) location /es-api/ { proxy_pass http://222.128.1.157:19200/; @@ -223,36 +259,7 @@ server { proxy_send_timeout 86400s; proxy_read_timeout 86400s; } - # AI Chat 应用 (Next.js) - MCP 集成 - # AI Chat 静态资源(图片、CSS、JS) - location ~ ^/ai-chat/(images|_next/static|_next/image|favicon.ico) { - proxy_pass http://127.0.0.1:3000; - proxy_http_version 1.1; - proxy_set_header Host $host; - proxy_set_header X-Real-IP $remote_addr; - proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; - proxy_set_header X-Forwarded-Proto $scheme; - # 缓存设置 - expires 30d; - add_header Cache-Control "public, immutable"; - } - - # AI Chat 主应用 - location /ai-chat { - proxy_pass http://127.0.0.1:3000; - proxy_http_version 1.1; - proxy_set_header Host $host; - proxy_set_header X-Real-IP $remote_addr; - proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; - proxy_set_header X-Forwarded-Proto $scheme; - proxy_set_header Cookie $http_cookie; - proxy_pass_request_headers on; - proxy_set_header Upgrade $http_upgrade; - proxy_set_header Connection "upgrade"; - proxy_buffering off; - proxy_cache off; - } # iframe 内部资源代理(Bytedesk 聊天窗口的 CSS/JS) location /chat/ { proxy_pass http://43.143.189.195/chat/; @@ -326,6 +333,22 @@ server { add_header Cache-Control "public, max-age=86400"; } + # Bytedesk 文件访问代理(仅 2025 年文件) + location ^~ /file/2025/ { + proxy_pass http://43.143.189.195/file/2025/; + proxy_http_version 1.1; + proxy_set_header Host $host; + proxy_set_header X-Real-IP $remote_addr; + proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; + proxy_set_header X-Forwarded-Proto $scheme; + + # 缓存配置 + proxy_cache_valid 200 1d; + expires 1d; + add_header Cache-Control "public, max-age=86400"; + add_header Access-Control-Allow-Origin *; + } + # Visitor API 代理(Bytedesk 初始化接口) location /visitor/ {