#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 每日盘后运行:更新滚动基线 使用方法: python ml/update_baseline.py 建议加入 crontab,每天 15:30 后运行: 30 15 * * 1-5 cd /path/to/project && python ml/update_baseline.py """ import os import sys import pickle import pandas as pd import numpy as np from datetime import datetime, timedelta from pathlib import Path from tqdm import tqdm sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from ml.prepare_data_v2 import ( get_all_concepts, get_trading_days, compute_raw_concept_features, init_process_connections, CONFIG, RAW_CACHE_DIR, BASELINE_DIR ) def update_rolling_baseline(baseline_days: int = 20): """ 更新滚动基线(用于实盘检测) 基线 = 最近 N 个交易日每个时间片的统计量 """ print("=" * 60) print("更新滚动基线(用于实盘)") print("=" * 60) # 初始化连接 init_process_connections() # 获取概念列表 concepts = get_all_concepts() all_stocks = list(set(s for c in concepts for s in c['stocks'])) # 获取最近的交易日 today = datetime.now().strftime('%Y-%m-%d') start_date = (datetime.now() - timedelta(days=60)).strftime('%Y-%m-%d') # 多取一些 trading_days = get_trading_days(start_date, today) if len(trading_days) < baseline_days: print(f"错误:交易日不足 {baseline_days} 天") return # 只取最近 N 天 recent_days = trading_days[-baseline_days:] print(f"使用 {len(recent_days)} 天数据: {recent_days[0]} ~ {recent_days[-1]}") # 加载原始数据 all_data = [] for trade_date in tqdm(recent_days, desc="加载数据"): cache_file = os.path.join(RAW_CACHE_DIR, f'raw_{trade_date}.parquet') if os.path.exists(cache_file): df = pd.read_parquet(cache_file) else: df = compute_raw_concept_features(trade_date, concepts, all_stocks) if not df.empty: all_data.append(df) if not all_data: print("错误:无数据") return combined = pd.concat(all_data, ignore_index=True) print(f"总数据量: {len(combined):,} 条") # 按概念计算基线 baselines = {} for concept_id, group in tqdm(combined.groupby('concept_id'), desc="计算基线"): baseline_dict = {} for time_slot, slot_group in group.groupby('time_slot'): if len(slot_group) < CONFIG['min_baseline_samples']: continue alpha_std = slot_group['alpha'].std() amt_std = slot_group['total_amt'].std() rank_std = slot_group['rank_pct'].std() baseline_dict[time_slot] = { 'alpha_mean': float(slot_group['alpha'].mean()), 'alpha_std': float(max(alpha_std if pd.notna(alpha_std) else 1.0, 0.1)), 'amt_mean': float(slot_group['total_amt'].mean()), 'amt_std': float(max(amt_std if pd.notna(amt_std) else slot_group['total_amt'].mean() * 0.5, 1.0)), 'rank_mean': float(slot_group['rank_pct'].mean()), 'rank_std': float(max(rank_std if pd.notna(rank_std) else 0.2, 0.05)), 'sample_count': len(slot_group), } if baseline_dict: baselines[concept_id] = baseline_dict print(f"计算了 {len(baselines)} 个概念的基线") # 保存 os.makedirs(BASELINE_DIR, exist_ok=True) baseline_file = os.path.join(BASELINE_DIR, 'realtime_baseline.pkl') with open(baseline_file, 'wb') as f: pickle.dump({ 'baselines': baselines, 'update_time': datetime.now().isoformat(), 'date_range': [recent_days[0], recent_days[-1]], 'baseline_days': baseline_days, }, f) print(f"基线已保存: {baseline_file}") print("=" * 60) if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument('--days', type=int, default=20, help='基线天数') args = parser.parse_args() update_rolling_baseline(args.days)