update pay ui
This commit is contained in:
715
ml/prepare_data_v2.py
Normal file
715
ml/prepare_data_v2.py
Normal file
@@ -0,0 +1,715 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
数据准备 V2 - 基于时间片对齐的特征计算(修复版)
|
||||
|
||||
核心改进:
|
||||
1. 时间片对齐:9:35 和历史的 9:35 比,而不是和前30分钟比
|
||||
2. Z-Score 特征:相对于同时间片历史分布的偏离程度
|
||||
3. 滚动窗口基线:每个日期使用它之前 N 天的数据作为基线(不是固定的最后 N 天!)
|
||||
4. 基于 Z-Score 的动量:消除一天内波动率异构性
|
||||
|
||||
修复:
|
||||
- 滚动窗口基线:避免未来数据泄露
|
||||
- Z-Score 动量:消除早盘/尾盘波动率差异
|
||||
- 进程级数据库单例:避免连接池爆炸
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from datetime import datetime, timedelta
|
||||
from sqlalchemy import create_engine, text
|
||||
from elasticsearch import Elasticsearch
|
||||
from clickhouse_driver import Client
|
||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||
from typing import Dict, List, Tuple, Optional
|
||||
from tqdm import tqdm
|
||||
from collections import defaultdict
|
||||
import warnings
|
||||
import pickle
|
||||
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
# ==================== 配置 ====================
|
||||
|
||||
MYSQL_URL = "mysql+pymysql://root:Zzl5588161!@192.168.1.5:3306/stock"
|
||||
ES_HOST = 'http://127.0.0.1:9200'
|
||||
ES_INDEX = 'concept_library_v3'
|
||||
|
||||
CLICKHOUSE_CONFIG = {
|
||||
'host': '127.0.0.1',
|
||||
'port': 9000,
|
||||
'user': 'default',
|
||||
'password': 'Zzl33818!',
|
||||
'database': 'stock'
|
||||
}
|
||||
|
||||
REFERENCE_INDEX = '000001.SH'
|
||||
|
||||
# 输出目录
|
||||
OUTPUT_DIR = os.path.join(os.path.dirname(__file__), 'data_v2')
|
||||
BASELINE_DIR = os.path.join(OUTPUT_DIR, 'baselines')
|
||||
RAW_CACHE_DIR = os.path.join(OUTPUT_DIR, 'raw_cache')
|
||||
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
||||
os.makedirs(BASELINE_DIR, exist_ok=True)
|
||||
os.makedirs(RAW_CACHE_DIR, exist_ok=True)
|
||||
|
||||
# 特征配置
|
||||
CONFIG = {
|
||||
'baseline_days': 20, # 滚动窗口大小
|
||||
'min_baseline_samples': 10, # 最少需要10个样本才算有效基线
|
||||
'limit_up_threshold': 9.8,
|
||||
'limit_down_threshold': -9.8,
|
||||
'zscore_clip': 5.0,
|
||||
}
|
||||
|
||||
# 特征列表
|
||||
FEATURES_V2 = [
|
||||
'alpha', 'alpha_zscore', 'amt_zscore', 'rank_zscore',
|
||||
'momentum_3m', 'momentum_5m', 'limit_up_ratio',
|
||||
]
|
||||
|
||||
# ==================== 进程级单例(避免连接池爆炸)====================
|
||||
|
||||
# 进程级全局变量
|
||||
_process_mysql_engine = None
|
||||
_process_es_client = None
|
||||
_process_ch_client = None
|
||||
|
||||
|
||||
def init_process_connections():
|
||||
"""进程初始化时调用,创建连接(单例)"""
|
||||
global _process_mysql_engine, _process_es_client, _process_ch_client
|
||||
_process_mysql_engine = create_engine(MYSQL_URL, echo=False, pool_pre_ping=True, pool_size=5)
|
||||
_process_es_client = Elasticsearch([ES_HOST])
|
||||
_process_ch_client = Client(**CLICKHOUSE_CONFIG)
|
||||
|
||||
|
||||
def get_mysql_engine():
|
||||
"""获取进程级 MySQL Engine(单例)"""
|
||||
global _process_mysql_engine
|
||||
if _process_mysql_engine is None:
|
||||
_process_mysql_engine = create_engine(MYSQL_URL, echo=False, pool_pre_ping=True, pool_size=5)
|
||||
return _process_mysql_engine
|
||||
|
||||
|
||||
def get_es_client():
|
||||
"""获取进程级 ES 客户端(单例)"""
|
||||
global _process_es_client
|
||||
if _process_es_client is None:
|
||||
_process_es_client = Elasticsearch([ES_HOST])
|
||||
return _process_es_client
|
||||
|
||||
|
||||
def get_ch_client():
|
||||
"""获取进程级 ClickHouse 客户端(单例)"""
|
||||
global _process_ch_client
|
||||
if _process_ch_client is None:
|
||||
_process_ch_client = Client(**CLICKHOUSE_CONFIG)
|
||||
return _process_ch_client
|
||||
|
||||
|
||||
# ==================== 工具函数 ====================
|
||||
|
||||
def code_to_ch_format(code: str) -> str:
|
||||
if not code or len(code) != 6 or not code.isdigit():
|
||||
return None
|
||||
if code.startswith('6'):
|
||||
return f"{code}.SH"
|
||||
elif code.startswith('0') or code.startswith('3'):
|
||||
return f"{code}.SZ"
|
||||
else:
|
||||
return f"{code}.BJ"
|
||||
|
||||
|
||||
def time_to_slot(ts) -> str:
|
||||
"""将时间戳转换为时间片(HH:MM格式)"""
|
||||
if isinstance(ts, str):
|
||||
return ts
|
||||
return ts.strftime('%H:%M')
|
||||
|
||||
|
||||
# ==================== 获取概念列表 ====================
|
||||
|
||||
def get_all_concepts() -> List[dict]:
|
||||
"""从ES获取所有叶子概念"""
|
||||
es_client = get_es_client()
|
||||
concepts = []
|
||||
|
||||
query = {
|
||||
"query": {"match_all": {}},
|
||||
"size": 100,
|
||||
"_source": ["concept_id", "concept", "stocks"]
|
||||
}
|
||||
|
||||
resp = es_client.search(index=ES_INDEX, body=query, scroll='2m')
|
||||
scroll_id = resp['_scroll_id']
|
||||
hits = resp['hits']['hits']
|
||||
|
||||
while len(hits) > 0:
|
||||
for hit in hits:
|
||||
source = hit['_source']
|
||||
stocks = []
|
||||
if 'stocks' in source and isinstance(source['stocks'], list):
|
||||
for stock in source['stocks']:
|
||||
if isinstance(stock, dict) and 'code' in stock and stock['code']:
|
||||
stocks.append(stock['code'])
|
||||
|
||||
if stocks:
|
||||
concepts.append({
|
||||
'concept_id': source.get('concept_id'),
|
||||
'concept_name': source.get('concept'),
|
||||
'stocks': stocks
|
||||
})
|
||||
|
||||
resp = es_client.scroll(scroll_id=scroll_id, scroll='2m')
|
||||
scroll_id = resp['_scroll_id']
|
||||
hits = resp['hits']['hits']
|
||||
|
||||
es_client.clear_scroll(scroll_id=scroll_id)
|
||||
print(f"获取到 {len(concepts)} 个概念")
|
||||
return concepts
|
||||
|
||||
|
||||
# ==================== 获取交易日列表 ====================
|
||||
|
||||
def get_trading_days(start_date: str, end_date: str) -> List[str]:
|
||||
"""获取交易日列表"""
|
||||
client = get_ch_client()
|
||||
|
||||
query = f"""
|
||||
SELECT DISTINCT toDate(timestamp) as trade_date
|
||||
FROM stock_minute
|
||||
WHERE toDate(timestamp) >= '{start_date}'
|
||||
AND toDate(timestamp) <= '{end_date}'
|
||||
ORDER BY trade_date
|
||||
"""
|
||||
|
||||
result = client.execute(query)
|
||||
days = [row[0].strftime('%Y-%m-%d') for row in result]
|
||||
if days:
|
||||
print(f"找到 {len(days)} 个交易日: {days[0]} ~ {days[-1]}")
|
||||
return days
|
||||
|
||||
|
||||
# ==================== 获取昨收价 ====================
|
||||
|
||||
def get_prev_close(stock_codes: List[str], trade_date: str) -> Dict[str, float]:
|
||||
"""获取昨收价(上一交易日的收盘价 F007N)"""
|
||||
valid_codes = [c for c in stock_codes if c and len(c) == 6 and c.isdigit()]
|
||||
if not valid_codes:
|
||||
return {}
|
||||
|
||||
codes_str = "','".join(valid_codes)
|
||||
query = f"""
|
||||
SELECT SECCODE, F007N
|
||||
FROM ea_trade
|
||||
WHERE SECCODE IN ('{codes_str}')
|
||||
AND TRADEDATE = (
|
||||
SELECT MAX(TRADEDATE) FROM ea_trade WHERE TRADEDATE < '{trade_date}'
|
||||
)
|
||||
AND F007N IS NOT NULL AND F007N > 0
|
||||
"""
|
||||
|
||||
try:
|
||||
engine = get_mysql_engine()
|
||||
with engine.connect() as conn:
|
||||
result = conn.execute(text(query))
|
||||
return {row[0]: float(row[1]) for row in result if row[1]}
|
||||
except Exception as e:
|
||||
print(f"获取昨收价失败: {e}")
|
||||
return {}
|
||||
|
||||
|
||||
def get_index_prev_close(trade_date: str, index_code: str = REFERENCE_INDEX) -> float:
|
||||
"""获取指数昨收价"""
|
||||
code_no_suffix = index_code.split('.')[0]
|
||||
|
||||
try:
|
||||
engine = get_mysql_engine()
|
||||
with engine.connect() as conn:
|
||||
result = conn.execute(text("""
|
||||
SELECT F006N FROM ea_exchangetrade
|
||||
WHERE INDEXCODE = :code AND TRADEDATE < :today
|
||||
ORDER BY TRADEDATE DESC LIMIT 1
|
||||
"""), {'code': code_no_suffix, 'today': trade_date}).fetchone()
|
||||
|
||||
if result and result[0]:
|
||||
return float(result[0])
|
||||
except Exception as e:
|
||||
print(f"获取指数昨收失败: {e}")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# ==================== 获取分钟数据 ====================
|
||||
|
||||
def get_daily_stock_data(trade_date: str, stock_codes: List[str]) -> pd.DataFrame:
|
||||
"""获取单日所有股票的分钟数据"""
|
||||
client = get_ch_client()
|
||||
|
||||
ch_codes = []
|
||||
code_map = {}
|
||||
for code in stock_codes:
|
||||
ch_code = code_to_ch_format(code)
|
||||
if ch_code:
|
||||
ch_codes.append(ch_code)
|
||||
code_map[ch_code] = code
|
||||
|
||||
if not ch_codes:
|
||||
return pd.DataFrame()
|
||||
|
||||
ch_codes_str = "','".join(ch_codes)
|
||||
|
||||
query = f"""
|
||||
SELECT code, timestamp, close, volume, amt
|
||||
FROM stock_minute
|
||||
WHERE toDate(timestamp) = '{trade_date}'
|
||||
AND code IN ('{ch_codes_str}')
|
||||
ORDER BY code, timestamp
|
||||
"""
|
||||
|
||||
result = client.execute(query)
|
||||
if not result:
|
||||
return pd.DataFrame()
|
||||
|
||||
df = pd.DataFrame(result, columns=['ch_code', 'timestamp', 'close', 'volume', 'amt'])
|
||||
df['code'] = df['ch_code'].map(code_map)
|
||||
df = df.dropna(subset=['code'])
|
||||
|
||||
return df[['code', 'timestamp', 'close', 'volume', 'amt']]
|
||||
|
||||
|
||||
def get_daily_index_data(trade_date: str, index_code: str = REFERENCE_INDEX) -> pd.DataFrame:
|
||||
"""获取单日指数分钟数据"""
|
||||
client = get_ch_client()
|
||||
|
||||
query = f"""
|
||||
SELECT timestamp, close, volume, amt
|
||||
FROM index_minute
|
||||
WHERE toDate(timestamp) = '{trade_date}'
|
||||
AND code = '{index_code}'
|
||||
ORDER BY timestamp
|
||||
"""
|
||||
|
||||
result = client.execute(query)
|
||||
if not result:
|
||||
return pd.DataFrame()
|
||||
|
||||
df = pd.DataFrame(result, columns=['timestamp', 'close', 'volume', 'amt'])
|
||||
return df
|
||||
|
||||
|
||||
# ==================== 计算原始概念特征(单日)====================
|
||||
|
||||
def compute_raw_concept_features(
|
||||
trade_date: str,
|
||||
concepts: List[dict],
|
||||
all_stocks: List[str]
|
||||
) -> pd.DataFrame:
|
||||
"""计算单日概念的原始特征(alpha, amt, rank_pct, limit_up_ratio)"""
|
||||
# 检查缓存
|
||||
cache_file = os.path.join(RAW_CACHE_DIR, f'raw_{trade_date}.parquet')
|
||||
if os.path.exists(cache_file):
|
||||
return pd.read_parquet(cache_file)
|
||||
|
||||
# 获取数据
|
||||
stock_df = get_daily_stock_data(trade_date, all_stocks)
|
||||
if stock_df.empty:
|
||||
return pd.DataFrame()
|
||||
|
||||
index_df = get_daily_index_data(trade_date)
|
||||
if index_df.empty:
|
||||
return pd.DataFrame()
|
||||
|
||||
# 获取昨收价
|
||||
prev_close = get_prev_close(all_stocks, trade_date)
|
||||
index_prev_close = get_index_prev_close(trade_date)
|
||||
|
||||
if not prev_close or not index_prev_close:
|
||||
return pd.DataFrame()
|
||||
|
||||
# 计算涨跌幅
|
||||
stock_df['prev_close'] = stock_df['code'].map(prev_close)
|
||||
stock_df = stock_df.dropna(subset=['prev_close'])
|
||||
stock_df['change_pct'] = (stock_df['close'] - stock_df['prev_close']) / stock_df['prev_close'] * 100
|
||||
|
||||
index_df['change_pct'] = (index_df['close'] - index_prev_close) / index_prev_close * 100
|
||||
index_change_map = dict(zip(index_df['timestamp'], index_df['change_pct']))
|
||||
|
||||
# 获取所有时间点
|
||||
timestamps = sorted(stock_df['timestamp'].unique())
|
||||
|
||||
# 概念到股票的映射
|
||||
concept_stocks = {c['concept_id']: set(c['stocks']) for c in concepts}
|
||||
|
||||
results = []
|
||||
|
||||
for ts in timestamps:
|
||||
ts_stock_data = stock_df[stock_df['timestamp'] == ts]
|
||||
index_change = index_change_map.get(ts, 0)
|
||||
|
||||
stock_change = dict(zip(ts_stock_data['code'], ts_stock_data['change_pct']))
|
||||
stock_amt = dict(zip(ts_stock_data['code'], ts_stock_data['amt']))
|
||||
|
||||
concept_features = []
|
||||
|
||||
for concept_id, stocks in concept_stocks.items():
|
||||
concept_changes = [stock_change[s] for s in stocks if s in stock_change]
|
||||
concept_amts = [stock_amt.get(s, 0) for s in stocks if s in stock_change]
|
||||
|
||||
if not concept_changes:
|
||||
continue
|
||||
|
||||
avg_change = np.mean(concept_changes)
|
||||
total_amt = sum(concept_amts)
|
||||
alpha = avg_change - index_change
|
||||
|
||||
limit_up_count = sum(1 for c in concept_changes if c >= CONFIG['limit_up_threshold'])
|
||||
limit_up_ratio = limit_up_count / len(concept_changes)
|
||||
|
||||
concept_features.append({
|
||||
'concept_id': concept_id,
|
||||
'alpha': alpha,
|
||||
'total_amt': total_amt,
|
||||
'limit_up_ratio': limit_up_ratio,
|
||||
'stock_count': len(concept_changes),
|
||||
})
|
||||
|
||||
if not concept_features:
|
||||
continue
|
||||
|
||||
concept_df = pd.DataFrame(concept_features)
|
||||
concept_df['rank_pct'] = concept_df['alpha'].rank(pct=True)
|
||||
concept_df['timestamp'] = ts
|
||||
concept_df['time_slot'] = time_to_slot(ts)
|
||||
concept_df['trade_date'] = trade_date
|
||||
|
||||
results.append(concept_df)
|
||||
|
||||
if not results:
|
||||
return pd.DataFrame()
|
||||
|
||||
result_df = pd.concat(results, ignore_index=True)
|
||||
|
||||
# 保存缓存
|
||||
result_df.to_parquet(cache_file, index=False)
|
||||
|
||||
return result_df
|
||||
|
||||
|
||||
# ==================== 滚动窗口基线计算 ====================
|
||||
|
||||
def compute_rolling_baseline(
|
||||
historical_data: pd.DataFrame,
|
||||
concept_id: str
|
||||
) -> Dict[str, Dict]:
|
||||
"""
|
||||
计算单个概念的滚动基线
|
||||
|
||||
返回: {time_slot: {alpha_mean, alpha_std, amt_mean, amt_std, rank_mean, rank_std, sample_count}}
|
||||
"""
|
||||
if historical_data.empty:
|
||||
return {}
|
||||
|
||||
concept_data = historical_data[historical_data['concept_id'] == concept_id]
|
||||
if concept_data.empty:
|
||||
return {}
|
||||
|
||||
baseline_dict = {}
|
||||
|
||||
for time_slot, group in concept_data.groupby('time_slot'):
|
||||
if len(group) < CONFIG['min_baseline_samples']:
|
||||
continue
|
||||
|
||||
alpha_std = group['alpha'].std()
|
||||
amt_std = group['total_amt'].std()
|
||||
rank_std = group['rank_pct'].std()
|
||||
|
||||
baseline_dict[time_slot] = {
|
||||
'alpha_mean': group['alpha'].mean(),
|
||||
'alpha_std': max(alpha_std if pd.notna(alpha_std) else 1.0, 0.1),
|
||||
'amt_mean': group['total_amt'].mean(),
|
||||
'amt_std': max(amt_std if pd.notna(amt_std) else group['total_amt'].mean() * 0.5, 1.0),
|
||||
'rank_mean': group['rank_pct'].mean(),
|
||||
'rank_std': max(rank_std if pd.notna(rank_std) else 0.2, 0.05),
|
||||
'sample_count': len(group),
|
||||
}
|
||||
|
||||
return baseline_dict
|
||||
|
||||
|
||||
# ==================== 计算单日 Z-Score 特征(带滚动基线)====================
|
||||
|
||||
def compute_zscore_features_rolling(
|
||||
trade_date: str,
|
||||
concepts: List[dict],
|
||||
all_stocks: List[str],
|
||||
historical_raw_data: pd.DataFrame # 该日期之前 N 天的原始数据
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
计算单日的 Z-Score 特征(使用滚动窗口基线)
|
||||
|
||||
关键改进:
|
||||
1. 基线只使用 trade_date 之前的数据(无未来泄露)
|
||||
2. 动量基于 Z-Score 计算(消除波动率异构性)
|
||||
"""
|
||||
# 计算当日原始特征
|
||||
raw_df = compute_raw_concept_features(trade_date, concepts, all_stocks)
|
||||
|
||||
if raw_df.empty:
|
||||
return pd.DataFrame()
|
||||
|
||||
zscore_records = []
|
||||
|
||||
for concept_id, group in raw_df.groupby('concept_id'):
|
||||
# 计算该概念的滚动基线(只用历史数据)
|
||||
baseline_dict = compute_rolling_baseline(historical_raw_data, concept_id)
|
||||
|
||||
if not baseline_dict:
|
||||
continue
|
||||
|
||||
# 按时间排序
|
||||
group = group.sort_values('timestamp').reset_index(drop=True)
|
||||
|
||||
# Z-Score 历史(用于计算基于 Z-Score 的动量)
|
||||
zscore_history = []
|
||||
|
||||
for idx, row in group.iterrows():
|
||||
time_slot = row['time_slot']
|
||||
|
||||
if time_slot not in baseline_dict:
|
||||
continue
|
||||
|
||||
bl = baseline_dict[time_slot]
|
||||
|
||||
# 计算 Z-Score
|
||||
alpha_zscore = (row['alpha'] - bl['alpha_mean']) / bl['alpha_std']
|
||||
amt_zscore = (row['total_amt'] - bl['amt_mean']) / bl['amt_std']
|
||||
rank_zscore = (row['rank_pct'] - bl['rank_mean']) / bl['rank_std']
|
||||
|
||||
# 截断极端值
|
||||
clip = CONFIG['zscore_clip']
|
||||
alpha_zscore = np.clip(alpha_zscore, -clip, clip)
|
||||
amt_zscore = np.clip(amt_zscore, -clip, clip)
|
||||
rank_zscore = np.clip(rank_zscore, -clip, clip)
|
||||
|
||||
# 记录 Z-Score 历史
|
||||
zscore_history.append(alpha_zscore)
|
||||
|
||||
# 基于 Z-Score 计算动量(消除波动率异构性)
|
||||
momentum_3m = 0.0
|
||||
momentum_5m = 0.0
|
||||
|
||||
if len(zscore_history) >= 3:
|
||||
recent_3 = zscore_history[-3:]
|
||||
older_3 = zscore_history[-6:-3] if len(zscore_history) >= 6 else [zscore_history[0]]
|
||||
momentum_3m = np.mean(recent_3) - np.mean(older_3)
|
||||
|
||||
if len(zscore_history) >= 5:
|
||||
recent_5 = zscore_history[-5:]
|
||||
older_5 = zscore_history[-10:-5] if len(zscore_history) >= 10 else [zscore_history[0]]
|
||||
momentum_5m = np.mean(recent_5) - np.mean(older_5)
|
||||
|
||||
zscore_records.append({
|
||||
'concept_id': concept_id,
|
||||
'timestamp': row['timestamp'],
|
||||
'time_slot': time_slot,
|
||||
'trade_date': trade_date,
|
||||
# 原始特征
|
||||
'alpha': row['alpha'],
|
||||
'total_amt': row['total_amt'],
|
||||
'limit_up_ratio': row['limit_up_ratio'],
|
||||
'stock_count': row['stock_count'],
|
||||
'rank_pct': row['rank_pct'],
|
||||
# Z-Score 特征
|
||||
'alpha_zscore': alpha_zscore,
|
||||
'amt_zscore': amt_zscore,
|
||||
'rank_zscore': rank_zscore,
|
||||
# 基于 Z-Score 的动量
|
||||
'momentum_3m': momentum_3m,
|
||||
'momentum_5m': momentum_5m,
|
||||
})
|
||||
|
||||
if not zscore_records:
|
||||
return pd.DataFrame()
|
||||
|
||||
return pd.DataFrame(zscore_records)
|
||||
|
||||
|
||||
# ==================== 多进程处理 ====================
|
||||
|
||||
def process_single_day_v2(args) -> Tuple[str, bool]:
|
||||
"""处理单个交易日(多进程版本)"""
|
||||
trade_date, day_index, concepts, all_stocks, all_trading_days = args
|
||||
output_file = os.path.join(OUTPUT_DIR, f'features_v2_{trade_date}.parquet')
|
||||
|
||||
if os.path.exists(output_file):
|
||||
return (trade_date, True)
|
||||
|
||||
try:
|
||||
# 计算滚动窗口范围(该日期之前的 N 天)
|
||||
baseline_days = CONFIG['baseline_days']
|
||||
|
||||
# 找出 trade_date 之前的交易日
|
||||
start_idx = max(0, day_index - baseline_days)
|
||||
end_idx = day_index # 不包含当天
|
||||
|
||||
if end_idx <= start_idx:
|
||||
# 没有足够的历史数据
|
||||
return (trade_date, False)
|
||||
|
||||
historical_days = all_trading_days[start_idx:end_idx]
|
||||
|
||||
# 加载历史原始数据
|
||||
historical_dfs = []
|
||||
for hist_date in historical_days:
|
||||
cache_file = os.path.join(RAW_CACHE_DIR, f'raw_{hist_date}.parquet')
|
||||
if os.path.exists(cache_file):
|
||||
historical_dfs.append(pd.read_parquet(cache_file))
|
||||
else:
|
||||
# 需要计算
|
||||
hist_df = compute_raw_concept_features(hist_date, concepts, all_stocks)
|
||||
if not hist_df.empty:
|
||||
historical_dfs.append(hist_df)
|
||||
|
||||
if not historical_dfs:
|
||||
return (trade_date, False)
|
||||
|
||||
historical_raw_data = pd.concat(historical_dfs, ignore_index=True)
|
||||
|
||||
# 计算当日 Z-Score 特征(使用滚动基线)
|
||||
df = compute_zscore_features_rolling(trade_date, concepts, all_stocks, historical_raw_data)
|
||||
|
||||
if df.empty:
|
||||
return (trade_date, False)
|
||||
|
||||
df.to_parquet(output_file, index=False)
|
||||
return (trade_date, True)
|
||||
|
||||
except Exception as e:
|
||||
print(f"[{trade_date}] 处理失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return (trade_date, False)
|
||||
|
||||
|
||||
# ==================== 主流程 ====================
|
||||
|
||||
def main():
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description='准备训练数据 V2(滚动窗口基线 + Z-Score 动量)')
|
||||
parser.add_argument('--start', type=str, default='2022-01-01', help='开始日期')
|
||||
parser.add_argument('--end', type=str, default=None, help='结束日期(默认今天)')
|
||||
parser.add_argument('--workers', type=int, default=18, help='并行进程数')
|
||||
parser.add_argument('--baseline-days', type=int, default=20, help='滚动基线窗口大小')
|
||||
parser.add_argument('--force', action='store_true', help='强制重新计算(忽略缓存)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
end_date = args.end or datetime.now().strftime('%Y-%m-%d')
|
||||
CONFIG['baseline_days'] = args.baseline_days
|
||||
|
||||
print("=" * 60)
|
||||
print("数据准备 V2 - 滚动窗口基线 + Z-Score 动量")
|
||||
print("=" * 60)
|
||||
print(f"日期范围: {args.start} ~ {end_date}")
|
||||
print(f"并行进程数: {args.workers}")
|
||||
print(f"滚动基线窗口: {args.baseline_days} 天")
|
||||
|
||||
# 初始化主进程连接
|
||||
init_process_connections()
|
||||
|
||||
# 1. 获取概念列表
|
||||
concepts = get_all_concepts()
|
||||
all_stocks = list(set(s for c in concepts for s in c['stocks']))
|
||||
print(f"股票总数: {len(all_stocks)}")
|
||||
|
||||
# 2. 获取交易日列表
|
||||
trading_days = get_trading_days(args.start, end_date)
|
||||
|
||||
if not trading_days:
|
||||
print("无交易日数据")
|
||||
return
|
||||
|
||||
# 3. 第一阶段:预计算所有原始特征(用于缓存)
|
||||
print(f"\n{'='*60}")
|
||||
print("第一阶段:预计算原始特征(用于滚动基线)")
|
||||
print(f"{'='*60}")
|
||||
|
||||
# 如果强制重新计算,删除缓存
|
||||
if args.force:
|
||||
import shutil
|
||||
if os.path.exists(RAW_CACHE_DIR):
|
||||
shutil.rmtree(RAW_CACHE_DIR)
|
||||
os.makedirs(RAW_CACHE_DIR, exist_ok=True)
|
||||
if os.path.exists(OUTPUT_DIR):
|
||||
for f in os.listdir(OUTPUT_DIR):
|
||||
if f.startswith('features_v2_'):
|
||||
os.remove(os.path.join(OUTPUT_DIR, f))
|
||||
|
||||
# 单线程预计算原始特征(因为需要顺序缓存)
|
||||
print(f"预计算 {len(trading_days)} 天的原始特征...")
|
||||
for trade_date in tqdm(trading_days, desc="预计算原始特征"):
|
||||
cache_file = os.path.join(RAW_CACHE_DIR, f'raw_{trade_date}.parquet')
|
||||
if not os.path.exists(cache_file):
|
||||
compute_raw_concept_features(trade_date, concepts, all_stocks)
|
||||
|
||||
# 4. 第二阶段:计算 Z-Score 特征(多进程)
|
||||
print(f"\n{'='*60}")
|
||||
print("第二阶段:计算 Z-Score 特征(滚动基线)")
|
||||
print(f"{'='*60}")
|
||||
|
||||
# 从第 baseline_days 天开始(前面的没有足够历史)
|
||||
start_idx = args.baseline_days
|
||||
processable_days = trading_days[start_idx:]
|
||||
|
||||
if not processable_days:
|
||||
print(f"错误:需要至少 {args.baseline_days + 1} 天的数据")
|
||||
return
|
||||
|
||||
print(f"可处理日期: {processable_days[0]} ~ {processable_days[-1]} ({len(processable_days)} 天)")
|
||||
print(f"跳过前 {start_idx} 天(基线预热期)")
|
||||
|
||||
# 构建任务
|
||||
tasks = []
|
||||
for i, trade_date in enumerate(trading_days):
|
||||
if i >= start_idx:
|
||||
tasks.append((trade_date, i, concepts, all_stocks, trading_days))
|
||||
|
||||
print(f"开始处理 {len(tasks)} 个交易日({args.workers} 进程并行)...")
|
||||
|
||||
success_count = 0
|
||||
failed_dates = []
|
||||
|
||||
# 使用进程池初始化器
|
||||
with ProcessPoolExecutor(max_workers=args.workers, initializer=init_process_connections) as executor:
|
||||
futures = {executor.submit(process_single_day_v2, task): task[0] for task in tasks}
|
||||
|
||||
with tqdm(total=len(futures), desc="处理进度", unit="天") as pbar:
|
||||
for future in as_completed(futures):
|
||||
trade_date = futures[future]
|
||||
try:
|
||||
result_date, success = future.result()
|
||||
if success:
|
||||
success_count += 1
|
||||
else:
|
||||
failed_dates.append(result_date)
|
||||
except Exception as e:
|
||||
print(f"\n[{trade_date}] 进程异常: {e}")
|
||||
failed_dates.append(trade_date)
|
||||
pbar.update(1)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print(f"处理完成: {success_count}/{len(tasks)} 个交易日")
|
||||
if failed_dates:
|
||||
print(f"失败日期: {failed_dates[:10]}{'...' if len(failed_dates) > 10 else ''}")
|
||||
print(f"数据保存在: {OUTPUT_DIR}")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user