# ============ Eventlet/Gevent Monkey Patching(必须放在所有 import 之前!)============ # 用于支持 Gunicorn + eventlet/gevent 异步模式,使 requests 等阻塞调用变为非阻塞 import os import sys def _detect_async_env(): """检测当前异步环境""" # 检测 eventlet try: import eventlet if hasattr(eventlet, 'is_monkey_patched') and eventlet.is_monkey_patched('socket'): return 'eventlet_patched' return 'eventlet_available' except ImportError: pass # 检测 gevent try: from gevent import monkey if monkey.is_module_patched('socket'): return 'gevent_patched' return 'gevent_available' except ImportError: pass return 'none' _async_env = _detect_async_env() # Gunicorn eventlet worker 会自动 patch,这里只打印状态 if _async_env == 'eventlet_patched': print("✅ Eventlet monkey patching 已由 Worker 启用") elif _async_env == 'gevent_patched': print("✅ Gevent monkey patching 已由 Worker 启用") elif _async_env == 'eventlet_available': print("📡 Eventlet 可用,等待 Gunicorn worker 初始化") elif _async_env == 'gevent_available': print("📡 Gevent 可用,等待 Gunicorn worker 初始化") else: print("⚠️ 未检测到 eventlet 或 gevent,将使用 threading 模式") # ============ Monkey Patching 检测结束 ============ import base64 import csv import io import threading import time import urllib import uuid from functools import wraps import qrcode from flask_mail import Mail, Message from flask_socketio import SocketIO, emit, join_room, leave_room import pytz import requests from celery import Celery from flask_compress import Compress from pathlib import Path import json from sqlalchemy import Column, Integer, String, Boolean, DateTime, create_engine, text, func, or_ from flask import Flask, render_template, request, jsonify, redirect, url_for, flash, session, render_template_string, \ current_app, make_response from flask_sqlalchemy import SQLAlchemy from flask_login import LoginManager, UserMixin, login_user, logout_user, login_required, current_user import random from werkzeug.security import generate_password_hash, check_password_hash from werkzeug.middleware.proxy_fix import ProxyFix import re import string from datetime import datetime, timedelta, time as dt_time, date from clickhouse_driver import Client as Cclient from elasticsearch import Elasticsearch from flask_cors import CORS import redis from flask_session import Session from collections import defaultdict from functools import lru_cache import jieba import jieba.analyse from flask_cors import cross_origin from tencentcloud.common import credential from tencentcloud.common.profile.client_profile import ClientProfile from tencentcloud.common.profile.http_profile import HttpProfile from tencentcloud.sms.v20210111 import sms_client, models from tencentcloud.common.exception.tencent_cloud_sdk_exception import TencentCloudSDKException from sqlalchemy import text, desc, and_ import pandas as pd from decimal import Decimal from apscheduler.schedulers.background import BackgroundScheduler # 交易日数据缓存 trading_days = [] trading_days_set = set() def load_trading_days(): """加载交易日数据""" global trading_days, trading_days_set try: with open('tdays.csv', 'r') as f: reader = csv.DictReader(f) for row in reader: date_str = row['DateTime'] # 解析日期 (格式: 2010/1/4) date = datetime.strptime(date_str, '%Y/%m/%d').date() trading_days.append(date) trading_days_set.add(date) # 排序交易日 trading_days.sort() print(f"成功加载 {len(trading_days)} 个交易日数据") except Exception as e: print(f"加载交易日数据失败: {e}") def row_to_dict(row): """ 将 SQLAlchemy Row 对象转换为字典 兼容 SQLAlchemy 1.4+ 版本 """ if row is None: return None # 使用 _mapping 属性来访问列数据 return dict(row._mapping) def get_trading_day_near_date(target_date): """ 获取距离目标日期最近的交易日 如果目标日期是交易日,返回该日期 如果不是,返回下一个交易日 """ if not trading_days: load_trading_days() if not trading_days: return None # 如果目标日期是datetime,转换为date if isinstance(target_date, datetime): target_date = target_date.date() # 检查目标日期是否是交易日 if target_date in trading_days_set: return target_date # 查找下一个交易日 for trading_day in trading_days: if trading_day >= target_date: return trading_day # 如果没有找到,返回最后一个交易日 return trading_days[-1] if trading_days else None def get_target_and_prev_trading_day(event_datetime): """ 根据事件时间确定目标交易日和前一交易日(用于计算涨跌幅) 处理跨周末场景: - 周五15:00后到周一15:00前,分时图显示周一行情,涨跌幅基于周五收盘价 逻辑: - 如果事件时间在交易日的 9:00-15:00 之间,显示当天数据,涨跌幅基于前一交易日 - 如果事件时间在交易日的 15:00 之后,显示下一个交易日数据,涨跌幅基于当天 - 如果事件时间在非交易日(周末/节假日),显示下一个交易日数据,涨跌幅基于上一个交易日 - 如果事件时间在交易日的 9:00 之前,显示当天数据,涨跌幅基于前一交易日 返回:(target_date, prev_close_date) - 分时图显示日期和涨跌幅基准日期 """ if not trading_days: load_trading_days() if not trading_days: return None, None # 如果是datetime,提取date和time if isinstance(event_datetime, datetime): event_date = event_datetime.date() event_time = event_datetime.time() else: event_date = event_datetime event_time = dt_time(12, 0) # 默认中午,认为在盘中 # 检查事件日期是否是交易日 is_trading_day = event_date in trading_days_set # 收盘时间判断 market_close_time = dt_time(15, 0) is_after_market = event_time > market_close_time if is_trading_day: if is_after_market: # 交易日收盘后:显示下一个交易日,涨跌幅基于当天(即本交易日) target_date = get_trading_day_near_date(event_date + timedelta(days=1)) prev_close_date = event_date else: # 交易日盘中或开盘前:显示当天,涨跌幅基于前一交易日 target_date = event_date # 找前一个交易日 target_idx = trading_days.index(event_date) if event_date in trading_days else -1 prev_close_date = trading_days[target_idx - 1] if target_idx > 0 else None else: # 非交易日(周末/节假日):显示下一个交易日,涨跌幅基于上一个交易日 target_date = get_trading_day_near_date(event_date) # 找上一个交易日作为基准 prev_close_date = None for td in reversed(trading_days): if td < event_date: prev_close_date = td break return target_date, prev_close_date # 应用启动时加载交易日数据 load_trading_days() def is_trading_hours(): """ 判断当前是否在交易时间段内 交易时间:交易日的 9:00-15:00(含午休时间,因为事件可能在午休发布) Returns: bool: True 表示在交易时间段,False 表示非交易时间 """ now = datetime.now() today = now.date() current_time = now.time() # 判断今天是否为交易日 if today not in trading_days_set: return False # 判断是否在 9:00-15:00 之间 market_open = dt_time(9, 0) market_close = dt_time(15, 0) return market_open <= current_time <= market_close engine = create_engine( "mysql+pymysql://root:Zzl33818!@127.0.0.1:3306/stock?charset=utf8mb4", echo=False, pool_size=50, # 每个 worker 常驻连接数 pool_recycle=1800, # 连接回收时间 30 分钟(原 1 小时) pool_pre_ping=True, # 使用前检测连接是否有效 pool_timeout=20, # 获取连接超时时间(秒) max_overflow=100 # 每个 worker 临时溢出连接数 # 每个 worker 最多 150 个连接,32 workers 总共最多 4800 个连接 ) # Elasticsearch 客户端初始化 es_client = Elasticsearch( hosts=["http://222.128.1.157:19200"], request_timeout=30, max_retries=3, retry_on_timeout=True ) app = Flask(__name__) # ============ ProxyFix 配置(信任反向代理头)============ # 重要:解决 Nginx 反向代理后 Flask 无法识别 HTTPS 的问题 # 这会导致 SESSION_COOKIE_SECURE=True 时 cookie 被清除 # x_for=1: 信任 1 层代理的 X-Forwarded-For 头(获取真实客户端 IP) # x_proto=1: 信任 1 层代理的 X-Forwarded-Proto 头(识别 HTTPS) # x_host=1: 信任 1 层代理的 X-Forwarded-Host 头(获取原始 Host) # x_prefix=1: 信任 1 层代理的 X-Forwarded-Prefix 头(URL 前缀) app.wsgi_app = ProxyFix(app.wsgi_app, x_for=1, x_proto=1, x_host=1, x_prefix=1) print("✅ ProxyFix 已配置,Flask 将信任反向代理头(X-Forwarded-Proto 等)") # ============ Redis 连接配置(支持环境变量覆盖) ============ _REDIS_HOST = os.environ.get('REDIS_HOST', 'localhost') _REDIS_PORT = int(os.environ.get('REDIS_PORT', 6379)) _REDIS_PASSWORD = os.environ.get('REDIS_PASSWORD', 'VF_Redis_2024') # Redis 密码(安全加固) redis_client = redis.Redis(host=_REDIS_HOST, port=_REDIS_PORT, db=0, password=_REDIS_PASSWORD, decode_responses=True) print(f"📦 Redis 配置: {_REDIS_HOST}:{_REDIS_PORT}/db=0 (已启用密码认证)") # ============ 验证码 Redis 存储(支持多进程/多 Worker) ============ VERIFICATION_CODE_PREFIX = "vf_code:" VERIFICATION_CODE_EXPIRE = 300 # 验证码过期时间(5分钟) def set_verification_code(key, code, expires_in=VERIFICATION_CODE_EXPIRE): """存储验证码到 Redis""" try: data = { 'code': code, 'expires': time.time() + expires_in } redis_client.setex( f"{VERIFICATION_CODE_PREFIX}{key}", expires_in, json.dumps(data) ) return True except Exception as e: print(f"❌ Redis 存储验证码失败: {e}") return False def get_verification_code(key): """从 Redis 获取验证码""" try: data = redis_client.get(f"{VERIFICATION_CODE_PREFIX}{key}") if data: return json.loads(data) return None except Exception as e: print(f"❌ Redis 获取验证码失败: {e}") return None def delete_verification_code(key): """从 Redis 删除验证码""" try: redis_client.delete(f"{VERIFICATION_CODE_PREFIX}{key}") except Exception as e: print(f"❌ Redis 删除验证码失败: {e}") print(f"📦 验证码存储: Redis, 过期时间: {VERIFICATION_CODE_EXPIRE}秒") # ============ 事件列表 Redis 缓存(智能 TTL 策略) ============ EVENTS_CACHE_PREFIX = "events:cache:" EVENTS_CACHE_TTL_TRADING = 20 # 交易时间缓存 TTL(秒) EVENTS_CACHE_TTL_NON_TRADING = 600 # 非交易时间缓存 TTL(秒,10分钟) def generate_events_cache_key(args_dict): """ 根据请求参数生成缓存 Key 使用 MD5 哈希保证 key 长度固定且唯一 Args: args_dict: 请求参数字典 Returns: str: 缓存 key,格式为 events:cache:{md5_hash} """ import hashlib # 过滤掉空值,并排序保证顺序一致 filtered_params = {k: v for k, v in sorted(args_dict.items()) if v is not None and v != '' and v != 'all'} # 生成参数字符串并计算 MD5 params_str = json.dumps(filtered_params, sort_keys=True) params_hash = hashlib.md5(params_str.encode()).hexdigest() return f"{EVENTS_CACHE_PREFIX}{params_hash}" def get_events_cache(cache_key): """ 从 Redis 获取事件列表缓存 Args: cache_key: 缓存 key Returns: dict or None: 缓存的响应数据,如果不存在或出错返回 None """ try: cached = redis_client.get(cache_key) if cached: return json.loads(cached) return None except Exception as e: print(f"❌ Redis 获取事件缓存失败: {e}") return None def set_events_cache(cache_key, data): """ 将事件列表数据存入 Redis 缓存 根据是否在交易时间自动选择 TTL Args: cache_key: 缓存 key data: 要缓存的响应数据 Returns: bool: 是否成功 """ try: # 根据交易时间选择 TTL ttl = EVENTS_CACHE_TTL_TRADING if is_trading_hours() else EVENTS_CACHE_TTL_NON_TRADING redis_client.setex(cache_key, ttl, json.dumps(data, ensure_ascii=False)) return True except Exception as e: print(f"❌ Redis 存储事件缓存失败: {e}") return False def clear_events_cache(): """ 清除所有事件列表缓存 用于事件数据更新后主动刷新缓存 """ try: # 使用 SCAN 命令迭代删除,避免 KEYS 命令阻塞 cursor = 0 deleted_count = 0 while True: cursor, keys = redis_client.scan(cursor, match=f"{EVENTS_CACHE_PREFIX}*", count=100) if keys: redis_client.delete(*keys) deleted_count += len(keys) if cursor == 0: break if deleted_count > 0: print(f"🗑️ 已清除 {deleted_count} 个事件缓存") return deleted_count except Exception as e: print(f"❌ 清除事件缓存失败: {e}") return 0 print(f"📦 事件列表缓存: 交易时间 {EVENTS_CACHE_TTL_TRADING}s / 非交易时间 {EVENTS_CACHE_TTL_NON_TRADING}s") # ============ 微信登录 Session 管理(Redis 存储,支持多进程) ============ WECHAT_SESSION_EXPIRE = 300 # Session 过期时间(5分钟) WECHAT_SESSION_PREFIX = "wechat_session:" def set_wechat_session(state, data): """存储微信登录 session 到 Redis""" try: redis_client.setex( f"{WECHAT_SESSION_PREFIX}{state}", WECHAT_SESSION_EXPIRE, json.dumps(data) ) return True except Exception as e: print(f"❌ Redis 存储 wechat session 失败: {e}") return False def get_wechat_session(state): """从 Redis 获取微信登录 session""" try: data = redis_client.get(f"{WECHAT_SESSION_PREFIX}{state}") if data: return json.loads(data) return None except Exception as e: print(f"❌ Redis 获取 wechat session 失败: {e}") return None def update_wechat_session(state, updates): """更新微信登录 session(合并更新)""" try: data = get_wechat_session(state) if data: data.update(updates) # 获取剩余 TTL,保持原有过期时间 ttl = redis_client.ttl(f"{WECHAT_SESSION_PREFIX}{state}") if ttl > 0: redis_client.setex( f"{WECHAT_SESSION_PREFIX}{state}", ttl, json.dumps(data) ) else: # 如果 TTL 无效,使用默认过期时间 set_wechat_session(state, data) return True return False except Exception as e: print(f"❌ Redis 更新 wechat session 失败: {e}") return False def delete_wechat_session(state): """删除微信登录 session""" try: redis_client.delete(f"{WECHAT_SESSION_PREFIX}{state}") return True except Exception as e: print(f"❌ Redis 删除 wechat session 失败: {e}") return False def wechat_session_exists(state): """检查微信登录 session 是否存在""" try: return redis_client.exists(f"{WECHAT_SESSION_PREFIX}{state}") > 0 except Exception as e: print(f"❌ Redis 检查 wechat session 失败: {e}") return False # ============ 微信登录 Session 管理结束 ============ # ============ 股票数据 Redis 缓存(股票名称 + 前收盘价) ============ STOCK_NAME_PREFIX = "vf:stock:name:" # 股票名称缓存前缀 STOCK_NAME_EXPIRE = 86400 # 股票名称缓存24小时 PREV_CLOSE_PREFIX = "vf:stock:prev_close:" # 前收盘价缓存前缀 PREV_CLOSE_EXPIRE = 86400 # 前收盘价缓存24小时(当日有效) def get_cached_stock_names(base_codes): """ 批量获取股票名称(优先从 Redis 缓存读取) :param base_codes: 股票代码列表(不带后缀,如 ['600000', '000001']) :return: dict {code: name} """ if not base_codes: return {} result = {} missing_codes = [] try: # 批量从 Redis 获取 pipe = redis_client.pipeline() for code in base_codes: pipe.get(f"{STOCK_NAME_PREFIX}{code}") cached_values = pipe.execute() for code, cached_name in zip(base_codes, cached_values): if cached_name: result[code] = cached_name else: missing_codes.append(code) except Exception as e: print(f"⚠️ Redis 批量获取股票名称失败: {e},降级为数据库查询") missing_codes = base_codes # 从数据库查询缺失的股票名称 if missing_codes: try: with engine.connect() as conn: placeholders = ','.join([f':code{i}' for i in range(len(missing_codes))]) params = {f'code{i}': code for i, code in enumerate(missing_codes)} db_result = conn.execute(text( f"SELECT SECCODE, SECNAME FROM ea_stocklist WHERE SECCODE IN ({placeholders})" ), params).fetchall() # 写入 Redis 缓存 pipe = redis_client.pipeline() for row in db_result: code, name = row[0], row[1] result[code] = name pipe.setex(f"{STOCK_NAME_PREFIX}{code}", STOCK_NAME_EXPIRE, name) try: pipe.execute() except Exception as e: print(f"⚠️ Redis 缓存股票名称失败: {e}") except Exception as e: print(f"❌ 数据库查询股票名称失败: {e}") return result def get_cached_prev_close(base_codes, trade_date_str): """ 批量获取前收盘价(优先从 Redis 缓存读取) :param base_codes: 股票代码列表(不带后缀,如 ['600000', '000001']) :param trade_date_str: 交易日期字符串(格式 YYYYMMDD) :return: dict {code: close_price} """ if not base_codes or not trade_date_str: return {} result = {} missing_codes = [] try: # 批量从 Redis 获取(缓存键包含日期,确保不会跨日混用) pipe = redis_client.pipeline() for code in base_codes: pipe.get(f"{PREV_CLOSE_PREFIX}{trade_date_str}:{code}") cached_values = pipe.execute() for code, cached_price in zip(base_codes, cached_values): if cached_price: result[code] = float(cached_price) else: missing_codes.append(code) except Exception as e: print(f"⚠️ Redis 批量获取前收盘价失败: {e},降级为数据库查询") missing_codes = base_codes # 从数据库查询缺失的前收盘价 if missing_codes: try: with engine.connect() as conn: placeholders = ','.join([f':code{i}' for i in range(len(missing_codes))]) params = {f'code{i}': code for i, code in enumerate(missing_codes)} params['trade_date'] = trade_date_str db_result = conn.execute(text(f""" SELECT SECCODE, F007N as close_price FROM ea_trade WHERE SECCODE IN ({placeholders}) AND TRADEDATE = :trade_date AND F007N > 0 """), params).fetchall() # 写入 Redis 缓存 pipe = redis_client.pipeline() for row in db_result: code, close_price = row[0], float(row[1]) if row[1] else None if close_price: result[code] = close_price pipe.setex(f"{PREV_CLOSE_PREFIX}{trade_date_str}:{code}", PREV_CLOSE_EXPIRE, str(close_price)) try: pipe.execute() except Exception as e: print(f"⚠️ Redis 缓存前收盘价失败: {e}") except Exception as e: print(f"❌ 数据库查询前收盘价失败: {e}") return result def preload_stock_cache(): """ 预热股票缓存(定时任务,每天 9:25 执行) - 批量加载所有股票名称 - 批量加载前一交易日收盘价 """ from datetime import datetime, timedelta print(f"[缓存预热] 开始预热股票缓存... {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") try: # 1. 预热股票名称(全量加载) with engine.connect() as conn: result = conn.execute(text("SELECT SECCODE, SECNAME FROM ea_stocklist")).fetchall() pipe = redis_client.pipeline() count = 0 for row in result: code, name = row[0], row[1] if code and name: pipe.setex(f"{STOCK_NAME_PREFIX}{code}", STOCK_NAME_EXPIRE, name) count += 1 pipe.execute() print(f"[缓存预热] 股票名称: {count} 条已加载到 Redis") # 2. 预热前收盘价(获取前一交易日) today = datetime.now().date() today_str = today.strftime('%Y-%m-%d') prev_trading_day = None if 'trading_days' in globals() and trading_days: for td in reversed(trading_days): if td < today_str: prev_trading_day = td break if prev_trading_day: prev_date_str = prev_trading_day.replace('-', '') # YYYYMMDD 格式 with engine.connect() as conn: result = conn.execute(text(""" SELECT SECCODE, F007N as close_price FROM ea_trade WHERE TRADEDATE = :trade_date AND F007N > 0 """), {'trade_date': prev_date_str}).fetchall() pipe = redis_client.pipeline() count = 0 for row in result: code, close_price = row[0], row[1] if code and close_price: pipe.setex(f"{PREV_CLOSE_PREFIX}{prev_date_str}:{code}", PREV_CLOSE_EXPIRE, str(close_price)) count += 1 pipe.execute() print(f"[缓存预热] 前收盘价({prev_trading_day}): {count} 条已加载到 Redis") else: print(f"[缓存预热] 未找到前一交易日,跳过前收盘价预热") print(f"[缓存预热] 预热完成 ✅ {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") except Exception as e: print(f"[缓存预热] 预热失败 ❌: {e}") import traceback traceback.print_exc() print(f"📦 股票缓存: Redis, 名称过期 {STOCK_NAME_EXPIRE}秒, 收盘价过期 {PREV_CLOSE_EXPIRE}秒") # ============ 股票数据 Redis 缓存结束 ============ # 腾讯云短信配置 SMS_SECRET_ID = 'AKID2we9TacdTAhCjCSYTErHVimeJo9Yr00s' SMS_SECRET_KEY = 'pMlBWijlkgT9fz5ziEXdWEnAPTJzRfkf' SMS_SDK_APP_ID = "1400972398" SMS_SIGN_NAME = "价值前沿科技" SMS_TEMPLATE_REGISTER = "2386557" # 注册模板 SMS_TEMPLATE_LOGIN = "2386540" # 登录模板 # 微信开放平台配置(PC 扫码登录用) WECHAT_OPEN_APPID = 'wxa8d74c47041b5f87' WECHAT_OPEN_APPSECRET = 'eedef95b11787fd7ca7f1acc6c9061bc' # 微信公众号配置(H5 网页授权用) WECHAT_MP_APPID = 'wx8afd36f7c7b21ba0' WECHAT_MP_APPSECRET = 'c3ec5a227ddb26ad8a1d4c55efa1cf86' # 微信小程序配置(H5 跳转小程序用) WECHAT_MINIPROGRAM_APPID = 'wx0edeaab76d4fa414' WECHAT_MINIPROGRAM_APPSECRET = os.environ.get('WECHAT_MINIPROGRAM_APPSECRET', '0d0c70084f05a8c1411f6b89da7e815d') WECHAT_MINIPROGRAM_ORIGINAL_ID = 'gh_fd2fd8dd2fb5' # Redis 缓存键前缀(微信 token) WECHAT_ACCESS_TOKEN_PREFIX = "wechat:access_token:" WECHAT_JSAPI_TICKET_PREFIX = "wechat:jsapi_ticket:" # 微信回调地址 WECHAT_REDIRECT_URI = 'https://api.valuefrontier.cn/api/auth/wechat/callback' # 前端域名(用于登录成功后重定向) FRONTEND_URL = 'https://valuefrontier.cn' # 邮件服务配置(QQ企业邮箱) MAIL_SERVER = 'smtp.exmail.qq.com' MAIL_PORT = 465 MAIL_USE_SSL = True MAIL_USE_TLS = False MAIL_USERNAME = 'admin@valuefrontier.cn' MAIL_PASSWORD = 'QYncRu6WUdASvTg4' MAIL_DEFAULT_SENDER = 'admin@valuefrontier.cn' # Session和安全配置 # 使用固定的 SECRET_KEY,确保服务器重启后用户登录状态不丢失 # 重要:生产环境请使用环境变量配置,不要硬编码 import os app.config['SECRET_KEY'] = os.environ.get('FLASK_SECRET_KEY', 'vf_production_secret_key_2024_valuefrontier_cn') # ============ Redis Session 配置(支持多进程/多 Worker)============ # 使用 Redis 存储 session,确保多个 Gunicorn worker 共享 session # 通过环境变量控制是否启用 Redis Session(排查问题时可以禁用) USE_REDIS_SESSION = os.environ.get('USE_REDIS_SESSION', 'true').lower() == 'true' if USE_REDIS_SESSION: app.config['SESSION_TYPE'] = 'redis' app.config['SESSION_REDIS'] = redis.Redis(host=_REDIS_HOST, port=_REDIS_PORT, db=1, password=_REDIS_PASSWORD) # db=1 用于 session app.config['SESSION_PERMANENT'] = True app.config['SESSION_USE_SIGNER'] = True # 对 session cookie 签名,提高安全性 app.config['SESSION_KEY_PREFIX'] = 'vf_session:' # session key 前缀 app.config['SESSION_REFRESH_EACH_REQUEST'] = True # 每次请求都刷新 session TTL # 注意:Flask-Session 使用 PERMANENT_SESSION_LIFETIME 作为 Redis TTL(下面已配置为7天) print(f"📦 Flask Session 配置: Redis {_REDIS_HOST}:{_REDIS_PORT}/db=1, 过期时间: 7天") else: # 使用默认的 cookie session(单 Worker 模式可用) app.config['SESSION_TYPE'] = 'null' # 禁用服务端 session,使用 cookie print(f"📦 Flask Session 配置: Cookie 模式(单 Worker)") # ============ Redis Session 配置结束 ============ # Cookie 配置 - 重要:HTTPS 环境必须设置 SECURE=True app.config['SESSION_COOKIE_SECURE'] = True # 生产环境使用 HTTPS,必须为 True app.config['SESSION_COOKIE_HTTPONLY'] = True # 生产环境应设为True,防止XSS攻击 # SameSite='None' 允许微信内置浏览器在 OAuth 重定向后携带 Cookie # 必须配合 Secure=True 使用(已在上面配置) app.config['SESSION_COOKIE_SAMESITE'] = 'None' # 微信浏览器兼容性:必须为 None app.config['SESSION_COOKIE_DOMAIN'] = None # 不限制域名 app.config['SESSION_COOKIE_PATH'] = '/' # 设置cookie路径 app.config['PERMANENT_SESSION_LIFETIME'] = timedelta(days=7) # session持续7天 app.config['REMEMBER_COOKIE_DURATION'] = timedelta(days=30) # 记住登录30天 app.config['REMEMBER_COOKIE_SECURE'] = True # 生产环境使用 HTTPS,必须为 True app.config['REMEMBER_COOKIE_HTTPONLY'] = True # 防止XSS攻击 app.config['REMEMBER_COOKIE_SAMESITE'] = 'None' # 微信浏览器兼容性 # 初始化 Flask-Session(仅在启用 Redis Session 时) if USE_REDIS_SESSION: Session(app) print("✅ Flask-Session (Redis) 已初始化,支持多 Worker 共享 session") # 确保 session 使用永久模式并刷新 TTL(解决 Flask-Session 0.8.0 TTL 问题) @app.before_request def refresh_session_ttl(): """ 每次请求开始时: 1. 确保 session 是永久的,使用 PERMANENT_SESSION_LIFETIME 作为 TTL 2. 标记 session 为已修改,触发 Redis TTL 刷新 注意:必须在 before_request 中设置 session.modified = True 因为 Flask-Session 的 save_session 在 after_request 之前执行 如果在 after_request 中设置,TTL 不会被刷新 """ from flask import session session.permanent = True # 只有当 session 中有用户数据时才刷新 TTL(避免为匿名用户创建 session) if session.get('user_id') or session.get('_user_id'): session.modified = True # 配置邮件 app.config['MAIL_SERVER'] = MAIL_SERVER app.config['MAIL_PORT'] = MAIL_PORT app.config['MAIL_USE_SSL'] = MAIL_USE_SSL app.config['MAIL_USE_TLS'] = MAIL_USE_TLS app.config['MAIL_USERNAME'] = MAIL_USERNAME app.config['MAIL_PASSWORD'] = MAIL_PASSWORD app.config['MAIL_DEFAULT_SENDER'] = MAIL_DEFAULT_SENDER # 允许前端跨域访问 - 修复CORS配置 try: CORS(app, origins=["http://localhost:3000", "http://127.0.0.1:3000", "http://localhost:5173", "https://valuefrontier.cn", "http://valuefrontier.cn", "https://www.valuefrontier.cn", "http://www.valuefrontier.cn"], # 明确指定允许的源 methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], allow_headers=["Content-Type", "Authorization", "X-Requested-With", "Cache-Control"], supports_credentials=True, # 允许携带凭据 expose_headers=["Content-Type", "Authorization"]) except ImportError: pass # 如果未安装flask_cors则跳过 # 初始化 Flask-Login login_manager = LoginManager() login_manager.init_app(app) login_manager.login_view = 'login' login_manager.login_message = '请先登录访问此页面' login_manager.remember_cookie_duration = timedelta(days=30) # 记住登录持续时间 Compress(app) MAX_CONTENT_LENGTH = 16 * 1024 * 1024 # 16MB max file size # Configure Flask-Compress app.config['COMPRESS_ALGORITHM'] = ['gzip', 'br'] app.config['COMPRESS_MIMETYPES'] = [ 'text/html', 'text/css', 'text/xml', 'application/json', 'application/javascript', 'application/x-javascript' ] app.config['SQLALCHEMY_DATABASE_URI'] = 'mysql+pymysql://root:Zzl33818!@127.0.0.1:3306/stock?charset=utf8mb4' app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False app.config['SQLALCHEMY_ENGINE_OPTIONS'] = { 'pool_size': 50, # 每个 worker 常驻连接数 'pool_recycle': 1800, # 连接回收时间 30 分钟(原 1 小时) 'pool_pre_ping': True, # 使用前检测连接是否有效 'pool_timeout': 20, # 获取连接超时时间(秒) 'max_overflow': 100 # 每个 worker 临时溢出连接数 # 每个 worker 最多 150 个连接,32 workers 总共最多 4800 个连接 } # Cache directory setup CACHE_DIR = Path('cache') CACHE_DIR.mkdir(exist_ok=True) def beijing_now(): # 使用 pytz 处理时区,但返回 naive datetime(适合数据库存储) beijing_tz = pytz.timezone('Asia/Shanghai') return datetime.now(beijing_tz).replace(tzinfo=None) # 检查用户是否登录的装饰器 def login_required(f): @wraps(f) def decorated_function(*args, **kwargs): if 'user_id' not in session: return jsonify({'success': False, 'error': '未登录'}), 401 return f(*args, **kwargs) return decorated_function # Memory management constants MAX_MEMORY_PERCENT = 75 MEMORY_CHECK_INTERVAL = 300 MAX_CACHE_ITEMS = 50 db = SQLAlchemy(app) # 初始化邮件服务 mail = Mail(app) # 初始化 Flask-SocketIO(用于实时事件推送) # 支持通过环境变量指定模式: SOCKETIO_ASYNC_MODE=gevent|threading def _detect_async_mode(): """检测可用的异步模式""" # 允许通过环境变量强制指定 forced_mode = os.environ.get('SOCKETIO_ASYNC_MODE', '').lower() if forced_mode in ('gevent', 'threading', 'eventlet'): return forced_mode # 检测 gevent 是否已被 patch(Gunicorn -k gevent 会自动 patch) try: from gevent import monkey if monkey.is_module_patched('socket'): return 'gevent' except ImportError: pass # 默认使用 threading(最稳定,配合 simple-websocket) return 'threading' _async_mode = _detect_async_mode() print(f"📡 Flask-SocketIO async_mode: {_async_mode}") # Redis 消息队列 URL(支持多 Worker 之间的消息同步) # 使用 127.0.0.1 而非 localhost,避免 eventlet DNS 问题 # 格式: redis://:password@host:port/db SOCKETIO_MESSAGE_QUEUE = os.environ.get('SOCKETIO_REDIS_URL', f'redis://:{_REDIS_PASSWORD}@{_REDIS_HOST}:{_REDIS_PORT}/2') # 检测是否需要启用消息队列 # 默认启用(多 Worker 模式需要,单 Worker 模式也兼容) _use_message_queue = os.environ.get('SOCKETIO_USE_QUEUE', 'true').lower() == 'true' socketio = SocketIO( app, cors_allowed_origins=["http://localhost:3000", "http://127.0.0.1:3000", "http://localhost:5173", "https://valuefrontier.cn", "http://valuefrontier.cn"], async_mode=_async_mode, message_queue=SOCKETIO_MESSAGE_QUEUE if _use_message_queue else None, manage_session=False, # 让 Flask-Session 管理 session,避免与 SocketIO 冲突 logger=True, engineio_logger=False, ping_timeout=120, # 心跳超时时间(秒),客户端120秒内无响应才断开 ping_interval=25 # 心跳检测间隔(秒),每25秒发送一次ping ) if _use_message_queue: print(f"✅ Flask-SocketIO 已配置 Redis 消息队列: {SOCKETIO_MESSAGE_QUEUE}") else: print(f"📡 Flask-SocketIO 单 Worker 模式(无消息队列)") @login_manager.user_loader def load_user(user_id): """Flask-Login 用户加载回调""" try: return User.query.get(int(user_id)) except Exception as e: app.logger.error(f"用户加载错误: {e}") return None # 全局错误处理器 - 确保API接口始终返回JSON @app.errorhandler(404) def not_found_error(error): """404错误处理""" if request.path.startswith('/api/'): return jsonify({'success': False, 'error': '接口不存在'}), 404 return error @app.errorhandler(500) def internal_error(error): """500错误处理""" db.session.rollback() if request.path.startswith('/api/'): return jsonify({'success': False, 'error': '服务器内部错误'}), 500 return error @app.errorhandler(405) def method_not_allowed_error(error): """405错误处理""" if request.path.startswith('/api/'): return jsonify({'success': False, 'error': '请求方法不被允许'}), 405 return error class Post(db.Model): """帖子模型""" id = db.Column(db.Integer, primary_key=True) event_id = db.Column(db.Integer, db.ForeignKey('event.id'), nullable=False) user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False) # 内容 title = db.Column(db.String(200)) # 标题(可选) content = db.Column(db.Text, nullable=False) # 内容 content_type = db.Column(db.String(20), default='text') # 内容类型:text/rich_text/link # 时间 created_at = db.Column(db.DateTime, default=beijing_now) updated_at = db.Column(db.DateTime, default=beijing_now, onupdate=beijing_now) # 统计 likes_count = db.Column(db.Integer, default=0) comments_count = db.Column(db.Integer, default=0) view_count = db.Column(db.Integer, default=0) # 状态 status = db.Column(db.String(20), default='active') # active/hidden/deleted is_top = db.Column(db.Boolean, default=False) # 是否置顶 # 关系 user = db.relationship('User', backref='posts') likes = db.relationship('PostLike', backref='post', lazy='dynamic') comments = db.relationship('Comment', backref='post', lazy='dynamic') class Comment(db.Model): """帖子评论模型""" id = db.Column(db.Integer, primary_key=True) post_id = db.Column(db.Integer, db.ForeignKey('post.id'), nullable=False) user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False) # 内容 content = db.Column(db.Text, nullable=False) parent_id = db.Column(db.Integer, db.ForeignKey('comment.id')) # 时间 created_at = db.Column(db.DateTime, default=beijing_now) updated_at = db.Column(db.DateTime, default=beijing_now, onupdate=beijing_now) # 统计 likes_count = db.Column(db.Integer, default=0) # 状态 status = db.Column(db.String(20), default='active') # active/hidden/deleted # 关系 user = db.relationship('User', backref='comments') replies = db.relationship('Comment', backref=db.backref('parent', remote_side=[id])) class User(UserMixin, db.Model): """用户模型 - 完全匹配现有数据库表结构""" __tablename__ = 'user' # 主键 id = db.Column(db.Integer, primary_key=True, autoincrement=True) # 基础账号信息 username = db.Column(db.String(80), unique=True, nullable=False) email = db.Column(db.String(120), unique=True, nullable=True) password_hash = db.Column(db.String(255), nullable=True) email_confirmed = db.Column(db.Boolean, nullable=True, default=True) # 时间字段 created_at = db.Column(db.DateTime, nullable=True, default=beijing_now) last_seen = db.Column(db.DateTime, nullable=True, default=beijing_now) # 账号状态 status = db.Column(db.String(20), nullable=True, default='active') # 个人资料信息 nickname = db.Column(db.String(30), nullable=True) avatar_url = db.Column(db.String(200), nullable=True) banner_url = db.Column(db.String(200), nullable=True) bio = db.Column(db.String(200), nullable=True) gender = db.Column(db.String(10), nullable=True) birth_date = db.Column(db.Date, nullable=True) location = db.Column(db.String(100), nullable=True) # 联系方式 phone = db.Column(db.String(20), nullable=True) wechat_id = db.Column(db.String(80), nullable=True) # 微信号 # 实名认证 real_name = db.Column(db.String(30), nullable=True) id_number = db.Column(db.String(18), nullable=True) is_verified = db.Column(db.Boolean, nullable=True, default=False) verify_time = db.Column(db.DateTime, nullable=True) # 投资偏好 trading_experience = db.Column(db.String(200), nullable=True) investment_style = db.Column(db.String(50), nullable=True) risk_preference = db.Column(db.String(20), nullable=True) investment_amount = db.Column(db.String(20), nullable=True) preferred_markets = db.Column(db.String(200), nullable=True) # 社区数据 user_level = db.Column(db.Integer, nullable=True, default=1) reputation_score = db.Column(db.Integer, nullable=True, default=0) contribution_point = db.Column(db.Integer, nullable=True, default=0) post_count = db.Column(db.Integer, nullable=True, default=0) comment_count = db.Column(db.Integer, nullable=True, default=0) follower_count = db.Column(db.Integer, nullable=True, default=0) following_count = db.Column(db.Integer, nullable=True, default=0) # 创作者相关 is_creator = db.Column(db.Boolean, nullable=True, default=False) creator_type = db.Column(db.String(20), nullable=True) creator_tags = db.Column(db.String(200), nullable=True) # 通知设置 email_notifications = db.Column(db.Boolean, nullable=True, default=True) sms_notifications = db.Column(db.Boolean, nullable=True, default=False) wechat_notifications = db.Column(db.Boolean, nullable=True, default=False) notification_preferences = db.Column(db.String(500), nullable=True) # 隐私和界面设置 privacy_level = db.Column(db.String(20), nullable=True, default='public') theme_preference = db.Column(db.String(20), nullable=True, default='light') blocked_keywords = db.Column(db.String(500), nullable=True) # 手机验证相关 phone_confirmed = db.Column(db.Boolean, nullable=True, default=False) # 注意:原表中是blob,这里改为Boolean更合理 phone_confirm_time = db.Column(db.DateTime, nullable=True) # 微信登录相关字段 wechat_union_id = db.Column(db.String(100), nullable=True) # 微信UnionID wechat_open_id = db.Column(db.String(100), nullable=True) # 微信OpenID def __init__(self, username, email=None, password=None, phone=None): """初始化用户""" self.username = username if email: self.email = email if phone: self.phone = phone if password: self.set_password(password) self.nickname = username # 默认昵称为用户名 self.created_at = beijing_now() self.last_seen = beijing_now() def set_password(self, password): """设置密码""" if password: self.password_hash = generate_password_hash(password) def check_password(self, password): """验证密码""" if not password or not self.password_hash: return False return check_password_hash(self.password_hash, password) def update_last_seen(self): """更新最后活跃时间""" self.last_seen = beijing_now() db.session.commit() def confirm_email(self): """确认邮箱""" self.email_confirmed = True db.session.commit() def confirm_phone(self): """确认手机号""" self.phone_confirmed = True self.phone_confirm_time = beijing_now() db.session.commit() def bind_wechat(self, open_id, union_id=None, wechat_info=None): """绑定微信账号""" self.wechat_open_id = open_id if union_id: self.wechat_union_id = union_id # 如果提供了微信用户信息,更新头像和昵称 if wechat_info: if not self.avatar_url and wechat_info.get('headimgurl'): self.avatar_url = wechat_info['headimgurl'] if not self.nickname and wechat_info.get('nickname'): # 确保昵称编码正确且长度合理 nickname = self._sanitize_nickname(wechat_info['nickname']) self.nickname = nickname db.session.commit() def _sanitize_nickname(self, nickname): """清理和验证昵称""" if not nickname: return '微信用户' try: # 确保是正确的UTF-8字符串 sanitized = str(nickname).strip() # 移除可能的控制字符 import re sanitized = re.sub(r'[\x00-\x1f\x7f-\x9f]', '', sanitized) # 限制长度(避免过长的昵称) if len(sanitized) > 50: sanitized = sanitized[:47] + '...' # 如果清理后为空,使用默认值 if not sanitized: sanitized = '微信用户' return sanitized except Exception as e: return '微信用户' def unbind_wechat(self): """解绑微信账号""" self.wechat_open_id = None self.wechat_union_id = None db.session.commit() def increment_post_count(self): """增加发帖数""" self.post_count = (self.post_count or 0) + 1 db.session.commit() def increment_comment_count(self): """增加评论数""" self.comment_count = (self.comment_count or 0) + 1 db.session.commit() def add_reputation(self, points): """增加声誉分数""" self.reputation_score = (self.reputation_score or 0) + points db.session.commit() def to_dict(self, include_sensitive=False): """转换为字典""" data = { 'id': self.id, 'username': self.username, 'nickname': self.nickname or self.username, 'avatar_url': self.avatar_url, 'banner_url': self.banner_url, 'bio': self.bio, 'gender': self.gender, 'location': self.location, 'user_level': self.user_level or 1, 'reputation_score': self.reputation_score or 0, 'contribution_point': self.contribution_point or 0, 'post_count': self.post_count or 0, 'comment_count': self.comment_count or 0, 'follower_count': self.follower_count or 0, 'following_count': self.following_count or 0, 'is_creator': self.is_creator or False, 'creator_type': self.creator_type, 'creator_tags': self.creator_tags, 'is_verified': self.is_verified or False, 'created_at': self.created_at.isoformat() if self.created_at else None, 'last_seen': self.last_seen.isoformat() if self.last_seen else None, 'status': self.status, 'has_wechat': bool(self.wechat_open_id), 'is_authenticated': True } # 获取用户订阅信息(从 user_subscriptions 表) subscription = UserSubscription.query.filter_by(user_id=self.id).first() if subscription: data.update({ 'subscription_type': subscription.subscription_type, 'subscription_status': subscription.subscription_status, 'billing_cycle': subscription.billing_cycle, 'start_date': subscription.start_date.isoformat() if subscription.start_date else None, 'end_date': subscription.end_date.isoformat() if subscription.end_date else None, 'auto_renewal': subscription.auto_renewal }) else: # 无订阅时使用默认值 data.update({ 'subscription_type': 'free', 'subscription_status': 'inactive', 'billing_cycle': None, 'start_date': None, 'end_date': None, 'auto_renewal': False }) # 敏感信息只在需要时包含 if include_sensitive: data.update({ 'email': self.email, 'phone': self.phone, 'email_confirmed': self.email_confirmed, 'phone_confirmed': self.phone_confirmed, 'real_name': self.real_name, 'birth_date': self.birth_date.isoformat() if self.birth_date else None, 'trading_experience': self.trading_experience, 'investment_style': self.investment_style, 'risk_preference': self.risk_preference, 'investment_amount': self.investment_amount, 'preferred_markets': self.preferred_markets, 'email_notifications': self.email_notifications, 'sms_notifications': self.sms_notifications, 'wechat_notifications': self.wechat_notifications, 'privacy_level': self.privacy_level, 'theme_preference': self.theme_preference }) return data def to_public_dict(self): """公开信息字典(用于显示给其他用户)""" return { 'id': self.id, 'username': self.username, 'nickname': self.nickname or self.username, 'avatar_url': self.avatar_url, 'bio': self.bio, 'user_level': self.user_level or 1, 'reputation_score': self.reputation_score or 0, 'post_count': self.post_count or 0, 'follower_count': self.follower_count or 0, 'is_creator': self.is_creator or False, 'creator_type': self.creator_type, 'is_verified': self.is_verified or False, 'created_at': self.created_at.isoformat() if self.created_at else None } @staticmethod def find_by_login_info(login_info): """根据登录信息查找用户(支持用户名、邮箱、手机号)""" return User.query.filter( db.or_( User.username == login_info, User.email == login_info, User.phone == login_info ) ).first() @staticmethod def find_by_wechat_openid(open_id): """根据微信OpenID查找用户""" return User.query.filter_by(wechat_open_id=open_id).first() @staticmethod def find_by_wechat_unionid(union_id): """根据微信UnionID查找用户""" return User.query.filter_by(wechat_union_id=union_id).first() @staticmethod def is_username_taken(username): """检查用户名是否已被使用""" return User.query.filter_by(username=username).first() is not None @staticmethod def is_email_taken(email): """检查邮箱是否已被使用""" return User.query.filter_by(email=email).first() is not None @staticmethod def is_phone_taken(phone): """检查手机号是否已被使用""" return User.query.filter_by(phone=phone).first() is not None def __repr__(self): return f'' # ============================================ # 订阅功能模块(安全版本 - 独立表) # ============================================ class UserSubscription(db.Model): """用户订阅表 - 独立于现有User表""" __tablename__ = 'user_subscriptions' id = db.Column(db.Integer, primary_key=True, autoincrement=True) user_id = db.Column(db.Integer, nullable=False, unique=True, index=True) subscription_type = db.Column(db.String(10), nullable=False, default='free') subscription_status = db.Column(db.String(20), nullable=False, default='active') start_date = db.Column(db.DateTime, nullable=True) end_date = db.Column(db.DateTime, nullable=True) billing_cycle = db.Column(db.String(10), nullable=True) auto_renewal = db.Column(db.Boolean, nullable=False, default=False) created_at = db.Column(db.DateTime, default=beijing_now) updated_at = db.Column(db.DateTime, default=beijing_now, onupdate=beijing_now) def is_active(self): if self.subscription_status != 'active': return False if self.subscription_type == 'free': return True if self.end_date: try: now = beijing_now() if self.end_date < now: return False except Exception as e: return False return True def days_left(self): if self.subscription_type == 'free' or not self.end_date: return 999 try: now = beijing_now() delta = self.end_date - now return max(0, delta.days) except Exception as e: return 0 def to_dict(self): return { 'type': self.subscription_type, 'status': self.subscription_status, 'is_active': self.is_active(), 'days_left': self.days_left(), 'start_date': self.start_date.isoformat() if self.start_date else None, 'end_date': self.end_date.isoformat() if self.end_date else None, 'billing_cycle': self.billing_cycle, 'auto_renewal': self.auto_renewal } class SubscriptionPlan(db.Model): """订阅套餐表""" __tablename__ = 'subscription_plans' id = db.Column(db.Integer, primary_key=True, autoincrement=True) name = db.Column(db.String(50), nullable=False, unique=True) display_name = db.Column(db.String(100), nullable=False) description = db.Column(db.Text, nullable=True) monthly_price = db.Column(db.Numeric(10, 2), nullable=False) yearly_price = db.Column(db.Numeric(10, 2), nullable=False) features = db.Column(db.Text, nullable=True) pricing_options = db.Column(db.Text, nullable=True) # JSON格式:[{"months": 1, "price": 99}, {"months": 12, "price": 999}] is_active = db.Column(db.Boolean, default=True) sort_order = db.Column(db.Integer, default=0) created_at = db.Column(db.DateTime, default=beijing_now) def to_dict(self): # 解析pricing_options(如果存在) pricing_opts = None if self.pricing_options: try: pricing_opts = json.loads(self.pricing_options) except: pricing_opts = None # 如果没有pricing_options,则从monthly_price和yearly_price生成默认选项 if not pricing_opts: pricing_opts = [ { 'months': 1, 'price': float(self.monthly_price) if self.monthly_price else 0, 'label': '月付', 'cycle_key': 'monthly' }, { 'months': 12, 'price': float(self.yearly_price) if self.yearly_price else 0, 'label': '年付', 'cycle_key': 'yearly', 'discount_percent': 20 # 年付默认20%折扣 } ] return { 'id': self.id, 'name': self.name, 'display_name': self.display_name, 'description': self.description, 'monthly_price': float(self.monthly_price) if self.monthly_price else 0, 'yearly_price': float(self.yearly_price) if self.yearly_price else 0, 'pricing_options': pricing_opts, # 新增:灵活计费周期选项 'features': json.loads(self.features) if self.features else [], 'is_active': self.is_active, 'sort_order': self.sort_order } class PaymentOrder(db.Model): """支付订单表""" __tablename__ = 'payment_orders' id = db.Column(db.Integer, primary_key=True, autoincrement=True) order_no = db.Column(db.String(32), unique=True, nullable=False) user_id = db.Column(db.Integer, nullable=False) plan_name = db.Column(db.String(20), nullable=False) billing_cycle = db.Column(db.String(10), nullable=False) amount = db.Column(db.Numeric(10, 2), nullable=False) original_amount = db.Column(db.Numeric(10, 2), nullable=True) # 原价 discount_amount = db.Column(db.Numeric(10, 2), nullable=True, default=0) # 折扣金额 promo_code_id = db.Column(db.Integer, db.ForeignKey('promo_codes.id'), nullable=True) # 优惠码ID payment_method = db.Column(db.String(20), default='wechat') # 支付方式: wechat/alipay wechat_order_id = db.Column(db.String(64), nullable=True) # 微信交易号 alipay_trade_no = db.Column(db.String(64), nullable=True) # 支付宝交易号 prepay_id = db.Column(db.String(64), nullable=True) qr_code_url = db.Column(db.String(200), nullable=True) # 微信支付二维码URL pay_url = db.Column(db.String(2000), nullable=True) # 支付宝支付链接(较长) status = db.Column(db.String(20), default='pending') created_at = db.Column(db.DateTime, default=beijing_now) paid_at = db.Column(db.DateTime, nullable=True) expired_at = db.Column(db.DateTime, nullable=True) remark = db.Column(db.String(200), nullable=True) # 关联优惠码 promo_code = db.relationship('PromoCode', backref='orders', lazy=True, foreign_keys=[promo_code_id]) def __init__(self, user_id, plan_name, billing_cycle, amount, original_amount=None, discount_amount=0): self.user_id = user_id self.plan_name = plan_name self.billing_cycle = billing_cycle self.amount = amount self.original_amount = original_amount if original_amount is not None else amount self.discount_amount = discount_amount or 0 import random timestamp = int(beijing_now().timestamp() * 1000000) random_suffix = random.randint(1000, 9999) self.order_no = f"{timestamp}{user_id:04d}{random_suffix}" self.expired_at = beijing_now() + timedelta(minutes=30) def is_expired(self): if not self.expired_at: return False try: now = beijing_now() return now > self.expired_at except Exception as e: return False def mark_as_paid(self, transaction_id, payment_method=None): """ 标记订单为已支付 Args: transaction_id: 交易号(微信或支付宝) payment_method: 支付方式(可选,如果已设置则不覆盖) """ self.status = 'paid' self.paid_at = beijing_now() # 根据支付方式存储交易号 if payment_method: self.payment_method = payment_method if self.payment_method == 'alipay': self.alipay_trade_no = transaction_id else: self.wechat_order_id = transaction_id def to_dict(self): return { 'id': self.id, 'order_no': self.order_no, 'user_id': self.user_id, 'plan_name': self.plan_name, 'billing_cycle': self.billing_cycle, 'amount': float(self.amount) if self.amount else 0, 'original_amount': float(self.original_amount) if self.original_amount else None, 'discount_amount': float(self.discount_amount) if self.discount_amount else 0, 'promo_code': self.promo_code.code if self.promo_code else None, 'payment_method': self.payment_method or 'wechat', 'qr_code_url': self.qr_code_url, 'pay_url': self.pay_url, 'status': self.status, 'is_expired': self.is_expired(), 'created_at': self.created_at.isoformat() if self.created_at else None, 'paid_at': self.paid_at.isoformat() if self.paid_at else None, 'expired_at': self.expired_at.isoformat() if self.expired_at else None, 'remark': self.remark } class PromoCode(db.Model): """优惠码表""" __tablename__ = 'promo_codes' id = db.Column(db.Integer, primary_key=True, autoincrement=True) code = db.Column(db.String(50), unique=True, nullable=False, index=True) description = db.Column(db.String(200), nullable=True) # 折扣类型和值 discount_type = db.Column(db.String(20), nullable=False) # 'percentage' 或 'fixed_amount' discount_value = db.Column(db.Numeric(10, 2), nullable=False) # 适用范围 applicable_plans = db.Column(db.String(200), nullable=True) # JSON格式 applicable_cycles = db.Column(db.String(50), nullable=True) # JSON格式 min_amount = db.Column(db.Numeric(10, 2), nullable=True) # 使用限制 max_uses = db.Column(db.Integer, nullable=True) max_uses_per_user = db.Column(db.Integer, default=1) current_uses = db.Column(db.Integer, default=0) # 有效期 valid_from = db.Column(db.DateTime, nullable=False) valid_until = db.Column(db.DateTime, nullable=False) # 状态 is_active = db.Column(db.Boolean, default=True) created_by = db.Column(db.Integer, nullable=True) created_at = db.Column(db.DateTime, default=beijing_now) updated_at = db.Column(db.DateTime, default=beijing_now, onupdate=beijing_now) def to_dict(self): return { 'id': self.id, 'code': self.code, 'description': self.description, 'discount_type': self.discount_type, 'discount_value': float(self.discount_value) if self.discount_value else 0, 'applicable_plans': json.loads(self.applicable_plans) if self.applicable_plans else None, 'applicable_cycles': json.loads(self.applicable_cycles) if self.applicable_cycles else None, 'min_amount': float(self.min_amount) if self.min_amount else None, 'max_uses': self.max_uses, 'max_uses_per_user': self.max_uses_per_user, 'current_uses': self.current_uses, 'valid_from': self.valid_from.isoformat() if self.valid_from else None, 'valid_until': self.valid_until.isoformat() if self.valid_until else None, 'is_active': self.is_active } class PromoCodeUsage(db.Model): """优惠码使用记录表""" __tablename__ = 'promo_code_usage' id = db.Column(db.Integer, primary_key=True, autoincrement=True) promo_code_id = db.Column(db.Integer, db.ForeignKey('promo_codes.id'), nullable=False) user_id = db.Column(db.Integer, nullable=False, index=True) order_id = db.Column(db.Integer, db.ForeignKey('payment_orders.id'), nullable=False) original_amount = db.Column(db.Numeric(10, 2), nullable=False) discount_amount = db.Column(db.Numeric(10, 2), nullable=False) final_amount = db.Column(db.Numeric(10, 2), nullable=False) used_at = db.Column(db.DateTime, default=beijing_now) # 关系 promo_code = db.relationship('PromoCode', backref='usages') order = db.relationship('PaymentOrder', backref='promo_usage') class SubscriptionUpgrade(db.Model): """订阅升级/降级记录表""" __tablename__ = 'subscription_upgrades' id = db.Column(db.Integer, primary_key=True, autoincrement=True) user_id = db.Column(db.Integer, nullable=False, index=True) order_id = db.Column(db.Integer, db.ForeignKey('payment_orders.id'), nullable=False) # 原订阅信息 from_plan = db.Column(db.String(20), nullable=False) from_cycle = db.Column(db.String(10), nullable=False) from_end_date = db.Column(db.DateTime, nullable=True) # 新订阅信息 to_plan = db.Column(db.String(20), nullable=False) to_cycle = db.Column(db.String(10), nullable=False) to_end_date = db.Column(db.DateTime, nullable=False) # 价格计算 remaining_value = db.Column(db.Numeric(10, 2), nullable=False) upgrade_amount = db.Column(db.Numeric(10, 2), nullable=False) actual_amount = db.Column(db.Numeric(10, 2), nullable=False) upgrade_type = db.Column(db.String(20), nullable=False) # 'plan_upgrade', 'cycle_change', 'both' created_at = db.Column(db.DateTime, default=beijing_now) # 关系 order = db.relationship('PaymentOrder', backref='upgrade_record') # ============================================ # 模拟盘相关模型 # ============================================ class SimulationAccount(db.Model): """模拟账户""" __tablename__ = 'simulation_accounts' id = db.Column(db.Integer, primary_key=True) user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False, unique=True) account_name = db.Column(db.String(100), default='我的模拟账户') initial_capital = db.Column(db.Numeric(15, 2), default=1000000.00) # 初始资金 available_cash = db.Column(db.Numeric(15, 2), default=1000000.00) # 可用资金 frozen_cash = db.Column(db.Numeric(15, 2), default=0.00) # 冻结资金 position_value = db.Column(db.Numeric(15, 2), default=0.00) # 持仓市值 total_assets = db.Column(db.Numeric(15, 2), default=1000000.00) # 总资产 total_profit = db.Column(db.Numeric(15, 2), default=0.00) # 总盈亏 total_profit_rate = db.Column(db.Numeric(10, 4), default=0.00) # 总收益率 daily_profit = db.Column(db.Numeric(15, 2), default=0.00) # 日盈亏 daily_profit_rate = db.Column(db.Numeric(10, 4), default=0.00) # 日收益率 created_at = db.Column(db.DateTime, default=beijing_now) updated_at = db.Column(db.DateTime, default=beijing_now, onupdate=beijing_now) last_settlement_date = db.Column(db.Date) # 最后结算日期 # 关系 user = db.relationship('User', backref='simulation_account') positions = db.relationship('SimulationPosition', backref='account', lazy='dynamic') orders = db.relationship('SimulationOrder', backref='account', lazy='dynamic') transactions = db.relationship('SimulationTransaction', backref='account', lazy='dynamic') def calculate_total_assets(self): """计算总资产""" self.total_assets = self.available_cash + self.frozen_cash + self.position_value self.total_profit = self.total_assets - self.initial_capital self.total_profit_rate = (self.total_profit / self.initial_capital) * 100 if self.initial_capital > 0 else 0 return self.total_assets class SimulationPosition(db.Model): """模拟持仓""" __tablename__ = 'simulation_positions' id = db.Column(db.Integer, primary_key=True) account_id = db.Column(db.Integer, db.ForeignKey('simulation_accounts.id'), nullable=False) stock_code = db.Column(db.String(20), nullable=False) stock_name = db.Column(db.String(100)) position_qty = db.Column(db.Integer, default=0) # 持仓数量 available_qty = db.Column(db.Integer, default=0) # 可用数量(T+1) frozen_qty = db.Column(db.Integer, default=0) # 冻结数量 avg_cost = db.Column(db.Numeric(10, 3), default=0.00) # 平均成本 current_price = db.Column(db.Numeric(10, 3), default=0.00) # 当前价格 market_value = db.Column(db.Numeric(15, 2), default=0.00) # 市值 profit = db.Column(db.Numeric(15, 2), default=0.00) # 盈亏 profit_rate = db.Column(db.Numeric(10, 4), default=0.00) # 盈亏比例 today_profit = db.Column(db.Numeric(15, 2), default=0.00) # 今日盈亏 today_profit_rate = db.Column(db.Numeric(10, 4), default=0.00) # 今日盈亏比例 created_at = db.Column(db.DateTime, default=beijing_now) updated_at = db.Column(db.DateTime, default=beijing_now, onupdate=beijing_now) __table_args__ = ( db.UniqueConstraint('account_id', 'stock_code', name='unique_account_stock'), ) def update_market_value(self, current_price): """更新市值和盈亏""" self.current_price = current_price self.market_value = self.position_qty * current_price total_cost = self.position_qty * self.avg_cost self.profit = self.market_value - total_cost self.profit_rate = (self.profit / total_cost * 100) if total_cost > 0 else 0 return self.market_value class SimulationOrder(db.Model): """模拟订单""" __tablename__ = 'simulation_orders' id = db.Column(db.Integer, primary_key=True) account_id = db.Column(db.Integer, db.ForeignKey('simulation_accounts.id'), nullable=False) order_no = db.Column(db.String(32), unique=True, nullable=False) stock_code = db.Column(db.String(20), nullable=False) stock_name = db.Column(db.String(100)) order_type = db.Column(db.String(10), nullable=False) # BUY/SELL price_type = db.Column(db.String(10), default='MARKET') # MARKET/LIMIT order_price = db.Column(db.Numeric(10, 3)) # 委托价格 order_qty = db.Column(db.Integer, nullable=False) # 委托数量 filled_qty = db.Column(db.Integer, default=0) # 成交数量 filled_price = db.Column(db.Numeric(10, 3)) # 成交价格 filled_amount = db.Column(db.Numeric(15, 2)) # 成交金额 commission = db.Column(db.Numeric(10, 2), default=0.00) # 手续费 stamp_tax = db.Column(db.Numeric(10, 2), default=0.00) # 印花税 transfer_fee = db.Column(db.Numeric(10, 2), default=0.00) # 过户费 total_fee = db.Column(db.Numeric(10, 2), default=0.00) # 总费用 status = db.Column(db.String(20), default='PENDING') # PENDING/PARTIAL/FILLED/CANCELLED/REJECTED reject_reason = db.Column(db.String(200)) order_time = db.Column(db.DateTime, default=beijing_now) filled_time = db.Column(db.DateTime) cancel_time = db.Column(db.DateTime) def calculate_fees(self): """计算交易费用""" if not self.filled_amount: return 0 # 佣金(万分之2.5,最低5元) self.commission = max(float(self.filled_amount) * 0.00025, 5.0) # 印花税(卖出时收取千分之1) if self.order_type == 'SELL': self.stamp_tax = float(self.filled_amount) * 0.001 else: self.stamp_tax = 0 # 过户费(双向收取,万分之0.2) self.transfer_fee = float(self.filled_amount) * 0.00002 # 总费用 self.total_fee = self.commission + self.stamp_tax + self.transfer_fee return self.total_fee class SimulationTransaction(db.Model): """模拟成交记录""" __tablename__ = 'simulation_transactions' id = db.Column(db.Integer, primary_key=True) account_id = db.Column(db.Integer, db.ForeignKey('simulation_accounts.id'), nullable=False) order_id = db.Column(db.Integer, db.ForeignKey('simulation_orders.id'), nullable=False) transaction_no = db.Column(db.String(32), unique=True, nullable=False) stock_code = db.Column(db.String(20), nullable=False) stock_name = db.Column(db.String(100)) transaction_type = db.Column(db.String(10), nullable=False) # BUY/SELL transaction_price = db.Column(db.Numeric(10, 3), nullable=False) transaction_qty = db.Column(db.Integer, nullable=False) transaction_amount = db.Column(db.Numeric(15, 2), nullable=False) commission = db.Column(db.Numeric(10, 2), default=0.00) stamp_tax = db.Column(db.Numeric(10, 2), default=0.00) transfer_fee = db.Column(db.Numeric(10, 2), default=0.00) total_fee = db.Column(db.Numeric(10, 2), default=0.00) transaction_time = db.Column(db.DateTime, default=beijing_now) settlement_date = db.Column(db.Date) # T+1结算日期 # 关系 order = db.relationship('SimulationOrder', backref='transactions') class SimulationDailyStats(db.Model): """模拟账户日统计""" __tablename__ = 'simulation_daily_stats' id = db.Column(db.Integer, primary_key=True) account_id = db.Column(db.Integer, db.ForeignKey('simulation_accounts.id'), nullable=False) stat_date = db.Column(db.Date, nullable=False) opening_assets = db.Column(db.Numeric(15, 2)) # 期初资产 closing_assets = db.Column(db.Numeric(15, 2)) # 期末资产 daily_profit = db.Column(db.Numeric(15, 2)) # 日盈亏 daily_profit_rate = db.Column(db.Numeric(10, 4)) # 日收益率 total_profit = db.Column(db.Numeric(15, 2)) # 累计盈亏 total_profit_rate = db.Column(db.Numeric(10, 4)) # 累计收益率 trade_count = db.Column(db.Integer, default=0) # 交易次数 win_count = db.Column(db.Integer, default=0) # 盈利次数 loss_count = db.Column(db.Integer, default=0) # 亏损次数 max_profit = db.Column(db.Numeric(15, 2)) # 最大盈利 max_loss = db.Column(db.Numeric(15, 2)) # 最大亏损 created_at = db.Column(db.DateTime, default=beijing_now) __table_args__ = ( db.UniqueConstraint('account_id', 'stat_date', name='unique_account_date'), ) def get_user_subscription_safe(user_id): """安全地获取用户订阅信息""" try: subscription = UserSubscription.query.filter_by(user_id=user_id).first() if not subscription: subscription = UserSubscription(user_id=user_id) db.session.add(subscription) db.session.commit() return subscription except Exception as e: # 返回默认免费版本对象 class DefaultSub: def to_dict(self): return { 'type': 'free', 'status': 'active', 'is_active': True, 'days_left': 999, 'billing_cycle': None, 'auto_renewal': False } return DefaultSub() def activate_user_subscription(user_id, plan_type, billing_cycle, extend_from_now=False): """ 激活用户订阅(新版:续费时从当前订阅结束时间开始延长) Args: user_id: 用户ID plan_type: 套餐类型 (pro/max) billing_cycle: 计费周期 (monthly/quarterly/semiannual/yearly) extend_from_now: 废弃参数,保留以兼容(现在自动判断) Returns: UserSubscription 对象 或 None """ try: subscription = UserSubscription.query.filter_by(user_id=user_id).first() if not subscription: # 新用户,创建订阅记录 subscription = UserSubscription(user_id=user_id) db.session.add(subscription) # 更新订阅类型和状态 subscription.subscription_type = plan_type subscription.subscription_status = 'active' subscription.billing_cycle = billing_cycle # 计算订阅周期天数 cycle_days_map = { 'monthly': 30, 'quarterly': 90, # 3个月 'semiannual': 180, # 6个月 'yearly': 365 } days = cycle_days_map.get(billing_cycle, 30) now = beijing_now() # 判断是新购还是续费 if subscription.end_date and subscription.end_date > now: # 续费:从当前订阅结束时间开始延长 start_date = subscription.end_date end_date = start_date + timedelta(days=days) else: # 新购或过期后重新购买:从当前时间开始 start_date = now end_date = now + timedelta(days=days) subscription.start_date = start_date subscription.end_date = end_date subscription.updated_at = now db.session.commit() return subscription except Exception as e: print(f"激活订阅失败: {e}") db.session.rollback() return None def validate_promo_code(code, plan_name, billing_cycle, amount, user_id): """验证优惠码 Returns: tuple: (promo_code_obj, error_message) """ try: promo = PromoCode.query.filter_by(code=code.upper(), is_active=True).first() if not promo: return None, "优惠码不存在或已失效" # 检查有效期 now = beijing_now() if now < promo.valid_from: return None, "优惠码尚未生效" if now > promo.valid_until: return None, "优惠码已过期" # 检查使用次数 if promo.max_uses and promo.current_uses >= promo.max_uses: return None, "优惠码已被使用完" # 检查每用户使用次数 if promo.max_uses_per_user: user_usage_count = PromoCodeUsage.query.filter_by( promo_code_id=promo.id, user_id=user_id ).count() if user_usage_count >= promo.max_uses_per_user: return None, f"您已使用过此优惠码(限用{promo.max_uses_per_user}次)" # 检查适用套餐 if promo.applicable_plans: try: applicable = json.loads(promo.applicable_plans) if plan_name not in applicable: return None, "该优惠码不适用于此套餐" except: pass # 检查适用周期 if promo.applicable_cycles: try: applicable = json.loads(promo.applicable_cycles) if billing_cycle not in applicable: return None, "该优惠码不适用于此计费周期" except: pass # 检查最低消费 if promo.min_amount and amount < float(promo.min_amount): return None, f"需满{float(promo.min_amount):.2f}元才可使用此优惠码" return promo, None except Exception as e: return None, f"验证优惠码时出错: {str(e)}" def calculate_discount(promo_code, amount): """计算优惠金额""" try: if promo_code.discount_type == 'percentage': discount = amount * (float(promo_code.discount_value) / 100) else: # fixed_amount discount = float(promo_code.discount_value) # 确保折扣不超过总金额 return min(discount, amount) except: return 0 def calculate_subscription_price_simple(user_id, to_plan_name, to_cycle, promo_code=None): """ 简化版价格计算:续费用户和新用户价格完全一致,不计算剩余价值 Args: user_id: 用户ID to_plan_name: 目标套餐名称 (pro/max) to_cycle: 计费周期 (monthly/quarterly/semiannual/yearly) promo_code: 优惠码(可选) Returns: dict: { 'is_renewal': False/True, # 是否为续费 'subscription_type': 'new'/'renew', # 订阅类型 'current_plan': 'pro', # 当前套餐(如果有) 'current_cycle': 'yearly', # 当前周期(如果有) 'new_plan_price': 2699.00, # 新套餐价格 'original_amount': 2699.00, # 原价 'discount_amount': 0, # 优惠金额 'final_amount': 2699.00, # 实付金额 'promo_code': None, # 使用的优惠码 'promo_error': None # 优惠码错误信息 } """ try: # 1. 获取当前订阅 current_sub = UserSubscription.query.filter_by(user_id=user_id).first() # 2. 获取目标套餐 to_plan = SubscriptionPlan.query.filter_by(name=to_plan_name, is_active=True).first() if not to_plan: return {'error': '目标套餐不存在'} # 3. 根据计费周期获取价格 # 优先从 pricing_options 获取价格 price = None if to_plan.pricing_options: try: pricing_opts = json.loads(to_plan.pricing_options) # 查找匹配的周期 for opt in pricing_opts: cycle_key = opt.get('cycle_key', '') months = opt.get('months', 0) # 匹配逻辑 if (cycle_key == to_cycle or (to_cycle == 'monthly' and months == 1) or (to_cycle == 'quarterly' and months == 3) or (to_cycle == 'semiannual' and months == 6) or (to_cycle == 'yearly' and months == 12)): price = float(opt.get('price', 0)) break except: pass # 如果 pricing_options 中没有找到,使用旧的 monthly_price/yearly_price if price is None: if to_cycle == 'yearly': price = float(to_plan.yearly_price) if to_plan.yearly_price else 0 else: # 默认月付 price = float(to_plan.monthly_price) if to_plan.monthly_price else 0 if price <= 0: return {'error': f'{to_cycle} 周期价格未配置'} # 4. 判断订阅类型和计算价格 is_renewal = False is_upgrade = False is_downgrade = False subscription_type = 'new' current_plan = None current_cycle = None remaining_value = 0 final_price = price if current_sub and current_sub.subscription_type in ['pro', 'max']: current_plan = current_sub.subscription_type current_cycle = current_sub.billing_cycle if current_plan == to_plan_name: # 同级续费:延长时长,全价购买 is_renewal = True subscription_type = 'renew' elif current_plan == 'pro' and to_plan_name == 'max': # 升级:Pro → Max,需要计算差价 is_upgrade = True subscription_type = 'upgrade' # 计算当前订阅的剩余价值 if current_sub.end_date and current_sub.end_date > datetime.utcnow(): # 获取当前套餐的原始价格 current_plan_obj = SubscriptionPlan.query.filter_by(name=current_plan, is_active=True).first() if current_plan_obj: current_price = None # 优先从 pricing_options 获取价格 if current_plan_obj.pricing_options: try: pricing_opts = json.loads(current_plan_obj.pricing_options) # 如果 current_cycle 为空或无效,根据剩余天数推断计费周期 if not current_cycle or current_cycle.strip() == '': remaining_days_total = (current_sub.end_date - current_sub.start_date).days if current_sub.start_date else 365 # 根据总天数推断计费周期 if remaining_days_total <= 35: inferred_cycle = 'monthly' elif remaining_days_total <= 100: inferred_cycle = 'quarterly' elif remaining_days_total <= 200: inferred_cycle = 'semiannual' else: inferred_cycle = 'yearly' else: inferred_cycle = current_cycle for opt in pricing_opts: if opt.get('cycle_key') == inferred_cycle: current_price = float(opt.get('price', 0)) current_cycle = inferred_cycle # 更新周期信息 break except: pass # 如果 pricing_options 中没找到,使用 yearly_price 作为默认 if current_price is None or current_price <= 0: current_price = float(current_plan_obj.yearly_price) if current_plan_obj.yearly_price else 0 current_cycle = 'yearly' if current_price and current_price > 0: # 计算剩余天数 remaining_days = (current_sub.end_date - datetime.utcnow()).days # 计算总天数 cycle_days_map = { 'monthly': 30, 'quarterly': 90, 'semiannual': 180, 'yearly': 365 } total_days = cycle_days_map.get(current_cycle, 365) # 计算剩余价值 if total_days > 0 and remaining_days > 0: remaining_value = current_price * (remaining_days / total_days) # 实付金额 = 新套餐价格 - 剩余价值 final_price = max(0, price - remaining_value) # 如果剩余价值 >= 新套餐价格,标记为免费升级 if remaining_value >= price: final_price = 0 elif current_plan == 'max' and to_plan_name == 'pro': # 降级:Max → Pro,到期后切换,全价购买 is_downgrade = True subscription_type = 'downgrade' else: # 其他情况视为新购 subscription_type = 'new' # 5. 构建结果 result = { 'is_renewal': is_renewal, 'is_upgrade': is_upgrade, 'is_downgrade': is_downgrade, 'subscription_type': subscription_type, 'current_plan': current_plan, 'current_cycle': current_cycle, 'new_plan_price': price, 'original_price': price, # 新套餐原价 'remaining_value': remaining_value, # 当前订阅剩余价值(仅升级时有效) 'original_amount': price, 'discount_amount': 0, 'final_amount': final_price, 'promo_code': None, 'promo_error': None } # 6. 应用优惠码(基于差价后的金额) if promo_code and promo_code.strip(): # 优惠码作用于差价后的金额 promo, error = validate_promo_code(promo_code, to_plan_name, to_cycle, final_price, user_id) if promo: discount = calculate_discount(promo, final_price) result['discount_amount'] = float(discount) result['final_amount'] = final_price - float(discount) result['promo_code'] = promo.code elif error: result['promo_error'] = error return result except Exception as e: return {'error': f'价格计算失败: {str(e)}'} # 保留旧函数以兼容(标记为废弃) def calculate_upgrade_price(user_id, to_plan_name, to_cycle, promo_code=None): """ 【已废弃】旧版升级价格计算函数,保留以兼容旧代码 新代码请使用 calculate_subscription_price_simple """ # 直接调用新函数 return calculate_subscription_price_simple(user_id, to_plan_name, to_cycle, promo_code) def initialize_subscription_plans_safe(): """安全地初始化订阅套餐""" try: if SubscriptionPlan.query.first(): return pro_plan = SubscriptionPlan( name='pro', display_name='Pro 专业版', description='事件关联股票深度分析 | 历史事件智能对比复盘 | 事件概念关联与挖掘 | 概念板块个股追踪 | 概念深度研报与解读 | 个股异动实时预警', monthly_price=0.01, yearly_price=0.08, features=json.dumps([ "基础股票分析工具", "历史数据查询", "基础财务报表", "简单投资计划记录", "标准客服支持" ]), sort_order=1 ) max_plan = SubscriptionPlan( name='max', display_name='Max 旗舰版', description='包含Pro版全部功能 | 事件传导链路智能分析 | 概念演变时间轴追溯 | 个股全方位深度研究 | 价小前投研助手无限使用 | 新功能优先体验权 | 专属客服一对一服务', monthly_price=0.1, yearly_price=0.8, features=json.dumps([ "全部Pro版本功能", "高级分析工具", "实时数据推送", "专业财务分析报告", "AI投资建议", "无限投资计划存储", "优先客服支持", "独家研报访问" ]), sort_order=2 ) db.session.add(pro_plan) db.session.add(max_plan) db.session.commit() except Exception as e: pass # -------------------------------------------- # 订阅等级工具函数 # -------------------------------------------- def _get_current_subscription_info(): """获取当前登录用户订阅信息的字典形式,未登录或异常时视为免费用户。""" try: user_id = session.get('user_id') if not user_id: return { 'type': 'free', 'status': 'active', 'is_active': True } sub = get_user_subscription_safe(user_id) data = sub.to_dict() # 标准化字段名 return { 'type': data.get('type') or data.get('subscription_type') or 'free', 'status': data.get('status') or data.get('subscription_status') or 'active', 'is_active': data.get('is_active', True) } except Exception: return { 'type': 'free', 'status': 'active', 'is_active': True } def _subscription_level(sub_type): """将订阅类型映射到等级数值,free=0, pro=1, max=2。""" mapping = {'free': 0, 'pro': 1, 'max': 2} return mapping.get((sub_type or 'free').lower(), 0) def _has_required_level(required: str) -> bool: """判断当前用户是否达到所需订阅级别。""" info = _get_current_subscription_info() if not info.get('is_active', True): return False return _subscription_level(info.get('type')) >= _subscription_level(required) # ============================================ # 微信开放平台域名校验 # ============================================ @app.route('/gvQnxIQ5Rs.txt', methods=['GET']) def wechat_domain_verify(): """微信开放平台域名校验文件""" return 'd526e9e857dbd2621e5100811972e8c5', 200, {'Content-Type': 'text/plain'} @app.route('/MP_verify_17Fo4JhapMw6vtNa.txt', methods=['GET']) def wechat_mp_domain_verify(): """微信公众号网页授权域名校验文件""" return '17Fo4JhapMw6vtNa', 200, {'Content-Type': 'text/plain'} # ============================================ # 订阅相关API接口 # ============================================ @app.route('/api/subscription/plans', methods=['GET']) def get_subscription_plans(): """获取订阅套餐列表""" try: plans = SubscriptionPlan.query.filter_by(is_active=True).order_by(SubscriptionPlan.sort_order).all() return jsonify({ 'success': True, 'data': [plan.to_dict() for plan in plans] }) except Exception as e: # 返回默认套餐(包含pricing_options以兼容新前端) default_plans = [ { 'id': 1, 'name': 'pro', 'display_name': 'Pro版本', 'description': '适合个人投资者的基础功能套餐', 'monthly_price': 198, 'yearly_price': 2000, 'pricing_options': [ {'months': 1, 'price': 198, 'label': '月付', 'cycle_key': 'monthly'}, {'months': 3, 'price': 534, 'label': '3个月', 'cycle_key': '3months', 'discount_percent': 10}, {'months': 6, 'price': 950, 'label': '半年', 'cycle_key': '6months', 'discount_percent': 20}, {'months': 12, 'price': 2000, 'label': '1年', 'cycle_key': 'yearly', 'discount_percent': 16}, {'months': 24, 'price': 3600, 'label': '2年', 'cycle_key': '2years', 'discount_percent': 24}, {'months': 36, 'price': 5040, 'label': '3年', 'cycle_key': '3years', 'discount_percent': 29} ], 'features': ['基础股票分析工具', '历史数据查询', '基础财务报表', '简单投资计划记录', '标准客服支持'], 'is_active': True, 'sort_order': 1 }, { 'id': 2, 'name': 'max', 'display_name': 'Max版本', 'description': '适合专业投资者的全功能套餐', 'monthly_price': 998, 'yearly_price': 10000, 'pricing_options': [ {'months': 1, 'price': 998, 'label': '月付', 'cycle_key': 'monthly'}, {'months': 3, 'price': 2695, 'label': '3个月', 'cycle_key': '3months', 'discount_percent': 10}, {'months': 6, 'price': 4790, 'label': '半年', 'cycle_key': '6months', 'discount_percent': 20}, {'months': 12, 'price': 10000, 'label': '1年', 'cycle_key': 'yearly', 'discount_percent': 17}, {'months': 24, 'price': 18000, 'label': '2年', 'cycle_key': '2years', 'discount_percent': 25}, {'months': 36, 'price': 25200, 'label': '3年', 'cycle_key': '3years', 'discount_percent': 30} ], 'features': ['全部Pro版本功能', '高级分析工具', '实时数据推送', 'API访问', '优先客服支持'], 'is_active': True, 'sort_order': 2 } ] return jsonify({ 'success': True, 'data': default_plans }) @app.route('/api/subscription/current', methods=['GET']) def get_current_subscription(): """获取当前用户的订阅信息""" try: if 'user_id' not in session: return jsonify({'success': False, 'error': '未登录'}), 401 subscription = get_user_subscription_safe(session['user_id']) return jsonify({ 'success': True, 'data': subscription.to_dict() }) except Exception as e: return jsonify({ 'success': True, 'data': { 'type': 'free', 'status': 'active', 'is_active': True, 'days_left': 999 } }) @app.route('/api/subscription/info', methods=['GET']) def get_subscription_info(): """获取当前用户的订阅信息 - 前端专用接口""" try: info = _get_current_subscription_info() return jsonify({ 'success': True, 'data': info }) except Exception as e: print(f"获取订阅信息错误: {e}") return jsonify({ 'success': True, 'data': { 'type': 'free', 'status': 'active', 'is_active': True, 'days_left': 999 } }) @app.route('/api/promo-code/validate', methods=['POST']) def validate_promo_code_api(): """验证优惠码""" try: if 'user_id' not in session: return jsonify({'success': False, 'error': '未登录'}), 401 data = request.get_json() code = data.get('code', '').strip() plan_name = data.get('plan_name') billing_cycle = data.get('billing_cycle') amount = data.get('amount', 0) if not code or not plan_name or not billing_cycle: return jsonify({'success': False, 'error': '参数不完整'}), 400 # 验证优惠码 promo, error = validate_promo_code(code, plan_name, billing_cycle, amount, session['user_id']) if error: return jsonify({ 'success': False, 'valid': False, 'error': error }) # 计算折扣 discount_amount = calculate_discount(promo, amount) final_amount = amount - discount_amount return jsonify({ 'success': True, 'valid': True, 'promo_code': promo.to_dict(), 'discount_amount': discount_amount, 'final_amount': final_amount }) except Exception as e: return jsonify({ 'success': False, 'error': f'验证失败: {str(e)}' }), 500 @app.route('/api/subscription/calculate-price', methods=['POST']) def calculate_subscription_price(): """ 计算订阅价格(新版:续费和新购价格一致) Request Body: { "to_plan": "pro", "to_cycle": "yearly", "promo_code": "WELCOME2025" // 可选 } Response: { "success": true, "data": { "is_renewal": true, // 是否为续费 "subscription_type": "renew", // new 或 renew "current_plan": "pro", // 当前套餐(如果有) "current_cycle": "monthly", // 当前周期(如果有) "new_plan_price": 2699.00, "original_amount": 2699.00, "discount_amount": 0, "final_amount": 2699.00, "promo_code": null, "promo_error": null } } """ try: if 'user_id' not in session: return jsonify({'success': False, 'error': '未登录'}), 401 data = request.get_json() to_plan = data.get('to_plan') to_cycle = data.get('to_cycle') promo_code = (data.get('promo_code') or '').strip() or None if not to_plan or not to_cycle: return jsonify({'success': False, 'error': '参数不完整'}), 400 # 使用新的简化价格计算函数 result = calculate_subscription_price_simple(session['user_id'], to_plan, to_cycle, promo_code) if 'error' in result: return jsonify({ 'success': False, 'error': result['error'] }), 400 return jsonify({ 'success': True, 'data': result }) except Exception as e: return jsonify({ 'success': False, 'error': f'计算失败: {str(e)}' }), 500 @app.route('/api/subscription/free-upgrade', methods=['POST']) @login_required def free_upgrade_subscription(): """ 免费升级订阅(当剩余价值 >= 新套餐价格时) Request Body: { "plan_name": "max", "billing_cycle": "yearly" } """ try: data = request.get_json() plan_name = data.get('plan_name') billing_cycle = data.get('billing_cycle') if not plan_name or not billing_cycle: return jsonify({'success': False, 'error': '参数不完整'}), 400 user_id = current_user.id # 计算价格,验证是否可以免费升级 price_result = calculate_subscription_price_simple(user_id, plan_name, billing_cycle, None) if 'error' in price_result: return jsonify({'success': False, 'error': price_result['error']}), 400 # 检查是否为升级且实付金额为0 if not price_result.get('is_upgrade') or price_result.get('final_amount', 1) > 0: return jsonify({'success': False, 'error': '当前情况不符合免费升级条件'}), 400 # 获取当前订阅 subscription = UserSubscription.query.filter_by(user_id=user_id).first() if not subscription: return jsonify({'success': False, 'error': '未找到订阅记录'}), 404 # 计算新的到期时间(按剩余价值折算) remaining_value = price_result.get('remaining_value', 0) new_plan_price = price_result.get('new_plan_price', 0) if new_plan_price > 0: # 计算可以兑换的新套餐天数 value_ratio = remaining_value / new_plan_price cycle_days_map = { 'monthly': 30, 'quarterly': 90, 'semiannual': 180, 'yearly': 365 } new_cycle_days = cycle_days_map.get(billing_cycle, 365) # 新的到期天数 = 周期天数 × 价值比例 new_days = int(new_cycle_days * value_ratio) # 更新订阅信息 subscription.subscription_type = plan_name subscription.billing_cycle = billing_cycle subscription.start_date = datetime.utcnow() subscription.end_date = datetime.utcnow() + timedelta(days=new_days) subscription.subscription_status = 'active' subscription.updated_at = datetime.utcnow() db.session.commit() return jsonify({ 'success': True, 'message': f'升级成功!您的{plan_name.upper()}版本将持续{new_days}天', 'data': { 'subscription_type': plan_name, 'end_date': subscription.end_date.isoformat(), 'days': new_days } }) else: return jsonify({'success': False, 'error': '价格计算异常'}), 500 except Exception as e: db.session.rollback() return jsonify({'success': False, 'error': f'升级失败: {str(e)}'}), 500 @app.route('/api/payment/create-order', methods=['POST']) def create_payment_order(): """ 创建支付订单(新版:简化逻辑,不再记录升级) Request Body: { "plan_name": "pro", "billing_cycle": "yearly", "promo_code": "WELCOME2025" // 可选 } """ try: if 'user_id' not in session: return jsonify({'success': False, 'error': '未登录'}), 401 data = request.get_json() plan_name = data.get('plan_name') billing_cycle = data.get('billing_cycle') promo_code = (data.get('promo_code') or '').strip() or None if not plan_name or not billing_cycle: return jsonify({'success': False, 'error': '参数不完整'}), 400 # 使用新的简化价格计算 price_result = calculate_subscription_price_simple(session['user_id'], plan_name, billing_cycle, promo_code) if 'error' in price_result: return jsonify({'success': False, 'error': price_result['error']}), 400 amount = price_result['final_amount'] subscription_type = price_result.get('subscription_type', 'new') # new 或 renew # 检查是否为免费升级(金额为0) if amount <= 0 and price_result.get('is_upgrade'): return jsonify({ 'success': False, 'error': '当前剩余价值可直接免费升级,请使用免费升级功能', 'should_free_upgrade': True, 'price_info': price_result }), 400 # 创建订单 try: # 获取原价和折扣金额 original_amount = price_result.get('original_amount', amount) discount_amount = price_result.get('discount_amount', 0) order = PaymentOrder( user_id=session['user_id'], plan_name=plan_name, billing_cycle=billing_cycle, amount=amount, original_amount=original_amount, discount_amount=discount_amount ) # 添加订阅类型标记(用于前端展示) order.remark = f"{subscription_type}订阅" if subscription_type == 'renew' else "新购订阅" # 如果使用了优惠码,关联优惠码 if promo_code and price_result.get('promo_code'): promo_obj = PromoCode.query.filter_by(code=promo_code.upper()).first() if promo_obj: order.promo_code_id = promo_obj.id print(f"📦 订单关联优惠码: {promo_obj.code} (ID: {promo_obj.id})") db.session.add(order) db.session.commit() except Exception as e: db.session.rollback() return jsonify({'success': False, 'error': f'订单创建失败: {str(e)}'}), 500 # 尝试调用真实的微信支付API(使用 subprocess 绕过 eventlet DNS 问题) try: import subprocess import urllib.parse # 使用独立脚本检查配置 script_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'wechat_pay_worker.py') # 先检查配置 check_result = subprocess.run( [sys.executable, script_path, 'check'], capture_output=True, text=True, timeout=10 ) if check_result.returncode != 0: check_data = json.loads(check_result.stdout) if check_result.stdout else {} ready_msg = check_data.get('error', check_data.get('message', '未知错误')) order.qr_code_url = f"https://api.qrserver.com/v1/create-qr-code/?size=200x200&data=wxpay://order/{order.order_no}" order.remark = f"演示模式 - {ready_msg}" else: # 创建微信支付订单 plan_display_name = f"{plan_name.upper()}版本-{billing_cycle}" body = f"VFr-{plan_display_name}" product_id = f"{plan_name}_{billing_cycle}" create_result = subprocess.run( [sys.executable, script_path, 'create', order.order_no, str(float(amount)), body, product_id], capture_output=True, text=True, timeout=60 ) print(f"[微信支付] 创建订单返回: {create_result.stdout}") if create_result.stderr: print(f"[微信支付] 错误输出: {create_result.stderr}") wechat_result = json.loads(create_result.stdout) if create_result.stdout else {'success': False, 'error': '无返回'} if wechat_result.get('success'): # 获取微信返回的原始code_url wechat_code_url = wechat_result['code_url'] # 将微信协议URL转换为二维码图片URL encoded_url = urllib.parse.quote(wechat_code_url, safe='') qr_image_url = f"https://api.qrserver.com/v1/create-qr-code/?size=200x200&data={encoded_url}" order.qr_code_url = qr_image_url order.prepay_id = wechat_result.get('prepay_id') order.remark = f"微信支付 - {wechat_code_url}" else: order.qr_code_url = f"https://api.qrserver.com/v1/create-qr-code/?size=200x200&data=wxpay://order/{order.order_no}" order.remark = f"微信支付失败: {wechat_result.get('error')}" except subprocess.TimeoutExpired: order.qr_code_url = f"https://api.qrserver.com/v1/create-qr-code/?size=200x200&data=wxpay://order/{order.order_no}" order.remark = "微信支付超时" except json.JSONDecodeError as e: order.qr_code_url = f"https://api.qrserver.com/v1/create-qr-code/?size=200x200&data=wxpay://order/{order.order_no}" order.remark = f"微信支付返回解析失败: {str(e)}" except Exception as e: import traceback print(f"[微信支付] Exception: {e}") traceback.print_exc() order.qr_code_url = f"https://api.qrserver.com/v1/create-qr-code/?size=200x200&data=wxpay://order/{order.order_no}" order.remark = f"支付异常: {str(e)}" db.session.commit() return jsonify({ 'success': True, 'data': order.to_dict(), 'message': '订单创建成功' }) except Exception as e: db.session.rollback() return jsonify({'success': False, 'error': '创建订单失败'}), 500 @app.route('/api/payment/order//status', methods=['GET']) def check_order_status(order_id): """查询订单支付状态""" try: if 'user_id' not in session: return jsonify({'success': False, 'error': '未登录'}), 401 # 查找订单 order = PaymentOrder.query.filter_by( id=order_id, user_id=session['user_id'] ).first() if not order: return jsonify({'success': False, 'error': '订单不存在'}), 404 # 如果订单已经是已支付状态,直接返回 if order.status == 'paid': return jsonify({ 'success': True, 'data': order.to_dict(), 'message': '订单已支付', 'payment_success': True }) # 如果订单过期,标记为过期 if order.is_expired(): order.status = 'expired' db.session.commit() return jsonify({ 'success': True, 'data': order.to_dict(), 'message': '订单已过期' }) # 调用微信支付API查询真实状态(使用 subprocess 绕过 eventlet DNS 问题) try: import subprocess script_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'wechat_pay_worker.py') query_proc = subprocess.run( [sys.executable, script_path, 'query', order.order_no], capture_output=True, text=True, timeout=30 ) query_result = json.loads(query_proc.stdout) if query_proc.stdout else {'success': False, 'error': '无返回'} if query_result.get('success'): trade_state = query_result.get('trade_state') transaction_id = query_result.get('transaction_id') if trade_state == 'SUCCESS': # 支付成功,更新订单状态 order.mark_as_paid(transaction_id) # 激活用户订阅 activate_user_subscription(order.user_id, order.plan_name, order.billing_cycle) # 记录优惠码使用情况 if order.promo_code_id: try: existing_usage = PromoCodeUsage.query.filter_by(order_id=order.id).first() if not existing_usage: usage = PromoCodeUsage( promo_code_id=order.promo_code_id, user_id=order.user_id, order_id=order.id, original_amount=order.original_amount or order.amount, discount_amount=order.discount_amount or 0, final_amount=order.amount ) db.session.add(usage) promo = PromoCode.query.get(order.promo_code_id) if promo: promo.current_uses = (promo.current_uses or 0) + 1 print(f"🎫 优惠码使用记录已创建: {promo.code}") except Exception as e: print(f"⚠️ 记录优惠码使用失败: {e}") db.session.commit() return jsonify({ 'success': True, 'data': order.to_dict(), 'message': '支付成功!订阅已激活', 'payment_success': True }) elif trade_state in ['NOTPAY', 'USERPAYING']: # 未支付或支付中 return jsonify({ 'success': True, 'data': order.to_dict(), 'message': '等待支付...', 'payment_success': False }) else: # 支付失败或取消 order.status = 'cancelled' db.session.commit() return jsonify({ 'success': True, 'data': order.to_dict(), 'message': '支付已取消', 'payment_success': False }) else: # 微信查询失败,返回当前状态 return jsonify({ 'success': True, 'data': order.to_dict(), 'message': f"查询失败: {query_result.get('error')}", 'payment_success': False }) except Exception as e: # 查询失败,返回当前订单状态 return jsonify({ 'success': True, 'data': order.to_dict(), 'message': '无法查询支付状态,请稍后重试', 'payment_success': False }) except Exception as e: return jsonify({'success': False, 'error': '查询失败'}), 500 @app.route('/api/payment/order//force-update', methods=['POST']) def force_update_order_status(order_id): """强制更新订单支付状态(调试用)""" try: if 'user_id' not in session: return jsonify({'success': False, 'error': '未登录'}), 401 # 查找订单 order = PaymentOrder.query.filter_by( id=order_id, user_id=session['user_id'] ).first() if not order: return jsonify({'success': False, 'error': '订单不存在'}), 404 # 检查微信支付状态(使用 subprocess 绕过 eventlet DNS 问题) try: import subprocess script_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'wechat_pay_worker.py') query_proc = subprocess.run( [sys.executable, script_path, 'query', order.order_no], capture_output=True, text=True, timeout=30 ) query_result = json.loads(query_proc.stdout) if query_proc.stdout else {'success': False, 'error': '无返回'} if query_result.get('success') and query_result.get('trade_state') == 'SUCCESS': # 强制更新为已支付 old_status = order.status order.mark_as_paid(query_result.get('transaction_id')) # 激活用户订阅 activate_user_subscription(order.user_id, order.plan_name, order.billing_cycle) # 记录优惠码使用(如果使用了优惠码) if order.promo_code_id: try: # 检查是否已经记录过(防止重复) existing_usage = PromoCodeUsage.query.filter_by(order_id=order.id).first() if not existing_usage: promo_usage = PromoCodeUsage( promo_code_id=order.promo_code_id, user_id=order.user_id, order_id=order.id, original_amount=order.original_amount or order.amount, discount_amount=order.discount_amount or 0, final_amount=order.amount ) db.session.add(promo_usage) # 更新优惠码使用次数 promo = PromoCode.query.get(order.promo_code_id) if promo: promo.current_uses = (promo.current_uses or 0) + 1 print(f"🎫 优惠码使用记录已创建: {promo.code}") else: print(f"ℹ️ 优惠码使用记录已存在,跳过") except Exception as e: print(f"⚠️ 记录优惠码使用失败: {e}") db.session.commit() print(f"✅ 订单状态强制更新成功: {old_status} -> paid") return jsonify({ 'success': True, 'message': f'订单状态已从 {old_status} 更新为 paid', 'data': order.to_dict(), 'payment_success': True }) else: return jsonify({ 'success': False, 'error': '微信支付状态不是成功状态,无法强制更新' }) except Exception as e: print(f"❌ 强制更新失败: {e}") return jsonify({ 'success': False, 'error': f'强制更新失败: {str(e)}' }) except Exception as e: print(f"强制更新订单状态失败: {str(e)}") return jsonify({'success': False, 'error': '操作失败'}), 500 @app.route('/api/payment/wechat/callback', methods=['POST']) def wechat_payment_callback(): """微信支付回调处理""" try: # 获取原始XML数据 raw_data = request.get_data() print(f"📥 收到微信支付回调: {raw_data}") # 验证回调数据 try: from wechat_pay import create_wechat_pay_instance wechat_pay = create_wechat_pay_instance() verify_result = wechat_pay.verify_callback(raw_data.decode('utf-8')) if not verify_result['success']: print(f"❌ 微信支付回调验证失败: {verify_result['error']}") return '' callback_data = verify_result['data'] except Exception as e: print(f"❌ 微信支付回调处理异常: {e}") # 简单解析XML(fallback) callback_data = _parse_xml_callback(raw_data.decode('utf-8')) if not callback_data: return '' # 获取关键字段 return_code = callback_data.get('return_code') result_code = callback_data.get('result_code') order_no = callback_data.get('out_trade_no') transaction_id = callback_data.get('transaction_id') print(f"📦 回调数据解析:") print(f" 返回码: {return_code}") print(f" 结果码: {result_code}") print(f" 订单号: {order_no}") print(f" 交易号: {transaction_id}") if not order_no: return '' # 查找订单 order = PaymentOrder.query.filter_by(order_no=order_no).first() if not order: print(f"❌ 订单不存在: {order_no}") return '' # 处理支付成功 if return_code == 'SUCCESS' and result_code == 'SUCCESS': print(f"🎉 支付回调成功: 订单 {order_no}") # 检查订单是否已经处理过 if order.status == 'paid': print(f"ℹ️ 订单已处理过: {order_no}") db.session.commit() return '' # 更新订单状态(无论之前是什么状态) old_status = order.status order.mark_as_paid(transaction_id) print(f"📝 订单状态已更新: {old_status} -> paid") # 激活用户订阅 subscription = activate_user_subscription(order.user_id, order.plan_name, order.billing_cycle) if subscription: print(f"✅ 用户订阅已激活: 用户{order.user_id}, 套餐{order.plan_name}") else: print(f"⚠️ 订阅激活失败,但订单已标记为已支付") # 记录优惠码使用情况 if order.promo_code_id: try: # 检查是否已经记录过(防止重复) existing_usage = PromoCodeUsage.query.filter_by( order_id=order.id ).first() if not existing_usage: # 创建优惠码使用记录 usage = PromoCodeUsage( promo_code_id=order.promo_code_id, user_id=order.user_id, order_id=order.id, original_amount=order.original_amount or order.amount, discount_amount=order.discount_amount or 0, final_amount=order.amount ) db.session.add(usage) # 更新优惠码使用次数 promo = PromoCode.query.get(order.promo_code_id) if promo: promo.current_uses = (promo.current_uses or 0) + 1 print(f"🎫 优惠码使用记录已创建: {promo.code}, 当前使用次数: {promo.current_uses}") else: print(f"ℹ️ 优惠码使用记录已存在,跳过") except Exception as e: print(f"⚠️ 记录优惠码使用失败: {e}") # 不影响主流程,继续执行 db.session.commit() # 返回成功响应给微信 return '' except Exception as e: db.session.rollback() print(f"❌ 微信支付回调处理失败: {e}") import traceback app.logger.error(f"回调处理错误: {e}", exc_info=True) return '' def _parse_xml_callback(xml_data): """简单的XML回调数据解析""" try: import xml.etree.ElementTree as ET root = ET.fromstring(xml_data) result = {} for child in root: result[child.tag] = child.text return result except Exception as e: print(f"XML解析失败: {e}") return None # ======================================== # 支付宝支付相关API # ======================================== @app.route('/api/payment/alipay/create-order', methods=['POST']) def create_alipay_order(): """ 创建支付宝支付订单 Request Body: { "plan_name": "pro", "billing_cycle": "yearly", "promo_code": "WELCOME2025", // 可选 "is_mobile": true // 可选,是否为手机端(自动使用 WAP 支付) } """ try: if 'user_id' not in session: return jsonify({'success': False, 'error': '未登录'}), 401 data = request.get_json() plan_name = data.get('plan_name') billing_cycle = data.get('billing_cycle') promo_code = (data.get('promo_code') or '').strip() or None # 前端传入的设备类型,用于决定使用 page 支付还是 wap 支付 is_mobile = data.get('is_mobile', False) if not plan_name or not billing_cycle: return jsonify({'success': False, 'error': '参数不完整'}), 400 # 使用简化价格计算 price_result = calculate_subscription_price_simple(session['user_id'], plan_name, billing_cycle, promo_code) if 'error' in price_result: return jsonify({'success': False, 'error': price_result['error']}), 400 amount = price_result['final_amount'] subscription_type = price_result.get('subscription_type', 'new') # 检查是否为免费升级 if amount <= 0 and price_result.get('is_upgrade'): return jsonify({ 'success': False, 'error': '当前剩余价值可直接免费升级,请使用免费升级功能', 'should_free_upgrade': True, 'price_info': price_result }), 400 # 创建订单 try: original_amount = price_result.get('original_amount', amount) discount_amount = price_result.get('discount_amount', 0) order = PaymentOrder( user_id=session['user_id'], plan_name=plan_name, billing_cycle=billing_cycle, amount=amount, original_amount=original_amount, discount_amount=discount_amount ) # 设置支付方式为支付宝 order.payment_method = 'alipay' order.remark = f"{subscription_type}订阅" if subscription_type == 'renew' else "新购订阅" # 关联优惠码 if promo_code and price_result.get('promo_code'): promo_obj = PromoCode.query.filter_by(code=promo_code.upper()).first() if promo_obj: order.promo_code_id = promo_obj.id print(f"📦 订单关联优惠码: {promo_obj.code} (ID: {promo_obj.id})") db.session.add(order) db.session.commit() except Exception as e: db.session.rollback() return jsonify({'success': False, 'error': f'订单创建失败: {str(e)}'}), 500 # 调用支付宝支付API(使用 subprocess 绕过 eventlet DNS 问题) try: import subprocess script_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'alipay_pay_worker.py') # 先检查配置 check_result = subprocess.run( [sys.executable, script_path, 'check'], capture_output=True, text=True, timeout=10 ) if check_result.returncode != 0: check_data = json.loads(check_result.stdout) if check_result.stdout else {} error_msg = check_data.get('error', check_data.get('message', '支付宝配置错误')) order.remark = f"支付宝配置错误 - {error_msg}" db.session.commit() return jsonify({ 'success': False, 'error': f'支付宝支付暂不可用: {error_msg}' }), 500 # 创建支付宝订单 plan_display_name = f"{plan_name.upper()}版本-{billing_cycle}" subject = f"VFr-{plan_display_name}" body = f"价值前沿订阅服务-{plan_display_name}" # 金额格式化为两位小数(支付宝要求) amount_str = f"{float(amount):.2f}" # 根据设备类型选择支付方式:wap=手机网站支付,page=电脑网站支付 pay_type = 'wap' if is_mobile else 'page' print(f"[支付宝] 设备类型: {'手机' if is_mobile else '电脑'}, 支付方式: {pay_type}") create_result = subprocess.run( [sys.executable, script_path, 'create', order.order_no, amount_str, subject, body, pay_type], capture_output=True, text=True, timeout=60 ) print(f"[支付宝] 创建订单返回: {create_result.stdout}") if create_result.stderr: print(f"[支付宝] 错误输出: {create_result.stderr}") alipay_result = json.loads(create_result.stdout) if create_result.stdout else {'success': False, 'error': '无返回'} if alipay_result.get('success'): # 获取支付宝返回的支付链接 pay_url = alipay_result['pay_url'] order.pay_url = pay_url order.remark = f"支付宝支付 - 订单已创建" db.session.commit() return jsonify({ 'success': True, 'data': order.to_dict(), 'message': '订单创建成功' }) else: order.remark = f"支付宝支付失败: {alipay_result.get('error')}" db.session.commit() return jsonify({ 'success': False, 'error': f"支付宝订单创建失败: {alipay_result.get('error')}" }), 500 except subprocess.TimeoutExpired: order.remark = "支付宝支付超时" db.session.commit() return jsonify({'success': False, 'error': '支付宝支付超时'}), 500 except json.JSONDecodeError as e: order.remark = f"支付宝返回解析失败: {str(e)}" db.session.commit() return jsonify({'success': False, 'error': '支付宝返回数据异常'}), 500 except Exception as e: import traceback print(f"[支付宝] Exception: {e}") traceback.print_exc() order.remark = f"支付异常: {str(e)}" db.session.commit() return jsonify({'success': False, 'error': '支付异常'}), 500 except Exception as e: db.session.rollback() return jsonify({'success': False, 'error': '创建订单失败'}), 500 @app.route('/api/payment/alipay/callback', methods=['POST']) def alipay_payment_callback(): """支付宝异步回调处理""" try: # 获取POST参数 callback_params = request.form.to_dict() print(f"📥 收到支付宝支付回调: {callback_params}") # 验证回调数据 try: from alipay_pay import create_alipay_instance alipay = create_alipay_instance() verify_result = alipay.verify_callback(callback_params.copy()) if not verify_result['success']: print(f"❌ 支付宝回调签名验证失败: {verify_result['error']}") return 'fail' callback_data = verify_result['data'] except Exception as e: print(f"❌ 支付宝回调处理异常: {e}") return 'fail' # 获取关键字段 trade_status = callback_data.get('trade_status') out_trade_no = callback_data.get('out_trade_no') # 商户订单号 trade_no = callback_data.get('trade_no') # 支付宝交易号 total_amount = callback_data.get('total_amount') print(f"📦 支付宝回调数据解析:") print(f" 交易状态: {trade_status}") print(f" 订单号: {out_trade_no}") print(f" 交易号: {trade_no}") print(f" 金额: {total_amount}") if not out_trade_no: print("❌ 缺少订单号") return 'fail' # 查找订单 order = PaymentOrder.query.filter_by(order_no=out_trade_no).first() if not order: print(f"❌ 订单不存在: {out_trade_no}") return 'fail' # 只处理交易成功的回调 if trade_status in ['TRADE_SUCCESS', 'TRADE_FINISHED']: print(f"🎉 支付宝支付成功: 订单 {out_trade_no}") # 检查订单是否已经处理过 if order.status == 'paid': print(f"ℹ️ 订单已处理过: {out_trade_no}") return 'success' # 更新订单状态 old_status = order.status order.mark_as_paid(trade_no, 'alipay') print(f"📝 订单状态已更新: {old_status} -> paid") # 激活用户订阅 subscription = activate_user_subscription(order.user_id, order.plan_name, order.billing_cycle) if subscription: print(f"✅ 用户订阅已激活: 用户{order.user_id}, 套餐{order.plan_name}") else: print(f"⚠️ 订阅激活失败,但订单已标记为已支付") # 记录优惠码使用情况 if order.promo_code_id: try: existing_usage = PromoCodeUsage.query.filter_by(order_id=order.id).first() if not existing_usage: usage = PromoCodeUsage( promo_code_id=order.promo_code_id, user_id=order.user_id, order_id=order.id, original_amount=order.original_amount or order.amount, discount_amount=order.discount_amount or 0, final_amount=order.amount ) db.session.add(usage) promo = PromoCode.query.get(order.promo_code_id) if promo: promo.current_uses = (promo.current_uses or 0) + 1 print(f"🎫 优惠码使用记录已创建: {promo.code}, 当前使用次数: {promo.current_uses}") else: print(f"ℹ️ 优惠码使用记录已存在,跳过") except Exception as e: print(f"⚠️ 记录优惠码使用失败: {e}") db.session.commit() elif trade_status == 'TRADE_CLOSED': # 交易关闭 if order.status not in ['paid', 'cancelled']: order.status = 'cancelled' db.session.commit() print(f"📝 订单已关闭: {out_trade_no}") # 返回成功响应给支付宝 return 'success' except Exception as e: db.session.rollback() print(f"❌ 支付宝回调处理失败: {e}") import traceback app.logger.error(f"支付宝回调处理错误: {e}", exc_info=True) return 'fail' @app.route('/api/payment/alipay/return', methods=['GET']) def alipay_payment_return(): """支付宝同步返回处理(用户支付后跳转回来)""" try: # 获取GET参数 return_params = request.args.to_dict() print(f"📥 支付宝同步返回: {return_params}") out_trade_no = return_params.get('out_trade_no') if out_trade_no: # 重定向到前端支付结果页面 return redirect(f'{FRONTEND_URL}/pricing?payment_return=alipay&order_no={out_trade_no}') else: return redirect(f'{FRONTEND_URL}/pricing?payment_return=alipay&error=missing_order') except Exception as e: print(f"❌ 支付宝同步返回处理失败: {e}") return redirect(f'{FRONTEND_URL}/pricing?payment_return=alipay&error=exception') @app.route('/api/payment/alipay/order//status', methods=['GET']) def check_alipay_order_status(order_id): """查询支付宝订单支付状态""" try: if 'user_id' not in session: return jsonify({'success': False, 'error': '未登录'}), 401 # 查找订单 order = PaymentOrder.query.filter_by( id=order_id, user_id=session['user_id'] ).first() if not order: return jsonify({'success': False, 'error': '订单不存在'}), 404 # 如果订单已经是已支付状态,直接返回 if order.status == 'paid': return jsonify({ 'success': True, 'data': order.to_dict(), 'message': '订单已支付', 'payment_success': True }) # 如果订单过期,标记为过期 if order.is_expired(): order.status = 'expired' db.session.commit() return jsonify({ 'success': True, 'data': order.to_dict(), 'message': '订单已过期' }) # 调用支付宝API查询真实状态 try: import subprocess script_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'alipay_pay_worker.py') query_proc = subprocess.run( [sys.executable, script_path, 'query', order.order_no], capture_output=True, text=True, timeout=30 ) query_result = json.loads(query_proc.stdout) if query_proc.stdout else {'success': False, 'error': '无返回'} if query_result.get('success'): trade_state = query_result.get('trade_state') trade_no = query_result.get('trade_no') if trade_state == 'SUCCESS': # 支付成功,更新订单状态 order.mark_as_paid(trade_no, 'alipay') # 激活用户订阅 activate_user_subscription(order.user_id, order.plan_name, order.billing_cycle) # 记录优惠码使用情况 if order.promo_code_id: try: existing_usage = PromoCodeUsage.query.filter_by(order_id=order.id).first() if not existing_usage: usage = PromoCodeUsage( promo_code_id=order.promo_code_id, user_id=order.user_id, order_id=order.id, original_amount=order.original_amount or order.amount, discount_amount=order.discount_amount or 0, final_amount=order.amount ) db.session.add(usage) promo = PromoCode.query.get(order.promo_code_id) if promo: promo.current_uses = (promo.current_uses or 0) + 1 print(f"🎫 优惠码使用记录已创建: {promo.code}") except Exception as e: print(f"⚠️ 记录优惠码使用失败: {e}") db.session.commit() return jsonify({ 'success': True, 'data': order.to_dict(), 'message': '支付成功!订阅已激活', 'payment_success': True }) elif trade_state in ['NOTPAY', 'WAIT_BUYER_PAY']: # 未支付或等待支付 return jsonify({ 'success': True, 'data': order.to_dict(), 'message': '等待支付...', 'payment_success': False }) elif trade_state in ['CLOSED', 'TRADE_CLOSED']: # 交易关闭 order.status = 'cancelled' db.session.commit() return jsonify({ 'success': True, 'data': order.to_dict(), 'message': '交易已关闭', 'payment_success': False }) else: # 其他状态 return jsonify({ 'success': True, 'data': order.to_dict(), 'message': f'当前状态: {trade_state}', 'payment_success': False }) else: # 支付宝查询失败,返回当前状态 return jsonify({ 'success': True, 'data': order.to_dict(), 'message': f"查询失败: {query_result.get('error')}", 'payment_success': False }) except Exception as e: # 查询失败,返回当前订单状态 return jsonify({ 'success': True, 'data': order.to_dict(), 'message': '无法查询支付状态,请稍后重试', 'payment_success': False }) except Exception as e: return jsonify({'success': False, 'error': '查询失败'}), 500 @app.route('/api/payment/alipay/order-by-no//status', methods=['GET']) def check_alipay_order_status_by_no(order_no): """通过订单号查询支付宝订单支付状态(用于手机端支付返回)""" try: if 'user_id' not in session: return jsonify({'success': False, 'error': '未登录'}), 401 # 通过订单号查找订单 order = PaymentOrder.query.filter_by( order_no=order_no, user_id=session['user_id'] ).first() if not order: return jsonify({'success': False, 'error': '订单不存在'}), 404 # 复用现有的状态检查逻辑 return check_alipay_order_status(str(order.id)) except Exception as e: return jsonify({'success': False, 'error': '查询失败'}), 500 @app.route('/api/auth/session', methods=['GET']) def get_session_info(): """获取当前登录用户信息""" if 'user_id' in session: user = User.query.get(session['user_id']) if user: # 获取用户订阅信息 subscription_info = get_user_subscription_safe(user.id).to_dict() return jsonify({ 'success': True, 'isAuthenticated': True, 'user': { 'id': user.id, 'username': user.username, 'nickname': user.nickname or user.username, 'email': user.email, 'phone': user.phone, 'phone_confirmed': bool(user.phone_confirmed), 'email_confirmed': bool(user.email_confirmed) if hasattr(user, 'email_confirmed') else None, 'avatar_url': user.avatar_url, 'has_wechat': bool(user.wechat_open_id), 'created_at': user.created_at.isoformat() if user.created_at else None, 'last_seen': user.last_seen.isoformat() if user.last_seen else None, # 将订阅字段映射到前端期望的字段名 'subscription_type': subscription_info['type'], 'subscription_status': subscription_info['status'], 'subscription_end_date': subscription_info['end_date'], 'is_subscription_active': subscription_info['is_active'], 'subscription_days_left': subscription_info['days_left'] } }) return jsonify({ 'success': True, 'isAuthenticated': False, 'user': None }) def generate_verification_code(): """生成6位数字验证码""" return ''.join(random.choices(string.digits, k=6)) @app.route('/api/auth/login', methods=['POST']) def login(): """传统登录 - 使用Session""" try: username = request.form.get('username') email = request.form.get('email') phone = request.form.get('phone') password = request.form.get('password') # 验证必要参数 if not password: return jsonify({'success': False, 'error': '密码不能为空'}), 400 # 根据提供的信息查找用户 user = None if username: # 检查username是否为手机号格式 if re.match(r'^1[3-9]\d{9}$', username): # 如果username是手机号格式,先按手机号查找 user = User.query.filter_by(phone=username).first() if not user: # 如果没找到,再按用户名查找 user = User.find_by_login_info(username) else: # 不是手机号格式,按用户名查找 user = User.find_by_login_info(username) elif email: user = User.query.filter_by(email=email).first() elif phone: user = User.query.filter_by(phone=phone).first() else: return jsonify({'success': False, 'error': '请提供用户名、邮箱或手机号'}), 400 if not user: return jsonify({'success': False, 'error': '用户不存在'}), 404 # 尝试密码验证 password_valid = user.check_password(password) if not password_valid: # 还可以尝试直接验证 if user.password_hash: from werkzeug.security import check_password_hash direct_check = check_password_hash(user.password_hash, password) return jsonify({'success': False, 'error': '密码错误'}), 401 # 设置session session.permanent = True # 使用永久session session['user_id'] = user.id session['username'] = user.username session['logged_in'] = True # Flask-Login 登录 login_user(user, remember=True) # 更新最后登录时间 user.update_last_seen() return jsonify({ 'success': True, 'message': '登录成功', 'user': { 'id': user.id, 'username': user.username, 'nickname': user.nickname or user.username, 'email': user.email, 'phone': user.phone, 'avatar_url': user.avatar_url, 'has_wechat': bool(user.wechat_open_id) } }) except Exception as e: import traceback app.logger.error(f"回调处理错误: {e}", exc_info=True) return jsonify({'success': False, 'error': '登录处理失败,请重试'}), 500 # 添加OPTIONS请求处理 @app.before_request def handle_preflight(): if request.method == "OPTIONS": response = make_response() response.headers.add("Access-Control-Allow-Origin", "*") response.headers.add('Access-Control-Allow-Headers', "*") response.headers.add('Access-Control-Allow-Methods', "*") return response # 修改密码API @app.route('/api/account/change-password', methods=['POST']) @login_required def change_password(): """修改当前用户密码""" try: data = request.get_json() or request.form current_password = data.get('currentPassword') or data.get('current_password') new_password = data.get('newPassword') or data.get('new_password') is_first_set = data.get('isFirstSet', False) # 是否为首次设置密码 if not new_password: return jsonify({'success': False, 'error': '新密码不能为空'}), 400 if len(new_password) < 6: return jsonify({'success': False, 'error': '新密码至少需要6个字符'}), 400 # 获取当前用户 user = current_user if not user: return jsonify({'success': False, 'error': '用户未登录'}), 401 # 检查是否为微信用户且首次设置密码 is_wechat_user = bool(user.wechat_open_id) # 如果是微信用户首次设置密码,或者明确标记为首次设置,则跳过当前密码验证 if is_first_set or (is_wechat_user and not current_password): pass # 跳过当前密码验证 else: # 普通用户或非首次设置,需要验证当前密码 if not current_password: return jsonify({'success': False, 'error': '请输入当前密码'}), 400 if not user.check_password(current_password): return jsonify({'success': False, 'error': '当前密码错误'}), 400 # 设置新密码 user.set_password(new_password) db.session.commit() return jsonify({ 'success': True, 'message': '密码设置成功' if (is_first_set or is_wechat_user) else '密码修改成功' }) except Exception as e: return jsonify({'success': False, 'error': str(e)}), 500 # 检查用户密码状态API @app.route('/api/account/password-status', methods=['GET']) @login_required def get_password_status(): """获取当前用户的密码状态信息""" try: user = current_user if not user: return jsonify({'success': False, 'error': '用户未登录'}), 401 is_wechat_user = bool(user.wechat_open_id) return jsonify({ 'success': True, 'data': { 'isWechatUser': is_wechat_user, 'hasPassword': bool(user.password_hash), 'needsFirstTimeSetup': is_wechat_user # 微信用户需要首次设置 } }) except Exception as e: return jsonify({'success': False, 'error': str(e)}), 500 # 检查用户信息完整性API @app.route('/api/account/profile-completeness', methods=['GET']) @login_required def get_profile_completeness(): try: user = current_user if not user: return jsonify({'success': False, 'error': '用户未登录'}), 401 is_wechat_user = bool(user.wechat_open_id) # 检查各项信息 completeness = { 'hasPassword': bool(user.password_hash), 'hasPhone': bool(user.phone), 'hasEmail': bool(user.email and '@' in user.email and not user.email.endswith('@valuefrontier.temp')), 'isWechatUser': is_wechat_user } # 计算完整度 total_items = 3 completed_items = sum([completeness['hasPassword'], completeness['hasPhone'], completeness['hasEmail']]) completeness_percentage = int((completed_items / total_items) * 100) # 智能判断是否需要提醒 needs_attention = False missing_items = [] # 只在用户首次登录或最近登录时提醒 if is_wechat_user: # 检查用户是否是新用户(注册7天内) is_new_user = (datetime.now() - user.created_at).days < 7 # 检查是否最近没有提醒过(使用session记录) last_reminder = session.get('last_completeness_reminder') should_remind = False if not last_reminder: should_remind = True else: # 每7天最多提醒一次 days_since_reminder = (datetime.now() - datetime.fromisoformat(last_reminder)).days should_remind = days_since_reminder >= 7 # 只对新用户或长时间未完善的用户提醒 if (is_new_user or completeness_percentage < 50) and should_remind: needs_attention = True if not completeness['hasPassword']: missing_items.append('登录密码') if not completeness['hasPhone']: missing_items.append('手机号') if not completeness['hasEmail']: missing_items.append('邮箱') # 记录本次提醒时间 session['last_completeness_reminder'] = datetime.now().isoformat() return jsonify({ 'success': True, 'data': { 'completeness': completeness, 'completenessPercentage': completeness_percentage, 'needsAttention': needs_attention, 'missingItems': missing_items, 'isComplete': completed_items == total_items, 'showReminder': needs_attention # 前端使用这个字段决定是否显示提醒 } }) except Exception as e: print(f"获取资料完整性错误: {e}") return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/auth/logout', methods=['POST']) def logout(): """登出 - 清除Session""" logout_user() # Flask-Login 登出 session.clear() return jsonify({'success': True, 'message': '已登出'}) @app.route('/api/auth/send-verification-code', methods=['POST']) def send_verification_code(): """发送验证码(支持手机号和邮箱)""" try: data = request.get_json() credential = data.get('credential') # 手机号或邮箱 code_type = data.get('type') # 'phone' 或 'email' purpose = data.get('purpose', 'login') # 'login' 或 'register' if not credential or not code_type: return jsonify({'success': False, 'error': '缺少必要参数'}), 400 # 清理格式字符(空格、横线、括号等) if code_type == 'phone': # 移除手机号中的空格、横线、括号、加号等格式字符 credential = re.sub(r'[\s\-\(\)\+]', '', credential) print(f"📱 清理后的手机号: {credential}") elif code_type == 'email': # 邮箱只移除空格 credential = credential.strip() # 生成验证码 verification_code = generate_verification_code() # 存储验证码到session(实际生产环境建议使用Redis) session_key = f'verification_code_{code_type}_{credential}_{purpose}' session[session_key] = { 'code': verification_code, 'timestamp': time.time(), 'attempts': 0 } if code_type == 'phone': # 手机号验证码发送 if not re.match(r'^1[3-9]\d{9}$', credential): return jsonify({'success': False, 'error': '手机号格式不正确'}), 400 # 发送真实短信验证码 if send_sms_code(credential, verification_code, SMS_TEMPLATE_LOGIN): print(f"[短信已发送] 验证码到 {credential}: {verification_code}") else: return jsonify({'success': False, 'error': '短信发送失败,请稍后重试'}), 500 elif code_type == 'email': # 邮箱验证码发送 if not re.match(r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$', credential): return jsonify({'success': False, 'error': '邮箱格式不正确'}), 400 # 发送真实邮件验证码 if send_email_code(credential, verification_code): print(f"[邮件已发送] 验证码到 {credential}: {verification_code}") else: return jsonify({'success': False, 'error': '邮件发送失败,请稍后重试'}), 500 else: return jsonify({'success': False, 'error': '不支持的验证码类型'}), 400 return jsonify({ 'success': True, 'message': f'验证码已发送到您的{code_type}' }) except Exception as e: print(f"发送验证码错误: {e}") return jsonify({'success': False, 'error': '发送验证码失败'}), 500 @app.route('/api/auth/login-with-code', methods=['POST']) def login_with_verification_code(): """使用验证码登录/注册(自动注册)""" try: data = request.get_json() credential = data.get('credential') # 手机号或邮箱 verification_code = data.get('verification_code') login_type = data.get('login_type') # 'phone' 或 'email' if not credential or not verification_code or not login_type: return jsonify({'success': False, 'error': '缺少必要参数'}), 400 # 清理格式字符(空格、横线、括号等) if login_type == 'phone': # 移除手机号中的空格、横线、括号、加号等格式字符 original_credential = credential credential = re.sub(r'[\s\-\(\)\+]', '', credential) if original_credential != credential: print(f"📱 登录时清理手机号: {original_credential} -> {credential}") elif login_type == 'email': # 邮箱只移除前后空格 credential = credential.strip() # 检查验证码 session_key = f'verification_code_{login_type}_{credential}_login' stored_code_info = session.get(session_key) if not stored_code_info: return jsonify({'success': False, 'error': '验证码已过期或不存在'}), 400 # 检查验证码是否过期(5分钟) if time.time() - stored_code_info['timestamp'] > 300: session.pop(session_key, None) return jsonify({'success': False, 'error': '验证码已过期'}), 400 # 检查尝试次数 if stored_code_info['attempts'] >= 3: session.pop(session_key, None) return jsonify({'success': False, 'error': '验证码错误次数过多'}), 400 # 验证码错误 if stored_code_info['code'] != verification_code: stored_code_info['attempts'] += 1 session[session_key] = stored_code_info return jsonify({'success': False, 'error': '验证码错误'}), 400 # 验证码正确,查找用户 user = None is_new_user = False if login_type == 'phone': user = User.query.filter_by(phone=credential).first() if not user: # 自动注册新用户 is_new_user = True # 生成唯一用户名 base_username = f"user_{credential}" username = base_username counter = 1 while User.query.filter_by(username=username).first(): username = f"{base_username}_{counter}" counter += 1 # 创建新用户 user = User(username=username, phone=credential) user.phone_confirmed = True user.email = f"{username}@valuefrontier.temp" # 临时邮箱 db.session.add(user) db.session.commit() elif login_type == 'email': user = User.query.filter_by(email=credential).first() if not user: # 自动注册新用户 is_new_user = True # 从邮箱生成用户名 email_prefix = credential.split('@')[0] base_username = f"user_{email_prefix}" username = base_username counter = 1 while User.query.filter_by(username=username).first(): username = f"{base_username}_{counter}" counter += 1 # 如果用户不存在,自动创建新用户 if not user: try: # 生成用户名 if login_type == 'phone': # 使用手机号生成用户名 base_username = f"用户{credential[-4:]}" elif login_type == 'email': # 使用邮箱前缀生成用户名 base_username = credential.split('@')[0] else: base_username = "新用户" # 确保用户名唯一 username = base_username counter = 1 while User.is_username_taken(username): username = f"{base_username}_{counter}" counter += 1 # 创建新用户 user = User(username=username) # 设置手机号或邮箱 if login_type == 'phone': user.phone = credential elif login_type == 'email': user.email = credential # 设置默认密码(使用随机密码,用户后续可以修改) user.set_password(uuid.uuid4().hex) user.status = 'active' user.nickname = username db.session.add(user) db.session.commit() is_new_user = True print(f"✅ 自动创建新用户: {username}, {login_type}: {credential}") except Exception as e: print(f"❌ 创建用户失败: {e}") db.session.rollback() return jsonify({'success': False, 'error': '创建用户失败'}), 500 # 清除验证码 session.pop(session_key, None) # 设置session session.permanent = True session['user_id'] = user.id session['username'] = user.username session['logged_in'] = True # Flask-Login 登录 login_user(user, remember=True) # 更新最后登录时间 user.update_last_seen() # 根据是否为新用户返回不同的消息 message = '注册成功,欢迎加入!' if is_new_user else '登录成功' return jsonify({ 'success': True, 'message': message, 'is_new_user': is_new_user, 'user': { 'id': user.id, 'username': user.username, 'nickname': user.nickname or user.username, 'email': user.email, 'phone': user.phone, 'avatar_url': user.avatar_url, 'has_wechat': bool(user.wechat_open_id) } }) except Exception as e: print(f"验证码登录错误: {e}") db.session.rollback() return jsonify({'success': False, 'error': '登录失败'}), 500 @app.route('/api/auth/register', methods=['POST']) def register(): """用户注册 - 使用Session""" username = request.form.get('username') email = request.form.get('email') password = request.form.get('password') # 验证输入 if not all([username, email, password]): return jsonify({'success': False, 'error': '所有字段都是必填的'}), 400 # 检查用户名和邮箱是否已存在 if User.is_username_taken(username): return jsonify({'success': False, 'error': '用户名已存在'}), 400 if User.is_email_taken(email): return jsonify({'success': False, 'error': '邮箱已被使用'}), 400 try: # 创建新用户 user = User(username=username, email=email) user.set_password(password) user.email_confirmed = True # 暂时默认已确认 db.session.add(user) db.session.flush() # 获取 user.id # 自动创建积分账户,初始10000积分 credit_account = UserCreditAccount( user_id=user.id, balance=10000, frozen=0 ) db.session.add(credit_account) db.session.commit() # 自动登录 session.permanent = True session['user_id'] = user.id session['username'] = user.username session['logged_in'] = True # Flask-Login 登录 login_user(user, remember=True) return jsonify({ 'success': True, 'message': '注册成功', 'user': { 'id': user.id, 'username': user.username, 'nickname': user.nickname or user.username, 'email': user.email } }), 201 except Exception as e: db.session.rollback() print(f"验证码登录/注册错误: {e}") return jsonify({'success': False, 'error': '登录失败'}), 500 def send_sms_code(phone, code, template_id): """发送短信验证码(使用 subprocess 绕过 eventlet DNS 问题)""" import subprocess import os try: # 获取脚本路径 script_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'sms_sender.py') print(f"[短信] 准备发送验证码到 {phone},模板ID: {template_id}") # 使用 subprocess 在独立进程中发送短信(绕过 eventlet DNS) result = subprocess.run( [ sys.executable, # 使用当前 Python 解释器 script_path, phone, code, template_id, SMS_SECRET_ID, SMS_SECRET_KEY, SMS_SDK_APP_ID, SMS_SIGN_NAME ], capture_output=True, text=True, timeout=30 ) if result.returncode == 0: print(f"[短信] ✓ 发送成功: {result.stdout.strip()}") return True else: print(f"[短信] ✗ 发送失败: {result.stderr.strip()}") return False except subprocess.TimeoutExpired: print(f"[短信] ✗ 发送超时") return False except Exception as e: print(f"[短信] ✗ 发送异常: {type(e).__name__}: {e}") return False def send_email_code(email, code): """发送邮件验证码(使用 subprocess 绕过 eventlet DNS 问题)""" import subprocess import os try: # 获取脚本路径 script_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'email_sender.py') subject = '价值前沿 - 验证码' body = f'您的验证码是:{code},有效期5分钟。如非本人操作,请忽略此邮件。' print(f"[邮件] 准备发送验证码到 {email}") print(f"[邮件] 服务器: {MAIL_SERVER}:{MAIL_PORT}, SSL: {MAIL_USE_SSL}") # 使用 subprocess 在独立进程中发送邮件(绕过 eventlet DNS) result = subprocess.run( [ sys.executable, # 使用当前 Python 解释器 script_path, email, subject, body, MAIL_SERVER, str(MAIL_PORT), MAIL_USERNAME, MAIL_PASSWORD, str(MAIL_USE_SSL).lower() ], capture_output=True, text=True, timeout=60 ) if result.returncode == 0: print(f"[邮件] ✓ 发送成功: {result.stdout.strip()}") return True else: print(f"[邮件] ✗ 发送失败: {result.stderr.strip()}") return False except subprocess.TimeoutExpired: print(f"[邮件] ✗ 发送超时") return False except Exception as e: print(f"[邮件] ✗ 发送异常: {type(e).__name__}: {e}") return False @app.route('/api/auth/send-sms-code', methods=['POST']) def send_sms_verification(): """发送手机验证码""" data = request.get_json() phone = data.get('phone') if not phone: return jsonify({'error': '手机号不能为空'}), 400 # 注册时验证是否已注册;若用于绑定手机,需要另外接口 # 这里保留原逻辑,新增绑定接口处理不同规则 if User.query.filter_by(phone=phone).first(): return jsonify({'error': '该手机号已注册'}), 400 # 生成验证码 code = generate_verification_code() # 发送短信 if send_sms_code(phone, code, SMS_TEMPLATE_REGISTER): # 存储验证码到 Redis(5分钟有效) set_verification_code(f'phone_{phone}', code) return jsonify({'message': '验证码已发送'}), 200 else: return jsonify({'error': '验证码发送失败'}), 500 @app.route('/api/auth/send-email-code', methods=['POST']) def send_email_verification(): """发送邮箱验证码""" data = request.get_json() email = data.get('email') if not email: return jsonify({'error': '邮箱不能为空'}), 400 if User.query.filter_by(email=email).first(): return jsonify({'error': '该邮箱已注册'}), 400 # 生成验证码 code = generate_verification_code() # 发送邮件 if send_email_code(email, code): # 存储验证码到 Redis(5分钟有效) set_verification_code(f'email_{email}', code) return jsonify({'message': '验证码已发送'}), 200 else: return jsonify({'error': '验证码发送失败'}), 500 @app.route('/api/auth/register/phone', methods=['POST']) def register_with_phone(): """手机号注册 - 使用Session""" data = request.get_json() phone = data.get('phone') code = data.get('code') password = data.get('password') username = data.get('username') if not all([phone, code, password, username]): return jsonify({'success': False, 'error': '所有字段都是必填的'}), 400 # 验证验证码(从 Redis 获取) stored_code = get_verification_code(f'phone_{phone}') if not stored_code or stored_code['expires'] < time.time(): return jsonify({'success': False, 'error': '验证码已过期'}), 400 if stored_code['code'] != code: return jsonify({'success': False, 'error': '验证码错误'}), 400 if User.query.filter_by(username=username).first(): return jsonify({'success': False, 'error': '用户名已存在'}), 400 try: # 创建用户 user = User(username=username, phone=phone) user.email = f"{username}@valuefrontier.temp" user.set_password(password) user.phone_confirmed = True db.session.add(user) db.session.flush() # 获取 user.id # 自动创建积分账户,初始10000积分 credit_account = UserCreditAccount( user_id=user.id, balance=10000, frozen=0 ) db.session.add(credit_account) db.session.commit() # 清除验证码(从 Redis 删除) delete_verification_code(f'phone_{phone}') # 自动登录 session.permanent = True session['user_id'] = user.id session['username'] = user.username session['logged_in'] = True # Flask-Login 登录 login_user(user, remember=True) return jsonify({ 'success': True, 'message': '注册成功', 'user': { 'id': user.id, 'username': user.username, 'phone': user.phone } }), 201 except Exception as e: db.session.rollback() return jsonify({'success': False, 'error': '注册失败,请重试'}), 500 @app.route('/api/account/phone/send-code', methods=['POST']) def send_sms_bind_code(): """发送绑定手机验证码(需已登录)""" # 调试日志:检查 session 状态 user_agent = request.headers.get('User-Agent', '') is_wechat = 'MicroMessenger' in user_agent print(f"[绑定手机验证码] User-Agent: {user_agent[:100]}...") print(f"[绑定手机验证码] 是否微信浏览器: {is_wechat}") print(f"[绑定手机验证码] session 内容: logged_in={session.get('logged_in')}, user_id={session.get('user_id')}") print(f"[绑定手机验证码] Cookie: {request.cookies.get('session', 'None')[:20] if request.cookies.get('session') else 'None'}...") if not session.get('logged_in'): print(f"[绑定手机验证码] ❌ 未登录,拒绝请求") return jsonify({'error': '未登录'}), 401 data = request.get_json() phone = data.get('phone') if not phone: return jsonify({'error': '手机号不能为空'}), 400 # 绑定时要求手机号未被占用 if User.query.filter_by(phone=phone).first(): return jsonify({'error': '该手机号已被其他账号使用'}), 400 code = generate_verification_code() if send_sms_code(phone, code, SMS_TEMPLATE_REGISTER): # 存储验证码到 Redis(5分钟有效) set_verification_code(f'bind_{phone}', code) return jsonify({'message': '验证码已发送'}), 200 else: return jsonify({'error': '验证码发送失败'}), 500 @app.route('/api/account/phone/bind', methods=['POST']) def bind_phone(): """当前登录用户绑定手机号""" if not session.get('logged_in'): return jsonify({'error': '未登录'}), 401 data = request.get_json() phone = data.get('phone') code = data.get('code') if not phone or not code: return jsonify({'error': '手机号和验证码不能为空'}), 400 # 从 Redis 获取验证码 stored = get_verification_code(f'bind_{phone}') if not stored or stored['expires'] < time.time(): return jsonify({'error': '验证码已过期'}), 400 if stored['code'] != code: return jsonify({'error': '验证码错误'}), 400 if User.query.filter_by(phone=phone).first(): return jsonify({'error': '该手机号已被其他账号使用'}), 400 try: user = User.query.get(session.get('user_id')) if not user: return jsonify({'error': '用户不存在'}), 404 user.phone = phone user.confirm_phone() # 清除验证码(从 Redis 删除) delete_verification_code(f'bind_{phone}') return jsonify({'message': '绑定成功', 'success': True}), 200 except Exception as e: print(f"Bind phone error: {e}") db.session.rollback() return jsonify({'error': '绑定失败,请重试'}), 500 @app.route('/api/account/phone/unbind', methods=['POST']) def unbind_phone(): """解绑手机号(需已登录)""" if not session.get('logged_in'): return jsonify({'error': '未登录'}), 401 try: user = User.query.get(session.get('user_id')) if not user: return jsonify({'error': '用户不存在'}), 404 user.phone = None user.phone_confirmed = False user.phone_confirm_time = None db.session.commit() return jsonify({'message': '解绑成功', 'success': True}), 200 except Exception as e: print(f"Unbind phone error: {e}") db.session.rollback() return jsonify({'error': '解绑失败,请重试'}), 500 @app.route('/api/account/email/send-bind-code', methods=['POST']) def send_email_bind_code(): """发送绑定邮箱验证码(需已登录)""" if not session.get('logged_in'): return jsonify({'error': '未登录'}), 401 data = request.get_json() email = data.get('email') if not email: return jsonify({'error': '邮箱不能为空'}), 400 # 邮箱格式验证 if not re.match(r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$', email): return jsonify({'error': '邮箱格式不正确'}), 400 # 检查邮箱是否已被其他账号使用 if User.query.filter_by(email=email).first(): return jsonify({'error': '该邮箱已被其他账号使用'}), 400 # 生成验证码 code = ''.join(random.choices(string.digits, k=6)) if send_email_code(email, code): # 存储验证码到 Redis(5分钟有效) set_verification_code(f'bind_{email}', code) return jsonify({'message': '验证码已发送'}), 200 else: return jsonify({'error': '验证码发送失败'}), 500 @app.route('/api/account/email/bind', methods=['POST']) def bind_email(): """当前登录用户绑定邮箱""" if not session.get('logged_in'): return jsonify({'error': '未登录'}), 401 data = request.get_json() email = data.get('email') code = data.get('code') if not email or not code: return jsonify({'error': '邮箱和验证码不能为空'}), 400 # 从 Redis 获取验证码 stored = get_verification_code(f'bind_{email}') if not stored or stored['expires'] < time.time(): return jsonify({'error': '验证码已过期'}), 400 if stored['code'] != code: return jsonify({'error': '验证码错误'}), 400 if User.query.filter_by(email=email).first(): return jsonify({'error': '该邮箱已被其他账号使用'}), 400 try: user = User.query.get(session.get('user_id')) if not user: return jsonify({'error': '用户不存在'}), 404 user.email = email user.confirm_email() db.session.commit() # 清除验证码(从 Redis 删除) delete_verification_code(f'bind_{email}') return jsonify({ 'message': '邮箱绑定成功', 'success': True, 'user': { 'email': user.email, 'email_confirmed': user.email_confirmed } }), 200 except Exception as e: print(f"Bind email error: {e}") db.session.rollback() return jsonify({'error': '绑定失败,请重试'}), 500 @app.route('/api/account/email/unbind', methods=['POST']) def unbind_email(): """解绑邮箱(需已登录)""" if not session.get('logged_in'): return jsonify({'error': '未登录'}), 401 try: user = User.query.get(session.get('user_id')) if not user: return jsonify({'error': '用户不存在'}), 404 user.email = None user.email_confirmed = False db.session.commit() return jsonify({'message': '解绑成功', 'success': True}), 200 except Exception as e: print(f"Unbind email error: {e}") db.session.rollback() return jsonify({'error': '解绑失败,请重试'}), 500 @app.route('/api/auth/register/email', methods=['POST']) def register_with_email(): """邮箱注册 - 使用Session""" data = request.get_json() email = data.get('email') code = data.get('code') password = data.get('password') username = data.get('username') if not all([email, code, password, username]): return jsonify({'success': False, 'error': '所有字段都是必填的'}), 400 # 验证验证码(从 Redis 获取) stored_code = get_verification_code(f'email_{email}') if not stored_code or stored_code['expires'] < time.time(): return jsonify({'success': False, 'error': '验证码已过期'}), 400 if stored_code['code'] != code: return jsonify({'success': False, 'error': '验证码错误'}), 400 if User.query.filter_by(username=username).first(): return jsonify({'success': False, 'error': '用户名已存在'}), 400 try: # 创建用户 user = User(username=username, email=email) user.set_password(password) user.email_confirmed = True db.session.add(user) db.session.flush() # 获取 user.id # 自动创建积分账户,初始10000积分 credit_account = UserCreditAccount( user_id=user.id, balance=10000, frozen=0 ) db.session.add(credit_account) db.session.commit() # 清除验证码(从 Redis 删除) delete_verification_code(f'email_{email}') # 自动登录 session.permanent = True session['user_id'] = user.id session['username'] = user.username session['logged_in'] = True # Flask-Login 登录 login_user(user, remember=True) return jsonify({ 'success': True, 'message': '注册成功', 'user': { 'id': user.id, 'username': user.username, 'email': user.email } }), 201 except Exception as e: db.session.rollback() return jsonify({'success': False, 'error': '注册失败,请重试'}), 500 def _safe_http_get(url, params=None, timeout=10): """安全的 HTTP GET 请求(绕过 eventlet DNS 问题) 使用 subprocess 调用 curl,完全绕过 Python/eventlet 的网络栈 """ import subprocess import urllib.parse # 构建完整 URL if params: query_string = urllib.parse.urlencode(params) full_url = f"{url}?{query_string}" else: full_url = url try: # 使用 curl 发起请求,绕过 eventlet DNS 问题 result = subprocess.run( ['curl', '-s', '-m', str(timeout), full_url], capture_output=True, text=True, timeout=timeout + 5 ) if result.returncode != 0: print(f"❌ curl 请求失败: returncode={result.returncode}, stderr={result.stderr}") return None # 返回一个模拟 Response 对象 class MockResponse: def __init__(self, text): self.text = text self.content = text.encode('utf-8') self.encoding = 'utf-8' def json(self): return json.loads(self.text) return MockResponse(result.stdout) except subprocess.TimeoutExpired: print(f"❌ curl 请求超时: {full_url}") return None except Exception as e: print(f"❌ curl 请求异常: {type(e).__name__}: {e}") return None def get_wechat_access_token(code, appid=None, appsecret=None): """通过code获取微信access_token Args: code: 微信授权后返回的 code appid: 微信 AppID(可选,默认使用开放平台配置) appsecret: 微信 AppSecret(可选,默认使用开放平台配置) """ url = "https://api.weixin.qq.com/sns/oauth2/access_token" params = { 'appid': appid or WECHAT_OPEN_APPID, 'secret': appsecret or WECHAT_OPEN_APPSECRET, 'code': code, 'grant_type': 'authorization_code' } try: print(f"🔄 正在获取微信 access_token... (appid={params['appid'][:8]}...)") response = _safe_http_get(url, params=params, timeout=15) data = response.json() if 'errcode' in data: print(f"❌ WeChat access token error: {data}") return None print(f"✅ 成功获取 access_token: openid={data.get('openid', 'N/A')}") return data except Exception as e: print(f"❌ WeChat access token request error: {type(e).__name__}: {e}") import traceback traceback.print_exc() return None def get_wechat_userinfo(access_token, openid): """获取微信用户信息(包含UnionID)""" url = "https://api.weixin.qq.com/sns/userinfo" params = { 'access_token': access_token, 'openid': openid, 'lang': 'zh_CN' } try: print(f"🔄 正在获取微信用户信息... (openid={openid})") response = _safe_http_get(url, params=params, timeout=15) response.encoding = 'utf-8' # 明确设置编码为UTF-8 data = response.json() if 'errcode' in data: print(f"❌ WeChat userinfo error: {data}") return None # 确保nickname字段的编码正确 if 'nickname' in data and data['nickname']: # 确保昵称是正确的UTF-8编码 try: # 检查是否已经是正确的UTF-8字符串 data['nickname'] = data['nickname'].encode('utf-8').decode('utf-8') except (UnicodeEncodeError, UnicodeDecodeError) as e: print(f"Nickname encoding error: {e}, using default") data['nickname'] = '微信用户' print(f"✅ 成功获取用户信息: nickname={data.get('nickname', 'N/A')}") return data except Exception as e: print(f"❌ WeChat userinfo request error: {type(e).__name__}: {e}") import traceback traceback.print_exc() return None @app.route('/api/auth/wechat/qrcode', methods=['GET']) def get_wechat_qrcode(): """返回微信授权URL,前端使用iframe展示""" # 生成唯一state参数 state = uuid.uuid4().hex # URL编码回调地址 redirect_uri = urllib.parse.quote_plus(WECHAT_REDIRECT_URI) # 构建微信授权URL(PC 扫码登录使用开放平台 AppID) wechat_auth_url = ( f"https://open.weixin.qq.com/connect/qrconnect?" f"appid={WECHAT_OPEN_APPID}&redirect_uri={redirect_uri}" f"&response_type=code&scope=snsapi_login&state={state}" "#wechat_redirect" ) # 存储session信息到 Redis if not set_wechat_session(state, { 'status': 'waiting', 'user_info': None, 'wechat_openid': None, 'wechat_unionid': None }): return jsonify({'error': '服务暂时不可用,请稍后重试'}), 500 return jsonify({"code":0, "data": { 'auth_url': wechat_auth_url, 'session_id': state, 'expires_in': 300 }}), 200 @app.route('/api/auth/wechat/h5-auth', methods=['POST']) def get_wechat_h5_auth_url(): """ 获取微信 H5 网页授权 URL 用于手机浏览器跳转微信 App 授权 """ data = request.get_json() or {} frontend_redirect = data.get('redirect_url', '/home') # 生成唯一 state state = uuid.uuid4().hex # 编码回调地址 redirect_uri = urllib.parse.quote_plus(WECHAT_REDIRECT_URI) # 构建授权 URL(H5 网页授权使用公众号 AppID) auth_url = ( f"https://open.weixin.qq.com/connect/oauth2/authorize?" f"appid={WECHAT_MP_APPID}&redirect_uri={redirect_uri}" f"&response_type=code&scope=snsapi_userinfo&state={state}" "#wechat_redirect" ) # 存储 session 信息到 Redis if not set_wechat_session(state, { 'status': 'waiting', 'mode': 'h5', # 标记为 H5 模式 'frontend_redirect': frontend_redirect, 'user_info': None, 'wechat_openid': None, 'wechat_unionid': None }): return jsonify({'error': '服务暂时不可用,请稍后重试'}), 500 return jsonify({ 'auth_url': auth_url, 'state': state }), 200 @app.route('/api/account/wechat/qrcode', methods=['GET']) def get_wechat_bind_qrcode(): """发起微信绑定二维码,会话标记为绑定模式""" if not session.get('logged_in'): return jsonify({'error': '未登录'}), 401 # 生成唯一state参数 state = uuid.uuid4().hex # URL编码回调地址 redirect_uri = urllib.parse.quote_plus(WECHAT_REDIRECT_URI) # 构建微信授权URL(PC 扫码绑定使用开放平台 AppID) wechat_auth_url = ( f"https://open.weixin.qq.com/connect/qrconnect?" f"appid={WECHAT_OPEN_APPID}&redirect_uri={redirect_uri}" f"&response_type=code&scope=snsapi_login&state={state}" "#wechat_redirect" ) # 存储session信息到 Redis,标记为绑定模式并记录目标用户 if not set_wechat_session(state, { 'status': 'waiting', 'mode': 'bind', 'bind_user_id': session.get('user_id'), 'user_info': None, 'wechat_openid': None, 'wechat_unionid': None }): return jsonify({'error': '服务暂时不可用,请稍后重试'}), 500 return jsonify({ 'auth_url': wechat_auth_url, 'session_id': state, 'expires_in': 300 }), 200 @app.route('/api/auth/wechat/check', methods=['POST']) def check_wechat_scan(): """检查微信扫码状态""" data = request.get_json() session_id = data.get('session_id') if not session_id: return jsonify({'status': 'invalid', 'error': '无效的session'}), 400 # 从 Redis 获取 session sess = get_wechat_session(session_id) if not sess: return jsonify({'status': 'expired'}), 200 # Redis 自动过期,返回 expired # 获取剩余 TTL ttl = redis_client.ttl(f"{WECHAT_SESSION_PREFIX}{session_id}") expires_in = max(0, ttl) if ttl > 0 else 0 return jsonify({ 'status': sess['status'], 'user_info': sess.get('user_info'), 'expires_in': expires_in }), 200 @app.route('/api/account/wechat/check', methods=['POST']) def check_wechat_bind_scan(): """检查微信扫码绑定状态""" data = request.get_json() session_id = data.get('session_id') if not session_id: return jsonify({'status': 'invalid', 'error': '无效的session'}), 400 # 从 Redis 获取 session sess = get_wechat_session(session_id) if not sess: return jsonify({'status': 'expired'}), 200 # Redis 自动过期,返回 expired # 绑定模式限制 if sess.get('mode') != 'bind': return jsonify({'status': 'invalid', 'error': '会话模式错误'}), 400 # 获取剩余 TTL ttl = redis_client.ttl(f"{WECHAT_SESSION_PREFIX}{session_id}") expires_in = max(0, ttl) if ttl > 0 else 0 return jsonify({ 'status': sess['status'], 'user_info': sess.get('user_info'), 'expires_in': expires_in }), 200 @app.route('/api/auth/wechat/callback', methods=['GET']) def wechat_callback(): """微信授权回调处理 - 使用Session""" code = request.args.get('code') state = request.args.get('state') error = request.args.get('error') # 错误处理:用户拒绝授权 if error: if state and wechat_session_exists(state): update_wechat_session(state, {'status': 'auth_denied', 'error': '用户拒绝授权'}) print(f"❌ 用户拒绝授权: state={state}") return redirect(f'{FRONTEND_URL}/home?error=wechat_auth_denied') # 参数验证 if not code or not state: if state and wechat_session_exists(state): update_wechat_session(state, {'status': 'auth_failed', 'error': '授权参数缺失'}) return redirect(f'{FRONTEND_URL}/home?error=wechat_auth_failed') # 从 Redis 获取 session(自动处理过期) session_data = get_wechat_session(state) if not session_data: return redirect(f'{FRONTEND_URL}/home?error=session_expired') try: # 步骤1: 用户已扫码并授权(微信回调过来说明用户已完成扫码+授权) update_wechat_session(state, {'status': 'scanned'}) print(f"✅ 微信扫码回调: state={state}, code={code[:10]}...") # 步骤2: 根据授权模式选择对应的 AppID/AppSecret # H5 模式使用公众号配置,PC 扫码和绑定模式使用开放平台配置 if session_data.get('mode') == 'h5': appid = WECHAT_MP_APPID appsecret = WECHAT_MP_APPSECRET print(f"📱 H5 模式授权,使用公众号配置") else: appid = WECHAT_OPEN_APPID appsecret = WECHAT_OPEN_APPSECRET print(f"💻 PC 模式授权,使用开放平台配置") # 步骤3: 获取access_token token_data = get_wechat_access_token(code, appid, appsecret) if not token_data: update_wechat_session(state, {'status': 'auth_failed', 'error': '获取访问令牌失败'}) print(f"❌ 获取微信access_token失败: state={state}") return redirect(f'{FRONTEND_URL}/home?error=token_failed') # 步骤3: Token获取成功,标记为已授权 update_wechat_session(state, {'status': 'authorized'}) print(f"✅ 微信授权成功: openid={token_data['openid']}") # 步骤4: 获取用户信息 user_info = get_wechat_userinfo(token_data['access_token'], token_data['openid']) if not user_info: update_wechat_session(state, {'status': 'auth_failed', 'error': '获取用户信息失败'}) print(f"❌ 获取微信用户信息失败: openid={token_data['openid']}") return redirect(f'{FRONTEND_URL}/home?error=userinfo_failed') # 查找或创建用户 / 或处理绑定 openid = token_data['openid'] unionid = user_info.get('unionid') or token_data.get('unionid') # 如果是绑定流程 if session_data.get('mode') == 'bind': try: target_user_id = session.get('user_id') or session_data.get('bind_user_id') if not target_user_id: return redirect(f'{FRONTEND_URL}/home?error=bind_no_user') target_user = User.query.get(target_user_id) if not target_user: return redirect(f'{FRONTEND_URL}/home?error=bind_user_missing') # 检查该微信是否已被其他账户绑定 existing = None if unionid: existing = User.query.filter_by(wechat_union_id=unionid).first() if not existing: existing = User.query.filter_by(wechat_open_id=openid).first() if existing and existing.id != target_user.id: update_wechat_session(state, {'status': 'bind_conflict'}) return redirect(f'{FRONTEND_URL}/home?bind=conflict') # 执行绑定 target_user.bind_wechat(openid, unionid, wechat_info=user_info) # 标记绑定完成,供前端轮询 update_wechat_session(state, {'status': 'bind_ready', 'user_info': {'user_id': target_user.id}}) return redirect(f'{FRONTEND_URL}/home?bind=success') except Exception as e: print(f"❌ 微信绑定失败: {e}") db.session.rollback() update_wechat_session(state, {'status': 'bind_failed'}) return redirect(f'{FRONTEND_URL}/home?bind=failed') user = None is_new_user = False # 统一使用 unionid 匹配用户(H5 和 PC 模式都一样) if not unionid: # 没有获取到 unionid,无法关联账号 mode_name = 'H5' if session_data.get('mode') == 'h5' else 'PC' update_wechat_session(state, {'status': 'auth_failed', 'error': f'{mode_name}授权未返回unionid'}) print(f"❌ {mode_name} 授权未返回 unionid, openid={openid}, user_info={user_info}") # 调试信息:将微信返回的数据通过 URL 传给前端 debug_params = urllib.parse.urlencode({ 'error': 'no_unionid', 'debug_mode': mode_name, 'debug_openid': openid[:10] + '...' if openid else 'null', 'debug_has_unionid_in_token': '1' if token_data.get('unionid') else '0', 'debug_has_unionid_in_userinfo': '1' if user_info.get('unionid') else '0', 'debug_nickname': user_info.get('nickname', '')[:10], 'debug_keys_in_userinfo': ','.join(user_info.keys()) if user_info else 'null', }) return redirect(f'{FRONTEND_URL}/home?{debug_params}') user = User.query.filter_by(wechat_union_id=unionid).first() if not user: # 创建新用户 # 先清理微信昵称 raw_nickname = user_info.get('nickname', '微信用户') # 创建临时用户实例以使用清理方法 temp_user = User.__new__(User) sanitized_nickname = temp_user._sanitize_nickname(raw_nickname) username = sanitized_nickname counter = 1 while User.is_username_taken(username): username = f"{sanitized_nickname}_{counter}" counter += 1 user = User(username=username) user.nickname = sanitized_nickname user.avatar_url = user_info.get('headimgurl') user.wechat_open_id = openid user.wechat_union_id = unionid user.set_password(uuid.uuid4().hex) user.status = 'active' db.session.add(user) db.session.commit() is_new_user = True print(f"✅ 微信扫码自动创建新用户: {username}, openid: {openid}") # 更新最后登录时间 user.update_last_seen() # 设置session session.permanent = True session['user_id'] = user.id session['username'] = user.username session['logged_in'] = True session['wechat_login'] = True # 标记是微信登录 # Flask-Login 登录 login_user(user, remember=True) # 更新微信session状态,供前端轮询检测 mode = session_data.get('mode') # H5 模式:重定向到前端回调页面 if mode == 'h5': frontend_redirect = session_data.get('frontend_redirect', '/home/wechat-callback') # 清理 session delete_wechat_session(state) print(f"✅ H5 微信登录成功,重定向到: {frontend_redirect}") # 调试信息:携带微信返回的关键数据 debug_params = urllib.parse.urlencode({ 'wechat_login': 'success', 'debug_is_new_user': '1' if is_new_user else '0', 'debug_has_unionid': '1' if unionid else '0', 'debug_unionid': (unionid[:10] + '...') if unionid else 'null', 'debug_openid': (openid[:10] + '...') if openid else 'null', 'debug_user_id': user.id, 'debug_nickname': user_info.get('nickname', '')[:10], }) # ⚡ 修复:正确处理已有查询参数的 URL separator = '&' if '?' in frontend_redirect else '?' return redirect(f"{frontend_redirect}{separator}{debug_params}") # PC 扫码模式:更新状态供前端轮询 if not mode: new_status = 'register_ready' if is_new_user else 'login_ready' update_wechat_session(state, {'status': new_status, 'user_info': {'user_id': user.id}}) print(f"✅ 微信扫码状态已更新: {new_status}, user_id: {user.id}") # ⚡ PC 扫码模式:重定向到前端回调页面 # 微信扫码登录会跳转整个页面,所以需要重定向到前端处理 pc_redirect_params = urllib.parse.urlencode({ 'wechat_login': 'success', 'state': state, 'is_new_user': '1' if is_new_user else '0', }) print(f"✅ PC 微信登录成功,重定向到前端回调页面") return redirect(f"{FRONTEND_URL}/home/wechat-callback?{pc_redirect_params}") except Exception as e: print(f"❌ 微信登录失败: {e}") import traceback traceback.print_exc() db.session.rollback() # 更新session状态为失败 if wechat_session_exists(state): update_wechat_session(state, {'status': 'auth_failed', 'error': str(e)}) # ⚡ 重定向到首页并显示错误 return redirect(f'{FRONTEND_URL}/home?error=wechat_login_failed') @app.route('/api/auth/login/wechat', methods=['POST']) def login_with_wechat(): """微信登录 - 修复版本""" data = request.get_json() session_id = data.get('session_id') if not session_id: return jsonify({'success': False, 'error': 'session_id不能为空'}), 400 # 从 Redis 获取 session wechat_sess = get_wechat_session(session_id) if not wechat_sess: return jsonify({'success': False, 'error': '会话不存在或已过期'}), 400 # 检查session状态 if wechat_sess['status'] not in ['login_ready', 'register_ready']: return jsonify({'success': False, 'error': '会话状态无效'}), 400 # 检查是否有用户信息 user_info = wechat_sess.get('user_info') if not user_info or not user_info.get('user_id'): return jsonify({'success': False, 'error': '用户信息不完整'}), 400 try: user = User.query.get(user_info['user_id']) if not user: return jsonify({'success': False, 'error': '用户不存在'}), 404 # 更新最后登录时间 user.update_last_seen() # Redis 会自动过期,无需手动延迟删除 # 保留 session 状态供前端轮询,Redis TTL 会自动清理 # 生成登录响应 response_data = { 'success': True, 'message': '登录成功' if wechat_sess['status'] == 'login_ready' else '注册并登录成功', 'user': { 'id': user.id, 'username': user.username, 'nickname': user.nickname or user.username, 'email': user.email, 'phone': user.phone, 'phone_confirmed': bool(user.phone_confirmed), 'avatar_url': user.avatar_url, 'has_wechat': True, 'wechat_open_id': user.wechat_open_id, 'wechat_union_id': user.wechat_union_id, 'created_at': user.created_at.isoformat() if user.created_at else None, 'last_seen': user.last_seen.isoformat() if user.last_seen else None }, 'isNewUser': wechat_sess['status'] == 'register_ready' # 标记是否为新用户 } # 如果需要token认证,可以在这里生成 # response_data['token'] = generate_token(user.id) return jsonify(response_data), 200 except Exception as e: print(f"❌ 微信登录错误: {e}") import traceback app.logger.error(f"回调处理错误: {e}", exc_info=True) return jsonify({ 'success': False, 'error': '登录失败,请重试' }), 500 @app.route('/api/account/wechat/unbind', methods=['POST']) def unbind_wechat_account(): """解绑当前登录用户的微信""" if not session.get('logged_in'): return jsonify({'error': '未登录'}), 401 try: user = User.query.get(session.get('user_id')) if not user: return jsonify({'error': '用户不存在'}), 404 user.unbind_wechat() return jsonify({'message': '解绑成功', 'success': True}), 200 except Exception as e: print(f"Unbind wechat error: {e}") db.session.rollback() return jsonify({'error': '解绑失败,请重试'}), 500 # ============ H5 跳转小程序相关 API ============ def get_wechat_access_token_cached(appid, appsecret): """ 获取微信 access_token(Redis 缓存,支持多 Worker) Args: appid: 微信 AppID(公众号或小程序) appsecret: 对应的 AppSecret Returns: access_token 字符串,失败返回 None """ cache_key = f"{WECHAT_ACCESS_TOKEN_PREFIX}{appid}" # 1. 尝试从 Redis 获取缓存 try: cached = redis_client.get(cache_key) if cached: data = json.loads(cached) # 提前 5 分钟刷新,避免临界问题 if data.get('expires_at', 0) > time.time() + 300: print(f"[access_token] 使用缓存: appid={appid[:8]}...") return data['token'] except Exception as e: print(f"[access_token] Redis 读取失败: {e}") # 2. 请求新 token url = "https://api.weixin.qq.com/cgi-bin/token" params = { 'grant_type': 'client_credential', 'appid': appid, 'secret': appsecret } try: response = requests.get(url, params=params, timeout=10) result = response.json() if 'access_token' in result: token = result['access_token'] expires_in = result.get('expires_in', 7200) # 3. 存入 Redis(TTL 比 token 有效期短 60 秒) cache_data = { 'token': token, 'expires_at': time.time() + expires_in } redis_client.setex( cache_key, expires_in - 60, json.dumps(cache_data) ) print(f"[access_token] 获取成功: appid={appid[:8]}..., expires_in={expires_in}s") return token else: print(f"[access_token] 获取失败: errcode={result.get('errcode')}, errmsg={result.get('errmsg')}") return None except Exception as e: print(f"[access_token] 请求异常: {e}") return None def get_jsapi_ticket_cached(appid, appsecret): """ 获取 jsapi_ticket(Redis 缓存) 用于 JS-SDK 签名 """ cache_key = f"{WECHAT_JSAPI_TICKET_PREFIX}{appid}" # 1. 尝试从缓存获取 try: cached = redis_client.get(cache_key) if cached: data = json.loads(cached) if data.get('expires_at', 0) > time.time() + 300: print(f"[jsapi_ticket] 使用缓存") return data['ticket'] except Exception as e: print(f"[jsapi_ticket] Redis 读取失败: {e}") # 2. 获取 access_token access_token = get_wechat_access_token_cached(appid, appsecret) if not access_token: return None # 3. 请求 jsapi_ticket url = "https://api.weixin.qq.com/cgi-bin/ticket/getticket" params = { 'access_token': access_token, 'type': 'jsapi' } try: response = requests.get(url, params=params, timeout=10) result = response.json() if result.get('errcode') == 0: ticket = result['ticket'] expires_in = result.get('expires_in', 7200) # 存入 Redis cache_data = { 'ticket': ticket, 'expires_at': time.time() + expires_in } redis_client.setex( cache_key, expires_in - 60, json.dumps(cache_data) ) print(f"[jsapi_ticket] 获取成功, expires_in={expires_in}s") return ticket else: print(f"[jsapi_ticket] 获取失败: errcode={result.get('errcode')}, errmsg={result.get('errmsg')}") return None except Exception as e: print(f"[jsapi_ticket] 请求异常: {e}") return None def generate_jssdk_signature(url, appid, appsecret): """ 生成 JS-SDK 签名配置 Args: url: 当前页面 URL(不含 # 及其后的部分) appid: 公众号 AppID appsecret: 公众号 AppSecret Returns: 签名配置字典,失败返回 None """ import hashlib # 获取 jsapi_ticket ticket = get_jsapi_ticket_cached(appid, appsecret) if not ticket: return None # 生成签名参数 timestamp = int(time.time()) nonce_str = uuid.uuid4().hex # 签名字符串(必须按字典序排序!) sign_str = f"jsapi_ticket={ticket}&noncestr={nonce_str}×tamp={timestamp}&url={url}" # SHA1 签名 signature = hashlib.sha1(sign_str.encode('utf-8')).hexdigest() return { 'appId': appid, 'timestamp': timestamp, 'nonceStr': nonce_str, 'signature': signature, 'jsApiList': ['updateAppMessageShareData', 'updateTimelineShareData'], 'openTagList': ['wx-open-launch-weapp'] } @app.route('/api/wechat/jssdk-config', methods=['POST']) def api_wechat_jssdk_config(): """获取微信 JS-SDK 签名配置(用于开放标签)""" try: print(f"[JS-SDK Config] 收到请求") data = request.get_json() or {} url = data.get('url') print(f"[JS-SDK Config] URL: {url}") if not url: print(f"[JS-SDK Config] 错误: 缺少 url 参数") return jsonify({ 'code': 400, 'message': '缺少必要参数 url', 'data': None }), 400 # URL 校验:必须是允许的域名 from urllib.parse import urlparse parsed = urlparse(url) # 扩展允许的域名列表,包括 API 域名 allowed_domains = ['valuefrontier.cn', 'www.valuefrontier.cn', 'api.valuefrontier.cn', 'localhost', '127.0.0.1'] domain = parsed.netloc.split(':')[0] print(f"[JS-SDK Config] 解析域名: {domain}") if domain not in allowed_domains: return jsonify({ 'code': 400, 'message': 'URL 域名不在允许范围内', 'data': None }), 400 # URL 处理:移除 hash 部分 if '#' in url: url = url.split('#')[0] # 生成签名(使用公众号配置) print(f"[JS-SDK Config] 开始生成签名...") config = generate_jssdk_signature( url=url, appid=WECHAT_MP_APPID, appsecret=WECHAT_MP_APPSECRET ) print(f"[JS-SDK Config] 签名生成完成: {config is not None}") if not config: return jsonify({ 'code': 500, 'message': '获取签名配置失败,请稍后重试', 'data': None }), 500 return jsonify({ 'code': 200, 'message': 'success', 'data': config }) except Exception as e: print(f"[JS-SDK Config] 异常: {e}") import traceback traceback.print_exc() return jsonify({ 'code': 500, 'message': '服务器内部错误', 'data': None }), 500 @app.route('/api/miniprogram/url-scheme', methods=['POST']) def api_miniprogram_url_scheme(): """生成小程序 URL Scheme(外部浏览器跳转小程序用)""" try: # 频率限制 client_ip = request.headers.get('X-Forwarded-For', request.remote_addr) if client_ip: client_ip = client_ip.split(',')[0].strip() rate_key = f"rate_limit:urlscheme:{client_ip}" current = redis_client.incr(rate_key) if current == 1: redis_client.expire(rate_key, 60) if current > 30: # 每分钟最多 30 次 return jsonify({ 'code': 429, 'message': '请求过于频繁,请稍后再试', 'data': None }), 429 data = request.get_json() or {} # 参数校验 path = data.get('path') if path and not path.startswith('/'): path = '/' + path # 自动补全 / # 获取小程序 access_token access_token = get_wechat_access_token_cached( WECHAT_MINIPROGRAM_APPID, WECHAT_MINIPROGRAM_APPSECRET ) if not access_token: return jsonify({ 'code': 500, 'message': '获取访问令牌失败', 'data': None }), 500 # 构建请求参数 wx_url = f"https://api.weixin.qq.com/wxa/generatescheme?access_token={access_token}" expire_type = data.get('expire_type', 1) expire_interval = min(data.get('expire_interval', 30), 30) # 最长30天 payload = { "is_expire": expire_type == 1 } # 跳转信息 if path or data.get('query'): payload["jump_wxa"] = {} if path: payload["jump_wxa"]["path"] = path if data.get('query'): payload["jump_wxa"]["query"] = data.get('query') # 有效期设置 if expire_type == 1: if data.get('expire_time'): payload["expire_time"] = data.get('expire_time') else: payload["expire_interval"] = expire_interval response = requests.post(wx_url, json=payload, timeout=10) result = response.json() if result.get('errcode') == 0: return jsonify({ 'code': 200, 'message': 'success', 'data': { 'openlink': result['openlink'], 'expire_time': data.get('expire_time') or (int(time.time()) + expire_interval * 86400), 'created_at': datetime.utcnow().isoformat() + 'Z' } }) else: print(f"[URL Scheme] 生成失败: errcode={result.get('errcode')}, errmsg={result.get('errmsg')}") return jsonify({ 'code': 500, 'message': f"生成 URL Scheme 失败: {result.get('errmsg', '未知错误')}", 'data': None }), 500 except Exception as e: print(f"[URL Scheme] 异常: {e}") import traceback traceback.print_exc() return jsonify({ 'code': 500, 'message': '服务器内部错误', 'data': None }), 500 # 评论模型 class EventComment(db.Model): """事件评论""" __tablename__ = 'event_comment' id = db.Column(db.Integer, primary_key=True) event_id = db.Column(db.Integer, nullable=False) user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=True) author = db.Column(db.String(50), default='匿名用户') content = db.Column(db.Text, nullable=False) parent_id = db.Column(db.Integer, db.ForeignKey('event_comment.id')) likes = db.Column(db.Integer, default=0) created_at = db.Column(db.DateTime, default=beijing_now) status = db.Column(db.String(20), default='active') user = db.relationship('User', backref='event_comments') replies = db.relationship('EventComment', backref=db.backref('parent', remote_side=[id])) def to_dict(self, user_session_id=None, current_user_id=None): # 检查当前用户是否已点赞 user_liked = False if user_session_id: like_record = CommentLike.query.filter_by( comment_id=self.id, session_id=user_session_id ).first() user_liked = like_record is not None # 检查当前用户是否可以删除此评论 can_delete = current_user_id is not None and self.user_id == current_user_id return { 'id': self.id, 'event_id': self.event_id, 'author': self.author, 'content': self.content, 'parent_id': self.parent_id, 'likes': self.likes, 'created_at': self.created_at.isoformat() if self.created_at else None, 'user_liked': user_liked, 'can_delete': can_delete, 'user_id': self.user_id, 'replies': [reply.to_dict(user_session_id, current_user_id) for reply in self.replies if reply.status == 'active'] } class CommentLike(db.Model): """评论点赞记录""" __tablename__ = 'comment_like' id = db.Column(db.Integer, primary_key=True) comment_id = db.Column(db.Integer, db.ForeignKey('event_comment.id'), nullable=False) session_id = db.Column(db.String(100), nullable=False) created_at = db.Column(db.DateTime, default=beijing_now) __table_args__ = (db.UniqueConstraint('comment_id', 'session_id'),) @app.after_request def after_request(response): """处理所有响应,添加CORS头部和安全头部""" origin = request.headers.get('Origin') allowed_origins = ['http://localhost:3000', 'http://127.0.0.1:3000', 'http://localhost:5173', 'https://valuefrontier.cn', 'http://valuefrontier.cn', 'https://www.valuefrontier.cn', 'http://www.valuefrontier.cn'] if origin in allowed_origins: response.headers['Access-Control-Allow-Origin'] = origin response.headers['Access-Control-Allow-Credentials'] = 'true' response.headers['Access-Control-Allow-Headers'] = 'Content-Type,Authorization,X-Requested-With,Cache-Control' response.headers['Access-Control-Allow-Methods'] = 'GET,PUT,POST,DELETE,OPTIONS' response.headers['Access-Control-Expose-Headers'] = 'Content-Type,Authorization' # 处理预检请求 if request.method == 'OPTIONS': response.status_code = 200 return response def add_cors_headers(response): """添加CORS头(保留原有函数以兼容)""" origin = request.headers.get('Origin') allowed_origins = ['http://localhost:3000', 'http://127.0.0.1:3000', 'http://localhost:5173', 'https://valuefrontier.cn', 'http://valuefrontier.cn', 'https://www.valuefrontier.cn', 'http://www.valuefrontier.cn'] if origin in allowed_origins: response.headers['Access-Control-Allow-Origin'] = origin else: response.headers['Access-Control-Allow-Origin'] = 'http://localhost:3000' response.headers['Access-Control-Allow-Headers'] = 'Content-Type,Authorization,X-Requested-With,Cache-Control' response.headers['Access-Control-Allow-Methods'] = 'GET,PUT,POST,DELETE,OPTIONS' response.headers['Access-Control-Allow-Credentials'] = 'true' return response class EventFollow(db.Model): """事件关注""" id = db.Column(db.Integer, primary_key=True) user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False) event_id = db.Column(db.Integer, db.ForeignKey('event.id'), nullable=False) created_at = db.Column(db.DateTime, default=beijing_now) user = db.relationship('User', backref='event_follows') __table_args__ = (db.UniqueConstraint('user_id', 'event_id'),) class FutureEventFollow(db.Model): """未来事件关注""" __tablename__ = 'future_event_follow' id = db.Column(db.Integer, primary_key=True) user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False) future_event_id = db.Column(db.Integer, nullable=False) # future_events表的id created_at = db.Column(db.DateTime, default=beijing_now) user = db.relationship('User', backref='future_event_follows') __table_args__ = (db.UniqueConstraint('user_id', 'future_event_id'),) # —— 自选股输入统一化与名称补全工具 —— def _normalize_stock_input(raw_input: str): """解析用户输入为标准6位股票代码与可选名称。 支持: - 6位代码: "600519",或带后缀 "600519.SH"/"600519.SZ" - 名称(代码): "贵州茅台(600519)" 或 "贵州茅台(600519)" 返回 (code6, name_or_none) """ if not raw_input: return None, None s = str(raw_input).strip() # 名称(600519) 或 名称(600519) m = re.match(r"^(.+?)[\((]\s*(\d{6})\s*[\))]\s*$", s) if m: name = m.group(1).strip() code = m.group(2) return code, (name if name else None) # 600519 或 600519.SH / 600519.SZ m2 = re.match(r"^(\d{6})(?:\.(?:SH|SZ))?$", s, re.IGNORECASE) if m2: return m2.group(1), None # SH600519 / SZ000001 m3 = re.match(r"^(SH|SZ)(\d{6})$", s, re.IGNORECASE) if m3: return m3.group(2), None return None, None def _query_stock_name_by_code(code6: str): """根据6位代码查询股票名称,查不到返回None。""" try: with engine.connect() as conn: q = text(""" SELECT SECNAME FROM ea_baseinfo WHERE SECCODE = :c LIMIT 1 """) row = conn.execute(q, {'c': code6}).fetchone() if row: return row[0] except Exception: pass return None class Watchlist(db.Model): """用户自选股""" __tablename__ = 'watchlist' id = db.Column(db.Integer, primary_key=True) user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False) stock_code = db.Column(db.String(20), nullable=False) stock_name = db.Column(db.String(100), nullable=True) created_at = db.Column(db.DateTime, default=beijing_now) user = db.relationship('User', backref='watchlist') __table_args__ = (db.UniqueConstraint('user_id', 'stock_code'),) @app.route('/api/account/watchlist', methods=['GET']) def get_my_watchlist(): """获取当前用户的自选股列表""" try: if 'user_id' not in session: return jsonify({'success': False, 'error': '未登录'}), 401 items = Watchlist.query.filter_by(user_id=session['user_id']).order_by(Watchlist.created_at.desc()).all() # 懒更新:统一代码为6位、补全缺失的名称,并去重(同一代码保留一个记录) from collections import defaultdict groups = defaultdict(list) for i in items: code6, _ = _normalize_stock_input(i.stock_code) normalized_code = code6 or (i.stock_code.strip().upper() if isinstance(i.stock_code, str) else i.stock_code) groups[normalized_code].append(i) dirty = False to_delete = [] for code6, group in groups.items(): # 选择保留记录:优先有名称的,其次创建时间早的 def sort_key(x): return (x.stock_name is None, x.created_at or datetime.min) group_sorted = sorted(group, key=sort_key) keep = group_sorted[0] # 规范保留项 if keep.stock_code != code6: keep.stock_code = code6 dirty = True if not keep.stock_name and code6: nm = _query_stock_name_by_code(code6) if nm: keep.stock_name = nm dirty = True # 其余删除 for g in group_sorted[1:]: to_delete.append(g) if to_delete: for g in to_delete: db.session.delete(g) dirty = True if dirty: db.session.commit() return jsonify({'success': True, 'data': [ { 'id': i.id, 'stock_code': i.stock_code, 'stock_name': i.stock_name, 'created_at': i.created_at.isoformat() if i.created_at else None } for i in items ]}) except Exception as e: print(f"Error in get_my_watchlist: {str(e)}") return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/account/watchlist', methods=['POST']) def add_to_watchlist(): """添加到自选股""" if 'user_id' not in session: return jsonify({'success': False, 'error': '未登录'}), 401 data = request.get_json() or {} raw_code = data.get('stock_code') raw_name = data.get('stock_name') code6, name_from_input = _normalize_stock_input(raw_code) if not code6: return jsonify({'success': False, 'error': '无效的股票标识'}), 400 # 优先使用传入名称,其次从输入解析中获得,最后查库补全 final_name = raw_name or name_from_input or _query_stock_name_by_code(code6) # 查找已存在记录,兼容历史:6位/带后缀 candidates = [code6, f"{code6}.SH", f"{code6}.SZ"] existing = Watchlist.query.filter( Watchlist.user_id == session['user_id'], Watchlist.stock_code.in_(candidates) ).first() if existing: # 统一为6位,补全名称 updated = False if existing.stock_code != code6: existing.stock_code = code6 updated = True if (not existing.stock_name) and final_name: existing.stock_name = final_name updated = True if updated: db.session.commit() return jsonify({'success': True, 'data': {'id': existing.id}}) item = Watchlist(user_id=session['user_id'], stock_code=code6, stock_name=final_name) db.session.add(item) db.session.commit() return jsonify({'success': True, 'data': {'id': item.id}}) # 注意:/realtime 路由必须在 / 之前定义,否则会被错误匹配 @app.route('/api/account/watchlist/realtime', methods=['GET']) def get_watchlist_realtime(): """获取自选股实时行情数据(基于分钟线)- 优化为批量查询""" try: if 'user_id' not in session: return jsonify({'success': False, 'error': '未登录'}), 401 # 获取用户自选股列表 watchlist = Watchlist.query.filter_by(user_id=session['user_id']).all() if not watchlist: return jsonify({'success': True, 'data': []}) # 获取股票代码列表并标准化 code_mapping = {} # code6 -> full_code 映射 full_codes = [] for item in watchlist: code6, _ = _normalize_stock_input(item.stock_code) normalized = code6 or str(item.stock_code).strip().upper() # 转换为带后缀的完整代码 if '.' in normalized: full_code = normalized elif normalized.startswith('6'): full_code = f"{normalized}.SH" elif normalized.startswith(('8', '9', '4')): full_code = f"{normalized}.BJ" else: full_code = f"{normalized}.SZ" code_mapping[normalized] = full_code full_codes.append(full_code) if not full_codes: return jsonify({'success': True, 'data': []}) # 使用批量查询获取最新行情(单次查询) client = get_clickhouse_client() today = datetime.now().date() start_date = today - timedelta(days=7) # 批量查询:获取每只股票的最新一条分钟数据 batch_query = """ WITH latest AS ( SELECT code, close, timestamp, high, low, volume, amt, ROW_NUMBER() OVER (PARTITION BY code ORDER BY timestamp DESC) as rn FROM stock_minute WHERE code IN %(codes)s AND timestamp >= %(start)s ) SELECT code, close, timestamp, high, low, volume, amt FROM latest WHERE rn = 1 """ result = client.execute(batch_query, { 'codes': full_codes, 'start': datetime.combine(start_date, dt_time(9, 30)) }) # 构建最新价格映射 latest_data_map = {} for row in result: code, close, ts, high, low, volume, amt = row latest_data_map[code] = { 'close': float(close), 'timestamp': ts, 'high': float(high), 'low': float(low), 'volume': int(volume), 'amount': float(amt) } # 批量查询前收盘价(使用 ea_trade 表,更准确) prev_close_map = {} if latest_data_map: # 获取前一交易日 prev_trading_day = None for td in reversed(trading_days): if td < today: prev_trading_day = td break if prev_trading_day: base_codes = [code.split('.')[0] for code in full_codes] prev_day_str = prev_trading_day.strftime('%Y%m%d') with engine.connect() as conn: placeholders = ','.join([f':code{i}' for i in range(len(base_codes))]) params = {f'code{i}': code for i, code in enumerate(base_codes)} params['trade_date'] = prev_day_str prev_result = conn.execute(text(f""" SELECT SECCODE, F007N as close_price FROM ea_trade WHERE SECCODE IN ({placeholders}) AND TRADEDATE = :trade_date """), params).fetchall() for row in prev_result: base_code, close_price = row[0], row[1] if close_price: prev_close_map[base_code] = float(close_price) # 构建响应数据 quotes_data = {} for code6, full_code in code_mapping.items(): latest = latest_data_map.get(full_code) if latest: base_code = full_code.split('.')[0] prev_close = prev_close_map.get(base_code, latest['close']) change = latest['close'] - prev_close change_percent = (change / prev_close * 100) if prev_close > 0 else 0.0 quotes_data[code6] = { 'price': latest['close'], 'prev_close': prev_close, 'change': change, 'change_percent': change_percent, 'high': latest['high'], 'low': latest['low'], 'volume': latest['volume'], 'amount': latest['amount'], 'update_time': latest['timestamp'].strftime('%H:%M:%S') } response_data = [] for item in watchlist: code6, _ = _normalize_stock_input(item.stock_code) quote = quotes_data.get(code6 or item.stock_code, {}) response_data.append({ 'stock_code': code6 or item.stock_code, 'stock_name': item.stock_name or (code6 and _query_stock_name_by_code(code6)) or None, 'current_price': quote.get('price', 0), 'prev_close': quote.get('prev_close', 0), 'change': quote.get('change', 0), 'change_percent': quote.get('change_percent', 0), 'high': quote.get('high', 0), 'low': quote.get('low', 0), 'volume': quote.get('volume', 0), 'amount': quote.get('amount', 0), 'update_time': quote.get('update_time', ''), }) return jsonify({ 'success': True, 'data': response_data }) except Exception as e: print(f"获取实时行情失败: {str(e)}") import traceback traceback.print_exc() return jsonify({'success': False, 'error': '获取实时行情失败'}), 500 @app.route('/api/account/watchlist/', methods=['DELETE']) def remove_from_watchlist(stock_code): """从自选股移除""" if 'user_id' not in session: return jsonify({'success': False, 'error': '未登录'}), 401 code6, _ = _normalize_stock_input(stock_code) candidates = [] if code6: candidates = [code6, f"{code6}.SH", f"{code6}.SZ"] # 包含原始传入(以兼容历史) if stock_code not in candidates: candidates.append(stock_code) item = Watchlist.query.filter( Watchlist.user_id == session['user_id'], Watchlist.stock_code.in_(candidates) ).first() if not item: return jsonify({'success': False, 'error': '未找到自选项'}), 404 db.session.delete(item) db.session.commit() return jsonify({'success': True}) # 投资计划和复盘相关的模型 class InvestmentPlan(db.Model): __tablename__ = 'investment_plans' id = db.Column(db.Integer, primary_key=True, autoincrement=True) user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False) date = db.Column(db.Date, nullable=False) title = db.Column(db.String(200), nullable=False) content = db.Column(db.Text) type = db.Column(db.String(20)) # 'plan' or 'review' stocks = db.Column(db.Text) # JSON array of stock codes tags = db.Column(db.String(500)) # JSON array of tags status = db.Column(db.String(20), default='active') # active, completed, cancelled created_at = db.Column(db.DateTime, default=datetime.utcnow) updated_at = db.Column(db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) def to_dict(self): return { 'id': self.id, 'date': self.date.isoformat() if self.date else None, 'title': self.title, 'content': self.content, 'type': self.type, 'stocks': json.loads(self.stocks) if self.stocks else [], 'tags': json.loads(self.tags) if self.tags else [], 'status': self.status, 'created_at': self.created_at.isoformat() if self.created_at else None, 'updated_at': self.updated_at.isoformat() if self.updated_at else None } @app.route('/api/account/investment-plans', methods=['GET']) def get_investment_plans(): """获取投资计划和复盘记录""" try: if 'user_id' not in session: return jsonify({'success': False, 'error': '未登录'}), 401 plan_type = request.args.get('type') # 'plan', 'review', or None for all start_date = request.args.get('start_date') end_date = request.args.get('end_date') query = InvestmentPlan.query.filter_by(user_id=session['user_id']) if plan_type: query = query.filter_by(type=plan_type) if start_date: query = query.filter(InvestmentPlan.date >= datetime.fromisoformat(start_date).date()) if end_date: query = query.filter(InvestmentPlan.date <= datetime.fromisoformat(end_date).date()) plans = query.order_by(InvestmentPlan.date.desc()).all() return jsonify({ 'success': True, 'data': [plan.to_dict() for plan in plans] }) except Exception as e: print(f"获取投资计划失败: {str(e)}") return jsonify({'success': False, 'error': '获取数据失败'}), 500 @app.route('/api/account/investment-plans', methods=['POST']) def create_investment_plan(): """创建投资计划或复盘记录""" try: if 'user_id' not in session: return jsonify({'success': False, 'error': '未登录'}), 401 data = request.get_json() # 验证必要字段 if not data.get('date') or not data.get('title') or not data.get('type'): return jsonify({'success': False, 'error': '缺少必要字段'}), 400 plan = InvestmentPlan( user_id=session['user_id'], date=datetime.fromisoformat(data['date']).date(), title=data['title'], content=data.get('content', ''), type=data['type'], stocks=json.dumps(data.get('stocks', [])), tags=json.dumps(data.get('tags', [])), status=data.get('status', 'active') ) db.session.add(plan) db.session.commit() return jsonify({ 'success': True, 'data': plan.to_dict() }) except Exception as e: db.session.rollback() print(f"创建投资计划失败: {str(e)}") return jsonify({'success': False, 'error': '创建失败'}), 500 @app.route('/api/account/investment-plans/', methods=['PUT']) def update_investment_plan(plan_id): """更新投资计划或复盘记录""" try: if 'user_id' not in session: return jsonify({'success': False, 'error': '未登录'}), 401 plan = InvestmentPlan.query.filter_by(id=plan_id, user_id=session['user_id']).first() if not plan: return jsonify({'success': False, 'error': '未找到该记录'}), 404 data = request.get_json() if 'date' in data: plan.date = datetime.fromisoformat(data['date']).date() if 'title' in data: plan.title = data['title'] if 'content' in data: plan.content = data['content'] if 'stocks' in data: plan.stocks = json.dumps(data['stocks']) if 'tags' in data: plan.tags = json.dumps(data['tags']) if 'status' in data: plan.status = data['status'] plan.updated_at = datetime.utcnow() db.session.commit() return jsonify({ 'success': True, 'data': plan.to_dict() }) except Exception as e: db.session.rollback() print(f"更新投资计划失败: {str(e)}") return jsonify({'success': False, 'error': '更新失败'}), 500 @app.route('/api/account/investment-plans/', methods=['DELETE']) def delete_investment_plan(plan_id): """删除投资计划或复盘记录""" try: if 'user_id' not in session: return jsonify({'success': False, 'error': '未登录'}), 401 plan = InvestmentPlan.query.filter_by(id=plan_id, user_id=session['user_id']).first() if not plan: return jsonify({'success': False, 'error': '未找到该记录'}), 404 db.session.delete(plan) db.session.commit() return jsonify({'success': True}) except Exception as e: db.session.rollback() print(f"删除投资计划失败: {str(e)}") return jsonify({'success': False, 'error': '删除失败'}), 500 @app.route('/api/account/events/following', methods=['GET']) def get_my_following_events(): """获取我关注的事件列表""" if 'user_id' not in session: return jsonify({'success': False, 'error': '未登录'}), 401 follows = EventFollow.query.filter_by(user_id=session['user_id']).order_by(EventFollow.created_at.desc()).all() event_ids = [f.event_id for f in follows] if not event_ids: return jsonify({'success': True, 'data': []}) events = Event.query.filter(Event.id.in_(event_ids)).all() data = [] for ev in events: data.append({ 'id': ev.id, 'title': ev.title, 'event_type': ev.event_type, 'start_time': ev.start_time.isoformat() if ev.start_time else None, 'view_count': ev.view_count or 0, 'related_avg_chg': ev.related_avg_chg, 'follower_count': ev.follower_count, }) return jsonify({'success': True, 'data': data}) @app.route('/api/account/events/comments', methods=['GET']) def get_my_event_comments(): """获取我在事件上的评论(EventComment)""" if 'user_id' not in session: return jsonify({'success': False, 'error': '未登录'}), 401 comments = EventComment.query.filter_by(user_id=session['user_id']).order_by(EventComment.created_at.desc()).limit( 100).all() return jsonify({'success': True, 'data': [c.to_dict() for c in comments]}) @app.route('/api/account/events/posts', methods=['GET']) def get_my_event_posts(): """获取我在事件上的帖子(Post)- 用于个人中心显示""" if 'user_id' not in session: return jsonify({'success': False, 'error': '未登录'}), 401 try: # 查询当前用户的所有 Post(按创建时间倒序) posts = Post.query.filter_by( user_id=session['user_id'], status='active' ).order_by(Post.created_at.desc()).limit(100).all() posts_data = [] for post in posts: # 获取关联的事件信息 event = Event.query.get(post.event_id) event_title = event.title if event else '未知事件' # 获取用户信息 user = User.query.get(post.user_id) author = user.username if user else '匿名用户' # ⚡ 返回格式兼容旧 EventComment.to_dict() posts_data.append({ 'id': post.id, 'event_id': post.event_id, 'event_title': event_title, # ⚡ 新增字段(旧 API 没有) 'user_id': post.user_id, 'author': author, # ⚡ 兼容旧格式(字符串类型) 'content': post.content, 'title': post.title, # Post 独有字段(可选) 'content_type': post.content_type, # Post 独有字段 'likes': post.likes_count, # ⚡ 兼容旧字段名 'created_at': post.created_at.isoformat(), 'updated_at': post.updated_at.isoformat(), 'status': post.status, }) return jsonify({'success': True, 'data': posts_data}) except Exception as e: print(f"获取用户帖子失败: {e}") return jsonify({'success': False, 'error': '获取帖子失败'}), 500 @app.route('/api/account/future-events/following', methods=['GET']) def get_my_following_future_events(): """获取当前用户关注的未来事件""" if 'user_id' not in session: return jsonify({'success': False, 'error': '未登录'}), 401 try: # 获取用户关注的未来事件ID列表 follows = FutureEventFollow.query.filter_by(user_id=session['user_id']).all() future_event_ids = [f.future_event_id for f in follows] if not future_event_ids: return jsonify({'success': True, 'data': []}) # 查询未来事件详情 sql = """ SELECT * FROM future_events WHERE data_id IN :event_ids ORDER BY calendar_time \ """ result = db.session.execute( text(sql), {'event_ids': tuple(future_event_ids)} ) events = [] # 所有返回的事件都是已关注的 following_ids = set(future_event_ids) for row in result: event_data = process_future_event_row(row, following_ids) events.append(event_data) return jsonify({'success': True, 'data': events}) except Exception as e: return jsonify({'success': False, 'error': str(e)}), 500 class PostLike(db.Model): """帖子点赞""" id = db.Column(db.Integer, primary_key=True) user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False) post_id = db.Column(db.Integer, db.ForeignKey('post.id'), nullable=False) created_at = db.Column(db.DateTime, default=beijing_now) user = db.relationship('User', backref='post_likes') __table_args__ = (db.UniqueConstraint('user_id', 'post_id'),) # =========================== # 预测市场系统模型 # =========================== class UserCreditAccount(db.Model): """用户积分账户""" __tablename__ = 'user_credit_account' id = db.Column(db.Integer, primary_key=True) user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False, unique=True) # 积分余额 balance = db.Column(db.Float, default=10000.0, nullable=False) # 初始10000积分 frozen_balance = db.Column(db.Float, default=0.0, nullable=False) # 冻结积分 total_earned = db.Column(db.Float, default=0.0, nullable=False) # 累计获得 total_spent = db.Column(db.Float, default=0.0, nullable=False) # 累计消费 # 时间 created_at = db.Column(db.DateTime, default=beijing_now, nullable=False) updated_at = db.Column(db.DateTime, default=beijing_now, onupdate=beijing_now) last_daily_bonus_at = db.Column(db.DateTime) # 最后一次领取每日奖励时间 # 关系 user = db.relationship('User', backref=db.backref('credit_account', uselist=False)) def __repr__(self): return f'' class PredictionTopic(db.Model): """预测话题""" __tablename__ = 'prediction_topic' id = db.Column(db.Integer, primary_key=True) creator_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False) # 基本信息 title = db.Column(db.String(200), nullable=False) description = db.Column(db.Text) category = db.Column(db.String(50), default='stock') # stock/event/market # 市场数据 yes_total_shares = db.Column(db.Integer, default=0, nullable=False) # YES方总份额 no_total_shares = db.Column(db.Integer, default=0, nullable=False) # NO方总份额 yes_price = db.Column(db.Float, default=500.0, nullable=False) # YES方价格(0-1000) no_price = db.Column(db.Float, default=500.0, nullable=False) # NO方价格(0-1000) # 奖池 total_pool = db.Column(db.Float, default=0.0, nullable=False) # 总奖池(2%交易税累积) # 领主信息 yes_lord_id = db.Column(db.Integer, db.ForeignKey('user.id')) # YES方领主 no_lord_id = db.Column(db.Integer, db.ForeignKey('user.id')) # NO方领主 # 状态 status = db.Column(db.String(20), default='active', nullable=False) # active/settled/cancelled result = db.Column(db.String(10)) # yes/no/draw(结算结果) # 时间 deadline = db.Column(db.DateTime, nullable=False) # 截止时间 settled_at = db.Column(db.DateTime) # 结算时间 created_at = db.Column(db.DateTime, default=beijing_now, nullable=False) updated_at = db.Column(db.DateTime, default=beijing_now, onupdate=beijing_now) # 统计 views_count = db.Column(db.Integer, default=0) comments_count = db.Column(db.Integer, default=0) participants_count = db.Column(db.Integer, default=0) # 关系 creator = db.relationship('User', foreign_keys=[creator_id], backref='created_topics') yes_lord = db.relationship('User', foreign_keys=[yes_lord_id], backref='yes_lord_topics') no_lord = db.relationship('User', foreign_keys=[no_lord_id], backref='no_lord_topics') positions = db.relationship('PredictionPosition', backref='topic', lazy='dynamic') transactions = db.relationship('PredictionTransaction', backref='topic', lazy='dynamic') comments = db.relationship('TopicComment', backref='topic', lazy='dynamic') def __repr__(self): return f'' class PredictionPosition(db.Model): """用户持仓""" __tablename__ = 'prediction_position' id = db.Column(db.Integer, primary_key=True) user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False) topic_id = db.Column(db.Integer, db.ForeignKey('prediction_topic.id'), nullable=False) # 持仓信息 direction = db.Column(db.String(3), nullable=False) # yes/no shares = db.Column(db.Integer, default=0, nullable=False) # 持有份额 avg_cost = db.Column(db.Float, default=0.0, nullable=False) # 平均成本 total_invested = db.Column(db.Float, default=0.0, nullable=False) # 总投入 # 时间 created_at = db.Column(db.DateTime, default=beijing_now, nullable=False) updated_at = db.Column(db.DateTime, default=beijing_now, onupdate=beijing_now) # 关系 user = db.relationship('User', backref='prediction_positions') # 唯一约束:每个用户在每个话题的每个方向只能有一个持仓 __table_args__ = (db.UniqueConstraint('user_id', 'topic_id', 'direction'),) def __repr__(self): return f'' class PredictionTransaction(db.Model): """预测交易记录""" __tablename__ = 'prediction_transaction' id = db.Column(db.Integer, primary_key=True) user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False) topic_id = db.Column(db.Integer, db.ForeignKey('prediction_topic.id'), nullable=False) # 交易信息 trade_type = db.Column(db.String(10), nullable=False) # buy/sell direction = db.Column(db.String(3), nullable=False) # yes/no shares = db.Column(db.Integer, nullable=False) # 份额数量 price = db.Column(db.Float, nullable=False) # 成交价格 # 费用 amount = db.Column(db.Float, nullable=False) # 交易金额 tax = db.Column(db.Float, default=0.0, nullable=False) # 手续费(2%) total_cost = db.Column(db.Float, nullable=False) # 总成本(amount + tax) # 时间 created_at = db.Column(db.DateTime, default=beijing_now, nullable=False) # 关系 user = db.relationship('User', backref='prediction_transactions') def __repr__(self): return f'' class CreditTransaction(db.Model): """积分交易记录""" __tablename__ = 'credit_transaction' id = db.Column(db.Integer, primary_key=True) user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False) # 交易信息 transaction_type = db.Column(db.String(30), nullable=False) # prediction_buy/prediction_sell/daily_bonus/create_topic/settle_win amount = db.Column(db.Float, nullable=False) # 金额(正数=增加,负数=减少) balance_after = db.Column(db.Float, nullable=False) # 交易后余额 # 关联 related_topic_id = db.Column(db.Integer, db.ForeignKey('prediction_topic.id')) # 相关话题 related_transaction_id = db.Column(db.Integer, db.ForeignKey('prediction_transaction.id')) # 相关预测交易 # 描述 description = db.Column(db.String(200)) # 交易描述 # 时间 created_at = db.Column(db.DateTime, default=beijing_now, nullable=False) # 关系 user = db.relationship('User', backref='credit_transactions') related_topic = db.relationship('PredictionTopic', backref='credit_transactions') def __repr__(self): return f'' class TopicComment(db.Model): """话题评论""" __tablename__ = 'topic_comment' id = db.Column(db.Integer, primary_key=True) topic_id = db.Column(db.Integer, db.ForeignKey('prediction_topic.id'), nullable=False) user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False) # 内容 content = db.Column(db.Text, nullable=False) parent_id = db.Column(db.Integer, db.ForeignKey('topic_comment.id')) # 父评论ID(回复功能) # 状态 is_pinned = db.Column(db.Boolean, default=False, nullable=False) # 是否置顶(领主特权) status = db.Column(db.String(20), default='active') # active/hidden/deleted # 统计 likes_count = db.Column(db.Integer, default=0, nullable=False) # 观点IPO 相关 total_investment = db.Column(db.Integer, default=0, nullable=False) # 总投资额 investor_count = db.Column(db.Integer, default=0, nullable=False) # 投资人数 is_verified = db.Column(db.Boolean, default=False, nullable=False) # 是否已验证 verification_result = db.Column(db.String(20)) # 验证结果:correct/incorrect/null position_rank = db.Column(db.Integer) # 评论位置排名(用于首发权拍卖) # 时间 created_at = db.Column(db.DateTime, default=beijing_now, nullable=False) updated_at = db.Column(db.DateTime, default=beijing_now, onupdate=beijing_now) # 关系 user = db.relationship('User', backref='topic_comments') replies = db.relationship('TopicComment', backref=db.backref('parent', remote_side=[id]), lazy='dynamic') likes = db.relationship('TopicCommentLike', backref='comment', lazy='dynamic') def __repr__(self): return f'' class TopicCommentLike(db.Model): """话题评论点赞""" __tablename__ = 'topic_comment_like' id = db.Column(db.Integer, primary_key=True) comment_id = db.Column(db.Integer, db.ForeignKey('topic_comment.id'), nullable=False) user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False) created_at = db.Column(db.DateTime, default=beijing_now, nullable=False) # 关系 user = db.relationship('User', backref='topic_comment_likes') # 唯一约束 __table_args__ = (db.UniqueConstraint('comment_id', 'user_id'),) def __repr__(self): return f'' class CommentInvestment(db.Model): """评论投资记录(观点IPO)""" __tablename__ = 'comment_investment' id = db.Column(db.Integer, primary_key=True) comment_id = db.Column(db.Integer, db.ForeignKey('topic_comment.id'), nullable=False) user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False) # 投资数据 shares = db.Column(db.Integer, nullable=False) # 投资份额 amount = db.Column(db.Integer, nullable=False) # 投资金额 avg_price = db.Column(db.Float, nullable=False) # 平均价格 # 状态 status = db.Column(db.String(20), default='active', nullable=False) # active/settled # 时间 created_at = db.Column(db.DateTime, default=beijing_now, nullable=False) # 关系 user = db.relationship('User', backref='comment_investments') comment = db.relationship('TopicComment', backref='investments') def __repr__(self): return f'' class CommentPositionBid(db.Model): """评论位置竞拍记录(首发权拍卖)""" __tablename__ = 'comment_position_bid' id = db.Column(db.Integer, primary_key=True) topic_id = db.Column(db.Integer, db.ForeignKey('prediction_topic.id'), nullable=False) user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False) # 竞拍数据 position = db.Column(db.Integer, nullable=False) # 位置:1/2/3 bid_amount = db.Column(db.Integer, nullable=False) # 出价金额 status = db.Column(db.String(20), default='pending', nullable=False) # pending/won/lost # 时间 created_at = db.Column(db.DateTime, default=beijing_now, nullable=False) expires_at = db.Column(db.DateTime, nullable=False) # 竞拍截止时间 # 关系 user = db.relationship('User', backref='comment_position_bids') topic = db.relationship('PredictionTopic', backref='position_bids') def __repr__(self): return f'' class TimeCapsuleTopic(db.Model): """时间胶囊话题(长期预测)""" __tablename__ = 'time_capsule_topic' id = db.Column(db.Integer, primary_key=True) user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False) # 话题内容 title = db.Column(db.String(200), nullable=False) description = db.Column(db.Text) encrypted_content = db.Column(db.Text) # 加密的预测内容 encryption_key = db.Column(db.String(500)) # 加密密钥(后端存储) # 时间范围 start_year = db.Column(db.Integer, nullable=False) # 起始年份 end_year = db.Column(db.Integer, nullable=False) # 结束年份 # 状态 status = db.Column(db.String(20), default='active', nullable=False) # active/settled is_decrypted = db.Column(db.Boolean, default=False, nullable=False) # 是否已解密 actual_happened_year = db.Column(db.Integer) # 实际发生年份 # 统计 total_pool = db.Column(db.Integer, default=0, nullable=False) # 总奖池 # 时间 created_at = db.Column(db.DateTime, default=beijing_now, nullable=False) updated_at = db.Column(db.DateTime, default=beijing_now, onupdate=beijing_now) # 关系 user = db.relationship('User', backref='time_capsule_topics') time_slots = db.relationship('TimeCapsuleTimeSlot', backref='topic', lazy='dynamic') def __repr__(self): return f'' class TimeCapsuleTimeSlot(db.Model): """时间胶囊时间段""" __tablename__ = 'time_capsule_time_slot' id = db.Column(db.Integer, primary_key=True) topic_id = db.Column(db.Integer, db.ForeignKey('time_capsule_topic.id'), nullable=False) # 时间段 year_start = db.Column(db.Integer, nullable=False) year_end = db.Column(db.Integer, nullable=False) # 竞拍数据 current_holder_id = db.Column(db.Integer, db.ForeignKey('user.id')) # 当前持有者 current_price = db.Column(db.Integer, default=100, nullable=False) # 当前价格 total_bids = db.Column(db.Integer, default=0, nullable=False) # 总竞拍次数 # 状态 status = db.Column(db.String(20), default='active', nullable=False) # active/won/expired # 时间 created_at = db.Column(db.DateTime, default=beijing_now, nullable=False) updated_at = db.Column(db.DateTime, default=beijing_now, onupdate=beijing_now) # 关系 current_holder = db.relationship('User', foreign_keys=[current_holder_id]) bids = db.relationship('TimeSlotBid', backref='time_slot', lazy='dynamic') def __repr__(self): return f'' class TimeSlotBid(db.Model): """时间段竞拍记录""" __tablename__ = 'time_slot_bid' id = db.Column(db.Integer, primary_key=True) slot_id = db.Column(db.Integer, db.ForeignKey('time_capsule_time_slot.id'), nullable=False) user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False) # 竞拍数据 bid_amount = db.Column(db.Integer, nullable=False) status = db.Column(db.String(20), default='outbid', nullable=False) # outbid/holding/won # 时间 created_at = db.Column(db.DateTime, default=beijing_now, nullable=False) # 关系 user = db.relationship('User', backref='time_slot_bids') def __repr__(self): return f'' class Event(db.Model): """事件模型""" id = db.Column(db.Integer, primary_key=True) title = db.Column(db.String(200), nullable=False) description = db.Column(db.Text) # 事件类型与状态 event_type = db.Column(db.String(50)) status = db.Column(db.String(20), default='active') # 时间相关 start_time = db.Column(db.DateTime, default=beijing_now) end_time = db.Column(db.DateTime) created_at = db.Column(db.DateTime, default=beijing_now) updated_at = db.Column(db.DateTime, default=beijing_now) # 热度与统计 hot_score = db.Column(db.Float, default=0) view_count = db.Column(db.Integer, default=0) trending_score = db.Column(db.Float, default=0) post_count = db.Column(db.Integer, default=0) follower_count = db.Column(db.Integer, default=0) # 关联信息 related_industries = db.Column(db.String(20)) # 申万行业代码,如 "S640701" keywords = db.Column(db.JSON) files = db.Column(db.JSON) importance = db.Column(db.String(20)) related_avg_chg = db.Column(db.Float, default=0) related_max_chg = db.Column(db.Float, default=0) related_week_chg = db.Column(db.Float, default=0) # 新增字段 invest_score = db.Column(db.Integer) # 超预期得分 expectation_surprise_score = db.Column(db.Integer) # 创建者信息 creator_id = db.Column(db.Integer, db.ForeignKey('user.id')) creator = db.relationship('User', backref='created_events') # 关系 posts = db.relationship('Post', backref='event', lazy='dynamic') followers = db.relationship('EventFollow', backref='event', lazy='dynamic') related_stocks = db.relationship('RelatedStock', backref='event', lazy='dynamic') historical_events = db.relationship('HistoricalEvent', backref='event', lazy='dynamic') related_data = db.relationship('RelatedData', backref='event', lazy='dynamic') related_concepts = db.relationship('RelatedConcepts', backref='event', lazy='dynamic') @property def keywords_list(self): """返回解析后的关键词列表""" if not self.keywords: return [] if isinstance(self.keywords, list): return self.keywords try: # 如果是字符串,尝试解析JSON if isinstance(self.keywords, str): decoded = json.loads(self.keywords) # 处理Unicode编码的情况 if isinstance(decoded, list): return [ keyword.encode('utf-8').decode('unicode_escape') if isinstance(keyword, str) and '\\u' in keyword else keyword for keyword in decoded ] return [] # 如果已经是字典或其他格式,尝试转换为列表 return list(self.keywords) except (json.JSONDecodeError, AttributeError, TypeError): return [] def set_keywords(self, keywords): """设置关键词列表""" if isinstance(keywords, list): self.keywords = json.dumps(keywords, ensure_ascii=False) elif isinstance(keywords, str): try: # 尝试解析JSON字符串 parsed = json.loads(keywords) if isinstance(parsed, list): self.keywords = json.dumps(parsed, ensure_ascii=False) else: self.keywords = json.dumps([keywords], ensure_ascii=False) except json.JSONDecodeError: # 如果不是有效的JSON,将其作为单个关键词 self.keywords = json.dumps([keywords], ensure_ascii=False) class RelatedStock(db.Model): """相关标的模型""" id = db.Column(db.Integer, primary_key=True) event_id = db.Column(db.Integer, db.ForeignKey('event.id')) stock_code = db.Column(db.String(20)) # 股票代码 stock_name = db.Column(db.String(100)) # 股票名称 sector = db.Column(db.String(100)) # 关联类型 relation_desc = db.Column(db.String(1024)) # 关联原因描述 created_at = db.Column(db.DateTime, default=beijing_now) updated_at = db.Column(db.DateTime, default=beijing_now, onupdate=beijing_now) correlation = db.Column(db.Float()) momentum = db.Column(db.String(1024)) # 动量 retrieved_sources = db.Column(db.JSON) # 动量 class RelatedData(db.Model): """关联数据模型""" id = db.Column(db.Integer, primary_key=True) event_id = db.Column(db.Integer, db.ForeignKey('event.id')) title = db.Column(db.String(200)) # 数据标题 data_type = db.Column(db.String(50)) # 数据类型 data_content = db.Column(db.JSON) # 数据内容(JSON格式) description = db.Column(db.Text) # 数据描述 created_at = db.Column(db.DateTime, default=beijing_now) class RelatedConcepts(db.Model): """相关概念模型(AI分析结果)""" __tablename__ = 'related_concepts' id = db.Column(db.Integer, primary_key=True) event_id = db.Column(db.Integer, db.ForeignKey('event.id')) concept = db.Column(db.String(255)) # 概念名称 reason = db.Column(db.Text) # 关联原因(AI分析) created_at = db.Column(db.DateTime, default=beijing_now) class EventHotHistory(db.Model): """事件热度历史记录""" id = db.Column(db.Integer, primary_key=True) event_id = db.Column(db.Integer, db.ForeignKey('event.id')) score = db.Column(db.Float) # 总分 interaction_score = db.Column(db.Float) # 互动分数 follow_score = db.Column(db.Float) # 关注度分数 view_score = db.Column(db.Float) # 浏览量分数 recent_activity_score = db.Column(db.Float) # 最近活跃度分数 time_decay = db.Column(db.Float) # 时间衰减因子 created_at = db.Column(db.DateTime, default=beijing_now) event = db.relationship('Event', backref='hot_history') class EventTransmissionNode(db.Model): """事件传导节点模型""" __tablename__ = 'event_transmission_nodes' id = db.Column(db.Integer, primary_key=True) event_id = db.Column(db.Integer, db.ForeignKey('event.id'), nullable=False) node_type = db.Column(db.Enum('company', 'industry', 'policy', 'technology', 'market', 'event', 'other'), nullable=False) node_name = db.Column(db.String(200), nullable=False) node_description = db.Column(db.Text) importance_score = db.Column(db.Integer, default=50) stock_code = db.Column(db.String(20)) is_main_event = db.Column(db.Boolean, default=False) created_at = db.Column(db.DateTime, default=beijing_now) updated_at = db.Column(db.DateTime, default=beijing_now, onupdate=beijing_now) # Relationships event = db.relationship('Event', backref='transmission_nodes') outgoing_edges = db.relationship('EventTransmissionEdge', foreign_keys='EventTransmissionEdge.from_node_id', backref='from_node', cascade='all, delete-orphan') incoming_edges = db.relationship('EventTransmissionEdge', foreign_keys='EventTransmissionEdge.to_node_id', backref='to_node', cascade='all, delete-orphan') __table_args__ = ( db.Index('idx_event_id', 'event_id'), db.Index('idx_node_type', 'node_type'), db.Index('idx_main_event', 'is_main_event'), ) class EventTransmissionEdge(db.Model): """事件传导边模型""" __tablename__ = 'event_transmission_edges' id = db.Column(db.Integer, primary_key=True) event_id = db.Column(db.Integer, db.ForeignKey('event.id'), nullable=False) from_node_id = db.Column(db.Integer, db.ForeignKey('event_transmission_nodes.id'), nullable=False) to_node_id = db.Column(db.Integer, db.ForeignKey('event_transmission_nodes.id'), nullable=False) transmission_type = db.Column(db.Enum('supply_chain', 'competition', 'policy', 'technology', 'capital_flow', 'expectation', 'cyclic_effect', 'other'), nullable=False) transmission_mechanism = db.Column(db.Text) direction = db.Column(db.Enum('positive', 'negative', 'neutral', 'mixed'), default='neutral') strength = db.Column(db.Integer, default=50) impact = db.Column(db.Text) is_circular = db.Column(db.Boolean, default=False) created_at = db.Column(db.DateTime, default=beijing_now) updated_at = db.Column(db.DateTime, default=beijing_now, onupdate=beijing_now) # Relationship event = db.relationship('Event', backref='transmission_edges') __table_args__ = ( db.Index('idx_event_id', 'event_id'), db.Index('idx_strength', 'strength'), db.Index('idx_from_to', 'from_node_id', 'to_node_id'), db.Index('idx_circular', 'is_circular'), ) # 在 paste-2.txt 的模型定义部分添加 class EventSankeyFlow(db.Model): """事件桑基流模型""" __tablename__ = 'event_sankey_flows' id = db.Column(db.Integer, primary_key=True) event_id = db.Column(db.Integer, db.ForeignKey('event.id'), nullable=False) # 流的基本信息 source_node = db.Column(db.String(200), nullable=False) source_type = db.Column(db.Enum('event', 'policy', 'technology', 'industry', 'company', 'product'), nullable=False) source_level = db.Column(db.Integer, nullable=False, default=0) target_node = db.Column(db.String(200), nullable=False) target_type = db.Column(db.Enum('policy', 'technology', 'industry', 'company', 'product'), nullable=False) target_level = db.Column(db.Integer, nullable=False, default=1) # 流量信息 flow_value = db.Column(db.Numeric(10, 2), nullable=False) flow_ratio = db.Column(db.Numeric(5, 4), nullable=False) # 传导机制 transmission_path = db.Column(db.String(500)) impact_description = db.Column(db.Text) evidence_strength = db.Column(db.Integer, default=50) # 时间戳 created_at = db.Column(db.DateTime, default=beijing_now) updated_at = db.Column(db.DateTime, default=beijing_now, onupdate=beijing_now) # 关系 event = db.relationship('Event', backref='sankey_flows') __table_args__ = ( db.Index('idx_event_id', 'event_id'), db.Index('idx_source_target', 'source_node', 'target_node'), db.Index('idx_levels', 'source_level', 'target_level'), db.Index('idx_flow_value', 'flow_value'), ) class HistoricalEvent(db.Model): """历史事件模型""" id = db.Column(db.Integer, primary_key=True) event_id = db.Column(db.Integer, db.ForeignKey('event.id')) title = db.Column(db.String(200)) content = db.Column(db.Text) event_date = db.Column(db.DateTime) relevance = db.Column(db.Integer) # 相关性 importance = db.Column(db.Integer) # 重要程度 related_stock = db.Column(db.JSON) # 保留JSON字段 created_at = db.Column(db.DateTime, default=beijing_now) # 新增关系 stocks = db.relationship('HistoricalEventStock', backref='historical_event', lazy='dynamic', cascade='all, delete-orphan') class HistoricalEventStock(db.Model): """历史事件相关股票模型""" __tablename__ = 'historical_event_stocks' id = db.Column(db.Integer, primary_key=True) historical_event_id = db.Column(db.Integer, db.ForeignKey('historical_event.id'), nullable=False) stock_code = db.Column(db.String(20), nullable=False) stock_name = db.Column(db.String(50)) relation_desc = db.Column(db.Text) correlation = db.Column(db.Float, default=0.5) sector = db.Column(db.String(100)) created_at = db.Column(db.DateTime, default=beijing_now) __table_args__ = ( db.UniqueConstraint('historical_event_id', 'stock_code', name='unique_event_stock'), ) # === 股票盈利预测(自有表) === class StockForecastData(db.Model): """股票盈利预测数据 源于本地表 stock_forecast_data,由独立离线程序写入。 字段与表结构保持一致,仅用于读取聚合后输出前端报表所需的结构。 """ __tablename__ = 'stock_forecast_data' id = db.Column(db.Integer, primary_key=True) stock_code = db.Column(db.String(6), nullable=False) indicator_name = db.Column(db.String(50), nullable=False) year_2022a = db.Column(db.Numeric(15, 2)) year_2023a = db.Column(db.Numeric(15, 2)) year_2024a = db.Column(db.Numeric(15, 2)) year_2025e = db.Column(db.Numeric(15, 2)) year_2026e = db.Column(db.Numeric(15, 2)) year_2027e = db.Column(db.Numeric(15, 2)) process_time = db.Column(db.DateTime, nullable=False) __table_args__ = ( db.UniqueConstraint('stock_code', 'indicator_name', name='unique_stock_indicator'), ) def values_by_year(self): years = ['2022A', '2023A', '2024A', '2025E', '2026E', '2027E'] vals = [self.year_2022a, self.year_2023a, self.year_2024a, self.year_2025e, self.year_2026e, self.year_2027e] def _to_float(x): try: return float(x) if x is not None else None except Exception: return None return years, [_to_float(v) for v in vals] @app.route('/api/events/', methods=['GET']) def get_event_detail(event_id): """获取事件详情""" try: event = Event.query.get_or_404(event_id) # 增加浏览计数 event.view_count += 1 db.session.commit() return jsonify({ 'success': True, 'data': { 'id': event.id, 'title': event.title, 'description': event.description, 'event_type': event.event_type, 'status': event.status, 'start_time': event.start_time.isoformat() if event.start_time else None, 'end_time': event.end_time.isoformat() if event.end_time else None, 'created_at': event.created_at.isoformat() if event.created_at else None, 'hot_score': event.hot_score, 'view_count': event.view_count, 'trending_score': event.trending_score, 'post_count': event.post_count, 'follower_count': event.follower_count, 'related_industries': event.related_industries, 'keywords': event.keywords_list, 'importance': event.importance, 'related_avg_chg': event.related_avg_chg, 'related_max_chg': event.related_max_chg, 'related_week_chg': event.related_week_chg, 'invest_score': event.invest_score, 'expectation_surprise_score': event.expectation_surprise_score, 'creator_id': event.creator_id, 'has_chain_analysis': ( EventTransmissionNode.query.filter_by(event_id=event_id).first() is not None or EventSankeyFlow.query.filter_by(event_id=event_id).first() is not None ), 'is_following': False, # 需要根据当前用户状态判断 } }) except Exception as e: return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/events//stocks', methods=['GET']) def get_related_stocks(event_id): """获取相关股票列表""" try: # 订阅控制:相关标的需要 Pro 及以上 if not _has_required_level('pro'): return jsonify({'success': False, 'error': '需要Pro订阅', 'required_level': 'pro'}), 403 event = Event.query.get_or_404(event_id) stocks = event.related_stocks.order_by(RelatedStock.correlation.desc()).all() stocks_data = [] for stock in stocks: # 处理 relation_desc:只有当 retrieved_sources 是数组时才使用新格式 if stock.retrieved_sources is not None and isinstance(stock.retrieved_sources, list): # retrieved_sources 是有效数组,使用新格式 relation_desc_value = {"data": stock.retrieved_sources} else: # retrieved_sources 不是数组(可能是 {"raw": "..."} 等异常格式),回退到原始文本 relation_desc_value = stock.relation_desc stocks_data.append({ 'id': stock.id, 'stock_code': stock.stock_code, 'stock_name': stock.stock_name, 'sector': stock.sector, 'relation_desc': relation_desc_value, 'retrieved_sources': stock.retrieved_sources, 'correlation': stock.correlation, 'momentum': stock.momentum, 'created_at': stock.created_at.isoformat() if stock.created_at else None, 'updated_at': stock.updated_at.isoformat() if stock.updated_at else None }) return jsonify({ 'success': True, 'data': stocks_data }) except Exception as e: return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/events//stocks', methods=['POST']) def add_related_stock(event_id): """添加相关股票""" try: event = Event.query.get_or_404(event_id) data = request.get_json() # 验证必要字段 if not data.get('stock_code') or not data.get('relation_desc'): return jsonify({'success': False, 'error': '缺少必要字段'}), 400 # 检查是否已存在 existing = RelatedStock.query.filter_by( event_id=event_id, stock_code=data['stock_code'] ).first() if existing: return jsonify({'success': False, 'error': '该股票已存在'}), 400 # 创建新的相关股票记录 new_stock = RelatedStock( event_id=event_id, stock_code=data['stock_code'], stock_name=data.get('stock_name', ''), sector=data.get('sector', ''), relation_desc=data['relation_desc'], correlation=data.get('correlation', 0.5), momentum=data.get('momentum', '') ) db.session.add(new_stock) db.session.commit() return jsonify({ 'success': True, 'data': { 'id': new_stock.id, 'stock_code': new_stock.stock_code, 'relation_desc': new_stock.relation_desc } }) except Exception as e: db.session.rollback() return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/stocks/', methods=['DELETE']) def delete_related_stock(stock_id): """删除相关股票""" try: stock = RelatedStock.query.get_or_404(stock_id) db.session.delete(stock) db.session.commit() return jsonify({'success': True, 'message': '删除成功'}) except Exception as e: db.session.rollback() return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/events/by-stocks', methods=['POST']) def get_events_by_stocks(): """ 通过股票代码列表获取关联的事件(新闻) 用于概念中心时间轴:聚合概念下所有股票的相关新闻 请求体: { "stock_codes": ["000001.SZ", "600000.SH", ...], # 股票代码列表 "start_date": "2024-01-01", # 可选,开始日期 "end_date": "2024-12-31", # 可选,结束日期 "limit": 100 # 可选,限制返回数量,默认100 } """ try: data = request.get_json() stock_codes = data.get('stock_codes', []) start_date_str = data.get('start_date') end_date_str = data.get('end_date') limit = data.get('limit', 100) if not stock_codes: return jsonify({'success': False, 'error': '缺少股票代码列表'}), 400 # 转换股票代码格式:概念API返回的是不带后缀的(如600000), # 但related_stock表中存储的是带后缀的(如600000.SH) def normalize_stock_code(code): """将股票代码标准化为带后缀的格式""" if not code: return code # 如果已经带后缀,直接返回 if '.' in str(code): return code code = str(code).strip() # 根据代码前缀判断交易所 if code.startswith('6'): return f"{code}.SH" # 上海 elif code.startswith('0') or code.startswith('3'): return f"{code}.SZ" # 深圳 elif code.startswith('8') or code.startswith('4'): return f"{code}.BJ" # 北交所 else: return code # 未知格式,保持原样 # 同时包含带后缀和不带后缀的版本,提高匹配率 normalized_codes = set() for code in stock_codes: if code: normalized_codes.add(str(code)) # 原始格式 normalized_codes.add(normalize_stock_code(code)) # 带后缀格式 # 如果原始带后缀,也加入不带后缀的版本 if '.' in str(code): normalized_codes.add(str(code).split('.')[0]) # 构建查询:通过 RelatedStock 表找到关联的事件 query = db.session.query(Event).join( RelatedStock, Event.id == RelatedStock.event_id ).filter( RelatedStock.stock_code.in_(list(normalized_codes)) ) # 日期过滤(使用 start_time 字段) if start_date_str: try: start_date = datetime.strptime(start_date_str, '%Y-%m-%d') query = query.filter(Event.start_time >= start_date) except ValueError: pass if end_date_str: try: end_date = datetime.strptime(end_date_str, '%Y-%m-%d') # 设置为当天结束 end_date = end_date.replace(hour=23, minute=59, second=59) query = query.filter(Event.start_time <= end_date) except ValueError: pass # 去重并排序(使用 start_time 字段) query = query.distinct().order_by(Event.start_time.desc()) # 限制数量 if limit: query = query.limit(limit) events = query.all() # 构建返回数据 events_data = [] for event in events: # 获取该事件关联的股票信息(在请求的股票列表中的) related_stocks_in_list = [ { 'stock_code': rs.stock_code, 'stock_name': rs.stock_name, 'sector': rs.sector } for rs in event.related_stocks if rs.stock_code in stock_codes ] events_data.append({ 'id': event.id, 'title': event.title, 'description': event.description, 'event_date': event.start_time.isoformat() if event.start_time else None, 'published_time': event.start_time.strftime('%Y-%m-%d %H:%M:%S') if event.start_time else None, 'source': 'event', # 标记来源为事件系统 'importance': event.importance, 'view_count': event.view_count, 'hot_score': event.hot_score, 'related_stocks': related_stocks_in_list, 'event_type': event.event_type, 'created_at': event.created_at.isoformat() if event.created_at else None }) return jsonify({ 'success': True, 'data': events_data, 'total': len(events_data) }) except Exception as e: return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/events//concepts', methods=['GET']) def get_related_concepts(event_id): """获取相关概念列表(AI分析结果)""" try: # 订阅控制:相关概念需要 Pro 及以上 if not _has_required_level('pro'): return jsonify({'success': False, 'error': '需要Pro订阅', 'required_level': 'pro'}), 403 # 直接查询 related_concepts 表 concepts = RelatedConcepts.query.filter_by(event_id=event_id).all() concepts_data = [] for concept in concepts: concepts_data.append({ 'id': concept.id, 'concept': concept.concept, 'reason': concept.reason, 'created_at': concept.created_at.isoformat() if concept.created_at else None }) return jsonify({ 'success': True, 'data': concepts_data }) except Exception as e: return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/events//historical', methods=['GET']) def get_historical_events(event_id): """获取历史事件对比""" try: event = Event.query.get_or_404(event_id) historical_events = event.historical_events.order_by(HistoricalEvent.event_date.desc()).all() events_data = [] for hist_event in historical_events: events_data.append({ 'id': hist_event.id, 'title': hist_event.title, 'content': hist_event.content, 'event_date': hist_event.event_date.isoformat() if hist_event.event_date else None, 'importance': hist_event.importance, 'relevance': hist_event.relevance, 'created_at': hist_event.created_at.isoformat() if hist_event.created_at else None }) # 订阅控制:免费用户仅返回前2条;Pro/Max返回全部 info = _get_current_subscription_info() sub_type = (info.get('type') or 'free').lower() if sub_type == 'free': return jsonify({ 'success': True, 'data': events_data[:2], 'truncated': len(events_data) > 2, 'required_level': 'pro' }) return jsonify({'success': True, 'data': events_data}) except Exception as e: return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/historical-events//stocks', methods=['GET']) def get_historical_event_stocks(event_id): """获取历史事件相关股票列表""" try: # 直接查询历史事件,不需要通过主事件 hist_event = HistoricalEvent.query.get_or_404(event_id) stocks = hist_event.stocks.order_by(HistoricalEventStock.correlation.desc()).all() # 获取事件对应的交易日 event_trading_date = None if hist_event.event_date: event_trading_date = get_trading_day_near_date(hist_event.event_date) stocks_data = [] for stock in stocks: stock_data = { 'id': stock.id, 'stock_code': stock.stock_code, 'stock_name': stock.stock_name, 'sector': stock.sector, 'relation_desc': stock.relation_desc, 'correlation': stock.correlation, 'created_at': stock.created_at.isoformat() if stock.created_at else None } # 添加涨幅数据 if event_trading_date: try: # 查询股票在事件对应交易日的数据 # ea_trade 表字段:F007N=最近成交价(收盘价), F010N=涨跌幅 base_stock_code = stock.stock_code.split('.')[0] if stock.stock_code else '' # 日期格式转换为 YYYYMMDD 整数(ea_trade.TRADEDATE 是 int 类型) if hasattr(event_trading_date, 'strftime'): trade_date_int = int(event_trading_date.strftime('%Y%m%d')) else: trade_date_int = int(str(event_trading_date).replace('-', '')) with engine.connect() as conn: query = text(""" SELECT F007N as close_price, F010N as change_pct FROM ea_trade WHERE SECCODE = :stock_code AND TRADEDATE = :trading_date LIMIT 1 """) result = conn.execute(query, { 'stock_code': base_stock_code, 'trading_date': trade_date_int }).fetchone() if result: stock_data['event_day_close'] = float(result[0]) if result[0] else None stock_data['event_day_change_pct'] = float(result[1]) if result[1] else None print(f"[DEBUG] 股票{base_stock_code}在{trade_date_int}: close={result[0]}, change_pct={result[1]}") else: stock_data['event_day_close'] = None stock_data['event_day_change_pct'] = None except Exception as e: print(f"查询股票{stock.stock_code}在{event_trading_date}的数据失败: {e}") stock_data['event_day_close'] = None stock_data['event_day_change_pct'] = None else: stock_data['event_day_close'] = None stock_data['event_day_change_pct'] = None stocks_data.append(stock_data) return jsonify({ 'success': True, 'data': stocks_data, 'event_trading_date': event_trading_date.isoformat() if event_trading_date else None }) except Exception as e: return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/events//expectation-score', methods=['GET']) def get_expectation_score(event_id): """获取超预期得分""" try: event = Event.query.get_or_404(event_id) # 如果事件有超预期得分,直接返回 if event.expectation_surprise_score is not None: score = event.expectation_surprise_score else: # 如果没有,根据历史事件计算一个模拟得分 historical_events = event.historical_events.all() if historical_events: # 基于历史事件数量和重要性计算得分 total_importance = sum(ev.importance or 0 for ev in historical_events) avg_importance = total_importance / len(historical_events) if historical_events else 0 score = min(100, max(0, int(avg_importance * 20 + len(historical_events) * 5))) else: # 默认得分 score = 65 return jsonify({ 'success': True, 'data': { 'score': score, 'description': '基于历史事件判断当前事件的超预期情况,满分100分' } }) except Exception as e: return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/events//follow', methods=['POST']) def toggle_event_follow(event_id): """切换事件关注状态(需登录)""" if 'user_id' not in session: return jsonify({'success': False, 'error': '未登录'}), 401 try: event = Event.query.get_or_404(event_id) user_id = session['user_id'] existing = EventFollow.query.filter_by(user_id=user_id, event_id=event_id).first() if existing: # 取消关注 db.session.delete(existing) event.follower_count = max(0, (event.follower_count or 0) - 1) db.session.commit() return jsonify({'success': True, 'data': {'is_following': False, 'follower_count': event.follower_count}}) else: # 关注 follow = EventFollow(user_id=user_id, event_id=event_id) db.session.add(follow) event.follower_count = (event.follower_count or 0) + 1 db.session.commit() return jsonify({'success': True, 'data': {'is_following': True, 'follower_count': event.follower_count}}) except Exception as e: db.session.rollback() return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/events//transmission', methods=['GET']) def get_transmission_chain(event_id): try: # 订阅控制:传导链分析需要 Max 及以上 if not _has_required_level('max'): return jsonify({'success': False, 'error': '需要Max订阅', 'required_level': 'max'}), 403 # 确保数据库连接是活跃的 db.session.execute(text('SELECT 1')) event = Event.query.get_or_404(event_id) nodes = EventTransmissionNode.query.filter_by(event_id=event_id).all() edges = EventTransmissionEdge.query.filter_by(event_id=event_id).all() # 过滤孤立节点 connected_node_ids = set() for edge in edges: connected_node_ids.add(edge.from_node_id) connected_node_ids.add(edge.to_node_id) # 只保留有连接的节点 connected_nodes = [node for node in nodes if node.id in connected_node_ids] # 如果没有主事件节点,也保留主事件节点 main_event_node = next((node for node in nodes if node.is_main_event), None) if main_event_node and main_event_node not in connected_nodes: connected_nodes.append(main_event_node) if not connected_nodes: return jsonify({'success': False, 'message': '暂无传导链分析数据'}) # 节点类型到中文类别的映射 categories = { 'event': "事件", 'industry': "行业", 'company': "公司", 'policy': "政策", 'technology': "技术", 'market': "市场", 'other': "其他" } nodes_data = [] for node in connected_nodes: node_category = categories.get(node.node_type, "其他") nodes_data.append({ 'id': str(node.id), # 转换为字符串以保持一致性 'name': node.node_name, 'category': node_category, 'value': node.importance_score or 20, 'extra': { 'node_type': node.node_type, 'description': node.node_description, 'importance_score': node.importance_score, 'stock_code': node.stock_code, 'is_main_event': node.is_main_event } }) edges_data = [] for edge in edges: # 确保边的两端节点都在连接节点列表中 if edge.from_node_id in connected_node_ids and edge.to_node_id in connected_node_ids: edges_data.append({ 'source': str(edge.from_node_id), # 转换为字符串以保持一致性 'target': str(edge.to_node_id), # 转换为字符串以保持一致性 'value': edge.strength or 50, 'extra': { 'transmission_type': edge.transmission_type, 'transmission_mechanism': edge.transmission_mechanism, 'direction': edge.direction, 'strength': edge.strength, 'impact': edge.impact, 'is_circular': edge.is_circular, } }) return jsonify({ 'success': True, 'data': { 'nodes': nodes_data, 'edges': edges_data } }) except Exception as e: return jsonify({'success': False, 'error': str(e)}), 500 # 修复股票报价API - 支持GET和POST方法 @app.route('/api/stock/quotes', methods=['GET', 'POST']) def get_stock_quotes(): """ 获取股票行情数据(使用全局交易日数据,与 batch-kline 保持一致) - 股票名称:从 MySQL ea_stocklist 查询 - 交易日数据:使用全局 trading_days(从 tdays.csv 加载) - 前一交易日收盘价:从 MySQL ea_trade 查询 - 实时价格:从 ClickHouse stock_minute 查询 """ try: if request.method == 'GET': codes_str = request.args.get('codes', '') codes = [code.strip() for code in codes_str.split(',') if code.strip()] event_time_str = request.args.get('event_time') else: codes = request.json.get('codes', []) event_time_str = request.json.get('event_time') if not codes: return jsonify({'success': False, 'error': '请提供股票代码'}), 400 # 标准化股票代码 def normalize_stock_code(code): if '.' in code: return code if code.startswith(('6',)): return f"{code}.SH" elif code.startswith(('8', '9', '4')): return f"{code}.BJ" else: return f"{code}.SZ" original_codes = codes normalized_codes = [normalize_stock_code(code) for code in codes] code_mapping = dict(zip(original_codes, normalized_codes)) # 处理事件时间 if event_time_str: try: event_time = datetime.fromisoformat(event_time_str.replace('Z', '+00:00')) except: event_time = datetime.now() else: event_time = datetime.now() current_time = datetime.now() # ==================== 查询股票名称(使用 Redis 缓存) ==================== base_codes = list(set([code.split('.')[0] for code in codes])) stock_names = get_cached_stock_names(base_codes) # 构建完整的名称映射 full_stock_names = {} for orig_code, norm_code in code_mapping.items(): base_code = orig_code.split('.')[0] name = stock_names.get(base_code, f"股票{base_code}") full_stock_names[orig_code] = name full_stock_names[norm_code] = name # ==================== 使用全局交易日数据(处理跨周末场景) ==================== # 使用新的辅助函数处理跨周末场景: # - 周五15:00后到周一15:00前,分时图显示周一行情,涨跌幅基于周五收盘价 target_date, prev_trading_day = get_target_and_prev_trading_day(event_time) if not target_date: return jsonify({ 'success': True, 'data': {code: {'name': full_stock_names.get(code, f'股票{code}'), 'price': None, 'change': None} for code in original_codes} }) start_datetime = datetime.combine(target_date, dt_time(9, 30)) end_datetime = datetime.combine(target_date, dt_time(15, 0)) results = {} print(f"批量处理 {len(codes)} 只股票: {codes[:5]}{'...' if len(codes) > 5 else ''}, 目标交易日: {target_date}, 涨跌幅基准日: {prev_trading_day}, 时间范围: {start_datetime} - {end_datetime}") # 初始化 ClickHouse 客户端 client = get_clickhouse_client() # ==================== 查询前一交易日收盘价(使用 Redis 缓存) ==================== try: prev_close_map = {} if prev_trading_day: # ea_trade 表的 TRADEDATE 格式是 YYYYMMDD(无连字符) prev_day_str = prev_trading_day.strftime('%Y%m%d') if hasattr(prev_trading_day, 'strftime') else str(prev_trading_day).replace('-', '') base_codes = list(set([code.split('.')[0] for code in codes])) # 使用 Redis 缓存获取前收盘价 base_close_map = get_cached_prev_close(base_codes, prev_day_str) print(f"前一交易日({prev_day_str})收盘价: 获取到 {len(base_close_map)} 条(Redis缓存)") # 为每个标准化代码分配收盘价 for norm_code in normalized_codes: base_code = norm_code.split('.')[0] if base_code in base_close_map: prev_close_map[norm_code] = base_close_map[base_code] # 批量查询当前价格数据(从 ClickHouse) # 使用 argMax 函数获取最新价格,比窗口函数效率高很多 batch_price_query = """ SELECT code, argMax(close, timestamp) as last_price FROM stock_minute WHERE code IN %(codes)s AND timestamp >= %(start)s AND timestamp <= %(end)s GROUP BY code """ batch_data = client.execute(batch_price_query, { 'codes': normalized_codes, 'start': start_datetime, 'end': end_datetime }) print(f"批量查询返回 {len(batch_data)} 条价格数据") # 解析批量查询结果 price_data_map = {} for row in batch_data: code = row[0] last_price = float(row[1]) if row[1] is not None else None prev_close = prev_close_map.get(code) # 计算涨跌幅 change_pct = None if last_price is not None and prev_close is not None and prev_close > 0: change_pct = (last_price - prev_close) / prev_close * 100 price_data_map[code] = { 'price': last_price, 'change': change_pct } # 组装结果 for orig_code in original_codes: norm_code = code_mapping[orig_code] price_info = price_data_map.get(norm_code) if price_info: results[orig_code] = { 'price': price_info['price'], 'change': price_info['change'], 'name': full_stock_names.get(orig_code, f'股票{orig_code.split(".")[0]}') } else: results[orig_code] = { 'price': None, 'change': None, 'name': full_stock_names.get(orig_code, f'股票{orig_code.split(".")[0]}') } except Exception as e: print(f"批量查询失败: {e},回退到逐只查询") # 降级方案:逐只股票查询 for orig_code in original_codes: norm_code = code_mapping[orig_code] try: # 查询当前价格 current_data = client.execute(""" SELECT close FROM stock_minute WHERE code = %(code)s AND timestamp >= %(start)s AND timestamp <= %(end)s ORDER BY timestamp DESC LIMIT 1 """, {'code': norm_code, 'start': start_datetime, 'end': end_datetime}) last_price = float(current_data[0][0]) if current_data and current_data[0] and current_data[0][0] else None # 查询前一交易日收盘价 prev_close = None if prev_trading_day and last_price is not None: base_code = orig_code.split('.')[0] # ea_trade 表的 TRADEDATE 格式是 YYYYMMDD(无连字符) prev_day_str = prev_trading_day.strftime('%Y%m%d') if hasattr(prev_trading_day, 'strftime') else str(prev_trading_day).replace('-', '') with engine.connect() as conn: prev_result = conn.execute(text(""" SELECT F007N as close_price FROM ea_trade WHERE SECCODE = :code AND TRADEDATE = :trade_date """), {'code': base_code, 'trade_date': prev_day_str}).fetchone() prev_close = float(prev_result[0]) if prev_result and prev_result[0] else None # 计算涨跌幅 change_pct = None if last_price is not None and prev_close is not None and prev_close > 0: change_pct = (last_price - prev_close) / prev_close * 100 results[orig_code] = { 'price': last_price, 'change': change_pct, 'name': full_stock_names.get(orig_code, f'股票{orig_code.split(".")[0]}') } except Exception as inner_e: print(f"Error processing stock {orig_code}: {inner_e}") results[orig_code] = {'price': None, 'change': None, 'name': full_stock_names.get(orig_code, f'股票{orig_code.split(".")[0]}')} # 返回标准格式 return jsonify({'success': True, 'data': results}) except Exception as e: print(f"Stock quotes API error: {e}") return jsonify({'success': False, 'error': str(e)}), 500 # ==================== ClickHouse 连接池(单例模式) ==================== _clickhouse_client = None _clickhouse_client_lock = threading.Lock() def get_clickhouse_client(): """获取 ClickHouse 客户端(单例模式,避免重复创建连接)""" global _clickhouse_client if _clickhouse_client is None: with _clickhouse_client_lock: if _clickhouse_client is None: _clickhouse_client = Cclient( host='127.0.0.1', port=9000, user='default', password='Zzl33818!', database='stock' ) print("[ClickHouse] 创建新连接(单例)") return _clickhouse_client @app.route('/api/account/calendar/events', methods=['GET', 'POST']) def account_calendar_events(): """返回当前用户的投资计划与关注的未来事件(合并)。 GET: 可按日期范围/月份过滤;POST: 新增投资计划(写入 InvestmentPlan)。 """ try: if 'user_id' not in session: return jsonify({'success': False, 'error': '未登录'}), 401 if request.method == 'POST': data = request.get_json() or {} title = data.get('title') event_date_str = data.get('event_date') or data.get('date') plan_type = data.get('type') or 'plan' description = data.get('description') or data.get('content') or '' stocks = data.get('stocks') or [] if not title or not event_date_str: return jsonify({'success': False, 'error': '缺少必填字段'}), 400 try: event_date = datetime.fromisoformat(event_date_str).date() except Exception: return jsonify({'success': False, 'error': '日期格式错误'}), 400 plan = InvestmentPlan( user_id=session['user_id'], date=event_date, title=title, content=description, type=plan_type, stocks=json.dumps(stocks), tags=json.dumps(data.get('tags', [])), status=data.get('status', 'active') ) db.session.add(plan) db.session.commit() return jsonify({'success': True, 'data': { 'id': plan.id, 'title': plan.title, 'event_date': plan.date.isoformat(), 'type': plan.type, 'description': plan.content, 'stocks': json.loads(plan.stocks) if plan.stocks else [], 'source': 'plan' }}) # GET # 解析过滤参数:date 或 (year, month) 或 (start_date, end_date) date_str = request.args.get('date') year = request.args.get('year', type=int) month = request.args.get('month', type=int) start_date_str = request.args.get('start_date') end_date_str = request.args.get('end_date') start_date = None end_date = None if date_str: try: d = datetime.fromisoformat(date_str).date() start_date = d end_date = d except Exception: pass elif year and month: # 月份范围 start_date = datetime(year, month, 1).date() if month == 12: end_date = datetime(year + 1, 1, 1).date() - timedelta(days=1) else: end_date = datetime(year, month + 1, 1).date() - timedelta(days=1) elif start_date_str and end_date_str: try: start_date = datetime.fromisoformat(start_date_str).date() end_date = datetime.fromisoformat(end_date_str).date() except Exception: start_date = None end_date = None # 查询投资计划 plans_query = InvestmentPlan.query.filter_by(user_id=session['user_id']) if start_date and end_date: plans_query = plans_query.filter(InvestmentPlan.date >= start_date, InvestmentPlan.date <= end_date) elif start_date: plans_query = plans_query.filter(InvestmentPlan.date == start_date) plans = plans_query.order_by(InvestmentPlan.date.asc()).all() plan_events = [{ 'id': p.id, 'title': p.title, 'event_date': p.date.isoformat(), 'type': p.type or 'plan', 'description': p.content, 'importance': 3, 'stocks': json.loads(p.stocks) if p.stocks else [], 'source': 'plan' } for p in plans] # 查询关注的未来事件 follows = FutureEventFollow.query.filter_by(user_id=session['user_id']).all() future_event_ids = [f.future_event_id for f in follows] future_events = [] if future_event_ids: # 使用 SELECT * 以便获取所有字段(包括新字段) base_sql = """ SELECT * FROM future_events WHERE data_id IN :event_ids \ """ params = {'event_ids': tuple(future_event_ids)} # 日期过滤(按 calendar_time 的日期) if start_date and end_date: base_sql += " AND DATE(calendar_time) BETWEEN :start_date AND :end_date" params.update({'start_date': start_date, 'end_date': end_date}) elif start_date: base_sql += " AND DATE(calendar_time) = :start_date" params.update({'start_date': start_date}) base_sql += " ORDER BY calendar_time" result = db.session.execute(text(base_sql), params) for row in result: # 使用新字段回退逻辑获取 former former_value = get_future_event_field(row, 'second_modified_text', 'former') # 获取 related_stocks,优先使用 best_matches best_matches = getattr(row, 'best_matches', None) if hasattr(row, 'best_matches') else None if best_matches and str(best_matches).strip(): rs = parse_best_matches(best_matches) else: rs = parse_json_field(getattr(row, 'related_stocks', None)) # 生成股票标签列表 stock_tags = [] try: for it in rs: if isinstance(it, dict): # 新结构 stock_tags.append(f"{it.get('code', '')} {it.get('name', '')}") elif isinstance(it, (list, tuple)) and len(it) >= 2: stock_tags.append(f"{it[0]} {it[1]}") elif isinstance(it, str): stock_tags.append(it) except Exception: pass future_events.append({ 'id': row.data_id, 'title': row.title, 'event_date': (row.calendar_time.date().isoformat() if row.calendar_time else None), 'type': 'future_event', 'importance': int(row.star) if getattr(row, 'star', None) is not None else 3, 'description': former_value or '', 'stocks': stock_tags, 'is_following': True, 'source': 'future' }) return jsonify({'success': True, 'data': plan_events + future_events}) except Exception as e: return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/account/calendar/events/', methods=['DELETE']) def delete_account_calendar_event(event_id): """删除用户创建的投资计划事件(不影响关注的未来事件)。""" try: if 'user_id' not in session: return jsonify({'success': False, 'error': '未登录'}), 401 plan = InvestmentPlan.query.filter_by(id=event_id, user_id=session['user_id']).first() if not plan: return jsonify({'success': False, 'error': '未找到该记录'}), 404 db.session.delete(plan) db.session.commit() return jsonify({'success': True}) except Exception as e: db.session.rollback() return jsonify({'success': False, 'error': str(e)}), 500 # ==================== 灵活屏实时行情 API ==================== # 从 ClickHouse 实时行情表获取最新数据(用于盘后/WebSocket 无数据时的回退) @app.route('/api/flex-screen/quotes', methods=['POST']) def get_flex_screen_quotes(): """ 获取灵活屏行情数据 优先从实时行情表查询,如果没有则从分钟线表查询 请求体: { "codes": ["000001.SZ", "399001.SZ", "600519.SH"], "include_order_book": false // 是否包含五档盘口 } 返回: { "success": true, "data": { "000001.SZ": { "security_id": "000001", "name": "平安银行", "last_px": 10.50, "prev_close_px": 10.20, "open_px": 10.30, "high_px": 10.55, "low_px": 10.15, "total_volume_trade": 1000000, "total_value_trade": 10500000, "change": 0.30, "change_pct": 2.94, "bid_prices": [10.49, 10.48, ...], "bid_volumes": [1000, 2000, ...], "ask_prices": [10.50, 10.51, ...], "ask_volumes": [800, 1200, ...], "update_time": "2024-12-11 15:00:00" }, ... }, "source": "realtime" | "minute" // 数据来源 } """ try: data = request.json or {} codes = data.get('codes', []) include_order_book = data.get('include_order_book', False) if not codes: return jsonify({'success': False, 'error': '请提供股票代码'}), 400 client = get_clickhouse_client() results = {} source = 'realtime' # 分离上交所和深交所代码 sse_codes = [] # 上交所 szse_stock_codes = [] # 深交所股票 szse_index_codes = [] # 深交所指数 for code in codes: base_code = code.split('.')[0] if code.endswith('.SH'): sse_codes.append(base_code) elif code.endswith('.SZ'): # 399 开头是指数 if base_code.startswith('399'): szse_index_codes.append(base_code) else: szse_stock_codes.append(base_code) # 获取股票名称 stock_names = {} with engine.connect() as conn: base_codes = list(set([code.split('.')[0] for code in codes])) if base_codes: placeholders = ','.join([f':code{i}' for i in range(len(base_codes))]) params = {f'code{i}': code for i, code in enumerate(base_codes)} result = conn.execute(text( f"SELECT SECCODE, SECNAME FROM ea_stocklist WHERE SECCODE IN ({placeholders})" ), params).fetchall() stock_names = {row[0]: row[1] for row in result} # 查询深交所股票实时行情 if szse_stock_codes: try: order_book_cols = "" if include_order_book: order_book_cols = """, bid_price1, bid_volume1, bid_price2, bid_volume2, bid_price3, bid_volume3, bid_price4, bid_volume4, bid_price5, bid_volume5, ask_price1, ask_volume1, ask_price2, ask_volume2, ask_price3, ask_volume3, ask_price4, ask_volume4, ask_price5, ask_volume5""" szse_stock_query = f""" SELECT security_id, last_price, prev_close, open_price, high_price, low_price, volume, amount, num_trades, upper_limit_price, lower_limit_price, trading_phase_code, trade_time {order_book_cols} FROM stock.szse_stock_realtime WHERE trade_date = today() AND security_id IN %(codes)s ORDER BY security_id, trade_time DESC LIMIT 1 BY security_id """ szse_stock_data = client.execute(szse_stock_query, {'codes': szse_stock_codes}) for row in szse_stock_data: security_id = row[0] full_code = f"{security_id}.SZ" last_px = float(row[1]) if row[1] else 0 prev_close = float(row[2]) if row[2] else 0 change = last_px - prev_close if last_px and prev_close else 0 change_pct = (change / prev_close * 100) if prev_close else 0 quote = { 'security_id': security_id, 'name': stock_names.get(security_id, ''), 'last_px': last_px, 'prev_close_px': prev_close, 'open_px': float(row[3]) if row[3] else 0, 'high_px': float(row[4]) if row[4] else 0, 'low_px': float(row[5]) if row[5] else 0, 'total_volume_trade': float(row[6]) if row[6] else 0, 'total_value_trade': float(row[7]) if row[7] else 0, 'num_trades': int(row[8]) if row[8] else 0, 'upper_limit_px': float(row[9]) if row[9] else None, 'lower_limit_px': float(row[10]) if row[10] else None, 'trading_phase_code': row[11], 'change': change, 'change_pct': change_pct, 'update_time': str(row[12]) if row[12] else None, } if include_order_book and len(row) > 13: quote['bid_prices'] = [float(row[i]) if row[i] else 0 for i in range(13, 23, 2)] quote['bid_volumes'] = [float(row[i]) if row[i] else 0 for i in range(14, 24, 2)] quote['ask_prices'] = [float(row[i]) if row[i] else 0 for i in range(23, 33, 2)] quote['ask_volumes'] = [float(row[i]) if row[i] else 0 for i in range(24, 34, 2)] results[full_code] = quote except Exception as e: print(f"查询深交所实时行情失败: {e}") # 查询深交所指数实时行情 if szse_index_codes: try: szse_index_query = """ SELECT security_id, current_index, prev_close, open_index, high_index, low_index, close_index, volume, amount, num_trades, trade_time FROM stock.szse_index_realtime WHERE trade_date = today() AND security_id IN %(codes)s ORDER BY security_id, trade_time DESC LIMIT 1 BY security_id """ szse_index_data = client.execute(szse_index_query, {'codes': szse_index_codes}) for row in szse_index_data: security_id = row[0] full_code = f"{security_id}.SZ" current_index = float(row[1]) if row[1] else 0 prev_close = float(row[2]) if row[2] else 0 change = current_index - prev_close if current_index and prev_close else 0 change_pct = (change / prev_close * 100) if prev_close else 0 results[full_code] = { 'security_id': security_id, 'name': stock_names.get(security_id, ''), 'last_px': current_index, 'prev_close_px': prev_close, 'open_px': float(row[3]) if row[3] else 0, 'high_px': float(row[4]) if row[4] else 0, 'low_px': float(row[5]) if row[5] else 0, 'close_px': float(row[6]) if row[6] else None, 'total_volume_trade': float(row[7]) if row[7] else 0, 'total_value_trade': float(row[8]) if row[8] else 0, 'num_trades': int(row[9]) if row[9] else 0, 'change': change, 'change_pct': change_pct, 'update_time': str(row[10]) if row[10] else None, 'bid_prices': [], 'bid_volumes': [], 'ask_prices': [], 'ask_volumes': [], } except Exception as e: print(f"查询深交所指数实时行情失败: {e}") # 查询上交所实时行情(如果有 sse_stock_realtime 表) if sse_codes: try: sse_query = """ SELECT security_id, last_price, prev_close, open_price, high_price, low_price, volume, amount, trade_time FROM stock.sse_stock_realtime WHERE trade_date = today() AND security_id IN %(codes)s ORDER BY security_id, trade_time DESC LIMIT 1 BY security_id """ sse_data = client.execute(sse_query, {'codes': sse_codes}) for row in sse_data: security_id = row[0] full_code = f"{security_id}.SH" last_px = float(row[1]) if row[1] else 0 prev_close = float(row[2]) if row[2] else 0 change = last_px - prev_close if last_px and prev_close else 0 change_pct = (change / prev_close * 100) if prev_close else 0 results[full_code] = { 'security_id': security_id, 'name': stock_names.get(security_id, ''), 'last_px': last_px, 'prev_close_px': prev_close, 'open_px': float(row[3]) if row[3] else 0, 'high_px': float(row[4]) if row[4] else 0, 'low_px': float(row[5]) if row[5] else 0, 'total_volume_trade': float(row[6]) if row[6] else 0, 'total_value_trade': float(row[7]) if row[7] else 0, 'change': change, 'change_pct': change_pct, 'update_time': str(row[8]) if row[8] else None, 'bid_prices': [], 'bid_volumes': [], 'ask_prices': [], 'ask_volumes': [], } except Exception as e: print(f"查询上交所实时行情失败: {e},尝试从分钟线表查询") # 对于实时表中没有数据的股票,从分钟线表查询 missing_codes = [code for code in codes if code not in results] if missing_codes: source = 'minute' if not results else 'mixed' try: # 从分钟线表查询最新数据 minute_query = """ SELECT code, close, open, high, low, volume, amt, timestamp FROM stock.stock_minute WHERE toDate(timestamp) = today() AND code IN %(codes)s ORDER BY code, timestamp DESC LIMIT 1 BY code """ minute_data = client.execute(minute_query, {'codes': missing_codes}) # 获取昨收价 prev_close_map = {} with engine.connect() as conn: base_codes = list(set([code.split('.')[0] for code in missing_codes])) if base_codes: # 获取上一交易日 prev_day_result = conn.execute(text(""" SELECT EXCHANGE_DATE FROM trading_days WHERE EXCHANGE_DATE < CURDATE() ORDER BY EXCHANGE_DATE DESC LIMIT 1 """)).fetchone() if prev_day_result: prev_day = prev_day_result[0] placeholders = ','.join([f':code{i}' for i in range(len(base_codes))]) params = {f'code{i}': code for i, code in enumerate(base_codes)} params['trade_date'] = prev_day prev_result = conn.execute(text(f""" SELECT SECCODE, F007N as close_price FROM ea_trade WHERE SECCODE IN ({placeholders}) AND TRADEDATE = :trade_date """), params).fetchall() prev_close_map = {row[0]: float(row[1]) if row[1] else 0 for row in prev_result} for row in minute_data: code = row[0] base_code = code.split('.')[0] last_px = float(row[1]) if row[1] else 0 prev_close = prev_close_map.get(base_code, 0) change = last_px - prev_close if last_px and prev_close else 0 change_pct = (change / prev_close * 100) if prev_close else 0 results[code] = { 'security_id': base_code, 'name': stock_names.get(base_code, ''), 'last_px': last_px, 'prev_close_px': prev_close, 'open_px': float(row[2]) if row[2] else 0, 'high_px': float(row[3]) if row[3] else 0, 'low_px': float(row[4]) if row[4] else 0, 'total_volume_trade': float(row[5]) if row[5] else 0, 'total_value_trade': float(row[6]) if row[6] else 0, 'change': change, 'change_pct': change_pct, 'update_time': str(row[7]) if row[7] else None, 'bid_prices': [], 'bid_volumes': [], 'ask_prices': [], 'ask_volumes': [], } except Exception as e: print(f"查询分钟线数据失败: {e}") return jsonify({ 'success': True, 'data': results, 'source': source }) except Exception as e: print(f"灵活屏行情查询失败: {e}") import traceback traceback.print_exc() return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/stock//kline') def get_stock_kline(stock_code): chart_type = request.args.get('type', 'minute') event_time = request.args.get('event_time') # 是否跳过"下一个交易日"逻辑: # - 如果没有传 event_time(灵活屏等实时行情场景),盘后应显示当天数据 # - 如果传了 event_time(Community 事件等场景),使用原逻辑 skip_next_day = event_time is None try: event_datetime = datetime.fromisoformat(event_time) if event_time else datetime.now() except ValueError: return jsonify({'error': 'Invalid event_time format'}), 400 # 确保股票代码包含后缀(ClickHouse 中数据带后缀) if '.' not in stock_code: if stock_code.startswith('6'): stock_code = f"{stock_code}.SH" # 上海 elif stock_code.startswith(('8', '9', '4')): stock_code = f"{stock_code}.BJ" # 北交所 else: stock_code = f"{stock_code}.SZ" # 深圳 # 获取股票名称 with engine.connect() as conn: result = conn.execute(text( "SELECT SECNAME FROM ea_stocklist WHERE SECCODE = :code" ), {"code": stock_code.split('.')[0]}).fetchone() stock_name = result[0] if result else 'Unknown' if chart_type == 'daily': return get_daily_kline(stock_code, event_datetime, stock_name) elif chart_type == 'minute': return get_minute_kline(stock_code, event_datetime, stock_name, skip_next_day=skip_next_day) elif chart_type == 'timeline': return get_timeline_data(stock_code, event_datetime, stock_name) else: # 对于未知的类型,返回错误 return jsonify({'error': f'Unsupported chart type: {chart_type}'}), 400 @app.route('/api/stock/batch-kline', methods=['POST']) def get_batch_kline_data(): """批量获取多只股票的K线/分时数据 请求体:{ codes: string[], type: 'timeline'|'daily', event_time?: string, days_before?: number, # 查询事件日期前多少天的数据,默认60,最大365 end_date?: string # 分页加载时指定结束日期(用于加载更早的数据) } 返回:{ success: true, data: { [code]: { data: [], trade_date: '', ... } }, has_more: boolean } """ try: data = request.json codes = data.get('codes', []) chart_type = data.get('type', 'timeline') event_time = data.get('event_time') days_before = min(int(data.get('days_before', 60)), 365) # 默认60天,最多365天 custom_end_date = data.get('end_date') # 用于分页加载更早数据 if not codes: return jsonify({'success': False, 'error': '请提供股票代码列表'}), 400 if len(codes) > 50: return jsonify({'success': False, 'error': '单次最多查询50只股票'}), 400 # 标准化股票代码(确保带后缀,用于 ClickHouse 查询) def normalize_stock_code(code): """将股票代码标准化为带后缀格式(如 300274.SZ)""" if '.' in code: return code # 已经带后缀 # 根据代码规则添加后缀 if code.startswith('6'): return f"{code}.SH" # 上海 elif code.startswith(('8', '9', '4')): return f"{code}.BJ" # 北交所 else: return f"{code}.SZ" # 深圳 # 保留原始代码用于返回结果,同时创建标准化代码用于 ClickHouse 查询 original_codes = codes normalized_codes = [normalize_stock_code(code) for code in codes] code_mapping = dict(zip(original_codes, normalized_codes)) reverse_mapping = dict(zip(normalized_codes, original_codes)) try: event_datetime = datetime.fromisoformat(event_time) if event_time else datetime.now() except ValueError: return jsonify({'success': False, 'error': 'Invalid event_time format'}), 400 client = get_clickhouse_client() # 批量获取股票名称(使用 Redis 缓存) base_codes = list(set([code.split('.')[0] for code in codes])) stock_names = get_cached_stock_names(base_codes) # 确定目标交易日和涨跌幅基准日(处理跨周末场景) # - 周五15:00后到周一15:00前,分时图显示周一行情,涨跌幅基于周五收盘价 target_date, prev_trading_day = get_target_and_prev_trading_day(event_datetime) if not target_date: # 返回空数据(使用原始代码作为 key) return jsonify({ 'success': True, 'data': {code: {'data': [], 'trade_date': event_datetime.date().strftime('%Y-%m-%d'), 'type': chart_type} for code in original_codes} }) start_time = datetime.combine(target_date, dt_time(9, 30)) end_time = datetime.combine(target_date, dt_time(15, 0)) results = {} if chart_type == 'timeline': # 批量获取前收盘价(使用 Redis 缓存) # 使用 prev_trading_day 作为基准日期(处理跨周末场景) prev_close_map = {} if prev_trading_day: prev_date_str = prev_trading_day.strftime('%Y%m%d') base_codes = list(set([code.split('.')[0] for code in codes])) prev_close_map = get_cached_prev_close(base_codes, prev_date_str) print(f"分时图基准日期: {prev_trading_day}, 获取到 {len(prev_close_map)} 条前收盘价(Redis缓存)") # 批量查询分时数据(使用标准化代码查询 ClickHouse) batch_data = client.execute(""" SELECT code, timestamp, close, volume FROM stock_minute WHERE code IN %(codes)s AND timestamp BETWEEN %(start)s AND %(end)s ORDER BY code, timestamp """, { 'codes': normalized_codes, # 使用标准化代码 'start': start_time, 'end': end_time }) # 按股票代码分组,同时计算均价和涨跌幅 stock_data = {} stock_accum = {} # 用于计算均价的累计值 for row in batch_data: norm_code = row[0] base_code = norm_code.split('.')[0] price = float(row[2]) volume = float(row[3]) if norm_code not in stock_data: stock_data[norm_code] = [] stock_accum[norm_code] = {'total_amount': 0, 'total_volume': 0} # 累计计算均价 stock_accum[norm_code]['total_amount'] += price * volume stock_accum[norm_code]['total_volume'] += volume total_vol = stock_accum[norm_code]['total_volume'] avg_price = stock_accum[norm_code]['total_amount'] / total_vol if total_vol > 0 else price # 计算涨跌幅 prev_close = prev_close_map.get(base_code) change_percent = ((price - prev_close) / prev_close * 100) if prev_close and prev_close > 0 else 0 stock_data[norm_code].append({ 'time': row[1].strftime('%H:%M'), 'price': price, 'avg_price': round(avg_price, 2), 'volume': volume, 'change_percent': round(change_percent, 2) }) # 组装结果(使用原始代码作为 key 返回) for orig_code in original_codes: norm_code = code_mapping[orig_code] base_code = orig_code.split('.')[0] stock_name = stock_names.get(base_code, f'股票{base_code}') data_list = stock_data.get(norm_code, []) prev_close = prev_close_map.get(base_code) results[orig_code] = { 'code': orig_code, 'name': stock_name, 'data': data_list, 'trade_date': target_date.strftime('%Y-%m-%d'), 'type': 'timeline', 'prev_close': prev_close } elif chart_type == 'daily': # 批量查询日线数据(从MySQL ea_trade表) with engine.connect() as conn: base_codes = list(set([code.split('.')[0] for code in codes])) if base_codes: placeholders = ','.join([f':code{i}' for i in range(len(base_codes))]) params = {f'code{i}': code for i, code in enumerate(base_codes)} # 确定查询的日期范围 # 如果指定了 custom_end_date,用于分页加载更早的数据 if custom_end_date: try: end_date_obj = datetime.strptime(custom_end_date, '%Y-%m-%d').date() except ValueError: end_date_obj = target_date else: end_date_obj = target_date # TRADEDATE 是整数格式 YYYYMMDD,需要转换日期格式 start_date = end_date_obj - timedelta(days=days_before) params['start_date'] = int(start_date.strftime('%Y%m%d')) params['end_date'] = int(end_date_obj.strftime('%Y%m%d')) daily_result = conn.execute(text(f""" SELECT SECCODE, TRADEDATE, F003N as open, F005N as high, F006N as low, F007N as close, F004N as volume FROM ea_trade WHERE SECCODE IN ({placeholders}) AND TRADEDATE BETWEEN :start_date AND :end_date ORDER BY SECCODE, TRADEDATE """), params).fetchall() # 按股票代码分组 stock_data = {} for row in daily_result: code_base = row[0] if code_base not in stock_data: stock_data[code_base] = [] # 日期格式处理:TRADEDATE 可能是 datetime 或 int(YYYYMMDD) trade_date_val = row[1] if hasattr(trade_date_val, 'strftime'): date_str = trade_date_val.strftime('%Y-%m-%d') elif isinstance(trade_date_val, int): # 整数格式 YYYYMMDD -> YYYY-MM-DD date_str = f"{str(trade_date_val)[:4]}-{str(trade_date_val)[4:6]}-{str(trade_date_val)[6:8]}" else: date_str = str(trade_date_val) stock_data[code_base].append({ 'time': date_str, # 统一使用 time 字段,与前端期望一致 'open': float(row[2]) if row[2] else 0, 'high': float(row[3]) if row[3] else 0, 'low': float(row[4]) if row[4] else 0, 'close': float(row[5]) if row[5] else 0, 'volume': float(row[6]) if row[6] else 0 }) # 组装结果(使用原始代码作为 key 返回) # 同时计算最早日期,用于判断是否还有更多数据 earliest_dates = {} for orig_code in original_codes: base_code = orig_code.split('.')[0] stock_name = stock_names.get(base_code, f'股票{base_code}') data_list = stock_data.get(base_code, []) # 记录每只股票的最早日期 if data_list: earliest_dates[orig_code] = data_list[0]['time'] results[orig_code] = { 'code': orig_code, 'name': stock_name, 'data': data_list, 'trade_date': target_date.strftime('%Y-%m-%d'), 'type': 'daily', 'earliest_date': data_list[0]['time'] if data_list else None } # 计算是否还有更多历史数据(基于事件日期往前推365天) event_date = event_datetime.date() one_year_ago = event_date - timedelta(days=365) # 如果当前查询的起始日期还没到一年前,则还有更多数据 has_more = start_date > one_year_ago if chart_type == 'daily' else False print(f"批量K线查询完成: {len(codes)} 只股票, 类型: {chart_type}, 交易日: {target_date}, days_before: {days_before}, has_more: {has_more}") return jsonify({ 'success': True, 'data': results, 'has_more': has_more, 'query_start_date': start_date.strftime('%Y-%m-%d') if chart_type == 'daily' else None, 'query_end_date': end_date_obj.strftime('%Y-%m-%d') if chart_type == 'daily' else None }) except Exception as e: print(f"批量K线查询错误: {e}") return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/stock//latest-minute', methods=['GET']) def get_latest_minute_data(stock_code): """获取最新交易日的分钟频数据""" client = get_clickhouse_client() # 确保股票代码包含后缀 if '.' not in stock_code: if stock_code.startswith('6'): stock_code = f"{stock_code}.SH" # 上海 elif stock_code.startswith(('8', '9', '4')): stock_code = f"{stock_code}.BJ" # 北交所 else: stock_code = f"{stock_code}.SZ" # 深圳 # 获取股票名称 with engine.connect() as conn: result = conn.execute(text( "SELECT SECNAME FROM ea_stocklist WHERE SECCODE = :code" ), {"code": stock_code.split('.')[0]}).fetchone() stock_name = result[0] if result else 'Unknown' # 查找最近30天内有数据的最新交易日 target_date = None current_date = datetime.now().date() for i in range(30): check_date = current_date - timedelta(days=i) trading_day = get_trading_day_near_date(check_date) if trading_day and trading_day <= current_date: # 检查这个交易日是否有分钟数据 test_data = client.execute(""" SELECT COUNT(*) FROM stock_minute WHERE code = %(code)s AND timestamp BETWEEN %(start)s AND %(end)s LIMIT 1 """, { 'code': stock_code, 'start': datetime.combine(trading_day, dt_time(9, 30)), 'end': datetime.combine(trading_day, dt_time(15, 0)) }) if test_data and test_data[0][0] > 0: target_date = trading_day break if not target_date: return jsonify({ 'error': 'No data available', 'code': stock_code, 'name': stock_name, 'data': [], 'trade_date': current_date.strftime('%Y-%m-%d'), 'type': 'minute' }) # 获取目标日期的完整交易时段数据 data = client.execute(""" SELECT timestamp, open, high, low, close, volume, amt FROM stock_minute WHERE code = %(code)s AND timestamp BETWEEN %(start)s AND %(end)s ORDER BY timestamp """, { 'code': stock_code, 'start': datetime.combine(target_date, dt_time(9, 30)), 'end': datetime.combine(target_date, dt_time(15, 0)) }) kline_data = [{ 'time': row[0].strftime('%H:%M'), 'open': float(row[1]), 'high': float(row[2]), 'low': float(row[3]), 'close': float(row[4]), 'volume': float(row[5]), 'amount': float(row[6]) } for row in data] return jsonify({ 'code': stock_code, 'name': stock_name, 'data': kline_data, 'trade_date': target_date.strftime('%Y-%m-%d'), 'type': 'minute', 'is_latest': True }) @app.route('/api/stock//forecast-report', methods=['GET']) def get_stock_forecast_report(stock_code): """基于 stock_forecast_data 输出报表所需数据结构 返回: - income_profit_trend: 营业收入/归母净利润趋势 - growth_bars: 增长率柱状图数据(基于营业收入同比) - eps_trend: EPS 折线 - pe_peg_axes: PE/PEG 双轴 - detail_table: 详细数据表格(与附件结构一致) """ try: # 读取该股票所有指标 rows = StockForecastData.query.filter_by(stock_code=stock_code).all() if not rows: return jsonify({'success': False, 'error': 'no_data'}), 404 # 将指标映射为字典 indicators = {} for r in rows: years, vals = r.values_by_year() indicators[r.indicator_name] = dict(zip(years, vals)) def safe(x): return x if x is not None else None years = ['2022A', '2023A', '2024A', '2025E', '2026E', '2027E'] # 营业收入与净利润趋势 income = indicators.get('营业总收入(百万元)', {}) profit = indicators.get('归母净利润(百万元)', {}) income_profit_trend = { 'years': years, 'income': [safe(income.get(y)) for y in years], 'profit': [safe(profit.get(y)) for y in years] } # 增长率柱状(若表内已有"增长率(%)",直接使用;否则按营业收入同比计算) growth = indicators.get('增长率(%)') if growth is None: # 计算同比: (curr - prev)/prev*100 growth_vals = [] prev = None for y in years: curr = income.get(y) if prev is not None and prev not in (None, 0) and curr is not None: growth_vals.append(round((float(curr) - float(prev)) / float(prev) * 100, 2)) else: growth_vals.append(None) prev = curr else: growth_vals = [safe(growth.get(y)) for y in years] growth_bars = { 'years': years, 'revenue_growth_pct': growth_vals, 'net_profit_growth_pct': None # 如后续需要可扩展 } # EPS 趋势 eps = indicators.get('EPS(稀释)') or indicators.get('EPS(元/股)') or {} eps_trend = { 'years': years, 'eps': [safe(eps.get(y)) for y in years] } # PE / PEG 双轴 pe = indicators.get('PE') or {} peg = indicators.get('PEG') or {} pe_peg_axes = { 'years': years, 'pe': [safe(pe.get(y)) for y in years], 'peg': [safe(peg.get(y)) for y in years] } # 详细数据表格(列顺序固定) def fmt(val): try: return None if val is None else round(float(val), 2) except Exception: return None detail_rows = [ { '指标': '营业总收入(百万元)', **{y: fmt(income.get(y)) for y in years}, }, { '指标': '增长率(%)', **{y: fmt(v) for y, v in zip(years, growth_vals)}, }, { '指标': '归母净利润(百万元)', **{y: fmt(profit.get(y)) for y in years}, }, { '指标': 'EPS(稀释)', **{y: fmt(eps.get(y)) for y in years}, }, { '指标': 'PE', **{y: fmt(pe.get(y)) for y in years}, }, { '指标': 'PEG', **{y: fmt(peg.get(y)) for y in years}, }, ] return jsonify({ 'success': True, 'data': { 'income_profit_trend': income_profit_trend, 'growth_bars': growth_bars, 'eps_trend': eps_trend, 'pe_peg_axes': pe_peg_axes, 'detail_table': { 'years': years, 'rows': detail_rows } } }) except Exception as e: app.logger.error(f"forecast report error: {e}", exc_info=True) return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/stock//basic-info', methods=['GET']) def get_stock_basic_info(stock_code): """获取股票基本信息(来自ea_baseinfo表)""" try: with engine.connect() as conn: query = text(""" SELECT SECCODE, SECNAME, ORGNAME, F001V as en_name, F002V as en_short_name, F003V as legal_representative, F004V as reg_address, F005V as office_address, F006V as post_code, F007N as reg_capital, F009V as currency, F010D as establish_date, F011V as website, F012V as email, F013V as tel, F014V as fax, F015V as main_business, F016V as business_scope, F017V as company_intro, F018V as secretary, F019V as secretary_tel, F020V as secretary_fax, F021V as secretary_email, F024V as listing_status, F026V as province, F028V as city, F030V as industry_l1, F032V as industry_l2, F034V as sw_industry_l1, F036V as sw_industry_l2, F038V as sw_industry_l3, F039V as accounting_firm, F040V as law_firm, F041V as chairman, F042V as general_manager, F043V as independent_directors, F050V as credit_code, F054V as company_size, UPDATE_DATE FROM ea_baseinfo WHERE SECCODE = :stock_code LIMIT 1 """) result = conn.execute(query, {'stock_code': stock_code}).fetchone() if not result: return jsonify({ 'success': False, 'error': f'未找到股票代码 {stock_code} 的基本信息' }), 404 # 转换为字典 basic_info = {} result_dict = row_to_dict(result) for key, value in result_dict.items(): if isinstance(value, datetime): basic_info[key] = value.strftime('%Y-%m-%d') elif isinstance(value, Decimal): basic_info[key] = float(value) else: basic_info[key] = value return jsonify({ 'success': True, 'data': basic_info }) except Exception as e: app.logger.error(f"Error getting stock basic info: {e}", exc_info=True) return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/stock//quote-detail', methods=['GET']) def get_stock_quote_detail(stock_code): """获取股票完整行情数据 - 供 StockQuoteCard 使用 返回数据包括: - 基础信息:名称、代码、行业分类 - 价格信息:现价、涨跌幅、开盘、收盘、最高、最低 - 关键指标:市盈率、市净率、流通市值、52周高低 - 主力动态:主力净流入、机构持仓(如有) """ try: # 标准化股票代码(去除后缀) base_code = stock_code.split('.')[0] if '.' in stock_code else stock_code result_data = { 'code': stock_code, 'name': '', 'industry': '', 'industry_l1': '', 'sw_industry_l1': '', 'sw_industry_l2': '', # 价格信息 'current_price': None, 'change_percent': None, 'today_open': None, 'yesterday_close': None, 'today_high': None, 'today_low': None, # 关键指标 'pe': None, 'pb': None, 'eps': None, 'market_cap': None, 'circ_mv': None, 'turnover_rate': None, 'week52_high': None, 'week52_low': None, # 主力动态(预留字段) 'main_net_inflow': None, 'institution_holding': None, 'buy_ratio': None, 'sell_ratio': None, 'update_time': None } with engine.connect() as conn: # 1. 获取最新交易数据(来自 ea_trade) trade_query = text(""" SELECT t.SECCODE, t.SECNAME, t.TRADEDATE, t.F002N as pre_close, t.F003N as open_price, t.F004N as volume, t.F005N as high, t.F006N as low, t.F007N as close_price, t.F010N as change_pct, t.F011N as amount, t.F012N as turnover_rate, t.F020N as total_shares, t.F021N as float_shares, t.F026N as pe_ratio, b.F034V as sw_industry_l1, b.F036V as sw_industry_l2, b.F030V as industry_l1 FROM ea_trade t LEFT JOIN ea_baseinfo b ON t.SECCODE = b.SECCODE WHERE t.SECCODE = :stock_code ORDER BY t.TRADEDATE DESC LIMIT 1 """) trade_result = conn.execute(trade_query, {'stock_code': base_code}).fetchone() if trade_result: row = row_to_dict(trade_result) result_data['name'] = row.get('SECNAME') or '' result_data['current_price'] = float(row.get('close_price') or 0) result_data['change_percent'] = float(row.get('change_pct') or 0) result_data['today_open'] = float(row.get('open_price') or 0) result_data['yesterday_close'] = float(row.get('pre_close') or 0) result_data['today_high'] = float(row.get('high') or 0) result_data['today_low'] = float(row.get('low') or 0) result_data['pe'] = float(row.get('pe_ratio') or 0) if row.get('pe_ratio') else None result_data['turnover_rate'] = float(row.get('turnover_rate') or 0) result_data['sw_industry_l1'] = row.get('sw_industry_l1') or '' result_data['sw_industry_l2'] = row.get('sw_industry_l2') or '' result_data['industry_l1'] = row.get('industry_l1') or '' result_data['industry'] = row.get('sw_industry_l2') or row.get('sw_industry_l1') or '' # 计算流通市值(亿元) float_shares = float(row.get('float_shares') or 0) close_price = float(row.get('close_price') or 0) if float_shares > 0 and close_price > 0: circ_mv = (float_shares * close_price) / 100000000 # 转为亿 result_data['circ_mv'] = round(circ_mv, 2) result_data['market_cap'] = f"{round(circ_mv, 2)}亿" trade_date = row.get('TRADEDATE') if trade_date: if hasattr(trade_date, 'strftime'): result_data['update_time'] = trade_date.strftime('%Y-%m-%d') else: result_data['update_time'] = str(trade_date) # 2. 获取52周高低价 week52_query = text(""" SELECT MAX(F005N) as week52_high, MIN(F006N) as week52_low FROM ea_trade WHERE SECCODE = :stock_code AND TRADEDATE >= DATE_SUB(CURDATE(), INTERVAL 52 WEEK) AND F005N > 0 AND F006N > 0 """) week52_result = conn.execute(week52_query, {'stock_code': base_code}).fetchone() if week52_result: w52 = row_to_dict(week52_result) result_data['week52_high'] = float(w52.get('week52_high') or 0) result_data['week52_low'] = float(w52.get('week52_low') or 0) return jsonify({ 'success': True, 'data': result_data }) except Exception as e: app.logger.error(f"Error getting stock quote detail: {e}", exc_info=True) return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/stock//announcements', methods=['GET']) def get_stock_announcements(stock_code): """获取股票公告列表""" try: limit = request.args.get('limit', 50, type=int) with engine.connect() as conn: query = text(""" SELECT F001D as announce_date, F002V as title, F003V as url, F004V as format, F005N as file_size, F006V as info_type, UPDATE_DATE FROM ea_baseinfolist WHERE SECCODE = :stock_code ORDER BY F001D DESC LIMIT :limit """) result = conn.execute(query, {'stock_code': stock_code, 'limit': limit}).fetchall() announcements = [] for row in result: announcement = {} for key, value in row_to_dict(row).items(): if value is None: announcement[key] = None elif isinstance(value, datetime): announcement[key] = value.strftime('%Y-%m-%d %H:%M:%S') elif isinstance(value, date): announcement[key] = value.strftime('%Y-%m-%d') elif isinstance(value, Decimal): announcement[key] = float(value) else: announcement[key] = value announcements.append(announcement) return jsonify({ 'success': True, 'data': announcements, 'total': len(announcements) }) except Exception as e: app.logger.error(f"Error getting stock announcements: {e}", exc_info=True) return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/stock//disclosure-schedule', methods=['GET']) def get_stock_disclosure_schedule(stock_code): """获取股票财报预披露时间表""" try: with engine.connect() as conn: query = text(""" SELECT distinct F001D as report_period, F002D as scheduled_date, F003D as change_date1, F004D as change_date2, F005D as change_date3, F006D as actual_date, F007D as change_date4, F008D as change_date5, MODTIME as mod_time FROM ea_pretime WHERE SECCODE = :stock_code ORDER BY F001D DESC LIMIT 20 """) result = conn.execute(query, {'stock_code': stock_code}).fetchall() schedules = [] for row in result: schedule = {} for key, value in row_to_dict(row).items(): if value is None: schedule[key] = None elif isinstance(value, datetime): schedule[key] = value.strftime('%Y-%m-%d %H:%M:%S') elif isinstance(value, date): schedule[key] = value.strftime('%Y-%m-%d') elif isinstance(value, Decimal): schedule[key] = float(value) else: schedule[key] = value # 计算最新的预约日期 latest_scheduled = schedule.get('scheduled_date') for change_field in ['change_date5', 'change_date4', 'change_date3', 'change_date2', 'change_date1']: if schedule.get(change_field): latest_scheduled = schedule[change_field] break schedule['latest_scheduled_date'] = latest_scheduled schedule['is_disclosed'] = bool(schedule.get('actual_date')) # 格式化报告期名称 if schedule.get('report_period'): period_date = schedule['report_period'] if period_date.endswith('-03-31'): schedule['report_name'] = f"{period_date[:4]}年一季报" elif period_date.endswith('-06-30'): schedule['report_name'] = f"{period_date[:4]}年中报" elif period_date.endswith('-09-30'): schedule['report_name'] = f"{period_date[:4]}年三季报" elif period_date.endswith('-12-31'): schedule['report_name'] = f"{period_date[:4]}年年报" else: schedule['report_name'] = period_date schedules.append(schedule) return jsonify({ 'success': True, 'data': schedules, 'total': len(schedules) }) except Exception as e: app.logger.error(f"Error getting disclosure schedule: {e}", exc_info=True) return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/stock//actual-control', methods=['GET']) def get_stock_actual_control(stock_code): """获取股票实际控制人信息""" try: with engine.connect() as conn: query = text(""" SELECT DECLAREDATE as declare_date, ENDDATE as end_date, F001V as direct_holder_id, F002V as direct_holder_name, F003V as actual_controller_id, F004V as actual_controller_name, F005N as holding_shares, F006N as holding_ratio, F007V as control_type_code, F008V as control_type, F012V as direct_controller_id, F013V as direct_controller_name, F014V as controller_type, ORGNAME as org_name, SECCODE as sec_code, SECNAME as sec_name FROM ea_actualcon WHERE SECCODE = :stock_code ORDER BY ENDDATE DESC, DECLAREDATE DESC LIMIT 20 """) result = conn.execute(query, {'stock_code': stock_code}).fetchall() control_info = [] for row in result: control_record = {} for key, value in row_to_dict(row).items(): if value is None: control_record[key] = None elif isinstance(value, datetime): control_record[key] = value.strftime('%Y-%m-%d %H:%M:%S') elif isinstance(value, date): control_record[key] = value.strftime('%Y-%m-%d') elif isinstance(value, Decimal): control_record[key] = float(value) else: control_record[key] = value control_info.append(control_record) return jsonify({ 'success': True, 'data': control_info, 'total': len(control_info) }) except Exception as e: app.logger.error(f"Error getting actual control info: {e}", exc_info=True) return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/stock//concentration', methods=['GET']) def get_stock_concentration(stock_code): """获取股票股权集中度信息""" try: with engine.connect() as conn: query = text(""" SELECT ENDDATE as end_date, F001V as stat_item, F002N as holding_shares, F003N as holding_ratio, F004N as ratio_change, ORGNAME as org_name, SECCODE as sec_code, SECNAME as sec_name FROM ea_concentration WHERE SECCODE = :stock_code ORDER BY ENDDATE DESC LIMIT 20 """) result = conn.execute(query, {'stock_code': stock_code}).fetchall() concentration_info = [] for row in result: concentration_record = {} for key, value in row_to_dict(row).items(): if value is None: concentration_record[key] = None elif isinstance(value, datetime): concentration_record[key] = value.strftime('%Y-%m-%d %H:%M:%S') elif isinstance(value, date): concentration_record[key] = value.strftime('%Y-%m-%d') elif isinstance(value, Decimal): concentration_record[key] = float(value) else: concentration_record[key] = value concentration_info.append(concentration_record) return jsonify({ 'success': True, 'data': concentration_info, 'total': len(concentration_info) }) except Exception as e: app.logger.error(f"Error getting concentration info: {e}", exc_info=True) return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/stock//management', methods=['GET']) def get_stock_management(stock_code): """获取股票管理层信息""" try: # 获取是否只显示在职人员参数 active_only = request.args.get('active_only', 'true').lower() == 'true' with engine.connect() as conn: base_query = """ SELECT DECLAREDATE as declare_date, \ F001V as person_id, \ F002V as name, \ F007D as start_date, \ F008D as end_date, \ F009V as position_name, \ F010V as gender, \ F011V as education, \ F012V as birth_year, \ F013V as nationality, \ F014V as position_category_code, \ F015V as position_category, \ F016V as position_code, \ F017V as highest_degree, \ F019V as resume, \ F020C as is_active, \ ORGNAME as org_name, \ SECCODE as sec_code, \ SECNAME as sec_name FROM ea_management WHERE SECCODE = :stock_code \ """ if active_only: base_query += " AND F020C = '1'" base_query += " ORDER BY DECLAREDATE DESC, F007D DESC" query = text(base_query) result = conn.execute(query, {'stock_code': stock_code}).fetchall() management_info = [] for row in result: management_record = {} for key, value in row_to_dict(row).items(): if value is None: management_record[key] = None elif isinstance(value, datetime): management_record[key] = value.strftime('%Y-%m-%d %H:%M:%S') elif isinstance(value, date): management_record[key] = value.strftime('%Y-%m-%d') elif isinstance(value, Decimal): management_record[key] = float(value) else: management_record[key] = value management_info.append(management_record) return jsonify({ 'success': True, 'data': management_info, 'total': len(management_info) }) except Exception as e: app.logger.error(f"Error getting management info: {e}", exc_info=True) return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/stock//top-circulation-shareholders', methods=['GET']) def get_stock_top_circulation_shareholders(stock_code): """获取股票十大流通股东信息""" try: limit = request.args.get('limit', 10, type=int) with engine.connect() as conn: query = text(""" SELECT DECLAREDATE as declare_date, ENDDATE as end_date, F001N as shareholder_rank, F002V as shareholder_id, F003V as shareholder_name, F004V as shareholder_type, F005N as holding_shares, F006N as total_share_ratio, F007N as circulation_share_ratio, F011V as share_nature, F012N as b_shares, F013N as h_shares, F014N as other_shares, ORGNAME as org_name, SECCODE as sec_code, SECNAME as sec_name FROM ea_tencirculation WHERE SECCODE = :stock_code ORDER BY ENDDATE DESC, F001N ASC LIMIT :limit """) result = conn.execute(query, {'stock_code': stock_code, 'limit': limit}).fetchall() shareholders_info = [] for row in result: shareholder_record = {} for key, value in row_to_dict(row).items(): if value is None: shareholder_record[key] = None elif isinstance(value, datetime): shareholder_record[key] = value.strftime('%Y-%m-%d %H:%M:%S') elif isinstance(value, date): shareholder_record[key] = value.strftime('%Y-%m-%d') elif isinstance(value, Decimal): shareholder_record[key] = float(value) else: shareholder_record[key] = value shareholders_info.append(shareholder_record) return jsonify({ 'success': True, 'data': shareholders_info, 'total': len(shareholders_info) }) except Exception as e: app.logger.error(f"Error getting top circulation shareholders: {e}", exc_info=True) return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/stock//top-shareholders', methods=['GET']) def get_stock_top_shareholders(stock_code): """获取股票十大股东信息""" try: limit = request.args.get('limit', 10, type=int) with engine.connect() as conn: query = text(""" SELECT DECLAREDATE as declare_date, ENDDATE as end_date, F001N as shareholder_rank, F002V as shareholder_name, F003V as shareholder_id, F004V as shareholder_type, F005N as holding_shares, F006N as total_share_ratio, F007N as circulation_share_ratio, F011V as share_nature, F016N as restricted_shares, F017V as concert_party_group, F018N as circulation_shares, ORGNAME as org_name, SECCODE as sec_code, SECNAME as sec_name FROM ea_tenshareholder WHERE SECCODE = :stock_code ORDER BY ENDDATE DESC, F001N ASC LIMIT :limit """) result = conn.execute(query, {'stock_code': stock_code, 'limit': limit}).fetchall() shareholders_info = [] for row in result: shareholder_record = {} for key, value in row_to_dict(row).items(): if value is None: shareholder_record[key] = None elif isinstance(value, datetime): shareholder_record[key] = value.strftime('%Y-%m-%d %H:%M:%S') elif isinstance(value, date): shareholder_record[key] = value.strftime('%Y-%m-%d') elif isinstance(value, Decimal): shareholder_record[key] = float(value) else: shareholder_record[key] = value shareholders_info.append(shareholder_record) return jsonify({ 'success': True, 'data': shareholders_info, 'total': len(shareholders_info) }) except Exception as e: app.logger.error(f"Error getting top shareholders: {e}", exc_info=True) return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/stock//branches', methods=['GET']) def get_stock_branches(stock_code): """获取股票分支机构信息""" try: with engine.connect() as conn: query = text(""" SELECT CRECODE as cre_code, F001V as branch_name, F002V as register_capital, F003V as business_status, F004D as register_date, F005N as related_company_count, F006V as legal_person, ORGNAME as org_name, SECCODE as sec_code, SECNAME as sec_name FROM ea_branch WHERE SECCODE = :stock_code ORDER BY F004D DESC """) result = conn.execute(query, {'stock_code': stock_code}).fetchall() branches_info = [] for row in result: branch_record = {} for key, value in row_to_dict(row).items(): if value is None: branch_record[key] = None elif isinstance(value, datetime): branch_record[key] = value.strftime('%Y-%m-%d %H:%M:%S') elif isinstance(value, date): branch_record[key] = value.strftime('%Y-%m-%d') elif isinstance(value, Decimal): branch_record[key] = float(value) else: branch_record[key] = value branches_info.append(branch_record) return jsonify({ 'success': True, 'data': branches_info, 'total': len(branches_info) }) except Exception as e: app.logger.error(f"Error getting branches info: {e}", exc_info=True) return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/stock//patents', methods=['GET']) def get_stock_patents(stock_code): """获取股票专利信息""" try: limit = request.args.get('limit', 50, type=int) patent_type = request.args.get('type', None) # 专利类型筛选 with engine.connect() as conn: base_query = """ SELECT CRECODE as cre_code, \ F001V as patent_name, \ F002V as application_number, \ F003V as publication_number, \ F004V as classification_number, \ F005D as publication_date, \ F006D as application_date, \ F007V as patent_type, \ F008V as applicant, \ F009V as inventor, \ ID as id, \ ORGNAME as org_name, \ SECCODE as sec_code, \ SECNAME as sec_name FROM ea_patent WHERE SECCODE = :stock_code \ """ params = {'stock_code': stock_code, 'limit': limit} if patent_type: base_query += " AND F007V = :patent_type" params['patent_type'] = patent_type base_query += " ORDER BY F006D DESC, F005D DESC LIMIT :limit" query = text(base_query) result = conn.execute(query, params).fetchall() patents_info = [] for row in result: patent_record = {} for key, value in row_to_dict(row).items(): if value is None: patent_record[key] = None elif isinstance(value, datetime): patent_record[key] = value.strftime('%Y-%m-%d %H:%M:%S') elif isinstance(value, date): patent_record[key] = value.strftime('%Y-%m-%d') elif isinstance(value, Decimal): patent_record[key] = float(value) else: patent_record[key] = value patents_info.append(patent_record) return jsonify({ 'success': True, 'data': patents_info, 'total': len(patents_info) }) except Exception as e: app.logger.error(f"Error getting patents info: {e}", exc_info=True) return jsonify({'success': False, 'error': str(e)}), 500 def get_daily_kline(stock_code, event_datetime, stock_name): """处理日K线数据""" stock_code = stock_code.split('.')[0] with engine.connect() as conn: # 获取事件日期前后的数据(前365天/1年,后30天) kline_sql = """ WITH date_range AS (SELECT TRADEDATE \ FROM ea_trade \ WHERE SECCODE = :stock_code \ AND TRADEDATE BETWEEN DATE_SUB(:trade_date, INTERVAL 365 DAY) \ AND DATE_ADD(:trade_date, INTERVAL 30 DAY) \ GROUP BY TRADEDATE \ ORDER BY TRADEDATE) SELECT t.TRADEDATE, CAST(t.F003N AS FLOAT) as open, CAST(t.F007N AS FLOAT) as close, CAST(t.F005N AS FLOAT) as high, CAST(t.F006N AS FLOAT) as low, CAST(t.F004N AS FLOAT) as volume FROM ea_trade t JOIN date_range d \ ON t.TRADEDATE = d.TRADEDATE WHERE t.SECCODE = :stock_code ORDER BY t.TRADEDATE \ """ result = conn.execute(text(kline_sql), { "stock_code": stock_code, "trade_date": event_datetime.date() }).fetchall() if not result: return jsonify({ 'error': 'No data available', 'code': stock_code, 'name': stock_name, 'data': [], 'trade_date': event_datetime.date().strftime('%Y-%m-%d'), 'type': 'daily' }) kline_data = [{ 'time': row.TRADEDATE.strftime('%Y-%m-%d'), 'open': float(row.open), 'high': float(row.high), 'low': float(row.low), 'close': float(row.close), 'volume': float(row.volume) } for row in result] return jsonify({ 'code': stock_code, 'name': stock_name, 'data': kline_data, 'trade_date': event_datetime.date().strftime('%Y-%m-%d'), 'type': 'daily', 'is_history': True }) def get_minute_kline(stock_code, event_datetime, stock_name, skip_next_day=False): """处理分钟K线数据 Args: stock_code: 股票代码 event_datetime: 事件时间 stock_name: 股票名称 skip_next_day: 是否跳过"下一个交易日"逻辑(用于灵活屏盘后查看当天数据) """ client = get_clickhouse_client() target_date = get_trading_day_near_date(event_datetime.date()) is_after_market = event_datetime.time() > dt_time(15, 0) # 只有在指定了 event_time 参数时(如 Community 页面事件)才跳转到下一个交易日 # 灵活屏等实时行情场景,盘后应显示当天数据 if target_date and is_after_market and not skip_next_day: # 如果是交易日且已收盘,查找下一个交易日 next_trade_date = get_trading_day_near_date(target_date + timedelta(days=1)) if next_trade_date: target_date = next_trade_date if not target_date: return jsonify({ 'error': 'No data available', 'code': stock_code, 'name': stock_name, 'data': [], 'trade_date': event_datetime.date().strftime('%Y-%m-%d'), 'type': 'minute' }) # 获取目标日期的完整交易时段数据 data = client.execute(""" SELECT timestamp, open, high, low, close, volume, amt FROM stock_minute WHERE code = %(code)s AND timestamp BETWEEN %(start)s AND %(end)s ORDER BY timestamp """, { 'code': stock_code, 'start': datetime.combine(target_date, dt_time(9, 30)), 'end': datetime.combine(target_date, dt_time(15, 0)) }) kline_data = [{ 'time': row[0].strftime('%H:%M'), 'open': float(row[1]), 'high': float(row[2]), 'low': float(row[3]), 'close': float(row[4]), 'volume': float(row[5]), 'amount': float(row[6]) } for row in data] return jsonify({ 'code': stock_code, 'name': stock_name, 'data': kline_data, 'trade_date': target_date.strftime('%Y-%m-%d'), 'type': 'minute', 'is_history': target_date < event_datetime.date() }) def get_timeline_data(stock_code, event_datetime, stock_name): """处理分时均价线数据(timeline)。 规则: - 若事件时间在交易日的15:00之后,则展示下一个交易日的分时数据; - 若事件日非交易日,优先展示下一个交易日;如无,则回退到最近一个交易日; - 数据区间固定为 09:30-15:00。 """ client = get_clickhouse_client() target_date = get_trading_day_near_date(event_datetime.date()) is_after_market = event_datetime.time() > dt_time(15, 0) # 与分钟K逻辑保持一致的日期选择规则 if target_date and is_after_market: next_trade_date = get_trading_day_near_date(target_date + timedelta(days=1)) if next_trade_date: target_date = next_trade_date if not target_date: return jsonify({ 'error': 'No data available', 'code': stock_code, 'name': stock_name, 'data': [], 'trade_date': event_datetime.date().strftime('%Y-%m-%d'), 'type': 'timeline' }) # 获取昨收盘价 - 优先从 MySQL ea_trade 表获取(更可靠) prev_close = None base_code = stock_code.split('.')[0] target_date_str = target_date.strftime('%Y%m%d') try: with engine.connect() as conn: # F007N 是昨收价字段 result = conn.execute(text(""" SELECT F007N FROM ea_trade WHERE SECCODE = :code AND TRADEDATE = :trade_date AND F007N > 0 """), {'code': base_code, 'trade_date': target_date_str}).fetchone() if result and result[0]: prev_close = float(result[0]) except Exception as e: logger.warning(f"从 ea_trade 获取昨收价失败: {e}") # 如果 MySQL 没有数据,回退到 ClickHouse if prev_close is None: prev_close_query = """ SELECT close FROM stock_minute WHERE code = %(code)s AND timestamp < %(start)s ORDER BY timestamp DESC LIMIT 1 """ prev_close_result = client.execute(prev_close_query, { 'code': stock_code, 'start': datetime.combine(target_date, dt_time(9, 30)) }) if prev_close_result: prev_close = float(prev_close_result[0][0]) data = client.execute( """ SELECT timestamp, close, volume FROM stock_minute WHERE code = %(code)s AND timestamp BETWEEN %(start)s AND %(end)s ORDER BY timestamp """, { 'code': stock_code, 'start': datetime.combine(target_date, dt_time(9, 30)), 'end': datetime.combine(target_date, dt_time(15, 0)), } ) timeline_data = [] total_amount = 0 total_volume = 0 for row in data: price = float(row[1]) volume = float(row[2]) total_amount += price * volume total_volume += volume avg_price = total_amount / total_volume if total_volume > 0 else price # 计算涨跌幅 change_percent = ((price - prev_close) / prev_close * 100) if prev_close else 0 timeline_data.append({ 'time': row[0].strftime('%H:%M'), 'price': price, 'avg_price': avg_price, 'volume': volume, 'change_percent': change_percent, }) return jsonify({ 'code': stock_code, 'name': stock_name, 'data': timeline_data, 'trade_date': target_date.strftime('%Y-%m-%d'), 'type': 'timeline', 'is_history': target_date < event_datetime.date(), 'prev_close': prev_close, }) # ==================== 指数行情API(与股票逻辑一致,数据表为 index_minute) ==================== @app.route('/api/index//realtime') def get_index_realtime(index_code): """ 获取指数实时行情(用于交易时间内的行情更新) 从 index_minute 表获取最新的分钟数据 返回: 最新价、涨跌幅、涨跌额、开盘价、最高价、最低价、昨收价 """ # 确保指数代码包含后缀(ClickHouse 中存储的是带后缀的代码) # 上证指数: 000xxx.SH, 深证指数: 399xxx.SZ if '.' not in index_code: if index_code.startswith('399'): index_code = f"{index_code}.SZ" else: # 000开头的上证指数,以及其他指数默认上海 index_code = f"{index_code}.SH" client = get_clickhouse_client() today = date.today() # 判断今天是否是交易日 if today not in trading_days_set: # 非交易日,获取最近一个交易日的收盘数据 target_date = get_trading_day_near_date(today) if not target_date: return jsonify({ 'success': False, 'error': 'No trading day found', 'data': None }) is_trading = False else: target_date = today # 判断是否在交易时间内 now = datetime.now() current_minutes = now.hour * 60 + now.minute # 9:30-11:30 = 570-690, 13:00-15:00 = 780-900 is_trading = (570 <= current_minutes <= 690) or (780 <= current_minutes <= 900) try: # 获取当天/最近交易日的第一条数据(开盘价)和最后一条数据(最新价) # 同时获取最高价和最低价 data = client.execute( """ SELECT min(open) as first_open, max(high) as day_high, min(low) as day_low, argMax(close, timestamp) as latest_close, argMax(timestamp, timestamp) as latest_time FROM index_minute WHERE code = %(code)s AND toDate(timestamp) = %(date)s """, { 'code': index_code, 'date': target_date, } ) if not data or not data[0] or data[0][3] is None: return jsonify({ 'success': False, 'error': 'No data available', 'data': None }) row = data[0] first_open = float(row[0]) if row[0] else None day_high = float(row[1]) if row[1] else None day_low = float(row[2]) if row[2] else None latest_close = float(row[3]) if row[3] else None latest_time = row[4] # 获取昨收价(从 MySQL ea_exchangetrade 表) code_no_suffix = index_code.split('.')[0] prev_close = None with engine.connect() as conn: # 获取前一个交易日的收盘价 prev_result = conn.execute(text( """ SELECT F006N FROM ea_exchangetrade WHERE INDEXCODE = :code AND TRADEDATE < :today ORDER BY TRADEDATE DESC LIMIT 1 """ ), { 'code': code_no_suffix, 'today': datetime.combine(target_date, dt_time(0, 0, 0)) }).fetchone() if prev_result and prev_result[0]: prev_close = float(prev_result[0]) # 计算涨跌额和涨跌幅 change_amount = None change_pct = None if latest_close is not None and prev_close is not None and prev_close > 0: change_amount = latest_close - prev_close change_pct = (change_amount / prev_close) * 100 return jsonify({ 'success': True, 'data': { 'code': index_code, 'price': latest_close, 'open': first_open, 'high': day_high, 'low': day_low, 'prev_close': prev_close, 'change': change_amount, 'change_pct': change_pct, 'update_time': latest_time.strftime('%H:%M:%S') if latest_time else None, 'trade_date': target_date.strftime('%Y-%m-%d'), 'is_trading': is_trading, } }) except Exception as e: app.logger.error(f"获取指数实时行情失败: {index_code}, 错误: {str(e)}") return jsonify({ 'success': False, 'error': str(e), 'data': None }), 500 @app.route('/api/index//kline') def get_index_kline(index_code): chart_type = request.args.get('type', 'minute') event_time = request.args.get('event_time') try: event_datetime = datetime.fromisoformat(event_time) if event_time else datetime.now() except ValueError: return jsonify({'error': 'Invalid event_time format'}), 400 # 确保指数代码包含后缀(ClickHouse 中数据带后缀) # 399xxx -> 深交所, 其他(000xxx等)-> 上交所 if '.' not in index_code: index_code = f"{index_code}.SZ" if index_code.startswith('39') else f"{index_code}.SH" # 指数名称(暂无索引表,先返回代码本身) index_name = index_code.split('.')[0] if chart_type == 'minute': return get_index_minute_kline(index_code, event_datetime, index_name) elif chart_type == 'timeline': return get_index_timeline_data(index_code, event_datetime, index_name) elif chart_type == 'daily': return get_index_daily_kline(index_code, event_datetime, index_name) else: return jsonify({'error': f'Unsupported chart type: {chart_type}'}), 400 def get_index_minute_kline(index_code, event_datetime, index_name): client = get_clickhouse_client() target_date = get_trading_day_near_date(event_datetime.date()) if not target_date: return jsonify({ 'error': 'No data available', 'code': index_code, 'name': index_name, 'data': [], 'trade_date': event_datetime.date().strftime('%Y-%m-%d'), 'type': 'minute' }) data = client.execute( """ SELECT timestamp, open, high, low, close, volume, amt FROM index_minute WHERE code = %(code)s AND timestamp BETWEEN %(start)s AND %(end)s ORDER BY timestamp """, { 'code': index_code, 'start': datetime.combine(target_date, dt_time(9, 30)), 'end': datetime.combine(target_date, dt_time(15, 0)), } ) kline_data = [{ 'time': row[0].strftime('%H:%M'), 'open': float(row[1]), 'high': float(row[2]), 'low': float(row[3]), 'close': float(row[4]), 'volume': float(row[5]), 'amount': float(row[6]), } for row in data] return jsonify({ 'code': index_code, 'name': index_name, 'data': kline_data, 'trade_date': target_date.strftime('%Y-%m-%d'), 'type': 'minute', 'is_history': target_date < event_datetime.date(), }) def get_index_timeline_data(index_code, event_datetime, index_name): client = get_clickhouse_client() target_date = get_trading_day_near_date(event_datetime.date()) if not target_date: return jsonify({ 'error': 'No data available', 'code': index_code, 'name': index_name, 'data': [], 'trade_date': event_datetime.date().strftime('%Y-%m-%d'), 'type': 'timeline' }) data = client.execute( """ SELECT timestamp, close, volume FROM index_minute WHERE code = %(code)s AND timestamp BETWEEN %(start)s AND %(end)s ORDER BY timestamp """, { 'code': index_code, 'start': datetime.combine(target_date, dt_time(9, 30)), 'end': datetime.combine(target_date, dt_time(15, 0)), } ) timeline = [] total_amount = 0 total_volume = 0 for row in data: price = float(row[1]) volume = float(row[2]) total_amount += price * volume total_volume += volume avg_price = total_amount / total_volume if total_volume > 0 else price timeline.append({ 'time': row[0].strftime('%H:%M'), 'price': price, 'avg_price': avg_price, 'volume': volume, }) return jsonify({ 'code': index_code, 'name': index_name, 'data': timeline, 'trade_date': target_date.strftime('%Y-%m-%d'), 'type': 'timeline', 'is_history': target_date < event_datetime.date(), }) def get_index_daily_kline(index_code, event_datetime, index_name): """从 MySQL 的 stock.ea_exchangetrade 获取指数日线 注意:表中 INDEXCODE 无后缀,例如 000001.SH -> 000001 字段: F003N 开市指数 -> open F004N 最高指数 -> high F005N 最低指数 -> low F006N 最近指数 -> close(作为当日收盘或最近价使用) F007N 昨日收市指数 -> prev_close """ # 去掉后缀 code_no_suffix = index_code.split('.')[0] # 选择展示的最后交易日 target_date = get_trading_day_near_date(event_datetime.date()) if not target_date: return jsonify({ 'error': 'No data available', 'code': index_code, 'name': index_name, 'data': [], 'trade_date': event_datetime.date().strftime('%Y-%m-%d'), 'type': 'daily' }) # 取最近一段时间的日线(倒序再反转为升序) with engine.connect() as conn: rows = conn.execute(text( """ SELECT TRADEDATE, F003N, F004N, F005N, F006N, F007N FROM ea_exchangetrade WHERE INDEXCODE = :code AND TRADEDATE <= :end_dt ORDER BY TRADEDATE DESC LIMIT 180 """ ), { 'code': code_no_suffix, 'end_dt': datetime.combine(target_date, dt_time(23, 59, 59)) }).fetchall() # 反转为时间升序 rows = list(reversed(rows)) daily = [] for i, r in enumerate(rows): trade_dt = r[0] open_v = r[1] high_v = r[2] low_v = r[3] last_v = r[4] prev_close_v = r[5] # 正确的前收盘价逻辑:使用前一个交易日的F006N(收盘价) calculated_prev_close = None if i > 0 and rows[i - 1][4] is not None: # 使用前一个交易日的收盘价作为前收盘价 calculated_prev_close = float(rows[i - 1][4]) else: # 第一条记录,尝试使用F007N字段作为备选 if prev_close_v is not None and prev_close_v > 0: calculated_prev_close = float(prev_close_v) daily.append({ 'time': trade_dt.strftime('%Y-%m-%d') if hasattr(trade_dt, 'strftime') else str(trade_dt), 'open': float(open_v) if open_v is not None else None, 'high': float(high_v) if high_v is not None else None, 'low': float(low_v) if low_v is not None else None, 'close': float(last_v) if last_v is not None else None, 'prev_close': calculated_prev_close, }) return jsonify({ 'code': index_code, 'name': index_name, 'data': daily, 'trade_date': target_date.strftime('%Y-%m-%d'), 'type': 'daily', 'is_history': target_date < event_datetime.date(), }) # ==================== 日历API ==================== @app.route('/api/v1/calendar/event-counts', methods=['GET']) def get_event_counts(): """获取日历事件数量统计""" try: # 获取月份参数 year = request.args.get('year', datetime.now().year, type=int) month = request.args.get('month', datetime.now().month, type=int) # 计算月份的开始和结束日期 start_date = datetime(year, month, 1) if month == 12: end_date = datetime(year + 1, 1, 1) else: end_date = datetime(year, month + 1, 1) # 查询事件数量 query = """ SELECT DATE(calendar_time) as date, COUNT(*) as count FROM future_events WHERE calendar_time BETWEEN :start_date AND :end_date AND type = 'event' GROUP BY DATE(calendar_time) """ result = db.session.execute(text(query), { 'start_date': start_date, 'end_date': end_date }) # 格式化结果 events = [] for day in result: events.append({ 'date': day.date.isoformat(), 'count': day.count, 'className': get_event_class(day.count) }) return jsonify({ 'success': True, 'data': events }) except Exception as e: return jsonify({ 'success': False, 'error': str(e) }), 500 @app.route('/api/v1/calendar/events', methods=['GET']) def get_calendar_events(): """获取指定日期的事件列表""" date_str = request.args.get('date') event_type = request.args.get('type', 'all') if not date_str: return jsonify({ 'success': False, 'error': 'Date parameter required' }), 400 try: date = datetime.strptime(date_str, '%Y-%m-%d') except ValueError: return jsonify({ 'success': False, 'error': 'Invalid date format' }), 400 # 修复SQL语法:去掉函数名后的空格,去掉参数前的空格 query = """ SELECT * FROM future_events WHERE DATE(calendar_time) = :date """ params = {'date': date} if event_type != 'all': query += " AND type = :type" params['type'] = event_type query += " ORDER BY calendar_time" result = db.session.execute(text(query), params) events = [] user_following_ids = set() if 'user_id' in session: follows = FutureEventFollow.query.filter_by(user_id=session['user_id']).all() user_following_ids = {f.future_event_id for f in follows} for row in result: # 使用统一的处理函数,支持新字段回退和 best_matches 解析 event_data = process_future_event_row(row, user_following_ids) events.append(event_data) return jsonify({ 'success': True, 'data': events }) @app.route('/api/v1/calendar/events/', methods=['GET']) def get_calendar_event_detail(event_id): """获取日历事件详情""" try: sql = """ SELECT * FROM future_events WHERE data_id = :event_id \ """ result = db.session.execute(text(sql), {'event_id': event_id}).first() if not result: return jsonify({ 'success': False, 'error': 'Event not found' }), 404 # 检查当前用户是否关注了该未来事件 user_following_ids = set() if 'user_id' in session: is_following = FutureEventFollow.query.filter_by( user_id=session['user_id'], future_event_id=event_id ).first() is not None if is_following: user_following_ids.add(event_id) # 使用统一的处理函数,支持新字段回退和 best_matches 解析 event_data = process_future_event_row(result, user_following_ids) return jsonify({ 'success': True, 'data': event_data }) except Exception as e: return jsonify({ 'success': False, 'error': str(e) }), 500 @app.route('/api/v1/calendar/events//follow', methods=['POST']) def toggle_future_event_follow(event_id): """切换未来事件关注状态(需登录)""" if 'user_id' not in session: return jsonify({'success': False, 'error': '未登录'}), 401 try: # 检查未来事件是否存在 sql = """ SELECT data_id \ FROM future_events \ WHERE data_id = :event_id \ """ result = db.session.execute(text(sql), {'event_id': event_id}).first() if not result: return jsonify({'success': False, 'error': '未来事件不存在'}), 404 user_id = session['user_id'] # 检查是否已关注 existing = FutureEventFollow.query.filter_by( user_id=user_id, future_event_id=event_id ).first() if existing: # 取消关注 db.session.delete(existing) db.session.commit() return jsonify({ 'success': True, 'data': {'is_following': False} }) else: # 关注 follow = FutureEventFollow( user_id=user_id, future_event_id=event_id ) db.session.add(follow) db.session.commit() return jsonify({ 'success': True, 'data': {'is_following': True} }) except Exception as e: db.session.rollback() return jsonify({'success': False, 'error': str(e)}), 500 def get_event_class(count): """根据事件数量返回CSS类名""" if count >= 10: return 'event-high' elif count >= 5: return 'event-medium' elif count > 0: return 'event-low' return '' def parse_json_field(field_value): """解析JSON字段""" if not field_value: return [] try: if isinstance(field_value, str): if field_value.startswith('['): return json.loads(field_value) else: return field_value.split(',') else: return field_value except: return [] def get_future_event_field(row, new_field, old_field): """ 获取 future_events 表字段值,支持新旧字段回退 如果新字段存在且不为空,使用新字段;否则使用旧字段 """ new_value = getattr(row, new_field, None) if hasattr(row, new_field) else None old_value = getattr(row, old_field, None) if hasattr(row, old_field) else None # 如果新字段有值(不为空字符串),使用新字段 if new_value is not None and str(new_value).strip(): return new_value return old_value def parse_best_matches(best_matches_value): """ 解析新的 best_matches 数据结构(含研报引用信息) 新结构示例: [ { "stock_code": "300451.SZ", "company_name": "创业慧康", "original_description": "核心标的,医疗信息化...", "best_report_title": "报告标题", "best_report_author": "作者", "best_report_sentences": "相关内容", "best_report_match_score": "好", "best_report_match_ratio": 0.9285714285714286, "best_report_declare_date": "2023-04-25T00:00:00", "total_reports": 9, "high_score_reports": 6 }, ... ] 返回统一格式的股票列表,兼容旧格式 """ if not best_matches_value: return [] try: # 解析 JSON if isinstance(best_matches_value, str): data = json.loads(best_matches_value) else: data = best_matches_value if not isinstance(data, list): return [] result = [] for item in data: if isinstance(item, dict): # 新结构:包含研报信息的字典 stock_info = { 'code': item.get('stock_code', ''), 'name': item.get('company_name', ''), 'description': item.get('original_description', ''), 'score': item.get('best_report_match_ratio', 0), # 研报引用信息 'report': { 'title': item.get('best_report_title', ''), 'author': item.get('best_report_author', ''), 'sentences': item.get('best_report_sentences', ''), 'match_score': item.get('best_report_match_score', ''), 'match_ratio': item.get('best_report_match_ratio', 0), 'declare_date': item.get('best_report_declare_date', ''), 'total_reports': item.get('total_reports', 0), 'high_score_reports': item.get('high_score_reports', 0) } if item.get('best_report_title') else None } result.append(stock_info) elif isinstance(item, (list, tuple)) and len(item) >= 2: # 旧结构:[code, name, description, score] result.append({ 'code': item[0], 'name': item[1], 'description': item[2] if len(item) > 2 else '', 'score': item[3] if len(item) > 3 else 0, 'report': None }) return result except Exception as e: print(f"parse_best_matches error: {e}") return [] def process_future_event_row(row, user_following_ids=None): """ 统一处理 future_events 表的行数据 支持新字段回退和 best_matches 解析 """ if user_following_ids is None: user_following_ids = set() # 获取字段值,支持新旧回退 # second_modified_text -> former # second_modified_text.1 -> forecast (MySQL 中用反引号) former_value = get_future_event_field(row, 'second_modified_text', 'former') # 处理 second_modified_text.1 字段(特殊字段名) forecast_new = None if hasattr(row, 'second_modified_text.1'): forecast_new = getattr(row, 'second_modified_text.1', None) # 尝试其他可能的属性名 for attr_name in ['second_modified_text.1', 'second_modified_text_1']: if hasattr(row, attr_name): val = getattr(row, attr_name, None) if val and str(val).strip(): forecast_new = val break forecast_value = forecast_new if (forecast_new and str(forecast_new).strip()) else getattr(row, 'forecast', None) # best_matches -> related_stocks best_matches = getattr(row, 'best_matches', None) if hasattr(row, 'best_matches') else None if best_matches and str(best_matches).strip(): related_stocks = parse_best_matches(best_matches) else: related_stocks = parse_json_field(getattr(row, 'related_stocks', None)) # 构建事件数据 event_data = { 'id': row.data_id, 'title': row.title, 'type': getattr(row, 'type', None), 'calendar_time': row.calendar_time.isoformat() if row.calendar_time else None, 'star': row.star, 'former': former_value, 'forecast': forecast_value, 'fact': getattr(row, 'fact', None), 'is_following': row.data_id in user_following_ids, 'related_stocks': related_stocks, 'concepts': parse_json_field(getattr(row, 'concepts', None)), 'update_time': getattr(row, 'update_time', None).isoformat() if getattr(row, 'update_time', None) else None } return event_data # ==================== 行业API ==================== @app.route('/api/classifications', methods=['GET']) def get_classifications(): """获取申银万国行业分类树形结构""" try: # 查询申银万国行业分类的所有数据 sql = """ SELECT f003v as code, f004v as level1, f005v as level2, f006v as level3,f007v as level4 FROM ea_sector WHERE f002v = '申银万国行业分类' AND f003v IS NOT NULL AND f004v IS NOT NULL ORDER BY f003v """ result = db.session.execute(text(sql)).all() # 构建树形结构 tree_dict = {} for row in result: code = row.code level1 = row.level1 level2 = row.level2 level3 = row.level3 # 跳过空数据 if not level1: continue # 第一层 if level1 not in tree_dict: # 获取第一层的code(取前3位或前缀) level1_code = code[:3] if len(code) >= 3 else code tree_dict[level1] = { 'value': level1_code, 'label': level1, 'children_dict': {} } # 第二层 if level2: if level2 not in tree_dict[level1]['children_dict']: # 获取第二层的code(取前6位) level2_code = code[:6] if len(code) >= 6 else code tree_dict[level1]['children_dict'][level2] = { 'value': level2_code, 'label': level2, 'children_dict': {} } # 第三层 if level3: if level3 not in tree_dict[level1]['children_dict'][level2]['children_dict']: tree_dict[level1]['children_dict'][level2]['children_dict'][level3] = { 'value': code, 'label': level3 } # 转换为最终格式 result_list = [] for level1_name, level1_data in tree_dict.items(): level1_node = { 'value': level1_data['value'], 'label': level1_data['label'] } # 处理第二层 if level1_data['children_dict']: level1_children = [] for level2_name, level2_data in level1_data['children_dict'].items(): level2_node = { 'value': level2_data['value'], 'label': level2_data['label'] } # 处理第三层 if level2_data['children_dict']: level2_children = [] for level3_name, level3_data in level2_data['children_dict'].items(): level2_children.append({ 'value': level3_data['value'], 'label': level3_data['label'] }) if level2_children: level2_node['children'] = level2_children level1_children.append(level2_node) if level1_children: level1_node['children'] = level1_children result_list.append(level1_node) return jsonify({ 'success': True, 'data': result_list }) except Exception as e: return jsonify({ 'success': False, 'error': str(e) }), 500 @app.route('/api/stocklist', methods=['GET']) def get_stock_list(): """获取股票列表""" try: sql = """ SELECT DISTINCT SECCODE as code, SECNAME as name FROM ea_stocklist ORDER BY SECCODE """ result = db.session.execute(text(sql)).all() stocks = [{'code': row.code, 'name': row.name} for row in result] return jsonify(stocks) except Exception as e: return jsonify({ 'success': False, 'error': str(e) }), 500 @app.route('/api/events', methods=['GET'], strict_slashes=False) def api_get_events(): """ 获取事件列表API - 支持筛选、排序、分页,兼容前端调用 Redis 缓存策略: - 交易时间(交易日 9:00-15:00):缓存 20 秒 - 非交易时间:缓存 10 分钟 """ try: # ==================== Redis 缓存检查 ==================== # 获取所有请求参数用于生成缓存 key cache_params = dict(request.args) cache_key = generate_events_cache_key(cache_params) # 尝试从缓存获取 cached_response = get_events_cache(cache_key) if cached_response: # 添加缓存命中标记(可选,用于调试) cached_response['_cached'] = True cached_response['_cache_ttl'] = EVENTS_CACHE_TTL_TRADING if is_trading_hours() else EVENTS_CACHE_TTL_NON_TRADING return jsonify(cached_response) # ==================== 缓存未命中,执行数据库查询 ==================== # 分页参数 page = max(1, request.args.get('page', 1, type=int)) per_page = min(100, max(1, request.args.get('per_page', 10, type=int))) # 基础筛选参数 event_type = request.args.get('type', 'all') event_status = request.args.get('status', 'active') importance = request.args.get('importance', 'all') # 日期筛选参数 start_date = request.args.get('start_date') end_date = request.args.get('end_date') date_range = request.args.get('date_range') recent_days = request.args.get('recent_days', type=int) # 行业筛选参数(只支持申银万国行业分类) industry_code = request.args.get('industry_code') # 申万行业代码,如 "S370502" # 概念/标签筛选参数 tag = request.args.get('tag') tags = request.args.get('tags') keywords = request.args.get('keywords') # 搜索参数 search_query = request.args.get('q') search_type = request.args.get('search_type', 'topic') search_fields = request.args.get('search_fields', 'title,description').split(',') # 排序参数 sort_by = request.args.get('sort', 'new') return_type = request.args.get('return_type', 'avg') order = request.args.get('order', 'desc') # 收益率筛选参数 min_avg_return = request.args.get('min_avg_return', type=float) max_avg_return = request.args.get('max_avg_return', type=float) min_max_return = request.args.get('min_max_return', type=float) max_max_return = request.args.get('max_max_return', type=float) min_week_return = request.args.get('min_week_return', type=float) max_week_return = request.args.get('max_week_return', type=float) # 其他筛选参数 min_hot_score = request.args.get('min_hot_score', type=float) max_hot_score = request.args.get('max_hot_score', type=float) min_view_count = request.args.get('min_view_count', type=int) creator_id = request.args.get('creator_id', type=int) # 返回格式参数 include_creator = request.args.get('include_creator', 'true').lower() == 'true' include_stats = request.args.get('include_stats', 'true').lower() == 'true' include_related_data = request.args.get('include_related_data', 'false').lower() == 'true' # ==================== 构建查询 ==================== from sqlalchemy.orm import joinedload # 使用 joinedload 预加载 creator,解决 N+1 查询问题 query = Event.query.options(joinedload(Event.creator)) # 只返回有关联股票的事件(没有关联股票的事件不计入列表) from sqlalchemy import exists query = query.filter( exists().where(RelatedStock.event_id == Event.id) ) if event_status != 'all': query = query.filter_by(status=event_status) if event_type != 'all': query = query.filter_by(event_type=event_type) # 支持多个重要性级别筛选,用逗号分隔(如 importance=S,A) if importance != 'all': if ',' in importance: # 多个重要性级别 importance_list = [imp.strip() for imp in importance.split(',') if imp.strip()] query = query.filter(Event.importance.in_(importance_list)) else: # 单个重要性级别 query = query.filter_by(importance=importance) if creator_id: query = query.filter_by(creator_id=creator_id) # 新增:行业代码过滤(申银万国行业分类)- 支持前缀匹配 # 申万行业分类层级:一级 Sxx, 二级 Sxxxx, 三级 Sxxxxxx # 搜索 S22 会匹配所有 S22xxxx 的事件(如 S2203, S220309 等) # related_industries 格式: varchar,如 "S640701" if industry_code: # 判断是否需要前缀匹配:一级(3字符)或二级(5字符)行业代码 def is_prefix_code(code): """判断是否为需要前缀匹配的行业代码(一级或二级)""" code = code.strip() # 申万行业代码格式:S + 数字 # 一级: S + 2位数字 (如 S22) = 3字符 # 二级: S + 4位数字 (如 S2203) = 5字符 # 三级: S + 6位数字 (如 S220309) = 7字符 return len(code) < 7 and code.startswith('S') # 如果包含逗号,说明是多个行业代码 if ',' in industry_code: codes = [code.strip() for code in industry_code.split(',') if code.strip()] conditions = [] for code in codes: if is_prefix_code(code): # 前缀匹配:使用 LIKE conditions.append(Event.related_industries.like(f"{code}%")) else: # 精确匹配(三级行业代码) conditions.append(Event.related_industries == code) query = query.filter(db.or_(*conditions)) else: # 单个行业代码 if is_prefix_code(industry_code): # 前缀匹配:使用 LIKE query = query.filter(Event.related_industries.like(f"{industry_code}%")) else: # 精确匹配(三级行业代码) query = query.filter(Event.related_industries == industry_code) # 新增:关键词/全文搜索过滤(MySQL JSON) if search_query: like_pattern = f"%{search_query}%" # 子查询:查找关联股票中匹配的事件ID # stock_code 格式:600111.SH / 000001.SZ / 830001.BJ,支持不带后缀搜索 stock_subquery = db.session.query(RelatedStock.event_id).filter( db.or_( RelatedStock.stock_code.ilike(like_pattern), # 支持股票代码搜索 RelatedStock.stock_name.ilike(like_pattern), RelatedStock.relation_desc.ilike(like_pattern) ) ).distinct() # 主查询:搜索事件标题、描述、关键词或关联股票 query = query.filter( db.or_( Event.title.ilike(like_pattern), Event.description.ilike(like_pattern), text(f"JSON_SEARCH(keywords, 'one', '%{search_query}%') IS NOT NULL"), Event.id.in_(stock_subquery) ) ) if recent_days: from datetime import datetime, timedelta cutoff_date = datetime.now() - timedelta(days=recent_days) query = query.filter(Event.created_at >= cutoff_date) else: if date_range and ' 至 ' in date_range: try: start_date_str, end_date_str = date_range.split(' 至 ') start_date = start_date_str.strip() end_date = end_date_str.strip() except ValueError: pass if start_date: from datetime import datetime try: if len(start_date) == 10: start_datetime = datetime.strptime(start_date, '%Y-%m-%d') else: start_datetime = datetime.strptime(start_date, '%Y-%m-%d %H:%M:%S') query = query.filter(Event.created_at >= start_datetime) except ValueError: pass if end_date: from datetime import datetime try: if len(end_date) == 10: end_datetime = datetime.strptime(end_date, '%Y-%m-%d') end_datetime = end_datetime.replace(hour=23, minute=59, second=59) else: end_datetime = datetime.strptime(end_date, '%Y-%m-%d %H:%M:%S') query = query.filter(Event.created_at <= end_datetime) except ValueError: pass if min_view_count is not None: query = query.filter(Event.view_count >= min_view_count) # 排序 from sqlalchemy import desc, asc, case order_func = desc if order.lower() == 'desc' else asc if sort_by == 'hot': query = query.order_by(order_func(Event.hot_score)) elif sort_by == 'new': query = query.order_by(order_func(Event.created_at)) elif sort_by == 'returns': if return_type == 'avg': query = query.order_by(order_func(Event.related_avg_chg)) elif return_type == 'max': query = query.order_by(order_func(Event.related_max_chg)) elif return_type == 'week': query = query.order_by(order_func(Event.related_week_chg)) elif sort_by == 'importance': importance_order = case( (Event.importance == 'S', 1), (Event.importance == 'A', 2), (Event.importance == 'B', 3), (Event.importance == 'C', 4), else_=5 ) if order.lower() == 'desc': query = query.order_by(importance_order) else: query = query.order_by(desc(importance_order)) elif sort_by == 'view_count': query = query.order_by(order_func(Event.view_count)) # 分页 paginated = query.paginate(page=page, per_page=per_page, error_out=False) events_data = [] for event in paginated.items: event_dict = { 'id': event.id, 'title': event.title, 'description': event.description, 'event_type': event.event_type, 'importance': event.importance, 'status': event.status, 'created_at': event.created_at.isoformat() if event.created_at else None, 'updated_at': event.updated_at.isoformat() if event.updated_at else None, 'start_time': event.start_time.isoformat() if event.start_time else None, 'end_time': event.end_time.isoformat() if event.end_time else None, } if include_stats: event_dict.update({ 'hot_score': event.hot_score, 'view_count': event.view_count, 'post_count': event.post_count, 'follower_count': event.follower_count, 'related_avg_chg': event.related_avg_chg, 'related_max_chg': event.related_max_chg, 'related_week_chg': event.related_week_chg, 'invest_score': event.invest_score, 'trending_score': event.trending_score, 'expectation_surprise_score': event.expectation_surprise_score, }) if include_creator: event_dict['creator'] = { 'id': event.creator.id if event.creator else None, 'username': event.creator.username if event.creator else 'Anonymous' } event_dict['keywords'] = event.keywords_list if hasattr(event, 'keywords_list') else event.keywords event_dict['related_industries'] = event.related_industries if include_related_data: pass events_data.append(event_dict) applied_filters = {} if event_type != 'all': applied_filters['type'] = event_type if importance != 'all': applied_filters['importance'] = importance if start_date: applied_filters['start_date'] = start_date if end_date: applied_filters['end_date'] = end_date if industry_code: applied_filters['industry_code'] = industry_code if tag: applied_filters['tag'] = tag if tags: applied_filters['tags'] = tags if search_query: applied_filters['search_query'] = search_query applied_filters['search_type'] = search_type # 构建响应数据 response_data = { 'success': True, 'data': { 'events': events_data, 'pagination': { 'page': paginated.page, 'per_page': paginated.per_page, 'total': paginated.total, 'pages': paginated.pages, 'has_prev': paginated.has_prev, 'has_next': paginated.has_next }, 'filters': { 'applied_filters': applied_filters, 'total_count': paginated.total } } } # ==================== 存入 Redis 缓存 ==================== set_events_cache(cache_key, response_data) return jsonify(response_data) except Exception as e: app.logger.error(f"获取事件列表出错: {str(e)}", exc_info=True) return jsonify({ 'success': False, 'error': str(e), 'error_type': type(e).__name__ }), 500 @app.route('/api/events/hot', methods=['GET']) def get_hot_events(): """获取热点事件""" try: from datetime import datetime, timedelta days = request.args.get('days', 3, type=int) limit = request.args.get('limit', 4, type=int) since_date = datetime.now() - timedelta(days=days) hot_events = Event.query.filter( Event.status == 'active', Event.created_at >= since_date, Event.related_avg_chg != None, Event.related_avg_chg > 0 ).order_by(Event.related_avg_chg.desc()).limit(limit).all() if len(hot_events) < limit: additional_events = Event.query.filter( Event.status == 'active', Event.created_at >= since_date, ~Event.id.in_([event.id for event in hot_events]) ).order_by(Event.hot_score.desc()).limit(limit - len(hot_events)).all() hot_events.extend(additional_events) events_data = [] for event in hot_events: events_data.append({ 'id': event.id, 'title': event.title, 'description': event.description, 'importance': event.importance, 'created_at': event.created_at.isoformat() if event.created_at else None, 'related_avg_chg': event.related_avg_chg, 'related_max_chg': event.related_max_chg, 'expectation_surprise_score': event.expectation_surprise_score, 'creator': { 'username': event.creator.username if event.creator else 'Anonymous' } }) return jsonify({'success': True, 'data': events_data}) except Exception as e: return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/events/mainline', methods=['GET']) def get_events_by_mainline(): """ 获取按主线(lv1/lv2/lv3概念)分组的事件列表 逻辑: 1. 根据筛选条件获取事件列表 2. 通过 related_concepts 表关联概念 3. 调用 concept-api/hierarchy 获取概念的层级归属 4. 按指定层级分组返回 参数: - recent_days: 近N天(默认7天) - importance: 重要性筛选(S,A,B,C 或 all) - group_by: 分组方式 (lv1/lv2/lv3/具体概念ID如L2_AI_INFRA),默认lv2 返回: { "success": true, "data": { "mainlines": [ { "lv2_id": "L2_AI_INFRA", "lv2_name": "AI基础设施 (算力/CPO/PCB)", "lv1_name": "TMT (科技/媒体/通信)", "event_count": 15, "events": [...] }, ... ], "total_events": 100, "ungrouped_count": 5, "group_by": "lv2", "hierarchy_options": {...} // 层级选项供前端下拉框使用 } } """ try: import requests from datetime import datetime, timedelta from sqlalchemy.orm import joinedload from sqlalchemy import exists # 获取请求参数 recent_days = request.args.get('recent_days', 7, type=int) importance = request.args.get('importance', 'all') group_by = request.args.get('group_by', 'lv2') # lv1/lv2/lv3 或具体ID # 计算日期范围 since_date = datetime.now() - timedelta(days=recent_days) # ==================== 1. 获取概念层级映射 ==================== # 调用 concept-api 获取层级结构 concept_hierarchy_map = {} # { concept_name: { lv1, lv2, lv3, lv1_id, lv2_id, lv3_id } } hierarchy_options = {'lv1': [], 'lv2': [], 'lv3': []} # 层级选项供前端下拉框使用 try: # 从本地文件读取概念层级结构 import json import os hierarchy_file = os.path.join(os.path.dirname(__file__), 'concept_hierarchy_v3.json') with open(hierarchy_file, 'r', encoding='utf-8') as f: hierarchy_data = json.load(f) hierarchy_list = hierarchy_data.get('hierarchy', []) # 构建概念名称 -> 完整层级映射 + 层级选项 # 结构: L1 -> L2 -> L3 -> concepts (concepts 只在 L3 层) for lv1 in hierarchy_list: lv1_name = lv1.get('lv1', '') lv1_id = lv1.get('lv1_id', '') # 添加 lv1 选项 if lv1_id and lv1_name: hierarchy_options['lv1'].append({ 'id': lv1_id, 'name': lv1_name }) for lv2 in lv1.get('children', []) or []: lv2_name = lv2.get('lv2', '') lv2_id = lv2.get('lv2_id', '') # 添加 lv2 选项 if lv2_id and lv2_name: hierarchy_options['lv2'].append({ 'id': lv2_id, 'name': lv2_name, 'lv1_id': lv1_id, 'lv1_name': lv1_name }) # L3 层包含 concepts for lv3 in lv2.get('children', []) or []: lv3_name = lv3.get('lv3', '') lv3_id = lv3.get('lv3_id', '') # 添加 lv3 选项 if lv3_id and lv3_name: hierarchy_options['lv3'].append({ 'id': lv3_id, 'name': lv3_name, 'lv2_id': lv2_id, 'lv2_name': lv2_name, 'lv1_id': lv1_id, 'lv1_name': lv1_name }) for concept in lv3.get('concepts', []) or []: concept_name = concept if isinstance(concept, str) else concept.get('name', '') if concept_name: concept_hierarchy_map[concept_name] = { 'lv1': lv1_name, 'lv1_id': lv1_id, 'lv2': lv2_name, 'lv2_id': lv2_id, 'lv3': lv3_name, 'lv3_id': lv3_id } app.logger.info(f'[mainline] 加载概念层级映射: {len(concept_hierarchy_map)} 个概念, lv1: {len(hierarchy_options["lv1"])}, lv2: {len(hierarchy_options["lv2"])}, lv3: {len(hierarchy_options["lv3"])}') except Exception as e: app.logger.warning(f'[mainline] 获取概念层级失败: {e}') # ==================== 2. 查询事件及其关联概念 ==================== query = Event.query.options(joinedload(Event.creator)) # 只返回有关联股票的事件 query = query.filter( exists().where(RelatedStock.event_id == Event.id) ) # 状态筛选 query = query.filter(Event.status == 'active') # 日期筛选 query = query.filter(Event.created_at >= since_date) # 重要性筛选 if importance != 'all': if ',' in importance: importance_list = [imp.strip() for imp in importance.split(',') if imp.strip()] query = query.filter(Event.importance.in_(importance_list)) else: query = query.filter(Event.importance == importance) # 按时间倒序 query = query.order_by(Event.created_at.desc()) # 获取事件(提高限制以支持主线模式显示更多数据) events = query.limit(2000).all() app.logger.info(f'[mainline] 查询到 {len(events)} 个事件') # ==================== 3. 获取事件的关联概念 ==================== event_ids = [e.id for e in events] # 批量查询 related_concepts related_concepts_query = db.session.query( RelatedConcepts.event_id, RelatedConcepts.concept ).filter(RelatedConcepts.event_id.in_(event_ids)).all() # 构建 event_id -> concepts 映射 event_concepts_map = {} # { event_id: [concept1, concept2, ...] } for event_id, concept in related_concepts_query: if event_id not in event_concepts_map: event_concepts_map[event_id] = [] event_concepts_map[event_id].append(concept) app.logger.warning(f'[mainline] 查询到 {len(related_concepts_query)} 条概念关联') # 调试:输出一些 related_concepts 的样本 sample_concepts = list(set([c for _, c in related_concepts_query[:100]]))[:20] app.logger.warning(f'[mainline] related_concepts 样本: {sample_concepts}') # 调试:输出一些 hierarchy 的样本 hierarchy_sample = list(concept_hierarchy_map.keys())[:20] app.logger.warning(f'[mainline] hierarchy 概念样本: {hierarchy_sample}') # ==================== 4. 按 lv2 分组事件 ==================== mainline_groups = {} # { lv2_id: { info: {...}, events: [...] } } ungrouped_events = [] def find_concept_hierarchy(concept_name): """查找概念的层级信息(支持多种匹配方式)""" if not concept_name: return None # 1. 精确匹配 if concept_name in concept_hierarchy_map: return concept_hierarchy_map[concept_name] # 2. 去掉常见前缀后缀再匹配 # 例如 "消费电子-玄玑感知系统" -> "消费电子" concept_clean = concept_name.replace('-', ' ').replace('_', ' ').split()[0] if '-' in concept_name or '_' in concept_name else concept_name if concept_clean in concept_hierarchy_map: return concept_hierarchy_map[concept_clean] # 3. 包含匹配(双向) for key in concept_hierarchy_map: if concept_name in key or key in concept_name: return concept_hierarchy_map[key] # 4. 关键词匹配 - 提取关键词进行匹配 # 例如 "华为鸿蒙" 能匹配到包含 "华为" 或 "鸿蒙" 的 hierarchy keywords_to_check = ['华为', '鸿蒙', '特斯拉', '比亚迪', '英伟达', '苹果', '小米', 'AI', '机器人', '光伏', '储能', '锂电', '芯片', '半导体', '无人机', '低空', '汽车', '医药', '消费电子', '算力', 'GPU', '大模型', '智能体', 'DeepSeek', 'KIMI', '固态电池'] for kw in keywords_to_check: if kw in concept_name: # 找 hierarchy 中包含这个关键词的 for key in concept_hierarchy_map: if kw in key: return concept_hierarchy_map[key] return None # 判断分组方式 is_specific_id = group_by.startswith('L1_') or group_by.startswith('L2_') or group_by.startswith('L3_') for event in events: concepts = event_concepts_map.get(event.id, []) # 找出该事件所属的层级信息 event_groups = set() # 存储 (group_id, group_name, parent_info) 元组 for concept in concepts: hierarchy = find_concept_hierarchy(concept) if hierarchy: if is_specific_id: # 筛选特定概念ID if group_by.startswith('L1_') and hierarchy['lv1_id'] == group_by: event_groups.add((hierarchy['lv2_id'], hierarchy['lv2'], hierarchy['lv1'], hierarchy.get('lv3', ''), hierarchy.get('lv3_id', ''))) elif group_by.startswith('L2_') and hierarchy['lv2_id'] == group_by: event_groups.add((hierarchy.get('lv3_id', hierarchy['lv2_id']), hierarchy.get('lv3', hierarchy['lv2']), hierarchy['lv2'], hierarchy['lv1'], '')) elif group_by.startswith('L3_') and hierarchy.get('lv3_id') == group_by: event_groups.add((hierarchy['lv3_id'], hierarchy['lv3'], hierarchy['lv2'], hierarchy['lv1'], '')) elif group_by == 'lv1': event_groups.add((hierarchy['lv1_id'], hierarchy['lv1'], '', '', '')) elif group_by == 'lv3': if hierarchy.get('lv3_id'): event_groups.add((hierarchy['lv3_id'], hierarchy['lv3'], hierarchy['lv2'], hierarchy['lv1'], '')) else: # 默认 lv2 event_groups.add((hierarchy['lv2_id'], hierarchy['lv2'], hierarchy['lv1'], '', '')) # 事件数据 event_data = { 'id': event.id, 'title': event.title, 'description': event.description, 'importance': event.importance, 'created_at': event.created_at.isoformat() if event.created_at else None, 'related_avg_chg': event.related_avg_chg, 'related_max_chg': event.related_max_chg, 'expectation_surprise_score': event.expectation_surprise_score, 'hot_score': event.hot_score, 'related_concepts': [{'concept': c} for c in concepts], 'creator': { 'username': event.creator.username if event.creator else 'Anonymous' } } if event_groups: # 添加到每个相关的分组 for group_info in event_groups: group_id = group_info[0] group_name = group_info[1] parent_name = group_info[2] if len(group_info) > 2 else '' grandparent_name = group_info[3] if len(group_info) > 3 else '' if group_id not in mainline_groups: mainline_groups[group_id] = { 'group_id': group_id, 'group_name': group_name, 'parent_name': parent_name, 'grandparent_name': grandparent_name, # 兼容旧字段 'lv2_id': group_id if group_by == 'lv2' or group_by.startswith('L1_') else None, 'lv2_name': group_name if group_by == 'lv2' or group_by.startswith('L1_') else parent_name, 'lv1_name': parent_name if group_by == 'lv2' else grandparent_name, 'lv3_name': group_name if group_by == 'lv3' or group_by.startswith('L2_') else None, 'events': [] } mainline_groups[group_id]['events'].append(event_data) else: ungrouped_events.append(event_data) # ==================== 5. 获取 lv2 概念涨跌幅 ==================== lv2_price_map = {} try: # 获取所有 lv2 名称 lv2_names = [group['lv2_name'] for group in mainline_groups.values() if group.get('lv2_name')] if lv2_names: # 数据库中的 concept_name 带有 "[二级] " 前缀,需要添加前缀来匹配 lv2_names_with_prefix = [f'[二级] {name}' for name in lv2_names] # 查询 concept_daily_stats 表获取最新涨跌幅 price_sql = text(''' SELECT concept_name, avg_change_pct, trade_date FROM concept_daily_stats WHERE concept_type = 'lv2' AND concept_name IN :names AND trade_date = ( SELECT MAX(trade_date) FROM concept_daily_stats WHERE concept_type = 'lv2' ) ''') price_result = db.session.execute(price_sql, {'names': tuple(lv2_names_with_prefix)}).fetchall() for row in price_result: # 去掉 "[二级] " 前缀,用原始名称作为 key original_name = row.concept_name.replace('[二级] ', '') if row.concept_name else '' lv2_price_map[original_name] = { 'avg_change_pct': float(row.avg_change_pct) if row.avg_change_pct else None, 'trade_date': str(row.trade_date) if row.trade_date else None } app.logger.info(f'[mainline] 获取 lv2 涨跌幅: {len(lv2_price_map)} 条, lv2_names 数量: {len(lv2_names)}') except Exception as price_err: app.logger.warning(f'[mainline] 获取 lv2 涨跌幅失败: {price_err}') # ==================== 6. 整理返回数据 ==================== mainlines = [] for group_id, group in mainline_groups.items(): # 按时间倒序排列(不限制数量) group['events'] = sorted( group['events'], key=lambda x: x['created_at'] or '', reverse=True ) group['event_count'] = len(group['events']) # 添加涨跌幅数据(目前只支持 lv2) lv2_name = group.get('lv2_name', '') or group.get('group_name', '') if lv2_name in lv2_price_map: group['avg_change_pct'] = lv2_price_map[lv2_name]['avg_change_pct'] group['price_date'] = lv2_price_map[lv2_name]['trade_date'] else: group['avg_change_pct'] = None group['price_date'] = None mainlines.append(group) # 按事件数量排序 mainlines.sort(key=lambda x: x['event_count'], reverse=True) return jsonify({ 'success': True, 'data': { 'mainlines': mainlines, 'total_events': len(events), 'mainline_count': len(mainlines), 'ungrouped_count': len(ungrouped_events), 'group_by': group_by, 'hierarchy_options': hierarchy_options, # 调试信息 '_debug': { 'hierarchy_count': len(concept_hierarchy_map), 'hierarchy_sample': hierarchy_sample[:10], 'related_concepts_sample': sample_concepts[:10], 'related_concepts_count': len(related_concepts_query) } } }) except Exception as e: app.logger.error(f'[mainline] 错误: {e}', exc_info=True) return jsonify({ 'success': False, 'error': str(e) }), 500 @app.route('/api/events/keywords/popular', methods=['GET']) def get_popular_keywords(): """获取热门关键词""" try: limit = request.args.get('limit', 20, type=int) sql = ''' WITH RECURSIVE \ numbers AS (SELECT 0 as n \ UNION ALL \ SELECT n + 1 \ FROM numbers \ WHERE n < 100), \ json_array AS (SELECT JSON_UNQUOTE(JSON_EXTRACT(e.keywords, CONCAT('$[', n.n, ']'))) as keyword, \ COUNT(*) as count FROM event e CROSS JOIN numbers n WHERE e.status = 'active' AND JSON_EXTRACT(e.keywords \ , CONCAT('$[' \ , n.n \ , ']')) IS NOT NULL GROUP BY JSON_UNQUOTE(JSON_EXTRACT(e.keywords, CONCAT('$[', n.n, ']'))) HAVING keyword IS NOT NULL ) SELECT keyword, count FROM json_array ORDER BY count DESC, keyword LIMIT :limit \ ''' result = db.session.execute(text(sql), {'limit': limit}).all() keywords_data = [{'keyword': row.keyword, 'count': row.count} for row in result] return jsonify({'success': True, 'data': keywords_data}) except Exception as e: return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/events//sankey-data') def get_event_sankey_data(event_id): """ 获取事件桑基图数据 (最终优化版) - 处理重名节点 - 检测并打破循环依赖 """ flows = EventSankeyFlow.query.filter_by(event_id=event_id).order_by( EventSankeyFlow.source_level, EventSankeyFlow.target_level ).all() if not flows: return jsonify({'success': False, 'message': '暂无桑基图数据'}) nodes_map = {} links = [] type_colors = { 'event': '#ff4757', 'policy': '#10ac84', 'technology': '#ee5a6f', 'industry': '#00d2d3', 'company': '#54a0ff', 'product': '#ffd93d' } # --- 1. 识别并处理重名节点 (与上一版相同) --- all_node_keys = set() name_counts = {} for flow in flows: source_key = f"{flow.source_node}|{flow.source_level}" target_key = f"{flow.target_node}|{flow.target_level}" all_node_keys.add(source_key) all_node_keys.add(target_key) name_counts.setdefault(flow.source_node, set()).add(flow.source_level) name_counts.setdefault(flow.target_node, set()).add(flow.target_level) duplicate_names = {name for name, levels in name_counts.items() if len(levels) > 1} for flow in flows: source_key = f"{flow.source_node}|{flow.source_level}" if source_key not in nodes_map: display_name = f"{flow.source_node} (L{flow.source_level})" if flow.source_node in duplicate_names else flow.source_node nodes_map[source_key] = {'name': display_name, 'type': flow.source_type, 'level': flow.source_level, 'color': type_colors.get(flow.source_type)} target_key = f"{flow.target_node}|{flow.target_level}" if target_key not in nodes_map: display_name = f"{flow.target_node} (L{flow.target_level})" if flow.target_node in duplicate_names else flow.target_node nodes_map[target_key] = {'name': display_name, 'type': flow.target_type, 'level': flow.target_level, 'color': type_colors.get(flow.target_type)} links.append({ 'source_key': source_key, 'target_key': target_key, 'value': float(flow.flow_value), 'ratio': float(flow.flow_ratio), 'transmission_path': flow.transmission_path, 'impact_description': flow.impact_description, 'evidence_strength': flow.evidence_strength }) # --- 2. 循环检测与处理 --- # 构建邻接表 adj = defaultdict(list) for link in links: adj[link['source_key']].append(link['target_key']) # 深度优先搜索(DFS)来检测循环 path = set() # 记录当前递归路径上的节点 visited = set() # 记录所有访问过的节点 back_edges = set() # 记录导致循环的"回流边" def detect_cycle_util(node): path.add(node) visited.add(node) for neighbour in adj.get(node, []): if neighbour in path: # 发现了循环,记录这条回流边 (target, source) back_edges.add((neighbour, node)) elif neighbour not in visited: detect_cycle_util(neighbour) path.remove(node) # 从所有节点开始检测 for node_key in list(adj.keys()): if node_key not in visited: detect_cycle_util(node_key) # 过滤掉导致循环的边 if back_edges: print(f"检测到并移除了 {len(back_edges)} 条循环边: {back_edges}") valid_links_no_cycle = [] for link in links: if (link['source_key'], link['target_key']) not in back_edges and \ (link['target_key'], link['source_key']) not in back_edges: # 移除非严格意义上的双向边 valid_links_no_cycle.append(link) # --- 3. 构建最终的 JSON 响应 (与上一版相似) --- node_list = [] node_index_map = {} sorted_node_keys = sorted(nodes_map.keys(), key=lambda k: (nodes_map[k]['level'], nodes_map[k]['name'])) for i, key in enumerate(sorted_node_keys): node_list.append(nodes_map[key]) node_index_map[key] = i final_links = [] for link in valid_links_no_cycle: source_idx = node_index_map.get(link['source_key']) target_idx = node_index_map.get(link['target_key']) if source_idx is not None and target_idx is not None: # 移除临时的 key,只保留 ECharts 需要的字段 link.pop('source_key', None) link.pop('target_key', None) link['source'] = source_idx link['target'] = target_idx final_links.append(link) # ... (统计信息计算部分保持不变) ... stats = { 'total_nodes': len(node_list), 'total_flows': len(final_links), 'total_flow_value': sum(link['value'] for link in final_links), 'max_level': max((node['level'] for node in node_list), default=0), 'node_type_counts': {ntype: sum(1 for n in node_list if n['type'] == ntype) for ntype in type_colors} } return jsonify({ 'success': True, 'data': {'nodes': node_list, 'links': final_links, 'stats': stats} }) # 优化后的传导链分析 API @app.route('/api/events//chain-analysis') def get_event_chain_analysis(event_id): """获取事件传导链分析数据""" nodes = EventTransmissionNode.query.filter_by(event_id=event_id).all() if not nodes: return jsonify({'success': False, 'message': '暂无传导链分析数据'}) edges = EventTransmissionEdge.query.filter_by(event_id=event_id).all() # 过滤孤立节点 connected_node_ids = set() for edge in edges: connected_node_ids.add(edge.from_node_id) connected_node_ids.add(edge.to_node_id) # 只保留有连接的节点 connected_nodes = [node for node in nodes if node.id in connected_node_ids] if not connected_nodes: return jsonify({'success': False, 'message': '所有节点都是孤立的,暂无传导关系'}) # 节点分类,用于力导向图的图例 categories = { 'event': "事件", 'industry': "行业", 'company': "公司", 'policy': "政策", 'technology': "技术", 'market': "市场", 'other': "其他" } # 计算每个节点的连接数 node_connection_count = {} for node in connected_nodes: count = sum(1 for edge in edges if edge.from_node_id == node.id or edge.to_node_id == node.id) node_connection_count[node.id] = count nodes_data = [] for node in connected_nodes: connection_count = node_connection_count[node.id] nodes_data.append({ 'id': str(node.id), 'name': node.node_name, 'value': node.importance_score, # 用于控制节点大小的基础值 'category': categories.get(node.node_type), 'extra': { 'node_type': node.node_type, 'description': node.node_description, 'importance_score': node.importance_score, 'stock_code': node.stock_code, 'is_main_event': node.is_main_event, 'connection_count': connection_count, # 添加连接数信息 } }) edges_data = [] for edge in edges: # 确保边的两端节点都在连接节点列表中 if edge.from_node_id in connected_node_ids and edge.to_node_id in connected_node_ids: edges_data.append({ 'source': str(edge.from_node_id), 'target': str(edge.to_node_id), 'value': edge.strength, # 用于控制边的宽度 'extra': { 'transmission_type': edge.transmission_type, 'transmission_mechanism': edge.transmission_mechanism, 'direction': edge.direction, 'strength': edge.strength, 'impact': edge.impact, 'is_circular': edge.is_circular, } }) # 重新计算统计信息(基于连接的节点和边) stats = { 'total_nodes': len(connected_nodes), 'total_edges': len(edges_data), 'node_types': {cat: sum(1 for n in connected_nodes if n.node_type == node_type) for node_type, cat in categories.items()}, 'edge_types': {edge.transmission_type: sum(1 for e in edges_data if e['extra']['transmission_type'] == edge.transmission_type) for edge in edges}, 'avg_importance': sum(node.importance_score for node in connected_nodes) / len( connected_nodes) if connected_nodes else 0, 'avg_strength': sum(edge.strength for edge in edges) / len(edges) if edges else 0 } return jsonify({ 'success': True, 'data': { 'nodes': nodes_data, 'edges': edges_data, 'categories': list(categories.values()), 'stats': stats } }) @app.route('/api/events//chain-node/', methods=['GET']) @cross_origin() def get_chain_node_detail(event_id, node_id): """获取传导链节点及其直接关联节点的详细信息""" node = db.session.get(EventTransmissionNode, node_id) if not node or node.event_id != event_id: return jsonify({'success': False, 'message': '节点不存在'}) # 验证节点是否为孤立节点 total_connections = (EventTransmissionEdge.query.filter_by(from_node_id=node_id).count() + EventTransmissionEdge.query.filter_by(to_node_id=node_id).count()) if total_connections == 0 and not node.is_main_event: return jsonify({'success': False, 'message': '该节点为孤立节点,无连接关系'}) # 找出影响当前节点的父节点 parents_info = [] incoming_edges = EventTransmissionEdge.query.filter_by(to_node_id=node_id).all() for edge in incoming_edges: parent = db.session.get(EventTransmissionNode, edge.from_node_id) if parent: parents_info.append({ 'id': parent.id, 'name': parent.node_name, 'type': parent.node_type, 'direction': edge.direction, 'strength': edge.strength, 'transmission_type': edge.transmission_type, 'transmission_mechanism': edge.transmission_mechanism, # 修复字段名 'is_circular': edge.is_circular, 'impact': edge.impact }) # 找出被当前节点影响的子节点 children_info = [] outgoing_edges = EventTransmissionEdge.query.filter_by(from_node_id=node_id).all() for edge in outgoing_edges: child = db.session.get(EventTransmissionNode, edge.to_node_id) if child: children_info.append({ 'id': child.id, 'name': child.node_name, 'type': child.node_type, 'direction': edge.direction, 'strength': edge.strength, 'transmission_type': edge.transmission_type, 'transmission_mechanism': edge.transmission_mechanism, # 修复字段名 'is_circular': edge.is_circular, 'impact': edge.impact }) node_data = { 'id': node.id, 'name': node.node_name, 'type': node.node_type, 'description': node.node_description, 'importance_score': node.importance_score, 'stock_code': node.stock_code, 'is_main_event': node.is_main_event, 'total_connections': total_connections, 'incoming_connections': len(incoming_edges), 'outgoing_connections': len(outgoing_edges) } return jsonify({ 'success': True, 'data': { 'node': node_data, 'parents': parents_info, 'children': children_info } }) @app.route('/api/events//posts', methods=['GET']) def get_event_posts(event_id): """获取事件下的帖子""" try: sort_type = request.args.get('sort', 'latest') page = request.args.get('page', 1, type=int) per_page = request.args.get('per_page', 20, type=int) # 查询事件下的帖子 query = Post.query.filter_by(event_id=event_id, status='active') if sort_type == 'hot': query = query.order_by(Post.likes_count.desc(), Post.created_at.desc()) else: # latest query = query.order_by(Post.created_at.desc()) # 分页 pagination = query.paginate(page=page, per_page=per_page, error_out=False) posts = pagination.items posts_data = [] for post in posts: post_dict = { 'id': post.id, 'event_id': post.event_id, 'user_id': post.user_id, 'title': post.title, 'content': post.content, 'content_type': post.content_type, 'created_at': post.created_at.isoformat(), 'updated_at': post.updated_at.isoformat(), 'likes_count': post.likes_count, 'comments_count': post.comments_count, 'view_count': post.view_count, 'is_top': post.is_top, 'user': { 'id': post.user.id, 'username': post.user.username, 'avatar_url': post.user.avatar_url } if post.user else None, 'liked': False # 后续可以根据当前用户判断 } posts_data.append(post_dict) return jsonify({ 'success': True, 'data': posts_data, 'pagination': { 'page': page, 'per_page': per_page, 'total': pagination.total, 'pages': pagination.pages } }) except Exception as e: print(f"获取帖子失败: {e}") return jsonify({ 'success': False, 'error': str(e) }), 500 @app.route('/api/events//posts', methods=['POST']) @login_required def create_event_post(event_id): """在事件下创建帖子""" try: data = request.get_json() content = data.get('content', '').strip() title = data.get('title', '').strip() content_type = data.get('content_type', 'text') if not content: return jsonify({ 'success': False, 'message': '帖子内容不能为空' }), 400 # 创建新帖子 post = Post( event_id=event_id, user_id=current_user.id, title=title, content=content, content_type=content_type ) db.session.add(post) # 更新事件的帖子数 event = Event.query.get(event_id) if event: event.post_count = Post.query.filter_by(event_id=event_id, status='active').count() # 更新用户发帖数 current_user.post_count = (current_user.post_count or 0) + 1 db.session.commit() return jsonify({ 'success': True, 'data': { 'id': post.id, 'event_id': post.event_id, 'user_id': post.user_id, 'title': post.title, 'content': post.content, 'content_type': post.content_type, 'created_at': post.created_at.isoformat(), 'user': { 'id': current_user.id, 'nickname': current_user.nickname, # 添加昵称,与导航区保持一致 'username': current_user.username, 'avatar_url': current_user.avatar_url } }, 'message': '帖子发布成功' }) except Exception as e: db.session.rollback() print(f"创建帖子失败: {e}") return jsonify({ 'success': False, 'message': str(e) }), 500 @app.route('/api/posts//comments', methods=['GET']) def get_post_comments(post_id): """获取帖子的评论""" try: sort_type = request.args.get('sort', 'latest') # 查询帖子的顶级评论(非回复) query = Comment.query.filter_by(post_id=post_id, parent_id=None, status='active') if sort_type == 'hot': comments = query.order_by(Comment.likes_count.desc(), Comment.created_at.desc()).all() else: # latest comments = query.order_by(Comment.created_at.desc()).all() comments_data = [] for comment in comments: comment_dict = { 'id': comment.id, 'post_id': comment.post_id, 'user_id': comment.user_id, 'content': comment.content, 'created_at': comment.created_at.isoformat(), 'updated_at': comment.updated_at.isoformat(), 'likes_count': comment.likes_count, 'user': { 'id': comment.user.id, 'username': comment.user.username, 'avatar_url': comment.user.avatar_url } if comment.user else None, 'replies': [] # 加载回复 } # 加载回复 replies = Comment.query.filter_by(parent_id=comment.id, status='active').order_by(Comment.created_at).all() for reply in replies: reply_dict = { 'id': reply.id, 'post_id': reply.post_id, 'user_id': reply.user_id, 'content': reply.content, 'parent_id': reply.parent_id, 'created_at': reply.created_at.isoformat(), 'likes_count': reply.likes_count, 'user': { 'id': reply.user.id, 'username': reply.user.username, 'avatar_url': reply.user.avatar_url } if reply.user else None } comment_dict['replies'].append(reply_dict) comments_data.append(comment_dict) return jsonify({ 'success': True, 'data': comments_data }) except Exception as e: print(f"获取评论失败: {e}") return jsonify({ 'success': False, 'error': str(e) }), 500 @app.route('/api/posts//comments', methods=['POST']) @login_required def create_post_comment(post_id): """在帖子下创建评论""" try: data = request.get_json() content = data.get('content', '').strip() parent_id = data.get('parent_id') if not content: return jsonify({ 'success': False, 'message': '评论内容不能为空' }), 400 # 创建新评论 comment = Comment( post_id=post_id, user_id=current_user.id, content=content, parent_id=parent_id ) db.session.add(comment) # 更新帖子评论数 post = Post.query.get(post_id) if post: post.comments_count = Comment.query.filter_by(post_id=post_id, status='active').count() # 更新用户评论数 current_user.comment_count = (current_user.comment_count or 0) + 1 db.session.commit() return jsonify({ 'success': True, 'data': { 'id': comment.id, 'post_id': comment.post_id, 'user_id': comment.user_id, 'content': comment.content, 'parent_id': comment.parent_id, 'created_at': comment.created_at.isoformat(), 'user': { 'id': current_user.id, 'username': current_user.username, 'avatar_url': current_user.avatar_url } }, 'message': '评论发布成功' }) except Exception as e: db.session.rollback() print(f"创建评论失败: {e}") return jsonify({ 'success': False, 'message': str(e) }), 500 # 兼容旧的评论接口,转换为帖子模式 @app.route('/api/events//comments', methods=['GET']) def get_event_comments(event_id): """获取事件评论(兼容旧接口)""" # 将事件评论转换为获取事件下所有帖子的评论 return get_event_posts(event_id) @app.route('/api/events//comments', methods=['POST']) @login_required def add_event_comment(event_id): """添加事件评论(兼容旧接口)""" try: data = request.get_json() content = data.get('content', '').strip() parent_id = data.get('parent_id') if not content: return jsonify({ 'success': False, 'message': '评论内容不能为空' }), 400 # 如果有 parent_id,说明是回复,需要找到对应的帖子 if parent_id: # 这是一个回复,需要将其转换为对应帖子的评论 # 首先需要找到 parent_id 对应的帖子 # 这里假设旧的 parent_id 是之前的 EventComment id # 需要在数据迁移时处理这个映射关系 return jsonify({ 'success': False, 'message': '回复功能正在升级中,请稍后再试' }), 503 # 如果没有 parent_id,说明是顶级评论,创建为新帖子 post = Post( event_id=event_id, user_id=current_user.id, content=content, content_type='text' ) db.session.add(post) # 更新事件的帖子数 event = Event.query.get(event_id) if event: event.post_count = Post.query.filter_by(event_id=event_id, status='active').count() # 更新用户发帖数 current_user.post_count = (current_user.post_count or 0) + 1 db.session.commit() # 返回兼容旧接口的数据格式 return jsonify({ 'success': True, 'data': { 'id': post.id, 'event_id': post.event_id, 'user_id': post.user_id, 'author': current_user.username, 'content': post.content, 'parent_id': None, 'likes': 0, 'created_at': post.created_at.isoformat(), 'status': 'active', 'user': { 'id': current_user.id, 'username': current_user.username, 'avatar_url': current_user.avatar_url }, 'replies': [] }, 'message': '评论发布成功' }) except Exception as e: db.session.rollback() print(f"添加事件评论失败: {e}") return jsonify({ 'success': False, 'message': str(e) }), 500 # ==================== WebSocket 事件处理器(实时事件推送) ==================== @socketio.on('connect') def handle_connect(): """客户端连接事件""" print(f'\n[WebSocket DEBUG] ========== 客户端连接 ==========') print(f'[WebSocket DEBUG] Socket ID: {request.sid}') print(f'[WebSocket DEBUG] Remote Address: {request.remote_addr if hasattr(request, "remote_addr") else "N/A"}') print(f'[WebSocket] 客户端已连接: {request.sid}') emit('connection_response', { 'status': 'connected', 'sid': request.sid, 'message': '已连接到事件推送服务' }) print(f'[WebSocket DEBUG] ✓ 已发送 connection_response') print(f'[WebSocket DEBUG] ========== 连接完成 ==========\n') @socketio.on('subscribe_events') def handle_subscribe(data): """ 客户端订阅事件推送 data: { 'event_type': 'all' | 'policy' | 'market' | 'tech' | ..., 'importance': 'all' | 'S' | 'A' | 'B' | 'C', 'filters': {...} # 可选的其他筛选条件 } """ try: print(f'\n[WebSocket DEBUG] ========== 收到订阅请求 ==========') print(f'[WebSocket DEBUG] Socket ID: {request.sid}') print(f'[WebSocket DEBUG] 订阅数据: {data}') event_type = data.get('event_type', 'all') importance = data.get('importance', 'all') print(f'[WebSocket DEBUG] 事件类型: {event_type}') print(f'[WebSocket DEBUG] 重要性: {importance}') # 加入对应的房间 room_name = f"events_{event_type}" print(f'[WebSocket DEBUG] 准备加入房间: {room_name}') join_room(room_name) print(f'[WebSocket DEBUG] ✓ 已加入房间: {room_name}') print(f'[WebSocket] 客户端 {request.sid} 订阅了房间: {room_name}') response_data = { 'success': True, 'room': room_name, 'event_type': event_type, 'importance': importance, 'message': f'已订阅 {event_type} 类型的事件推送' } print(f'[WebSocket DEBUG] 准备发送 subscription_confirmed: {response_data}') emit('subscription_confirmed', response_data) print(f'[WebSocket DEBUG] ✓ 已发送 subscription_confirmed') print(f'[WebSocket DEBUG] ========== 订阅完成 ==========\n') except Exception as e: print(f'[WebSocket ERROR] 订阅失败: {e}') import traceback traceback.print_exc() emit('subscription_error', { 'success': False, 'error': str(e) }) @socketio.on('unsubscribe_events') def handle_unsubscribe(data): """取消订阅事件推送""" try: print(f'\n[WebSocket DEBUG] ========== 收到取消订阅请求 ==========') print(f'[WebSocket DEBUG] Socket ID: {request.sid}') print(f'[WebSocket DEBUG] 数据: {data}') event_type = data.get('event_type', 'all') room_name = f"events_{event_type}" print(f'[WebSocket DEBUG] 准备离开房间: {room_name}') leave_room(room_name) print(f'[WebSocket DEBUG] ✓ 已离开房间: {room_name}') print(f'[WebSocket] 客户端 {request.sid} 取消订阅房间: {room_name}') emit('unsubscription_confirmed', { 'success': True, 'room': room_name, 'message': f'已取消订阅 {event_type} 类型的事件推送' }) print(f'[WebSocket DEBUG] ========== 取消订阅完成 ==========\n') except Exception as e: print(f'[WebSocket ERROR] 取消订阅失败: {e}') import traceback traceback.print_exc() emit('unsubscription_error', { 'success': False, 'error': str(e) }) @socketio.on('disconnect') def handle_disconnect(): """客户端断开连接事件""" print(f'\n[WebSocket DEBUG] ========== 客户端断开 ==========') print(f'[WebSocket DEBUG] Socket ID: {request.sid}') print(f'[WebSocket] 客户端已断开: {request.sid}') print(f'[WebSocket DEBUG] ========== 断开完成 ==========\n') # ==================== WebSocket 辅助函数 ==================== def broadcast_new_event(event): """ 广播新事件到所有订阅的客户端 在创建新事件时调用此函数 Args: event: Event 模型实例 """ try: print(f'\n[WebSocket DEBUG] ========== 广播新事件 ==========') print(f'[WebSocket DEBUG] 事件ID: {event.id}') print(f'[WebSocket DEBUG] 事件标题: {event.title}') print(f'[WebSocket DEBUG] 事件类型: {event.event_type}') print(f'[WebSocket DEBUG] 重要性: {event.importance}') event_data = { 'id': event.id, 'title': event.title, 'description': event.description, 'event_type': event.event_type, 'importance': event.importance, 'status': event.status, 'created_at': event.created_at.isoformat() if event.created_at else None, 'hot_score': event.hot_score, 'view_count': event.view_count, 'related_avg_chg': event.related_avg_chg, 'related_max_chg': event.related_max_chg, 'keywords': event.keywords_list if hasattr(event, 'keywords_list') else event.keywords, } print(f'[WebSocket DEBUG] 准备发送的数据: {event_data}') # 发送到所有订阅者(all 房间) print(f'[WebSocket DEBUG] 正在发送到房间: events_all') socketio.emit('new_event', event_data, room='events_all', namespace='/') print(f'[WebSocket DEBUG] ✓ 已发送到 events_all') # 发送到特定类型订阅者 if event.event_type: room_name = f"events_{event.event_type}" print(f'[WebSocket DEBUG] 正在发送到房间: {room_name}') socketio.emit('new_event', event_data, room=room_name, namespace='/') print(f'[WebSocket DEBUG] ✓ 已发送到 {room_name}') print(f'[WebSocket] 已推送新事件到房间: events_all, {room_name}') else: print(f'[WebSocket] 已推送新事件到房间: events_all') # 清除事件列表缓存,确保用户刷新页面时获取最新数据 clear_events_cache() print(f'[WebSocket DEBUG] ========== 广播完成 ==========\n') except Exception as e: print(f'[WebSocket ERROR] 推送新事件失败: {e}') import traceback traceback.print_exc() # ==================== WebSocket 轮询机制(检测新事件) ==================== # Redis Key 用于多 Worker 协调 REDIS_KEY_LAST_MAX_EVENT_ID = 'vf:event_polling:last_max_id' REDIS_KEY_POLLING_LOCK = 'vf:event_polling:lock' REDIS_KEY_PENDING_EVENTS = 'vf:event_polling:pending_events' # 待推送事件集合(没有 related_stocks 的事件) # 本地缓存(减少 Redis 查询) _local_last_max_event_id = 0 _polling_initialized = False def _add_pending_event(event_id): """将事件添加到待推送列表""" try: redis_client.sadd(REDIS_KEY_PENDING_EVENTS, str(event_id)) except Exception as e: print(f'[轮询 WARN] 添加待推送事件失败: {e}') def _remove_pending_event(event_id): """从待推送列表移除事件""" try: redis_client.srem(REDIS_KEY_PENDING_EVENTS, str(event_id)) except Exception as e: print(f'[轮询 WARN] 移除待推送事件失败: {e}') def _get_pending_events(): """获取所有待推送事件ID""" try: pending = redis_client.smembers(REDIS_KEY_PENDING_EVENTS) return [int(eid) for eid in pending] if pending else [] except Exception as e: print(f'[轮询 WARN] 获取待推送事件失败: {e}') return [] def _get_last_max_event_id(): """从 Redis 获取最大事件 ID""" try: val = redis_client.get(REDIS_KEY_LAST_MAX_EVENT_ID) return int(val) if val else 0 except Exception as e: print(f'[轮询 WARN] 读取 Redis 失败: {e}') return _local_last_max_event_id def _set_last_max_event_id(new_id): """设置最大事件 ID 到 Redis""" global _local_last_max_event_id try: redis_client.set(REDIS_KEY_LAST_MAX_EVENT_ID, str(new_id)) _local_last_max_event_id = new_id except Exception as e: print(f'[轮询 WARN] 写入 Redis 失败: {e}') _local_last_max_event_id = new_id def poll_new_events(): """ 定期轮询数据库,检查是否有新事件 每 30 秒执行一次 多 Worker 协调机制: 1. 使用 Redis 分布式锁,确保同一时刻只有一个 Worker 执行轮询 2. 使用 Redis 存储 last_max_event_id,所有 Worker 共享状态 3. 通过 Redis 消息队列广播到所有 Worker 的客户端 待推送事件机制: - 当事件首次被检测到但没有 related_stocks 时,加入待推送列表 - 每次轮询时检查待推送列表中的事件是否已有 related_stocks - 有则推送并从列表移除,超过24小时的事件自动清理 """ import os try: # 尝试获取分布式锁(30秒超时,防止死锁) lock_acquired = redis_client.set( REDIS_KEY_POLLING_LOCK, os.getpid(), nx=True, # 只在不存在时设置 ex=30 # 30秒后自动过期 ) if not lock_acquired: # 其他 Worker 正在轮询,跳过本次 return with app.app_context(): from datetime import datetime, timedelta current_time = datetime.now() last_max_event_id = _get_last_max_event_id() print(f'\n[轮询] ========== 开始轮询 (PID: {os.getpid()}) ==========') print(f'[轮询] 当前时间: {current_time.strftime("%Y-%m-%d %H:%M:%S")}') print(f'[轮询] 当前最大事件ID: {last_max_event_id}') # 查询近24小时内的所有活跃事件(按事件发生时间 created_at) time_24h_ago = current_time - timedelta(hours=24) # 查询所有近24小时内的活跃事件 events_in_24h = Event.query.filter( Event.created_at >= time_24h_ago, Event.status == 'active' ).order_by(Event.id.asc()).all() print(f'[轮询] 数据库查询: 找到 {len(events_in_24h)} 个近24小时内的事件') # 创建事件ID到事件对象的映射 events_map = {event.id: event for event in events_in_24h} # === 步骤1: 检查待推送列表中的事件 === pending_event_ids = _get_pending_events() print(f'[轮询] 待推送列表: {len(pending_event_ids)} 个事件') pushed_from_pending = 0 for pending_id in pending_event_ids: if pending_id in events_map: event = events_map[pending_id] related_stocks_count = event.related_stocks.count() if related_stocks_count > 0: # 事件现在有 related_stocks 了,推送它 broadcast_new_event(event) _remove_pending_event(pending_id) pushed_from_pending += 1 print(f'[轮询] ✓ 待推送事件 ID={pending_id} 现在有 {related_stocks_count} 个关联股票,已推送') else: print(f'[轮询] - 待推送事件 ID={pending_id} 仍无关联股票,继续等待') else: # 事件已超过24小时或已删除,从待推送列表移除 _remove_pending_event(pending_id) print(f'[轮询] × 待推送事件 ID={pending_id} 已过期或不存在,已移除') if pushed_from_pending > 0: print(f'[轮询] 从待推送列表推送了 {pushed_from_pending} 个事件') # === 步骤2: 检查新事件 === # 找出新插入的事件(ID > last_max_event_id) new_events = [ event for event in events_in_24h if event.id > last_max_event_id ] print(f'[轮询] 新事件数量(ID > {last_max_event_id}): {len(new_events)} 个') if new_events: print(f'[轮询] 发现 {len(new_events)} 个新事件') pushed_count = 0 pending_count = 0 for event in new_events: # 检查事件是否有关联股票(只推送有关联股票的事件) related_stocks_count = event.related_stocks.count() print(f'[轮询] 事件 ID={event.id}: {event.title} (关联股票: {related_stocks_count})') # 只推送有关联股票的事件 if related_stocks_count > 0: broadcast_new_event(event) pushed_count += 1 print(f'[轮询] ✓ 已推送事件 ID={event.id}') else: # 没有关联股票,加入待推送列表 _add_pending_event(event.id) pending_count += 1 print(f'[轮询] → 加入待推送列表(暂无关联股票)') print(f'[轮询] 本轮: 推送 {pushed_count} 个, 加入待推送 {pending_count} 个') # 更新最大事件ID new_max_id = max(event.id for event in events_in_24h) _set_last_max_event_id(new_max_id) print(f'[轮询] 更新最大事件ID: {last_max_event_id} -> {new_max_id}') else: # 即使没有新事件,也要更新最大ID(防止状态不同步) if events_in_24h: current_max_id = max(event.id for event in events_in_24h) if current_max_id != last_max_event_id: _set_last_max_event_id(current_max_id) print(f'[轮询] ========== 轮询结束 ==========\n') except Exception as e: print(f'[轮询 ERROR] 检查新事件时出错: {e}') import traceback traceback.print_exc() finally: # 释放锁 try: redis_client.delete(REDIS_KEY_POLLING_LOCK) except: pass def initialize_event_polling(): """ 初始化事件轮询机制 在应用启动时调用(支持 gunicorn 多 Worker) """ global _polling_initialized # 防止重复初始化 if _polling_initialized: print('[轮询] 已经初始化过,跳过') return try: from datetime import datetime, timedelta import os with app.app_context(): current_time = datetime.now() time_24h_ago = current_time - timedelta(hours=24) print(f'\n[轮询] ========== 初始化事件轮询 (PID: {os.getpid()}) ==========') print(f'[轮询] 当前时间: {current_time.strftime("%Y-%m-%d %H:%M:%S")}') # 查询数据库中最大的事件 ID(不限于 24 小时) max_event = Event.query.filter_by(status='active').order_by(Event.id.desc()).first() db_max_id = max_event.id if max_event else 0 # 获取 Redis 中当前保存的最大 ID current_redis_max = _get_last_max_event_id() print(f'[轮询] 数据库最大事件ID: {db_max_id}') print(f'[轮询] Redis 中的最大事件ID: {current_redis_max}') # 始终使用数据库中的最大 ID(避免推送历史事件) # 只在 Redis 值为 0 或小于数据库最大值时更新 if current_redis_max == 0 or db_max_id > current_redis_max: _set_last_max_event_id(db_max_id) print(f'[轮询] 更新最大事件ID为: {db_max_id}(避免推送历史事件)') else: print(f'[轮询] 保持 Redis 中的最大事件ID: {current_redis_max}') # 统计数据库中的事件总数 total_events = Event.query.filter_by(status='active').count() events_in_24h_count = Event.query.filter( Event.created_at >= time_24h_ago, Event.status == 'active' ).count() print(f'[轮询] 数据库中共有 {total_events} 个活跃事件(其中近24小时: {events_in_24h_count} 个)') print(f'[轮询] 只会推送 ID > {max(current_redis_max, db_max_id)} 的新事件') print(f'[轮询] ========== 初始化完成 ==========\n') # 检测是否在 eventlet 环境下运行 is_eventlet = False try: import eventlet # 检查 eventlet 是否已经 monkey patch if hasattr(eventlet, 'patcher') and eventlet.patcher.is_monkey_patched('socket'): is_eventlet = True except ImportError: pass if is_eventlet: # Eventlet 环境:使用 eventlet 的协程定时器 print(f'[轮询] 检测到 Eventlet 环境,使用 eventlet 协程调度器') def eventlet_polling_loop(): """Eventlet 兼容的轮询循环""" import eventlet while True: try: eventlet.sleep(30) # 等待 30 秒 poll_new_events() except Exception as e: print(f'[轮询 ERROR] Eventlet 轮询循环出错: {e}') import traceback traceback.print_exc() eventlet.sleep(30) # 出错后等待 30 秒再继续 # 启动 eventlet 协程 eventlet.spawn(eventlet_polling_loop) print(f'[轮询] Eventlet 协程调度器已启动 (PID: {os.getpid()}),每 30 秒检查一次新事件') else: # 非 Eventlet 环境:使用 APScheduler print(f'[轮询] 使用 APScheduler BackgroundScheduler') scheduler = BackgroundScheduler() # 每 30 秒执行一次轮询 scheduler.add_job( func=poll_new_events, trigger='interval', seconds=30, id='poll_new_events', name='检查新事件并推送', replace_existing=True ) # 每天 9:25 预热股票缓存(开盘前 5 分钟) from apscheduler.triggers.cron import CronTrigger scheduler.add_job( func=preload_stock_cache, trigger=CronTrigger(hour=9, minute=25), id='preload_stock_cache', name='预热股票缓存(股票名称+前收盘价)', replace_existing=True ) print(f'[缓存] 已添加定时任务: 每天 9:25 预热股票缓存') scheduler.start() print(f'[轮询] APScheduler 调度器已启动 (PID: {os.getpid()}),每 30 秒检查一次新事件') _polling_initialized = True except Exception as e: print(f'[轮询] 初始化失败: {e}') import traceback traceback.print_exc() # ==================== Gunicorn 兼容:自动初始化轮询 ==================== # Redis key 用于确保只有一个 Worker 启动调度器 REDIS_KEY_SCHEDULER_LOCK = 'vf:event_polling:scheduler_lock' def _auto_init_polling(): """ 自动初始化事件轮询(兼容 gunicorn) 使用 Redis 锁确保整个集群只有一个 Worker 启动调度器 """ global _polling_initialized import os if _polling_initialized: return try: # 尝试获取调度器锁(10分钟过期,防止死锁) lock_acquired = redis_client.set( REDIS_KEY_SCHEDULER_LOCK, str(os.getpid()), nx=True, # 只在不存在时设置 ex=600 # 10分钟过期 ) if lock_acquired: print(f'[轮询] Worker {os.getpid()} 获得调度器锁,启动轮询调度器') initialize_event_polling() else: # 其他 Worker 已经启动了调度器 _polling_initialized = True # 标记为已初始化,避免重复尝试 print(f'[轮询] Worker {os.getpid()} 跳过调度器初始化(已由其他 Worker 启动)') except Exception as e: print(f'[轮询] 自动初始化失败: {e}') # 注册 before_request 钩子,确保 gunicorn 启动后也能初始化轮询 @app.before_request def ensure_polling_initialized(): """确保轮询机制已初始化(只执行一次)""" global _polling_initialized if not _polling_initialized: _auto_init_polling() # ==================== 结束 WebSocket 部分 ==================== @app.route('/api/posts//like', methods=['POST']) @login_required def like_post(post_id): """点赞/取消点赞帖子""" try: post = Post.query.get_or_404(post_id) # 检查是否已经点赞 existing_like = PostLike.query.filter_by( post_id=post_id, user_id=current_user.id ).first() if existing_like: # 取消点赞 db.session.delete(existing_like) post.likes_count = max(0, post.likes_count - 1) message = '取消点赞成功' liked = False else: # 添加点赞 new_like = PostLike(post_id=post_id, user_id=current_user.id) db.session.add(new_like) post.likes_count += 1 message = '点赞成功' liked = True db.session.commit() return jsonify({ 'success': True, 'message': message, 'likes_count': post.likes_count, 'liked': liked }) except Exception as e: db.session.rollback() print(f"点赞失败: {e}") return jsonify({ 'success': False, 'message': str(e) }), 500 @app.route('/api/comments//like', methods=['POST']) @login_required def like_comment(comment_id): """点赞/取消点赞评论""" try: comment = Comment.query.get_or_404(comment_id) # 检查是否已经点赞(需要创建 CommentLike 关联到新的 Comment 模型) # 暂时使用简单的计数器 comment.likes_count += 1 db.session.commit() return jsonify({ 'success': True, 'message': '点赞成功', 'likes_count': comment.likes_count }) except Exception as e: db.session.rollback() print(f"点赞失败: {e}") return jsonify({ 'success': False, 'message': str(e) }), 500 @app.route('/api/posts/', methods=['DELETE']) @login_required def delete_post(post_id): """删除帖子""" try: post = Post.query.get_or_404(post_id) # 检查权限:只能删除自己的帖子 if post.user_id != current_user.id: return jsonify({ 'success': False, 'message': '您只能删除自己的帖子' }), 403 # 软删除 post.status = 'deleted' # 更新事件的帖子数 event = Event.query.get(post.event_id) if event: event.post_count = Post.query.filter_by(event_id=post.event_id, status='active').count() # 更新用户发帖数 if current_user.post_count > 0: current_user.post_count -= 1 db.session.commit() return jsonify({ 'success': True, 'message': '帖子删除成功' }) except Exception as e: db.session.rollback() print(f"删除帖子失败: {e}") return jsonify({ 'success': False, 'message': str(e) }), 500 @app.route('/api/comments/', methods=['DELETE']) @login_required def delete_comment(comment_id): """删除评论""" try: comment = Comment.query.get_or_404(comment_id) # 检查权限:只能删除自己的评论 if comment.user_id != current_user.id: return jsonify({ 'success': False, 'message': '您只能删除自己的评论' }), 403 # 软删除 comment.status = 'deleted' comment.content = '[该评论已被删除]' # 更新帖子评论数 post = Post.query.get(comment.post_id) if post: post.comments_count = Comment.query.filter_by(post_id=comment.post_id, status='active').count() # 更新用户评论数 if current_user.comment_count > 0: current_user.comment_count -= 1 db.session.commit() return jsonify({ 'success': True, 'message': '评论删除成功' }) except Exception as e: db.session.rollback() print(f"删除评论失败: {e}") return jsonify({ 'success': False, 'message': str(e) }), 500 def format_decimal(value): """格式化decimal类型数据""" if value is None: return None if isinstance(value, Decimal): return float(value) return float(value) def format_date(date_obj): """格式化日期""" if date_obj is None: return None if isinstance(date_obj, datetime): return date_obj.strftime('%Y-%m-%d') return str(date_obj) def remove_cycles_from_sankey_flows(flows_data): """ 移除Sankey图数据中的循环边,确保数据是DAG(有向无环图) 使用拓扑排序算法检测循环,优先保留flow_ratio高的边 Args: flows_data: list of flow objects with 'source', 'target', 'flow_metrics' keys Returns: list of flows without cycles """ if not flows_data: return flows_data # 按flow_ratio降序排序,优先保留重要的边 sorted_flows = sorted( flows_data, key=lambda x: x.get('flow_metrics', {}).get('flow_ratio', 0) or 0, reverse=True ) # 构建图的邻接表和入度表 def build_graph(flows): graph = {} # node -> list of successors in_degree = {} # node -> in-degree count all_nodes = set() for flow in flows: source = flow['source']['node_name'] target = flow['target']['node_name'] all_nodes.add(source) all_nodes.add(target) if source not in graph: graph[source] = [] graph[source].append(target) if target not in in_degree: in_degree[target] = 0 in_degree[target] += 1 if source not in in_degree: in_degree[source] = 0 return graph, in_degree, all_nodes # 使用Kahn算法检测是否有环 def has_cycle(graph, in_degree, all_nodes): # 找到所有入度为0的节点 queue = [node for node in all_nodes if in_degree.get(node, 0) == 0] visited_count = 0 while queue: node = queue.pop(0) visited_count += 1 # 访问所有邻居 for neighbor in graph.get(node, []): in_degree[neighbor] -= 1 if in_degree[neighbor] == 0: queue.append(neighbor) # 如果访问的节点数等于总节点数,说明没有环 return visited_count < len(all_nodes) # 逐个添加边,如果添加后产生环则跳过 result_flows = [] for flow in sorted_flows: # 尝试添加这条边 temp_flows = result_flows + [flow] # 检查是否产生环 graph, in_degree, all_nodes = build_graph(temp_flows) # 复制in_degree用于检测(因为检测过程会修改它) in_degree_copy = in_degree.copy() if not has_cycle(graph, in_degree_copy, all_nodes): # 没有产生环,可以添加 result_flows.append(flow) else: # 产生环,跳过这条边 print(f"Skipping edge that creates cycle: {flow['source']['node_name']} -> {flow['target']['node_name']}") removed_count = len(flows_data) - len(result_flows) if removed_count > 0: print(f"Removed {removed_count} edges to eliminate cycles in Sankey diagram") return result_flows def get_report_type(date_str): """获取报告期类型""" if not date_str: return '' if isinstance(date_str, str): date = datetime.strptime(date_str, '%Y-%m-%d') else: date = date_str month = date.month year = date.year if month == 3: return f"{year}年一季报" elif month == 6: return f"{year}年中报" elif month == 9: return f"{year}年三季报" elif month == 12: return f"{year}年年报" else: return str(date_str) @app.route('/api/financial/stock-info/', methods=['GET']) def get_stock_info(seccode): """获取股票基本信息和最新财务摘要""" try: # 获取最新的财务数据 query = text(""" SELECT distinct a.SECCODE, a.SECNAME, a.ENDDATE, a.F003N as eps, a.F004N as basic_eps, a.F005N as diluted_eps, a.F006N as deducted_eps, a.F007N as undistributed_profit_ps, a.F008N as bvps, a.F010N as capital_reserve_ps, a.F014N as roe, a.F067N as roe_weighted, a.F016N as roa, a.F078N as gross_margin, a.F017N as net_margin, a.F089N as revenue, a.F101N as net_profit, a.F102N as parent_net_profit, a.F118N as total_assets, a.F121N as total_liabilities, a.F128N as total_equity, a.F052N as revenue_growth, a.F053N as profit_growth, a.F054N as equity_growth, a.F056N as asset_growth, a.F122N as share_capital FROM ea_financialindex a WHERE a.SECCODE = :seccode ORDER BY a.ENDDATE DESC LIMIT 1 """) with engine.connect() as conn: result = conn.execute(query, {'seccode': seccode}).fetchone() if not result: return jsonify({ 'success': False, 'message': f'未找到股票代码 {seccode} 的财务数据' }), 404 # 获取最近的业绩预告 forecast_query = text(""" SELECT distinct F001D as report_date, F003V as forecast_type, F004V as content, F007N as profit_lower, F008N as profit_upper, F009N as change_lower, F010N as change_upper FROM ea_forecast WHERE SECCODE = :seccode AND F006C = 'T' ORDER BY F001D DESC LIMIT 1 """) with engine.connect() as conn: forecast_result = conn.execute(forecast_query, {'seccode': seccode}).fetchone() data = { 'stock_code': result.SECCODE, 'stock_name': result.SECNAME, 'latest_period': format_date(result.ENDDATE), 'report_type': get_report_type(result.ENDDATE), 'key_metrics': { 'eps': format_decimal(result.eps), 'basic_eps': format_decimal(result.basic_eps), 'diluted_eps': format_decimal(result.diluted_eps), 'deducted_eps': format_decimal(result.deducted_eps), 'bvps': format_decimal(result.bvps), 'roe': format_decimal(result.roe), 'roe_weighted': format_decimal(result.roe_weighted), 'roa': format_decimal(result.roa), 'gross_margin': format_decimal(result.gross_margin), 'net_margin': format_decimal(result.net_margin), }, 'financial_summary': { 'revenue': format_decimal(result.revenue), 'net_profit': format_decimal(result.net_profit), 'parent_net_profit': format_decimal(result.parent_net_profit), 'total_assets': format_decimal(result.total_assets), 'total_liabilities': format_decimal(result.total_liabilities), 'total_equity': format_decimal(result.total_equity), 'share_capital': format_decimal(result.share_capital), }, 'growth_rates': { 'revenue_growth': format_decimal(result.revenue_growth), 'profit_growth': format_decimal(result.profit_growth), 'equity_growth': format_decimal(result.equity_growth), 'asset_growth': format_decimal(result.asset_growth), } } # 添加业绩预告信息 if forecast_result: data['latest_forecast'] = { 'report_date': format_date(forecast_result.report_date), 'forecast_type': forecast_result.forecast_type, 'content': forecast_result.content, 'profit_range': { 'lower': format_decimal(forecast_result.profit_lower), 'upper': format_decimal(forecast_result.profit_upper), }, 'change_range': { 'lower': format_decimal(forecast_result.change_lower), 'upper': format_decimal(forecast_result.change_upper), } } return jsonify({ 'success': True, 'data': data }) except Exception as e: return jsonify({ 'success': False, 'error': str(e) }), 500 @app.route('/api/financial/balance-sheet/', methods=['GET']) def get_balance_sheet(seccode): """获取完整的资产负债表数据""" try: limit = request.args.get('limit', 12, type=int) query = text(""" SELECT distinct ENDDATE, DECLAREDATE, -- 流动资产 F006N as cash, -- 货币资金 F007N as trading_financial_assets, -- 交易性金融资产 F008N as notes_receivable, -- 应收票据 F009N as accounts_receivable, -- 应收账款 F010N as prepayments, -- 预付款项 F011N as other_receivables, -- 其他应收款 F013N as interest_receivable, -- 应收利息 F014N as dividends_receivable, -- 应收股利 F015N as inventory, -- 存货 F016N as consumable_biological_assets, -- 消耗性生物资产 F017N as non_current_assets_due_within_one_year, -- 一年内到期的非流动资产 F018N as other_current_assets, -- 其他流动资产 F019N as total_current_assets, -- 流动资产合计 -- 非流动资产 F020N as available_for_sale_financial_assets, -- 可供出售金融资产 F021N as held_to_maturity_investments, -- 持有至到期投资 F022N as long_term_receivables, -- 长期应收款 F023N as long_term_equity_investments, -- 长期股权投资 F024N as investment_property, -- 投资性房地产 F025N as fixed_assets, -- 固定资产 F026N as construction_in_progress, -- 在建工程 F027N as engineering_materials, -- 工程物资 F029N as productive_biological_assets, -- 生产性生物资产 F030N as oil_and_gas_assets, -- 油气资产 F031N as intangible_assets, -- 无形资产 F032N as development_expenditure, -- 开发支出 F033N as goodwill, -- 商誉 F034N as long_term_deferred_expenses, -- 长期待摊费用 F035N as deferred_tax_assets, -- 递延所得税资产 F036N as other_non_current_assets, -- 其他非流动资产 F037N as total_non_current_assets, -- 非流动资产合计 F038N as total_assets, -- 资产总计 -- 流动负债 F039N as short_term_borrowings, -- 短期借款 F040N as trading_financial_liabilities, -- 交易性金融负债 F041N as notes_payable, -- 应付票据 F042N as accounts_payable, -- 应付账款 F043N as advance_receipts, -- 预收款项 F044N as employee_compensation_payable, -- 应付职工薪酬 F045N as taxes_payable, -- 应交税费 F046N as interest_payable, -- 应付利息 F047N as dividends_payable, -- 应付股利 F048N as other_payables, -- 其他应付款 F050N as non_current_liabilities_due_within_one_year, -- 一年内到期的非流动负债 F051N as other_current_liabilities, -- 其他流动负债 F052N as total_current_liabilities, -- 流动负债合计 -- 非流动负债 F053N as long_term_borrowings, -- 长期借款 F054N as bonds_payable, -- 应付债券 F055N as long_term_payables, -- 长期应付款 F056N as special_payables, -- 专项应付款 F057N as estimated_liabilities, -- 预计负债 F058N as deferred_tax_liabilities, -- 递延所得税负债 F059N as other_non_current_liabilities, -- 其他非流动负债 F060N as total_non_current_liabilities, -- 非流动负债合计 F061N as total_liabilities, -- 负债合计 -- 所有者权益 F062N as share_capital, -- 股本 F063N as capital_reserve, -- 资本公积 F064N as surplus_reserve, -- 盈余公积 F065N as undistributed_profit, -- 未分配利润 F066N as treasury_stock, -- 库存股 F067N as minority_interests, -- 少数股东权益 F070N as total_equity, -- 所有者权益合计 F071N as total_liabilities_and_equity, -- 负债和所有者权益合计 F073N as parent_company_equity, -- 归属于母公司所有者权益 F074N as other_comprehensive_income, -- 其他综合收益 -- 新会计准则科目 F110N as other_debt_investments, -- 其他债权投资 F111N as other_equity_investments, -- 其他权益工具投资 F112N as other_non_current_financial_assets, -- 其他非流动金融资产 F115N as contract_liabilities, -- 合同负债 F119N as contract_assets, -- 合同资产 F120N as receivables_financing, -- 应收款项融资 F121N as right_of_use_assets, -- 使用权资产 F122N as lease_liabilities -- 租赁负债 FROM ea_asset WHERE SECCODE = :seccode and F002V = '071001' ORDER BY ENDDATE DESC LIMIT :limit """) with engine.connect() as conn: result = conn.execute(query, {'seccode': seccode, 'limit': limit}) data = [] for row in result: # 安全计算关键比率,避免 Decimal 与 None 运算错误 def to_float(v): try: return float(v) if v is not None else None except Exception: return None ta = to_float(row.total_assets) tl = to_float(row.total_liabilities) tca = to_float(row.total_current_assets) tcl = to_float(row.total_current_liabilities) inv = to_float(row.inventory) or 0.0 asset_liability_ratio_val = None if ta is not None and ta != 0 and tl is not None: asset_liability_ratio_val = (tl / ta) * 100 current_ratio_val = None if tcl is not None and tcl != 0 and tca is not None: current_ratio_val = tca / tcl quick_ratio_val = None if tcl is not None and tcl != 0 and tca is not None: quick_ratio_val = (tca - inv) / tcl period_data = { 'period': format_date(row.ENDDATE), 'declare_date': format_date(row.DECLAREDATE), 'report_type': get_report_type(row.ENDDATE), # 资产部分 'assets': { 'current_assets': { 'cash': format_decimal(row.cash), 'trading_financial_assets': format_decimal(row.trading_financial_assets), 'notes_receivable': format_decimal(row.notes_receivable), 'accounts_receivable': format_decimal(row.accounts_receivable), 'prepayments': format_decimal(row.prepayments), 'other_receivables': format_decimal(row.other_receivables), 'inventory': format_decimal(row.inventory), 'contract_assets': format_decimal(row.contract_assets), 'other_current_assets': format_decimal(row.other_current_assets), 'total': format_decimal(row.total_current_assets), }, 'non_current_assets': { 'long_term_equity_investments': format_decimal(row.long_term_equity_investments), 'investment_property': format_decimal(row.investment_property), 'fixed_assets': format_decimal(row.fixed_assets), 'construction_in_progress': format_decimal(row.construction_in_progress), 'intangible_assets': format_decimal(row.intangible_assets), 'goodwill': format_decimal(row.goodwill), 'right_of_use_assets': format_decimal(row.right_of_use_assets), 'deferred_tax_assets': format_decimal(row.deferred_tax_assets), 'other_non_current_assets': format_decimal(row.other_non_current_assets), 'total': format_decimal(row.total_non_current_assets), }, 'total': format_decimal(row.total_assets), }, # 负债部分 'liabilities': { 'current_liabilities': { 'short_term_borrowings': format_decimal(row.short_term_borrowings), 'notes_payable': format_decimal(row.notes_payable), 'accounts_payable': format_decimal(row.accounts_payable), 'advance_receipts': format_decimal(row.advance_receipts), 'contract_liabilities': format_decimal(row.contract_liabilities), 'employee_compensation_payable': format_decimal(row.employee_compensation_payable), 'taxes_payable': format_decimal(row.taxes_payable), 'other_payables': format_decimal(row.other_payables), 'non_current_liabilities_due_within_one_year': format_decimal( row.non_current_liabilities_due_within_one_year), 'total': format_decimal(row.total_current_liabilities), }, 'non_current_liabilities': { 'long_term_borrowings': format_decimal(row.long_term_borrowings), 'bonds_payable': format_decimal(row.bonds_payable), 'lease_liabilities': format_decimal(row.lease_liabilities), 'deferred_tax_liabilities': format_decimal(row.deferred_tax_liabilities), 'other_non_current_liabilities': format_decimal(row.other_non_current_liabilities), 'total': format_decimal(row.total_non_current_liabilities), }, 'total': format_decimal(row.total_liabilities), }, # 股东权益部分 'equity': { 'share_capital': format_decimal(row.share_capital), 'capital_reserve': format_decimal(row.capital_reserve), 'surplus_reserve': format_decimal(row.surplus_reserve), 'undistributed_profit': format_decimal(row.undistributed_profit), 'treasury_stock': format_decimal(row.treasury_stock), 'other_comprehensive_income': format_decimal(row.other_comprehensive_income), 'parent_company_equity': format_decimal(row.parent_company_equity), 'minority_interests': format_decimal(row.minority_interests), 'total': format_decimal(row.total_equity), }, # 关键比率 'key_ratios': { 'asset_liability_ratio': format_decimal(asset_liability_ratio_val), 'current_ratio': format_decimal(current_ratio_val), 'quick_ratio': format_decimal(quick_ratio_val), } } data.append(period_data) return jsonify({ 'success': True, 'data': data }) except Exception as e: return jsonify({ 'success': False, 'error': str(e) }), 500 @app.route('/api/financial/income-statement/', methods=['GET']) def get_income_statement(seccode): """获取完整的利润表数据""" try: limit = request.args.get('limit', 12, type=int) query = text(""" SELECT distinct ENDDATE, STARTDATE, DECLAREDATE, -- 营业收入部分 F006N as revenue, -- 营业收入 F035N as total_operating_revenue, -- 营业总收入 F051N as other_income, -- 其他收入 -- 营业成本部分 F007N as cost, -- 营业成本 F008N as taxes_and_surcharges, -- 税金及附加 F009N as selling_expenses, -- 销售费用 F010N as admin_expenses, -- 管理费用 F056N as rd_expenses, -- 研发费用 F012N as financial_expenses, -- 财务费用 F062N as interest_expense, -- 利息费用 F063N as interest_income, -- 利息收入 F013N as asset_impairment_loss, -- 资产减值损失(营业总成本) F057N as credit_impairment_loss, -- 信用减值损失(营业总成本) F036N as total_operating_cost, -- 营业总成本 -- 其他收益 F014N as fair_value_change_income, -- 公允价值变动净收益 F015N as investment_income, -- 投资收益 F016N as investment_income_from_associates, -- 对联营企业和合营企业的投资收益 F037N as exchange_income, -- 汇兑收益 F058N as net_exposure_hedging_income, -- 净敞口套期收益 F059N as asset_disposal_income, -- 资产处置收益 -- 利润部分 F018N as operating_profit, -- 营业利润 F019N as subsidy_income, -- 补贴收入 F020N as non_operating_income, -- 营业外收入 F021N as non_operating_expenses, -- 营业外支出 F022N as non_current_asset_disposal_loss, -- 非流动资产处置损失 F024N as total_profit, -- 利润总额 F025N as income_tax_expense, -- 所得税 F027N as net_profit, -- 净利润 F028N as parent_net_profit, -- 归属于母公司所有者的净利润 F029N as minority_profit, -- 少数股东损益 -- 持续经营 F060N as continuing_operations_net_profit, -- 持续经营净利润 F061N as discontinued_operations_net_profit, -- 终止经营净利润 -- 每股收益 F031N as basic_eps, -- 基本每股收益 F032N as diluted_eps, -- 稀释每股收益 -- 综合收益 F038N as other_comprehensive_income_after_tax, -- 其他综合收益的税后净额 F039N as total_comprehensive_income, -- 综合收益总额 F040N as parent_company_comprehensive_income, -- 归属于母公司的综合收益 F041N as minority_comprehensive_income -- 归属于少数股东的综合收益 FROM ea_profit WHERE SECCODE = :seccode and F002V = '071001' ORDER BY ENDDATE DESC LIMIT :limit """) with engine.connect() as conn: result = conn.execute(query, {'seccode': seccode, 'limit': limit}) data = [] for row in result: # 计算一些衍生指标 gross_profit = (row.revenue - row.cost) if row.revenue and row.cost else None gross_margin = (gross_profit / row.revenue * 100) if row.revenue and gross_profit else None operating_margin = ( row.operating_profit / row.revenue * 100) if row.revenue and row.operating_profit else None net_margin = (row.net_profit / row.revenue * 100) if row.revenue and row.net_profit else None # 三费合计 three_expenses = 0 if row.selling_expenses: three_expenses += row.selling_expenses if row.admin_expenses: three_expenses += row.admin_expenses if row.financial_expenses: three_expenses += row.financial_expenses # 四费合计(加研发) four_expenses = three_expenses if row.rd_expenses: four_expenses += row.rd_expenses period_data = { 'period': format_date(row.ENDDATE), 'start_date': format_date(row.STARTDATE), 'declare_date': format_date(row.DECLAREDATE), 'report_type': get_report_type(row.ENDDATE), # 收入部分 'revenue': { 'operating_revenue': format_decimal(row.revenue), 'total_operating_revenue': format_decimal(row.total_operating_revenue), 'other_income': format_decimal(row.other_income), }, # 成本费用部分 'costs': { 'operating_cost': format_decimal(row.cost), 'taxes_and_surcharges': format_decimal(row.taxes_and_surcharges), 'selling_expenses': format_decimal(row.selling_expenses), 'admin_expenses': format_decimal(row.admin_expenses), 'rd_expenses': format_decimal(row.rd_expenses), 'financial_expenses': format_decimal(row.financial_expenses), 'interest_expense': format_decimal(row.interest_expense), 'interest_income': format_decimal(row.interest_income), 'asset_impairment_loss': format_decimal(row.asset_impairment_loss), 'credit_impairment_loss': format_decimal(row.credit_impairment_loss), 'total_operating_cost': format_decimal(row.total_operating_cost), 'three_expenses_total': format_decimal(three_expenses), 'four_expenses_total': format_decimal(four_expenses), }, # 其他收益 'other_gains': { 'fair_value_change': format_decimal(row.fair_value_change_income), 'investment_income': format_decimal(row.investment_income), 'investment_income_from_associates': format_decimal(row.investment_income_from_associates), 'exchange_income': format_decimal(row.exchange_income), 'asset_disposal_income': format_decimal(row.asset_disposal_income), }, # 利润 'profit': { 'gross_profit': format_decimal(gross_profit), 'operating_profit': format_decimal(row.operating_profit), 'total_profit': format_decimal(row.total_profit), 'net_profit': format_decimal(row.net_profit), 'parent_net_profit': format_decimal(row.parent_net_profit), 'minority_profit': format_decimal(row.minority_profit), 'continuing_operations_net_profit': format_decimal(row.continuing_operations_net_profit), 'discontinued_operations_net_profit': format_decimal(row.discontinued_operations_net_profit), }, # 非经营项目 'non_operating': { 'subsidy_income': format_decimal(row.subsidy_income), 'non_operating_income': format_decimal(row.non_operating_income), 'non_operating_expenses': format_decimal(row.non_operating_expenses), }, # 每股收益 'per_share': { 'basic_eps': format_decimal(row.basic_eps), 'diluted_eps': format_decimal(row.diluted_eps), }, # 综合收益 'comprehensive_income': { 'other_comprehensive_income': format_decimal(row.other_comprehensive_income_after_tax), 'total_comprehensive_income': format_decimal(row.total_comprehensive_income), 'parent_comprehensive_income': format_decimal(row.parent_company_comprehensive_income), 'minority_comprehensive_income': format_decimal(row.minority_comprehensive_income), }, # 关键比率 'margins': { 'gross_margin': format_decimal(gross_margin), 'operating_margin': format_decimal(operating_margin), 'net_margin': format_decimal(net_margin), 'expense_ratio': format_decimal(four_expenses / row.revenue * 100) if row.revenue else None, 'rd_ratio': format_decimal( row.rd_expenses / row.revenue * 100) if row.revenue and row.rd_expenses else None, } } data.append(period_data) return jsonify({ 'success': True, 'data': data }) except Exception as e: return jsonify({ 'success': False, 'error': str(e) }), 500 @app.route('/api/financial/cashflow/', methods=['GET']) def get_cashflow(seccode): """获取完整的现金流量表数据""" try: limit = request.args.get('limit', 12, type=int) query = text(""" SELECT distinct ENDDATE, STARTDATE, DECLAREDATE, -- 经营活动现金流 F006N as cash_from_sales, -- 销售商品、提供劳务收到的现金 F007N as tax_refunds, -- 收到的税费返还 F008N as other_operating_cash_received, -- 收到其他与经营活动有关的现金 F009N as total_operating_cash_inflow, -- 经营活动现金流入小计 F010N as cash_paid_for_goods, -- 购买商品、接受劳务支付的现金 F011N as cash_paid_to_employees, -- 支付给职工以及为职工支付的现金 F012N as taxes_paid, -- 支付的各项税费 F013N as other_operating_cash_paid, -- 支付其他与经营活动有关的现金 F014N as total_operating_cash_outflow, -- 经营活动现金流出小计 F015N as net_operating_cash_flow, -- 经营活动产生的现金流量净额 -- 投资活动现金流 F016N as cash_from_investment_recovery, -- 收回投资收到的现金 F017N as cash_from_investment_income, -- 取得投资收益收到的现金 F018N as cash_from_asset_disposal, -- 处置固定资产、无形资产和其他长期资产收回的现金净额 F019N as cash_from_subsidiary_disposal, -- 处置子公司及其他营业单位收到的现金净额 F020N as other_investment_cash_received, -- 收到其他与投资活动有关的现金 F021N as total_investment_cash_inflow, -- 投资活动现金流入小计 F022N as cash_paid_for_assets, -- 购建固定资产、无形资产和其他长期资产支付的现金 F023N as cash_paid_for_investments, -- 投资支付的现金 F024N as cash_paid_for_subsidiaries, -- 取得子公司及其他营业单位支付的现金净额 F025N as other_investment_cash_paid, -- 支付其他与投资活动有关的现金 F026N as total_investment_cash_outflow, -- 投资活动现金流出小计 F027N as net_investment_cash_flow, -- 投资活动产生的现金流量净额 -- 筹资活动现金流 F028N as cash_from_capital, -- 吸收投资收到的现金 F029N as cash_from_borrowings, -- 取得借款收到的现金 F030N as other_financing_cash_received, -- 收到其他与筹资活动有关的现金 F031N as total_financing_cash_inflow, -- 筹资活动现金流入小计 F032N as cash_paid_for_debt, -- 偿还债务支付的现金 F033N as cash_paid_for_distribution, -- 分配股利、利润或偿付利息支付的现金 F034N as other_financing_cash_paid, -- 支付其他与筹资活动有关的现金 F035N as total_financing_cash_outflow, -- 筹资活动现金流出小计 F036N as net_financing_cash_flow, -- 筹资活动产生的现金流量净额 -- 汇率变动影响 F037N as exchange_rate_effect, -- 汇率变动对现金及现金等价物的影响 F038N as other_cash_effect, -- 其他原因对现金的影响 -- 现金净增加额 F039N as net_cash_increase, -- 现金及现金等价物净增加额 F040N as beginning_cash_balance, -- 期初现金及现金等价物余额 F041N as ending_cash_balance, -- 期末现金及现金等价物余额 -- 补充资料部分 F044N as net_profit, -- 净利润 F045N as asset_impairment, -- 资产减值准备 F096N as credit_impairment, -- 信用减值损失 F046N as depreciation, -- 固定资产折旧、油气资产折耗、生产性生物资产折旧 F097N as right_of_use_asset_depreciation, -- 使用权资产折旧/摊销 F047N as intangible_amortization, -- 无形资产摊销 F048N as long_term_expense_amortization, -- 长期待摊费用摊销 F049N as loss_on_disposal, -- 处置固定资产、无形资产和其他长期资产的损失 F050N as fixed_asset_scrap_loss, -- 固定资产报废损失 F051N as fair_value_change_loss, -- 公允价值变动损失 F052N as financial_expenses, -- 财务费用 F053N as investment_loss, -- 投资损失 F054N as deferred_tax_asset_decrease, -- 递延所得税资产减少 F055N as deferred_tax_liability_increase, -- 递延所得税负债增加 F056N as inventory_decrease, -- 存货的减少 F057N as operating_receivables_decrease, -- 经营性应收项目的减少 F058N as operating_payables_increase, -- 经营性应付项目的增加 F059N as other, -- 其他 F060N as net_operating_cash_flow_indirect, -- 经营活动产生的现金流量净额(间接法) -- 特殊行业科目(金融) F072N as customer_deposit_increase, -- 客户存款和同业存放款项净增加额 F073N as central_bank_borrowing_increase, -- 向中央银行借款净增加额 F081N as interest_and_commission_received, -- 收取利息、手续费及佣金的现金 F087N as interest_and_commission_paid -- 支付利息、手续费及佣金的现金 FROM ea_cashflow WHERE SECCODE = :seccode and F002V = '071001' ORDER BY ENDDATE DESC LIMIT :limit """) with engine.connect() as conn: result = conn.execute(query, {'seccode': seccode, 'limit': limit}) data = [] for row in result: # 计算一些衍生指标 free_cash_flow = None if row.net_operating_cash_flow and row.cash_paid_for_assets: free_cash_flow = row.net_operating_cash_flow - row.cash_paid_for_assets period_data = { 'period': format_date(row.ENDDATE), 'start_date': format_date(row.STARTDATE), 'declare_date': format_date(row.DECLAREDATE), 'report_type': get_report_type(row.ENDDATE), # 经营活动现金流 'operating_activities': { 'inflow': { 'cash_from_sales': format_decimal(row.cash_from_sales), 'tax_refunds': format_decimal(row.tax_refunds), 'other': format_decimal(row.other_operating_cash_received), 'total': format_decimal(row.total_operating_cash_inflow), }, 'outflow': { 'cash_for_goods': format_decimal(row.cash_paid_for_goods), 'cash_for_employees': format_decimal(row.cash_paid_to_employees), 'taxes_paid': format_decimal(row.taxes_paid), 'other': format_decimal(row.other_operating_cash_paid), 'total': format_decimal(row.total_operating_cash_outflow), }, 'net_flow': format_decimal(row.net_operating_cash_flow), }, # 投资活动现金流 'investment_activities': { 'inflow': { 'investment_recovery': format_decimal(row.cash_from_investment_recovery), 'investment_income': format_decimal(row.cash_from_investment_income), 'asset_disposal': format_decimal(row.cash_from_asset_disposal), 'subsidiary_disposal': format_decimal(row.cash_from_subsidiary_disposal), 'other': format_decimal(row.other_investment_cash_received), 'total': format_decimal(row.total_investment_cash_inflow), }, 'outflow': { 'asset_purchase': format_decimal(row.cash_paid_for_assets), 'investments': format_decimal(row.cash_paid_for_investments), 'subsidiaries': format_decimal(row.cash_paid_for_subsidiaries), 'other': format_decimal(row.other_investment_cash_paid), 'total': format_decimal(row.total_investment_cash_outflow), }, 'net_flow': format_decimal(row.net_investment_cash_flow), }, # 筹资活动现金流 'financing_activities': { 'inflow': { 'capital': format_decimal(row.cash_from_capital), 'borrowings': format_decimal(row.cash_from_borrowings), 'other': format_decimal(row.other_financing_cash_received), 'total': format_decimal(row.total_financing_cash_inflow), }, 'outflow': { 'debt_repayment': format_decimal(row.cash_paid_for_debt), 'distribution': format_decimal(row.cash_paid_for_distribution), 'other': format_decimal(row.other_financing_cash_paid), 'total': format_decimal(row.total_financing_cash_outflow), }, 'net_flow': format_decimal(row.net_financing_cash_flow), }, # 现金变动 'cash_changes': { 'exchange_rate_effect': format_decimal(row.exchange_rate_effect), 'other_effect': format_decimal(row.other_cash_effect), 'net_increase': format_decimal(row.net_cash_increase), 'beginning_balance': format_decimal(row.beginning_cash_balance), 'ending_balance': format_decimal(row.ending_cash_balance), }, # 补充资料(间接法) 'indirect_method': { 'net_profit': format_decimal(row.net_profit), 'adjustments': { 'asset_impairment': format_decimal(row.asset_impairment), 'credit_impairment': format_decimal(row.credit_impairment), 'depreciation': format_decimal(row.depreciation), 'intangible_amortization': format_decimal(row.intangible_amortization), 'financial_expenses': format_decimal(row.financial_expenses), 'investment_loss': format_decimal(row.investment_loss), 'inventory_decrease': format_decimal(row.inventory_decrease), 'receivables_decrease': format_decimal(row.operating_receivables_decrease), 'payables_increase': format_decimal(row.operating_payables_increase), }, 'net_operating_cash_flow': format_decimal(row.net_operating_cash_flow_indirect), }, # 关键指标 'key_metrics': { 'free_cash_flow': format_decimal(free_cash_flow), 'cash_flow_to_profit_ratio': format_decimal( row.net_operating_cash_flow / row.net_profit) if row.net_profit and row.net_operating_cash_flow else None, 'capex': format_decimal(row.cash_paid_for_assets), } } data.append(period_data) return jsonify({ 'success': True, 'data': data }) except Exception as e: return jsonify({ 'success': False, 'error': str(e) }), 500 @app.route('/api/financial/financial-metrics/', methods=['GET']) def get_financial_metrics(seccode): """获取完整的财务指标数据""" try: limit = request.args.get('limit', 12, type=int) query = text(""" SELECT distinct ENDDATE, STARTDATE, -- 每股指标 F003N as eps, -- 每股收益 F004N as basic_eps, -- 基本每股收益 F005N as diluted_eps, -- 稀释每股收益 F006N as deducted_eps, -- 扣除非经常性损益每股收益 F007N as undistributed_profit_ps, -- 每股未分配利润 F008N as bvps, -- 每股净资产 F009N as adjusted_bvps, -- 调整后每股净资产 F010N as capital_reserve_ps, -- 每股资本公积金 F059N as cash_flow_ps, -- 每股现金流量 F060N as operating_cash_flow_ps, -- 每股经营现金流量 -- 盈利能力指标 F011N as operating_profit_margin, -- 营业利润率 F012N as tax_rate, -- 营业税金率 F013N as cost_ratio, -- 营业成本率 F014N as roe, -- 净资产收益率 F066N as roe_deducted, -- 净资产收益率(扣除非经常性损益) F067N as roe_weighted, -- 净资产收益率-加权 F068N as roe_weighted_deducted, -- 净资产收益率-加权(扣除非经常性损益) F015N as investment_return, -- 投资收益率 F016N as roa, -- 总资产报酬率 F017N as net_profit_margin, -- 净利润率 F078N as gross_margin, -- 毛利率 F020N as cost_profit_ratio, -- 成本费用利润率 -- 费用率指标 F018N as admin_expense_ratio, -- 管理费用率 F019N as financial_expense_ratio, -- 财务费用率 F021N as three_expense_ratio, -- 三费比重 F091N as selling_expense, -- 销售费用 F092N as admin_expense, -- 管理费用 F093N as financial_expense, -- 财务费用 F094N as three_expense_total, -- 三费合计 F130N as rd_expense, -- 研发费用 F131N as rd_expense_ratio, -- 研发费用率 F132N as selling_expense_ratio, -- 销售费用率 F133N as four_expense_ratio, -- 四费费用率 -- 运营能力指标 F022N as receivable_turnover, -- 应收账款周转率 F023N as inventory_turnover, -- 存货周转率 F024N as working_capital_turnover, -- 运营资金周转率 F025N as total_asset_turnover, -- 总资产周转率 F026N as fixed_asset_turnover, -- 固定资产周转率 F027N as receivable_days, -- 应收账款周转天数 F028N as inventory_days, -- 存货周转天数 F029N as current_asset_turnover, -- 流动资产周转率 F030N as current_asset_days, -- 流动资产周转天数 F031N as total_asset_days, -- 总资产周转天数 F032N as equity_turnover, -- 股东权益周转率 -- 偿债能力指标 F041N as asset_liability_ratio, -- 资产负债率 F042N as current_ratio, -- 流动比率 F043N as quick_ratio, -- 速动比率 F044N as cash_ratio, -- 现金比率 F045N as interest_coverage, -- 利息保障倍数 F049N as conservative_quick_ratio, -- 保守速动比率 F050N as cash_to_maturity_debt_ratio, -- 现金到期债务比率 F051N as tangible_asset_debt_ratio, -- 有形资产净值债务率 -- 成长能力指标 F052N as revenue_growth, -- 营业收入增长率 F053N as net_profit_growth, -- 净利润增长率 F054N as equity_growth, -- 净资产增长率 F055N as fixed_asset_growth, -- 固定资产增长率 F056N as total_asset_growth, -- 总资产增长率 F057N as investment_income_growth, -- 投资收益增长率 F058N as operating_profit_growth, -- 营业利润增长率 F141N as deducted_profit_growth, -- 扣除非经常性损益后的净利润同比变化率 F142N as parent_profit_growth, -- 归属于母公司所有者的净利润同比变化率 F143N as operating_cash_flow_growth, -- 经营活动产生的现金流净额同比变化率 -- 现金流量指标 F061N as operating_cash_to_short_debt, -- 经营净现金比率(短期债务) F062N as operating_cash_to_total_debt, -- 经营净现金比率(全部债务) F063N as operating_cash_to_profit_ratio, -- 经营活动现金净流量与净利润比率 F064N as cash_revenue_ratio, -- 营业收入现金含量 F065N as cash_recovery_rate, -- 全部资产现金回收率 F082N as cash_to_profit_ratio, -- 净利含金量 -- 财务结构指标 F033N as current_asset_ratio, -- 流动资产比率 F034N as cash_ratio_structure, -- 货币资金比率 F036N as inventory_ratio, -- 存货比率 F037N as fixed_asset_ratio, -- 固定资产比率 F038N as liability_structure_ratio, -- 负债结构比 F039N as equity_ratio, -- 产权比率 F040N as net_asset_ratio, -- 净资产比率 F046N as working_capital, -- 营运资金 F047N as non_current_liability_ratio, -- 非流动负债比率 F048N as current_liability_ratio, -- 流动负债比率 -- 非经常性损益 F076N as deducted_net_profit, -- 扣除非经常性损益后的净利润 F077N as non_recurring_items, -- 非经常性损益合计 F083N as non_recurring_ratio, -- 非经常性损益占比 -- 综合指标 F085N as ebit, -- 基本获利能力(EBIT) F086N as receivable_to_asset_ratio, -- 应收账款占比 F087N as inventory_to_asset_ratio -- 存货占比 FROM ea_financialindex WHERE SECCODE = :seccode ORDER BY ENDDATE DESC LIMIT :limit """) with engine.connect() as conn: result = conn.execute(query, {'seccode': seccode, 'limit': limit}) data = [] for row in result: period_data = { 'period': format_date(row.ENDDATE), 'start_date': format_date(row.STARTDATE), 'report_type': get_report_type(row.ENDDATE), # 每股指标 'per_share_metrics': { 'eps': format_decimal(row.eps), 'basic_eps': format_decimal(row.basic_eps), 'diluted_eps': format_decimal(row.diluted_eps), 'deducted_eps': format_decimal(row.deducted_eps), 'bvps': format_decimal(row.bvps), 'adjusted_bvps': format_decimal(row.adjusted_bvps), 'undistributed_profit_ps': format_decimal(row.undistributed_profit_ps), 'capital_reserve_ps': format_decimal(row.capital_reserve_ps), 'cash_flow_ps': format_decimal(row.cash_flow_ps), 'operating_cash_flow_ps': format_decimal(row.operating_cash_flow_ps), }, # 盈利能力 'profitability': { 'roe': format_decimal(row.roe), 'roe_deducted': format_decimal(row.roe_deducted), 'roe_weighted': format_decimal(row.roe_weighted), 'roa': format_decimal(row.roa), 'gross_margin': format_decimal(row.gross_margin), 'net_profit_margin': format_decimal(row.net_profit_margin), 'operating_profit_margin': format_decimal(row.operating_profit_margin), 'cost_profit_ratio': format_decimal(row.cost_profit_ratio), 'ebit': format_decimal(row.ebit), }, # 费用率 'expense_ratios': { 'selling_expense_ratio': format_decimal(row.selling_expense_ratio), 'admin_expense_ratio': format_decimal(row.admin_expense_ratio), 'financial_expense_ratio': format_decimal(row.financial_expense_ratio), 'rd_expense_ratio': format_decimal(row.rd_expense_ratio), 'three_expense_ratio': format_decimal(row.three_expense_ratio), 'four_expense_ratio': format_decimal(row.four_expense_ratio), }, # 运营能力 'operational_efficiency': { 'receivable_turnover': format_decimal(row.receivable_turnover), 'receivable_days': format_decimal(row.receivable_days), 'inventory_turnover': format_decimal(row.inventory_turnover), 'inventory_days': format_decimal(row.inventory_days), 'total_asset_turnover': format_decimal(row.total_asset_turnover), 'total_asset_days': format_decimal(row.total_asset_days), 'fixed_asset_turnover': format_decimal(row.fixed_asset_turnover), 'current_asset_turnover': format_decimal(row.current_asset_turnover), 'working_capital_turnover': format_decimal(row.working_capital_turnover), }, # 偿债能力 'solvency': { 'current_ratio': format_decimal(row.current_ratio), 'quick_ratio': format_decimal(row.quick_ratio), 'cash_ratio': format_decimal(row.cash_ratio), 'conservative_quick_ratio': format_decimal(row.conservative_quick_ratio), 'asset_liability_ratio': format_decimal(row.asset_liability_ratio), 'interest_coverage': format_decimal(row.interest_coverage), 'cash_to_maturity_debt_ratio': format_decimal(row.cash_to_maturity_debt_ratio), 'tangible_asset_debt_ratio': format_decimal(row.tangible_asset_debt_ratio), }, # 成长能力 'growth': { 'revenue_growth': format_decimal(row.revenue_growth), 'net_profit_growth': format_decimal(row.net_profit_growth), 'deducted_profit_growth': format_decimal(row.deducted_profit_growth), 'parent_profit_growth': format_decimal(row.parent_profit_growth), 'equity_growth': format_decimal(row.equity_growth), 'total_asset_growth': format_decimal(row.total_asset_growth), 'fixed_asset_growth': format_decimal(row.fixed_asset_growth), 'operating_profit_growth': format_decimal(row.operating_profit_growth), 'operating_cash_flow_growth': format_decimal(row.operating_cash_flow_growth), }, # 现金流量 'cash_flow_quality': { 'operating_cash_to_profit_ratio': format_decimal(row.operating_cash_to_profit_ratio), 'cash_to_profit_ratio': format_decimal(row.cash_to_profit_ratio), 'cash_revenue_ratio': format_decimal(row.cash_revenue_ratio), 'cash_recovery_rate': format_decimal(row.cash_recovery_rate), 'operating_cash_to_short_debt': format_decimal(row.operating_cash_to_short_debt), 'operating_cash_to_total_debt': format_decimal(row.operating_cash_to_total_debt), }, # 财务结构 'financial_structure': { 'current_asset_ratio': format_decimal(row.current_asset_ratio), 'fixed_asset_ratio': format_decimal(row.fixed_asset_ratio), 'inventory_ratio': format_decimal(row.inventory_ratio), 'receivable_to_asset_ratio': format_decimal(row.receivable_to_asset_ratio), 'current_liability_ratio': format_decimal(row.current_liability_ratio), 'non_current_liability_ratio': format_decimal(row.non_current_liability_ratio), 'equity_ratio': format_decimal(row.equity_ratio), }, # 非经常性损益 'non_recurring': { 'deducted_net_profit': format_decimal(row.deducted_net_profit), 'non_recurring_items': format_decimal(row.non_recurring_items), 'non_recurring_ratio': format_decimal(row.non_recurring_ratio), } } data.append(period_data) return jsonify({ 'success': True, 'data': data }) except Exception as e: return jsonify({ 'success': False, 'error': str(e) }), 500 @app.route('/api/financial/main-business/', methods=['GET']) def get_main_business(seccode): """获取主营业务构成数据(包括产品和行业分类)""" try: limit = request.args.get('periods', 4, type=int) # 获取最近几期的数据 # 获取最近的报告期 period_query = text(""" SELECT DISTINCT ENDDATE FROM ea_mainproduct WHERE SECCODE = :seccode ORDER BY ENDDATE DESC LIMIT :limit """) with engine.connect() as conn: periods = conn.execute(period_query, {'seccode': seccode, 'limit': limit}).fetchall() # 产品分类数据 product_data = [] for period in periods: query = text(""" SELECT distinct ENDDATE, F002V as category, F003V as content, F005N as revenue, F006N as cost, F007N as profit FROM ea_mainproduct WHERE SECCODE = :seccode AND ENDDATE = :enddate ORDER BY F005N DESC """) with engine.connect() as conn: result = conn.execute(query, {'seccode': seccode, 'enddate': period[0]}) # Convert result to list to allow multiple iterations rows = list(result) period_products = [] total_revenue = 0 for row in rows: if row.revenue: total_revenue += row.revenue for row in rows: product = { 'category': row.category, 'content': row.content, 'revenue': format_decimal(row.revenue), 'cost': format_decimal(row.cost), 'profit': format_decimal(row.profit), 'profit_margin': format_decimal( (row.profit / row.revenue * 100) if row.revenue and row.profit else None), 'revenue_ratio': format_decimal( (row.revenue / total_revenue * 100) if total_revenue and row.revenue else None) } period_products.append(product) if period_products: product_data.append({ 'period': format_date(period[0]), 'report_type': get_report_type(period[0]), 'total_revenue': format_decimal(total_revenue), 'products': period_products }) # 行业分类数据(从ea_mainind表) industry_data = [] for period in periods: query = text(""" SELECT distinct ENDDATE, F002V as business_content, F007N as main_revenue, F008N as main_cost, F009N as main_profit, F010N as gross_margin, F012N as revenue_ratio FROM ea_mainind WHERE SECCODE = :seccode AND ENDDATE = :enddate ORDER BY F007N DESC """) with engine.connect() as conn: result = conn.execute(query, {'seccode': seccode, 'enddate': period[0]}) # Convert result to list to allow multiple iterations rows = list(result) period_industries = [] for row in rows: industry = { 'content': row.business_content, 'revenue': format_decimal(row.main_revenue), 'cost': format_decimal(row.main_cost), 'profit': format_decimal(row.main_profit), 'gross_margin': format_decimal(row.gross_margin), 'revenue_ratio': format_decimal(row.revenue_ratio) } period_industries.append(industry) if period_industries: industry_data.append({ 'period': format_date(period[0]), 'report_type': get_report_type(period[0]), 'industries': period_industries }) return jsonify({ 'success': True, 'data': { 'product_classification': product_data, 'industry_classification': industry_data } }) except Exception as e: return jsonify({ 'success': False, 'error': str(e) }), 500 @app.route('/api/financial/forecast/', methods=['GET']) def get_forecast(seccode): """获取业绩预告和预披露时间""" try: # 获取业绩预告 forecast_query = text(""" SELECT distinct DECLAREDATE, F001D as report_date, F002V as forecast_type_code, F003V as forecast_type, F004V as content, F005V as reason, F006C as latest_flag, F007N as profit_lower, F008N as profit_upper, F009N as change_lower, F010N as change_upper, UPDATE_DATE FROM ea_forecast WHERE SECCODE = :seccode ORDER BY F001D DESC, UPDATE_DATE DESC LIMIT 10 """) with engine.connect() as conn: forecast_result = conn.execute(forecast_query, {'seccode': seccode}) forecast_data = [] for row in forecast_result: forecast = { 'declare_date': format_date(row.DECLAREDATE), 'report_date': format_date(row.report_date), 'report_type': get_report_type(row.report_date), 'forecast_type': row.forecast_type, 'forecast_type_code': row.forecast_type_code, 'content': row.content, 'reason': row.reason, 'is_latest': row.latest_flag == 'T', 'profit_range': { 'lower': format_decimal(row.profit_lower), 'upper': format_decimal(row.profit_upper), }, 'change_range': { 'lower': format_decimal(row.change_lower), 'upper': format_decimal(row.change_upper), }, 'update_date': format_date(row.UPDATE_DATE) } forecast_data.append(forecast) # 获取预披露时间 pretime_query = text(""" SELECT distinct F001D as report_period, F002D as scheduled_date, F003D as change_date_1, F004D as change_date_2, F005D as change_date_3, F006D as actual_date, F007D as change_date_4, F008D as change_date_5, UPDATE_DATE FROM ea_pretime WHERE SECCODE = :seccode ORDER BY F001D DESC LIMIT 8 """) with engine.connect() as conn: pretime_result = conn.execute(pretime_query, {'seccode': seccode}) pretime_data = [] for row in pretime_result: # 收集所有变更日期 change_dates = [] for date in [row.change_date_1, row.change_date_2, row.change_date_3, row.change_date_4, row.change_date_5]: if date: change_dates.append(format_date(date)) pretime = { 'report_period': format_date(row.report_period), 'report_type': get_report_type(row.report_period), 'scheduled_date': format_date(row.scheduled_date), 'actual_date': format_date(row.actual_date), 'change_dates': change_dates, 'update_date': format_date(row.UPDATE_DATE), 'status': 'completed' if row.actual_date else 'pending' } pretime_data.append(pretime) return jsonify({ 'success': True, 'data': { 'forecasts': forecast_data, 'disclosure_schedule': pretime_data } }) except Exception as e: return jsonify({ 'success': False, 'error': str(e) }), 500 @app.route('/api/financial/industry-rank/', methods=['GET']) def get_industry_rank(seccode): """获取行业排名数据""" try: limit = request.args.get('limit', 4, type=int) query = text(""" SELECT distinct F001V as industry_level, F002V as level_description, F003D as report_date, INDNAME as industry_name, -- 每股收益 F004N as eps, F005N as eps_industry_avg, F006N as eps_rank, -- 扣除后每股收益 F007N as deducted_eps, F008N as deducted_eps_industry_avg, F009N as deducted_eps_rank, -- 每股净资产 F010N as bvps, F011N as bvps_industry_avg, F012N as bvps_rank, -- 净资产收益率 F013N as roe, F014N as roe_industry_avg, F015N as roe_rank, -- 每股未分配利润 F016N as undistributed_profit_ps, F017N as undistributed_profit_ps_industry_avg, F018N as undistributed_profit_ps_rank, -- 每股经营现金流量 F019N as operating_cash_flow_ps, F020N as operating_cash_flow_ps_industry_avg, F021N as operating_cash_flow_ps_rank, -- 营业收入增长率 F022N as revenue_growth, F023N as revenue_growth_industry_avg, F024N as revenue_growth_rank, -- 净利润增长率 F025N as profit_growth, F026N as profit_growth_industry_avg, F027N as profit_growth_rank, -- 营业利润率 F028N as operating_margin, F029N as operating_margin_industry_avg, F030N as operating_margin_rank, -- 资产负债率 F031N as debt_ratio, F032N as debt_ratio_industry_avg, F033N as debt_ratio_rank, -- 应收账款周转率 F034N as receivable_turnover, F035N as receivable_turnover_industry_avg, F036N as receivable_turnover_rank, UPDATE_DATE FROM ea_finindexrank WHERE SECCODE = :seccode ORDER BY F003D DESC, F001V ASC LIMIT :limit_total """) # 获取多个报告期的数据 with engine.connect() as conn: result = conn.execute(query, {'seccode': seccode, 'limit_total': limit * 4}) # 按报告期和行业级别组织数据 data_by_period = {} for row in result: period = format_date(row.report_date) if period not in data_by_period: data_by_period[period] = [] rank_data = { 'industry_level': row.industry_level, 'level_description': row.level_description, 'industry_name': row.industry_name, 'metrics': { 'eps': { 'value': format_decimal(row.eps), 'industry_avg': format_decimal(row.eps_industry_avg), 'rank': int(row.eps_rank) if row.eps_rank else None }, 'deducted_eps': { 'value': format_decimal(row.deducted_eps), 'industry_avg': format_decimal(row.deducted_eps_industry_avg), 'rank': int(row.deducted_eps_rank) if row.deducted_eps_rank else None }, 'bvps': { 'value': format_decimal(row.bvps), 'industry_avg': format_decimal(row.bvps_industry_avg), 'rank': int(row.bvps_rank) if row.bvps_rank else None }, 'roe': { 'value': format_decimal(row.roe), 'industry_avg': format_decimal(row.roe_industry_avg), 'rank': int(row.roe_rank) if row.roe_rank else None }, 'operating_cash_flow_ps': { 'value': format_decimal(row.operating_cash_flow_ps), 'industry_avg': format_decimal(row.operating_cash_flow_ps_industry_avg), 'rank': int(row.operating_cash_flow_ps_rank) if row.operating_cash_flow_ps_rank else None }, 'revenue_growth': { 'value': format_decimal(row.revenue_growth), 'industry_avg': format_decimal(row.revenue_growth_industry_avg), 'rank': int(row.revenue_growth_rank) if row.revenue_growth_rank else None }, 'profit_growth': { 'value': format_decimal(row.profit_growth), 'industry_avg': format_decimal(row.profit_growth_industry_avg), 'rank': int(row.profit_growth_rank) if row.profit_growth_rank else None }, 'operating_margin': { 'value': format_decimal(row.operating_margin), 'industry_avg': format_decimal(row.operating_margin_industry_avg), 'rank': int(row.operating_margin_rank) if row.operating_margin_rank else None }, 'debt_ratio': { 'value': format_decimal(row.debt_ratio), 'industry_avg': format_decimal(row.debt_ratio_industry_avg), 'rank': int(row.debt_ratio_rank) if row.debt_ratio_rank else None }, 'receivable_turnover': { 'value': format_decimal(row.receivable_turnover), 'industry_avg': format_decimal(row.receivable_turnover_industry_avg), 'rank': int(row.receivable_turnover_rank) if row.receivable_turnover_rank else None } } } data_by_period[period].append(rank_data) # 转换为列表格式 data = [] for period, ranks in data_by_period.items(): data.append({ 'period': period, 'report_type': get_report_type(period), 'rankings': ranks }) return jsonify({ 'success': True, 'data': data }) except Exception as e: return jsonify({ 'success': False, 'error': str(e) }), 500 @app.route('/api/financial/comparison/', methods=['GET']) def get_period_comparison(seccode): """获取不同报告期的对比数据""" try: periods = request.args.get('periods', 8, type=int) # 获取多期财务数据进行对比 query = text(""" SELECT distinct fi.ENDDATE, fi.F089N as revenue, fi.F101N as net_profit, fi.F102N as parent_net_profit, fi.F078N as gross_margin, fi.F017N as net_margin, fi.F014N as roe, fi.F016N as roa, fi.F052N as revenue_growth, fi.F053N as profit_growth, fi.F003N as eps, fi.F060N as operating_cash_flow_ps, fi.F042N as current_ratio, fi.F041N as debt_ratio, fi.F105N as operating_cash_flow, fi.F118N as total_assets, fi.F121N as total_liabilities, fi.F128N as total_equity FROM ea_financialindex fi WHERE fi.SECCODE = :seccode ORDER BY fi.ENDDATE DESC LIMIT :periods """) with engine.connect() as conn: result = conn.execute(query, {'seccode': seccode, 'periods': periods}) data = [] for row in result: period_data = { 'period': format_date(row.ENDDATE), 'report_type': get_report_type(row.ENDDATE), 'performance': { 'revenue': format_decimal(row.revenue), 'net_profit': format_decimal(row.net_profit), 'parent_net_profit': format_decimal(row.parent_net_profit), 'operating_cash_flow': format_decimal(row.operating_cash_flow), }, 'profitability': { 'gross_margin': format_decimal(row.gross_margin), 'net_margin': format_decimal(row.net_margin), 'roe': format_decimal(row.roe), 'roa': format_decimal(row.roa), }, 'growth': { 'revenue_growth': format_decimal(row.revenue_growth), 'profit_growth': format_decimal(row.profit_growth), }, 'per_share': { 'eps': format_decimal(row.eps), 'operating_cash_flow_ps': format_decimal(row.operating_cash_flow_ps), }, 'financial_health': { 'current_ratio': format_decimal(row.current_ratio), 'debt_ratio': format_decimal(row.debt_ratio), 'total_assets': format_decimal(row.total_assets), 'total_liabilities': format_decimal(row.total_liabilities), 'total_equity': format_decimal(row.total_equity), } } data.append(period_data) # 计算同比和环比变化 for i in range(len(data)): if i > 0: # 环比 data[i]['qoq_changes'] = { 'revenue': calculate_change(data[i]['performance']['revenue'], data[i - 1]['performance']['revenue']), 'net_profit': calculate_change(data[i]['performance']['net_profit'], data[i - 1]['performance']['net_profit']), } # 同比(找到去年同期) current_period = data[i]['period'] yoy_period = get_yoy_period(current_period) for j in range(len(data)): if data[j]['period'] == yoy_period: data[i]['yoy_changes'] = { 'revenue': calculate_change(data[i]['performance']['revenue'], data[j]['performance']['revenue']), 'net_profit': calculate_change(data[i]['performance']['net_profit'], data[j]['performance']['net_profit']), } break return jsonify({ 'success': True, 'data': data }) except Exception as e: return jsonify({ 'success': False, 'error': str(e) }), 500 # 辅助函数 def calculate_change(current, previous): """计算变化率""" if previous and current: return format_decimal((current - previous) / abs(previous) * 100) return None def get_yoy_period(date_str): """获取去年同期""" if not date_str: return None try: date = datetime.strptime(date_str, '%Y-%m-%d') yoy_date = date.replace(year=date.year - 1) return yoy_date.strftime('%Y-%m-%d') except: return None @app.route('/api/market/trade/', methods=['GET']) def get_trade_data(seccode): """获取股票交易数据(日K线)""" try: days = request.args.get('days', 60, type=int) end_date = request.args.get('end_date', datetime.now().strftime('%Y-%m-%d')) query = text(""" SELECT TRADEDATE, SECNAME, F002N as pre_close, F003N as open, F004N as volume, F005N as high, F006N as low, F007N as close, F008N as trades_count, F009N as change_amount, F010N as change_percent, F011N as amount, F012N as turnover_rate, F013N as amplitude, F020N as total_shares, F021N as float_shares, F026N as pe_ratio FROM ea_trade WHERE SECCODE = :seccode AND TRADEDATE <= :end_date ORDER BY TRADEDATE DESC LIMIT :days """) with engine.connect() as conn: result = conn.execute(query, {'seccode': seccode, 'end_date': end_date, 'days': days}) data = [] for row in result: data.append({ 'date': format_date(row.TRADEDATE), 'stock_name': row.SECNAME, 'open': format_decimal(row.open), 'high': format_decimal(row.high), 'low': format_decimal(row.low), 'close': format_decimal(row.close), 'pre_close': format_decimal(row.pre_close), 'volume': format_decimal(row.volume), 'amount': format_decimal(row.amount), 'change_amount': format_decimal(row.change_amount), 'change_percent': format_decimal(row.change_percent), 'turnover_rate': format_decimal(row.turnover_rate), 'amplitude': format_decimal(row.amplitude), 'trades_count': format_decimal(row.trades_count), 'pe_ratio': format_decimal(row.pe_ratio), 'total_shares': format_decimal(row.total_shares), 'float_shares': format_decimal(row.float_shares), }) # 倒序,让最早的日期在前 data.reverse() # 计算统计数据 if data: prices = [d['close'] for d in data if d['close']] stats = { 'highest': max(prices) if prices else None, 'lowest': min(prices) if prices else None, 'average': sum(prices) / len(prices) if prices else None, 'latest_price': data[-1]['close'] if data else None, 'total_volume': sum([d['volume'] for d in data if d['volume']]) if data else None, 'total_amount': sum([d['amount'] for d in data if d['amount']]) if data else None, } else: stats = {} return jsonify({ 'success': True, 'data': data, 'stats': stats }) except Exception as e: return jsonify({ 'success': False, 'error': str(e) }), 500 @app.route('/api/market/trade/batch', methods=['POST']) def get_batch_trade_data(): """批量获取多只股票的交易数据(日K线) 请求体:{ codes: string[], // 股票代码列表(6位代码) days: number // 获取天数,默认1 } 返回:{ success: true, data: { [seccode]: { data: [], stats: {} } } } """ try: data = request.json codes = data.get('codes', []) days = data.get('days', 1) end_date = data.get('end_date', datetime.now().strftime('%Y-%m-%d')) if not codes: return jsonify({'success': False, 'error': '请提供股票代码列表'}), 400 if len(codes) > 100: return jsonify({'success': False, 'error': '单次最多查询100只股票'}), 400 # 构建批量查询 placeholders = ','.join([f':code{i}' for i in range(len(codes))]) params = {f'code{i}': code for i, code in enumerate(codes)} params['end_date'] = end_date params['days'] = days query = text(f""" SELECT SECCODE, TRADEDATE, SECNAME, F002N as pre_close, F003N as open, F004N as volume, F005N as high, F006N as low, F007N as close, F008N as trades_count, F009N as change_amount, F010N as change_percent, F011N as amount, F012N as turnover_rate, F013N as amplitude FROM ea_trade WHERE SECCODE IN ({placeholders}) AND TRADEDATE <= :end_date ORDER BY SECCODE, TRADEDATE DESC """) with engine.connect() as conn: result = conn.execute(query, params) rows = result.fetchall() # 按股票代码分组,每只股票只取最近N天 stock_data = {} stock_counts = {} for row in rows: seccode = row.SECCODE if seccode not in stock_data: stock_data[seccode] = [] stock_counts[seccode] = 0 # 只取指定天数的数据 if stock_counts[seccode] < days: stock_data[seccode].append({ 'date': format_date(row.TRADEDATE), 'stock_name': row.SECNAME, 'open': format_decimal(row.open), 'high': format_decimal(row.high), 'low': format_decimal(row.low), 'close': format_decimal(row.close), 'pre_close': format_decimal(row.pre_close), 'volume': format_decimal(row.volume), 'amount': format_decimal(row.amount), 'change_amount': format_decimal(row.change_amount), 'change_percent': format_decimal(row.change_percent), 'turnover_rate': format_decimal(row.turnover_rate), 'amplitude': format_decimal(row.amplitude), 'trades_count': format_decimal(row.trades_count), }) stock_counts[seccode] += 1 # 倒序每只股票的数据(让最早的日期在前) results = {} for seccode, data_list in stock_data.items(): data_list.reverse() results[seccode] = { 'data': data_list, 'stats': { 'latest_price': data_list[-1]['close'] if data_list else None, 'change_percent': data_list[-1]['change_percent'] if data_list else None, } if data_list else {} } # 为没有数据的股票返回空结果 for code in codes: if code not in results: results[code] = {'data': [], 'stats': {}} return jsonify({ 'success': True, 'data': results }) except Exception as e: return jsonify({ 'success': False, 'error': str(e) }), 500 @app.route('/api/market/funding/', methods=['GET']) def get_funding_data(seccode): """获取融资融券数据""" try: days = request.args.get('days', 30, type=int) query = text(""" SELECT TRADEDATE, SECNAME, F001N as financing_balance, F002N as financing_buy, F003N as financing_repay, F004N as securities_balance, F006N as securities_sell, F007N as securities_repay, F008N as securities_balance_amount, F009N as total_balance FROM ea_funding WHERE SECCODE = :seccode ORDER BY TRADEDATE DESC LIMIT :days """) with engine.connect() as conn: result = conn.execute(query, {'seccode': seccode, 'days': days}) data = [] for row in result: data.append({ 'date': format_date(row.TRADEDATE), 'stock_name': row.SECNAME, 'financing': { 'balance': format_decimal(row.financing_balance), 'buy': format_decimal(row.financing_buy), 'repay': format_decimal(row.financing_repay), 'net': format_decimal( row.financing_buy - row.financing_repay) if row.financing_buy and row.financing_repay else None }, 'securities': { 'balance': format_decimal(row.securities_balance), 'sell': format_decimal(row.securities_sell), 'repay': format_decimal(row.securities_repay), 'balance_amount': format_decimal(row.securities_balance_amount) }, 'total_balance': format_decimal(row.total_balance) }) data.reverse() return jsonify({ 'success': True, 'data': data }) except Exception as e: return jsonify({ 'success': False, 'error': str(e) }), 500 @app.route('/api/market/bigdeal/', methods=['GET']) def get_bigdeal_data(seccode): """获取大宗交易数据""" try: days = request.args.get('days', 30, type=int) query = text(""" SELECT TRADEDATE, SECNAME, F001V as exchange, F002V as buyer_dept, F003V as seller_dept, F004N as price, F005N as volume, F006N as amount, F007N as seq_no FROM ea_bigdeal WHERE SECCODE = :seccode ORDER BY TRADEDATE DESC, F007N LIMIT :days """) with engine.connect() as conn: result = conn.execute(query, {'seccode': seccode, 'days': days}) data = [] for row in result: data.append({ 'date': format_date(row.TRADEDATE), 'stock_name': row.SECNAME, 'exchange': row.exchange, 'buyer_dept': row.buyer_dept, 'seller_dept': row.seller_dept, 'price': format_decimal(row.price), 'volume': format_decimal(row.volume), 'amount': format_decimal(row.amount), 'seq_no': int(row.seq_no) if row.seq_no else None }) # 按日期分组统计 daily_stats = {} for item in data: date = item['date'] if date not in daily_stats: daily_stats[date] = { 'date': date, 'count': 0, 'total_volume': 0, 'total_amount': 0, 'avg_price': 0, 'deals': [] } daily_stats[date]['count'] += 1 daily_stats[date]['total_volume'] += item['volume'] or 0 daily_stats[date]['total_amount'] += item['amount'] or 0 daily_stats[date]['deals'].append(item) # 计算平均价格 for date in daily_stats: if daily_stats[date]['total_volume'] > 0: daily_stats[date]['avg_price'] = daily_stats[date]['total_amount'] / daily_stats[date]['total_volume'] return jsonify({ 'success': True, 'data': data, 'daily_stats': list(daily_stats.values()) }) except Exception as e: return jsonify({ 'success': False, 'error': str(e) }), 500 @app.route('/api/market/unusual/', methods=['GET']) def get_unusual_data(seccode): """获取龙虎榜数据""" try: days = request.args.get('days', 30, type=int) query = text(""" SELECT TRADEDATE, SECNAME, F001V as info_type_code, F002V as info_type, F003C as trade_type, F004N as rank_no, F005V as dept_name, F006N as buy_amount, F007N as sell_amount, F008N as net_amount FROM ea_unusual WHERE SECCODE = :seccode ORDER BY TRADEDATE DESC, F004N LIMIT 100 """) with engine.connect() as conn: result = conn.execute(query, {'seccode': seccode}) data = [] for row in result: data.append({ 'date': format_date(row.TRADEDATE), 'stock_name': row.SECNAME, 'info_type': row.info_type, 'info_type_code': row.info_type_code, 'trade_type': 'buy' if row.trade_type == 'B' else 'sell' if row.trade_type == 'S' else 'unknown', 'rank': int(row.rank_no) if row.rank_no else None, 'dept_name': row.dept_name, 'buy_amount': format_decimal(row.buy_amount), 'sell_amount': format_decimal(row.sell_amount), 'net_amount': format_decimal(row.net_amount) }) # 按日期分组 grouped_data = {} for item in data: date = item['date'] if date not in grouped_data: grouped_data[date] = { 'date': date, 'info_types': set(), 'buyers': [], 'sellers': [], 'total_buy': 0, 'total_sell': 0, 'net_amount': 0 } grouped_data[date]['info_types'].add(item['info_type']) if item['trade_type'] == 'buy': grouped_data[date]['buyers'].append(item) grouped_data[date]['total_buy'] += item['buy_amount'] or 0 elif item['trade_type'] == 'sell': grouped_data[date]['sellers'].append(item) grouped_data[date]['total_sell'] += item['sell_amount'] or 0 grouped_data[date]['net_amount'] = grouped_data[date]['total_buy'] - grouped_data[date]['total_sell'] # 转换set为list for date in grouped_data: grouped_data[date]['info_types'] = list(grouped_data[date]['info_types']) return jsonify({ 'success': True, 'data': data, 'grouped_data': list(grouped_data.values()) }) except Exception as e: return jsonify({ 'success': False, 'error': str(e) }), 500 @app.route('/api/market/pledge/', methods=['GET']) def get_pledge_data(seccode): """获取股权质押数据""" try: query = text(""" SELECT ENDDATE, STARTDATE, SECNAME, F001N as unrestricted_pledge, F002N as restricted_pledge, F003N as total_shares_a, F004N as pledge_count, F005N as pledge_ratio FROM ea_pledgeratio WHERE SECCODE = :seccode ORDER BY ENDDATE DESC LIMIT 12 """) with engine.connect() as conn: result = conn.execute(query, {'seccode': seccode}) data = [] for row in result: total_pledge = (row.unrestricted_pledge or 0) + (row.restricted_pledge or 0) data.append({ 'end_date': format_date(row.ENDDATE), 'start_date': format_date(row.STARTDATE), 'stock_name': row.SECNAME, 'unrestricted_pledge': format_decimal(row.unrestricted_pledge), 'restricted_pledge': format_decimal(row.restricted_pledge), 'total_pledge': format_decimal(total_pledge), 'total_shares': format_decimal(row.total_shares_a), 'pledge_count': int(row.pledge_count) if row.pledge_count else None, 'pledge_ratio': format_decimal(row.pledge_ratio) }) return jsonify({ 'success': True, 'data': data }) except Exception as e: return jsonify({ 'success': False, 'error': str(e) }), 500 @app.route('/api/market/summary/', methods=['GET']) def get_market_summary(seccode): """获取市场数据汇总""" try: # 获取最新交易数据 trade_query = text(""" SELECT * FROM ea_trade WHERE SECCODE = :seccode ORDER BY TRADEDATE DESC LIMIT 1 """) # 获取最新融资融券数据 funding_query = text(""" SELECT * FROM ea_funding WHERE SECCODE = :seccode ORDER BY TRADEDATE DESC LIMIT 1 """) # 获取最新质押数据 pledge_query = text(""" SELECT * FROM ea_pledgeratio WHERE SECCODE = :seccode ORDER BY ENDDATE DESC LIMIT 1 """) with engine.connect() as conn: trade_result = conn.execute(trade_query, {'seccode': seccode}).fetchone() with engine.connect() as conn: funding_result = conn.execute(funding_query, {'seccode': seccode}).fetchone() with engine.connect() as conn: pledge_result = conn.execute(pledge_query, {'seccode': seccode}).fetchone() summary = { 'stock_code': seccode, 'stock_name': trade_result.SECNAME if trade_result else None, 'latest_trade': { 'date': format_date(trade_result.TRADEDATE) if trade_result else None, 'close': format_decimal(trade_result.F007N) if trade_result else None, 'change_percent': format_decimal(trade_result.F010N) if trade_result else None, 'volume': format_decimal(trade_result.F004N) if trade_result else None, 'amount': format_decimal(trade_result.F011N) if trade_result else None, 'pe_ratio': format_decimal(trade_result.F026N) if trade_result else None, 'turnover_rate': format_decimal(trade_result.F012N) if trade_result else None, } if trade_result else None, 'latest_funding': { 'date': format_date(funding_result.TRADEDATE) if funding_result else None, 'financing_balance': format_decimal(funding_result.F001N) if funding_result else None, 'securities_balance': format_decimal(funding_result.F004N) if funding_result else None, 'total_balance': format_decimal(funding_result.F009N) if funding_result else None, } if funding_result else None, 'latest_pledge': { 'date': format_date(pledge_result.ENDDATE) if pledge_result else None, 'pledge_ratio': format_decimal(pledge_result.F005N) if pledge_result else None, 'pledge_count': int(pledge_result.F004N) if pledge_result and pledge_result.F004N else None, } if pledge_result else None } return jsonify({ 'success': True, 'data': summary }) except Exception as e: return jsonify({ 'success': False, 'error': str(e) }), 500 @app.route('/api/stocks/search', methods=['GET']) def search_stocks(): """搜索股票和指数(支持代码、名称搜索)""" try: query = request.args.get('q', '').strip() limit = request.args.get('limit', 20, type=int) search_type = request.args.get('type', 'all') # all, stock, index if not query: return jsonify({ 'success': False, 'error': '请输入搜索关键词' }), 400 results = [] with engine.connect() as conn: # 搜索指数(优先显示指数,因为通常用户搜索代码时指数更常用) if search_type in ('all', 'index'): index_sql = text(""" SELECT DISTINCT INDEXCODE as stock_code, SECNAME as stock_name, INDEXNAME as full_name, F018V as exchange FROM ea_exchangeindex WHERE ( UPPER(INDEXCODE) LIKE UPPER(:query_pattern) OR UPPER(SECNAME) LIKE UPPER(:query_pattern) OR UPPER(INDEXNAME) LIKE UPPER(:query_pattern) ) ORDER BY CASE WHEN UPPER(INDEXCODE) = UPPER(:exact_query) THEN 1 WHEN UPPER(SECNAME) = UPPER(:exact_query) THEN 2 WHEN UPPER(INDEXCODE) LIKE UPPER(:prefix_pattern) THEN 3 WHEN UPPER(SECNAME) LIKE UPPER(:prefix_pattern) THEN 4 ELSE 5 END, INDEXCODE LIMIT :limit """) index_result = conn.execute(index_sql, { 'query_pattern': f'%{query}%', 'exact_query': query, 'prefix_pattern': f'{query}%', 'limit': limit }).fetchall() for row in index_result: results.append({ 'stock_code': row.stock_code, 'stock_name': row.stock_name, 'full_name': row.full_name, 'exchange': row.exchange, 'isIndex': True, 'security_type': '指数' }) # 搜索股票 if search_type in ('all', 'stock'): stock_sql = text(""" SELECT DISTINCT SECCODE as stock_code, SECNAME as stock_name, F001V as pinyin_abbr, F003V as security_type, F005V as exchange, F011V as listing_status FROM ea_stocklist WHERE ( UPPER(SECCODE) LIKE UPPER(:query_pattern) OR UPPER(SECNAME) LIKE UPPER(:query_pattern) OR UPPER(F001V) LIKE UPPER(:query_pattern) ) AND (F011V = '正常上市' OR F010V = '013001') AND F003V IN ('A股', 'B股') ORDER BY CASE WHEN UPPER(SECCODE) = UPPER(:exact_query) THEN 1 WHEN UPPER(SECNAME) = UPPER(:exact_query) THEN 2 WHEN UPPER(F001V) = UPPER(:exact_query) THEN 3 WHEN UPPER(SECCODE) LIKE UPPER(:prefix_pattern) THEN 4 WHEN UPPER(SECNAME) LIKE UPPER(:prefix_pattern) THEN 5 WHEN UPPER(F001V) LIKE UPPER(:prefix_pattern) THEN 6 ELSE 7 END, SECCODE LIMIT :limit """) stock_result = conn.execute(stock_sql, { 'query_pattern': f'%{query}%', 'exact_query': query, 'prefix_pattern': f'{query}%', 'limit': limit }).fetchall() for row in stock_result: results.append({ 'stock_code': row.stock_code, 'stock_name': row.stock_name, 'pinyin_abbr': row.pinyin_abbr, 'security_type': row.security_type, 'exchange': row.exchange, 'listing_status': row.listing_status, 'isIndex': False }) # 如果搜索全部,按相关性重新排序(精确匹配优先) if search_type == 'all': def sort_key(item): code = item['stock_code'].upper() name = item['stock_name'].upper() q = query.upper() # 精确匹配代码优先 if code == q: return (0, not item['isIndex'], code) # 指数优先 # 精确匹配名称 if name == q: return (1, not item['isIndex'], code) # 前缀匹配代码 if code.startswith(q): return (2, not item['isIndex'], code) # 前缀匹配名称 if name.startswith(q): return (3, not item['isIndex'], code) return (4, not item['isIndex'], code) results.sort(key=sort_key) # 限制总数 results = results[:limit] return jsonify({ 'success': True, 'data': results, 'count': len(results) }) except Exception as e: app.logger.error(f"搜索股票/指数错误: {e}") return jsonify({ 'success': False, 'error': str(e) }), 500 @app.route('/api/market/heatmap', methods=['GET']) def get_market_heatmap(): """获取市场热力图数据(基于市值和涨跌幅)""" try: # 获取交易日期参数 trade_date = request.args.get('date') # 前端显示用的limit,但统计数据会基于全部股票 display_limit = request.args.get('limit', 500, type=int) with engine.connect() as conn: # 如果没有指定日期,获取最新交易日 if not trade_date: latest_date_result = conn.execute(text(""" SELECT MAX(TRADEDATE) as latest_date FROM ea_trade """)).fetchone() trade_date = latest_date_result.latest_date if latest_date_result else None if not trade_date: return jsonify({ 'success': False, 'error': '无法获取交易数据' }), 404 # 获取全部股票数据用于统计 all_stocks_sql = text(""" SELECT t.SECCODE as stock_code, t.SECNAME as stock_name, t.F010N as change_percent, -- 涨跌幅 t.F007N as close_price, -- 收盘价 t.F021N * t.F007N / 100000000 as market_cap, -- 市值(亿元) t.F011N / 100000000 as amount, -- 成交额(亿元) t.F012N as turnover_rate, -- 换手率 b.F034V as industry, -- 申万行业分类一级名称 b.F026V as province -- 所属省份 FROM ea_trade t LEFT JOIN ea_baseinfo b ON t.SECCODE = b.SECCODE WHERE t.TRADEDATE = :trade_date AND t.F010N IS NOT NULL -- 仅统计当日有涨跌幅数据的股票 ORDER BY market_cap DESC """) all_result = conn.execute(all_stocks_sql, { 'trade_date': trade_date }).fetchall() # 计算统计数据(基于全部股票) total_market_cap = 0 total_amount = 0 rising_count = 0 falling_count = 0 flat_count = 0 all_data = [] for row in all_result: # F010N 已在 SQL 中确保非空 change_percent = float(row.change_percent) market_cap = float(row.market_cap) if row.market_cap else 0 amount = float(row.amount) if row.amount else 0 total_market_cap += market_cap total_amount += amount if change_percent > 0: rising_count += 1 elif change_percent < 0: falling_count += 1 else: flat_count += 1 all_data.append({ 'stock_code': row.stock_code, 'stock_name': row.stock_name, 'change_percent': change_percent, 'close_price': float(row.close_price) if row.close_price else 0, 'market_cap': market_cap, 'amount': amount, 'turnover_rate': float(row.turnover_rate) if row.turnover_rate else 0, 'industry': row.industry, 'province': row.province }) # 只返回前display_limit条用于热力图显示 heatmap_data = all_data[:display_limit] return jsonify({ 'success': True, 'data': heatmap_data, 'trade_date': trade_date.strftime('%Y-%m-%d') if hasattr(trade_date, 'strftime') else str(trade_date), 'count': len(all_data), # 全部股票数量 'display_count': len(heatmap_data), # 显示的股票数量 'statistics': { 'total_market_cap': round(total_market_cap, 2), # 总市值(亿元) 'total_amount': round(total_amount, 2), # 总成交额(亿元) 'rising_count': rising_count, # 上涨家数 'falling_count': falling_count, # 下跌家数 'flat_count': flat_count # 平盘家数 } }) except Exception as e: return jsonify({ 'success': False, 'error': str(e) }), 500 @app.route('/api/market/statistics', methods=['GET']) def get_market_statistics(): """获取市场统计数据(从ea_blocktrading表)""" try: # 获取交易日期参数 trade_date = request.args.get('date') with engine.connect() as conn: # 如果没有指定日期,获取最新交易日 if not trade_date: latest_date_result = conn.execute(text(""" SELECT MAX(TRADEDATE) as latest_date FROM ea_blocktrading """)).fetchone() trade_date = latest_date_result.latest_date if latest_date_result else None if not trade_date: return jsonify({ 'success': False, 'error': '无法获取统计数据' }), 404 # 获取沪深两市的统计数据 stats_sql = text(""" SELECT EXCHANGECODE, EXCHANGENAME, F001V as indicator_code, F002V as indicator_name, F003N as indicator_value, F004V as unit, TRADEDATE FROM ea_blocktrading WHERE TRADEDATE = :trade_date AND EXCHANGECODE IN ('012001', '012002') -- 只获取上交所和深交所的数据 AND F001V IN ( '250006', '250014', -- 深交所股票总市值、上交所市价总值 '250007', '250015', -- 深交所股票流通市值、上交所流通市值 '250008', -- 深交所股票成交金额 '250010', '250019', -- 深交所股票平均市盈率、上交所平均市盈率 '250050', '250001' -- 上交所上市公司家数、深交所上市公司数 ) """) result = conn.execute(stats_sql, { 'trade_date': trade_date }).fetchall() # 整理数据 statistics = {} for row in result: key = f"{row.EXCHANGECODE}_{row.indicator_code}" statistics[key] = { 'exchange_code': row.EXCHANGECODE, 'exchange_name': row.EXCHANGENAME, 'indicator_code': row.indicator_code, 'indicator_name': row.indicator_name, 'value': float(row.indicator_value) if row.indicator_value else 0, 'unit': row.unit } # 汇总数据 summary = { 'total_market_cap': 0, # 总市值 'total_float_cap': 0, # 流通市值 'total_amount': 0, # 成交额 'sh_pe_ratio': 0, # 上交所市盈率 'sz_pe_ratio': 0, # 深交所市盈率 'sh_companies': 0, # 上交所上市公司数 'sz_companies': 0 # 深交所上市公司数 } # 计算汇总值 if '012001_250014' in statistics: # 上交所市价总值 summary['total_market_cap'] += statistics['012001_250014']['value'] if '012002_250006' in statistics: # 深交所股票总市值 summary['total_market_cap'] += statistics['012002_250006']['value'] if '012001_250015' in statistics: # 上交所流通市值 summary['total_float_cap'] += statistics['012001_250015']['value'] if '012002_250007' in statistics: # 深交所股票流通市值 summary['total_float_cap'] += statistics['012002_250007']['value'] # 成交额需要获取上交所的数据 # 获取上交所成交金额 sh_amount_result = conn.execute(text(""" SELECT F003N FROM ea_blocktrading WHERE TRADEDATE = :trade_date AND EXCHANGECODE = '012001' AND F002V LIKE '%成交金额%' LIMIT 1 """), {'trade_date': trade_date}).fetchone() sh_amount = float(sh_amount_result.F003N) if sh_amount_result and sh_amount_result.F003N else 0 sz_amount = statistics['012002_250008']['value'] if '012002_250008' in statistics else 0 summary['total_amount'] = sh_amount + sz_amount if '012001_250019' in statistics: # 上交所平均市盈率 summary['sh_pe_ratio'] = statistics['012001_250019']['value'] if '012002_250010' in statistics: # 深交所股票平均市盈率 summary['sz_pe_ratio'] = statistics['012002_250010']['value'] if '012001_250050' in statistics: # 上交所上市公司家数 summary['sh_companies'] = int(statistics['012001_250050']['value']) if '012002_250001' in statistics: # 深交所上市公司数 summary['sz_companies'] = int(statistics['012002_250001']['value']) # 获取可用的交易日期列表 available_dates_result = conn.execute(text(""" SELECT DISTINCT TRADEDATE FROM ea_blocktrading WHERE EXCHANGECODE IN ('012001', '012002') ORDER BY TRADEDATE DESC LIMIT 30 """)).fetchall() available_dates = [str(row.TRADEDATE) for row in available_dates_result] # 格式化日期为 YYYY-MM-DD formatted_trade_date = trade_date.strftime('%Y-%m-%d') if hasattr(trade_date, 'strftime') else str(trade_date).split(' ')[0][:10] formatted_available_dates = [ d.strftime('%Y-%m-%d') if hasattr(d, 'strftime') else str(d).split(' ')[0][:10] for d in [row.TRADEDATE for row in available_dates_result] ] return jsonify({ 'success': True, 'trade_date': formatted_trade_date, 'summary': summary, 'details': list(statistics.values()), 'available_dates': formatted_available_dates }) except Exception as e: return jsonify({ 'success': False, 'error': str(e) }), 500 @app.route('/api/concepts/daily-top', methods=['GET']) def get_daily_top_concepts(): """获取每日涨幅靠前的概念板块""" try: # 获取交易日期参数 trade_date = request.args.get('date') limit = request.args.get('limit', 6, type=int) # 构建概念中心API的URL concept_api_url = 'http://222.128.1.157:16801/search' # 准备请求数据 request_data = { 'query': '', 'size': limit, 'page': 1, 'sort_by': 'change_pct' } if trade_date: request_data['trade_date'] = trade_date # 调用概念中心API response = requests.post(concept_api_url, json=request_data, timeout=10) if response.status_code == 200: data = response.json() top_concepts = [] for concept in data.get('results', []): # 处理 stocks 字段:兼容 {name, code} 和 {stock_name, stock_code} 两种格式 raw_stocks = concept.get('stocks', []) formatted_stocks = [] for stock in raw_stocks: # 优先使用 stock_name,其次使用 name stock_name = stock.get('stock_name') or stock.get('name', '') stock_code = stock.get('stock_code') or stock.get('code', '') formatted_stocks.append({ 'stock_name': stock_name, 'stock_code': stock_code, 'name': stock_name, # 兼容旧格式 'code': stock_code # 兼容旧格式 }) # 保持与 /concept-api/search 相同的字段结构,并添加新字段 top_concepts.append({ 'concept_id': concept.get('concept_id'), 'concept': concept.get('concept'), # 原始字段名 'concept_name': concept.get('concept'), # 兼容旧字段名 'description': concept.get('description'), 'stock_count': concept.get('stock_count', 0), 'score': concept.get('score'), 'match_type': concept.get('match_type'), 'price_info': concept.get('price_info', {}), # 完整的价格信息 'change_percent': concept.get('price_info', {}).get('avg_change_pct', 0), # 兼容旧字段 'tags': concept.get('tags', []), # 标签列表 'outbreak_dates': concept.get('outbreak_dates', []), # 爆发日期列表 'hierarchy': concept.get('hierarchy'), # 层级信息 {lv1, lv2, lv3} 'stocks': formatted_stocks, # 返回格式化后的股票列表 'hot_score': concept.get('hot_score') }) # 格式化日期为 YYYY-MM-DD price_date = data.get('price_date', '') formatted_date = str(price_date).split(' ')[0][:10] if price_date else '' return jsonify({ 'success': True, 'data': top_concepts, 'trade_date': formatted_date, 'count': len(top_concepts) }) else: return jsonify({ 'success': False, 'error': '获取概念数据失败' }), 500 except Exception as e: return jsonify({ 'success': False, 'error': str(e) }), 500 # ==================== 热点概览 API ==================== @app.route('/api/market/hotspot-overview', methods=['GET']) def get_hotspot_overview(): """ 获取热点概览数据(用于个股中心的热点概览图表) 返回:指数分时数据 + 概念异动标注 数据来源: - 指数分时:ClickHouse index_minute 表 - 概念异动:MySQL concept_anomaly_hybrid 表(来自 realtime_detector.py) """ try: trade_date = request.args.get('date') index_code = request.args.get('index', '000001.SH') # 如果没有指定日期,使用最新交易日 if not trade_date: today = date.today() if today in trading_days_set: trade_date = today.strftime('%Y-%m-%d') else: target_date = get_trading_day_near_date(today) trade_date = target_date.strftime('%Y-%m-%d') if target_date else today.strftime('%Y-%m-%d') # 1. 获取指数分时数据 client = get_clickhouse_client() target_date_obj = datetime.strptime(trade_date, '%Y-%m-%d').date() index_data = client.execute( """ SELECT timestamp, open, high, low, close, volume FROM index_minute WHERE code = %(code)s AND toDate(timestamp) = %(date)s ORDER BY timestamp """, { 'code': index_code, 'date': target_date_obj } ) # 获取昨收价 code_no_suffix = index_code.split('.')[0] prev_close = None with engine.connect() as conn: prev_result = conn.execute(text(""" SELECT F006N FROM ea_exchangetrade WHERE INDEXCODE = :code AND TRADEDATE < :today ORDER BY TRADEDATE DESC LIMIT 1 """), { 'code': code_no_suffix, 'today': target_date_obj }).fetchone() if prev_result and prev_result[0]: prev_close = float(prev_result[0]) # 格式化指数数据 index_timeline = [] for row in index_data: ts, open_p, high_p, low_p, close_p, vol = row change_pct = None if prev_close and close_p: change_pct = round((float(close_p) - prev_close) / prev_close * 100, 4) index_timeline.append({ 'time': ts.strftime('%H:%M'), 'timestamp': ts.isoformat(), 'price': float(close_p) if close_p else None, 'open': float(open_p) if open_p else None, 'high': float(high_p) if high_p else None, 'low': float(low_p) if low_p else None, 'volume': int(vol) if vol else 0, 'change_pct': change_pct }) # 2. 获取概念异动数据(优先从 V2 表,fallback 到旧表) alerts = [] use_v2 = False with engine.connect() as conn: # 尝试查询 V2 表(时间片对齐 + 持续确认版本) try: v2_result = conn.execute(text(""" SELECT concept_id, alert_time, trade_date, alert_type, final_score, rule_score, ml_score, trigger_reason, confirm_ratio, alpha, alpha_zscore, amt_zscore, rank_zscore, momentum_3m, momentum_5m, limit_up_ratio, triggered_rules FROM concept_anomaly_v2 WHERE trade_date = :trade_date ORDER BY alert_time """), {'trade_date': trade_date}) v2_rows = v2_result.fetchall() if v2_rows: use_v2 = True for row in v2_rows: triggered_rules = None if row[16]: try: triggered_rules = json.loads(row[16]) if isinstance(row[16], str) else row[16] except: pass alerts.append({ 'concept_id': row[0], 'concept_name': row[0], # 后面会填充 'time': row[1].strftime('%H:%M') if row[1] else None, 'timestamp': row[1].isoformat() if row[1] else None, 'alert_type': row[3], 'final_score': float(row[4]) if row[4] else None, 'rule_score': float(row[5]) if row[5] else None, 'ml_score': float(row[6]) if row[6] else None, 'trigger_reason': row[7], # V2 新增字段 'confirm_ratio': float(row[8]) if row[8] else None, 'alpha': float(row[9]) if row[9] else None, 'alpha_zscore': float(row[10]) if row[10] else None, 'amt_zscore': float(row[11]) if row[11] else None, 'rank_zscore': float(row[12]) if row[12] else None, 'momentum_3m': float(row[13]) if row[13] else None, 'momentum_5m': float(row[14]) if row[14] else None, 'limit_up_ratio': float(row[15]) if row[15] else 0, 'triggered_rules': triggered_rules, # 兼容字段 'importance_score': float(row[4]) / 100 if row[4] else None, 'is_v2': True, }) except Exception as v2_err: app.logger.debug(f"V2 表查询失败,使用旧表: {v2_err}") # Fallback: 查询旧表 if not use_v2: try: alert_result = conn.execute(text(""" SELECT a.concept_id, a.alert_time, a.trade_date, a.alert_type, a.final_score, a.rule_score, a.ml_score, a.trigger_reason, a.alpha, a.alpha_delta, a.amt_ratio, a.amt_delta, a.rank_pct, a.limit_up_ratio, a.stock_count, a.total_amt, a.triggered_rules FROM concept_anomaly_hybrid a WHERE a.trade_date = :trade_date ORDER BY a.alert_time """), {'trade_date': trade_date}) for row in alert_result: triggered_rules = None if row[16]: try: triggered_rules = json.loads(row[16]) if isinstance(row[16], str) else row[16] except: pass limit_up_ratio = float(row[13]) if row[13] else 0 stock_count = int(row[14]) if row[14] else 0 limit_up_count = int(limit_up_ratio * stock_count) if stock_count > 0 else 0 alerts.append({ 'concept_id': row[0], 'concept_name': row[0], 'time': row[1].strftime('%H:%M') if row[1] else None, 'timestamp': row[1].isoformat() if row[1] else None, 'alert_type': row[3], 'final_score': float(row[4]) if row[4] else None, 'rule_score': float(row[5]) if row[5] else None, 'ml_score': float(row[6]) if row[6] else None, 'trigger_reason': row[7], 'alpha': float(row[8]) if row[8] else None, 'alpha_delta': float(row[9]) if row[9] else None, 'amt_ratio': float(row[10]) if row[10] else None, 'amt_delta': float(row[11]) if row[11] else None, 'rank_pct': float(row[12]) if row[12] else None, 'limit_up_ratio': limit_up_ratio, 'limit_up_count': limit_up_count, 'stock_count': stock_count, 'total_amt': float(row[15]) if row[15] else None, 'triggered_rules': triggered_rules, 'importance_score': float(row[4]) / 100 if row[4] else None, 'is_v2': False, }) except Exception as old_err: app.logger.debug(f"旧表查询也失败: {old_err}") # 尝试批量获取概念名称 if alerts: concept_ids = list(set(a['concept_id'] for a in alerts)) concept_names = {} # 初始化 concept_names 字典 try: from elasticsearch import Elasticsearch es_client = Elasticsearch(["http://222.128.1.157:19200"]) es_result = es_client.mget( index='concept_library_v3', body={'ids': concept_ids}, _source=['concept'] ) for doc in es_result.get('docs', []): if doc.get('found') and doc.get('_source'): concept_names[doc['_id']] = doc['_source'].get('concept', doc['_id']) # 更新 alerts 中的概念名称 for alert in alerts: if alert['concept_id'] in concept_names: alert['concept_name'] = concept_names[alert['concept_id']] except Exception as e: app.logger.warning(f"获取概念名称失败: {e}") # 计算统计信息 day_high = max([d['price'] for d in index_timeline if d['price']], default=None) day_low = min([d['price'] for d in index_timeline if d['price']], default=None) latest_price = index_timeline[-1]['price'] if index_timeline else None latest_change_pct = index_timeline[-1]['change_pct'] if index_timeline else None return jsonify({ 'success': True, 'data': { 'trade_date': trade_date, 'index': { 'code': index_code, 'name': '上证指数' if index_code == '000001.SH' else index_code, 'prev_close': prev_close, 'latest_price': latest_price, 'change_pct': latest_change_pct, 'high': day_high, 'low': day_low, 'timeline': index_timeline }, 'alerts': alerts, 'alert_count': len(alerts), 'alert_summary': { 'surge': len([a for a in alerts if a['alert_type'] == 'surge']), 'surge_up': len([a for a in alerts if a['alert_type'] == 'surge_up']), 'surge_down': len([a for a in alerts if a['alert_type'] == 'surge_down']), 'volume_surge_up': len([a for a in alerts if a['alert_type'] == 'volume_surge_up']), 'shrink_surge_up': len([a for a in alerts if a['alert_type'] == 'shrink_surge_up']), 'volume_oscillation': len([a for a in alerts if a['alert_type'] == 'volume_oscillation']), 'limit_up': len([a for a in alerts if a['alert_type'] == 'limit_up']), 'volume_spike': len([a for a in alerts if a['alert_type'] == 'volume_spike']), 'rank_jump': len([a for a in alerts if a['alert_type'] == 'rank_jump']) } } }) except Exception as e: import traceback error_trace = traceback.format_exc() app.logger.error(f"获取热点概览数据失败: {error_trace}") return jsonify({ 'success': False, 'error': str(e), 'traceback': error_trace # 临时返回完整错误信息用于调试 }), 500 @app.route('/api/concept//stocks', methods=['GET']) def get_concept_stocks(concept_id): """ 获取概念的相关股票列表(带实时涨跌幅) Args: concept_id: 概念 ID 或概念名称(支持两种方式查询) Returns: - stocks: 股票列表 [{code, name, reason, change_pct}, ...] """ try: from elasticsearch import Elasticsearch from clickhouse_driver import Client es_client = Elasticsearch(["http://222.128.1.157:19200"]) # 1. 尝试多种方式获取概念数据 source = None concept_name = concept_id # 方式1: 先尝试按 ID 查询 try: es_result = es_client.get(index='concept_library_v3', id=concept_id) if es_result.get('found'): source = es_result.get('_source', {}) concept_name = source.get('concept', concept_id) except: pass # 方式2: 如果按 ID 没找到,尝试按概念名称搜索 if not source: try: search_result = es_client.search( index='concept_library_v3', body={ 'query': { 'term': { 'concept.keyword': concept_id } }, 'size': 1 } ) hits = search_result.get('hits', {}).get('hits', []) if hits: source = hits[0].get('_source', {}) concept_name = source.get('concept', concept_id) except Exception as search_err: app.logger.debug(f"ES 搜索概念失败: {search_err}") if not source: return jsonify({ 'success': False, 'error': f'概念 {concept_id} 不存在' }), 404 raw_stocks = source.get('stocks', []) if not raw_stocks: return jsonify({ 'success': True, 'data': { 'concept_id': concept_id, 'concept_name': concept_name, 'stocks': [] } }) # 提取股票代码和原因 stocks_info = [] stock_codes = [] for s in raw_stocks: if isinstance(s, dict): code = s.get('code', '') if code and len(code) == 6: stocks_info.append({ 'code': code, 'name': s.get('name', ''), 'reason': s.get('reason', '') }) stock_codes.append(code) if not stock_codes: return jsonify({ 'success': True, 'data': { 'concept_id': concept_id, 'concept_name': concept_name, 'stocks': stocks_info } }) # 2. 获取最新交易日和前一交易日 today = datetime.now().date() trading_day = None prev_trading_day = None with engine.connect() as conn: # 获取最新交易日 result = conn.execute(text(""" SELECT EXCHANGE_DATE FROM trading_days WHERE EXCHANGE_DATE <= :today ORDER BY EXCHANGE_DATE DESC LIMIT 1 """), {"today": today}).fetchone() if result: trading_day = result[0].date() if hasattr(result[0], 'date') else result[0] # 获取前一交易日 if trading_day: result = conn.execute(text(""" SELECT EXCHANGE_DATE FROM trading_days WHERE EXCHANGE_DATE < :date ORDER BY EXCHANGE_DATE DESC LIMIT 1 """), {"date": trading_day}).fetchone() if result: prev_trading_day = result[0].date() if hasattr(result[0], 'date') else result[0] # 3. 从 MySQL ea_trade 获取前一交易日收盘价(F007N) prev_close_map = {} if prev_trading_day and stock_codes: with engine.connect() as conn: placeholders = ','.join([f':code{i}' for i in range(len(stock_codes))]) params = {f'code{i}': code for i, code in enumerate(stock_codes)} params['trade_date'] = prev_trading_day result = conn.execute(text(f""" SELECT SECCODE, F007N FROM ea_trade WHERE SECCODE IN ({placeholders}) AND TRADEDATE = :trade_date AND F007N > 0 """), params).fetchall() prev_close_map = {row[0]: float(row[1]) for row in result if row[1]} # 4. 从 ClickHouse 获取最新价格 current_price_map = {} if stock_codes: try: ch_client = Client( host='127.0.0.1', port=9000, user='default', password='Zzl33818!', database='stock' ) # 转换为 ClickHouse 格式 ch_codes = [] code_mapping = {} for code in stock_codes: if code.startswith('6'): ch_code = f"{code}.SH" elif code.startswith('0') or code.startswith('3'): ch_code = f"{code}.SZ" else: ch_code = f"{code}.BJ" ch_codes.append(ch_code) code_mapping[ch_code] = code ch_codes_str = "','".join(ch_codes) # 查询当天最新价格 query = f""" SELECT code, close FROM stock_minute WHERE code IN ('{ch_codes_str}') AND toDate(timestamp) = today() ORDER BY timestamp DESC LIMIT 1 BY code """ result = ch_client.execute(query) for row in result: ch_code, close_price = row if ch_code in code_mapping and close_price: original_code = code_mapping[ch_code] current_price_map[original_code] = float(close_price) except Exception as ch_err: app.logger.warning(f"ClickHouse 获取价格失败: {ch_err}") # 5. 计算涨跌幅并合并数据 result_stocks = [] for stock in stocks_info: code = stock['code'] prev_close = prev_close_map.get(code) current_price = current_price_map.get(code) change_pct = None if prev_close and current_price and prev_close > 0: change_pct = round((current_price - prev_close) / prev_close * 100, 2) result_stocks.append({ 'code': code, 'name': stock['name'], 'reason': stock['reason'], 'change_pct': change_pct, 'price': current_price, 'prev_close': prev_close }) # 按涨跌幅排序(涨停优先) result_stocks.sort(key=lambda x: x.get('change_pct') if x.get('change_pct') is not None else -999, reverse=True) return jsonify({ 'success': True, 'data': { 'concept_id': concept_id, 'concept_name': concept_name, 'stock_count': len(result_stocks), 'trading_day': str(trading_day) if trading_day else None, 'stocks': result_stocks } }) except Exception as e: import traceback app.logger.error(f"获取概念股票失败: {traceback.format_exc()}") return jsonify({ 'success': False, 'error': str(e) }), 500 @app.route('/api/market/concept-alerts', methods=['GET']) def get_concept_alerts(): """ 获取概念异动列表(支持分页和筛选) """ try: trade_date = request.args.get('date') alert_type = request.args.get('type') # surge/limit_up/rank_jump concept_type = request.args.get('concept_type') # leaf/lv1/lv2/lv3 limit = request.args.get('limit', 50, type=int) offset = request.args.get('offset', 0, type=int) # 构建查询条件 conditions = [] params = {'limit': limit, 'offset': offset} if trade_date: conditions.append("trade_date = :trade_date") params['trade_date'] = trade_date else: conditions.append("trade_date = CURDATE()") if alert_type: conditions.append("alert_type = :alert_type") params['alert_type'] = alert_type if concept_type: conditions.append("concept_type = :concept_type") params['concept_type'] = concept_type where_clause = " AND ".join(conditions) if conditions else "1=1" with engine.connect() as conn: # 获取总数 count_sql = text(f"SELECT COUNT(*) FROM concept_minute_alert WHERE {where_clause}") total = conn.execute(count_sql, params).scalar() # 获取数据 query_sql = text(f""" SELECT id, concept_id, concept_name, alert_time, alert_type, trade_date, change_pct, prev_change_pct, change_delta, limit_up_count, prev_limit_up_count, limit_up_delta, rank_position, prev_rank_position, rank_delta, index_price, index_change_pct, stock_count, concept_type, extra_info FROM concept_minute_alert WHERE {where_clause} ORDER BY alert_time DESC LIMIT :limit OFFSET :offset """) result = conn.execute(query_sql, params) alerts = [] for row in result: extra_info = None if row[19]: try: extra_info = json.loads(row[19]) if isinstance(row[19], str) else row[19] except: pass alerts.append({ 'id': row[0], 'concept_id': row[1], 'concept_name': row[2], 'alert_time': row[3].isoformat() if row[3] else None, 'alert_type': row[4], 'trade_date': row[5].isoformat() if row[5] else None, 'change_pct': float(row[6]) if row[6] else None, 'prev_change_pct': float(row[7]) if row[7] else None, 'change_delta': float(row[8]) if row[8] else None, 'limit_up_count': row[9], 'prev_limit_up_count': row[10], 'limit_up_delta': row[11], 'rank_position': row[12], 'prev_rank_position': row[13], 'rank_delta': row[14], 'index_price': float(row[15]) if row[15] else None, 'index_change_pct': float(row[16]) if row[16] else None, 'stock_count': row[17], 'concept_type': row[18], 'extra_info': extra_info }) return jsonify({ 'success': True, 'data': alerts, 'total': total, 'limit': limit, 'offset': offset }) except Exception as e: import traceback app.logger.error(f"获取概念异动列表失败: {traceback.format_exc()}") return jsonify({ 'success': False, 'error': str(e) }), 500 @app.route('/api/market/rise-analysis/', methods=['GET']) def get_rise_analysis(seccode): """获取股票涨幅分析数据(从 Elasticsearch 获取)""" try: # 获取日期范围参数 start_date = request.args.get('start_date') end_date = request.args.get('end_date') limit = request.args.get('limit', 100, type=int) # 构建 ES 查询 must_conditions = [ {"term": {"stock_code": seccode}} ] # 添加日期范围筛选 if start_date and end_date: must_conditions.append({ "range": { "trade_date": { "gte": start_date, "lte": end_date, "format": "yyyy-MM-dd" } } }) es_query = { "query": { "bool": { "must": must_conditions } }, "sort": [ {"trade_date": {"order": "desc"}} ], "size": limit, "_source": { "excludes": ["rise_reason_detail_embedding"] # 排除向量字段 } } # 执行 ES 查询 response = es_client.search(index="stock_rise_analysis", body=es_query) # 格式化数据 rise_analysis_data = [] for hit in response['hits']['hits']: source = hit['_source'] # 处理研报引用数据 verification_reports = [] if source.get('has_verification_info') and source.get('verification_info'): v_info = source['verification_info'] processed_results = v_info.get('processed_result', []) for report in processed_results: verification_reports.append({ 'publisher': report.get('publisher', ''), 'report_title': report.get('report_title', ''), 'author': report.get('author', ''), 'declare_date': report.get('declare_date', ''), 'content': report.get('content', ''), 'verification_item': report.get('verification_item', ''), 'match_ratio': report.get('match_ratio', 0), 'match_score': report.get('match_score', '') }) rise_analysis_data.append({ 'stock_code': source.get('stock_code', ''), 'stock_name': source.get('stock_name', ''), 'trade_date': source.get('trade_date', ''), 'rise_rate': source.get('rise_rate', 0), 'close_price': source.get('close_price', 0), 'volume': source.get('volume', 0), 'amount': source.get('amount', 0), 'main_business': source.get('main_business', ''), 'rise_reason_brief': source.get('rise_reason_brief', ''), 'rise_reason_detail': source.get('rise_reason_detail', ''), 'announcements': source.get('announcements', ''), 'verification_reports': verification_reports, 'has_verification_info': source.get('has_verification_info', False), 'create_time': source.get('create_time', ''), 'update_time': source.get('update_time', '') }) return jsonify({ 'success': True, 'data': rise_analysis_data, 'count': len(rise_analysis_data), 'total': response['hits']['total']['value'] }) except Exception as e: import traceback print(f"ES查询错误: {traceback.format_exc()}") return jsonify({ 'success': False, 'error': str(e) }), 500 # ============================================ # 公司分析相关接口 # ============================================ @app.route('/api/company/comprehensive-analysis/', methods=['GET']) def get_comprehensive_analysis(company_code): """获取公司综合分析数据""" try: # 获取公司定性分析 qualitative_query = text(""" SELECT one_line_intro, investment_highlights, business_model_desc, company_story, positioning_analysis, unique_value_proposition, business_logic_explanation, revenue_driver_analysis, customer_value_analysis, strategy_description, strategic_initiatives, created_at, updated_at FROM company_analysis WHERE company_code = :company_code """) with engine.connect() as conn: qualitative_result = conn.execute(qualitative_query, {'company_code': company_code}).fetchone() # 获取业务板块分析 segments_query = text(""" SELECT segment_name, segment_description, competitive_position, future_potential, key_customers, value_chain_position, created_at, updated_at FROM business_segment_analysis WHERE company_code = :company_code ORDER BY created_at DESC """) with engine.connect() as conn: segments_result = conn.execute(segments_query, {'company_code': company_code}).fetchall() # 获取竞争地位数据 - 最新一期 competitive_query = text(""" SELECT market_position_score, technology_score, brand_score, operation_score, finance_score, innovation_score, risk_score, growth_score, industry_avg_comparison, main_competitors, competitive_advantages, competitive_disadvantages, industry_rank, total_companies, report_period, updated_at FROM company_competitive_position WHERE company_code = :company_code ORDER BY report_period DESC LIMIT 1 """) with engine.connect() as conn: competitive_result = conn.execute(competitive_query, {'company_code': company_code}).fetchone() # 获取业务结构数据 - 最新一期 business_structure_query = text(""" SELECT business_name, parent_business, business_level, revenue, revenue_unit, revenue_ratio, profit, profit_unit, profit_ratio, revenue_growth, profit_growth, gross_margin, customer_count, market_share, report_period FROM company_business_structure WHERE company_code = :company_code AND report_period = (SELECT MAX(report_period) FROM company_business_structure WHERE company_code = :company_code) ORDER BY revenue_ratio DESC """) with engine.connect() as conn: business_structure_result = conn.execute(business_structure_query, {'company_code': company_code}).fetchall() # 构建返回数据 response_data = { 'company_code': company_code, 'qualitative_analysis': None, 'business_segments': [], 'competitive_position': None, 'business_structure': [] } # 处理定性分析数据 if qualitative_result: response_data['qualitative_analysis'] = { 'core_positioning': { 'one_line_intro': qualitative_result.one_line_intro, 'investment_highlights': qualitative_result.investment_highlights, 'business_model_desc': qualitative_result.business_model_desc, 'company_story': qualitative_result.company_story }, 'business_understanding': { 'positioning_analysis': qualitative_result.positioning_analysis, 'unique_value_proposition': qualitative_result.unique_value_proposition, 'business_logic_explanation': qualitative_result.business_logic_explanation, 'revenue_driver_analysis': qualitative_result.revenue_driver_analysis, 'customer_value_analysis': qualitative_result.customer_value_analysis }, 'strategy': { 'strategy_description': qualitative_result.strategy_description, 'strategic_initiatives': qualitative_result.strategic_initiatives }, 'updated_at': qualitative_result.updated_at.strftime( '%Y-%m-%d %H:%M:%S') if qualitative_result.updated_at else None } # 处理业务板块数据 for segment in segments_result: response_data['business_segments'].append({ 'segment_name': segment.segment_name, 'segment_description': segment.segment_description, 'competitive_position': segment.competitive_position, 'future_potential': segment.future_potential, 'key_customers': segment.key_customers, 'value_chain_position': segment.value_chain_position, 'updated_at': segment.updated_at.strftime('%Y-%m-%d %H:%M:%S') if segment.updated_at else None }) # 处理竞争地位数据 if competitive_result: response_data['competitive_position'] = { 'scores': { 'market_position': competitive_result.market_position_score, 'technology': competitive_result.technology_score, 'brand': competitive_result.brand_score, 'operation': competitive_result.operation_score, 'finance': competitive_result.finance_score, 'innovation': competitive_result.innovation_score, 'risk': competitive_result.risk_score, 'growth': competitive_result.growth_score }, 'analysis': { 'industry_avg_comparison': competitive_result.industry_avg_comparison, 'main_competitors': competitive_result.main_competitors, 'competitive_advantages': competitive_result.competitive_advantages, 'competitive_disadvantages': competitive_result.competitive_disadvantages }, 'ranking': { 'industry_rank': competitive_result.industry_rank, 'total_companies': competitive_result.total_companies, 'rank_percentage': round( (competitive_result.industry_rank / competitive_result.total_companies * 100), 2) if competitive_result.industry_rank and competitive_result.total_companies else None }, 'report_period': competitive_result.report_period, 'updated_at': competitive_result.updated_at.strftime( '%Y-%m-%d %H:%M:%S') if competitive_result.updated_at else None } # 处理业务结构数据 for business in business_structure_result: response_data['business_structure'].append({ 'business_name': business.business_name, 'parent_business': business.parent_business, 'business_level': business.business_level, 'revenue': format_decimal(business.revenue), 'revenue_unit': business.revenue_unit, 'profit': format_decimal(business.profit), 'profit_unit': business.profit_unit, 'financial_metrics': { 'revenue': format_decimal(business.revenue), 'revenue_ratio': format_decimal(business.revenue_ratio), 'profit': format_decimal(business.profit), 'profit_ratio': format_decimal(business.profit_ratio), 'gross_margin': format_decimal(business.gross_margin) }, 'growth_metrics': { 'revenue_growth': format_decimal(business.revenue_growth), 'profit_growth': format_decimal(business.profit_growth) }, 'market_metrics': { 'customer_count': business.customer_count, 'market_share': format_decimal(business.market_share) }, 'report_period': business.report_period }) return jsonify({ 'success': True, 'data': response_data }) except Exception as e: return jsonify({ 'success': False, 'error': str(e) }), 500 @app.route('/api/company/value-chain-analysis/', methods=['GET']) def get_value_chain_analysis(company_code): """获取公司产业链分析数据""" try: # 获取产业链节点数据 nodes_query = text(""" SELECT node_name, node_type, node_level, node_description, importance_score, market_share, dependency_degree, created_at FROM company_value_chain_nodes WHERE company_code = :company_code ORDER BY node_level ASC, importance_score DESC """) with engine.connect() as conn: nodes_result = conn.execute(nodes_query, {'company_code': company_code}).fetchall() # 获取产业链流向数据 flows_query = text(""" SELECT source_node, source_type, source_level, target_node, target_type, target_level, flow_value, flow_ratio, flow_type, relationship_desc, transaction_volume FROM company_value_chain_flows WHERE company_code = :company_code ORDER BY flow_ratio DESC """) with engine.connect() as conn: flows_result = conn.execute(flows_query, {'company_code': company_code}).fetchall() # 构建节点数据结构 nodes_by_level = {} all_nodes = [] for node in nodes_result: node_data = { 'node_name': node.node_name, 'node_type': node.node_type, 'node_level': node.node_level, 'node_description': node.node_description, 'importance_score': node.importance_score, 'market_share': format_decimal(node.market_share), 'dependency_degree': format_decimal(node.dependency_degree), 'created_at': node.created_at.strftime('%Y-%m-%d %H:%M:%S') if node.created_at else None } all_nodes.append(node_data) # 按层级分组 level_key = f"level_{node.node_level}" if level_key not in nodes_by_level: nodes_by_level[level_key] = [] nodes_by_level[level_key].append(node_data) # 构建流向数据 flows_data = [] for flow in flows_result: flows_data.append({ 'source': { 'node_name': flow.source_node, 'node_type': flow.source_type, 'node_level': flow.source_level }, 'target': { 'node_name': flow.target_node, 'node_type': flow.target_type, 'node_level': flow.target_level }, 'flow_metrics': { 'flow_value': format_decimal(flow.flow_value), 'flow_ratio': format_decimal(flow.flow_ratio), 'flow_type': flow.flow_type }, 'relationship_info': { 'relationship_desc': flow.relationship_desc, 'transaction_volume': flow.transaction_volume } }) # 移除循环边,确保Sankey图数据是DAG(有向无环图) flows_data = remove_cycles_from_sankey_flows(flows_data) # 统计各层级节点数量 level_stats = {} for level_key, nodes in nodes_by_level.items(): level_stats[level_key] = { 'count': len(nodes), 'avg_importance': round(sum(node['importance_score'] or 0 for node in nodes) / len(nodes), 2) if nodes else 0 } response_data = { 'company_code': company_code, 'value_chain_structure': { 'nodes_by_level': nodes_by_level, 'level_statistics': level_stats, 'total_nodes': len(all_nodes) }, 'value_chain_flows': flows_data, 'analysis_summary': { 'total_flows': len(flows_data), 'upstream_nodes': len([n for n in all_nodes if n['node_level'] < 0]), 'company_nodes': len([n for n in all_nodes if n['node_level'] == 0]), 'downstream_nodes': len([n for n in all_nodes if n['node_level'] > 0]) } } return jsonify({ 'success': True, 'data': response_data }) except Exception as e: return jsonify({ 'success': False, 'error': str(e) }), 500 @app.route('/api/company/value-chain/related-companies', methods=['GET']) def get_related_companies_by_node(): """ 根据产业链节点名称查询相关公司(结合nodes和flows表) 参数: node_name - 节点名称(如 "中芯国际"、"EDA/IP"等) 返回: 包含该节点的所有公司列表,附带节点层级、类型、关系描述等信息 """ try: node_name = request.args.get('node_name') if not node_name: return jsonify({ 'success': False, 'error': '缺少必需参数 node_name' }), 400 # 查询包含该节点的所有公司及其节点信息 query = text(""" SELECT DISTINCT n.company_code as stock_code, s.SECNAME as stock_name, s.ORGNAME as company_name, n.node_level, n.node_type, n.node_description, n.importance_score, n.market_share, n.dependency_degree FROM company_value_chain_nodes n LEFT JOIN ea_stocklist s ON n.company_code = s.SECCODE WHERE n.node_name = :node_name ORDER BY n.importance_score DESC, n.company_code """) with engine.connect() as conn: nodes_result = conn.execute(query, {'node_name': node_name}).fetchall() # 构建返回数据 companies = [] for row in nodes_result: company_data = { 'stock_code': row.stock_code, 'stock_name': row.stock_name or row.stock_code, 'company_name': row.company_name, 'node_info': { 'node_level': row.node_level, 'node_type': row.node_type, 'node_description': row.node_description, 'importance_score': row.importance_score, 'market_share': format_decimal(row.market_share), 'dependency_degree': format_decimal(row.dependency_degree) }, 'relationships': [] } # 查询该节点在该公司产业链中的流向关系 flows_query = text(""" SELECT source_node, source_type, source_level, target_node, target_type, target_level, flow_type, relationship_desc, flow_value, flow_ratio FROM company_value_chain_flows WHERE company_code = :company_code AND (source_node = :node_name OR target_node = :node_name) ORDER BY flow_ratio DESC LIMIT 5 """) with engine.connect() as conn: flows_result = conn.execute(flows_query, { 'company_code': row.stock_code, 'node_name': node_name }).fetchall() # 添加流向关系信息 for flow in flows_result: # 判断节点在流向中的角色 is_source = (flow.source_node == node_name) relationship = { 'role': 'source' if is_source else 'target', 'connected_node': flow.target_node if is_source else flow.source_node, 'connected_type': flow.target_type if is_source else flow.source_type, 'connected_level': flow.target_level if is_source else flow.source_level, 'flow_type': flow.flow_type, 'relationship_desc': flow.relationship_desc, 'flow_ratio': format_decimal(flow.flow_ratio) } company_data['relationships'].append(relationship) companies.append(company_data) return jsonify({ 'success': True, 'data': companies, 'total': len(companies), 'node_name': node_name }) except Exception as e: return jsonify({ 'success': False, 'error': str(e) }), 500 @app.route('/api/company/key-factors-timeline/', methods=['GET']) def get_key_factors_timeline(company_code): """获取公司关键因素和时间线数据""" try: # 获取请求参数 report_period = request.args.get('report_period') # 可选的报告期筛选 event_limit = request.args.get('event_limit', 50, type=int) # 时间线事件数量限制 # 获取关键因素类别 categories_query = text(""" SELECT id, category_name, category_desc, display_order FROM company_key_factor_categories WHERE company_code = :company_code ORDER BY display_order ASC, created_at ASC """) with engine.connect() as conn: categories_result = conn.execute(categories_query, {'company_code': company_code}).fetchall() # 获取关键因素详情 factors_query = text(""" SELECT kf.category_id, kf.factor_name, kf.factor_type, kf.factor_value, kf.factor_unit, kf.factor_desc, kf.impact_direction, kf.impact_weight, kf.report_period, kf.year_on_year, kf.data_source, kf.created_at, kf.updated_at FROM company_key_factors kf WHERE kf.company_code = :company_code """) params = {'company_code': company_code} # 如果指定了报告期,添加筛选条件 if report_period: factors_query = text(""" SELECT kf.category_id, kf.factor_name, kf.factor_type, kf.factor_value, kf.factor_unit, kf.factor_desc, kf.impact_direction, kf.impact_weight, kf.report_period, kf.year_on_year, kf.data_source, kf.created_at, kf.updated_at FROM company_key_factors kf WHERE kf.company_code = :company_code AND kf.report_period = :report_period ORDER BY kf.impact_weight DESC, kf.updated_at DESC """) params['report_period'] = report_period else: factors_query = text(""" SELECT kf.category_id, kf.factor_name, kf.factor_type, kf.factor_value, kf.factor_unit, kf.factor_desc, kf.impact_direction, kf.impact_weight, kf.report_period, kf.year_on_year, kf.data_source, kf.created_at, kf.updated_at FROM company_key_factors kf WHERE kf.company_code = :company_code ORDER BY kf.report_period DESC, kf.impact_weight DESC, kf.updated_at DESC """) with engine.connect() as conn: factors_result = conn.execute(factors_query, params).fetchall() # 获取发展时间线事件 timeline_query = text(""" SELECT event_date, event_type, event_title, event_desc, impact_score, is_positive, related_products, related_partners, financial_impact, created_at FROM company_timeline_events WHERE company_code = :company_code ORDER BY event_date DESC LIMIT :limit """) with engine.connect() as conn: timeline_result = conn.execute(timeline_query, {'company_code': company_code, 'limit': event_limit}).fetchall() # 构建关键因素数据结构 key_factors_data = {} factors_by_category = {} # 先建立类别索引 categories_map = {} for category in categories_result: categories_map[category.id] = { 'category_name': category.category_name, 'category_desc': category.category_desc, 'display_order': category.display_order, 'factors': [] } # 将因素分组到类别中 for factor in factors_result: factor_data = { 'factor_name': factor.factor_name, 'factor_type': factor.factor_type, 'factor_value': factor.factor_value, 'factor_unit': factor.factor_unit, 'factor_desc': factor.factor_desc, 'impact_direction': factor.impact_direction, 'impact_weight': factor.impact_weight, 'report_period': factor.report_period, 'year_on_year': format_decimal(factor.year_on_year), 'data_source': factor.data_source, 'updated_at': factor.updated_at.strftime('%Y-%m-%d %H:%M:%S') if factor.updated_at else None } category_id = factor.category_id if category_id and category_id in categories_map: categories_map[category_id]['factors'].append(factor_data) # 构建时间线数据 timeline_data = [] for event in timeline_result: timeline_data.append({ 'event_date': event.event_date.strftime('%Y-%m-%d') if event.event_date else None, 'event_type': event.event_type, 'event_title': event.event_title, 'event_desc': event.event_desc, 'impact_metrics': { 'impact_score': event.impact_score, 'is_positive': event.is_positive }, 'related_info': { 'related_products': event.related_products, 'related_partners': event.related_partners, 'financial_impact': event.financial_impact }, 'created_at': event.created_at.strftime('%Y-%m-%d %H:%M:%S') if event.created_at else None }) # 统计信息 total_factors = len(factors_result) positive_events = len([e for e in timeline_result if e.is_positive]) negative_events = len(timeline_result) - positive_events response_data = { 'company_code': company_code, 'key_factors': { 'categories': list(categories_map.values()), 'total_factors': total_factors, 'report_period': report_period }, 'development_timeline': { 'events': timeline_data, 'statistics': { 'total_events': len(timeline_data), 'positive_events': positive_events, 'negative_events': negative_events, 'event_types': list(set(event.event_type for event in timeline_result if event.event_type)) } } } return jsonify({ 'success': True, 'data': response_data }) except Exception as e: return jsonify({ 'success': False, 'error': str(e) }), 500 # ============================================ # 模拟盘服务函数 # ============================================ def get_or_create_simulation_account(user_id): """获取或创建模拟账户""" account = SimulationAccount.query.filter_by(user_id=user_id).first() if not account: account = SimulationAccount( user_id=user_id, account_name=f'模拟账户_{user_id}', initial_capital=1000000.00, available_cash=1000000.00 ) db.session.add(account) db.session.commit() return account def is_trading_time(): """判断是否为交易时间""" now = beijing_now() # 检查是否为工作日 if now.weekday() >= 5: # 周六日 return False # 检查是否为交易时间 current_time = now.time() morning_start = dt_time(9, 30) morning_end = dt_time(11, 30) afternoon_start = dt_time(13, 0) afternoon_end = dt_time(15, 0) if (morning_start <= current_time <= morning_end) or \ (afternoon_start <= current_time <= afternoon_end): return True return False def get_latest_price_from_clickhouse(stock_code): """从ClickHouse获取最新价格(优先分钟数据,备选日线数据)""" try: client = get_clickhouse_client() # 确保stock_code包含后缀 if '.' not in stock_code: if stock_code.startswith('6'): stock_code = f"{stock_code}.SH" # 上海 elif stock_code.startswith(('8', '9', '4')): stock_code = f"{stock_code}.BJ" # 北交所 else: stock_code = f"{stock_code}.SZ" # 深圳 # 1. 首先尝试获取最新的分钟数据(近30天) minute_query = """ SELECT close, timestamp FROM stock_minute WHERE code = %(code)s AND timestamp >= today() - 30 ORDER BY timestamp DESC LIMIT 1 \ """ result = client.execute(minute_query, {'code': stock_code}) if result: return float(result[0][0]), result[0][1] # 2. 如果没有分钟数据,获取最新的日线收盘价 daily_query = """ SELECT close, date FROM stock_daily WHERE code = %(code)s AND date >= today() - 90 ORDER BY date DESC LIMIT 1 \ """ daily_result = client.execute(daily_query, {'code': stock_code}) if daily_result: return float(daily_result[0][0]), daily_result[0][1] # 3. 如果还是没有,尝试从其他表获取(如果有的话) fallback_query = """ SELECT close_price, trade_date FROM stock_minute_kline WHERE stock_code = %(code6)s AND trade_date >= today() - 30 ORDER BY trade_date DESC, trade_time DESC LIMIT 1 \ """ # 提取6位代码 code6 = stock_code.split('.')[0] fallback_result = client.execute(fallback_query, {'code6': code6}) if fallback_result: return float(fallback_result[0][0]), fallback_result[0][1] print(f"警告: 无法获取股票 {stock_code} 的价格数据") return None, None except Exception as e: print(f"获取最新价格失败 {stock_code}: {e}") return None, None def get_next_minute_price(stock_code, order_time): """获取下单后一分钟内的收盘价作为成交价""" try: client = get_clickhouse_client() # 确保stock_code包含后缀 if '.' not in stock_code: if stock_code.startswith('6'): stock_code = f"{stock_code}.SH" # 上海 elif stock_code.startswith(('8', '9', '4')): stock_code = f"{stock_code}.BJ" # 北交所 else: stock_code = f"{stock_code}.SZ" # 深圳 # 获取下单后一分钟内的数据 query = """ SELECT close, timestamp FROM stock_minute WHERE code = %(code)s AND timestamp \ > %(order_time)s AND timestamp <= %(end_time)s ORDER BY timestamp ASC LIMIT 1 \ """ end_time = order_time + timedelta(minutes=1) result = client.execute(query, { 'code': stock_code, 'order_time': order_time, 'end_time': end_time }) if result: return float(result[0][0]), result[0][1] # 如果一分钟内没有数据,获取最近的数据 query = """ SELECT close, timestamp FROM stock_minute WHERE code = %(code)s AND timestamp \ > %(order_time)s ORDER BY timestamp ASC LIMIT 1 \ """ result = client.execute(query, { 'code': stock_code, 'order_time': order_time }) if result: return float(result[0][0]), result[0][1] # 如果没有后续分钟数据,使用最新可用价格 print(f"没有找到下单后的分钟数据,使用最新价格: {stock_code}") return get_latest_price_from_clickhouse(stock_code) except Exception as e: print(f"获取成交价格失败: {e}") # 出错时也尝试获取最新价格 return get_latest_price_from_clickhouse(stock_code) def validate_and_get_stock_info(stock_input): """验证股票输入并获取标准代码和名称 支持输入格式: - 股票代码:600519 或 600519.SH - 股票名称:贵州茅台 - 拼音首字母:gzmt - 名称(代码):贵州茅台(600519) 返回: (stock_code_with_suffix, stock_code_6digit, stock_name) 或 (None, None, None) """ # 先尝试标准化输入 code6, name_from_input = _normalize_stock_input(stock_input) if code6: # 如果能解析出6位代码,查询股票名称 stock_name = name_from_input or _query_stock_name_by_code(code6) if code6.startswith('6'): stock_code_full = f"{code6}.SH" # 上海 elif code6.startswith(('8', '9', '4')): stock_code_full = f"{code6}.BJ" # 北交所 else: stock_code_full = f"{code6}.SZ" # 深圳 return stock_code_full, code6, stock_name # 如果不是标准代码格式,尝试搜索 with engine.connect() as conn: search_sql = text(""" SELECT DISTINCT SECCODE as stock_code, SECNAME as stock_name FROM ea_stocklist WHERE ( UPPER(SECCODE) = UPPER(:exact_match) OR UPPER(SECNAME) = UPPER(:exact_match) OR UPPER(F001V) = UPPER(:exact_match) ) AND F011V = '正常上市' AND F003V IN ('A股', 'B股') LIMIT 1 """) result = conn.execute(search_sql, { 'exact_match': stock_input.upper() }).fetchone() if result: code6 = result.stock_code stock_name = result.stock_name if code6.startswith('6'): stock_code_full = f"{code6}.SH" # 上海 elif code6.startswith(('8', '9', '4')): stock_code_full = f"{code6}.BJ" # 北交所 else: stock_code_full = f"{code6}.SZ" # 深圳 return stock_code_full, code6, stock_name return None, None, None def execute_simulation_order(order): """执行模拟订单(优化版)""" try: # 标准化股票代码 stock_code_full, code6, stock_name = validate_and_get_stock_info(order.stock_code) if not stock_code_full: order.status = 'REJECTED' order.reject_reason = '无效的股票代码' db.session.commit() return False # 更新订单的股票信息 order.stock_code = stock_code_full order.stock_name = stock_name # 获取成交价格(下单后一分钟的收盘价) filled_price, filled_time = get_next_minute_price(stock_code_full, order.order_time) if not filled_price: # 如果无法获取价格,订单保持PENDING状态,等待后台处理 order.status = 'PENDING' db.session.commit() return True # 返回True表示下单成功,但未成交 # 更新订单信息 order.filled_qty = order.order_qty order.filled_price = filled_price order.filled_amount = filled_price * order.order_qty order.filled_time = filled_time or beijing_now() # 计算费用 order.calculate_fees() # 获取账户 account = SimulationAccount.query.get(order.account_id) if order.order_type == 'BUY': # 买入操作 total_cost = float(order.filled_amount) + float(order.total_fee) # 检查资金是否充足 if float(account.available_cash) < total_cost: order.status = 'REJECTED' order.reject_reason = '可用资金不足' db.session.commit() return False # 扣除资金 account.available_cash -= Decimal(str(total_cost)) # 更新或创建持仓 position = SimulationPosition.query.filter_by( account_id=account.id, stock_code=order.stock_code ).first() if position: # 更新持仓 total_cost_before = float(position.avg_cost) * position.position_qty total_cost_after = total_cost_before + float(order.filled_amount) total_qty_after = position.position_qty + order.filled_qty position.avg_cost = Decimal(str(total_cost_after / total_qty_after)) position.position_qty = total_qty_after # 今日买入,T+1才可用 position.frozen_qty += order.filled_qty else: # 创建新持仓 position = SimulationPosition( account_id=account.id, stock_code=order.stock_code, stock_name=order.stock_name, position_qty=order.filled_qty, available_qty=0, # T+1 frozen_qty=order.filled_qty, # 今日买入冻结 avg_cost=order.filled_price, current_price=order.filled_price ) db.session.add(position) # 更新持仓市值 position.update_market_value(order.filled_price) else: # SELL # 卖出操作 print(f"🔍 调试:查找持仓,账户ID: {account.id}, 股票代码: {order.stock_code}") # 先尝试用完整格式查找 position = SimulationPosition.query.filter_by( account_id=account.id, stock_code=order.stock_code ).first() # 如果没找到,尝试用6位数字格式查找 if not position and '.' in order.stock_code: code6 = order.stock_code.split('.')[0] print(f"🔍 调试:尝试用6位格式查找: {code6}") position = SimulationPosition.query.filter_by( account_id=account.id, stock_code=code6 ).first() print(f"🔍 调试:找到持仓: {position}") if position: print( f"🔍 调试:持仓详情 - 股票代码: {position.stock_code}, 持仓数量: {position.position_qty}, 可用数量: {position.available_qty}") # 检查持仓是否存在 if not position: order.status = 'REJECTED' order.reject_reason = '持仓不存在' db.session.commit() return False # 检查总持仓数量是否足够(包括冻结的) total_holdings = position.position_qty if total_holdings < order.order_qty: order.status = 'REJECTED' order.reject_reason = f'持仓数量不足,当前持仓: {total_holdings} 股,需要: {order.order_qty} 股' db.session.commit() return False # 如果可用数量不足,但总持仓足够,则从冻结数量中解冻 if position.available_qty < order.order_qty: # 计算需要解冻的数量 need_to_unfreeze = order.order_qty - position.available_qty if position.frozen_qty >= need_to_unfreeze: # 解冻部分冻结数量 position.frozen_qty -= need_to_unfreeze position.available_qty += need_to_unfreeze print(f"解冻 {need_to_unfreeze} 股用于卖出") else: order.status = 'REJECTED' order.reject_reason = f'可用数量不足,可用: {position.available_qty} 股,冻结: {position.frozen_qty} 股,需要: {order.order_qty} 股' db.session.commit() return False # 更新持仓 position.position_qty -= order.filled_qty position.available_qty -= order.filled_qty # 增加资金 account.available_cash += Decimal(str(float(order.filled_amount) - float(order.total_fee))) # 如果全部卖出,删除持仓记录 if position.position_qty == 0: db.session.delete(position) # 创建成交记录 transaction = SimulationTransaction( account_id=account.id, order_id=order.id, transaction_no=f"T{int(beijing_now().timestamp() * 1000000)}", stock_code=order.stock_code, stock_name=order.stock_name, transaction_type=order.order_type, transaction_price=order.filled_price, transaction_qty=order.filled_qty, transaction_amount=order.filled_amount, commission=order.commission, stamp_tax=order.stamp_tax, transfer_fee=order.transfer_fee, total_fee=order.total_fee, transaction_time=order.filled_time, settlement_date=(order.filled_time + timedelta(days=1)).date() ) db.session.add(transaction) # 更新订单状态 order.status = 'FILLED' # 更新账户总资产 update_account_assets(account) db.session.commit() return True except Exception as e: print(f"执行订单失败: {e}") db.session.rollback() return False def update_account_assets(account): """更新账户资产(轻量级版本,不实时获取价格)""" try: # 只计算已有的持仓市值,不实时获取价格 # 价格更新由后台脚本负责 positions = SimulationPosition.query.filter_by(account_id=account.id).all() total_market_value = sum(position.market_value or Decimal('0') for position in positions) account.position_value = total_market_value account.calculate_total_assets() db.session.commit() except Exception as e: print(f"更新账户资产失败: {e}") db.session.rollback() def update_all_positions_price(): """更新所有持仓的最新价格(定时任务调用)""" try: positions = SimulationPosition.query.all() for position in positions: latest_price, _ = get_latest_price_from_clickhouse(position.stock_code) if latest_price: # 记录昨日收盘价(用于计算今日盈亏) yesterday_close = position.current_price # 更新市值 position.update_market_value(latest_price) # 计算今日盈亏 position.today_profit = (Decimal(str(latest_price)) - yesterday_close) * position.position_qty position.today_profit_rate = ((Decimal( str(latest_price)) - yesterday_close) / yesterday_close * 100) if yesterday_close > 0 else 0 db.session.commit() except Exception as e: print(f"更新持仓价格失败: {e}") db.session.rollback() def process_t1_settlement(): """处理T+1结算(每日收盘后运行)""" try: # 获取所有需要结算的持仓 positions = SimulationPosition.query.filter(SimulationPosition.frozen_qty > 0).all() for position in positions: # 将冻结数量转为可用数量 position.available_qty += position.frozen_qty position.frozen_qty = 0 db.session.commit() except Exception as e: print(f"T+1结算失败: {e}") db.session.rollback() # ============================================ # 模拟盘API接口 # ============================================ @app.route('/api/simulation/account', methods=['GET']) @login_required def get_simulation_account(): """获取模拟账户信息""" try: account = get_or_create_simulation_account(current_user.id) # 更新账户资产 update_account_assets(account) return jsonify({ 'success': True, 'data': { 'account_id': account.id, 'account_name': account.account_name, 'initial_capital': float(account.initial_capital), 'available_cash': float(account.available_cash), 'frozen_cash': float(account.frozen_cash), 'position_value': float(account.position_value), 'total_assets': float(account.total_assets), 'total_profit': float(account.total_profit), 'total_profit_rate': float(account.total_profit_rate), 'daily_profit': float(account.daily_profit), 'daily_profit_rate': float(account.daily_profit_rate), 'created_at': account.created_at.isoformat(), 'updated_at': account.updated_at.isoformat() } }) except Exception as e: return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/simulation/positions', methods=['GET']) @login_required def get_simulation_positions(): """获取模拟持仓列表(优化版本,使用缓存的价格数据)""" try: account = get_or_create_simulation_account(current_user.id) # 直接获取持仓数据,不实时更新价格(由后台脚本负责) positions = SimulationPosition.query.filter_by(account_id=account.id).all() positions_data = [] for position in positions: positions_data.append({ 'id': position.id, 'stock_code': position.stock_code, 'stock_name': position.stock_name, 'position_qty': position.position_qty, 'available_qty': position.available_qty, 'frozen_qty': position.frozen_qty, 'avg_cost': float(position.avg_cost), 'current_price': float(position.current_price or 0), 'market_value': float(position.market_value or 0), 'profit': float(position.profit or 0), 'profit_rate': float(position.profit_rate or 0), 'today_profit': float(position.today_profit or 0), 'today_profit_rate': float(position.today_profit_rate or 0), 'updated_at': position.updated_at.isoformat() }) return jsonify({ 'success': True, 'data': positions_data }) except Exception as e: return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/simulation/orders', methods=['GET']) @login_required def get_simulation_orders(): """获取模拟订单列表""" try: account = get_or_create_simulation_account(current_user.id) # 获取查询参数 status = request.args.get('status') # 订单状态筛选 date_str = request.args.get('date') # 日期筛选 limit = request.args.get('limit', 50, type=int) query = SimulationOrder.query.filter_by(account_id=account.id) if status: query = query.filter_by(status=status) if date_str: try: date = datetime.strptime(date_str, '%Y-%m-%d').date() start_time = datetime.combine(date, dt_time(0, 0, 0)) end_time = datetime.combine(date, dt_time(23, 59, 59)) query = query.filter(SimulationOrder.order_time.between(start_time, end_time)) except ValueError: pass orders = query.order_by(SimulationOrder.order_time.desc()).limit(limit).all() orders_data = [] for order in orders: orders_data.append({ 'id': order.id, 'order_no': order.order_no, 'stock_code': order.stock_code, 'stock_name': order.stock_name, 'order_type': order.order_type, 'price_type': order.price_type, 'order_price': float(order.order_price) if order.order_price else None, 'order_qty': order.order_qty, 'filled_qty': order.filled_qty, 'filled_price': float(order.filled_price) if order.filled_price else None, 'filled_amount': float(order.filled_amount) if order.filled_amount else None, 'commission': float(order.commission), 'stamp_tax': float(order.stamp_tax), 'transfer_fee': float(order.transfer_fee), 'total_fee': float(order.total_fee), 'status': order.status, 'reject_reason': order.reject_reason, 'order_time': order.order_time.isoformat(), 'filled_time': order.filled_time.isoformat() if order.filled_time else None }) return jsonify({ 'success': True, 'data': orders_data }) except Exception as e: return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/simulation/place-order', methods=['POST']) @login_required def place_simulation_order(): """下单""" try: # 移除交易时间检查,允许7x24小时下单 # 非交易时间下的单子会保持PENDING状态,等待行情数据 data = request.get_json() stock_code = data.get('stock_code') order_type = data.get('order_type') # BUY/SELL order_qty = data.get('order_qty') price_type = data.get('price_type', 'MARKET') # 目前只支持市价单 # 标准化股票代码格式 if stock_code and '.' not in stock_code: # 如果没有后缀,根据股票代码添加后缀 if stock_code.startswith('6'): stock_code = f"{stock_code}.SH" elif stock_code.startswith('0') or stock_code.startswith('3'): stock_code = f"{stock_code}.SZ" # 参数验证 if not all([stock_code, order_type, order_qty]): return jsonify({'success': False, 'error': '缺少必要参数'}), 400 if order_type not in ['BUY', 'SELL']: return jsonify({'success': False, 'error': '订单类型错误'}), 400 order_qty = int(order_qty) if order_qty <= 0 or order_qty % 100 != 0: return jsonify({'success': False, 'error': '下单数量必须为100的整数倍'}), 400 # 获取账户 account = get_or_create_simulation_account(current_user.id) # 获取股票信息 stock_name = None with engine.connect() as conn: result = conn.execute(text( "SELECT SECNAME FROM ea_stocklist WHERE SECCODE = :code" ), {"code": stock_code.split('.')[0]}).fetchone() if result: stock_name = result[0] # 创建订单 order = SimulationOrder( account_id=account.id, order_no=f"O{int(beijing_now().timestamp() * 1000000)}", stock_code=stock_code, stock_name=stock_name, order_type=order_type, price_type=price_type, order_qty=order_qty, status='PENDING' ) db.session.add(order) db.session.commit() # 执行订单 print(f"🔍 调试:开始执行订单,股票代码: {order.stock_code}, 订单类型: {order.order_type}") success = execute_simulation_order(order) print(f"🔍 调试:订单执行结果: {success}, 订单状态: {order.status}") if success: # 重新查询订单状态,因为可能在execute_simulation_order中被修改 db.session.refresh(order) if order.status == 'FILLED': return jsonify({ 'success': True, 'message': '订单执行成功,已成交', 'data': { 'order_no': order.order_no, 'status': 'FILLED', 'filled_price': float(order.filled_price) if order.filled_price else None, 'filled_qty': order.filled_qty, 'filled_amount': float(order.filled_amount) if order.filled_amount else None, 'total_fee': float(order.total_fee) } }) elif order.status == 'PENDING': return jsonify({ 'success': True, 'message': '订单提交成功,等待行情数据成交', 'data': { 'order_no': order.order_no, 'status': 'PENDING', 'order_qty': order.order_qty, 'order_price': float(order.order_price) if order.order_price else None } }) else: return jsonify({ 'success': False, 'error': order.reject_reason or '订单状态异常' }), 400 else: return jsonify({ 'success': False, 'error': order.reject_reason or '订单执行失败' }), 400 except Exception as e: db.session.rollback() return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/simulation/cancel-order/', methods=['POST']) @login_required def cancel_simulation_order(order_id): """撤销订单""" try: account = get_or_create_simulation_account(current_user.id) order = SimulationOrder.query.filter_by( id=order_id, account_id=account.id, status='PENDING' ).first() if not order: return jsonify({'success': False, 'error': '订单不存在或无法撤销'}), 404 order.status = 'CANCELLED' order.cancel_time = beijing_now() db.session.commit() return jsonify({ 'success': True, 'message': '订单已撤销' }) except Exception as e: db.session.rollback() return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/simulation/transactions', methods=['GET']) @login_required def get_simulation_transactions(): """获取成交记录""" try: account = get_or_create_simulation_account(current_user.id) # 获取查询参数 date_str = request.args.get('date') limit = request.args.get('limit', 100, type=int) query = SimulationTransaction.query.filter_by(account_id=account.id) if date_str: try: date = datetime.strptime(date_str, '%Y-%m-%d').date() start_time = datetime.combine(date, dt_time(0, 0, 0)) end_time = datetime.combine(date, dt_time(23, 59, 59)) query = query.filter(SimulationTransaction.transaction_time.between(start_time, end_time)) except ValueError: pass transactions = query.order_by(SimulationTransaction.transaction_time.desc()).limit(limit).all() transactions_data = [] for trans in transactions: transactions_data.append({ 'id': trans.id, 'transaction_no': trans.transaction_no, 'stock_code': trans.stock_code, 'stock_name': trans.stock_name, 'transaction_type': trans.transaction_type, 'transaction_price': float(trans.transaction_price), 'transaction_qty': trans.transaction_qty, 'transaction_amount': float(trans.transaction_amount), 'commission': float(trans.commission), 'stamp_tax': float(trans.stamp_tax), 'transfer_fee': float(trans.transfer_fee), 'total_fee': float(trans.total_fee), 'transaction_time': trans.transaction_time.isoformat(), 'settlement_date': trans.settlement_date.isoformat() if trans.settlement_date else None }) return jsonify({ 'success': True, 'data': transactions_data }) except Exception as e: return jsonify({'success': False, 'error': str(e)}), 500 def get_simulation_statistics(): """获取模拟交易统计""" try: account = get_or_create_simulation_account(current_user.id) # 获取统计时间范围 days = request.args.get('days', 30, type=int) end_date = beijing_now().date() start_date = end_date - timedelta(days=days) # 查询日统计数据 daily_stats = SimulationDailyStats.query.filter( SimulationDailyStats.account_id == account.id, SimulationDailyStats.stat_date >= start_date, SimulationDailyStats.stat_date <= end_date ).order_by(SimulationDailyStats.stat_date).all() # 查询总体统计 total_transactions = SimulationTransaction.query.filter_by(account_id=account.id).count() win_transactions = SimulationTransaction.query.filter( SimulationTransaction.account_id == account.id, SimulationTransaction.transaction_type == 'SELL' ).all() win_count = 0 total_profit = Decimal('0') for trans in win_transactions: # 查找对应的买入记录计算盈亏 position = SimulationPosition.query.filter_by( account_id=account.id, stock_code=trans.stock_code ).first() if position and trans.transaction_price > position.avg_cost: win_count += 1 profit = (trans.transaction_price - position.avg_cost) * trans.transaction_qty if position else 0 total_profit += profit # 构建日收益曲线 daily_returns = [] for stat in daily_stats: daily_returns.append({ 'date': stat.stat_date.isoformat(), 'daily_profit': float(stat.daily_profit), 'daily_profit_rate': float(stat.daily_profit_rate), 'total_profit': float(stat.total_profit), 'total_profit_rate': float(stat.total_profit_rate), 'closing_assets': float(stat.closing_assets) }) return jsonify({ 'success': True, 'data': { 'summary': { 'total_transactions': total_transactions, 'win_count': win_count, 'win_rate': (win_count / len(win_transactions) * 100) if win_transactions else 0, 'total_profit': float(total_profit), 'average_profit_per_trade': float(total_profit / len(win_transactions)) if win_transactions else 0 }, 'daily_returns': daily_returns } }) except Exception as e: return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/simulation/t1-settlement', methods=['POST']) @login_required def trigger_t1_settlement(): """手动触发T+1结算""" try: # 导入后台处理器的函数 from simulation_background_processor import process_t1_settlement # 执行T+1结算 process_t1_settlement() return jsonify({ 'success': True, 'message': 'T+1结算执行成功' }) except Exception as e: return jsonify({ 'success': False, 'error': str(e) }), 500 @app.route('/api/simulation/debug-positions', methods=['GET']) @login_required def debug_positions(): """调试接口:查看持仓数据""" try: account = get_or_create_simulation_account(current_user.id) positions = SimulationPosition.query.filter_by(account_id=account.id).all() positions_data = [] for position in positions: positions_data.append({ 'stock_code': position.stock_code, 'stock_name': position.stock_name, 'position_qty': position.position_qty, 'available_qty': position.available_qty, 'frozen_qty': position.frozen_qty, 'avg_cost': float(position.avg_cost), 'current_price': float(position.current_price or 0) }) return jsonify({ 'success': True, 'data': positions_data }) except Exception as e: return jsonify({ 'success': False, 'error': str(e) }), 500 @app.route('/api/simulation/debug-transactions', methods=['GET']) @login_required def debug_transactions(): """调试接口:查看成交记录数据""" try: account = get_or_create_simulation_account(current_user.id) transactions = SimulationTransaction.query.filter_by(account_id=account.id).all() transactions_data = [] for trans in transactions: transactions_data.append({ 'id': trans.id, 'transaction_no': trans.transaction_no, 'stock_code': trans.stock_code, 'stock_name': trans.stock_name, 'transaction_type': trans.transaction_type, 'transaction_price': float(trans.transaction_price), 'transaction_qty': trans.transaction_qty, 'transaction_amount': float(trans.transaction_amount), 'commission': float(trans.commission), 'stamp_tax': float(trans.stamp_tax), 'transfer_fee': float(trans.transfer_fee), 'total_fee': float(trans.total_fee), 'transaction_time': trans.transaction_time.isoformat(), 'settlement_date': trans.settlement_date.isoformat() if trans.settlement_date else None }) return jsonify({ 'success': True, 'data': transactions_data, 'count': len(transactions_data) }) except Exception as e: return jsonify({ 'success': False, 'error': str(e) }), 500 @app.route('/api/simulation/daily-settlement', methods=['POST']) @login_required def trigger_daily_settlement(): """手动触发日结算""" try: # 导入后台处理器的函数 from simulation_background_processor import generate_daily_stats # 执行日结算 generate_daily_stats() return jsonify({ 'success': True, 'message': '日结算执行成功' }) except Exception as e: return jsonify({ 'success': False, 'error': str(e) }), 500 @app.route('/api/simulation/reset', methods=['POST']) @login_required def reset_simulation_account(): """重置模拟账户""" try: account = SimulationAccount.query.filter_by(user_id=current_user.id).first() if account: # 删除所有相关数据 SimulationPosition.query.filter_by(account_id=account.id).delete() SimulationOrder.query.filter_by(account_id=account.id).delete() SimulationTransaction.query.filter_by(account_id=account.id).delete() SimulationDailyStats.query.filter_by(account_id=account.id).delete() # 重置账户数据 account.available_cash = account.initial_capital account.frozen_cash = Decimal('0') account.position_value = Decimal('0') account.total_assets = account.initial_capital account.total_profit = Decimal('0') account.total_profit_rate = Decimal('0') account.daily_profit = Decimal('0') account.daily_profit_rate = Decimal('0') account.updated_at = beijing_now() db.session.commit() return jsonify({ 'success': True, 'message': '模拟账户已重置' }) except Exception as e: db.session.rollback() return jsonify({'success': False, 'error': str(e)}), 500 # =========================== # 预测市场 API 路由 # 请将此文件内容插入到 app.py 的 `if __name__ == '__main__':` 之前 # =========================== # --- 积分系统 API --- @app.route('/api/prediction/credit/account', methods=['GET']) @login_required def get_credit_account(): """获取用户积分账户""" try: account = UserCreditAccount.query.filter_by(user_id=current_user.id).first() # 如果账户不存在,自动创建 if not account: account = UserCreditAccount(user_id=current_user.id) db.session.add(account) db.session.commit() return jsonify({ 'success': True, 'data': { 'balance': float(account.balance), 'frozen_balance': float(account.frozen_balance), 'available_balance': float(account.balance - account.frozen_balance), 'total_earned': float(account.total_earned), 'total_spent': float(account.total_spent), 'last_daily_bonus_at': account.last_daily_bonus_at.isoformat() if account.last_daily_bonus_at else None } }) except Exception as e: return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/prediction/credit/daily-bonus', methods=['POST']) @login_required def claim_daily_bonus(): """领取每日奖励(100积分)""" try: account = UserCreditAccount.query.filter_by(user_id=current_user.id).first() if not account: account = UserCreditAccount(user_id=current_user.id) db.session.add(account) # 检查是否已领取今日奖励 today = beijing_now().date() if account.last_daily_bonus_at and account.last_daily_bonus_at.date() == today: return jsonify({ 'success': False, 'error': '今日奖励已领取' }), 400 # 发放奖励 bonus_amount = 100.0 account.balance += bonus_amount account.total_earned += bonus_amount account.last_daily_bonus_at = beijing_now() # 记录交易 transaction = CreditTransaction( user_id=current_user.id, transaction_type='daily_bonus', amount=bonus_amount, balance_after=account.balance, description='每日登录奖励' ) db.session.add(transaction) db.session.commit() return jsonify({ 'success': True, 'message': f'领取成功,获得 {bonus_amount} 积分', 'data': { 'bonus_amount': bonus_amount, 'new_balance': float(account.balance) } }) except Exception as e: db.session.rollback() return jsonify({'success': False, 'error': str(e)}), 500 # --- 预测话题 API --- @app.route('/api/prediction/topics', methods=['POST']) @login_required def create_prediction_topic(): """创建预测话题(消耗100积分)""" try: data = request.get_json() title = data.get('title', '').strip() description = data.get('description', '').strip() category = data.get('category', 'stock') deadline_str = data.get('deadline') # 验证参数 if not title or len(title) < 5: return jsonify({'success': False, 'error': '标题至少5个字符'}), 400 if not deadline_str: return jsonify({'success': False, 'error': '请设置截止时间'}), 400 # 解析截止时间(移除时区信息以匹配数据库格式) deadline = datetime.fromisoformat(deadline_str.replace('Z', '+00:00')) # 移除时区信息,转换为naive datetime if deadline.tzinfo is not None: deadline = deadline.replace(tzinfo=None) if deadline <= beijing_now(): return jsonify({'success': False, 'error': '截止时间必须在未来'}), 400 # 检查积分账户 account = UserCreditAccount.query.filter_by(user_id=current_user.id).first() if not account or account.balance < 100: return jsonify({'success': False, 'error': '积分不足(需要100积分)'}), 400 # 扣除创建费用 create_cost = 100.0 account.balance -= create_cost account.total_spent += create_cost # 创建话题 topic = PredictionTopic( creator_id=current_user.id, title=title, description=description, category=category, deadline=deadline ) db.session.add(topic) # 记录积分交易 transaction = CreditTransaction( user_id=current_user.id, transaction_type='create_topic', amount=-create_cost, balance_after=account.balance, description=f'创建预测话题:{title}' ) db.session.add(transaction) db.session.commit() return jsonify({ 'success': True, 'message': '话题创建成功', 'data': { 'topic_id': topic.id, 'title': topic.title, 'new_balance': float(account.balance) } }) except Exception as e: db.session.rollback() return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/prediction/topics', methods=['GET']) def get_prediction_topics(): """获取预测话题列表""" try: status = request.args.get('status', 'active') category = request.args.get('category') sort_by = request.args.get('sort_by', 'created_at') page = request.args.get('page', 1, type=int) per_page = request.args.get('per_page', 20, type=int) # 构建查询 query = PredictionTopic.query if status: query = query.filter_by(status=status) if category: query = query.filter_by(category=category) # 排序 if sort_by == 'hot': query = query.order_by(desc(PredictionTopic.views_count)) elif sort_by == 'participants': query = query.order_by(desc(PredictionTopic.participants_count)) else: query = query.order_by(desc(PredictionTopic.created_at)) # 分页 pagination = query.paginate(page=page, per_page=per_page, error_out=False) topics = pagination.items # 格式化返回数据 topics_data = [] for topic in topics: # 计算市场倾向 total_shares = topic.yes_total_shares + topic.no_total_shares yes_prob = (topic.yes_total_shares / total_shares * 100) if total_shares > 0 else 50.0 # 处理datetime,确保移除时区信息 deadline = topic.deadline if hasattr(deadline, 'replace') and deadline.tzinfo is not None: deadline = deadline.replace(tzinfo=None) created_at = topic.created_at if hasattr(created_at, 'replace') and created_at.tzinfo is not None: created_at = created_at.replace(tzinfo=None) topics_data.append({ 'id': topic.id, 'title': topic.title, 'description': topic.description, 'category': topic.category, 'status': topic.status, 'yes_price': float(topic.yes_price), 'no_price': float(topic.no_price), 'yes_probability': round(yes_prob, 1), 'total_pool': float(topic.total_pool), 'yes_lord': { 'id': topic.yes_lord.id, 'username': topic.yes_lord.username, 'nickname': topic.yes_lord.nickname or topic.yes_lord.username, 'avatar_url': topic.yes_lord.avatar_url } if topic.yes_lord else None, 'no_lord': { 'id': topic.no_lord.id, 'username': topic.no_lord.username, 'nickname': topic.no_lord.nickname or topic.no_lord.username, 'avatar_url': topic.no_lord.avatar_url } if topic.no_lord else None, 'deadline': deadline.isoformat() if deadline else None, 'created_at': created_at.isoformat() if created_at else None, 'views_count': topic.views_count, 'comments_count': topic.comments_count, 'participants_count': topic.participants_count, 'creator': { 'id': topic.creator.id, 'username': topic.creator.username, 'nickname': topic.creator.nickname or topic.creator.username } }) return jsonify({ 'success': True, 'data': topics_data, 'pagination': { 'page': page, 'per_page': per_page, 'total': pagination.total, 'pages': pagination.pages, 'has_next': pagination.has_next, 'has_prev': pagination.has_prev } }) except Exception as e: import traceback print(f"[ERROR] 获取话题列表失败: {str(e)}") print(traceback.format_exc()) return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/prediction/topics/', methods=['GET']) def get_prediction_topic_detail(topic_id): """获取预测话题详情""" try: # 刷新会话,确保获取最新数据 db.session.expire_all() topic = PredictionTopic.query.get_or_404(topic_id) # 增加浏览量 topic.views_count += 1 db.session.commit() # 计算市场倾向 total_shares = topic.yes_total_shares + topic.no_total_shares yes_prob = (topic.yes_total_shares / total_shares * 100) if total_shares > 0 else 50.0 # 获取 TOP 5 持仓(YES 和 NO 各5个) yes_top_positions = PredictionPosition.query.filter_by( topic_id=topic_id, direction='yes' ).order_by(desc(PredictionPosition.shares)).limit(5).all() no_top_positions = PredictionPosition.query.filter_by( topic_id=topic_id, direction='no' ).order_by(desc(PredictionPosition.shares)).limit(5).all() def format_position(position): return { 'user': { 'id': position.user.id, 'username': position.user.username, 'nickname': position.user.nickname or position.user.username, 'avatar_url': position.user.avatar_url }, 'shares': position.shares, 'avg_cost': float(position.avg_cost), 'total_invested': float(position.total_invested), 'is_lord': (topic.yes_lord_id == position.user_id and position.direction == 'yes') or (topic.no_lord_id == position.user_id and position.direction == 'no') } return jsonify({ 'success': True, 'data': { 'id': topic.id, 'title': topic.title, 'description': topic.description, 'category': topic.category, 'status': topic.status, 'result': topic.result, 'yes_price': float(topic.yes_price), 'no_price': float(topic.no_price), 'yes_total_shares': topic.yes_total_shares, 'no_total_shares': topic.no_total_shares, 'yes_probability': round(yes_prob, 1), 'no_probability': round(100 - yes_prob, 1), 'total_pool': float(topic.total_pool), 'yes_lord': { 'id': topic.yes_lord.id, 'username': topic.yes_lord.username, 'nickname': topic.yes_lord.nickname or topic.yes_lord.username, 'avatar_url': topic.yes_lord.avatar_url } if topic.yes_lord else None, 'no_lord': { 'id': topic.no_lord.id, 'username': topic.no_lord.username, 'nickname': topic.no_lord.nickname or topic.no_lord.username, 'avatar_url': topic.no_lord.avatar_url } if topic.no_lord else None, 'yes_top_positions': [format_position(p) for p in yes_top_positions], 'no_top_positions': [format_position(p) for p in no_top_positions], 'deadline': topic.deadline.isoformat(), 'settled_at': topic.settled_at.isoformat() if topic.settled_at else None, 'created_at': topic.created_at.isoformat(), 'views_count': topic.views_count, 'comments_count': topic.comments_count, 'participants_count': topic.participants_count, 'creator': { 'id': topic.creator.id, 'username': topic.creator.username, 'nickname': topic.creator.nickname or topic.creator.username, 'avatar_url': topic.creator.avatar_url } } }) except Exception as e: return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/prediction/topics//settle', methods=['POST']) @login_required def settle_prediction_topic(topic_id): """结算预测话题(仅创建者可操作)""" try: topic = PredictionTopic.query.get_or_404(topic_id) # 验证权限 if topic.creator_id != current_user.id: return jsonify({'success': False, 'error': '只有创建者可以结算'}), 403 # 验证状态 if topic.status != 'active': return jsonify({'success': False, 'error': '话题已结算或已取消'}), 400 # 验证截止时间 if beijing_now() < topic.deadline: return jsonify({'success': False, 'error': '未到截止时间'}), 400 # 获取结算结果 data = request.get_json() result = data.get('result') # 'yes', 'no', 'draw' if result not in ['yes', 'no', 'draw']: return jsonify({'success': False, 'error': '无效的结算结果'}), 400 # 更新话题状态 topic.status = 'settled' topic.result = result topic.settled_at = beijing_now() # 获取获胜方的所有持仓 if result == 'draw': # 平局:所有人按投入比例分配奖池 all_positions = PredictionPosition.query.filter_by(topic_id=topic_id).all() total_invested = sum(p.total_invested for p in all_positions) for position in all_positions: if total_invested > 0: share_ratio = position.total_invested / total_invested prize = topic.total_pool * share_ratio # 发放奖金 account = UserCreditAccount.query.filter_by(user_id=position.user_id).first() if account: account.balance += prize account.total_earned += prize # 记录交易 transaction = CreditTransaction( user_id=position.user_id, transaction_type='settle_win', amount=prize, balance_after=account.balance, related_topic_id=topic_id, description=f'预测平局,获得奖池分红:{topic.title}' ) db.session.add(transaction) else: # YES 或 NO 获胜 winning_direction = result winning_positions = PredictionPosition.query.filter_by( topic_id=topic_id, direction=winning_direction ).all() if winning_positions: total_winning_shares = sum(p.shares for p in winning_positions) for position in winning_positions: # 按份额比例分配奖池 share_ratio = position.shares / total_winning_shares prize = topic.total_pool * share_ratio # 发放奖金 account = UserCreditAccount.query.filter_by(user_id=position.user_id).first() if account: account.balance += prize account.total_earned += prize # 记录交易 transaction = CreditTransaction( user_id=position.user_id, transaction_type='settle_win', amount=prize, balance_after=account.balance, related_topic_id=topic_id, description=f'预测正确,获得奖金:{topic.title}' ) db.session.add(transaction) db.session.commit() return jsonify({ 'success': True, 'message': f'话题已结算,结果为:{result}', 'data': { 'topic_id': topic.id, 'result': result, 'total_pool': float(topic.total_pool), 'settled_at': topic.settled_at.isoformat() } }) except Exception as e: db.session.rollback() return jsonify({'success': False, 'error': str(e)}), 500 # --- 交易 API --- @app.route('/api/prediction/trade/buy', methods=['POST']) @login_required def buy_prediction_shares(): """买入预测份额""" try: data = request.get_json() topic_id = data.get('topic_id') direction = data.get('direction') # 'yes' or 'no' shares = data.get('shares', 0) # 验证参数 if not topic_id or direction not in ['yes', 'no'] or shares <= 0: return jsonify({'success': False, 'error': '参数错误'}), 400 if shares > 1000: return jsonify({'success': False, 'error': '单次最多买入1000份额'}), 400 # 获取话题 topic = PredictionTopic.query.get_or_404(topic_id) if topic.status != 'active': return jsonify({'success': False, 'error': '话题已结算或已取消'}), 400 if beijing_now() >= topic.deadline: return jsonify({'success': False, 'error': '话题已截止'}), 400 # 获取积分账户 account = UserCreditAccount.query.filter_by(user_id=current_user.id).first() if not account: account = UserCreditAccount(user_id=current_user.id) db.session.add(account) db.session.flush() # 计算价格 current_price = topic.yes_price if direction == 'yes' else topic.no_price # 简化的AMM定价:price = (对应方份额 / 总份额) * 1000 total_shares = topic.yes_total_shares + topic.no_total_shares if total_shares > 0: if direction == 'yes': current_price = (topic.yes_total_shares / total_shares) * 1000 else: current_price = (topic.no_total_shares / total_shares) * 1000 else: current_price = 500.0 # 初始价格 # 买入后价格会上涨,使用平均价格 after_total = total_shares + shares if direction == 'yes': after_yes_shares = topic.yes_total_shares + shares after_price = (after_yes_shares / after_total) * 1000 else: after_no_shares = topic.no_total_shares + shares after_price = (after_no_shares / after_total) * 1000 avg_price = (current_price + after_price) / 2 # 计算费用 amount = avg_price * shares tax = amount * 0.02 # 2% 手续费 total_cost = amount + tax # 检查余额 if account.balance < total_cost: return jsonify({'success': False, 'error': '积分不足'}), 400 # 扣除费用 account.balance -= total_cost account.total_spent += total_cost # 更新话题数据 if direction == 'yes': topic.yes_total_shares += shares topic.yes_price = after_price else: topic.no_total_shares += shares topic.no_price = after_price topic.total_pool += tax # 手续费进入奖池 # 更新或创建持仓 position = PredictionPosition.query.filter_by( user_id=current_user.id, topic_id=topic_id, direction=direction ).first() if position: # 更新平均成本 old_cost = position.avg_cost * position.shares new_cost = avg_price * shares position.shares += shares position.avg_cost = (old_cost + new_cost) / position.shares position.total_invested += total_cost else: position = PredictionPosition( user_id=current_user.id, topic_id=topic_id, direction=direction, shares=shares, avg_cost=avg_price, total_invested=total_cost ) db.session.add(position) topic.participants_count += 1 # 更新领主 if direction == 'yes': # 找到YES方持仓最多的用户 top_yes = db.session.query(PredictionPosition).filter_by( topic_id=topic_id, direction='yes' ).order_by(desc(PredictionPosition.shares)).first() if top_yes: topic.yes_lord_id = top_yes.user_id else: # 找到NO方持仓最多的用户 top_no = db.session.query(PredictionPosition).filter_by( topic_id=topic_id, direction='no' ).order_by(desc(PredictionPosition.shares)).first() if top_no: topic.no_lord_id = top_no.user_id # 记录交易 transaction = PredictionTransaction( user_id=current_user.id, topic_id=topic_id, trade_type='buy', direction=direction, shares=shares, price=avg_price, amount=amount, tax=tax, total_cost=total_cost ) db.session.add(transaction) # 记录积分交易 credit_transaction = CreditTransaction( user_id=current_user.id, transaction_type='prediction_buy', amount=-total_cost, balance_after=account.balance, related_topic_id=topic_id, related_transaction_id=transaction.id, description=f'买入 {direction.upper()} 份额:{topic.title}' ) db.session.add(credit_transaction) db.session.commit() return jsonify({ 'success': True, 'message': '买入成功', 'data': { 'transaction_id': transaction.id, 'shares': shares, 'price': round(avg_price, 2), 'total_cost': round(total_cost, 2), 'tax': round(tax, 2), 'new_balance': float(account.balance), 'new_position': { 'shares': position.shares, 'avg_cost': float(position.avg_cost) } } }) except Exception as e: db.session.rollback() return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/prediction/positions', methods=['GET']) @login_required def get_user_positions(): """获取用户的所有持仓""" try: positions = PredictionPosition.query.filter_by(user_id=current_user.id).all() positions_data = [] for position in positions: topic = position.topic # 计算当前市值(如果话题还在进行中) current_value = 0 profit = 0 profit_rate = 0 if topic.status == 'active': current_price = topic.yes_price if position.direction == 'yes' else topic.no_price current_value = current_price * position.shares profit = current_value - position.total_invested profit_rate = (profit / position.total_invested * 100) if position.total_invested > 0 else 0 positions_data.append({ 'id': position.id, 'topic': { 'id': topic.id, 'title': topic.title, 'status': topic.status, 'result': topic.result, 'deadline': topic.deadline.isoformat() }, 'direction': position.direction, 'shares': position.shares, 'avg_cost': float(position.avg_cost), 'total_invested': float(position.total_invested), 'current_value': round(current_value, 2), 'profit': round(profit, 2), 'profit_rate': round(profit_rate, 2), 'created_at': position.created_at.isoformat(), 'is_lord': (topic.yes_lord_id == current_user.id and position.direction == 'yes') or (topic.no_lord_id == current_user.id and position.direction == 'no') }) return jsonify({ 'success': True, 'data': positions_data, 'count': len(positions_data) }) except Exception as e: return jsonify({'success': False, 'error': str(e)}), 500 # --- 评论 API --- @app.route('/api/prediction/topics//comments', methods=['POST']) @login_required def create_topic_comment(topic_id): """发表话题评论""" try: topic = PredictionTopic.query.get_or_404(topic_id) data = request.get_json() content = data.get('content', '').strip() parent_id = data.get('parent_id') if not content or len(content) < 2: return jsonify({'success': False, 'error': '评论内容至少2个字符'}), 400 # 创建评论 comment = TopicComment( topic_id=topic_id, user_id=current_user.id, content=content, parent_id=parent_id ) # 如果是领主评论,自动置顶 is_lord = (topic.yes_lord_id == current_user.id) or (topic.no_lord_id == current_user.id) if is_lord: comment.is_pinned = True db.session.add(comment) # 更新话题评论数 topic.comments_count += 1 db.session.commit() return jsonify({ 'success': True, 'message': '评论成功', 'data': { 'comment_id': comment.id, 'content': comment.content, 'is_pinned': comment.is_pinned, 'created_at': comment.created_at.isoformat() } }) except Exception as e: db.session.rollback() return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/prediction/topics//comments', methods=['GET']) def get_topic_comments(topic_id): """获取话题评论列表""" try: topic = PredictionTopic.query.get_or_404(topic_id) page = request.args.get('page', 1, type=int) per_page = request.args.get('per_page', 20, type=int) # 置顶评论在前,然后按时间倒序 query = TopicComment.query.filter_by( topic_id=topic_id, status='active', parent_id=None # 只获取顶级评论 ).order_by( desc(TopicComment.is_pinned), desc(TopicComment.created_at) ) pagination = query.paginate(page=page, per_page=per_page, error_out=False) comments = pagination.items def format_comment(comment): # 获取回复 replies = TopicComment.query.filter_by( parent_id=comment.id, status='active' ).order_by(TopicComment.created_at).limit(5).all() return { 'id': comment.id, 'content': comment.content, 'is_pinned': comment.is_pinned, 'likes_count': comment.likes_count, 'created_at': comment.created_at.isoformat(), 'user': { 'id': comment.user.id, 'username': comment.user.username, 'nickname': comment.user.nickname or comment.user.username, 'avatar_url': comment.user.avatar_url }, 'is_lord': (topic.yes_lord_id == comment.user_id) or (topic.no_lord_id == comment.user_id), 'replies': [{ 'id': reply.id, 'content': reply.content, 'created_at': reply.created_at.isoformat(), 'user': { 'id': reply.user.id, 'username': reply.user.username, 'nickname': reply.user.nickname or reply.user.username, 'avatar_url': reply.user.avatar_url } } for reply in replies] } comments_data = [format_comment(comment) for comment in comments] return jsonify({ 'success': True, 'data': comments_data, 'pagination': { 'page': page, 'per_page': per_page, 'total': pagination.total, 'pages': pagination.pages } }) except Exception as e: return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/prediction/comments//like', methods=['POST']) @login_required def like_topic_comment(comment_id): """点赞/取消点赞评论""" try: comment = TopicComment.query.get_or_404(comment_id) # 检查是否已点赞 existing_like = TopicCommentLike.query.filter_by( comment_id=comment_id, user_id=current_user.id ).first() if existing_like: # 取消点赞 db.session.delete(existing_like) comment.likes_count = max(0, comment.likes_count - 1) action = 'unliked' else: # 点赞 like = TopicCommentLike( comment_id=comment_id, user_id=current_user.id ) db.session.add(like) comment.likes_count += 1 action = 'liked' db.session.commit() return jsonify({ 'success': True, 'action': action, 'likes_count': comment.likes_count }) except Exception as e: db.session.rollback() return jsonify({'success': False, 'error': str(e)}), 500 # ==================== 观点IPO API ==================== @app.route('/api/prediction/comments//invest', methods=['POST']) @login_required def invest_comment(comment_id): """投资评论(观点IPO)""" try: data = request.json shares = data.get('shares', 1) # 获取评论 comment = TopicComment.query.get_or_404(comment_id) # 检查评论是否已结算 if comment.is_verified: return jsonify({'success': False, 'error': '该评论已结算,无法继续投资'}), 400 # 检查是否是自己的评论 if comment.user_id == current_user.id: return jsonify({'success': False, 'error': '不能投资自己的评论'}), 400 # 计算投资金额(简化:每份100积分基础价格 + 已有投资额/10) base_price = 100 price_increase = comment.total_investment / 10 if comment.total_investment > 0 else 0 price_per_share = base_price + price_increase amount = int(price_per_share * shares) # 获取用户积分账户 account = UserCreditAccount.query.filter_by(user_id=current_user.id).first() if not account: return jsonify({'success': False, 'error': '账户不存在'}), 404 # 检查余额 if account.balance < amount: return jsonify({'success': False, 'error': '积分不足'}), 400 # 扣减积分 account.balance -= amount # 检查是否已有投资记录 existing_investment = CommentInvestment.query.filter_by( comment_id=comment_id, user_id=current_user.id, status='active' ).first() if existing_investment: # 更新投资记录 total_shares = existing_investment.shares + shares total_amount = existing_investment.amount + amount existing_investment.shares = total_shares existing_investment.amount = total_amount existing_investment.avg_price = total_amount / total_shares else: # 创建新投资记录 investment = CommentInvestment( comment_id=comment_id, user_id=current_user.id, shares=shares, amount=amount, avg_price=price_per_share ) db.session.add(investment) comment.investor_count += 1 # 更新评论统计 comment.total_investment += amount # 记录积分交易 transaction = CreditTransaction( user_id=current_user.id, type='comment_investment', amount=-amount, balance_after=account.balance, description=f'投资评论 #{comment_id}' ) db.session.add(transaction) db.session.commit() return jsonify({ 'success': True, 'data': { 'shares': shares, 'amount': amount, 'price_per_share': price_per_share, 'total_investment': comment.total_investment, 'investor_count': comment.investor_count, 'new_balance': account.balance } }) except Exception as e: db.session.rollback() return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/prediction/comments//investments', methods=['GET']) def get_comment_investments(comment_id): """获取评论的投资列表""" try: investments = CommentInvestment.query.filter_by( comment_id=comment_id, status='active' ).all() result = [] for inv in investments: user = User.query.get(inv.user_id) result.append({ 'id': inv.id, 'user_id': inv.user_id, 'user_name': user.username if user else '未知用户', 'user_avatar': user.avatar if user else None, 'shares': inv.shares, 'amount': inv.amount, 'avg_price': inv.avg_price, 'created_at': inv.created_at.strftime('%Y-%m-%d %H:%M:%S') }) return jsonify({ 'success': True, 'data': result }) except Exception as e: return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/prediction/comments//verify', methods=['POST']) @login_required def verify_comment(comment_id): """管理员验证评论预测结果""" try: # 检查管理员权限(简化版:假设 user_id=1 是管理员) if current_user.id != 1: return jsonify({'success': False, 'error': '无权限操作'}), 403 data = request.json result = data.get('result') # 'correct' or 'incorrect' if result not in ['correct', 'incorrect']: return jsonify({'success': False, 'error': '无效的验证结果'}), 400 comment = TopicComment.query.get_or_404(comment_id) # 检查是否已验证 if comment.is_verified: return jsonify({'success': False, 'error': '该评论已验证'}), 400 # 更新验证状态 comment.is_verified = True comment.verification_result = result # 如果预测正确,进行收益分配 if result == 'correct' and comment.total_investment > 0: # 获取所有投资记录 investments = CommentInvestment.query.filter_by( comment_id=comment_id, status='active' ).all() # 计算总收益(总投资额的1.5倍) total_reward = int(comment.total_investment * 1.5) # 按份额比例分配收益 total_shares = sum([inv.shares for inv in investments]) for inv in investments: # 计算该投资者的收益 investor_reward = int((inv.shares / total_shares) * total_reward) # 获取投资者账户 account = UserCreditAccount.query.filter_by(user_id=inv.user_id).first() if account: account.balance += investor_reward # 记录积分交易 transaction = CreditTransaction( user_id=inv.user_id, type='comment_investment_profit', amount=investor_reward, balance_after=account.balance, description=f'评论投资收益 #{comment_id}' ) db.session.add(transaction) # 更新投资状态 inv.status = 'settled' # 评论作者也获得奖励(总投资额的20%) author_reward = int(comment.total_investment * 0.2) author_account = UserCreditAccount.query.filter_by(user_id=comment.user_id).first() if author_account: author_account.balance += author_reward transaction = CreditTransaction( user_id=comment.user_id, type='comment_author_bonus', amount=author_reward, balance_after=author_account.balance, description=f'评论作者奖励 #{comment_id}' ) db.session.add(transaction) db.session.commit() return jsonify({ 'success': True, 'data': { 'comment_id': comment_id, 'verification_result': result, 'total_investment': comment.total_investment } }) except Exception as e: db.session.rollback() return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/prediction/topics//bid-position', methods=['POST']) @login_required def bid_comment_position(topic_id): """竞拍评论位置(首发权拍卖)""" try: data = request.json position = data.get('position') # 1/2/3 bid_amount = data.get('bid_amount') if position not in [1, 2, 3]: return jsonify({'success': False, 'error': '无效的位置'}), 400 if bid_amount < 500: return jsonify({'success': False, 'error': '最低出价500积分'}), 400 # 获取用户积分账户 account = UserCreditAccount.query.filter_by(user_id=current_user.id).first() if not account or account.balance < bid_amount: return jsonify({'success': False, 'error': '积分不足'}), 400 # 检查该位置的当前最高出价 current_highest = CommentPositionBid.query.filter_by( topic_id=topic_id, position=position, status='pending' ).order_by(CommentPositionBid.bid_amount.desc()).first() if current_highest and bid_amount <= current_highest.bid_amount: return jsonify({ 'success': False, 'error': f'出价必须高于当前最高价 {current_highest.bid_amount}' }), 400 # 扣减积分(冻结) account.balance -= bid_amount account.frozen += bid_amount # 如果有之前的出价,退还积分 user_previous_bid = CommentPositionBid.query.filter_by( topic_id=topic_id, position=position, user_id=current_user.id, status='pending' ).first() if user_previous_bid: account.frozen -= user_previous_bid.bid_amount account.balance += user_previous_bid.bid_amount user_previous_bid.status = 'lost' # 创建竞拍记录 topic = PredictionTopic.query.get_or_404(topic_id) bid = CommentPositionBid( topic_id=topic_id, user_id=current_user.id, position=position, bid_amount=bid_amount, expires_at=topic.deadline # 竞拍截止时间与话题截止时间相同 ) db.session.add(bid) # 记录积分交易 transaction = CreditTransaction( user_id=current_user.id, type='position_bid', amount=-bid_amount, balance_after=account.balance, description=f'竞拍评论位置 #{position} (话题#{topic_id})' ) db.session.add(transaction) db.session.commit() return jsonify({ 'success': True, 'data': { 'bid_id': bid.id, 'position': position, 'bid_amount': bid_amount, 'new_balance': account.balance, 'frozen': account.frozen } }) except Exception as e: db.session.rollback() return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/prediction/topics//position-bids', methods=['GET']) def get_position_bids(topic_id): """获取话题的位置竞拍列表""" try: result = {} for position in [1, 2, 3]: bids = CommentPositionBid.query.filter_by( topic_id=topic_id, position=position, status='pending' ).order_by(CommentPositionBid.bid_amount.desc()).limit(5).all() position_bids = [] for bid in bids: user = User.query.get(bid.user_id) position_bids.append({ 'id': bid.id, 'user_id': bid.user_id, 'user_name': user.username if user else '未知用户', 'user_avatar': user.avatar if user else None, 'bid_amount': bid.bid_amount, 'created_at': bid.created_at.strftime('%Y-%m-%d %H:%M:%S') }) result[f'position_{position}'] = position_bids return jsonify({ 'success': True, 'data': result }) except Exception as e: return jsonify({'success': False, 'error': str(e)}), 500 # ==================== 时间胶囊 API ==================== @app.route('/api/time-capsule/topics', methods=['POST']) @login_required def create_time_capsule_topic(): """创建时间胶囊话题""" try: data = request.json title = data.get('title') description = data.get('description', '') encrypted_content = data.get('encrypted_content') encryption_key = data.get('encryption_key') start_year = data.get('start_year') end_year = data.get('end_year') # 验证 if not title or not encrypted_content or not encryption_key: return jsonify({'success': False, 'error': '缺少必要参数'}), 400 if not start_year or not end_year or end_year <= start_year: return jsonify({'success': False, 'error': '无效的时间范围'}), 400 # 获取用户积分账户 account = UserCreditAccount.query.filter_by(user_id=current_user.id).first() if not account or account.balance < 100: return jsonify({'success': False, 'error': '积分不足,需要100积分'}), 400 # 扣减积分 account.balance -= 100 # 创建话题 topic = TimeCapsuleTopic( user_id=current_user.id, title=title, description=description, encrypted_content=encrypted_content, encryption_key=encryption_key, start_year=start_year, end_year=end_year, total_pool=100 # 创建费用进入奖池 ) db.session.add(topic) db.session.flush() # 获取 topic.id # 自动创建时间段(每年一个时间段) for year in range(start_year, end_year + 1): slot = TimeCapsuleTimeSlot( topic_id=topic.id, year_start=year, year_end=year ) db.session.add(slot) # 记录积分交易 transaction = CreditTransaction( user_id=current_user.id, type='time_capsule_create', amount=-100, balance_after=account.balance, description=f'创建时间胶囊话题 #{topic.id}' ) db.session.add(transaction) db.session.commit() return jsonify({ 'success': True, 'data': { 'topic_id': topic.id, 'title': topic.title, 'new_balance': account.balance } }) except Exception as e: db.session.rollback() return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/time-capsule/topics', methods=['GET']) def get_time_capsule_topics(): """获取时间胶囊话题列表""" try: status = request.args.get('status', 'active') query = TimeCapsuleTopic.query.filter_by(status=status) topics = query.order_by(TimeCapsuleTopic.created_at.desc()).all() result = [] for topic in topics: # 获取用户信息 user = User.query.get(topic.user_id) # 获取时间段统计 slots = TimeCapsuleTimeSlot.query.filter_by(topic_id=topic.id).all() total_slots = len(slots) active_slots = len([s for s in slots if s.status == 'active']) result.append({ 'id': topic.id, 'title': topic.title, 'description': topic.description, 'start_year': topic.start_year, 'end_year': topic.end_year, 'total_pool': topic.total_pool, 'total_slots': total_slots, 'active_slots': active_slots, 'is_decrypted': topic.is_decrypted, 'status': topic.status, 'author_id': topic.user_id, 'author_name': user.username if user else '未知用户', 'author_avatar': user.avatar if user else None, 'created_at': topic.created_at.strftime('%Y-%m-%d %H:%M:%S') }) return jsonify({ 'success': True, 'data': result }) except Exception as e: return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/time-capsule/topics/', methods=['GET']) def get_time_capsule_topic(topic_id): """获取时间胶囊话题详情""" try: topic = TimeCapsuleTopic.query.get_or_404(topic_id) user = User.query.get(topic.user_id) # 获取所有时间段 slots = TimeCapsuleTimeSlot.query.filter_by(topic_id=topic_id).order_by(TimeCapsuleTimeSlot.year_start).all() slots_data = [] for slot in slots: holder = User.query.get(slot.current_holder_id) if slot.current_holder_id else None slots_data.append({ 'id': slot.id, 'year_start': slot.year_start, 'year_end': slot.year_end, 'current_price': slot.current_price, 'total_bids': slot.total_bids, 'status': slot.status, 'current_holder_id': slot.current_holder_id, 'current_holder_name': holder.username if holder else None, 'current_holder_avatar': holder.avatar if holder else None }) result = { 'id': topic.id, 'title': topic.title, 'description': topic.description, 'start_year': topic.start_year, 'end_year': topic.end_year, 'total_pool': topic.total_pool, 'is_decrypted': topic.is_decrypted, 'decrypted_content': topic.encrypted_content if topic.is_decrypted else None, 'actual_happened_year': topic.actual_happened_year, 'status': topic.status, 'author_id': topic.user_id, 'author_name': user.username if user else '未知用户', 'author_avatar': user.avatar if user else None, 'time_slots': slots_data, 'created_at': topic.created_at.strftime('%Y-%m-%d %H:%M:%S') } return jsonify({ 'success': True, 'data': result }) except Exception as e: return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/time-capsule/slots//bid', methods=['POST']) @login_required def bid_time_slot(slot_id): """竞拍时间段""" try: data = request.json bid_amount = data.get('bid_amount') slot = TimeCapsuleTimeSlot.query.get_or_404(slot_id) # 检查时间段是否还在竞拍 if slot.status != 'active': return jsonify({'success': False, 'error': '该时间段已结束竞拍'}), 400 # 检查出价是否高于当前价格 min_bid = slot.current_price + 50 # 至少比当前价格高50积分 if bid_amount < min_bid: return jsonify({ 'success': False, 'error': f'出价必须至少为 {min_bid} 积分' }), 400 # 获取用户积分账户 account = UserCreditAccount.query.filter_by(user_id=current_user.id).first() if not account or account.balance < bid_amount: return jsonify({'success': False, 'error': '积分不足'}), 400 # 扣减积分 account.balance -= bid_amount # 如果有前任持有者,退还积分 if slot.current_holder_id: prev_holder_account = UserCreditAccount.query.filter_by(user_id=slot.current_holder_id).first() if prev_holder_account: prev_holder_account.balance += slot.current_price # 更新前任的竞拍记录状态 prev_bid = TimeSlotBid.query.filter_by( slot_id=slot_id, user_id=slot.current_holder_id, status='holding' ).first() if prev_bid: prev_bid.status = 'outbid' # 创建竞拍记录 bid = TimeSlotBid( slot_id=slot_id, user_id=current_user.id, bid_amount=bid_amount, status='holding' ) db.session.add(bid) # 更新时间段 slot.current_holder_id = current_user.id slot.current_price = bid_amount slot.total_bids += 1 # 更新话题奖池 topic = TimeCapsuleTopic.query.get(slot.topic_id) price_increase = bid_amount - (slot.current_price if slot.current_holder_id else 100) topic.total_pool += price_increase # 记录积分交易 transaction = CreditTransaction( user_id=current_user.id, type='time_slot_bid', amount=-bid_amount, balance_after=account.balance, description=f'竞拍时间段 {slot.year_start}-{slot.year_end}' ) db.session.add(transaction) db.session.commit() return jsonify({ 'success': True, 'data': { 'slot_id': slot_id, 'bid_amount': bid_amount, 'new_balance': account.balance, 'total_pool': topic.total_pool } }) except Exception as e: db.session.rollback() return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/time-capsule/topics//decrypt', methods=['POST']) @login_required def decrypt_time_capsule(topic_id): """解密时间胶囊(管理员或作者)""" try: topic = TimeCapsuleTopic.query.get_or_404(topic_id) # 检查权限(管理员或作者) if current_user.id != 1 and current_user.id != topic.user_id: return jsonify({'success': False, 'error': '无权限操作'}), 403 # 检查是否已解密 if topic.is_decrypted: return jsonify({'success': False, 'error': '该话题已解密'}), 400 # 解密(前端会用密钥解密内容) topic.is_decrypted = True db.session.commit() return jsonify({ 'success': True, 'data': { 'encrypted_content': topic.encrypted_content, 'encryption_key': topic.encryption_key } }) except Exception as e: db.session.rollback() return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/time-capsule/topics//settle', methods=['POST']) @login_required def settle_time_capsule(topic_id): """结算时间胶囊话题""" try: # 检查管理员权限 if current_user.id != 1: return jsonify({'success': False, 'error': '无权限操作'}), 403 data = request.json happened_year = data.get('happened_year') topic = TimeCapsuleTopic.query.get_or_404(topic_id) # 检查是否已结算 if topic.status == 'settled': return jsonify({'success': False, 'error': '该话题已结算'}), 400 # 更新话题状态 topic.status = 'settled' topic.actual_happened_year = happened_year # 找到中奖的时间段 winning_slot = TimeCapsuleTimeSlot.query.filter_by( topic_id=topic_id, year_start=happened_year ).first() if winning_slot and winning_slot.current_holder_id: # 中奖者获得全部奖池 winner_account = UserCreditAccount.query.filter_by(user_id=winning_slot.current_holder_id).first() if winner_account: winner_account.balance += topic.total_pool # 记录积分交易 transaction = CreditTransaction( user_id=winning_slot.current_holder_id, type='time_capsule_win', amount=topic.total_pool, balance_after=winner_account.balance, description=f'时间胶囊中奖 #{topic_id}' ) db.session.add(transaction) # 更新竞拍记录 winning_bid = TimeSlotBid.query.filter_by( slot_id=winning_slot.id, user_id=winning_slot.current_holder_id, status='holding' ).first() if winning_bid: winning_bid.status = 'won' # 更新时间段状态 winning_slot.status = 'won' # 其他时间段设为过期 other_slots = TimeCapsuleTimeSlot.query.filter( TimeCapsuleTimeSlot.topic_id == topic_id, TimeCapsuleTimeSlot.id != (winning_slot.id if winning_slot else None) ).all() for slot in other_slots: slot.status = 'expired' db.session.commit() return jsonify({ 'success': True, 'data': { 'topic_id': topic_id, 'happened_year': happened_year, 'winner_id': winning_slot.current_holder_id if winning_slot else None, 'prize': topic.total_pool } }) except Exception as e: db.session.rollback() return jsonify({'success': False, 'error': str(e)}), 500 if __name__ == '__main__': # 创建数据库表 with app.app_context(): try: db.create_all() # 安全地初始化订阅套餐 initialize_subscription_plans_safe() except Exception as e: app.logger.error(f"数据库初始化失败: {e}") # 初始化事件轮询机制(WebSocket 推送) initialize_event_polling() # 启动时预热股票缓存(股票名称 + 前收盘价) print("[启动] 正在预热股票缓存...") try: preload_stock_cache() except Exception as e: print(f"[启动] 预热缓存失败(不影响服务启动): {e}") # 使用 socketio.run 替代 app.run 以支持 WebSocket socketio.run(app, host='0.0.0.0', port=5001, debug=False, allow_unsafe_werkzeug=True)