#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 数据准备脚本 - 为 Transformer Autoencoder 准备训练数据 从 ClickHouse 提取历史分钟数据,计算以下特征: 1. alpha - 超额收益(概念涨幅 - 大盘涨幅) 2. alpha_delta - Alpha 变化率(5分钟) 3. amt_ratio - 成交额相对均值(当前/过去20分钟均值) 4. amt_delta - 成交额变化率 5. rank_pct - Alpha 排名百分位 6. limit_up_ratio - 涨停股占比 输出:按交易日存储的特征文件(parquet格式) """ import os import sys import numpy as np import pandas as pd from datetime import datetime, timedelta, date from sqlalchemy import create_engine, text from elasticsearch import Elasticsearch from clickhouse_driver import Client import hashlib import json import logging from typing import Dict, List, Set, Tuple from concurrent.futures import ProcessPoolExecutor, as_completed from multiprocessing import Manager import multiprocessing import warnings warnings.filterwarnings('ignore') # ==================== 配置 ==================== MYSQL_ENGINE = create_engine( "mysql+pymysql://root:Zzl5588161!@192.168.1.5:3306/stock", echo=False ) ES_CLIENT = Elasticsearch(['http://127.0.0.1:9200']) ES_INDEX = 'concept_library_v3' CLICKHOUSE_CONFIG = { 'host': '127.0.0.1', 'port': 9000, 'user': 'default', 'password': 'Zzl33818!', 'database': 'stock' } # 输出目录 OUTPUT_DIR = os.path.join(os.path.dirname(__file__), 'data') os.makedirs(OUTPUT_DIR, exist_ok=True) # 特征计算参数 FEATURE_CONFIG = { 'alpha_delta_window': 5, # Alpha变化窗口(分钟) 'amt_ma_window': 20, # 成交额均值窗口(分钟) 'limit_up_threshold': 9.8, # 涨停阈值(%) 'limit_down_threshold': -9.8, # 跌停阈值(%) } REFERENCE_INDEX = '000001.SH' # ==================== 日志 ==================== logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) # ==================== 工具函数 ==================== def get_ch_client(): return Client(**CLICKHOUSE_CONFIG) def generate_id(name: str) -> str: return hashlib.md5(name.encode('utf-8')).hexdigest()[:16] def code_to_ch_format(code: str) -> str: if not code or len(code) != 6 or not code.isdigit(): return None if code.startswith('6'): return f"{code}.SH" elif code.startswith('0') or code.startswith('3'): return f"{code}.SZ" else: return f"{code}.BJ" # ==================== 获取概念列表 ==================== def get_all_concepts() -> List[dict]: """从ES获取所有叶子概念""" concepts = [] query = { "query": {"match_all": {}}, "size": 100, "_source": ["concept_id", "concept", "stocks"] } resp = ES_CLIENT.search(index=ES_INDEX, body=query, scroll='2m') scroll_id = resp['_scroll_id'] hits = resp['hits']['hits'] while len(hits) > 0: for hit in hits: source = hit['_source'] stocks = [] if 'stocks' in source and isinstance(source['stocks'], list): for stock in source['stocks']: if isinstance(stock, dict) and 'code' in stock and stock['code']: stocks.append(stock['code']) if stocks: concepts.append({ 'concept_id': source.get('concept_id'), 'concept_name': source.get('concept'), 'stocks': stocks }) resp = ES_CLIENT.scroll(scroll_id=scroll_id, scroll='2m') scroll_id = resp['_scroll_id'] hits = resp['hits']['hits'] ES_CLIENT.clear_scroll(scroll_id=scroll_id) print(f"获取到 {len(concepts)} 个概念") return concepts # ==================== 获取交易日列表 ==================== def get_trading_days(start_date: str, end_date: str) -> List[str]: """获取交易日列表""" client = get_ch_client() query = f""" SELECT DISTINCT toDate(timestamp) as trade_date FROM stock_minute WHERE toDate(timestamp) >= '{start_date}' AND toDate(timestamp) <= '{end_date}' ORDER BY trade_date """ result = client.execute(query) days = [row[0].strftime('%Y-%m-%d') for row in result] print(f"找到 {len(days)} 个交易日: {days[0]} ~ {days[-1]}") return days # ==================== 获取单日数据 ==================== def get_daily_stock_data(trade_date: str, stock_codes: List[str]) -> pd.DataFrame: """获取单日所有股票的分钟数据""" client = get_ch_client() # 转换代码格式 ch_codes = [] code_map = {} for code in stock_codes: ch_code = code_to_ch_format(code) if ch_code: ch_codes.append(ch_code) code_map[ch_code] = code if not ch_codes: return pd.DataFrame() ch_codes_str = "','".join(ch_codes) query = f""" SELECT code, timestamp, close, volume, amt FROM stock_minute WHERE toDate(timestamp) = '{trade_date}' AND code IN ('{ch_codes_str}') ORDER BY code, timestamp """ result = client.execute(query) if not result: return pd.DataFrame() df = pd.DataFrame(result, columns=['ch_code', 'timestamp', 'close', 'volume', 'amt']) df['code'] = df['ch_code'].map(code_map) df = df.dropna(subset=['code']) return df[['code', 'timestamp', 'close', 'volume', 'amt']] def get_daily_index_data(trade_date: str, index_code: str = REFERENCE_INDEX) -> pd.DataFrame: """获取单日指数分钟数据""" client = get_ch_client() query = f""" SELECT timestamp, close, volume, amt FROM index_minute WHERE toDate(timestamp) = '{trade_date}' AND code = '{index_code}' ORDER BY timestamp """ result = client.execute(query) if not result: return pd.DataFrame() df = pd.DataFrame(result, columns=['timestamp', 'close', 'volume', 'amt']) return df def get_prev_close(stock_codes: List[str], trade_date: str) -> Dict[str, float]: """获取昨收价(上一交易日的收盘价 F007N)""" valid_codes = [c for c in stock_codes if c and len(c) == 6 and c.isdigit()] if not valid_codes: return {} codes_str = "','".join(valid_codes) # 注意:F007N 是"最近成交价"即当日收盘价,F002N 是"昨日收盘价" # 我们需要查上一交易日的 F007N(那天的收盘价)作为今天的昨收 query = f""" SELECT SECCODE, F007N FROM ea_trade WHERE SECCODE IN ('{codes_str}') AND TRADEDATE = ( SELECT MAX(TRADEDATE) FROM ea_trade WHERE TRADEDATE < '{trade_date}' ) AND F007N IS NOT NULL AND F007N > 0 """ try: with MYSQL_ENGINE.connect() as conn: result = conn.execute(text(query)) return {row[0]: float(row[1]) for row in result if row[1]} except Exception as e: print(f"获取昨收价失败: {e}") return {} def get_index_prev_close(trade_date: str, index_code: str = REFERENCE_INDEX) -> float: """获取指数昨收价""" code_no_suffix = index_code.split('.')[0] try: with MYSQL_ENGINE.connect() as conn: result = conn.execute(text(""" SELECT F006N FROM ea_exchangetrade WHERE INDEXCODE = :code AND TRADEDATE < :today ORDER BY TRADEDATE DESC LIMIT 1 """), {'code': code_no_suffix, 'today': trade_date}).fetchone() if result and result[0]: return float(result[0]) except Exception as e: print(f"获取指数昨收失败: {e}") return None # ==================== 计算特征 ==================== def compute_daily_features( trade_date: str, concepts: List[dict], all_stocks: List[str] ) -> pd.DataFrame: """ 计算单日所有概念的特征 返回 DataFrame: - index: (timestamp, concept_id) - columns: alpha, alpha_delta, amt_ratio, amt_delta, rank_pct, limit_up_ratio """ # 1. 获取数据 stock_df = get_daily_stock_data(trade_date, all_stocks) if stock_df.empty: return pd.DataFrame() index_df = get_daily_index_data(trade_date) if index_df.empty: return pd.DataFrame() # 2. 获取昨收价 prev_close = get_prev_close(all_stocks, trade_date) index_prev_close = get_index_prev_close(trade_date) if not prev_close or not index_prev_close: return pd.DataFrame() # 3. 计算股票涨跌幅和成交额 stock_df['prev_close'] = stock_df['code'].map(prev_close) stock_df = stock_df.dropna(subset=['prev_close']) stock_df['change_pct'] = (stock_df['close'] - stock_df['prev_close']) / stock_df['prev_close'] * 100 # 4. 计算指数涨跌幅 index_df['change_pct'] = (index_df['close'] - index_prev_close) / index_prev_close * 100 index_change_map = dict(zip(index_df['timestamp'], index_df['change_pct'])) # 5. 获取所有时间点 timestamps = sorted(stock_df['timestamp'].unique()) # 6. 按时间点计算概念特征 results = [] # 概念到股票的映射 concept_stocks = {c['concept_id']: set(c['stocks']) for c in concepts} concept_names = {c['concept_id']: c['concept_name'] for c in concepts} # 历史数据缓存(用于计算变化率) concept_history = {cid: {'alpha': [], 'amt': []} for cid in concept_stocks} for ts in timestamps: ts_stock_data = stock_df[stock_df['timestamp'] == ts] index_change = index_change_map.get(ts, 0) # 股票涨跌幅和成交额字典 stock_change = dict(zip(ts_stock_data['code'], ts_stock_data['change_pct'])) stock_amt = dict(zip(ts_stock_data['code'], ts_stock_data['amt'])) concept_features = [] for concept_id, stocks in concept_stocks.items(): # 该概念的股票数据 concept_changes = [stock_change[s] for s in stocks if s in stock_change] concept_amts = [stock_amt.get(s, 0) for s in stocks if s in stock_change] if not concept_changes: continue # 基础统计 avg_change = np.mean(concept_changes) total_amt = sum(concept_amts) # Alpha = 概念涨幅 - 指数涨幅 alpha = avg_change - index_change # 涨停/跌停股占比 limit_up_count = sum(1 for c in concept_changes if c >= FEATURE_CONFIG['limit_up_threshold']) limit_down_count = sum(1 for c in concept_changes if c <= FEATURE_CONFIG['limit_down_threshold']) limit_up_ratio = limit_up_count / len(concept_changes) limit_down_ratio = limit_down_count / len(concept_changes) # 更新历史 history = concept_history[concept_id] history['alpha'].append(alpha) history['amt'].append(total_amt) # 计算变化率 alpha_delta = 0 if len(history['alpha']) > FEATURE_CONFIG['alpha_delta_window']: alpha_delta = alpha - history['alpha'][-FEATURE_CONFIG['alpha_delta_window']-1] # 成交额相对均值 amt_ratio = 1.0 amt_delta = 0 if len(history['amt']) > FEATURE_CONFIG['amt_ma_window']: amt_ma = np.mean(history['amt'][-FEATURE_CONFIG['amt_ma_window']-1:-1]) if amt_ma > 0: amt_ratio = total_amt / amt_ma amt_delta = total_amt - history['amt'][-2] if len(history['amt']) > 1 else 0 concept_features.append({ 'concept_id': concept_id, 'alpha': alpha, 'alpha_delta': alpha_delta, 'amt_ratio': amt_ratio, 'amt_delta': amt_delta, 'limit_up_ratio': limit_up_ratio, 'limit_down_ratio': limit_down_ratio, 'total_amt': total_amt, 'stock_count': len(concept_changes), }) if not concept_features: continue # 计算排名百分位 concept_df = pd.DataFrame(concept_features) concept_df['rank_pct'] = concept_df['alpha'].rank(pct=True) # 添加时间戳 concept_df['timestamp'] = ts results.append(concept_df) if not results: return pd.DataFrame() # 合并所有时间点 final_df = pd.concat(results, ignore_index=True) # 标准化成交额变化率 if 'amt_delta' in final_df.columns: amt_delta_std = final_df['amt_delta'].std() if amt_delta_std > 0: final_df['amt_delta'] = final_df['amt_delta'] / amt_delta_std return final_df # ==================== 主流程 ==================== def process_single_day(args) -> Tuple[str, bool]: """ 处理单个交易日(多进程版本) Args: args: (trade_date, concepts, all_stocks) 元组 Returns: (trade_date, success) 元组 """ trade_date, concepts, all_stocks = args output_file = os.path.join(OUTPUT_DIR, f'features_{trade_date}.parquet') # 检查是否已处理 if os.path.exists(output_file): print(f"[{trade_date}] 已存在,跳过") return (trade_date, True) print(f"[{trade_date}] 开始处理...") try: df = compute_daily_features(trade_date, concepts, all_stocks) if df.empty: print(f"[{trade_date}] 无数据") return (trade_date, False) # 保存 df.to_parquet(output_file, index=False) print(f"[{trade_date}] 保存完成") return (trade_date, True) except Exception as e: print(f"[{trade_date}] 处理失败: {e}") import traceback traceback.print_exc() return (trade_date, False) def main(): import argparse from tqdm import tqdm parser = argparse.ArgumentParser(description='准备训练数据') parser.add_argument('--start', type=str, default='2022-01-01', help='开始日期') parser.add_argument('--end', type=str, default=None, help='结束日期(默认今天)') parser.add_argument('--workers', type=int, default=18, help='并行进程数(默认18)') parser.add_argument('--force', action='store_true', help='强制重新处理已存在的文件') args = parser.parse_args() end_date = args.end or datetime.now().strftime('%Y-%m-%d') print("=" * 60) print("数据准备 - Transformer Autoencoder 训练数据") print("=" * 60) print(f"日期范围: {args.start} ~ {end_date}") print(f"并行进程数: {args.workers}") # 1. 获取概念列表 concepts = get_all_concepts() # 收集所有股票 all_stocks = list(set(s for c in concepts for s in c['stocks'])) print(f"股票总数: {len(all_stocks)}") # 2. 获取交易日列表 trading_days = get_trading_days(args.start, end_date) if not trading_days: print("无交易日数据") return # 如果强制模式,删除已有文件 if args.force: for trade_date in trading_days: output_file = os.path.join(OUTPUT_DIR, f'features_{trade_date}.parquet') if os.path.exists(output_file): os.remove(output_file) print(f"删除已有文件: {output_file}") # 3. 准备任务参数 tasks = [(trade_date, concepts, all_stocks) for trade_date in trading_days] print(f"\n开始处理 {len(trading_days)} 个交易日({args.workers} 进程并行)...") # 4. 多进程处理 success_count = 0 failed_dates = [] with ProcessPoolExecutor(max_workers=args.workers) as executor: # 提交所有任务 futures = {executor.submit(process_single_day, task): task[0] for task in tasks} # 使用 tqdm 显示进度 with tqdm(total=len(futures), desc="处理进度", unit="天") as pbar: for future in as_completed(futures): trade_date = futures[future] try: result_date, success = future.result() if success: success_count += 1 else: failed_dates.append(result_date) except Exception as e: print(f"\n[{trade_date}] 进程异常: {e}") failed_dates.append(trade_date) pbar.update(1) print("\n" + "=" * 60) print(f"处理完成: {success_count}/{len(trading_days)} 个交易日") if failed_dates: print(f"失败日期: {failed_dates[:10]}{'...' if len(failed_dates) > 10 else ''}") print(f"数据保存在: {OUTPUT_DIR}") print("=" * 60) if __name__ == "__main__": main()