Files
vf_react/ml/prepare_data.py
2025-12-09 08:31:18 +08:00

502 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
数据准备脚本 - 为 Transformer Autoencoder 准备训练数据
从 ClickHouse 提取历史分钟数据,计算以下特征:
1. alpha - 超额收益(概念涨幅 - 大盘涨幅)
2. alpha_delta - Alpha 变化率5分钟
3. amt_ratio - 成交额相对均值(当前/过去20分钟均值
4. amt_delta - 成交额变化率
5. rank_pct - Alpha 排名百分位
6. limit_up_ratio - 涨停股占比
输出按交易日存储的特征文件parquet格式
"""
import os
import sys
import numpy as np
import pandas as pd
from datetime import datetime, timedelta, date
from sqlalchemy import create_engine, text
from elasticsearch import Elasticsearch
from clickhouse_driver import Client
import hashlib
import json
import logging
from typing import Dict, List, Set, Tuple
from concurrent.futures import ThreadPoolExecutor, as_completed
import warnings
warnings.filterwarnings('ignore')
# ==================== 配置 ====================
MYSQL_ENGINE = create_engine(
"mysql+pymysql://root:Zzl5588161!@192.168.1.5:3306/stock",
echo=False
)
ES_CLIENT = Elasticsearch(['http://127.0.0.1:9200'])
ES_INDEX = 'concept_library_v3'
CLICKHOUSE_CONFIG = {
'host': '127.0.0.1',
'port': 9000,
'user': 'default',
'password': 'Zzl33818!',
'database': 'stock'
}
# 输出目录
OUTPUT_DIR = os.path.join(os.path.dirname(__file__), 'data')
os.makedirs(OUTPUT_DIR, exist_ok=True)
# 特征计算参数
FEATURE_CONFIG = {
'alpha_delta_window': 5, # Alpha变化窗口分钟
'amt_ma_window': 20, # 成交额均值窗口(分钟)
'limit_up_threshold': 9.8, # 涨停阈值(%
'limit_down_threshold': -9.8, # 跌停阈值(%
}
REFERENCE_INDEX = '000001.SH'
# ==================== 日志 ====================
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# ==================== 工具函数 ====================
def get_ch_client():
return Client(**CLICKHOUSE_CONFIG)
def generate_id(name: str) -> str:
return hashlib.md5(name.encode('utf-8')).hexdigest()[:16]
def code_to_ch_format(code: str) -> str:
if not code or len(code) != 6 or not code.isdigit():
return None
if code.startswith('6'):
return f"{code}.SH"
elif code.startswith('0') or code.startswith('3'):
return f"{code}.SZ"
else:
return f"{code}.BJ"
# ==================== 获取概念列表 ====================
def get_all_concepts() -> List[dict]:
"""从ES获取所有叶子概念"""
concepts = []
query = {
"query": {"match_all": {}},
"size": 100,
"_source": ["concept_id", "concept", "stocks"]
}
resp = ES_CLIENT.search(index=ES_INDEX, body=query, scroll='2m')
scroll_id = resp['_scroll_id']
hits = resp['hits']['hits']
while len(hits) > 0:
for hit in hits:
source = hit['_source']
stocks = []
if 'stocks' in source and isinstance(source['stocks'], list):
for stock in source['stocks']:
if isinstance(stock, dict) and 'code' in stock and stock['code']:
stocks.append(stock['code'])
if stocks:
concepts.append({
'concept_id': source.get('concept_id'),
'concept_name': source.get('concept'),
'stocks': stocks
})
resp = ES_CLIENT.scroll(scroll_id=scroll_id, scroll='2m')
scroll_id = resp['_scroll_id']
hits = resp['hits']['hits']
ES_CLIENT.clear_scroll(scroll_id=scroll_id)
logger.info(f"获取到 {len(concepts)} 个概念")
return concepts
# ==================== 获取交易日列表 ====================
def get_trading_days(start_date: str, end_date: str) -> List[str]:
"""获取交易日列表"""
client = get_ch_client()
query = f"""
SELECT DISTINCT toDate(timestamp) as trade_date
FROM stock_minute
WHERE toDate(timestamp) >= '{start_date}'
AND toDate(timestamp) <= '{end_date}'
ORDER BY trade_date
"""
result = client.execute(query)
days = [row[0].strftime('%Y-%m-%d') for row in result]
logger.info(f"找到 {len(days)} 个交易日: {days[0]} ~ {days[-1]}")
return days
# ==================== 获取单日数据 ====================
def get_daily_stock_data(trade_date: str, stock_codes: List[str]) -> pd.DataFrame:
"""获取单日所有股票的分钟数据"""
client = get_ch_client()
# 转换代码格式
ch_codes = []
code_map = {}
for code in stock_codes:
ch_code = code_to_ch_format(code)
if ch_code:
ch_codes.append(ch_code)
code_map[ch_code] = code
if not ch_codes:
return pd.DataFrame()
ch_codes_str = "','".join(ch_codes)
query = f"""
SELECT
code,
timestamp,
close,
volume,
amt
FROM stock_minute
WHERE toDate(timestamp) = '{trade_date}'
AND code IN ('{ch_codes_str}')
ORDER BY code, timestamp
"""
result = client.execute(query)
if not result:
return pd.DataFrame()
df = pd.DataFrame(result, columns=['ch_code', 'timestamp', 'close', 'volume', 'amt'])
df['code'] = df['ch_code'].map(code_map)
df = df.dropna(subset=['code'])
return df[['code', 'timestamp', 'close', 'volume', 'amt']]
def get_daily_index_data(trade_date: str, index_code: str = REFERENCE_INDEX) -> pd.DataFrame:
"""获取单日指数分钟数据"""
client = get_ch_client()
query = f"""
SELECT
timestamp,
close,
volume,
amt
FROM index_minute
WHERE toDate(timestamp) = '{trade_date}'
AND code = '{index_code}'
ORDER BY timestamp
"""
result = client.execute(query)
if not result:
return pd.DataFrame()
df = pd.DataFrame(result, columns=['timestamp', 'close', 'volume', 'amt'])
return df
def get_prev_close(stock_codes: List[str], trade_date: str) -> Dict[str, float]:
"""获取昨收价"""
valid_codes = [c for c in stock_codes if c and len(c) == 6 and c.isdigit()]
if not valid_codes:
return {}
codes_str = "','".join(valid_codes)
query = f"""
SELECT SECCODE, F002N
FROM ea_trade
WHERE SECCODE IN ('{codes_str}')
AND TRADEDATE = (
SELECT MAX(TRADEDATE) FROM ea_trade WHERE TRADEDATE < '{trade_date}'
)
AND F002N IS NOT NULL AND F002N > 0
"""
try:
with MYSQL_ENGINE.connect() as conn:
result = conn.execute(text(query))
return {row[0]: float(row[1]) for row in result if row[1]}
except Exception as e:
logger.error(f"获取昨收价失败: {e}")
return {}
def get_index_prev_close(trade_date: str, index_code: str = REFERENCE_INDEX) -> float:
"""获取指数昨收价"""
code_no_suffix = index_code.split('.')[0]
try:
with MYSQL_ENGINE.connect() as conn:
result = conn.execute(text("""
SELECT F006N FROM ea_exchangetrade
WHERE INDEXCODE = :code AND TRADEDATE < :today
ORDER BY TRADEDATE DESC LIMIT 1
"""), {'code': code_no_suffix, 'today': trade_date}).fetchone()
if result and result[0]:
return float(result[0])
except Exception as e:
logger.error(f"获取指数昨收失败: {e}")
return None
# ==================== 计算特征 ====================
def compute_daily_features(
trade_date: str,
concepts: List[dict],
all_stocks: List[str]
) -> pd.DataFrame:
"""
计算单日所有概念的特征
返回 DataFrame:
- index: (timestamp, concept_id)
- columns: alpha, alpha_delta, amt_ratio, amt_delta, rank_pct, limit_up_ratio
"""
# 1. 获取数据
logger.info(f" 获取股票数据...")
stock_df = get_daily_stock_data(trade_date, all_stocks)
if stock_df.empty:
logger.warning(f" 无股票数据")
return pd.DataFrame()
logger.info(f" 获取指数数据...")
index_df = get_daily_index_data(trade_date)
if index_df.empty:
logger.warning(f" 无指数数据")
return pd.DataFrame()
# 2. 获取昨收价
logger.info(f" 获取昨收价...")
prev_close = get_prev_close(all_stocks, trade_date)
index_prev_close = get_index_prev_close(trade_date)
if not prev_close or not index_prev_close:
logger.warning(f" 无昨收价数据")
return pd.DataFrame()
# 3. 计算股票涨跌幅和成交额
stock_df['prev_close'] = stock_df['code'].map(prev_close)
stock_df = stock_df.dropna(subset=['prev_close'])
stock_df['change_pct'] = (stock_df['close'] - stock_df['prev_close']) / stock_df['prev_close'] * 100
# 4. 计算指数涨跌幅
index_df['change_pct'] = (index_df['close'] - index_prev_close) / index_prev_close * 100
index_change_map = dict(zip(index_df['timestamp'], index_df['change_pct']))
# 5. 获取所有时间点
timestamps = sorted(stock_df['timestamp'].unique())
logger.info(f" 时间点数: {len(timestamps)}")
# 6. 按时间点计算概念特征
results = []
# 概念到股票的映射
concept_stocks = {c['concept_id']: set(c['stocks']) for c in concepts}
concept_names = {c['concept_id']: c['concept_name'] for c in concepts}
# 历史数据缓存(用于计算变化率)
concept_history = {cid: {'alpha': [], 'amt': []} for cid in concept_stocks}
for ts in timestamps:
ts_stock_data = stock_df[stock_df['timestamp'] == ts]
index_change = index_change_map.get(ts, 0)
# 股票涨跌幅和成交额字典
stock_change = dict(zip(ts_stock_data['code'], ts_stock_data['change_pct']))
stock_amt = dict(zip(ts_stock_data['code'], ts_stock_data['amt']))
concept_features = []
for concept_id, stocks in concept_stocks.items():
# 该概念的股票数据
concept_changes = [stock_change[s] for s in stocks if s in stock_change]
concept_amts = [stock_amt.get(s, 0) for s in stocks if s in stock_change]
if not concept_changes:
continue
# 基础统计
avg_change = np.mean(concept_changes)
total_amt = sum(concept_amts)
# Alpha = 概念涨幅 - 指数涨幅
alpha = avg_change - index_change
# 涨停/跌停股占比
limit_up_count = sum(1 for c in concept_changes if c >= FEATURE_CONFIG['limit_up_threshold'])
limit_down_count = sum(1 for c in concept_changes if c <= FEATURE_CONFIG['limit_down_threshold'])
limit_up_ratio = limit_up_count / len(concept_changes)
limit_down_ratio = limit_down_count / len(concept_changes)
# 更新历史
history = concept_history[concept_id]
history['alpha'].append(alpha)
history['amt'].append(total_amt)
# 计算变化率
alpha_delta = 0
if len(history['alpha']) > FEATURE_CONFIG['alpha_delta_window']:
alpha_delta = alpha - history['alpha'][-FEATURE_CONFIG['alpha_delta_window']-1]
# 成交额相对均值
amt_ratio = 1.0
amt_delta = 0
if len(history['amt']) > FEATURE_CONFIG['amt_ma_window']:
amt_ma = np.mean(history['amt'][-FEATURE_CONFIG['amt_ma_window']-1:-1])
if amt_ma > 0:
amt_ratio = total_amt / amt_ma
amt_delta = total_amt - history['amt'][-2] if len(history['amt']) > 1 else 0
concept_features.append({
'concept_id': concept_id,
'alpha': alpha,
'alpha_delta': alpha_delta,
'amt_ratio': amt_ratio,
'amt_delta': amt_delta,
'limit_up_ratio': limit_up_ratio,
'limit_down_ratio': limit_down_ratio,
'total_amt': total_amt,
'stock_count': len(concept_changes),
})
if not concept_features:
continue
# 计算排名百分位
concept_df = pd.DataFrame(concept_features)
concept_df['rank_pct'] = concept_df['alpha'].rank(pct=True)
# 添加时间戳
concept_df['timestamp'] = ts
results.append(concept_df)
if not results:
return pd.DataFrame()
# 合并所有时间点
final_df = pd.concat(results, ignore_index=True)
# 标准化成交额变化率
if 'amt_delta' in final_df.columns:
amt_delta_std = final_df['amt_delta'].std()
if amt_delta_std > 0:
final_df['amt_delta'] = final_df['amt_delta'] / amt_delta_std
logger.info(f" 计算完成: {len(final_df)} 条记录")
return final_df
# ==================== 主流程 ====================
def process_single_day(trade_date: str, concepts: List[dict], all_stocks: List[str]) -> str:
"""处理单个交易日"""
output_file = os.path.join(OUTPUT_DIR, f'features_{trade_date}.parquet')
# 检查是否已处理
if os.path.exists(output_file):
logger.info(f"[{trade_date}] 已存在,跳过")
return output_file
logger.info(f"[{trade_date}] 开始处理...")
try:
df = compute_daily_features(trade_date, concepts, all_stocks)
if df.empty:
logger.warning(f"[{trade_date}] 无数据")
return None
# 保存
df.to_parquet(output_file, index=False)
logger.info(f"[{trade_date}] 保存完成: {output_file}")
return output_file
except Exception as e:
logger.error(f"[{trade_date}] 处理失败: {e}")
import traceback
traceback.print_exc()
return None
def main():
import argparse
parser = argparse.ArgumentParser(description='准备训练数据')
parser.add_argument('--start', type=str, default='2022-01-01', help='开始日期')
parser.add_argument('--end', type=str, default=None, help='结束日期(默认今天)')
parser.add_argument('--workers', type=int, default=1, help='并行数建议1避免数据库压力')
args = parser.parse_args()
end_date = args.end or datetime.now().strftime('%Y-%m-%d')
logger.info("=" * 60)
logger.info("数据准备 - Transformer Autoencoder 训练数据")
logger.info("=" * 60)
logger.info(f"日期范围: {args.start} ~ {end_date}")
# 1. 获取概念列表
concepts = get_all_concepts()
# 收集所有股票
all_stocks = list(set(s for c in concepts for s in c['stocks']))
logger.info(f"股票总数: {len(all_stocks)}")
# 2. 获取交易日列表
trading_days = get_trading_days(args.start, end_date)
if not trading_days:
logger.error("无交易日数据")
return
# 3. 处理每个交易日
logger.info(f"\n开始处理 {len(trading_days)} 个交易日...")
success_count = 0
for i, trade_date in enumerate(trading_days):
logger.info(f"\n[{i+1}/{len(trading_days)}] {trade_date}")
result = process_single_day(trade_date, concepts, all_stocks)
if result:
success_count += 1
logger.info("\n" + "=" * 60)
logger.info(f"处理完成: {success_count}/{len(trading_days)} 个交易日")
logger.info(f"数据保存在: {OUTPUT_DIR}")
logger.info("=" * 60)
if __name__ == "__main__":
main()