update pay ui
This commit is contained in:
@@ -26,7 +26,9 @@ import hashlib
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, List, Set, Tuple
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||
from multiprocessing import Manager
|
||||
import multiprocessing
|
||||
import warnings
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
@@ -128,7 +130,7 @@ def get_all_concepts() -> List[dict]:
|
||||
hits = resp['hits']['hits']
|
||||
|
||||
ES_CLIENT.clear_scroll(scroll_id=scroll_id)
|
||||
logger.info(f"获取到 {len(concepts)} 个概念")
|
||||
print(f"获取到 {len(concepts)} 个概念")
|
||||
return concepts
|
||||
|
||||
|
||||
@@ -148,7 +150,7 @@ def get_trading_days(start_date: str, end_date: str) -> List[str]:
|
||||
|
||||
result = client.execute(query)
|
||||
days = [row[0].strftime('%Y-%m-%d') for row in result]
|
||||
logger.info(f"找到 {len(days)} 个交易日: {days[0]} ~ {days[-1]}")
|
||||
print(f"找到 {len(days)} 个交易日: {days[0]} ~ {days[-1]}")
|
||||
return days
|
||||
|
||||
|
||||
@@ -223,21 +225,23 @@ def get_daily_index_data(trade_date: str, index_code: str = REFERENCE_INDEX) ->
|
||||
|
||||
|
||||
def get_prev_close(stock_codes: List[str], trade_date: str) -> Dict[str, float]:
|
||||
"""获取昨收价"""
|
||||
"""获取昨收价(上一交易日的收盘价 F007N)"""
|
||||
valid_codes = [c for c in stock_codes if c and len(c) == 6 and c.isdigit()]
|
||||
if not valid_codes:
|
||||
return {}
|
||||
|
||||
codes_str = "','".join(valid_codes)
|
||||
|
||||
# 注意:F007N 是"最近成交价"即当日收盘价,F002N 是"昨日收盘价"
|
||||
# 我们需要查上一交易日的 F007N(那天的收盘价)作为今天的昨收
|
||||
query = f"""
|
||||
SELECT SECCODE, F002N
|
||||
SELECT SECCODE, F007N
|
||||
FROM ea_trade
|
||||
WHERE SECCODE IN ('{codes_str}')
|
||||
AND TRADEDATE = (
|
||||
SELECT MAX(TRADEDATE) FROM ea_trade WHERE TRADEDATE < '{trade_date}'
|
||||
)
|
||||
AND F002N IS NOT NULL AND F002N > 0
|
||||
AND F007N IS NOT NULL AND F007N > 0
|
||||
"""
|
||||
|
||||
try:
|
||||
@@ -245,7 +249,7 @@ def get_prev_close(stock_codes: List[str], trade_date: str) -> Dict[str, float]:
|
||||
result = conn.execute(text(query))
|
||||
return {row[0]: float(row[1]) for row in result if row[1]}
|
||||
except Exception as e:
|
||||
logger.error(f"获取昨收价失败: {e}")
|
||||
print(f"获取昨收价失败: {e}")
|
||||
return {}
|
||||
|
||||
|
||||
@@ -264,7 +268,7 @@ def get_index_prev_close(trade_date: str, index_code: str = REFERENCE_INDEX) ->
|
||||
if result and result[0]:
|
||||
return float(result[0])
|
||||
except Exception as e:
|
||||
logger.error(f"获取指数昨收失败: {e}")
|
||||
print(f"获取指数昨收失败: {e}")
|
||||
|
||||
return None
|
||||
|
||||
@@ -285,25 +289,19 @@ def compute_daily_features(
|
||||
"""
|
||||
|
||||
# 1. 获取数据
|
||||
logger.info(f" 获取股票数据...")
|
||||
stock_df = get_daily_stock_data(trade_date, all_stocks)
|
||||
if stock_df.empty:
|
||||
logger.warning(f" 无股票数据")
|
||||
return pd.DataFrame()
|
||||
|
||||
logger.info(f" 获取指数数据...")
|
||||
index_df = get_daily_index_data(trade_date)
|
||||
if index_df.empty:
|
||||
logger.warning(f" 无指数数据")
|
||||
return pd.DataFrame()
|
||||
|
||||
# 2. 获取昨收价
|
||||
logger.info(f" 获取昨收价...")
|
||||
prev_close = get_prev_close(all_stocks, trade_date)
|
||||
index_prev_close = get_index_prev_close(trade_date)
|
||||
|
||||
if not prev_close or not index_prev_close:
|
||||
logger.warning(f" 无昨收价数据")
|
||||
return pd.DataFrame()
|
||||
|
||||
# 3. 计算股票涨跌幅和成交额
|
||||
@@ -317,7 +315,6 @@ def compute_daily_features(
|
||||
|
||||
# 5. 获取所有时间点
|
||||
timestamps = sorted(stock_df['timestamp'].unique())
|
||||
logger.info(f" 时间点数: {len(timestamps)}")
|
||||
|
||||
# 6. 按时间点计算概念特征
|
||||
results = []
|
||||
@@ -414,87 +411,126 @@ def compute_daily_features(
|
||||
if amt_delta_std > 0:
|
||||
final_df['amt_delta'] = final_df['amt_delta'] / amt_delta_std
|
||||
|
||||
logger.info(f" 计算完成: {len(final_df)} 条记录")
|
||||
return final_df
|
||||
|
||||
|
||||
# ==================== 主流程 ====================
|
||||
|
||||
def process_single_day(trade_date: str, concepts: List[dict], all_stocks: List[str]) -> str:
|
||||
"""处理单个交易日"""
|
||||
def process_single_day(args) -> Tuple[str, bool]:
|
||||
"""
|
||||
处理单个交易日(多进程版本)
|
||||
|
||||
Args:
|
||||
args: (trade_date, concepts, all_stocks) 元组
|
||||
|
||||
Returns:
|
||||
(trade_date, success) 元组
|
||||
"""
|
||||
trade_date, concepts, all_stocks = args
|
||||
output_file = os.path.join(OUTPUT_DIR, f'features_{trade_date}.parquet')
|
||||
|
||||
# 检查是否已处理
|
||||
if os.path.exists(output_file):
|
||||
logger.info(f"[{trade_date}] 已存在,跳过")
|
||||
return output_file
|
||||
print(f"[{trade_date}] 已存在,跳过")
|
||||
return (trade_date, True)
|
||||
|
||||
logger.info(f"[{trade_date}] 开始处理...")
|
||||
print(f"[{trade_date}] 开始处理...")
|
||||
|
||||
try:
|
||||
df = compute_daily_features(trade_date, concepts, all_stocks)
|
||||
|
||||
if df.empty:
|
||||
logger.warning(f"[{trade_date}] 无数据")
|
||||
return None
|
||||
print(f"[{trade_date}] 无数据")
|
||||
return (trade_date, False)
|
||||
|
||||
# 保存
|
||||
df.to_parquet(output_file, index=False)
|
||||
logger.info(f"[{trade_date}] 保存完成: {output_file}")
|
||||
return output_file
|
||||
print(f"[{trade_date}] 保存完成")
|
||||
return (trade_date, True)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[{trade_date}] 处理失败: {e}")
|
||||
print(f"[{trade_date}] 处理失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return None
|
||||
return (trade_date, False)
|
||||
|
||||
|
||||
def main():
|
||||
import argparse
|
||||
from tqdm import tqdm
|
||||
|
||||
parser = argparse.ArgumentParser(description='准备训练数据')
|
||||
parser.add_argument('--start', type=str, default='2022-01-01', help='开始日期')
|
||||
parser.add_argument('--end', type=str, default=None, help='结束日期(默认今天)')
|
||||
parser.add_argument('--workers', type=int, default=1, help='并行数(建议1,避免数据库压力)')
|
||||
parser.add_argument('--workers', type=int, default=18, help='并行进程数(默认18)')
|
||||
parser.add_argument('--force', action='store_true', help='强制重新处理已存在的文件')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
end_date = args.end or datetime.now().strftime('%Y-%m-%d')
|
||||
|
||||
logger.info("=" * 60)
|
||||
logger.info("数据准备 - Transformer Autoencoder 训练数据")
|
||||
logger.info("=" * 60)
|
||||
logger.info(f"日期范围: {args.start} ~ {end_date}")
|
||||
print("=" * 60)
|
||||
print("数据准备 - Transformer Autoencoder 训练数据")
|
||||
print("=" * 60)
|
||||
print(f"日期范围: {args.start} ~ {end_date}")
|
||||
print(f"并行进程数: {args.workers}")
|
||||
|
||||
# 1. 获取概念列表
|
||||
concepts = get_all_concepts()
|
||||
|
||||
# 收集所有股票
|
||||
all_stocks = list(set(s for c in concepts for s in c['stocks']))
|
||||
logger.info(f"股票总数: {len(all_stocks)}")
|
||||
print(f"股票总数: {len(all_stocks)}")
|
||||
|
||||
# 2. 获取交易日列表
|
||||
trading_days = get_trading_days(args.start, end_date)
|
||||
|
||||
if not trading_days:
|
||||
logger.error("无交易日数据")
|
||||
print("无交易日数据")
|
||||
return
|
||||
|
||||
# 3. 处理每个交易日
|
||||
logger.info(f"\n开始处理 {len(trading_days)} 个交易日...")
|
||||
# 如果强制模式,删除已有文件
|
||||
if args.force:
|
||||
for trade_date in trading_days:
|
||||
output_file = os.path.join(OUTPUT_DIR, f'features_{trade_date}.parquet')
|
||||
if os.path.exists(output_file):
|
||||
os.remove(output_file)
|
||||
print(f"删除已有文件: {output_file}")
|
||||
|
||||
# 3. 准备任务参数
|
||||
tasks = [(trade_date, concepts, all_stocks) for trade_date in trading_days]
|
||||
|
||||
print(f"\n开始处理 {len(trading_days)} 个交易日({args.workers} 进程并行)...")
|
||||
|
||||
# 4. 多进程处理
|
||||
success_count = 0
|
||||
for i, trade_date in enumerate(trading_days):
|
||||
logger.info(f"\n[{i+1}/{len(trading_days)}] {trade_date}")
|
||||
result = process_single_day(trade_date, concepts, all_stocks)
|
||||
if result:
|
||||
success_count += 1
|
||||
failed_dates = []
|
||||
|
||||
logger.info("\n" + "=" * 60)
|
||||
logger.info(f"处理完成: {success_count}/{len(trading_days)} 个交易日")
|
||||
logger.info(f"数据保存在: {OUTPUT_DIR}")
|
||||
logger.info("=" * 60)
|
||||
with ProcessPoolExecutor(max_workers=args.workers) as executor:
|
||||
# 提交所有任务
|
||||
futures = {executor.submit(process_single_day, task): task[0] for task in tasks}
|
||||
|
||||
# 使用 tqdm 显示进度
|
||||
with tqdm(total=len(futures), desc="处理进度", unit="天") as pbar:
|
||||
for future in as_completed(futures):
|
||||
trade_date = futures[future]
|
||||
try:
|
||||
result_date, success = future.result()
|
||||
if success:
|
||||
success_count += 1
|
||||
else:
|
||||
failed_dates.append(result_date)
|
||||
except Exception as e:
|
||||
print(f"\n[{trade_date}] 进程异常: {e}")
|
||||
failed_dates.append(trade_date)
|
||||
pbar.update(1)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print(f"处理完成: {success_count}/{len(trading_days)} 个交易日")
|
||||
if failed_dates:
|
||||
print(f"失败日期: {failed_dates[:10]}{'...' if len(failed_dates) > 10 else ''}")
|
||||
print(f"数据保存在: {OUTPUT_DIR}")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user