462 lines
18 KiB
Python
462 lines
18 KiB
Python
from clickhouse_driver import Client as Cclient
|
||
from sqlalchemy import create_engine, text
|
||
from datetime import datetime, time as dt_time, timedelta
|
||
import time
|
||
import pandas as pd
|
||
import os
|
||
|
||
# 读取交易日数据
|
||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||
TRADING_DAYS_FILE = os.path.join(script_dir, 'tdays.csv')
|
||
trading_days_df = pd.read_csv(TRADING_DAYS_FILE)
|
||
trading_days_df['DateTime'] = pd.to_datetime(trading_days_df['DateTime']).dt.date
|
||
TRADING_DAYS = sorted(trading_days_df['DateTime'].tolist()) # 排序后的交易日列表
|
||
|
||
|
||
def get_clickhouse_client():
|
||
return Cclient(
|
||
host='127.0.0.1',
|
||
port=9000,
|
||
user='default',
|
||
password='Zzl33818!',
|
||
database='stock'
|
||
)
|
||
|
||
|
||
def get_mysql_engine():
|
||
return create_engine(
|
||
"mysql+pymysql://root:Zzl33818!@127.0.0.1:3306/stock",
|
||
echo=False
|
||
)
|
||
|
||
|
||
def is_trading_time(check_datetime=None):
|
||
"""判断是否在交易时间内
|
||
|
||
Args:
|
||
check_datetime: 要检查的时间,默认为当前时间
|
||
|
||
Returns:
|
||
bool: True表示在交易时间内
|
||
"""
|
||
if check_datetime is None:
|
||
check_datetime = datetime.now()
|
||
|
||
# 检查是否是交易日
|
||
check_date = check_datetime.date()
|
||
if check_date not in TRADING_DAYS:
|
||
return False
|
||
|
||
# 检查是否在交易时段内
|
||
check_time = check_datetime.time()
|
||
|
||
# 上午时段: 9:30 - 11:30
|
||
morning_start = dt_time(9, 30)
|
||
morning_end = dt_time(11, 30)
|
||
|
||
# 下午时段: 13:00 - 15:00
|
||
afternoon_start = dt_time(13, 0)
|
||
afternoon_end = dt_time(15, 0)
|
||
|
||
is_morning = morning_start <= check_time <= morning_end
|
||
is_afternoon = afternoon_start <= check_time <= afternoon_end
|
||
|
||
return is_morning or is_afternoon
|
||
|
||
|
||
def get_next_trading_time():
|
||
"""获取下一个交易时段的开始时间"""
|
||
now = datetime.now()
|
||
current_date = now.date()
|
||
current_time = now.time()
|
||
|
||
# 如果今天是交易日
|
||
if current_date in TRADING_DAYS:
|
||
morning_start = dt_time(9, 30)
|
||
afternoon_start = dt_time(13, 0)
|
||
|
||
# 如果还没到上午开盘
|
||
if current_time < morning_start:
|
||
return datetime.combine(current_date, morning_start)
|
||
# 如果在上午休市后,下午还没开盘
|
||
elif dt_time(11, 30) < current_time < afternoon_start:
|
||
return datetime.combine(current_date, afternoon_start)
|
||
|
||
# 否则找下一个交易日的上午开盘时间
|
||
for td in TRADING_DAYS:
|
||
if td > current_date:
|
||
return datetime.combine(td, dt_time(9, 30))
|
||
|
||
# 如果没有找到未来交易日,返回明天上午9:30(可能需要更新交易日数据)
|
||
return datetime.combine(current_date + timedelta(days=1), dt_time(9, 30))
|
||
|
||
|
||
def get_next_trading_day(date):
|
||
"""获取下一个交易日"""
|
||
for td in TRADING_DAYS:
|
||
if td > date:
|
||
return td
|
||
return None
|
||
|
||
|
||
def get_nth_trading_day_after(start_date, n=7):
|
||
"""获取start_date之后的第n个交易日"""
|
||
try:
|
||
start_idx = TRADING_DAYS.index(start_date)
|
||
target_idx = start_idx + n
|
||
if target_idx < len(TRADING_DAYS):
|
||
return TRADING_DAYS[target_idx]
|
||
except (ValueError, IndexError):
|
||
pass
|
||
|
||
# 如果start_date不在交易日列表中,找到它之后的交易日
|
||
future_days = [d for d in TRADING_DAYS if d > start_date]
|
||
if len(future_days) >= n:
|
||
return future_days[n - 1]
|
||
elif future_days:
|
||
return future_days[-1] # 返回最后一个可用的交易日
|
||
|
||
return None
|
||
|
||
|
||
def get_trading_day_info(event_datetime):
|
||
"""获取事件对应的交易日信息"""
|
||
event_date = event_datetime.date()
|
||
market_close = dt_time(15, 0)
|
||
|
||
# 如果是交易日且在收盘前,使用当天
|
||
if event_date in TRADING_DAYS and event_datetime.time() <= market_close:
|
||
return event_date
|
||
|
||
# 否则使用下一个交易日
|
||
return get_next_trading_day(event_date)
|
||
|
||
|
||
def calculate_stock_changes(stock_codes, event_datetime, ch_client, debug=False):
|
||
"""批量计算一个事件关联的所有股票涨跌幅"""
|
||
|
||
if not stock_codes:
|
||
return None, None, None
|
||
|
||
event_date = event_datetime.date()
|
||
event_time = event_datetime.time()
|
||
market_open = dt_time(9, 30)
|
||
market_close = dt_time(15, 0)
|
||
|
||
# 确定起始时间点(事件发生后的第一个有效价格点)
|
||
if event_date in TRADING_DAYS and market_open <= event_time <= market_close:
|
||
# 事件在交易时间内发生 → 用事件发生时的价格作为起点
|
||
start_datetime = event_datetime
|
||
trading_date = event_date
|
||
end_datetime = datetime.combine(trading_date, market_close)
|
||
if debug:
|
||
print(f" 事件在交易时间内: {event_datetime} -> 起点={start_datetime}")
|
||
else:
|
||
# 事件在交易时间外发生 → 用下一个交易日开盘价作为起点
|
||
trading_date = get_trading_day_info(event_datetime)
|
||
if not trading_date:
|
||
if debug:
|
||
print(f" 找不到交易日: {event_datetime}")
|
||
return None, None, None
|
||
start_datetime = datetime.combine(trading_date, market_open)
|
||
end_datetime = datetime.combine(trading_date, market_close)
|
||
if debug:
|
||
print(f" 事件在非交易时间: {event_datetime} -> 下一交易日={trading_date}, 起点={start_datetime}")
|
||
|
||
# 获取7个交易日后的日期
|
||
week_trading_date = get_nth_trading_day_after(trading_date, 7)
|
||
if not week_trading_date:
|
||
# 降级:如果没有足够的未来交易日,就用当前能找到的最远日期
|
||
week_trading_date = trading_date + timedelta(days=10)
|
||
|
||
week_end_datetime = datetime.combine(week_trading_date, market_close)
|
||
|
||
if debug:
|
||
print(f" 查询范围: {start_datetime} -> 当日={end_datetime}, 周末={week_end_datetime}")
|
||
print(f" 股票代码: {stock_codes}")
|
||
|
||
# 一次性查询所有股票的价格数据
|
||
results = ch_client.execute("""
|
||
SELECT code,
|
||
-- 起始价格:事件发生时或之后的第一个价格
|
||
argMin(close, timestamp) as start_price,
|
||
-- 当日收盘价:当日交易结束时的最后一个价格
|
||
argMax(
|
||
close, if(timestamp <= %(end)s, timestamp, toDateTime('1970-01-01'))
|
||
) as day_close_price,
|
||
-- 周后收盘价:7个交易日后的收盘价
|
||
argMax(
|
||
close, if(timestamp <= %(week_end)s, timestamp, toDateTime('1970-01-01'))
|
||
) as week_close_price
|
||
FROM stock_minute
|
||
WHERE code IN %(codes)s
|
||
AND timestamp >= %(start)s
|
||
AND timestamp <= %(week_end)s
|
||
GROUP BY code
|
||
HAVING start_price > 0
|
||
""", {
|
||
'codes': tuple(stock_codes),
|
||
'start': start_datetime,
|
||
'end': end_datetime,
|
||
'week_end': week_end_datetime
|
||
})
|
||
|
||
if debug:
|
||
print(f" 查询到 {len(results)} 只股票的数据")
|
||
|
||
if not results:
|
||
return None, None, None
|
||
|
||
# 计算涨跌幅
|
||
day_changes = []
|
||
week_changes = []
|
||
|
||
for code, start_price, day_close, week_close in results:
|
||
if start_price and start_price > 0:
|
||
# 当日涨跌幅(从事件发生到当日收盘)
|
||
if day_close and day_close > 0:
|
||
day_change = (day_close - start_price) / start_price * 100
|
||
day_changes.append(day_change)
|
||
|
||
# 周度涨跌幅(从事件发生到第7个交易日收盘)
|
||
if week_close and week_close > 0:
|
||
week_change = (week_close - start_price) / start_price * 100
|
||
week_changes.append(week_change)
|
||
|
||
# 计算统计值
|
||
avg_change = sum(day_changes) / len(day_changes) if day_changes else None
|
||
max_change = max(day_changes) if day_changes else None
|
||
avg_week_change = sum(week_changes) / len(week_changes) if week_changes else None
|
||
|
||
if debug:
|
||
print(
|
||
f" 结果: 日均={avg_change:.2f}% 日最大={max_change:.2f}% 周均={avg_week_change:.2f}%" if avg_change else " 结果: 无有效数据")
|
||
|
||
return avg_change, max_change, avg_week_change
|
||
|
||
|
||
def update_event_statistics(start_date=None, end_date=None, force_update=False, debug_mode=False):
|
||
"""更新事件统计数据
|
||
|
||
Args:
|
||
start_date: 开始日期
|
||
end_date: 结束日期
|
||
force_update: 是否强制更新(忽略已有数据)
|
||
debug_mode: 是否开启调试模式
|
||
"""
|
||
try:
|
||
mysql_engine = get_mysql_engine()
|
||
ch_client = get_clickhouse_client()
|
||
|
||
with mysql_engine.connect() as mysql_conn:
|
||
# 构建SQL查询
|
||
query = """
|
||
SELECT e.id, \
|
||
e.created_at, \
|
||
GROUP_CONCAT(rs.stock_code) as stock_codes,
|
||
e.related_avg_chg, \
|
||
e.related_max_chg, \
|
||
e.related_week_chg
|
||
FROM event e
|
||
JOIN related_stock rs ON e.id = rs.event_id \
|
||
"""
|
||
|
||
conditions = []
|
||
params = {}
|
||
|
||
if start_date:
|
||
conditions.append("e.created_at >= :start_date")
|
||
params["start_date"] = start_date
|
||
|
||
if end_date:
|
||
conditions.append("e.created_at <= :end_date")
|
||
params["end_date"] = end_date
|
||
|
||
if not force_update:
|
||
# 只更新没有数据的记录
|
||
conditions.append("(e.related_avg_chg IS NULL OR e.related_max_chg IS NULL)")
|
||
|
||
if conditions:
|
||
query += " WHERE " + " AND ".join(conditions)
|
||
|
||
query += """
|
||
GROUP BY e.id, e.created_at, e.related_avg_chg, e.related_max_chg, e.related_week_chg
|
||
ORDER BY e.created_at DESC
|
||
"""
|
||
|
||
events = mysql_conn.execute(text(query), params).fetchall()
|
||
|
||
print(f"Found {len(events)} events to update (force_update={force_update})")
|
||
if debug_mode and len(events) > 0:
|
||
print(f"Date range: {events[-1][1]} to {events[0][1]}")
|
||
|
||
# 准备批量更新数据
|
||
update_data = []
|
||
|
||
for idx, event in enumerate(events, 1):
|
||
try:
|
||
event_id = event[0]
|
||
created_at = event[1]
|
||
stock_codes = event[2].split(',') if event[2] else []
|
||
existing_avg = event[3]
|
||
existing_max = event[4]
|
||
existing_week = event[5]
|
||
|
||
if not stock_codes:
|
||
continue
|
||
|
||
if debug_mode and idx <= 3: # 只调试前3个事件
|
||
print(f"\n[Event {event_id}] created_at={created_at}")
|
||
if not force_update and existing_avg is not None:
|
||
print(
|
||
f" 已有数据: avg={existing_avg:.2f}% max={existing_max:.2f}% week={existing_week:.2f}%")
|
||
|
||
# 批量计算该事件所有股票的涨跌幅
|
||
avg_change, max_change, week_change = calculate_stock_changes(
|
||
stock_codes, created_at, ch_client, debug=(debug_mode and idx <= 3)
|
||
)
|
||
|
||
# 收集更新数据
|
||
if any(x is not None for x in (avg_change, max_change, week_change)):
|
||
update_data.append({
|
||
"avg_chg": avg_change,
|
||
"max_chg": max_change,
|
||
"week_chg": week_change,
|
||
"event_id": event_id
|
||
})
|
||
|
||
# 每处理10个事件打印一次进度
|
||
if idx % 10 == 0:
|
||
print(f"Processed {idx}/{len(events)} events...")
|
||
|
||
except Exception as e:
|
||
print(f"Error processing event {event[0]}: {str(e)}")
|
||
if debug_mode:
|
||
import traceback
|
||
traceback.print_exc()
|
||
continue
|
||
|
||
# 批量更新MySQL
|
||
if update_data:
|
||
mysql_conn.execute(text("""
|
||
UPDATE event
|
||
SET related_avg_chg = :avg_chg,
|
||
related_max_chg = :max_chg,
|
||
related_week_chg = :week_chg
|
||
WHERE id = :event_id
|
||
"""), update_data)
|
||
# SQLAlchemy 1.4 的 connect() 会在上下文管理器退出时自动提交
|
||
print(f"Successfully updated {len(update_data)} events")
|
||
|
||
except Exception as e:
|
||
print(f"Error in update_event_statistics: {str(e)}")
|
||
raise
|
||
|
||
|
||
def run_monitor():
|
||
"""运行监控循环 - 仅在交易时间段内每2分钟强制更新最近7天数据"""
|
||
print("=" * 60)
|
||
print("启动交易时段监控模式")
|
||
print("运行规则: 仅在交易日的9:30-11:30和13:00-15:00运行")
|
||
print("更新频率: 每2分钟一次")
|
||
print("更新模式: 强制更新(force_update=True)")
|
||
print("更新范围: 最近7天的事件数据")
|
||
print("=" * 60)
|
||
|
||
while True:
|
||
try:
|
||
now = datetime.now()
|
||
|
||
# 检查是否在交易时间内
|
||
if is_trading_time(now):
|
||
seven_days_ago = now - timedelta(days=7)
|
||
|
||
print(f"\n{'=' * 60}")
|
||
print(f"[{now.strftime('%Y-%m-%d %H:%M:%S')}] 交易时段 - 开始更新...")
|
||
print(f"{'=' * 60}")
|
||
|
||
update_event_statistics(
|
||
start_date=seven_days_ago,
|
||
force_update=True, # 强制更新所有数据
|
||
debug_mode=False
|
||
)
|
||
|
||
print(f"\n[{now.strftime('%Y-%m-%d %H:%M:%S')}] 更新完成")
|
||
print(f"等待2分钟后执行下次更新...\n")
|
||
time.sleep(120) # 2分钟
|
||
|
||
else:
|
||
# 不在交易时间,计算下次交易时间
|
||
next_trading_time = get_next_trading_time()
|
||
wait_seconds = (next_trading_time - now).total_seconds()
|
||
wait_minutes = int(wait_seconds / 60)
|
||
|
||
print(f"\n{'=' * 60}")
|
||
print(f"[{now.strftime('%Y-%m-%d %H:%M:%S')}] 非交易时段")
|
||
print(f"下次交易时间: {next_trading_time.strftime('%Y-%m-%d %H:%M:%S')}")
|
||
print(f"等待时长: {wait_minutes} 分钟")
|
||
print(f"{'=' * 60}\n")
|
||
|
||
# 等待到下一个交易时段(每5分钟检查一次,避免程序僵死)
|
||
check_interval = 300 # 5分钟检查一次
|
||
while not is_trading_time():
|
||
time.sleep(min(check_interval, max(1, wait_seconds)))
|
||
wait_seconds = (get_next_trading_time() - datetime.now()).total_seconds()
|
||
if wait_seconds <= 0:
|
||
break
|
||
|
||
except KeyboardInterrupt:
|
||
print("\n程序被用户中断")
|
||
break
|
||
except Exception as e:
|
||
print(f"Error in monitor loop: {str(e)}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
print("等待1分钟后重试...")
|
||
time.sleep(60) # 发生错误等待1分钟后重试
|
||
|
||
|
||
if __name__ == "__main__":
|
||
import sys
|
||
|
||
# 支持命令行参数
|
||
# python get_related_chg.py --test # 测试模式:只更新昨天和今天,开启调试
|
||
# python get_related_chg.py --once # 单次强制更新最近7天
|
||
# python get_related_chg.py # 正常运行:交易时段每2分钟强制更新
|
||
|
||
if len(sys.argv) > 1:
|
||
if sys.argv[1] == '--test':
|
||
# 测试模式:更新昨天和今天的数据,开启调试
|
||
print("=" * 60)
|
||
print("测试模式:更新昨天和今天的数据")
|
||
print("=" * 60)
|
||
yesterday = (datetime.now() - timedelta(days=2)).replace(hour=15, minute=0, second=0)
|
||
tomorrow = datetime.now() + timedelta(days=1)
|
||
update_event_statistics(
|
||
start_date=yesterday,
|
||
end_date=tomorrow,
|
||
force_update=True,
|
||
debug_mode=True
|
||
)
|
||
print("\n测试完成!")
|
||
|
||
elif sys.argv[1] == '--once':
|
||
# 单次强制更新模式
|
||
print("=" * 60)
|
||
print("单次强制更新模式:重新计算最近7天所有数据")
|
||
print("=" * 60)
|
||
seven_days_ago = datetime.now() - timedelta(days=7)
|
||
update_event_statistics(
|
||
start_date=seven_days_ago,
|
||
force_update=True,
|
||
debug_mode=False
|
||
)
|
||
print("\n强制更新完成!")
|
||
else:
|
||
print("未知参数。支持的参数:")
|
||
print(" --test : 测试模式(更新昨天和今天,开启调试)")
|
||
print(" --once : 单次强制更新最近7天")
|
||
print(" (无参数): 交易时段监控模式(每2分钟强制更新)")
|
||
else:
|
||
# 正常监控模式:仅在交易时间段运行
|
||
run_monitor() |