update pay ui

This commit is contained in:
2025-12-10 11:02:09 +08:00
parent d9daaeed19
commit e501ac3819
21 changed files with 5514 additions and 151 deletions

View File

@@ -26,7 +26,9 @@ import hashlib
import json
import logging
from typing import Dict, List, Set, Tuple
from concurrent.futures import ThreadPoolExecutor, as_completed
from concurrent.futures import ProcessPoolExecutor, as_completed
from multiprocessing import Manager
import multiprocessing
import warnings
warnings.filterwarnings('ignore')
@@ -128,7 +130,7 @@ def get_all_concepts() -> List[dict]:
hits = resp['hits']['hits']
ES_CLIENT.clear_scroll(scroll_id=scroll_id)
logger.info(f"获取到 {len(concepts)} 个概念")
print(f"获取到 {len(concepts)} 个概念")
return concepts
@@ -148,7 +150,7 @@ def get_trading_days(start_date: str, end_date: str) -> List[str]:
result = client.execute(query)
days = [row[0].strftime('%Y-%m-%d') for row in result]
logger.info(f"找到 {len(days)} 个交易日: {days[0]} ~ {days[-1]}")
print(f"找到 {len(days)} 个交易日: {days[0]} ~ {days[-1]}")
return days
@@ -223,21 +225,23 @@ def get_daily_index_data(trade_date: str, index_code: str = REFERENCE_INDEX) ->
def get_prev_close(stock_codes: List[str], trade_date: str) -> Dict[str, float]:
"""获取昨收价"""
"""获取昨收价(上一交易日的收盘价 F007N"""
valid_codes = [c for c in stock_codes if c and len(c) == 6 and c.isdigit()]
if not valid_codes:
return {}
codes_str = "','".join(valid_codes)
# 注意F007N 是"最近成交价"即当日收盘价F002N 是"昨日收盘价"
# 我们需要查上一交易日的 F007N那天的收盘价作为今天的昨收
query = f"""
SELECT SECCODE, F002N
SELECT SECCODE, F007N
FROM ea_trade
WHERE SECCODE IN ('{codes_str}')
AND TRADEDATE = (
SELECT MAX(TRADEDATE) FROM ea_trade WHERE TRADEDATE < '{trade_date}'
)
AND F002N IS NOT NULL AND F002N > 0
AND F007N IS NOT NULL AND F007N > 0
"""
try:
@@ -245,7 +249,7 @@ def get_prev_close(stock_codes: List[str], trade_date: str) -> Dict[str, float]:
result = conn.execute(text(query))
return {row[0]: float(row[1]) for row in result if row[1]}
except Exception as e:
logger.error(f"获取昨收价失败: {e}")
print(f"获取昨收价失败: {e}")
return {}
@@ -264,7 +268,7 @@ def get_index_prev_close(trade_date: str, index_code: str = REFERENCE_INDEX) ->
if result and result[0]:
return float(result[0])
except Exception as e:
logger.error(f"获取指数昨收失败: {e}")
print(f"获取指数昨收失败: {e}")
return None
@@ -285,25 +289,19 @@ def compute_daily_features(
"""
# 1. 获取数据
logger.info(f" 获取股票数据...")
stock_df = get_daily_stock_data(trade_date, all_stocks)
if stock_df.empty:
logger.warning(f" 无股票数据")
return pd.DataFrame()
logger.info(f" 获取指数数据...")
index_df = get_daily_index_data(trade_date)
if index_df.empty:
logger.warning(f" 无指数数据")
return pd.DataFrame()
# 2. 获取昨收价
logger.info(f" 获取昨收价...")
prev_close = get_prev_close(all_stocks, trade_date)
index_prev_close = get_index_prev_close(trade_date)
if not prev_close or not index_prev_close:
logger.warning(f" 无昨收价数据")
return pd.DataFrame()
# 3. 计算股票涨跌幅和成交额
@@ -317,7 +315,6 @@ def compute_daily_features(
# 5. 获取所有时间点
timestamps = sorted(stock_df['timestamp'].unique())
logger.info(f" 时间点数: {len(timestamps)}")
# 6. 按时间点计算概念特征
results = []
@@ -414,87 +411,126 @@ def compute_daily_features(
if amt_delta_std > 0:
final_df['amt_delta'] = final_df['amt_delta'] / amt_delta_std
logger.info(f" 计算完成: {len(final_df)} 条记录")
return final_df
# ==================== 主流程 ====================
def process_single_day(trade_date: str, concepts: List[dict], all_stocks: List[str]) -> str:
"""处理单个交易日"""
def process_single_day(args) -> Tuple[str, bool]:
"""
处理单个交易日(多进程版本)
Args:
args: (trade_date, concepts, all_stocks) 元组
Returns:
(trade_date, success) 元组
"""
trade_date, concepts, all_stocks = args
output_file = os.path.join(OUTPUT_DIR, f'features_{trade_date}.parquet')
# 检查是否已处理
if os.path.exists(output_file):
logger.info(f"[{trade_date}] 已存在,跳过")
return output_file
print(f"[{trade_date}] 已存在,跳过")
return (trade_date, True)
logger.info(f"[{trade_date}] 开始处理...")
print(f"[{trade_date}] 开始处理...")
try:
df = compute_daily_features(trade_date, concepts, all_stocks)
if df.empty:
logger.warning(f"[{trade_date}] 无数据")
return None
print(f"[{trade_date}] 无数据")
return (trade_date, False)
# 保存
df.to_parquet(output_file, index=False)
logger.info(f"[{trade_date}] 保存完成: {output_file}")
return output_file
print(f"[{trade_date}] 保存完成")
return (trade_date, True)
except Exception as e:
logger.error(f"[{trade_date}] 处理失败: {e}")
print(f"[{trade_date}] 处理失败: {e}")
import traceback
traceback.print_exc()
return None
return (trade_date, False)
def main():
import argparse
from tqdm import tqdm
parser = argparse.ArgumentParser(description='准备训练数据')
parser.add_argument('--start', type=str, default='2022-01-01', help='开始日期')
parser.add_argument('--end', type=str, default=None, help='结束日期(默认今天)')
parser.add_argument('--workers', type=int, default=1, help='并行建议1避免数据库压力')
parser.add_argument('--workers', type=int, default=18, help='并行进程数默认18')
parser.add_argument('--force', action='store_true', help='强制重新处理已存在的文件')
args = parser.parse_args()
end_date = args.end or datetime.now().strftime('%Y-%m-%d')
logger.info("=" * 60)
logger.info("数据准备 - Transformer Autoencoder 训练数据")
logger.info("=" * 60)
logger.info(f"日期范围: {args.start} ~ {end_date}")
print("=" * 60)
print("数据准备 - Transformer Autoencoder 训练数据")
print("=" * 60)
print(f"日期范围: {args.start} ~ {end_date}")
print(f"并行进程数: {args.workers}")
# 1. 获取概念列表
concepts = get_all_concepts()
# 收集所有股票
all_stocks = list(set(s for c in concepts for s in c['stocks']))
logger.info(f"股票总数: {len(all_stocks)}")
print(f"股票总数: {len(all_stocks)}")
# 2. 获取交易日列表
trading_days = get_trading_days(args.start, end_date)
if not trading_days:
logger.error("无交易日数据")
print("无交易日数据")
return
# 3. 处理每个交易日
logger.info(f"\n开始处理 {len(trading_days)} 个交易日...")
# 如果强制模式,删除已有文件
if args.force:
for trade_date in trading_days:
output_file = os.path.join(OUTPUT_DIR, f'features_{trade_date}.parquet')
if os.path.exists(output_file):
os.remove(output_file)
print(f"删除已有文件: {output_file}")
# 3. 准备任务参数
tasks = [(trade_date, concepts, all_stocks) for trade_date in trading_days]
print(f"\n开始处理 {len(trading_days)} 个交易日({args.workers} 进程并行)...")
# 4. 多进程处理
success_count = 0
for i, trade_date in enumerate(trading_days):
logger.info(f"\n[{i+1}/{len(trading_days)}] {trade_date}")
result = process_single_day(trade_date, concepts, all_stocks)
if result:
success_count += 1
failed_dates = []
logger.info("\n" + "=" * 60)
logger.info(f"处理完成: {success_count}/{len(trading_days)} 个交易日")
logger.info(f"数据保存在: {OUTPUT_DIR}")
logger.info("=" * 60)
with ProcessPoolExecutor(max_workers=args.workers) as executor:
# 提交所有任务
futures = {executor.submit(process_single_day, task): task[0] for task in tasks}
# 使用 tqdm 显示进度
with tqdm(total=len(futures), desc="处理进度", unit="") as pbar:
for future in as_completed(futures):
trade_date = futures[future]
try:
result_date, success = future.result()
if success:
success_count += 1
else:
failed_dates.append(result_date)
except Exception as e:
print(f"\n[{trade_date}] 进程异常: {e}")
failed_dates.append(trade_date)
pbar.update(1)
print("\n" + "=" * 60)
print(f"处理完成: {success_count}/{len(trading_days)} 个交易日")
if failed_dates:
print(f"失败日期: {failed_dates[:10]}{'...' if len(failed_dates) > 10 else ''}")
print(f"数据保存在: {OUTPUT_DIR}")
print("=" * 60)
if __name__ == "__main__":