update pay ui
This commit is contained in:
@@ -93,12 +93,12 @@ def backtest_single_day_hybrid(
|
||||
seq_len: int = 30
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
使用融合检测器回测单天数据
|
||||
使用融合检测器回测单天数据(批量优化版)
|
||||
"""
|
||||
alerts = []
|
||||
|
||||
# 按概念分组
|
||||
grouped = df.groupby('concept_id', sort=False)
|
||||
# 按概念分组,预先构建字典
|
||||
grouped_dict = {cid: cdf for cid, cdf in df.groupby('concept_id', sort=False)}
|
||||
|
||||
# 冷却记录
|
||||
cooldown = {}
|
||||
@@ -114,27 +114,46 @@ def backtest_single_day_hybrid(
|
||||
current_time = all_timestamps[t_idx]
|
||||
window_start_time = all_timestamps[t_idx - seq_len + 1]
|
||||
|
||||
minute_alerts = []
|
||||
# 批量收集该时刻所有候选概念
|
||||
batch_sequences = []
|
||||
batch_features = []
|
||||
batch_infos = []
|
||||
|
||||
for concept_id, concept_df in grouped_dict.items():
|
||||
# 检查冷却(提前过滤)
|
||||
if concept_id in cooldown:
|
||||
last_alert = cooldown[concept_id]
|
||||
if isinstance(current_time, datetime):
|
||||
time_diff = (current_time - last_alert).total_seconds() / 60
|
||||
else:
|
||||
time_diff = BACKTEST_CONFIG['cooldown_minutes'] + 1
|
||||
if time_diff < BACKTEST_CONFIG['cooldown_minutes']:
|
||||
continue
|
||||
|
||||
for concept_id, concept_df in grouped:
|
||||
# 获取时间窗口内的数据
|
||||
mask = (concept_df['timestamp'] >= window_start_time) & (concept_df['timestamp'] <= current_time)
|
||||
window_df = concept_df[mask].sort_values('timestamp')
|
||||
window_df = concept_df.loc[mask]
|
||||
|
||||
if len(window_df) < seq_len:
|
||||
continue
|
||||
|
||||
window_df = window_df.tail(seq_len)
|
||||
window_df = window_df.sort_values('timestamp').tail(seq_len)
|
||||
|
||||
# 提取特征序列(给 ML 模型)
|
||||
# 当前时刻特征
|
||||
current_row = window_df.iloc[-1]
|
||||
alpha = current_row.get('alpha', 0)
|
||||
|
||||
# 过滤微小波动(提前过滤)
|
||||
if abs(alpha) < BACKTEST_CONFIG['min_alpha_abs']:
|
||||
continue
|
||||
|
||||
# 提取特征序列
|
||||
sequence = window_df[FEATURES].values
|
||||
sequence = np.nan_to_num(sequence, nan=0.0, posinf=0.0, neginf=0.0)
|
||||
sequence = np.clip(sequence, -BACKTEST_CONFIG['clip_value'], BACKTEST_CONFIG['clip_value'])
|
||||
|
||||
# 当前时刻特征(给规则系统)
|
||||
current_row = window_df.iloc[-1]
|
||||
current_features = {
|
||||
'alpha': current_row.get('alpha', 0),
|
||||
'alpha': alpha,
|
||||
'alpha_delta': current_row.get('alpha_delta', 0),
|
||||
'amt_ratio': current_row.get('amt_ratio', 1),
|
||||
'amt_delta': current_row.get('amt_delta', 0),
|
||||
@@ -142,41 +161,79 @@ def backtest_single_day_hybrid(
|
||||
'limit_up_ratio': current_row.get('limit_up_ratio', 0),
|
||||
}
|
||||
|
||||
# 过滤微小波动
|
||||
if abs(current_features['alpha']) < BACKTEST_CONFIG['min_alpha_abs']:
|
||||
batch_sequences.append(sequence)
|
||||
batch_features.append(current_features)
|
||||
batch_infos.append({
|
||||
'concept_id': concept_id,
|
||||
'stock_count': current_row.get('stock_count', 0),
|
||||
'total_amt': current_row.get('total_amt', 0),
|
||||
})
|
||||
|
||||
if not batch_sequences:
|
||||
continue
|
||||
|
||||
# 批量 ML 推理
|
||||
sequences_array = np.array(batch_sequences)
|
||||
ml_scores = detector.ml_scorer.score(sequences_array) if detector.ml_scorer.is_ready() else [0.0] * len(batch_sequences)
|
||||
if isinstance(ml_scores, float):
|
||||
ml_scores = [ml_scores]
|
||||
|
||||
# 批量规则评分 + 融合
|
||||
minute_alerts = []
|
||||
for i, (features, info) in enumerate(zip(batch_features, batch_infos)):
|
||||
concept_id = info['concept_id']
|
||||
|
||||
# 规则评分
|
||||
rule_score, rule_details = detector.rule_scorer.score(features)
|
||||
|
||||
# ML 评分
|
||||
ml_score = ml_scores[i] if i < len(ml_scores) else 0.0
|
||||
|
||||
# 融合
|
||||
w1 = detector.config['rule_weight']
|
||||
w2 = detector.config['ml_weight']
|
||||
final_score = w1 * rule_score + w2 * ml_score
|
||||
|
||||
# 判断是否异动
|
||||
is_anomaly = False
|
||||
trigger_reason = ''
|
||||
|
||||
if rule_score >= detector.config['rule_trigger']:
|
||||
is_anomaly = True
|
||||
trigger_reason = f'规则强信号({rule_score:.0f}分)'
|
||||
elif ml_score >= detector.config['ml_trigger']:
|
||||
is_anomaly = True
|
||||
trigger_reason = f'ML强信号({ml_score:.0f}分)'
|
||||
elif final_score >= detector.config['fusion_trigger']:
|
||||
is_anomaly = True
|
||||
trigger_reason = f'融合触发({final_score:.0f}分)'
|
||||
|
||||
if not is_anomaly:
|
||||
continue
|
||||
|
||||
# 检查冷却
|
||||
if concept_id in cooldown:
|
||||
last_alert = cooldown[concept_id]
|
||||
if isinstance(current_time, datetime):
|
||||
time_diff = (current_time - last_alert).total_seconds() / 60
|
||||
else:
|
||||
time_diff = BACKTEST_CONFIG['cooldown_minutes'] + 1
|
||||
# 异动类型
|
||||
alpha = features.get('alpha', 0)
|
||||
if alpha >= 1.5:
|
||||
anomaly_type = 'surge_up'
|
||||
elif alpha <= -1.5:
|
||||
anomaly_type = 'surge_down'
|
||||
elif features.get('amt_ratio', 1) >= 3.0:
|
||||
anomaly_type = 'volume_spike'
|
||||
else:
|
||||
anomaly_type = 'unknown'
|
||||
|
||||
if time_diff < BACKTEST_CONFIG['cooldown_minutes']:
|
||||
continue
|
||||
|
||||
# 融合检测
|
||||
result = detector.detect(current_features, sequence)
|
||||
|
||||
if not result.is_anomaly:
|
||||
continue
|
||||
|
||||
# 记录异动
|
||||
alert = {
|
||||
'concept_id': concept_id,
|
||||
'alert_time': current_time,
|
||||
'trade_date': date,
|
||||
'alert_type': result.anomaly_type,
|
||||
'final_score': result.final_score,
|
||||
'rule_score': result.rule_score,
|
||||
'ml_score': result.ml_score,
|
||||
'trigger_reason': result.trigger_reason,
|
||||
'triggered_rules': list(result.rule_details.keys()),
|
||||
**current_features,
|
||||
'stock_count': current_row.get('stock_count', 0),
|
||||
'total_amt': current_row.get('total_amt', 0),
|
||||
'alert_type': anomaly_type,
|
||||
'final_score': final_score,
|
||||
'rule_score': rule_score,
|
||||
'ml_score': ml_score,
|
||||
'trigger_reason': trigger_reason,
|
||||
'triggered_rules': list(rule_details.keys()),
|
||||
**features,
|
||||
**info,
|
||||
}
|
||||
|
||||
minute_alerts.append(alert)
|
||||
@@ -341,6 +398,8 @@ def main():
|
||||
help='规则权重 (0-1)')
|
||||
parser.add_argument('--ml-weight', type=float, default=0.4,
|
||||
help='ML权重 (0-1)')
|
||||
parser.add_argument('--device', type=str, default='cuda',
|
||||
help='设备 (cuda/cpu),默认 cuda')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -355,15 +414,19 @@ def main():
|
||||
print(f"模型目录: {args.checkpoint_dir}")
|
||||
print(f"规则权重: {args.rule_weight}")
|
||||
print(f"ML权重: {args.ml_weight}")
|
||||
print(f"设备: {args.device}")
|
||||
print(f"Dry Run: {args.dry_run}")
|
||||
print("=" * 60)
|
||||
|
||||
# 初始化融合检测器
|
||||
# 初始化融合检测器(使用 GPU)
|
||||
config = {
|
||||
'rule_weight': args.rule_weight,
|
||||
'ml_weight': args.ml_weight,
|
||||
}
|
||||
detector = create_detector(args.checkpoint_dir, config)
|
||||
|
||||
# 修改 detector.py 中 MLScorer 的设备
|
||||
from detector import HybridAnomalyDetector
|
||||
detector = HybridAnomalyDetector(config, args.checkpoint_dir, device=args.device)
|
||||
|
||||
# 获取可用日期
|
||||
dates = get_available_dates(args.data_dir, args.start, args.end)
|
||||
|
||||
Reference in New Issue
Block a user