Files
vf_react/get_related_chg.py

491 lines
19 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from clickhouse_driver import Client as Cclient
from sqlalchemy import create_engine, text
from datetime import datetime, time as dt_time, timedelta
import time
import pandas as pd
import os
# 读取交易日数据
script_dir = os.path.dirname(os.path.abspath(__file__))
TRADING_DAYS_FILE = os.path.join(script_dir, 'tdays.csv')
trading_days_df = pd.read_csv(TRADING_DAYS_FILE)
trading_days_df['DateTime'] = pd.to_datetime(trading_days_df['DateTime']).dt.date
TRADING_DAYS = sorted(trading_days_df['DateTime'].tolist()) # 排序后的交易日列表
def get_clickhouse_client():
return Cclient(
host='127.0.0.1',
port=9000,
user='default',
password='Zzl33818!',
database='stock'
)
def get_mysql_engine():
return create_engine(
"mysql+pymysql://root:Zzl33818!@127.0.0.1:3306/stock",
echo=False
)
def is_trading_time(check_datetime=None):
"""判断是否在交易时间内
Args:
check_datetime: 要检查的时间,默认为当前时间
Returns:
bool: True表示在交易时间内
"""
if check_datetime is None:
check_datetime = datetime.now()
# 检查是否是交易日
check_date = check_datetime.date()
if check_date not in TRADING_DAYS:
return False
# 检查是否在交易时段内
check_time = check_datetime.time()
# 上午时段: 9:30 - 11:30
morning_start = dt_time(9, 30)
morning_end = dt_time(11, 30)
# 下午时段: 13:00 - 15:00
afternoon_start = dt_time(13, 0)
afternoon_end = dt_time(15, 0)
is_morning = morning_start <= check_time <= morning_end
is_afternoon = afternoon_start <= check_time <= afternoon_end
return is_morning or is_afternoon
def get_next_trading_time():
"""获取下一个交易时段的开始时间"""
now = datetime.now()
current_date = now.date()
current_time = now.time()
# 如果今天是交易日
if current_date in TRADING_DAYS:
morning_start = dt_time(9, 30)
afternoon_start = dt_time(13, 0)
# 如果还没到上午开盘
if current_time < morning_start:
return datetime.combine(current_date, morning_start)
# 如果在上午休市后,下午还没开盘
elif dt_time(11, 30) < current_time < afternoon_start:
return datetime.combine(current_date, afternoon_start)
# 否则找下一个交易日的上午开盘时间
for td in TRADING_DAYS:
if td > current_date:
return datetime.combine(td, dt_time(9, 30))
# 如果没有找到未来交易日,返回明天上午9:30(可能需要更新交易日数据)
return datetime.combine(current_date + timedelta(days=1), dt_time(9, 30))
def get_next_trading_day(date):
"""获取下一个交易日"""
for td in TRADING_DAYS:
if td > date:
return td
return None
def get_nth_trading_day_after(start_date, n=7):
"""获取start_date之后的第n个交易日"""
try:
start_idx = TRADING_DAYS.index(start_date)
target_idx = start_idx + n
if target_idx < len(TRADING_DAYS):
return TRADING_DAYS[target_idx]
except (ValueError, IndexError):
pass
# 如果start_date不在交易日列表中找到它之后的交易日
future_days = [d for d in TRADING_DAYS if d > start_date]
if len(future_days) >= n:
return future_days[n - 1]
elif future_days:
return future_days[-1] # 返回最后一个可用的交易日
return None
def get_trading_day_info(event_datetime):
"""获取事件对应的交易日信息"""
event_date = event_datetime.date()
market_close = dt_time(15, 0)
# 如果是交易日且在收盘前,使用当天
if event_date in TRADING_DAYS and event_datetime.time() <= market_close:
return event_date
# 否则使用下一个交易日
return get_next_trading_day(event_date)
def calculate_stock_changes(stock_codes, event_datetime, ch_client, debug=False):
"""批量计算一个事件关联的所有股票涨跌幅"""
if not stock_codes:
return None, None, None
event_date = event_datetime.date()
event_time = event_datetime.time()
market_open = dt_time(9, 30)
market_close = dt_time(15, 0)
# 确定起始时间点(事件发生后的第一个有效价格点)
if event_date in TRADING_DAYS and market_open <= event_time <= market_close:
# 事件在交易时间内发生 → 用事件发生时的价格作为起点
start_datetime = event_datetime
trading_date = event_date
end_datetime = datetime.combine(trading_date, market_close)
if debug:
print(f" 事件在交易时间内: {event_datetime} -> 起点={start_datetime}")
else:
# 事件在交易时间外发生 → 用下一个交易日开盘价作为起点
trading_date = get_trading_day_info(event_datetime)
if not trading_date:
if debug:
print(f" 找不到交易日: {event_datetime}")
return None, None, None
start_datetime = datetime.combine(trading_date, market_open)
end_datetime = datetime.combine(trading_date, market_close)
if debug:
print(f" 事件在非交易时间: {event_datetime} -> 下一交易日={trading_date}, 起点={start_datetime}")
# 获取7个交易日后的日期
week_trading_date = get_nth_trading_day_after(trading_date, 7)
if not week_trading_date:
# 降级:如果没有足够的未来交易日,就用当前能找到的最远日期
week_trading_date = trading_date + timedelta(days=10)
week_end_datetime = datetime.combine(week_trading_date, market_close)
if debug:
print(f" 查询范围: {start_datetime} -> 当日={end_datetime}, 周末={week_end_datetime}")
print(f" 股票代码: {stock_codes}")
# 一次性查询所有股票的价格数据
results = ch_client.execute("""
SELECT code,
-- 起始价格:事件发生时或之后的第一个价格
argMin(close, timestamp) as start_price,
-- 当日收盘价:当日交易结束时的最后一个价格
argMax(
close, if(timestamp <= %(end)s, timestamp, toDateTime('1970-01-01'))
) as day_close_price,
-- 周后收盘价7个交易日后的收盘价
argMax(
close, if(timestamp <= %(week_end)s, timestamp, toDateTime('1970-01-01'))
) as week_close_price
FROM stock_minute
WHERE code IN %(codes)s
AND timestamp >= %(start)s
AND timestamp <= %(week_end)s
GROUP BY code
HAVING start_price > 0
""", {
'codes': tuple(stock_codes),
'start': start_datetime,
'end': end_datetime,
'week_end': week_end_datetime
})
if debug:
print(f" 查询到 {len(results)} 只股票的数据")
if not results:
return None, None, None
# 计算涨跌幅
day_changes = []
week_changes = []
for code, start_price, day_close, week_close in results:
if start_price and start_price > 0:
# 当日涨跌幅(从事件发生到当日收盘)
if day_close and day_close > 0:
day_change = (day_close - start_price) / start_price * 100
day_changes.append(day_change)
# 周度涨跌幅从事件发生到第7个交易日收盘
if week_close and week_close > 0:
week_change = (week_close - start_price) / start_price * 100
week_changes.append(week_change)
# 计算统计值
avg_change = sum(day_changes) / len(day_changes) if day_changes else None
max_change = max(day_changes) if day_changes else None
avg_week_change = sum(week_changes) / len(week_changes) if week_changes else None
if debug:
print(
f" 结果: 日均={avg_change:.2f}% 日最大={max_change:.2f}% 周均={avg_week_change:.2f}%" if avg_change else " 结果: 无有效数据")
return avg_change, max_change, avg_week_change
def update_event_statistics(start_date=None, end_date=None, force_update=False, debug_mode=False):
"""更新事件统计数据
Args:
start_date: 开始日期
end_date: 结束日期
force_update: 是否强制更新(忽略已有数据)
debug_mode: 是否开启调试模式
"""
try:
print("[DEBUG] 开始 update_event_statistics")
print(f"[DEBUG] 参数: start_date={start_date}, end_date={end_date}, force_update={force_update}")
mysql_engine = get_mysql_engine()
print("[DEBUG] MySQL 引擎创建成功")
ch_client = get_clickhouse_client()
print("[DEBUG] ClickHouse 客户端创建成功")
with mysql_engine.connect() as mysql_conn:
print("[DEBUG] MySQL 连接已建立")
# 构建SQL查询
query = """
SELECT e.id, \
e.created_at, \
GROUP_CONCAT(rs.stock_code) as stock_codes,
e.related_avg_chg, \
e.related_max_chg, \
e.related_week_chg
FROM event e
JOIN related_stock rs ON e.id = rs.event_id \
"""
conditions = []
params = {}
if start_date:
conditions.append("e.created_at >= :start_date")
params["start_date"] = start_date
if end_date:
conditions.append("e.created_at <= :end_date")
params["end_date"] = end_date
if not force_update:
# 只更新没有数据的记录
conditions.append("(e.related_avg_chg IS NULL OR e.related_max_chg IS NULL)")
if conditions:
query += " WHERE " + " AND ".join(conditions)
query += """
GROUP BY e.id, e.created_at, e.related_avg_chg, e.related_max_chg, e.related_week_chg
ORDER BY e.created_at DESC
"""
print(f"[DEBUG] 执行查询SQL:\n{query}")
print(f"[DEBUG] 查询参数: {params}")
events = mysql_conn.execute(text(query), params).fetchall()
print(f"[DEBUG] 查询返回 {len(events)} 条事件记录")
print(f"Found {len(events)} events to update (force_update={force_update})")
if debug_mode and len(events) > 0:
print(f"Date range: {events[-1][1]} to {events[0][1]}")
# 准备批量更新数据
update_data = []
for idx, event in enumerate(events, 1):
try:
event_id = event[0]
created_at = event[1]
stock_codes = event[2].split(',') if event[2] else []
existing_avg = event[3]
existing_max = event[4]
existing_week = event[5]
if not stock_codes:
continue
if debug_mode and idx <= 3: # 只调试前3个事件
print(f"\n[Event {event_id}] created_at={created_at}")
if not force_update and existing_avg is not None:
print(
f" 已有数据: avg={existing_avg:.2f}% max={existing_max:.2f}% week={existing_week:.2f}%")
# 批量计算该事件所有股票的涨跌幅
avg_change, max_change, week_change = calculate_stock_changes(
stock_codes, created_at, ch_client, debug=(debug_mode and idx <= 3)
)
# 收集更新数据
if any(x is not None for x in (avg_change, max_change, week_change)):
update_data.append({
"avg_chg": avg_change,
"max_chg": max_change,
"week_chg": week_change,
"event_id": event_id
})
if idx <= 5: # 前5条显示详情
print(f"[DEBUG] 事件 {event_id}: avg={avg_change}, max={max_change}, week={week_change}")
else:
if idx <= 5:
print(f"[DEBUG] 事件 {event_id}: 计算结果全为None跳过")
# 每处理10个事件打印一次进度
if idx % 10 == 0:
print(f"Processed {idx}/{len(events)} events...")
except Exception as e:
print(f"Error processing event {event[0]}: {str(e)}")
if debug_mode:
import traceback
traceback.print_exc()
continue
# 批量更新MySQL
print(f"\n[DEBUG] ====== 准备写入数据库 ======")
print(f"[DEBUG] update_data 长度: {len(update_data)}")
if update_data:
print(f"[DEBUG] 前3条待更新数据: {update_data[:3]}")
print(f"[DEBUG] 执行 UPDATE 语句...")
result = mysql_conn.execute(text("""
UPDATE event
SET related_avg_chg = :avg_chg,
related_max_chg = :max_chg,
related_week_chg = :week_chg
WHERE id = :event_id
"""), update_data)
print(f"[DEBUG] UPDATE 执行完成, rowcount={result.rowcount}")
# 关键显式提交事务SQLAlchemy 2.0 需要手动 commit
print("[DEBUG] 准备提交事务 (commit)...")
mysql_conn.commit()
print("[DEBUG] 事务已提交!")
print(f"Successfully updated {len(update_data)} events")
else:
print("[DEBUG] update_data 为空,没有数据需要更新!")
except Exception as e:
print(f"Error in update_event_statistics: {str(e)}")
raise
def run_monitor():
"""运行监控循环 - 仅在交易时间段内每2分钟强制更新最近7天数据"""
print("=" * 60)
print("启动交易时段监控模式")
print("运行规则: 仅在交易日的9:30-11:30和13:00-15:00运行")
print("更新频率: 每2分钟一次")
print("更新模式: 强制更新(force_update=True)")
print("更新范围: 最近7天的事件数据")
print("=" * 60)
while True:
try:
now = datetime.now()
# 检查是否在交易时间内
if is_trading_time(now):
seven_days_ago = now - timedelta(days=7)
print(f"\n{'=' * 60}")
print(f"[{now.strftime('%Y-%m-%d %H:%M:%S')}] 交易时段 - 开始更新...")
print(f"{'=' * 60}")
update_event_statistics(
start_date=seven_days_ago,
force_update=True, # 强制更新所有数据
debug_mode=False
)
print(f"\n[{now.strftime('%Y-%m-%d %H:%M:%S')}] 更新完成")
print(f"等待2分钟后执行下次更新...\n")
time.sleep(120) # 2分钟
else:
# 不在交易时间,计算下次交易时间
next_trading_time = get_next_trading_time()
wait_seconds = (next_trading_time - now).total_seconds()
wait_minutes = int(wait_seconds / 60)
print(f"\n{'=' * 60}")
print(f"[{now.strftime('%Y-%m-%d %H:%M:%S')}] 非交易时段")
print(f"下次交易时间: {next_trading_time.strftime('%Y-%m-%d %H:%M:%S')}")
print(f"等待时长: {wait_minutes} 分钟")
print(f"{'=' * 60}\n")
# 等待到下一个交易时段(每5分钟检查一次,避免程序僵死)
check_interval = 300 # 5分钟检查一次
while not is_trading_time():
time.sleep(min(check_interval, max(1, wait_seconds)))
wait_seconds = (get_next_trading_time() - datetime.now()).total_seconds()
if wait_seconds <= 0:
break
except KeyboardInterrupt:
print("\n程序被用户中断")
break
except Exception as e:
print(f"Error in monitor loop: {str(e)}")
import traceback
traceback.print_exc()
print("等待1分钟后重试...")
time.sleep(60) # 发生错误等待1分钟后重试
if __name__ == "__main__":
import sys
# 支持命令行参数
# python get_related_chg.py --test # 测试模式:只更新昨天和今天,开启调试
# python get_related_chg.py --once # 单次强制更新最近7天
# python get_related_chg.py # 正常运行交易时段每2分钟强制更新
if len(sys.argv) > 1:
if sys.argv[1] == '--test':
# 测试模式:更新昨天和今天的数据,开启调试
print("=" * 60)
print("测试模式:更新昨天和今天的数据")
print("=" * 60)
yesterday = (datetime.now() - timedelta(days=2)).replace(hour=15, minute=0, second=0)
tomorrow = datetime.now() + timedelta(days=1)
update_event_statistics(
start_date=yesterday,
end_date=tomorrow,
force_update=True,
debug_mode=True
)
print("\n测试完成!")
elif sys.argv[1] == '--once':
# 单次强制更新模式
print("=" * 60)
print("单次强制更新模式重新计算最近7天所有数据")
print("=" * 60)
seven_days_ago = datetime.now() - timedelta(days=7)
update_event_statistics(
start_date=seven_days_ago,
force_update=True,
debug_mode=False
)
print("\n强制更新完成!")
else:
print("未知参数。支持的参数:")
print(" --test : 测试模式(更新昨天和今天,开启调试)")
print(" --once : 单次强制更新最近7天")
print(" (无参数): 交易时段监控模式(每2分钟强制更新)")
else:
# 正常监控模式:仅在交易时间段运行
run_monitor()