133 lines
4.1 KiB
Python
133 lines
4.1 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
每日盘后运行:更新滚动基线
|
||
|
||
使用方法:
|
||
python ml/update_baseline.py
|
||
|
||
建议加入 crontab,每天 15:30 后运行:
|
||
30 15 * * 1-5 cd /path/to/project && python ml/update_baseline.py
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import pickle
|
||
import pandas as pd
|
||
import numpy as np
|
||
from datetime import datetime, timedelta
|
||
from pathlib import Path
|
||
from tqdm import tqdm
|
||
|
||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
|
||
from ml.prepare_data_v2 import (
|
||
get_all_concepts, get_trading_days, compute_raw_concept_features,
|
||
init_process_connections, CONFIG, RAW_CACHE_DIR, BASELINE_DIR
|
||
)
|
||
|
||
|
||
def update_rolling_baseline(baseline_days: int = 20):
|
||
"""
|
||
更新滚动基线(用于实盘检测)
|
||
|
||
基线 = 最近 N 个交易日每个时间片的统计量
|
||
"""
|
||
print("=" * 60)
|
||
print("更新滚动基线(用于实盘)")
|
||
print("=" * 60)
|
||
|
||
# 初始化连接
|
||
init_process_connections()
|
||
|
||
# 获取概念列表
|
||
concepts = get_all_concepts()
|
||
all_stocks = list(set(s for c in concepts for s in c['stocks']))
|
||
|
||
# 获取最近的交易日
|
||
today = datetime.now().strftime('%Y-%m-%d')
|
||
start_date = (datetime.now() - timedelta(days=60)).strftime('%Y-%m-%d') # 多取一些
|
||
|
||
trading_days = get_trading_days(start_date, today)
|
||
|
||
if len(trading_days) < baseline_days:
|
||
print(f"错误:交易日不足 {baseline_days} 天")
|
||
return
|
||
|
||
# 只取最近 N 天
|
||
recent_days = trading_days[-baseline_days:]
|
||
print(f"使用 {len(recent_days)} 天数据: {recent_days[0]} ~ {recent_days[-1]}")
|
||
|
||
# 加载原始数据
|
||
all_data = []
|
||
for trade_date in tqdm(recent_days, desc="加载数据"):
|
||
cache_file = os.path.join(RAW_CACHE_DIR, f'raw_{trade_date}.parquet')
|
||
|
||
if os.path.exists(cache_file):
|
||
df = pd.read_parquet(cache_file)
|
||
else:
|
||
df = compute_raw_concept_features(trade_date, concepts, all_stocks)
|
||
|
||
if not df.empty:
|
||
all_data.append(df)
|
||
|
||
if not all_data:
|
||
print("错误:无数据")
|
||
return
|
||
|
||
combined = pd.concat(all_data, ignore_index=True)
|
||
print(f"总数据量: {len(combined):,} 条")
|
||
|
||
# 按概念计算基线
|
||
baselines = {}
|
||
|
||
for concept_id, group in tqdm(combined.groupby('concept_id'), desc="计算基线"):
|
||
baseline_dict = {}
|
||
|
||
for time_slot, slot_group in group.groupby('time_slot'):
|
||
if len(slot_group) < CONFIG['min_baseline_samples']:
|
||
continue
|
||
|
||
alpha_std = slot_group['alpha'].std()
|
||
amt_std = slot_group['total_amt'].std()
|
||
rank_std = slot_group['rank_pct'].std()
|
||
|
||
baseline_dict[time_slot] = {
|
||
'alpha_mean': float(slot_group['alpha'].mean()),
|
||
'alpha_std': float(max(alpha_std if pd.notna(alpha_std) else 1.0, 0.1)),
|
||
'amt_mean': float(slot_group['total_amt'].mean()),
|
||
'amt_std': float(max(amt_std if pd.notna(amt_std) else slot_group['total_amt'].mean() * 0.5, 1.0)),
|
||
'rank_mean': float(slot_group['rank_pct'].mean()),
|
||
'rank_std': float(max(rank_std if pd.notna(rank_std) else 0.2, 0.05)),
|
||
'sample_count': len(slot_group),
|
||
}
|
||
|
||
if baseline_dict:
|
||
baselines[concept_id] = baseline_dict
|
||
|
||
print(f"计算了 {len(baselines)} 个概念的基线")
|
||
|
||
# 保存
|
||
os.makedirs(BASELINE_DIR, exist_ok=True)
|
||
baseline_file = os.path.join(BASELINE_DIR, 'realtime_baseline.pkl')
|
||
|
||
with open(baseline_file, 'wb') as f:
|
||
pickle.dump({
|
||
'baselines': baselines,
|
||
'update_time': datetime.now().isoformat(),
|
||
'date_range': [recent_days[0], recent_days[-1]],
|
||
'baseline_days': baseline_days,
|
||
}, f)
|
||
|
||
print(f"基线已保存: {baseline_file}")
|
||
print("=" * 60)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
import argparse
|
||
parser = argparse.ArgumentParser()
|
||
parser.add_argument('--days', type=int, default=20, help='基线天数')
|
||
args = parser.parse_args()
|
||
|
||
update_rolling_baseline(args.days)
|