#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 数据准备 V2 - 基于时间片对齐的特征计算(修复版) 核心改进: 1. 时间片对齐:9:35 和历史的 9:35 比,而不是和前30分钟比 2. Z-Score 特征:相对于同时间片历史分布的偏离程度 3. 滚动窗口基线:每个日期使用它之前 N 天的数据作为基线(不是固定的最后 N 天!) 4. 基于 Z-Score 的动量:消除一天内波动率异构性 修复: - 滚动窗口基线:避免未来数据泄露 - Z-Score 动量:消除早盘/尾盘波动率差异 - 进程级数据库单例:避免连接池爆炸 """ import os import sys import numpy as np import pandas as pd from datetime import datetime, timedelta from sqlalchemy import create_engine, text from elasticsearch import Elasticsearch from clickhouse_driver import Client from concurrent.futures import ProcessPoolExecutor, as_completed from typing import Dict, List, Tuple, Optional from tqdm import tqdm from collections import defaultdict import warnings import pickle warnings.filterwarnings('ignore') # ==================== 配置 ==================== MYSQL_URL = "mysql+pymysql://root:Zzl5588161!@192.168.1.5:3306/stock" ES_HOST = 'http://127.0.0.1:9200' ES_INDEX = 'concept_library_v3' CLICKHOUSE_CONFIG = { 'host': '127.0.0.1', 'port': 9000, 'user': 'default', 'password': 'Zzl33818!', 'database': 'stock' } REFERENCE_INDEX = '000001.SH' # 输出目录 OUTPUT_DIR = os.path.join(os.path.dirname(__file__), 'data_v2') BASELINE_DIR = os.path.join(OUTPUT_DIR, 'baselines') RAW_CACHE_DIR = os.path.join(OUTPUT_DIR, 'raw_cache') os.makedirs(OUTPUT_DIR, exist_ok=True) os.makedirs(BASELINE_DIR, exist_ok=True) os.makedirs(RAW_CACHE_DIR, exist_ok=True) # 特征配置 CONFIG = { 'baseline_days': 20, # 滚动窗口大小 'min_baseline_samples': 10, # 最少需要10个样本才算有效基线 'limit_up_threshold': 9.8, 'limit_down_threshold': -9.8, 'zscore_clip': 5.0, } # 特征列表 FEATURES_V2 = [ 'alpha', 'alpha_zscore', 'amt_zscore', 'rank_zscore', 'momentum_3m', 'momentum_5m', 'limit_up_ratio', ] # ==================== 进程级单例(避免连接池爆炸)==================== # 进程级全局变量 _process_mysql_engine = None _process_es_client = None _process_ch_client = None def init_process_connections(): """进程初始化时调用,创建连接(单例)""" global _process_mysql_engine, _process_es_client, _process_ch_client _process_mysql_engine = create_engine(MYSQL_URL, echo=False, pool_pre_ping=True, pool_size=5) _process_es_client = Elasticsearch([ES_HOST]) _process_ch_client = Client(**CLICKHOUSE_CONFIG) def get_mysql_engine(): """获取进程级 MySQL Engine(单例)""" global _process_mysql_engine if _process_mysql_engine is None: _process_mysql_engine = create_engine(MYSQL_URL, echo=False, pool_pre_ping=True, pool_size=5) return _process_mysql_engine def get_es_client(): """获取进程级 ES 客户端(单例)""" global _process_es_client if _process_es_client is None: _process_es_client = Elasticsearch([ES_HOST]) return _process_es_client def get_ch_client(): """获取进程级 ClickHouse 客户端(单例)""" global _process_ch_client if _process_ch_client is None: _process_ch_client = Client(**CLICKHOUSE_CONFIG) return _process_ch_client # ==================== 工具函数 ==================== def code_to_ch_format(code: str) -> str: if not code or len(code) != 6 or not code.isdigit(): return None if code.startswith('6'): return f"{code}.SH" elif code.startswith('0') or code.startswith('3'): return f"{code}.SZ" else: return f"{code}.BJ" def time_to_slot(ts) -> str: """将时间戳转换为时间片(HH:MM格式)""" if isinstance(ts, str): return ts return ts.strftime('%H:%M') # ==================== 获取概念列表 ==================== def get_all_concepts() -> List[dict]: """从ES获取所有叶子概念""" es_client = get_es_client() concepts = [] query = { "query": {"match_all": {}}, "size": 100, "_source": ["concept_id", "concept", "stocks"] } resp = es_client.search(index=ES_INDEX, body=query, scroll='2m') scroll_id = resp['_scroll_id'] hits = resp['hits']['hits'] while len(hits) > 0: for hit in hits: source = hit['_source'] stocks = [] if 'stocks' in source and isinstance(source['stocks'], list): for stock in source['stocks']: if isinstance(stock, dict) and 'code' in stock and stock['code']: stocks.append(stock['code']) if stocks: concepts.append({ 'concept_id': source.get('concept_id'), 'concept_name': source.get('concept'), 'stocks': stocks }) resp = es_client.scroll(scroll_id=scroll_id, scroll='2m') scroll_id = resp['_scroll_id'] hits = resp['hits']['hits'] es_client.clear_scroll(scroll_id=scroll_id) print(f"获取到 {len(concepts)} 个概念") return concepts # ==================== 获取交易日列表 ==================== def get_trading_days(start_date: str, end_date: str) -> List[str]: """获取交易日列表""" client = get_ch_client() query = f""" SELECT DISTINCT toDate(timestamp) as trade_date FROM stock_minute WHERE toDate(timestamp) >= '{start_date}' AND toDate(timestamp) <= '{end_date}' ORDER BY trade_date """ result = client.execute(query) days = [row[0].strftime('%Y-%m-%d') for row in result] if days: print(f"找到 {len(days)} 个交易日: {days[0]} ~ {days[-1]}") return days # ==================== 获取昨收价 ==================== def get_prev_close(stock_codes: List[str], trade_date: str) -> Dict[str, float]: """获取昨收价(上一交易日的收盘价 F007N)""" valid_codes = [c for c in stock_codes if c and len(c) == 6 and c.isdigit()] if not valid_codes: return {} codes_str = "','".join(valid_codes) query = f""" SELECT SECCODE, F007N FROM ea_trade WHERE SECCODE IN ('{codes_str}') AND TRADEDATE = ( SELECT MAX(TRADEDATE) FROM ea_trade WHERE TRADEDATE < '{trade_date}' ) AND F007N IS NOT NULL AND F007N > 0 """ try: engine = get_mysql_engine() with engine.connect() as conn: result = conn.execute(text(query)) return {row[0]: float(row[1]) for row in result if row[1]} except Exception as e: print(f"获取昨收价失败: {e}") return {} def get_index_prev_close(trade_date: str, index_code: str = REFERENCE_INDEX) -> float: """获取指数昨收价""" code_no_suffix = index_code.split('.')[0] try: engine = get_mysql_engine() with engine.connect() as conn: result = conn.execute(text(""" SELECT F006N FROM ea_exchangetrade WHERE INDEXCODE = :code AND TRADEDATE < :today ORDER BY TRADEDATE DESC LIMIT 1 """), {'code': code_no_suffix, 'today': trade_date}).fetchone() if result and result[0]: return float(result[0]) except Exception as e: print(f"获取指数昨收失败: {e}") return None # ==================== 获取分钟数据 ==================== def get_daily_stock_data(trade_date: str, stock_codes: List[str]) -> pd.DataFrame: """获取单日所有股票的分钟数据""" client = get_ch_client() ch_codes = [] code_map = {} for code in stock_codes: ch_code = code_to_ch_format(code) if ch_code: ch_codes.append(ch_code) code_map[ch_code] = code if not ch_codes: return pd.DataFrame() ch_codes_str = "','".join(ch_codes) query = f""" SELECT code, timestamp, close, volume, amt FROM stock_minute WHERE toDate(timestamp) = '{trade_date}' AND code IN ('{ch_codes_str}') ORDER BY code, timestamp """ result = client.execute(query) if not result: return pd.DataFrame() df = pd.DataFrame(result, columns=['ch_code', 'timestamp', 'close', 'volume', 'amt']) df['code'] = df['ch_code'].map(code_map) df = df.dropna(subset=['code']) return df[['code', 'timestamp', 'close', 'volume', 'amt']] def get_daily_index_data(trade_date: str, index_code: str = REFERENCE_INDEX) -> pd.DataFrame: """获取单日指数分钟数据""" client = get_ch_client() query = f""" SELECT timestamp, close, volume, amt FROM index_minute WHERE toDate(timestamp) = '{trade_date}' AND code = '{index_code}' ORDER BY timestamp """ result = client.execute(query) if not result: return pd.DataFrame() df = pd.DataFrame(result, columns=['timestamp', 'close', 'volume', 'amt']) return df # ==================== 计算原始概念特征(单日)==================== def compute_raw_concept_features( trade_date: str, concepts: List[dict], all_stocks: List[str] ) -> pd.DataFrame: """计算单日概念的原始特征(alpha, amt, rank_pct, limit_up_ratio)""" # 检查缓存 cache_file = os.path.join(RAW_CACHE_DIR, f'raw_{trade_date}.parquet') if os.path.exists(cache_file): return pd.read_parquet(cache_file) # 获取数据 stock_df = get_daily_stock_data(trade_date, all_stocks) if stock_df.empty: return pd.DataFrame() index_df = get_daily_index_data(trade_date) if index_df.empty: return pd.DataFrame() # 获取昨收价 prev_close = get_prev_close(all_stocks, trade_date) index_prev_close = get_index_prev_close(trade_date) if not prev_close or not index_prev_close: return pd.DataFrame() # 计算涨跌幅 stock_df['prev_close'] = stock_df['code'].map(prev_close) stock_df = stock_df.dropna(subset=['prev_close']) stock_df['change_pct'] = (stock_df['close'] - stock_df['prev_close']) / stock_df['prev_close'] * 100 index_df['change_pct'] = (index_df['close'] - index_prev_close) / index_prev_close * 100 index_change_map = dict(zip(index_df['timestamp'], index_df['change_pct'])) # 获取所有时间点 timestamps = sorted(stock_df['timestamp'].unique()) # 概念到股票的映射 concept_stocks = {c['concept_id']: set(c['stocks']) for c in concepts} results = [] for ts in timestamps: ts_stock_data = stock_df[stock_df['timestamp'] == ts] index_change = index_change_map.get(ts, 0) stock_change = dict(zip(ts_stock_data['code'], ts_stock_data['change_pct'])) stock_amt = dict(zip(ts_stock_data['code'], ts_stock_data['amt'])) concept_features = [] for concept_id, stocks in concept_stocks.items(): concept_changes = [stock_change[s] for s in stocks if s in stock_change] concept_amts = [stock_amt.get(s, 0) for s in stocks if s in stock_change] if not concept_changes: continue avg_change = np.mean(concept_changes) total_amt = sum(concept_amts) alpha = avg_change - index_change limit_up_count = sum(1 for c in concept_changes if c >= CONFIG['limit_up_threshold']) limit_up_ratio = limit_up_count / len(concept_changes) concept_features.append({ 'concept_id': concept_id, 'alpha': alpha, 'total_amt': total_amt, 'limit_up_ratio': limit_up_ratio, 'stock_count': len(concept_changes), }) if not concept_features: continue concept_df = pd.DataFrame(concept_features) concept_df['rank_pct'] = concept_df['alpha'].rank(pct=True) concept_df['timestamp'] = ts concept_df['time_slot'] = time_to_slot(ts) concept_df['trade_date'] = trade_date results.append(concept_df) if not results: return pd.DataFrame() result_df = pd.concat(results, ignore_index=True) # 保存缓存 result_df.to_parquet(cache_file, index=False) return result_df # ==================== 滚动窗口基线计算 ==================== def compute_rolling_baseline( historical_data: pd.DataFrame, concept_id: str ) -> Dict[str, Dict]: """ 计算单个概念的滚动基线 返回: {time_slot: {alpha_mean, alpha_std, amt_mean, amt_std, rank_mean, rank_std, sample_count}} """ if historical_data.empty: return {} concept_data = historical_data[historical_data['concept_id'] == concept_id] if concept_data.empty: return {} baseline_dict = {} for time_slot, group in concept_data.groupby('time_slot'): if len(group) < CONFIG['min_baseline_samples']: continue alpha_std = group['alpha'].std() amt_std = group['total_amt'].std() rank_std = group['rank_pct'].std() baseline_dict[time_slot] = { 'alpha_mean': group['alpha'].mean(), 'alpha_std': max(alpha_std if pd.notna(alpha_std) else 1.0, 0.1), 'amt_mean': group['total_amt'].mean(), 'amt_std': max(amt_std if pd.notna(amt_std) else group['total_amt'].mean() * 0.5, 1.0), 'rank_mean': group['rank_pct'].mean(), 'rank_std': max(rank_std if pd.notna(rank_std) else 0.2, 0.05), 'sample_count': len(group), } return baseline_dict # ==================== 计算单日 Z-Score 特征(带滚动基线)==================== def compute_zscore_features_rolling( trade_date: str, concepts: List[dict], all_stocks: List[str], historical_raw_data: pd.DataFrame # 该日期之前 N 天的原始数据 ) -> pd.DataFrame: """ 计算单日的 Z-Score 特征(使用滚动窗口基线) 关键改进: 1. 基线只使用 trade_date 之前的数据(无未来泄露) 2. 动量基于 Z-Score 计算(消除波动率异构性) """ # 计算当日原始特征 raw_df = compute_raw_concept_features(trade_date, concepts, all_stocks) if raw_df.empty: return pd.DataFrame() zscore_records = [] for concept_id, group in raw_df.groupby('concept_id'): # 计算该概念的滚动基线(只用历史数据) baseline_dict = compute_rolling_baseline(historical_raw_data, concept_id) if not baseline_dict: continue # 按时间排序 group = group.sort_values('timestamp').reset_index(drop=True) # Z-Score 历史(用于计算基于 Z-Score 的动量) zscore_history = [] for idx, row in group.iterrows(): time_slot = row['time_slot'] if time_slot not in baseline_dict: continue bl = baseline_dict[time_slot] # 计算 Z-Score alpha_zscore = (row['alpha'] - bl['alpha_mean']) / bl['alpha_std'] amt_zscore = (row['total_amt'] - bl['amt_mean']) / bl['amt_std'] rank_zscore = (row['rank_pct'] - bl['rank_mean']) / bl['rank_std'] # 截断极端值 clip = CONFIG['zscore_clip'] alpha_zscore = np.clip(alpha_zscore, -clip, clip) amt_zscore = np.clip(amt_zscore, -clip, clip) rank_zscore = np.clip(rank_zscore, -clip, clip) # 记录 Z-Score 历史 zscore_history.append(alpha_zscore) # 基于 Z-Score 计算动量(消除波动率异构性) momentum_3m = 0.0 momentum_5m = 0.0 if len(zscore_history) >= 3: recent_3 = zscore_history[-3:] older_3 = zscore_history[-6:-3] if len(zscore_history) >= 6 else [zscore_history[0]] momentum_3m = np.mean(recent_3) - np.mean(older_3) if len(zscore_history) >= 5: recent_5 = zscore_history[-5:] older_5 = zscore_history[-10:-5] if len(zscore_history) >= 10 else [zscore_history[0]] momentum_5m = np.mean(recent_5) - np.mean(older_5) zscore_records.append({ 'concept_id': concept_id, 'timestamp': row['timestamp'], 'time_slot': time_slot, 'trade_date': trade_date, # 原始特征 'alpha': row['alpha'], 'total_amt': row['total_amt'], 'limit_up_ratio': row['limit_up_ratio'], 'stock_count': row['stock_count'], 'rank_pct': row['rank_pct'], # Z-Score 特征 'alpha_zscore': alpha_zscore, 'amt_zscore': amt_zscore, 'rank_zscore': rank_zscore, # 基于 Z-Score 的动量 'momentum_3m': momentum_3m, 'momentum_5m': momentum_5m, }) if not zscore_records: return pd.DataFrame() return pd.DataFrame(zscore_records) # ==================== 多进程处理 ==================== def process_single_day_v2(args) -> Tuple[str, bool]: """处理单个交易日(多进程版本)""" trade_date, day_index, concepts, all_stocks, all_trading_days = args output_file = os.path.join(OUTPUT_DIR, f'features_v2_{trade_date}.parquet') if os.path.exists(output_file): return (trade_date, True) try: # 计算滚动窗口范围(该日期之前的 N 天) baseline_days = CONFIG['baseline_days'] # 找出 trade_date 之前的交易日 start_idx = max(0, day_index - baseline_days) end_idx = day_index # 不包含当天 if end_idx <= start_idx: # 没有足够的历史数据 return (trade_date, False) historical_days = all_trading_days[start_idx:end_idx] # 加载历史原始数据 historical_dfs = [] for hist_date in historical_days: cache_file = os.path.join(RAW_CACHE_DIR, f'raw_{hist_date}.parquet') if os.path.exists(cache_file): historical_dfs.append(pd.read_parquet(cache_file)) else: # 需要计算 hist_df = compute_raw_concept_features(hist_date, concepts, all_stocks) if not hist_df.empty: historical_dfs.append(hist_df) if not historical_dfs: return (trade_date, False) historical_raw_data = pd.concat(historical_dfs, ignore_index=True) # 计算当日 Z-Score 特征(使用滚动基线) df = compute_zscore_features_rolling(trade_date, concepts, all_stocks, historical_raw_data) if df.empty: return (trade_date, False) df.to_parquet(output_file, index=False) return (trade_date, True) except Exception as e: print(f"[{trade_date}] 处理失败: {e}") import traceback traceback.print_exc() return (trade_date, False) # ==================== 主流程 ==================== def main(): import argparse parser = argparse.ArgumentParser(description='准备训练数据 V2(滚动窗口基线 + Z-Score 动量)') parser.add_argument('--start', type=str, default='2022-01-01', help='开始日期') parser.add_argument('--end', type=str, default=None, help='结束日期(默认今天)') parser.add_argument('--workers', type=int, default=18, help='并行进程数') parser.add_argument('--baseline-days', type=int, default=20, help='滚动基线窗口大小') parser.add_argument('--force', action='store_true', help='强制重新计算(忽略缓存)') args = parser.parse_args() end_date = args.end or datetime.now().strftime('%Y-%m-%d') CONFIG['baseline_days'] = args.baseline_days print("=" * 60) print("数据准备 V2 - 滚动窗口基线 + Z-Score 动量") print("=" * 60) print(f"日期范围: {args.start} ~ {end_date}") print(f"并行进程数: {args.workers}") print(f"滚动基线窗口: {args.baseline_days} 天") # 初始化主进程连接 init_process_connections() # 1. 获取概念列表 concepts = get_all_concepts() all_stocks = list(set(s for c in concepts for s in c['stocks'])) print(f"股票总数: {len(all_stocks)}") # 2. 获取交易日列表 trading_days = get_trading_days(args.start, end_date) if not trading_days: print("无交易日数据") return # 3. 第一阶段:预计算所有原始特征(用于缓存) print(f"\n{'='*60}") print("第一阶段:预计算原始特征(用于滚动基线)") print(f"{'='*60}") # 如果强制重新计算,删除缓存 if args.force: import shutil if os.path.exists(RAW_CACHE_DIR): shutil.rmtree(RAW_CACHE_DIR) os.makedirs(RAW_CACHE_DIR, exist_ok=True) if os.path.exists(OUTPUT_DIR): for f in os.listdir(OUTPUT_DIR): if f.startswith('features_v2_'): os.remove(os.path.join(OUTPUT_DIR, f)) # 单线程预计算原始特征(因为需要顺序缓存) print(f"预计算 {len(trading_days)} 天的原始特征...") for trade_date in tqdm(trading_days, desc="预计算原始特征"): cache_file = os.path.join(RAW_CACHE_DIR, f'raw_{trade_date}.parquet') if not os.path.exists(cache_file): compute_raw_concept_features(trade_date, concepts, all_stocks) # 4. 第二阶段:计算 Z-Score 特征(多进程) print(f"\n{'='*60}") print("第二阶段:计算 Z-Score 特征(滚动基线)") print(f"{'='*60}") # 从第 baseline_days 天开始(前面的没有足够历史) start_idx = args.baseline_days processable_days = trading_days[start_idx:] if not processable_days: print(f"错误:需要至少 {args.baseline_days + 1} 天的数据") return print(f"可处理日期: {processable_days[0]} ~ {processable_days[-1]} ({len(processable_days)} 天)") print(f"跳过前 {start_idx} 天(基线预热期)") # 构建任务 tasks = [] for i, trade_date in enumerate(trading_days): if i >= start_idx: tasks.append((trade_date, i, concepts, all_stocks, trading_days)) print(f"开始处理 {len(tasks)} 个交易日({args.workers} 进程并行)...") success_count = 0 failed_dates = [] # 使用进程池初始化器 with ProcessPoolExecutor(max_workers=args.workers, initializer=init_process_connections) as executor: futures = {executor.submit(process_single_day_v2, task): task[0] for task in tasks} with tqdm(total=len(futures), desc="处理进度", unit="天") as pbar: for future in as_completed(futures): trade_date = futures[future] try: result_date, success = future.result() if success: success_count += 1 else: failed_dates.append(result_date) except Exception as e: print(f"\n[{trade_date}] 进程异常: {e}") failed_dates.append(trade_date) pbar.update(1) print("\n" + "=" * 60) print(f"处理完成: {success_count}/{len(tasks)} 个交易日") if failed_dates: print(f"失败日期: {failed_dates[:10]}{'...' if len(failed_dates) > 10 else ''}") print(f"数据保存在: {OUTPUT_DIR}") print("=" * 60) if __name__ == "__main__": main()