538 lines
16 KiB
Python
538 lines
16 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
数据准备脚本 - 为 Transformer Autoencoder 准备训练数据
|
||
|
||
从 ClickHouse 提取历史分钟数据,计算以下特征:
|
||
1. alpha - 超额收益(概念涨幅 - 大盘涨幅)
|
||
2. alpha_delta - Alpha 变化率(5分钟)
|
||
3. amt_ratio - 成交额相对均值(当前/过去20分钟均值)
|
||
4. amt_delta - 成交额变化率
|
||
5. rank_pct - Alpha 排名百分位
|
||
6. limit_up_ratio - 涨停股占比
|
||
|
||
输出:按交易日存储的特征文件(parquet格式)
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import numpy as np
|
||
import pandas as pd
|
||
from datetime import datetime, timedelta, date
|
||
from sqlalchemy import create_engine, text
|
||
from elasticsearch import Elasticsearch
|
||
from clickhouse_driver import Client
|
||
import hashlib
|
||
import json
|
||
import logging
|
||
from typing import Dict, List, Set, Tuple
|
||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||
from multiprocessing import Manager
|
||
import multiprocessing
|
||
import warnings
|
||
warnings.filterwarnings('ignore')
|
||
|
||
# ==================== 配置 ====================
|
||
|
||
MYSQL_ENGINE = create_engine(
|
||
"mysql+pymysql://root:Zzl5588161!@192.168.1.5:3306/stock",
|
||
echo=False
|
||
)
|
||
|
||
ES_CLIENT = Elasticsearch(['http://127.0.0.1:9200'])
|
||
ES_INDEX = 'concept_library_v3'
|
||
|
||
CLICKHOUSE_CONFIG = {
|
||
'host': '127.0.0.1',
|
||
'port': 9000,
|
||
'user': 'default',
|
||
'password': 'Zzl33818!',
|
||
'database': 'stock'
|
||
}
|
||
|
||
# 输出目录
|
||
OUTPUT_DIR = os.path.join(os.path.dirname(__file__), 'data')
|
||
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
||
|
||
# 特征计算参数
|
||
FEATURE_CONFIG = {
|
||
'alpha_delta_window': 5, # Alpha变化窗口(分钟)
|
||
'amt_ma_window': 20, # 成交额均值窗口(分钟)
|
||
'limit_up_threshold': 9.8, # 涨停阈值(%)
|
||
'limit_down_threshold': -9.8, # 跌停阈值(%)
|
||
}
|
||
|
||
REFERENCE_INDEX = '000001.SH'
|
||
|
||
# ==================== 日志 ====================
|
||
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format='%(asctime)s - %(levelname)s - %(message)s'
|
||
)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# ==================== 工具函数 ====================
|
||
|
||
def get_ch_client():
|
||
return Client(**CLICKHOUSE_CONFIG)
|
||
|
||
|
||
def generate_id(name: str) -> str:
|
||
return hashlib.md5(name.encode('utf-8')).hexdigest()[:16]
|
||
|
||
|
||
def code_to_ch_format(code: str) -> str:
|
||
if not code or len(code) != 6 or not code.isdigit():
|
||
return None
|
||
if code.startswith('6'):
|
||
return f"{code}.SH"
|
||
elif code.startswith('0') or code.startswith('3'):
|
||
return f"{code}.SZ"
|
||
else:
|
||
return f"{code}.BJ"
|
||
|
||
|
||
# ==================== 获取概念列表 ====================
|
||
|
||
def get_all_concepts() -> List[dict]:
|
||
"""从ES获取所有叶子概念"""
|
||
concepts = []
|
||
|
||
query = {
|
||
"query": {"match_all": {}},
|
||
"size": 100,
|
||
"_source": ["concept_id", "concept", "stocks"]
|
||
}
|
||
|
||
resp = ES_CLIENT.search(index=ES_INDEX, body=query, scroll='2m')
|
||
scroll_id = resp['_scroll_id']
|
||
hits = resp['hits']['hits']
|
||
|
||
while len(hits) > 0:
|
||
for hit in hits:
|
||
source = hit['_source']
|
||
stocks = []
|
||
if 'stocks' in source and isinstance(source['stocks'], list):
|
||
for stock in source['stocks']:
|
||
if isinstance(stock, dict) and 'code' in stock and stock['code']:
|
||
stocks.append(stock['code'])
|
||
|
||
if stocks:
|
||
concepts.append({
|
||
'concept_id': source.get('concept_id'),
|
||
'concept_name': source.get('concept'),
|
||
'stocks': stocks
|
||
})
|
||
|
||
resp = ES_CLIENT.scroll(scroll_id=scroll_id, scroll='2m')
|
||
scroll_id = resp['_scroll_id']
|
||
hits = resp['hits']['hits']
|
||
|
||
ES_CLIENT.clear_scroll(scroll_id=scroll_id)
|
||
print(f"获取到 {len(concepts)} 个概念")
|
||
return concepts
|
||
|
||
|
||
# ==================== 获取交易日列表 ====================
|
||
|
||
def get_trading_days(start_date: str, end_date: str) -> List[str]:
|
||
"""获取交易日列表"""
|
||
client = get_ch_client()
|
||
|
||
query = f"""
|
||
SELECT DISTINCT toDate(timestamp) as trade_date
|
||
FROM stock_minute
|
||
WHERE toDate(timestamp) >= '{start_date}'
|
||
AND toDate(timestamp) <= '{end_date}'
|
||
ORDER BY trade_date
|
||
"""
|
||
|
||
result = client.execute(query)
|
||
days = [row[0].strftime('%Y-%m-%d') for row in result]
|
||
print(f"找到 {len(days)} 个交易日: {days[0]} ~ {days[-1]}")
|
||
return days
|
||
|
||
|
||
# ==================== 获取单日数据 ====================
|
||
|
||
def get_daily_stock_data(trade_date: str, stock_codes: List[str]) -> pd.DataFrame:
|
||
"""获取单日所有股票的分钟数据"""
|
||
client = get_ch_client()
|
||
|
||
# 转换代码格式
|
||
ch_codes = []
|
||
code_map = {}
|
||
for code in stock_codes:
|
||
ch_code = code_to_ch_format(code)
|
||
if ch_code:
|
||
ch_codes.append(ch_code)
|
||
code_map[ch_code] = code
|
||
|
||
if not ch_codes:
|
||
return pd.DataFrame()
|
||
|
||
ch_codes_str = "','".join(ch_codes)
|
||
|
||
query = f"""
|
||
SELECT
|
||
code,
|
||
timestamp,
|
||
close,
|
||
volume,
|
||
amt
|
||
FROM stock_minute
|
||
WHERE toDate(timestamp) = '{trade_date}'
|
||
AND code IN ('{ch_codes_str}')
|
||
ORDER BY code, timestamp
|
||
"""
|
||
|
||
result = client.execute(query)
|
||
|
||
if not result:
|
||
return pd.DataFrame()
|
||
|
||
df = pd.DataFrame(result, columns=['ch_code', 'timestamp', 'close', 'volume', 'amt'])
|
||
df['code'] = df['ch_code'].map(code_map)
|
||
df = df.dropna(subset=['code'])
|
||
|
||
return df[['code', 'timestamp', 'close', 'volume', 'amt']]
|
||
|
||
|
||
def get_daily_index_data(trade_date: str, index_code: str = REFERENCE_INDEX) -> pd.DataFrame:
|
||
"""获取单日指数分钟数据"""
|
||
client = get_ch_client()
|
||
|
||
query = f"""
|
||
SELECT
|
||
timestamp,
|
||
close,
|
||
volume,
|
||
amt
|
||
FROM index_minute
|
||
WHERE toDate(timestamp) = '{trade_date}'
|
||
AND code = '{index_code}'
|
||
ORDER BY timestamp
|
||
"""
|
||
|
||
result = client.execute(query)
|
||
|
||
if not result:
|
||
return pd.DataFrame()
|
||
|
||
df = pd.DataFrame(result, columns=['timestamp', 'close', 'volume', 'amt'])
|
||
return df
|
||
|
||
|
||
def get_prev_close(stock_codes: List[str], trade_date: str) -> Dict[str, float]:
|
||
"""获取昨收价(上一交易日的收盘价 F007N)"""
|
||
valid_codes = [c for c in stock_codes if c and len(c) == 6 and c.isdigit()]
|
||
if not valid_codes:
|
||
return {}
|
||
|
||
codes_str = "','".join(valid_codes)
|
||
|
||
# 注意:F007N 是"最近成交价"即当日收盘价,F002N 是"昨日收盘价"
|
||
# 我们需要查上一交易日的 F007N(那天的收盘价)作为今天的昨收
|
||
query = f"""
|
||
SELECT SECCODE, F007N
|
||
FROM ea_trade
|
||
WHERE SECCODE IN ('{codes_str}')
|
||
AND TRADEDATE = (
|
||
SELECT MAX(TRADEDATE) FROM ea_trade WHERE TRADEDATE < '{trade_date}'
|
||
)
|
||
AND F007N IS NOT NULL AND F007N > 0
|
||
"""
|
||
|
||
try:
|
||
with MYSQL_ENGINE.connect() as conn:
|
||
result = conn.execute(text(query))
|
||
return {row[0]: float(row[1]) for row in result if row[1]}
|
||
except Exception as e:
|
||
print(f"获取昨收价失败: {e}")
|
||
return {}
|
||
|
||
|
||
def get_index_prev_close(trade_date: str, index_code: str = REFERENCE_INDEX) -> float:
|
||
"""获取指数昨收价"""
|
||
code_no_suffix = index_code.split('.')[0]
|
||
|
||
try:
|
||
with MYSQL_ENGINE.connect() as conn:
|
||
result = conn.execute(text("""
|
||
SELECT F006N FROM ea_exchangetrade
|
||
WHERE INDEXCODE = :code AND TRADEDATE < :today
|
||
ORDER BY TRADEDATE DESC LIMIT 1
|
||
"""), {'code': code_no_suffix, 'today': trade_date}).fetchone()
|
||
|
||
if result and result[0]:
|
||
return float(result[0])
|
||
except Exception as e:
|
||
print(f"获取指数昨收失败: {e}")
|
||
|
||
return None
|
||
|
||
|
||
# ==================== 计算特征 ====================
|
||
|
||
def compute_daily_features(
|
||
trade_date: str,
|
||
concepts: List[dict],
|
||
all_stocks: List[str]
|
||
) -> pd.DataFrame:
|
||
"""
|
||
计算单日所有概念的特征
|
||
|
||
返回 DataFrame:
|
||
- index: (timestamp, concept_id)
|
||
- columns: alpha, alpha_delta, amt_ratio, amt_delta, rank_pct, limit_up_ratio
|
||
"""
|
||
|
||
# 1. 获取数据
|
||
stock_df = get_daily_stock_data(trade_date, all_stocks)
|
||
if stock_df.empty:
|
||
return pd.DataFrame()
|
||
|
||
index_df = get_daily_index_data(trade_date)
|
||
if index_df.empty:
|
||
return pd.DataFrame()
|
||
|
||
# 2. 获取昨收价
|
||
prev_close = get_prev_close(all_stocks, trade_date)
|
||
index_prev_close = get_index_prev_close(trade_date)
|
||
|
||
if not prev_close or not index_prev_close:
|
||
return pd.DataFrame()
|
||
|
||
# 3. 计算股票涨跌幅和成交额
|
||
stock_df['prev_close'] = stock_df['code'].map(prev_close)
|
||
stock_df = stock_df.dropna(subset=['prev_close'])
|
||
stock_df['change_pct'] = (stock_df['close'] - stock_df['prev_close']) / stock_df['prev_close'] * 100
|
||
|
||
# 4. 计算指数涨跌幅
|
||
index_df['change_pct'] = (index_df['close'] - index_prev_close) / index_prev_close * 100
|
||
index_change_map = dict(zip(index_df['timestamp'], index_df['change_pct']))
|
||
|
||
# 5. 获取所有时间点
|
||
timestamps = sorted(stock_df['timestamp'].unique())
|
||
|
||
# 6. 按时间点计算概念特征
|
||
results = []
|
||
|
||
# 概念到股票的映射
|
||
concept_stocks = {c['concept_id']: set(c['stocks']) for c in concepts}
|
||
concept_names = {c['concept_id']: c['concept_name'] for c in concepts}
|
||
|
||
# 历史数据缓存(用于计算变化率)
|
||
concept_history = {cid: {'alpha': [], 'amt': []} for cid in concept_stocks}
|
||
|
||
for ts in timestamps:
|
||
ts_stock_data = stock_df[stock_df['timestamp'] == ts]
|
||
index_change = index_change_map.get(ts, 0)
|
||
|
||
# 股票涨跌幅和成交额字典
|
||
stock_change = dict(zip(ts_stock_data['code'], ts_stock_data['change_pct']))
|
||
stock_amt = dict(zip(ts_stock_data['code'], ts_stock_data['amt']))
|
||
|
||
concept_features = []
|
||
|
||
for concept_id, stocks in concept_stocks.items():
|
||
# 该概念的股票数据
|
||
concept_changes = [stock_change[s] for s in stocks if s in stock_change]
|
||
concept_amts = [stock_amt.get(s, 0) for s in stocks if s in stock_change]
|
||
|
||
if not concept_changes:
|
||
continue
|
||
|
||
# 基础统计
|
||
avg_change = np.mean(concept_changes)
|
||
total_amt = sum(concept_amts)
|
||
|
||
# Alpha = 概念涨幅 - 指数涨幅
|
||
alpha = avg_change - index_change
|
||
|
||
# 涨停/跌停股占比
|
||
limit_up_count = sum(1 for c in concept_changes if c >= FEATURE_CONFIG['limit_up_threshold'])
|
||
limit_down_count = sum(1 for c in concept_changes if c <= FEATURE_CONFIG['limit_down_threshold'])
|
||
limit_up_ratio = limit_up_count / len(concept_changes)
|
||
limit_down_ratio = limit_down_count / len(concept_changes)
|
||
|
||
# 更新历史
|
||
history = concept_history[concept_id]
|
||
history['alpha'].append(alpha)
|
||
history['amt'].append(total_amt)
|
||
|
||
# 计算变化率
|
||
alpha_delta = 0
|
||
if len(history['alpha']) > FEATURE_CONFIG['alpha_delta_window']:
|
||
alpha_delta = alpha - history['alpha'][-FEATURE_CONFIG['alpha_delta_window']-1]
|
||
|
||
# 成交额相对均值
|
||
amt_ratio = 1.0
|
||
amt_delta = 0
|
||
if len(history['amt']) > FEATURE_CONFIG['amt_ma_window']:
|
||
amt_ma = np.mean(history['amt'][-FEATURE_CONFIG['amt_ma_window']-1:-1])
|
||
if amt_ma > 0:
|
||
amt_ratio = total_amt / amt_ma
|
||
amt_delta = total_amt - history['amt'][-2] if len(history['amt']) > 1 else 0
|
||
|
||
concept_features.append({
|
||
'concept_id': concept_id,
|
||
'alpha': alpha,
|
||
'alpha_delta': alpha_delta,
|
||
'amt_ratio': amt_ratio,
|
||
'amt_delta': amt_delta,
|
||
'limit_up_ratio': limit_up_ratio,
|
||
'limit_down_ratio': limit_down_ratio,
|
||
'total_amt': total_amt,
|
||
'stock_count': len(concept_changes),
|
||
})
|
||
|
||
if not concept_features:
|
||
continue
|
||
|
||
# 计算排名百分位
|
||
concept_df = pd.DataFrame(concept_features)
|
||
concept_df['rank_pct'] = concept_df['alpha'].rank(pct=True)
|
||
|
||
# 添加时间戳
|
||
concept_df['timestamp'] = ts
|
||
results.append(concept_df)
|
||
|
||
if not results:
|
||
return pd.DataFrame()
|
||
|
||
# 合并所有时间点
|
||
final_df = pd.concat(results, ignore_index=True)
|
||
|
||
# 标准化成交额变化率
|
||
if 'amt_delta' in final_df.columns:
|
||
amt_delta_std = final_df['amt_delta'].std()
|
||
if amt_delta_std > 0:
|
||
final_df['amt_delta'] = final_df['amt_delta'] / amt_delta_std
|
||
|
||
return final_df
|
||
|
||
|
||
# ==================== 主流程 ====================
|
||
|
||
def process_single_day(args) -> Tuple[str, bool]:
|
||
"""
|
||
处理单个交易日(多进程版本)
|
||
|
||
Args:
|
||
args: (trade_date, concepts, all_stocks) 元组
|
||
|
||
Returns:
|
||
(trade_date, success) 元组
|
||
"""
|
||
trade_date, concepts, all_stocks = args
|
||
output_file = os.path.join(OUTPUT_DIR, f'features_{trade_date}.parquet')
|
||
|
||
# 检查是否已处理
|
||
if os.path.exists(output_file):
|
||
print(f"[{trade_date}] 已存在,跳过")
|
||
return (trade_date, True)
|
||
|
||
print(f"[{trade_date}] 开始处理...")
|
||
|
||
try:
|
||
df = compute_daily_features(trade_date, concepts, all_stocks)
|
||
|
||
if df.empty:
|
||
print(f"[{trade_date}] 无数据")
|
||
return (trade_date, False)
|
||
|
||
# 保存
|
||
df.to_parquet(output_file, index=False)
|
||
print(f"[{trade_date}] 保存完成")
|
||
return (trade_date, True)
|
||
|
||
except Exception as e:
|
||
print(f"[{trade_date}] 处理失败: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return (trade_date, False)
|
||
|
||
|
||
def main():
|
||
import argparse
|
||
from tqdm import tqdm
|
||
|
||
parser = argparse.ArgumentParser(description='准备训练数据')
|
||
parser.add_argument('--start', type=str, default='2022-01-01', help='开始日期')
|
||
parser.add_argument('--end', type=str, default=None, help='结束日期(默认今天)')
|
||
parser.add_argument('--workers', type=int, default=18, help='并行进程数(默认18)')
|
||
parser.add_argument('--force', action='store_true', help='强制重新处理已存在的文件')
|
||
|
||
args = parser.parse_args()
|
||
|
||
end_date = args.end or datetime.now().strftime('%Y-%m-%d')
|
||
|
||
print("=" * 60)
|
||
print("数据准备 - Transformer Autoencoder 训练数据")
|
||
print("=" * 60)
|
||
print(f"日期范围: {args.start} ~ {end_date}")
|
||
print(f"并行进程数: {args.workers}")
|
||
|
||
# 1. 获取概念列表
|
||
concepts = get_all_concepts()
|
||
|
||
# 收集所有股票
|
||
all_stocks = list(set(s for c in concepts for s in c['stocks']))
|
||
print(f"股票总数: {len(all_stocks)}")
|
||
|
||
# 2. 获取交易日列表
|
||
trading_days = get_trading_days(args.start, end_date)
|
||
|
||
if not trading_days:
|
||
print("无交易日数据")
|
||
return
|
||
|
||
# 如果强制模式,删除已有文件
|
||
if args.force:
|
||
for trade_date in trading_days:
|
||
output_file = os.path.join(OUTPUT_DIR, f'features_{trade_date}.parquet')
|
||
if os.path.exists(output_file):
|
||
os.remove(output_file)
|
||
print(f"删除已有文件: {output_file}")
|
||
|
||
# 3. 准备任务参数
|
||
tasks = [(trade_date, concepts, all_stocks) for trade_date in trading_days]
|
||
|
||
print(f"\n开始处理 {len(trading_days)} 个交易日({args.workers} 进程并行)...")
|
||
|
||
# 4. 多进程处理
|
||
success_count = 0
|
||
failed_dates = []
|
||
|
||
with ProcessPoolExecutor(max_workers=args.workers) as executor:
|
||
# 提交所有任务
|
||
futures = {executor.submit(process_single_day, task): task[0] for task in tasks}
|
||
|
||
# 使用 tqdm 显示进度
|
||
with tqdm(total=len(futures), desc="处理进度", unit="天") as pbar:
|
||
for future in as_completed(futures):
|
||
trade_date = futures[future]
|
||
try:
|
||
result_date, success = future.result()
|
||
if success:
|
||
success_count += 1
|
||
else:
|
||
failed_dates.append(result_date)
|
||
except Exception as e:
|
||
print(f"\n[{trade_date}] 进程异常: {e}")
|
||
failed_dates.append(trade_date)
|
||
pbar.update(1)
|
||
|
||
print("\n" + "=" * 60)
|
||
print(f"处理完成: {success_count}/{len(trading_days)} 个交易日")
|
||
if failed_dates:
|
||
print(f"失败日期: {failed_dates[:10]}{'...' if len(failed_dates) > 10 else ''}")
|
||
print(f"数据保存在: {OUTPUT_DIR}")
|
||
print("=" * 60)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|