Files
vf_react/ml/update_baseline.py
2025-12-10 11:02:09 +08:00

133 lines
4.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
每日盘后运行:更新滚动基线
使用方法:
python ml/update_baseline.py
建议加入 crontab每天 15:30 后运行:
30 15 * * 1-5 cd /path/to/project && python ml/update_baseline.py
"""
import os
import sys
import pickle
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
from pathlib import Path
from tqdm import tqdm
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from ml.prepare_data_v2 import (
get_all_concepts, get_trading_days, compute_raw_concept_features,
init_process_connections, CONFIG, RAW_CACHE_DIR, BASELINE_DIR
)
def update_rolling_baseline(baseline_days: int = 20):
"""
更新滚动基线(用于实盘检测)
基线 = 最近 N 个交易日每个时间片的统计量
"""
print("=" * 60)
print("更新滚动基线(用于实盘)")
print("=" * 60)
# 初始化连接
init_process_connections()
# 获取概念列表
concepts = get_all_concepts()
all_stocks = list(set(s for c in concepts for s in c['stocks']))
# 获取最近的交易日
today = datetime.now().strftime('%Y-%m-%d')
start_date = (datetime.now() - timedelta(days=60)).strftime('%Y-%m-%d') # 多取一些
trading_days = get_trading_days(start_date, today)
if len(trading_days) < baseline_days:
print(f"错误:交易日不足 {baseline_days}")
return
# 只取最近 N 天
recent_days = trading_days[-baseline_days:]
print(f"使用 {len(recent_days)} 天数据: {recent_days[0]} ~ {recent_days[-1]}")
# 加载原始数据
all_data = []
for trade_date in tqdm(recent_days, desc="加载数据"):
cache_file = os.path.join(RAW_CACHE_DIR, f'raw_{trade_date}.parquet')
if os.path.exists(cache_file):
df = pd.read_parquet(cache_file)
else:
df = compute_raw_concept_features(trade_date, concepts, all_stocks)
if not df.empty:
all_data.append(df)
if not all_data:
print("错误:无数据")
return
combined = pd.concat(all_data, ignore_index=True)
print(f"总数据量: {len(combined):,}")
# 按概念计算基线
baselines = {}
for concept_id, group in tqdm(combined.groupby('concept_id'), desc="计算基线"):
baseline_dict = {}
for time_slot, slot_group in group.groupby('time_slot'):
if len(slot_group) < CONFIG['min_baseline_samples']:
continue
alpha_std = slot_group['alpha'].std()
amt_std = slot_group['total_amt'].std()
rank_std = slot_group['rank_pct'].std()
baseline_dict[time_slot] = {
'alpha_mean': float(slot_group['alpha'].mean()),
'alpha_std': float(max(alpha_std if pd.notna(alpha_std) else 1.0, 0.1)),
'amt_mean': float(slot_group['total_amt'].mean()),
'amt_std': float(max(amt_std if pd.notna(amt_std) else slot_group['total_amt'].mean() * 0.5, 1.0)),
'rank_mean': float(slot_group['rank_pct'].mean()),
'rank_std': float(max(rank_std if pd.notna(rank_std) else 0.2, 0.05)),
'sample_count': len(slot_group),
}
if baseline_dict:
baselines[concept_id] = baseline_dict
print(f"计算了 {len(baselines)} 个概念的基线")
# 保存
os.makedirs(BASELINE_DIR, exist_ok=True)
baseline_file = os.path.join(BASELINE_DIR, 'realtime_baseline.pkl')
with open(baseline_file, 'wb') as f:
pickle.dump({
'baselines': baselines,
'update_time': datetime.now().isoformat(),
'date_range': [recent_days[0], recent_days[-1]],
'baseline_days': baseline_days,
}, f)
print(f"基线已保存: {baseline_file}")
print("=" * 60)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--days', type=int, default=20, help='基线天数')
args = parser.parse_args()
update_rolling_baseline(args.days)