Files
vf_react/ml/prepare_data_v2.py
2025-12-10 11:02:09 +08:00

716 lines
23 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
数据准备 V2 - 基于时间片对齐的特征计算(修复版)
核心改进:
1. 时间片对齐9:35 和历史的 9:35 比而不是和前30分钟比
2. Z-Score 特征:相对于同时间片历史分布的偏离程度
3. 滚动窗口基线:每个日期使用它之前 N 天的数据作为基线(不是固定的最后 N 天!)
4. 基于 Z-Score 的动量:消除一天内波动率异构性
修复:
- 滚动窗口基线:避免未来数据泄露
- Z-Score 动量:消除早盘/尾盘波动率差异
- 进程级数据库单例:避免连接池爆炸
"""
import os
import sys
import numpy as np
import pandas as pd
from datetime import datetime, timedelta
from sqlalchemy import create_engine, text
from elasticsearch import Elasticsearch
from clickhouse_driver import Client
from concurrent.futures import ProcessPoolExecutor, as_completed
from typing import Dict, List, Tuple, Optional
from tqdm import tqdm
from collections import defaultdict
import warnings
import pickle
warnings.filterwarnings('ignore')
# ==================== 配置 ====================
MYSQL_URL = "mysql+pymysql://root:Zzl5588161!@192.168.1.5:3306/stock"
ES_HOST = 'http://127.0.0.1:9200'
ES_INDEX = 'concept_library_v3'
CLICKHOUSE_CONFIG = {
'host': '127.0.0.1',
'port': 9000,
'user': 'default',
'password': 'Zzl33818!',
'database': 'stock'
}
REFERENCE_INDEX = '000001.SH'
# 输出目录
OUTPUT_DIR = os.path.join(os.path.dirname(__file__), 'data_v2')
BASELINE_DIR = os.path.join(OUTPUT_DIR, 'baselines')
RAW_CACHE_DIR = os.path.join(OUTPUT_DIR, 'raw_cache')
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(BASELINE_DIR, exist_ok=True)
os.makedirs(RAW_CACHE_DIR, exist_ok=True)
# 特征配置
CONFIG = {
'baseline_days': 20, # 滚动窗口大小
'min_baseline_samples': 10, # 最少需要10个样本才算有效基线
'limit_up_threshold': 9.8,
'limit_down_threshold': -9.8,
'zscore_clip': 5.0,
}
# 特征列表
FEATURES_V2 = [
'alpha', 'alpha_zscore', 'amt_zscore', 'rank_zscore',
'momentum_3m', 'momentum_5m', 'limit_up_ratio',
]
# ==================== 进程级单例(避免连接池爆炸)====================
# 进程级全局变量
_process_mysql_engine = None
_process_es_client = None
_process_ch_client = None
def init_process_connections():
"""进程初始化时调用,创建连接(单例)"""
global _process_mysql_engine, _process_es_client, _process_ch_client
_process_mysql_engine = create_engine(MYSQL_URL, echo=False, pool_pre_ping=True, pool_size=5)
_process_es_client = Elasticsearch([ES_HOST])
_process_ch_client = Client(**CLICKHOUSE_CONFIG)
def get_mysql_engine():
"""获取进程级 MySQL Engine单例"""
global _process_mysql_engine
if _process_mysql_engine is None:
_process_mysql_engine = create_engine(MYSQL_URL, echo=False, pool_pre_ping=True, pool_size=5)
return _process_mysql_engine
def get_es_client():
"""获取进程级 ES 客户端(单例)"""
global _process_es_client
if _process_es_client is None:
_process_es_client = Elasticsearch([ES_HOST])
return _process_es_client
def get_ch_client():
"""获取进程级 ClickHouse 客户端(单例)"""
global _process_ch_client
if _process_ch_client is None:
_process_ch_client = Client(**CLICKHOUSE_CONFIG)
return _process_ch_client
# ==================== 工具函数 ====================
def code_to_ch_format(code: str) -> str:
if not code or len(code) != 6 or not code.isdigit():
return None
if code.startswith('6'):
return f"{code}.SH"
elif code.startswith('0') or code.startswith('3'):
return f"{code}.SZ"
else:
return f"{code}.BJ"
def time_to_slot(ts) -> str:
"""将时间戳转换为时间片HH:MM格式"""
if isinstance(ts, str):
return ts
return ts.strftime('%H:%M')
# ==================== 获取概念列表 ====================
def get_all_concepts() -> List[dict]:
"""从ES获取所有叶子概念"""
es_client = get_es_client()
concepts = []
query = {
"query": {"match_all": {}},
"size": 100,
"_source": ["concept_id", "concept", "stocks"]
}
resp = es_client.search(index=ES_INDEX, body=query, scroll='2m')
scroll_id = resp['_scroll_id']
hits = resp['hits']['hits']
while len(hits) > 0:
for hit in hits:
source = hit['_source']
stocks = []
if 'stocks' in source and isinstance(source['stocks'], list):
for stock in source['stocks']:
if isinstance(stock, dict) and 'code' in stock and stock['code']:
stocks.append(stock['code'])
if stocks:
concepts.append({
'concept_id': source.get('concept_id'),
'concept_name': source.get('concept'),
'stocks': stocks
})
resp = es_client.scroll(scroll_id=scroll_id, scroll='2m')
scroll_id = resp['_scroll_id']
hits = resp['hits']['hits']
es_client.clear_scroll(scroll_id=scroll_id)
print(f"获取到 {len(concepts)} 个概念")
return concepts
# ==================== 获取交易日列表 ====================
def get_trading_days(start_date: str, end_date: str) -> List[str]:
"""获取交易日列表"""
client = get_ch_client()
query = f"""
SELECT DISTINCT toDate(timestamp) as trade_date
FROM stock_minute
WHERE toDate(timestamp) >= '{start_date}'
AND toDate(timestamp) <= '{end_date}'
ORDER BY trade_date
"""
result = client.execute(query)
days = [row[0].strftime('%Y-%m-%d') for row in result]
if days:
print(f"找到 {len(days)} 个交易日: {days[0]} ~ {days[-1]}")
return days
# ==================== 获取昨收价 ====================
def get_prev_close(stock_codes: List[str], trade_date: str) -> Dict[str, float]:
"""获取昨收价(上一交易日的收盘价 F007N"""
valid_codes = [c for c in stock_codes if c and len(c) == 6 and c.isdigit()]
if not valid_codes:
return {}
codes_str = "','".join(valid_codes)
query = f"""
SELECT SECCODE, F007N
FROM ea_trade
WHERE SECCODE IN ('{codes_str}')
AND TRADEDATE = (
SELECT MAX(TRADEDATE) FROM ea_trade WHERE TRADEDATE < '{trade_date}'
)
AND F007N IS NOT NULL AND F007N > 0
"""
try:
engine = get_mysql_engine()
with engine.connect() as conn:
result = conn.execute(text(query))
return {row[0]: float(row[1]) for row in result if row[1]}
except Exception as e:
print(f"获取昨收价失败: {e}")
return {}
def get_index_prev_close(trade_date: str, index_code: str = REFERENCE_INDEX) -> float:
"""获取指数昨收价"""
code_no_suffix = index_code.split('.')[0]
try:
engine = get_mysql_engine()
with engine.connect() as conn:
result = conn.execute(text("""
SELECT F006N FROM ea_exchangetrade
WHERE INDEXCODE = :code AND TRADEDATE < :today
ORDER BY TRADEDATE DESC LIMIT 1
"""), {'code': code_no_suffix, 'today': trade_date}).fetchone()
if result and result[0]:
return float(result[0])
except Exception as e:
print(f"获取指数昨收失败: {e}")
return None
# ==================== 获取分钟数据 ====================
def get_daily_stock_data(trade_date: str, stock_codes: List[str]) -> pd.DataFrame:
"""获取单日所有股票的分钟数据"""
client = get_ch_client()
ch_codes = []
code_map = {}
for code in stock_codes:
ch_code = code_to_ch_format(code)
if ch_code:
ch_codes.append(ch_code)
code_map[ch_code] = code
if not ch_codes:
return pd.DataFrame()
ch_codes_str = "','".join(ch_codes)
query = f"""
SELECT code, timestamp, close, volume, amt
FROM stock_minute
WHERE toDate(timestamp) = '{trade_date}'
AND code IN ('{ch_codes_str}')
ORDER BY code, timestamp
"""
result = client.execute(query)
if not result:
return pd.DataFrame()
df = pd.DataFrame(result, columns=['ch_code', 'timestamp', 'close', 'volume', 'amt'])
df['code'] = df['ch_code'].map(code_map)
df = df.dropna(subset=['code'])
return df[['code', 'timestamp', 'close', 'volume', 'amt']]
def get_daily_index_data(trade_date: str, index_code: str = REFERENCE_INDEX) -> pd.DataFrame:
"""获取单日指数分钟数据"""
client = get_ch_client()
query = f"""
SELECT timestamp, close, volume, amt
FROM index_minute
WHERE toDate(timestamp) = '{trade_date}'
AND code = '{index_code}'
ORDER BY timestamp
"""
result = client.execute(query)
if not result:
return pd.DataFrame()
df = pd.DataFrame(result, columns=['timestamp', 'close', 'volume', 'amt'])
return df
# ==================== 计算原始概念特征(单日)====================
def compute_raw_concept_features(
trade_date: str,
concepts: List[dict],
all_stocks: List[str]
) -> pd.DataFrame:
"""计算单日概念的原始特征alpha, amt, rank_pct, limit_up_ratio"""
# 检查缓存
cache_file = os.path.join(RAW_CACHE_DIR, f'raw_{trade_date}.parquet')
if os.path.exists(cache_file):
return pd.read_parquet(cache_file)
# 获取数据
stock_df = get_daily_stock_data(trade_date, all_stocks)
if stock_df.empty:
return pd.DataFrame()
index_df = get_daily_index_data(trade_date)
if index_df.empty:
return pd.DataFrame()
# 获取昨收价
prev_close = get_prev_close(all_stocks, trade_date)
index_prev_close = get_index_prev_close(trade_date)
if not prev_close or not index_prev_close:
return pd.DataFrame()
# 计算涨跌幅
stock_df['prev_close'] = stock_df['code'].map(prev_close)
stock_df = stock_df.dropna(subset=['prev_close'])
stock_df['change_pct'] = (stock_df['close'] - stock_df['prev_close']) / stock_df['prev_close'] * 100
index_df['change_pct'] = (index_df['close'] - index_prev_close) / index_prev_close * 100
index_change_map = dict(zip(index_df['timestamp'], index_df['change_pct']))
# 获取所有时间点
timestamps = sorted(stock_df['timestamp'].unique())
# 概念到股票的映射
concept_stocks = {c['concept_id']: set(c['stocks']) for c in concepts}
results = []
for ts in timestamps:
ts_stock_data = stock_df[stock_df['timestamp'] == ts]
index_change = index_change_map.get(ts, 0)
stock_change = dict(zip(ts_stock_data['code'], ts_stock_data['change_pct']))
stock_amt = dict(zip(ts_stock_data['code'], ts_stock_data['amt']))
concept_features = []
for concept_id, stocks in concept_stocks.items():
concept_changes = [stock_change[s] for s in stocks if s in stock_change]
concept_amts = [stock_amt.get(s, 0) for s in stocks if s in stock_change]
if not concept_changes:
continue
avg_change = np.mean(concept_changes)
total_amt = sum(concept_amts)
alpha = avg_change - index_change
limit_up_count = sum(1 for c in concept_changes if c >= CONFIG['limit_up_threshold'])
limit_up_ratio = limit_up_count / len(concept_changes)
concept_features.append({
'concept_id': concept_id,
'alpha': alpha,
'total_amt': total_amt,
'limit_up_ratio': limit_up_ratio,
'stock_count': len(concept_changes),
})
if not concept_features:
continue
concept_df = pd.DataFrame(concept_features)
concept_df['rank_pct'] = concept_df['alpha'].rank(pct=True)
concept_df['timestamp'] = ts
concept_df['time_slot'] = time_to_slot(ts)
concept_df['trade_date'] = trade_date
results.append(concept_df)
if not results:
return pd.DataFrame()
result_df = pd.concat(results, ignore_index=True)
# 保存缓存
result_df.to_parquet(cache_file, index=False)
return result_df
# ==================== 滚动窗口基线计算 ====================
def compute_rolling_baseline(
historical_data: pd.DataFrame,
concept_id: str
) -> Dict[str, Dict]:
"""
计算单个概念的滚动基线
返回: {time_slot: {alpha_mean, alpha_std, amt_mean, amt_std, rank_mean, rank_std, sample_count}}
"""
if historical_data.empty:
return {}
concept_data = historical_data[historical_data['concept_id'] == concept_id]
if concept_data.empty:
return {}
baseline_dict = {}
for time_slot, group in concept_data.groupby('time_slot'):
if len(group) < CONFIG['min_baseline_samples']:
continue
alpha_std = group['alpha'].std()
amt_std = group['total_amt'].std()
rank_std = group['rank_pct'].std()
baseline_dict[time_slot] = {
'alpha_mean': group['alpha'].mean(),
'alpha_std': max(alpha_std if pd.notna(alpha_std) else 1.0, 0.1),
'amt_mean': group['total_amt'].mean(),
'amt_std': max(amt_std if pd.notna(amt_std) else group['total_amt'].mean() * 0.5, 1.0),
'rank_mean': group['rank_pct'].mean(),
'rank_std': max(rank_std if pd.notna(rank_std) else 0.2, 0.05),
'sample_count': len(group),
}
return baseline_dict
# ==================== 计算单日 Z-Score 特征(带滚动基线)====================
def compute_zscore_features_rolling(
trade_date: str,
concepts: List[dict],
all_stocks: List[str],
historical_raw_data: pd.DataFrame # 该日期之前 N 天的原始数据
) -> pd.DataFrame:
"""
计算单日的 Z-Score 特征(使用滚动窗口基线)
关键改进:
1. 基线只使用 trade_date 之前的数据(无未来泄露)
2. 动量基于 Z-Score 计算(消除波动率异构性)
"""
# 计算当日原始特征
raw_df = compute_raw_concept_features(trade_date, concepts, all_stocks)
if raw_df.empty:
return pd.DataFrame()
zscore_records = []
for concept_id, group in raw_df.groupby('concept_id'):
# 计算该概念的滚动基线(只用历史数据)
baseline_dict = compute_rolling_baseline(historical_raw_data, concept_id)
if not baseline_dict:
continue
# 按时间排序
group = group.sort_values('timestamp').reset_index(drop=True)
# Z-Score 历史(用于计算基于 Z-Score 的动量)
zscore_history = []
for idx, row in group.iterrows():
time_slot = row['time_slot']
if time_slot not in baseline_dict:
continue
bl = baseline_dict[time_slot]
# 计算 Z-Score
alpha_zscore = (row['alpha'] - bl['alpha_mean']) / bl['alpha_std']
amt_zscore = (row['total_amt'] - bl['amt_mean']) / bl['amt_std']
rank_zscore = (row['rank_pct'] - bl['rank_mean']) / bl['rank_std']
# 截断极端值
clip = CONFIG['zscore_clip']
alpha_zscore = np.clip(alpha_zscore, -clip, clip)
amt_zscore = np.clip(amt_zscore, -clip, clip)
rank_zscore = np.clip(rank_zscore, -clip, clip)
# 记录 Z-Score 历史
zscore_history.append(alpha_zscore)
# 基于 Z-Score 计算动量(消除波动率异构性)
momentum_3m = 0.0
momentum_5m = 0.0
if len(zscore_history) >= 3:
recent_3 = zscore_history[-3:]
older_3 = zscore_history[-6:-3] if len(zscore_history) >= 6 else [zscore_history[0]]
momentum_3m = np.mean(recent_3) - np.mean(older_3)
if len(zscore_history) >= 5:
recent_5 = zscore_history[-5:]
older_5 = zscore_history[-10:-5] if len(zscore_history) >= 10 else [zscore_history[0]]
momentum_5m = np.mean(recent_5) - np.mean(older_5)
zscore_records.append({
'concept_id': concept_id,
'timestamp': row['timestamp'],
'time_slot': time_slot,
'trade_date': trade_date,
# 原始特征
'alpha': row['alpha'],
'total_amt': row['total_amt'],
'limit_up_ratio': row['limit_up_ratio'],
'stock_count': row['stock_count'],
'rank_pct': row['rank_pct'],
# Z-Score 特征
'alpha_zscore': alpha_zscore,
'amt_zscore': amt_zscore,
'rank_zscore': rank_zscore,
# 基于 Z-Score 的动量
'momentum_3m': momentum_3m,
'momentum_5m': momentum_5m,
})
if not zscore_records:
return pd.DataFrame()
return pd.DataFrame(zscore_records)
# ==================== 多进程处理 ====================
def process_single_day_v2(args) -> Tuple[str, bool]:
"""处理单个交易日(多进程版本)"""
trade_date, day_index, concepts, all_stocks, all_trading_days = args
output_file = os.path.join(OUTPUT_DIR, f'features_v2_{trade_date}.parquet')
if os.path.exists(output_file):
return (trade_date, True)
try:
# 计算滚动窗口范围(该日期之前的 N 天)
baseline_days = CONFIG['baseline_days']
# 找出 trade_date 之前的交易日
start_idx = max(0, day_index - baseline_days)
end_idx = day_index # 不包含当天
if end_idx <= start_idx:
# 没有足够的历史数据
return (trade_date, False)
historical_days = all_trading_days[start_idx:end_idx]
# 加载历史原始数据
historical_dfs = []
for hist_date in historical_days:
cache_file = os.path.join(RAW_CACHE_DIR, f'raw_{hist_date}.parquet')
if os.path.exists(cache_file):
historical_dfs.append(pd.read_parquet(cache_file))
else:
# 需要计算
hist_df = compute_raw_concept_features(hist_date, concepts, all_stocks)
if not hist_df.empty:
historical_dfs.append(hist_df)
if not historical_dfs:
return (trade_date, False)
historical_raw_data = pd.concat(historical_dfs, ignore_index=True)
# 计算当日 Z-Score 特征(使用滚动基线)
df = compute_zscore_features_rolling(trade_date, concepts, all_stocks, historical_raw_data)
if df.empty:
return (trade_date, False)
df.to_parquet(output_file, index=False)
return (trade_date, True)
except Exception as e:
print(f"[{trade_date}] 处理失败: {e}")
import traceback
traceback.print_exc()
return (trade_date, False)
# ==================== 主流程 ====================
def main():
import argparse
parser = argparse.ArgumentParser(description='准备训练数据 V2滚动窗口基线 + Z-Score 动量)')
parser.add_argument('--start', type=str, default='2022-01-01', help='开始日期')
parser.add_argument('--end', type=str, default=None, help='结束日期(默认今天)')
parser.add_argument('--workers', type=int, default=18, help='并行进程数')
parser.add_argument('--baseline-days', type=int, default=20, help='滚动基线窗口大小')
parser.add_argument('--force', action='store_true', help='强制重新计算(忽略缓存)')
args = parser.parse_args()
end_date = args.end or datetime.now().strftime('%Y-%m-%d')
CONFIG['baseline_days'] = args.baseline_days
print("=" * 60)
print("数据准备 V2 - 滚动窗口基线 + Z-Score 动量")
print("=" * 60)
print(f"日期范围: {args.start} ~ {end_date}")
print(f"并行进程数: {args.workers}")
print(f"滚动基线窗口: {args.baseline_days}")
# 初始化主进程连接
init_process_connections()
# 1. 获取概念列表
concepts = get_all_concepts()
all_stocks = list(set(s for c in concepts for s in c['stocks']))
print(f"股票总数: {len(all_stocks)}")
# 2. 获取交易日列表
trading_days = get_trading_days(args.start, end_date)
if not trading_days:
print("无交易日数据")
return
# 3. 第一阶段:预计算所有原始特征(用于缓存)
print(f"\n{'='*60}")
print("第一阶段:预计算原始特征(用于滚动基线)")
print(f"{'='*60}")
# 如果强制重新计算,删除缓存
if args.force:
import shutil
if os.path.exists(RAW_CACHE_DIR):
shutil.rmtree(RAW_CACHE_DIR)
os.makedirs(RAW_CACHE_DIR, exist_ok=True)
if os.path.exists(OUTPUT_DIR):
for f in os.listdir(OUTPUT_DIR):
if f.startswith('features_v2_'):
os.remove(os.path.join(OUTPUT_DIR, f))
# 单线程预计算原始特征(因为需要顺序缓存)
print(f"预计算 {len(trading_days)} 天的原始特征...")
for trade_date in tqdm(trading_days, desc="预计算原始特征"):
cache_file = os.path.join(RAW_CACHE_DIR, f'raw_{trade_date}.parquet')
if not os.path.exists(cache_file):
compute_raw_concept_features(trade_date, concepts, all_stocks)
# 4. 第二阶段:计算 Z-Score 特征(多进程)
print(f"\n{'='*60}")
print("第二阶段:计算 Z-Score 特征(滚动基线)")
print(f"{'='*60}")
# 从第 baseline_days 天开始(前面的没有足够历史)
start_idx = args.baseline_days
processable_days = trading_days[start_idx:]
if not processable_days:
print(f"错误:需要至少 {args.baseline_days + 1} 天的数据")
return
print(f"可处理日期: {processable_days[0]} ~ {processable_days[-1]} ({len(processable_days)} 天)")
print(f"跳过前 {start_idx} 天(基线预热期)")
# 构建任务
tasks = []
for i, trade_date in enumerate(trading_days):
if i >= start_idx:
tasks.append((trade_date, i, concepts, all_stocks, trading_days))
print(f"开始处理 {len(tasks)} 个交易日({args.workers} 进程并行)...")
success_count = 0
failed_dates = []
# 使用进程池初始化器
with ProcessPoolExecutor(max_workers=args.workers, initializer=init_process_connections) as executor:
futures = {executor.submit(process_single_day_v2, task): task[0] for task in tasks}
with tqdm(total=len(futures), desc="处理进度", unit="") as pbar:
for future in as_completed(futures):
trade_date = futures[future]
try:
result_date, success = future.result()
if success:
success_count += 1
else:
failed_dates.append(result_date)
except Exception as e:
print(f"\n[{trade_date}] 进程异常: {e}")
failed_dates.append(trade_date)
pbar.update(1)
print("\n" + "=" * 60)
print(f"处理完成: {success_count}/{len(tasks)} 个交易日")
if failed_dates:
print(f"失败日期: {failed_dates[:10]}{'...' if len(failed_dates) > 10 else ''}")
print(f"数据保存在: {OUTPUT_DIR}")
print("=" * 60)
if __name__ == "__main__":
main()