# ===================== Gevent Monkey Patch (必须在最开头) ===================== # 检测是否通过 gevent/gunicorn 运行,如果是则打 monkey patch import os import sys # 检查环境变量或命令行参数判断是否需要 gevent _USE_GEVENT = os.environ.get('USE_GEVENT', 'false').lower() == 'true' if _USE_GEVENT or 'gevent' in sys.modules: try: from gevent import monkey monkey.patch_all() print("✅ Gevent monkey patch 已应用") except ImportError: print("⚠️ Gevent 未安装,跳过 monkey patch") # ===================== Gevent Monkey Patch 结束 ===================== import csv import logging import random import re import math import secrets import string import pytz import requests from flask_compress import Compress from functools import wraps from pathlib import Path import time from sqlalchemy import create_engine, text, func, or_, case, event, desc, asc from flask import Flask, has_request_context, render_template, request, jsonify, redirect, url_for, flash, session, \ render_template_string, current_app, send_from_directory # Flask 3.x 兼容性补丁:flask-sqlalchemy 旧版本需要 _app_ctx_stack try: from flask import _app_ctx_stack except ImportError: import flask from werkzeug.local import LocalStack import threading # 创建一个兼容的 LocalStack 子类 class CompatLocalStack(LocalStack): @property def __ident_func__(self): # 返回当前线程的标识函数 # 优先使用 greenlet(协程),否则使用 threading try: from greenlet import getcurrent return getcurrent except ImportError: return threading.get_ident flask._app_ctx_stack = CompatLocalStack() from flask_sqlalchemy import SQLAlchemy from flask_login import LoginManager, UserMixin, login_user, logout_user, login_required, current_user from flask_mail import Mail, Message from itsdangerous import URLSafeTimedSerializer from flask_migrate import Migrate from flask_session import Session # type: ignore from sqlalchemy.dialects.mysql.base import MySQLDialect from sqlalchemy.dialects.postgresql import JSONB from werkzeug.utils import secure_filename from PIL import Image from datetime import datetime, timedelta, time as dt_time from werkzeug.security import generate_password_hash, check_password_hash import json from clickhouse_driver import Client as Cclient from queue import Queue, Empty, Full from threading import Lock, RLock from contextlib import contextmanager import jwt from docx import Document from tencentcloud.common import credential from tencentcloud.common.profile.client_profile import ClientProfile from tencentcloud.common.profile.http_profile import HttpProfile from tencentcloud.sms.v20210111 import sms_client, models from tencentcloud.common.exception.tencent_cloud_sdk_exception import TencentCloudSDKException engine = create_engine("mysql+pymysql://root:Zzl33818!@127.0.0.1:3306/stock", echo=False, pool_size=20, max_overflow=50) logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) # ===================== ClickHouse 连接池实现(增强版)===================== import threading as _threading import atexit class ClickHouseConnectionPool: """ ClickHouse 连接池(增强版) - 支持连接复用,避免频繁创建/销毁连接 - 支持连接超时和健康检查 - 支持连接最大存活时间(防止僵尸连接) - 支持后台自动清理过期连接 - 支持自动重连和重试 - 线程安全 """ def __init__(self, host, port, user, password, database, pool_size=10, max_overflow=10, connection_timeout=10, query_timeout=30, health_check_interval=60, max_connection_lifetime=300, # 新增:连接最大存活时间(秒) cleanup_interval=60, # 新增:清理间隔(秒) max_retries=3): # 新增:最大重试次数 """ 初始化连接池 Args: host: ClickHouse 主机地址 port: ClickHouse 端口 user: 用户名 password: 密码 database: 数据库名 pool_size: 连接池核心大小(预创建连接数) max_overflow: 最大溢出连接数(总连接数 = pool_size + max_overflow) connection_timeout: 获取连接超时时间(秒) query_timeout: 查询超时时间(秒) health_check_interval: 健康检查间隔(秒) max_connection_lifetime: 连接最大存活时间(秒),超过后自动关闭重建 cleanup_interval: 后台清理线程间隔(秒) max_retries: 查询失败时的最大重试次数 """ self.host = host self.port = port self.user = user self.password = password self.database = database self.pool_size = pool_size self.max_overflow = max_overflow self.connection_timeout = connection_timeout self.query_timeout = query_timeout self.health_check_interval = health_check_interval self.max_connection_lifetime = max_connection_lifetime self.cleanup_interval = cleanup_interval self.max_retries = max_retries # 连接池队列 self._pool = Queue(maxsize=pool_size + max_overflow) # 当前活跃连接数 self._active_connections = 0 # 锁 self._lock = RLock() # 连接最后使用时间记录 self._last_used = {} # 连接创建时间记录 self._created_at = {} # 是否已关闭 self._closed = False # 清理线程 self._cleanup_thread = None self._cleanup_stop_event = _threading.Event() # 初始化核心连接(延迟初始化,首次使用时创建) self._initialized = False # 清理线程也延迟启动(避免 fork 前启动线程导致问题) self._cleanup_thread_started = False logger.info(f"ClickHouse 连接池配置完成: pool_size={pool_size}, max_overflow={max_overflow}, " f"max_lifetime={max_connection_lifetime}s") # 注册退出时清理 atexit.register(self.close_all) def _start_cleanup_thread(self): """启动后台清理线程(延迟启动,首次使用连接池时调用)""" if self._cleanup_thread_started or self._closed: return with self._lock: if self._cleanup_thread_started or self._closed: return def cleanup_worker(): while not self._cleanup_stop_event.wait(self.cleanup_interval): if self._closed: break try: self._cleanup_expired_connections() except Exception as e: logger.error(f"清理连接时出错: {e}") self._cleanup_thread = _threading.Thread(target=cleanup_worker, daemon=True, name="CH-Pool-Cleanup") self._cleanup_thread.start() self._cleanup_thread_started = True logger.debug("后台清理线程已启动") def _cleanup_expired_connections(self): """清理过期连接""" current_time = time.time() cleaned_count = 0 # 创建临时列表存放需要保留的连接 valid_connections = [] # 从池中取出所有连接进行检查 while True: try: conn = self._pool.get_nowait() conn_id = id(conn) created_at = self._created_at.get(conn_id, current_time) last_used = self._last_used.get(conn_id, current_time) # 检查是否过期 lifetime = current_time - created_at idle_time = current_time - last_used if lifetime > self.max_connection_lifetime: # 连接存活时间过长,关闭 logger.debug(f"连接 {conn_id} 存活时间 {lifetime:.0f}s 超过限制,关闭") self._close_connection(conn) cleaned_count += 1 elif idle_time > self.health_check_interval * 3: # 长时间空闲,进行健康检查 if not self._check_connection_health(conn): self._close_connection(conn) cleaned_count += 1 else: valid_connections.append(conn) else: valid_connections.append(conn) except Empty: break # 将有效连接放回池中 for conn in valid_connections: try: self._pool.put_nowait(conn) except Full: self._close_connection(conn) if cleaned_count > 0: logger.info(f"清理了 {cleaned_count} 个过期连接,当前活跃连接数: {self._active_connections}") def _init_pool(self): """初始化连接池,预创建部分核心连接(非阻塞)""" if self._initialized: return with self._lock: if self._initialized: return # 启动清理线程(延迟到首次使用时启动) self._start_cleanup_thread() # 只预创建 1 个连接,其余按需创建 init_count = min(1, self.pool_size) for i in range(init_count): try: conn = self._create_connection() if conn: self._pool.put(conn) logger.info(f"预创建 ClickHouse 连接 {i+1}/{init_count} 成功") except Exception as e: logger.warning(f"预创建 ClickHouse 连接失败 ({i+1}/{init_count}): {e}") # 预创建失败不阻塞启动,后续按需创建 break self._initialized = True def _create_connection(self): """创建新的 ClickHouse 连接""" try: client = Cclient( host=self.host, port=self.port, user=self.user, password=self.password, database=self.database, connect_timeout=self.connection_timeout, send_receive_timeout=self.query_timeout, sync_request_timeout=self.query_timeout, settings={ 'max_execution_time': self.query_timeout, 'connect_timeout': self.connection_timeout, } ) conn_id = id(client) current_time = time.time() self._created_at[conn_id] = current_time self._last_used[conn_id] = current_time with self._lock: self._active_connections += 1 logger.debug(f"创建新的 ClickHouse 连接: {conn_id}") return client except Exception as e: logger.error(f"创建 ClickHouse 连接失败: {e}") raise def _check_connection_health(self, conn): """检查连接健康状态(带超时保护)""" try: conn_id = id(conn) last_used = self._last_used.get(conn_id, 0) created_at = self._created_at.get(conn_id, 0) current_time = time.time() # 检查连接是否存活时间过长 if current_time - created_at > self.max_connection_lifetime: logger.debug(f"连接 {conn_id} 超过最大存活时间,标记为不健康") return False # 如果连接长时间未使用,进行健康检查 if current_time - last_used > self.health_check_interval: # 执行简单查询检查连接 conn.execute("SELECT 1") self._last_used[conn_id] = current_time logger.debug(f"连接 {conn_id} 健康检查通过") return True except Exception as e: logger.warning(f"连接健康检查失败: {e}") return False def _close_connection(self, conn): """关闭连接""" if conn is None: return try: conn_id = id(conn) try: conn.disconnect() except: pass # 忽略断开连接时的错误 self._last_used.pop(conn_id, None) self._created_at.pop(conn_id, None) with self._lock: self._active_connections = max(0, self._active_connections - 1) logger.debug(f"关闭 ClickHouse 连接: {conn_id}") except Exception as e: logger.warning(f"关闭连接时出错: {e}") def get_connection(self, timeout=None): """ 从连接池获取连接 Args: timeout: 获取连接的超时时间,默认使用 connection_timeout Returns: ClickHouse 客户端连接 Raises: TimeoutError: 获取连接超时 Exception: 创建连接失败 """ if self._closed: raise RuntimeError("连接池已关闭") # 延迟初始化 if not self._initialized: self._init_pool() timeout = timeout or self.connection_timeout start_time = time.time() while True: elapsed = time.time() - start_time if elapsed >= timeout: raise TimeoutError(f"获取 ClickHouse 连接超时 (timeout={timeout}s)") remaining_timeout = timeout - elapsed # 首先尝试从池中获取连接 try: conn = self._pool.get(block=True, timeout=min(remaining_timeout, 1.0)) # 检查连接健康状态 if self._check_connection_health(conn): self._last_used[id(conn)] = time.time() return conn else: # 连接不健康,关闭并尝试获取新连接 self._close_connection(conn) continue except Empty: # 池中没有可用连接,检查是否可以创建新连接 with self._lock: if self._active_connections < self.pool_size + self.max_overflow: try: return self._create_connection() except Exception as e: logger.error(f"创建溢出连接失败: {e}") # 不立即抛出异常,继续等待 # 短暂等待后重试 time.sleep(0.1) def release_connection(self, conn): """ 释放连接回连接池 Args: conn: 要释放的连接 """ if conn is None: return if self._closed: self._close_connection(conn) return conn_id = id(conn) created_at = self._created_at.get(conn_id, 0) # 如果连接存活时间过长,直接关闭而不放回池中 if time.time() - created_at > self.max_connection_lifetime: logger.debug(f"连接 {conn_id} 超过最大存活时间,关闭而不放回池中") self._close_connection(conn) return self._last_used[conn_id] = time.time() try: self._pool.put(conn, block=False) logger.debug(f"连接 {conn_id} 已释放回连接池") except Full: # 池已满,关闭多余连接 logger.debug(f"连接池已满,关闭多余连接: {conn_id}") self._close_connection(conn) @contextmanager def connection(self, timeout=None): """ 上下文管理器方式获取连接 Usage: with pool.connection() as conn: result = conn.execute("SELECT * FROM table") """ conn = None try: conn = self.get_connection(timeout) yield conn except Exception as e: # 发生异常时,检查连接是否需要重建 if conn: try: # 尝试简单查询检测连接状态 conn.execute("SELECT 1") except: # 连接已损坏,关闭它 self._close_connection(conn) conn = None raise finally: if conn: self.release_connection(conn) def execute(self, query, params=None, timeout=None): """ 执行查询(自动管理连接,带重试机制) Args: query: SQL 查询语句 params: 查询参数 timeout: 查询超时时间 Returns: 查询结果 """ last_error = None for retry in range(self.max_retries): try: with self.connection(timeout) as conn: return conn.execute(query, params) except (TimeoutError, RuntimeError) as e: # 这些错误不应该重试 raise except Exception as e: last_error = e logger.warning(f"查询执行失败 (重试 {retry + 1}/{self.max_retries}): {e}") if retry < self.max_retries - 1: time.sleep(0.5 * (retry + 1)) # 递增等待时间 raise last_error def get_pool_status(self): """获取连接池状态""" return { 'pool_size': self.pool_size, 'max_overflow': self.max_overflow, 'active_connections': self._active_connections, 'available_connections': self._pool.qsize(), 'max_connections': self.pool_size + self.max_overflow, 'max_connection_lifetime': self.max_connection_lifetime, 'initialized': self._initialized, 'closed': self._closed } def close_all(self): """关闭所有连接""" self._closed = True self._cleanup_stop_event.set() # 停止清理线程 # 等待清理线程结束 if self._cleanup_thread and self._cleanup_thread.is_alive(): self._cleanup_thread.join(timeout=2) # 关闭所有池中的连接 while not self._pool.empty(): try: conn = self._pool.get_nowait() self._close_connection(conn) except Empty: break logger.info("ClickHouse 连接池已关闭所有连接") # 初始化全局 ClickHouse 连接池(懒加载模式) clickhouse_pool = None _pool_lock = Lock() def _init_clickhouse_pool(): """懒加载初始化 ClickHouse 连接池""" global clickhouse_pool if clickhouse_pool is None: with _pool_lock: if clickhouse_pool is None: clickhouse_pool = ClickHouseConnectionPool( host='127.0.0.1', port=9000, user='default', password='Zzl33818!', database='stock', pool_size=5, # 核心连接数 max_overflow=20, # 溢出连接数,总共支持 25 并发 connection_timeout=15, # 连接超时 15 秒(增加容忍度) query_timeout=60, # 查询超时 60 秒(给复杂查询更多时间) health_check_interval=30, # 30 秒进行健康检查 max_connection_lifetime=300, # 连接最大存活 5 分钟(防止僵尸连接) cleanup_interval=60, # 每 60 秒清理一次过期连接 max_retries=3 # 查询失败最多重试 3 次 ) return clickhouse_pool # ===================== ClickHouse 连接池实现结束 ===================== app = Flask(__name__) Compress(app) UPLOAD_FOLDER = 'static/uploads/avatars' ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif'} MAX_CONTENT_LENGTH = 16 * 1024 * 1024 # 16MB max file size # Configure Flask-Compress app.config['COMPRESS_ALGORITHM'] = ['gzip', 'br'] app.config['COMPRESS_MIMETYPES'] = [ 'text/html', 'text/css', 'text/xml', 'application/json', 'application/javascript', 'application/x-javascript' ] # ===================== Token 存储(支持多 worker 共享) ===================== class TokenStore: """ Token 存储类 - 支持 Redis(多 worker 共享)或内存(单 worker) """ def __init__(self): self._redis_client = None self._memory_store = {} self._prefix = 'vf_token:' self._initialized = False def _ensure_initialized(self): """延迟初始化,确保在 fork 后才连接 Redis""" if self._initialized: return self._initialized = True redis_url = os.environ.get('REDIS_URL', 'redis://:VF_Redis_2024@localhost:6379/0') try: import redis self._redis_client = redis.from_url(redis_url) self._redis_client.ping() logger.info(f"✅ Token 存储: Redis ({redis_url})") except Exception as e: logger.warning(f"⚠️ Redis 不可用 ({e}),Token 使用内存存储(多 worker 模式下会有问题!)") self._redis_client = None def get(self, token): """获取 token 数据""" self._ensure_initialized() if self._redis_client: try: data = self._redis_client.get(f"{self._prefix}{token}") if data: return json.loads(data) return None except Exception as e: logger.error(f"Redis get error: {e}") return self._memory_store.get(token) return self._memory_store.get(token) def set(self, token, data, expire_seconds=30*24*3600): """设置 token 数据""" self._ensure_initialized() if self._redis_client: try: # 将 datetime 转为字符串存储 store_data = data.copy() if 'expires' in store_data and isinstance(store_data['expires'], datetime): store_data['expires'] = store_data['expires'].isoformat() self._redis_client.setex( f"{self._prefix}{token}", expire_seconds, json.dumps(store_data) ) return except Exception as e: logger.error(f"Redis set error: {e}") self._memory_store[token] = data def delete(self, token): """删除 token""" self._ensure_initialized() if self._redis_client: try: self._redis_client.delete(f"{self._prefix}{token}") return except Exception as e: logger.error(f"Redis delete error: {e}") self._memory_store.pop(token, None) def __contains__(self, token): """支持 'in' 操作符""" return self.get(token) is not None # 使用 TokenStore 替代内存字典 user_tokens = TokenStore() app.config['SECRET_KEY'] = 'vf7891574233241' app.config['SQLALCHEMY_DATABASE_URI'] = 'mysql+pymysql://root:Zzl33818!@127.0.0.1:3306/stock?charset=utf8mb4' app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False app.config['JSON_AS_ASCII'] = False app.config['JSONIFY_PRETTYPRINT_REGULAR'] = True # 邮件配置 app.config['MAIL_SERVER'] = 'smtp.exmail.qq.com' app.config['MAIL_PORT'] = 465 app.config['MAIL_USE_SSL'] = True app.config['MAIL_USERNAME'] = 'admin@valuefrontier.cn' app.config['MAIL_PASSWORD'] = 'QYncRu6WUdASvTg4' app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER app.config['MAX_CONTENT_LENGTH'] = MAX_CONTENT_LENGTH # 腾讯云短信配置 SMS_SECRET_ID = 'AKID2we9TacdTAhCjCSYTErHVimeJo9Yr00s' SMS_SECRET_KEY = 'pMlBWijlkgT9fz5ziEXdWEnAPTJzRfkf' SMS_SDK_APP_ID = "1400972398" SMS_SIGN_NAME = "价值前沿科技" SMS_TEMPLATE_REGISTER = "2386557" # 注册模板 SMS_TEMPLATE_LOGIN = "2386540" # 登录模板 verification_codes = {} # 微信小程序 app.config['WECHAT_APP_ID'] = 'wx0edeaab76d4fa414' app.config['WECHAT_APP_SECRET'] = '0d0c70084f05a8c1411f6b89da7e815d' app.config['BASE_URL'] = 'https://api.valuefrontier.cn:5002' app.config['WECHAT_REDIRECT_URI'] = f"{app.config['BASE_URL']}/api/wechat/callback" WECHAT_APP_ID = 'wx0edeaab76d4fa414' WECHAT_APP_SECRET = '0d0c70084f05a8c1411f6b89da7e815d' JWT_SECRET_KEY = 'vfllmgreat33818!' # 请修改为安全的密钥 JWT_ALGORITHM = 'HS256' JWT_EXPIRATION_HOURS = 24 * 7 # Token有效期7天 # Session 配置 # 优先使用 Redis(支持多 worker 共享),否则回退到文件系统 _REDIS_URL = os.environ.get('REDIS_URL', 'redis://:VF_Redis_2024@localhost:6379/0') _USE_REDIS_SESSION = os.environ.get('USE_REDIS_SESSION', 'true').lower() == 'true' try: if _USE_REDIS_SESSION: import redis # 测试 Redis 连接 _redis_client = redis.from_url(_REDIS_URL) _redis_client.ping() app.config['SESSION_TYPE'] = 'redis' app.config['SESSION_REDIS'] = _redis_client app.config['SESSION_KEY_PREFIX'] = 'vf_session:' logger.info(f"✅ Session 存储: Redis ({_REDIS_URL})") else: raise Exception("Redis session disabled by config") except Exception as e: # Redis 不可用,回退到文件系统 logger.warning(f"⚠️ Redis 不可用 ({e}),使用文件系统 session(多 worker 模式下可能不稳定)") app.config['SESSION_TYPE'] = 'filesystem' app.config['SESSION_FILE_DIR'] = os.path.join(os.path.dirname(__file__), 'flask_session') os.makedirs(app.config['SESSION_FILE_DIR'], exist_ok=True) app.config['SESSION_PERMANENT'] = True app.config['PERMANENT_SESSION_LIFETIME'] = timedelta(days=7) # Session 有效期 7 天 app.config['SESSION_COOKIE_SECURE'] = False # 生产环境 HTTPS 时设为 True app.config['SESSION_COOKIE_HTTPONLY'] = True app.config['SESSION_COOKIE_SAMESITE'] = 'Lax' # Cache directory setup CACHE_DIR = Path('cache') CACHE_DIR.mkdir(exist_ok=True) # Memory management constants MAX_MEMORY_PERCENT = 75 MEMORY_CHECK_INTERVAL = 300 MAX_CACHE_ITEMS = 50 # 申银万国行业分类缓存(启动时初始化,避免每次请求都查询数据库) # 结构: {industry_level: {industry_name: [code_prefix1, code_prefix2, ...]}} SYWG_INDUSTRY_CACHE = { 2: {}, # level2: 一级行业 3: {}, # level3: 二级行业 4: {}, # level4: 三级行业 5: {} # level5: 四级行业 } # 初始化扩展 db = SQLAlchemy(app) mail = Mail(app) login_manager = LoginManager(app) login_manager.login_view = 'login' serializer = URLSafeTimedSerializer(app.config['SECRET_KEY']) migrate = Migrate(app, db) DOMAIN = 'https://api.valuefrontier.cn:5002' JWT_SECRET = 'Llmgreat123' JWT_EXPIRES_SECONDS = 3600 # 1小时有效期 Session(app) def token_required(f): """装饰器:需要token认证的接口""" from functools import wraps @wraps(f) def decorated_function(*args, **kwargs): token = None # 从请求头获取token auth_header = request.headers.get('Authorization') if auth_header and auth_header.startswith('Bearer '): token = auth_header[7:] if not token: return jsonify({'message': '缺少认证token'}), 401 token_data = user_tokens.get(token) if not token_data: return jsonify({'message': 'Token无效', 'code': 401}), 401 # 检查是否过期(expires 可能是字符串或 datetime) expires = token_data['expires'] if isinstance(expires, str): expires = datetime.fromisoformat(expires) if expires < datetime.now(): user_tokens.delete(token) return jsonify({'message': 'Token已过期'}), 401 # 获取用户对象并添加到请求上下文 user = User.query.get(token_data['user_id']) if not user: return jsonify({'message': '用户不存在'}), 404 # 将用户对象添加到request request.user = user request.current_user_id = token_data['user_id'] return f(*args, **kwargs) return decorated_function def beijing_now(): # 使用 pytz 处理时区 beijing_tz = pytz.timezone('Asia/Shanghai') return datetime.now(beijing_tz) # ============================================ # 订阅功能模块(与 app.py 保持一致) # ============================================ class UserSubscription(db.Model): """用户订阅表 - 独立于现有User表""" __tablename__ = 'user_subscriptions' id = db.Column(db.Integer, primary_key=True, autoincrement=True) user_id = db.Column(db.Integer, nullable=False, unique=True, index=True) subscription_type = db.Column(db.String(10), nullable=False, default='free') subscription_status = db.Column(db.String(20), nullable=False, default='active') start_date = db.Column(db.DateTime, nullable=True) end_date = db.Column(db.DateTime, nullable=True) billing_cycle = db.Column(db.String(10), nullable=True) auto_renewal = db.Column(db.Boolean, nullable=False, default=False) created_at = db.Column(db.DateTime, default=beijing_now) updated_at = db.Column(db.DateTime, default=beijing_now, onupdate=beijing_now) def is_active(self): if self.subscription_status != 'active': return False if self.subscription_type == 'free': return True if self.end_date: # 将数据库的 naive datetime 转换为带时区的 aware datetime beijing_tz = pytz.timezone('Asia/Shanghai') end_date_aware = self.end_date if self.end_date.tzinfo else beijing_tz.localize(self.end_date) return beijing_now() <= end_date_aware return True def days_left(self): if not self.is_active(): return 0 if self.subscription_type == 'free': return 999 if not self.end_date: return 999 try: now = beijing_now() # 将数据库的 naive datetime 转换为带时区的 aware datetime beijing_tz = pytz.timezone('Asia/Shanghai') end_date_aware = self.end_date if self.end_date.tzinfo else beijing_tz.localize(self.end_date) delta = end_date_aware - now return max(0, delta.days) except Exception as e: return 0 def to_dict(self): return { 'type': self.subscription_type, 'status': self.subscription_status, 'is_active': self.is_active(), 'start_date': self.start_date.isoformat() if self.start_date else None, 'end_date': self.end_date.isoformat() if self.end_date else None, 'days_left': self.days_left(), 'billing_cycle': self.billing_cycle, 'auto_renewal': self.auto_renewal } # ============================================ # 订阅等级工具函数 # ============================================ def get_user_subscription_safe(user_id): """ 安全地获取用户订阅信息 :param user_id: 用户ID :return: UserSubscription 对象或默认免费订阅 """ try: subscription = UserSubscription.query.filter_by(user_id=user_id).first() if not subscription: # 如果用户没有订阅记录,创建默认免费订阅 subscription = UserSubscription( user_id=user_id, subscription_type='free', subscription_status='active' ) db.session.add(subscription) db.session.commit() return subscription except Exception as e: print(f"获取用户订阅信息失败: {e}") # 返回一个临时的免费订阅对象(不保存到数据库) temp_sub = UserSubscription( user_id=user_id, subscription_type='free', subscription_status='active' ) return temp_sub def _get_current_subscription_info(): """ 获取当前登录用户订阅信息的字典形式,未登录或异常时视为免费用户。 小程序场景下从 request.current_user_id 获取用户ID """ try: user_id = getattr(request, 'current_user_id', None) if not user_id: return { 'type': 'free', 'status': 'active', 'is_active': True } sub = get_user_subscription_safe(user_id) return { 'type': sub.subscription_type, 'status': sub.subscription_status, 'is_active': sub.is_active(), 'start_date': sub.start_date.isoformat() if sub.start_date else None, 'end_date': sub.end_date.isoformat() if sub.end_date else None, 'days_left': sub.days_left() } except Exception as e: print(f"获取订阅信息异常: {e}") return { 'type': 'free', 'status': 'active', 'is_active': True } def _subscription_level(sub_type): """将订阅类型映射到等级数值,free=0, pro=1, max=2。""" mapping = {'free': 0, 'pro': 1, 'max': 2} return mapping.get((sub_type or 'free').lower(), 0) def _has_required_level(required: str) -> bool: """判断当前用户是否达到所需订阅级别。""" info = _get_current_subscription_info() if not info.get('is_active', True): return False return _subscription_level(info.get('type')) >= _subscription_level(required) # ============================================ # 权限装饰器 # ============================================ def subscription_required(level='pro'): """ 订阅等级装饰器 - 小程序专用 用法: @subscription_required('pro') # 需要 Pro 或 Max 用户 @subscription_required('max') # 仅限 Max 用户 注意:此装饰器需要配合 使用 """ from functools import wraps def decorator(f): @wraps(f) def decorated_function(*args, **kwargs): if not _has_required_level(level): current_info = _get_current_subscription_info() current_type = current_info.get('type', 'free') is_active = current_info.get('is_active', False) if not is_active: return jsonify({ 'success': False, 'error': '您的订阅已过期,请续费后继续使用', 'error_code': 'SUBSCRIPTION_EXPIRED', 'current_subscription': current_type, 'required_subscription': level }), 403 return jsonify({ 'success': False, 'error': f'此功能需要 {level.upper()} 或更高等级会员', 'error_code': 'SUBSCRIPTION_REQUIRED', 'current_subscription': current_type, 'required_subscription': level }), 403 return f(*args, **kwargs) return decorated_function return decorator def pro_or_max_required(f): """ 快捷装饰器:要求 Pro 或 Max 用户(小程序专用场景) 等同于 @subscription_required('pro') """ from functools import wraps @wraps(f) def decorated_function(*args, **kwargs): if not _has_required_level('pro'): current_info = _get_current_subscription_info() current_type = current_info.get('type', 'free') return jsonify({ 'success': False, 'error': '小程序功能仅对 Pro 和 Max 会员开放', 'error_code': 'MINIPROGRAM_PRO_REQUIRED', 'current_subscription': current_type, 'required_subscription': 'pro', 'message': '请升级到 Pro 或 Max 会员以使用小程序完整功能' }), 403 return f(*args, **kwargs) return decorated_function class User(UserMixin, db.Model): """用户模型""" id = db.Column(db.Integer, primary_key=True) # 基础账号信息(注册时必填) username = db.Column(db.String(80), unique=True, nullable=False) # 用户名 email = db.Column(db.String(120), unique=True, nullable=False) # 邮箱 password_hash = db.Column(db.String(128), nullable=False) # 密码哈希 email_confirmed = db.Column(db.Boolean, default=False) # 邮箱是否验证 wechat_union_id = db.Column(db.String(100), unique=True) # 微信 UnionID wechat_open_id = db.Column(db.String(100)) # 微信 OpenID # 账号状态 created_at = db.Column(db.DateTime, default=beijing_now) # 注册时间 last_seen = db.Column(db.DateTime, default=beijing_now) # 最后活跃时间 status = db.Column(db.String(20), default='active') # 账号状态 active/banned/deleted # 个人资料(可选,后续在个人中心完善) nickname = db.Column(db.String(30)) # 社区昵称 avatar_url = db.Column(db.String(200)) # 头像URL banner_url = db.Column(db.String(200)) # 个人主页背景图 bio = db.Column(db.String(200)) # 个人简介 gender = db.Column(db.String(10)) # 性别 birth_date = db.Column(db.Date) # 生日 location = db.Column(db.String(100)) # 所在地 # 联系方式(可选) phone = db.Column(db.String(20)) # 手机号 wechat_id = db.Column(db.String(80)) # 微信号 # 实名认证信息(可选) real_name = db.Column(db.String(30)) # 真实姓名 id_number = db.Column(db.String(18)) # 身份证号(加密存储) is_verified = db.Column(db.Boolean, default=False) # 是否实名认证 verify_time = db.Column(db.DateTime) # 实名认证时间 # 投资相关信息(可选) trading_experience = db.Column(db.String(200)) # 炒股年限 investment_style = db.Column(db.String(50)) # 投资风格 risk_preference = db.Column(db.String(20)) # 风险偏好 investment_amount = db.Column(db.String(20)) # 投资规模 preferred_markets = db.Column(db.String(200), default='[]') # 偏好市场 JSON # 社区信息(系统自动更新) user_level = db.Column(db.Integer, default=1) # 用户等级 reputation_score = db.Column(db.Integer, default=0) # 信用积分 contribution_point = db.Column(db.Integer, default=0) # 贡献点数 post_count = db.Column(db.Integer, default=0) # 发帖数 comment_count = db.Column(db.Integer, default=0) # 评论数 follower_count = db.Column(db.Integer, default=0) # 粉丝数 following_count = db.Column(db.Integer, default=0) # 关注数 # 创作者信息(可选) is_creator = db.Column(db.Boolean, default=False) # 是否创作者 creator_type = db.Column(db.String(20)) # 创作者类型 creator_tags = db.Column(db.String(200), default='[]') # 创作者标签 JSON # 系统设置 email_notifications = db.Column(db.Boolean, default=True) # 邮件通知 sms_notifications = db.Column(db.Boolean, default=False) # 短信通知 wechat_notifications = db.Column(db.Boolean, default=False) # 微信通知 notification_preferences = db.Column(db.String(500), default='{}') # 通知偏好 JSON privacy_level = db.Column(db.String(20), default='public') # 隐私级别 theme_preference = db.Column(db.String(20), default='light') # 主题偏好 blocked_keywords = db.Column(db.String(500), default='[]') # 屏蔽关键词 JSON # 手机号验证 phone_confirmed = db.Column(db.Boolean, default=False) # 手机是否验证 phone_confirm_time = db.Column(db.DateTime) # 手机验证时间 def __init__(self, username, email=None, password=None, phone=None): """初始化用户,只需要基本信息""" self.username = username if email: self.email = email if phone: self.phone = phone if password: self.set_password(password) self.created_at = beijing_now() self.last_seen = beijing_now() def set_password(self, password): """设置密码""" self.password_hash = generate_password_hash(password) def check_password(self, password): """验证密码""" return check_password_hash(self.password_hash, password) def update_last_seen(self): """更新最后活跃时间""" self.last_seen = beijing_now() # JSON 字段的getter和setter def get_preferred_markets(self): """获取偏好市场列表""" if self.preferred_markets: try: return json.loads(self.preferred_markets) except: return [] return [] def get_blocked_keywords(self): """获取屏蔽关键词列表""" if self.blocked_keywords: try: return json.loads(self.blocked_keywords) except: return [] return [] def get_notification_preferences(self): """获取通知偏好设置""" if self.notification_preferences: try: return json.loads(self.notification_preferences) except: return {} return {} def get_creator_tags(self): """获取创作者标签""" if self.creator_tags: try: return json.loads(self.creator_tags) except: return [] return [] def set_preferred_markets(self, markets): """设置偏好市场""" self.preferred_markets = json.dumps(markets) def set_blocked_keywords(self, keywords): """设置屏蔽关键词""" self.blocked_keywords = json.dumps(keywords) def set_notification_preferences(self, preferences): """设置通知偏好""" self.notification_preferences = json.dumps(preferences) def set_creator_tags(self, tags): """设置创作者标签""" self.creator_tags = json.dumps(tags) def to_dict(self): """返回用户的字典表示""" return { 'id': self.id, 'username': self.username, 'email': self.email, 'nickname': self.nickname, 'avatar_url': get_full_avatar_url(self.avatar_url), # 修改这里 'bio': self.bio, 'is_verified': self.is_verified, 'user_level': self.user_level, 'reputation_score': self.reputation_score, 'is_creator': self.is_creator } def __repr__(self): return f'' class Notification(db.Model): """通知模型""" id = db.Column(db.Integer, primary_key=True) user_id = db.Column(db.Integer, db.ForeignKey('user.id')) type = db.Column(db.String(50)) # 通知类型 content = db.Column(db.Text) # 通知内容 link = db.Column(db.String(200)) # 相关链接 is_read = db.Column(db.Boolean, default=False) # 是否已读 created_at = db.Column(db.DateTime, default=beijing_now) def __init__(self, user_id, type, content, link=None): self.user_id = user_id self.type = type self.content = content self.link = link class Event(db.Model): """事件模型""" id = db.Column(db.Integer, primary_key=True) title = db.Column(db.String(200), nullable=False) description = db.Column(db.Text) # 事件类型与状态 event_type = db.Column(db.String(50)) status = db.Column(db.String(20), default='active') # 时间相关 start_time = db.Column(db.DateTime, default=beijing_now) end_time = db.Column(db.DateTime) created_at = db.Column(db.DateTime, default=beijing_now) updated_at = db.Column(db.DateTime, default=beijing_now) # 热度与统计 hot_score = db.Column(db.Float, default=0) view_count = db.Column(db.Integer, default=0) trending_score = db.Column(db.Float, default=0) post_count = db.Column(db.Integer, default=0) follower_count = db.Column(db.Integer, default=0) # 关联信息 related_industries = db.Column(db.String(20)) # 申万行业代码,如 "S640701" keywords = db.Column(db.JSON) files = db.Column(db.JSON) importance = db.Column(db.String(20)) related_avg_chg = db.Column(db.Float, default=0) related_max_chg = db.Column(db.Float, default=0) related_week_chg = db.Column(db.Float, default=0) # 新增字段 invest_score = db.Column(db.Integer) # 超预期得分 expectation_surprise_score = db.Column(db.Integer) # 创建者信息 creator_id = db.Column(db.Integer, db.ForeignKey('user.id')) creator = db.relationship('User', backref='created_events') # 关系 posts = db.relationship('Post', backref='event', lazy='dynamic') followers = db.relationship('EventFollow', backref='event', lazy='dynamic') related_stocks = db.relationship('RelatedStock', backref='event', lazy='dynamic') historical_events = db.relationship('HistoricalEvent', backref='event', lazy='dynamic') related_data = db.relationship('RelatedData', backref='event', lazy='dynamic') related_concepts = db.relationship('RelatedConcepts', backref='event', lazy='dynamic') ind_type = db.Column(db.String(255)) @property def keywords_list(self): """返回解析后的关键词列表""" if not self.keywords: return [] if isinstance(self.keywords, list): return self.keywords try: # 如果是字符串,尝试解析JSON if isinstance(self.keywords, str): decoded = json.loads(self.keywords) # 处理Unicode编码的情况 if isinstance(decoded, list): return [ keyword.encode('utf-8').decode('unicode_escape') if isinstance(keyword, str) and '\\u' in keyword else keyword for keyword in decoded ] return [] # 如果已经是字典或其他格式,尝试转换为列表 return list(self.keywords) except (json.JSONDecodeError, AttributeError, TypeError): return [] def set_keywords(self, keywords): """设置关键词列表""" if isinstance(keywords, list): self.keywords = json.dumps(keywords, ensure_ascii=False) elif isinstance(keywords, str): try: # 尝试解析JSON字符串 parsed = json.loads(keywords) if isinstance(parsed, list): self.keywords = json.dumps(parsed, ensure_ascii=False) else: self.keywords = json.dumps([keywords], ensure_ascii=False) except json.JSONDecodeError: # 如果不是有效的JSON,将其作为单个关键词 self.keywords = json.dumps([keywords], ensure_ascii=False) class RelatedStock(db.Model): """相关标的模型""" id = db.Column(db.Integer, primary_key=True) event_id = db.Column(db.Integer, db.ForeignKey('event.id')) stock_code = db.Column(db.String(20)) # 股票代码 stock_name = db.Column(db.String(100)) # 股票名称 sector = db.Column(db.String(100)) # 关联类型 relation_desc = db.Column(db.String(1024)) # 关联原因描述 created_at = db.Column(db.DateTime, default=beijing_now) updated_at = db.Column(db.DateTime, default=beijing_now, onupdate=beijing_now) correlation = db.Column(db.Float()) momentum = db.Column(db.String(1024)) # 动量 # 新增字段 retrieved_sources = db.Column(db.JSON) # 研报检索源数据 retrieved_update_time = db.Column(db.DateTime) # 检索数据更新时间 class RelatedData(db.Model): """关联数据模型""" id = db.Column(db.Integer, primary_key=True) event_id = db.Column(db.Integer, db.ForeignKey('event.id')) title = db.Column(db.String(200)) # 数据标题 data_type = db.Column(db.String(50)) # 数据类型 data_content = db.Column(db.JSON) # 数据内容(JSON格式) description = db.Column(db.Text) # 数据描述 created_at = db.Column(db.DateTime, default=beijing_now) class RelatedConcepts(db.Model): """关联数据模型""" id = db.Column(db.Integer, primary_key=True) event_id = db.Column(db.Integer, db.ForeignKey('event.id')) concept_code = db.Column(db.String(20)) # 数据标题 concept = db.Column(db.String(100)) # 数据类型 reason = db.Column(db.Text) # 数据描述 image_paths = db.Column(db.JSON) # 数据内容(JSON格式) created_at = db.Column(db.DateTime, default=beijing_now) @property def image_paths_list(self): """返回解析后的图片路径列表""" if not self.image_paths: return [] try: # 如果是字符串,先解析成JSON if isinstance(self.image_paths, str): paths = json.loads(self.image_paths) else: paths = self.image_paths # 确保paths是列表 if not isinstance(paths, list): paths = [paths] # 从每个对象中提取path字段 return [item['path'] if isinstance(item, dict) and 'path' in item else item for item in paths] except Exception as e: print(f"Error processing image paths: {e}") return [] def get_first_image_path(self): """获取第一张图片的完整路径""" paths = self.image_paths_list if not paths: return None # 获取第一个路径 first_path = paths[0] # 返回完整路径 return first_path class EventHotHistory(db.Model): """事件热度历史记录""" id = db.Column(db.Integer, primary_key=True) event_id = db.Column(db.Integer, db.ForeignKey('event.id')) score = db.Column(db.Float) # 总分 interaction_score = db.Column(db.Float) # 互动分数 follow_score = db.Column(db.Float) # 关注度分数 view_score = db.Column(db.Float) # 浏览量分数 recent_activity_score = db.Column(db.Float) # 最近活跃度分数 time_decay = db.Column(db.Float) # 时间衰减因子 created_at = db.Column(db.DateTime, default=beijing_now) event = db.relationship('Event', backref='hot_history') class Post(db.Model): """帖子模型""" id = db.Column(db.Integer, primary_key=True) event_id = db.Column(db.Integer, db.ForeignKey('event.id'), nullable=False) user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False) # 内容 title = db.Column(db.String(200)) # 标题(可选) content = db.Column(db.Text, nullable=False) # 内容 content_type = db.Column(db.String(20), default='text') # 内容类型:text/rich_text/link # 时间 created_at = db.Column(db.DateTime, default=beijing_now) updated_at = db.Column(db.DateTime, default=beijing_now, onupdate=beijing_now) # 统计 likes_count = db.Column(db.Integer, default=0) comments_count = db.Column(db.Integer, default=0) view_count = db.Column(db.Integer, default=0) # 状态 status = db.Column(db.String(20), default='active') # active/hidden/deleted is_top = db.Column(db.Boolean, default=False) # 是否置顶 # 关系 user = db.relationship('User', backref='posts') likes = db.relationship('PostLike', backref='post', lazy='dynamic') comments = db.relationship('Comment', backref='post', lazy='dynamic') # 辅助模型 class EventFollow(db.Model): """事件关注""" id = db.Column(db.Integer, primary_key=True) user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False) event_id = db.Column(db.Integer, db.ForeignKey('event.id'), nullable=False) created_at = db.Column(db.DateTime, default=beijing_now) user = db.relationship('User', backref='event_follows') __table_args__ = (db.UniqueConstraint('user_id', 'event_id'),) class PostLike(db.Model): """帖子点赞""" id = db.Column(db.Integer, primary_key=True) user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False) post_id = db.Column(db.Integer, db.ForeignKey('post.id'), nullable=False) created_at = db.Column(db.DateTime, default=beijing_now) user = db.relationship('User', backref='post_likes') __table_args__ = (db.UniqueConstraint('user_id', 'post_id'),) class Comment(db.Model): """评论""" id = db.Column(db.Integer, primary_key=True) post_id = db.Column(db.Integer, db.ForeignKey('post.id'), nullable=False) user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False) content = db.Column(db.Text, nullable=False) parent_id = db.Column(db.Integer, db.ForeignKey('comment.id')) # 父评论ID,用于回复 created_at = db.Column(db.DateTime, default=beijing_now) status = db.Column(db.String(20), default='active') user = db.relationship('User', backref='comments') replies = db.relationship('Comment', backref=db.backref('parent', remote_side=[id])) class Feedback(db.Model): """用户反馈""" __tablename__ = 'user_feedback' id = db.Column(db.Integer, primary_key=True) user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=True) # 可匿名反馈 feedback_type = db.Column(db.String(50), default='general') # 反馈类型: general/bug/suggestion/complaint title = db.Column(db.String(200)) # 反馈标题 content = db.Column(db.Text, nullable=False) # 反馈内容 contact = db.Column(db.String(100)) # 联系方式(可选) images = db.Column(db.JSON) # 附带图片URL列表 status = db.Column(db.String(20), default='pending') # pending/processing/resolved/closed admin_reply = db.Column(db.Text) # 管理员回复 created_at = db.Column(db.DateTime, default=beijing_now) updated_at = db.Column(db.DateTime, default=beijing_now, onupdate=beijing_now) user = db.relationship('User', backref='feedbacks') class StockBasicInfo(db.Model): __tablename__ = 'ea_stocklist' SECCODE = db.Column(db.String(10), primary_key=True) SECNAME = db.Column(db.String(40)) ORGNAME = db.Column(db.String(100)) F001V = db.Column(db.String(100)) # Pinyin abbreviation F003V = db.Column(db.String(50)) # Security category F005V = db.Column(db.String(50)) # Trading market F006D = db.Column(db.DateTime) # Listing date F011V = db.Column(db.String(50)) # Listing status class CompanyInfo(db.Model): __tablename__ = 'ea_baseinfo' SECCODE = db.Column(db.String(10), primary_key=True) SECNAME = db.Column(db.String(40)) ORGNAME = db.Column(db.String(100)) F001V = db.Column(db.String(100)) # English name F003V = db.Column(db.String(40)) # Legal representative F015V = db.Column(db.String(500)) # Main business F016V = db.Column(db.String(4000)) # Business scope F017V = db.Column(db.String(2000)) # Company introduction F030V = db.Column(db.String(60)) # CSRC industry first level F032V = db.Column(db.String(60)) # CSRC industry second level class TradeData(db.Model): __tablename__ = 'ea_trade' SECCODE = db.Column(db.String(10), primary_key=True) SECNAME = db.Column(db.String(40)) TRADEDATE = db.Column(db.Date, primary_key=True) F002N = db.Column(db.Numeric(18, 4)) # Previous close F003N = db.Column(db.Numeric(18, 4)) # Open price F004N = db.Column(db.Numeric(18, 4)) # Trading volume F005N = db.Column(db.Numeric(18, 4)) # High price F006N = db.Column(db.Numeric(18, 4)) # Low price F007N = db.Column(db.Numeric(18, 4)) # Close price F009N = db.Column(db.Numeric(18, 4)) # Change F010N = db.Column(db.Numeric(18, 4)) # Change percentage F011N = db.Column(db.Numeric(18, 4)) # Trading amount class SectorInfo(db.Model): __tablename__ = 'ea_sector' SECCODE = db.Column(db.String(10), primary_key=True) SECNAME = db.Column(db.String(40)) F001V = db.Column(db.String(50), primary_key=True) # Classification standard code F002V = db.Column(db.String(50)) # Classification standard F003V = db.Column(db.String(50)) # Sector code F004V = db.Column(db.String(50)) # Sector level 1 name F005V = db.Column(db.String(50)) # Sector level 2 name F006V = db.Column(db.String(50)) # Sector level 3 name F007V = db.Column(db.String(50)) # Sector level 4 name def init_sywg_industry_cache(): """ 初始化申银万国行业分类缓存 在程序启动时调用,将所有行业分类数据加载到内存中 """ global SYWG_INDUSTRY_CACHE try: app.logger.info('开始初始化申银万国行业分类缓存...') # 定义层级映射关系 level_column_map = { 2: 'f004v', # level2 对应一级行业 3: 'f005v', # level3 对应二级行业 4: 'f006v', # level4 对应三级行业 5: 'f007v' # level5 对应四级行业 } # 定义代码前缀长度映射 prefix_length_map = { 2: 3, # S + 2位 3: 5, # S + 2位 + 2位 4: 7, # S + 2位 + 2位 + 2位 5: 9 # 完整代码 } # 遍历所有层级 for level, column_name in level_column_map.items(): # 查询该层级的所有行业及其代码 query_sql = f""" SELECT DISTINCT {column_name} as industry_name, f003v as code FROM ea_sector WHERE f002v = '申银万国行业分类' AND {column_name} IS NOT NULL AND {column_name} != '' """ result = db.session.execute(text(query_sql)) rows = result.fetchall() # 构建该层级的缓存 industry_dict = {} for row in rows: industry_name = row[0] code = row[1] if industry_name and code: # 获取代码前缀 prefix_length = prefix_length_map[level] code_prefix = code[:prefix_length] # 将前缀添加到对应行业的列表中 if industry_name not in industry_dict: industry_dict[industry_name] = set() industry_dict[industry_name].add(code_prefix) # 将set转换为list并存储到缓存中 for industry_name, prefixes in industry_dict.items(): SYWG_INDUSTRY_CACHE[level][industry_name] = list(prefixes) app.logger.info(f'Level {level} 缓存完成,共 {len(industry_dict)} 个行业') # 统计总数 total_count = sum(len(industries) for industries in SYWG_INDUSTRY_CACHE.values()) app.logger.info(f'申银万国行业分类缓存初始化完成,共缓存 {total_count} 个行业分类') except Exception as e: app.logger.error(f'初始化申银万国行业分类缓存失败: {str(e)}') import traceback app.logger.error(traceback.format_exc()) def send_async_email(msg): """异步发送邮件""" try: mail.send(msg) except Exception as e: app.logger.error(f"Error sending async email: {str(e)}") def verify_sms_code(phone_number, code): """验证短信验证码""" stored_code = session.get('sms_verification_code') stored_phone = session.get('sms_verification_phone') expiration = session.get('sms_verification_expiration') if not all([stored_code, stored_phone, expiration]): return False, "请先获取验证码" if stored_phone != phone_number: return False, "手机号与验证码不匹配" if beijing_now().timestamp() > expiration: return False, "验证码已过期" if code != stored_code: return False, "验证码错误" return True, "验证成功" def allowed_file(filename): return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS # ============================================ # 订阅相关 API 接口(小程序专用) # ============================================ @app.route('/api/subscription/info', methods=['GET']) @token_required def get_subscription_info(): """ 获取当前用户的订阅信息 - 小程序专用接口 返回用户当前订阅类型、状态、剩余天数等信息 """ try: info = _get_current_subscription_info() return jsonify({ 'success': True, 'data': info }) except Exception as e: print(f"获取订阅信息错误: {e}") return jsonify({ 'success': True, 'data': { 'type': 'free', 'status': 'active', 'is_active': True, 'days_left': 0 } }) @app.route('/api/subscription/check', methods=['GET']) @token_required def check_subscription_access(): """ 检查当前用户是否有权限使用小程序功能 返回:是否为 Pro/Max 用户 """ try: has_access = _has_required_level('pro') info = _get_current_subscription_info() return jsonify({ 'success': True, 'data': { 'has_access': has_access, 'subscription_type': info.get('type', 'free'), 'is_active': info.get('is_active', False), 'message': '您可以使用小程序功能' if has_access else '小程序功能仅对 Pro 和 Max 会员开放' } }) except Exception as e: print(f"检查订阅权限错误: {e}") return jsonify({ 'success': False, 'error': str(e) }), 500 # ============================================ # 现有接口示例(应用权限控制) # ============================================ # 更新视图函数 @app.route('/settings/profile', methods=['POST']) @token_required def update_profile(): """更新个人资料""" try: user = request.user form = request.form # 基本信息更新 user.nickname = form.get('nickname') user.bio = form.get('bio') user.gender = form.get('gender') user.birth_date = datetime.strptime(form.get('birth_date'), '%Y-%m-%d') if form.get('birth_date') else None user.phone = form.get('phone') user.location = form.get('location') user.wechat_id = form.get('wechat_id') # 处理头像上传 if 'avatar' in request.files: file = request.files['avatar'] if file and allowed_file(file.filename): # 生成安全的文件名 filename = secure_filename(f"{user.id}_{int(datetime.now().timestamp())}_{file.filename}") filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename) # 确保上传目录存在 os.makedirs(os.path.dirname(filepath), exist_ok=True) # 保存并处理图片 image = Image.open(file) image.thumbnail((300, 300)) # 调整图片大小 image.save(filepath) # 更新用户头像URL user.avatar_url = f'{DOMAIN}/static/uploads/avatars/{filename}' db.session.commit() return jsonify({'success': True, 'message': '个人资料已更新'}) except Exception as e: db.session.rollback() app.logger.error(f"Error updating profile: {str(e)}") return jsonify({'success': False, 'message': '更新失败,请重试'}) # 投资偏好设置 @app.route('/settings/investment_preferences', methods=['POST']) @token_required def update_investment_preferences(): """更新投资偏好""" try: user = request.user form = request.form user.trading_experience = form.get('trading_experience') user.investment_style = form.get('investment_style') user.risk_preference = form.get('risk_preference') user.investment_amount = form.get('investment_amount') user.preferred_markets = json.dumps(request.form.getlist('preferred_markets')) db.session.commit() return jsonify({'success': True, 'message': '投资偏好已更新'}) except Exception as e: db.session.rollback() app.logger.error(f"Error updating investment preferences: {str(e)}") return jsonify({'success': False, 'message': '更新失败,请重试'}) def get_clickhouse_client(): """ 获取 ClickHouse 客户端(使用连接池,懒加载) 返回连接池对象,支持两种使用方式: 方式1(推荐)- 直接调用 execute: client = get_clickhouse_client() result = client.execute("SELECT * FROM table", {'param': value}) 方式2 - 使用上下文管理器: client = get_clickhouse_client() with client.connection() as conn: result = conn.execute("SELECT * FROM table") """ return _init_clickhouse_pool() @app.route('/api/system/clickhouse-pool-status', methods=['GET']) def api_clickhouse_pool_status(): """获取 ClickHouse 连接池状态(仅供监控使用)""" try: pool = _init_clickhouse_pool() status = pool.get_pool_status() return jsonify({ 'code': 200, 'message': 'success', 'data': status }) except Exception as e: return jsonify({ 'code': 500, 'message': str(e), 'data': None }), 500 @app.route('/api/stock//kline') def get_stock_kline(stock_code): """获取股票K线数据 - 仅限 Pro/Max 会员(小程序功能)""" chart_type = request.args.get('chart_type', 'daily') # 默认改为daily event_time = request.args.get('event_time') try: event_datetime = datetime.fromisoformat(event_time) if event_time else datetime.now() except ValueError: return jsonify({'error': 'Invalid event_time format'}), 400 # 获取股票名称 try: with engine.connect() as conn: result = conn.execute(text( "SELECT SECNAME FROM ea_stocklist WHERE SECCODE = :code" ), {"code": stock_code.split('.')[0]}).fetchone() stock_name = result[0] if result else 'Unknown' except Exception as e: print(f"Error getting stock name: {e}") stock_name = 'Unknown' if chart_type == 'daily': return get_daily_kline(stock_code, event_datetime, stock_name) elif chart_type == 'minute': return get_minute_kline(stock_code, event_datetime, stock_name) else: return jsonify({ 'error': 'Invalid chart type', 'message': 'Supported types: daily, minute', 'code': stock_code, 'name': stock_name }), 400 def get_daily_kline(stock_code, event_datetime, stock_name): """处理日K线数据""" stock_code = stock_code.split('.')[0] print(f"Debug: stock_code={stock_code}, event_datetime={event_datetime}, stock_name={stock_name}") try: with engine.connect() as conn: # 获取事件日期前后的数据 kline_sql = """ WITH date_range AS (SELECT TRADEDATE \ FROM ea_trade \ WHERE SECCODE = :stock_code \ AND TRADEDATE BETWEEN DATE_SUB(:trade_date, INTERVAL 60 DAY) \ AND :trade_date \ GROUP BY TRADEDATE \ ORDER BY TRADEDATE) SELECT t.TRADEDATE, CAST(t.F003N AS FLOAT) as open, CAST(t.F007N AS FLOAT) as close, CAST(t.F005N AS FLOAT) as high, CAST(t.F006N AS FLOAT) as low, CAST(t.F004N AS FLOAT) as volume FROM ea_trade t JOIN date_range d \ ON t.TRADEDATE = d.TRADEDATE WHERE t.SECCODE = :stock_code ORDER BY t.TRADEDATE \ """ result = conn.execute(text(kline_sql), { "stock_code": stock_code, "trade_date": event_datetime.date() }).fetchall() print(f"Debug: Query result count: {len(result)}") if not result: print("Debug: No data found, trying fallback query...") # 如果没有数据,尝试获取最近的交易数据 fallback_sql = """ SELECT TRADEDATE, CAST(F003N AS FLOAT) as open, CAST(F007N AS FLOAT) as close, CAST(F005N AS FLOAT) as high, CAST(F006N AS FLOAT) as low, CAST(F004N AS FLOAT) as volume FROM ea_trade WHERE SECCODE = :stock_code AND TRADEDATE <= :trade_date AND F003N IS NOT NULL AND F007N IS NOT NULL AND F005N IS NOT NULL AND F006N IS NOT NULL AND F004N IS NOT NULL ORDER BY TRADEDATE LIMIT 100 \ """ result = conn.execute(text(fallback_sql), { "stock_code": stock_code, "trade_date": event_datetime.date() }).fetchall() print(f"Debug: Fallback query result count: {len(result)}") if not result: return jsonify({ 'error': 'No data available', 'code': stock_code, 'name': stock_name, 'data': [], 'trade_date': event_datetime.date().strftime('%Y-%m-%d'), 'type': 'daily' }) kline_data = [] for row in result: try: kline_data.append({ 'time': row.TRADEDATE.strftime('%Y-%m-%d'), 'open': float(row.open) if row.open else 0, 'high': float(row.high) if row.high else 0, 'low': float(row.low) if row.low else 0, 'close': float(row.close) if row.close else 0, 'volume': float(row.volume) if row.volume else 0 }) except (ValueError, TypeError) as e: print(f"Debug: Error processing row: {e}") continue print(f"Debug: Final kline_data count: {len(kline_data)}") return jsonify({ 'code': stock_code, 'name': stock_name, 'data': kline_data, 'trade_date': event_datetime.date().strftime('%Y-%m-%d'), 'event_time': event_datetime.isoformat(), 'type': 'daily', 'is_history': True, 'data_count': len(kline_data) }) except Exception as e: print(f"Error in get_daily_kline: {e}") return jsonify({ 'error': f'Database error: {str(e)}', 'code': stock_code, 'name': stock_name, 'data': [], 'trade_date': event_datetime.date().strftime('%Y-%m-%d'), 'type': 'daily' }), 500 def get_minute_kline(stock_code, event_datetime, stock_name): """处理分钟K线数据 - 包含零轴(昨日收盘价)""" client = get_clickhouse_client() stock_code_short = stock_code.split('.')[0] # 获取不带后缀的股票代码 def get_trading_days(): trading_days = set() with open('tdays.csv', 'r') as f: reader = csv.DictReader(f) for row in reader: trading_days.add(datetime.strptime(row['DateTime'], '%Y/%m/%d').date()) return trading_days trading_days = get_trading_days() def find_next_trading_day(current_date): """找到下一个交易日""" while current_date <= max(trading_days): current_date += timedelta(days=1) if current_date in trading_days: return current_date return None def find_prev_trading_day(current_date): """找到前一个交易日""" while current_date >= min(trading_days): current_date -= timedelta(days=1) if current_date in trading_days: return current_date return None def get_prev_close(stock_code_short, target_date): """获取前一交易日的收盘价作为零轴基准""" prev_date = find_prev_trading_day(target_date) if not prev_date: return None try: with engine.connect() as conn: # 查询前一交易日的收盘价 sql = """ SELECT CAST(F007N AS FLOAT) as close FROM ea_trade WHERE SECCODE = :stock_code AND TRADEDATE = :prev_date AND F007N IS NOT NULL LIMIT 1 \ """ result = conn.execute(text(sql), { "stock_code": stock_code_short, "prev_date": prev_date }).fetchone() if result: return float(result.close) else: # 如果指定日期没有数据,尝试获取最近的收盘价 fallback_sql = """ SELECT CAST(F007N AS FLOAT) as close, TRADEDATE FROM ea_trade WHERE SECCODE = :stock_code AND TRADEDATE \ < :target_date AND F007N IS NOT NULL ORDER BY TRADEDATE DESC LIMIT 1 \ """ result = conn.execute(text(fallback_sql), { "stock_code": stock_code_short, "target_date": target_date }).fetchone() if result: print(f"Using close price from {result.TRADEDATE} as zero axis") return float(result.close) except Exception as e: print(f"Error getting previous close: {e}") return None target_date = event_datetime.date() is_after_market = event_datetime.time() > dt_time(15, 0) # 核心逻辑:先判断当前日期是否是交易日,以及是否已收盘 if target_date in trading_days and is_after_market: # 如果是交易日且已收盘,查找下一个交易日 next_trade_date = find_next_trading_day(target_date) if next_trade_date: target_date = next_trade_date elif target_date not in trading_days: # 如果不是交易日,先尝试找下一个交易日 next_trade_date = find_next_trading_day(target_date) if next_trade_date: target_date = next_trade_date else: # 如果找不到下一个交易日,找最近的历史交易日 target_date = find_prev_trading_day(target_date) if not target_date: return jsonify({ 'error': 'No data available', 'code': stock_code, 'name': stock_name, 'data': [], 'trade_date': event_datetime.date().strftime('%Y-%m-%d'), 'type': 'minute' }) # 获取前一交易日收盘价作为零轴 zero_axis = get_prev_close(stock_code_short, target_date) # 获取目标日期的完整交易时段数据 data = client.execute(""" SELECT timestamp, open, high, low, close, volume, amt FROM stock_minute WHERE code = %(code)s AND timestamp BETWEEN %(start)s AND %(end)s ORDER BY timestamp """, { 'code': stock_code, 'start': datetime.combine(target_date, dt_time(9, 30)), 'end': datetime.combine(target_date, dt_time(15, 0)) }) kline_data = [] for row in data: point = { 'time': row[0].strftime('%H:%M'), 'open': float(row[1]), 'high': float(row[2]), 'low': float(row[3]), 'close': float(row[4]), 'volume': float(row[5]), 'amount': float(row[6]) } # 如果有零轴数据,计算涨跌幅和涨跌额 if zero_axis: point['prev_close'] = zero_axis point['change'] = point['close'] - zero_axis # 涨跌额 point['change_pct'] = ((point['close'] - zero_axis) / zero_axis * 100) if zero_axis != 0 else 0 # 涨跌幅百分比 kline_data.append(point) response_data = { 'code': stock_code, 'name': stock_name, 'data': kline_data, 'trade_date': target_date.strftime('%Y-%m-%d'), 'type': 'minute', 'is_history': target_date < event_datetime.date() } # 添加零轴信息到响应中 if zero_axis: response_data['zero_axis'] = zero_axis response_data['prev_close'] = zero_axis # 计算当日整体涨跌幅(如果有数据) if kline_data: last_close = kline_data[-1]['close'] response_data['day_change'] = last_close - zero_axis response_data['day_change_pct'] = ((last_close - zero_axis) / zero_axis * 100) if zero_axis != 0 else 0 return jsonify(response_data) class HistoricalEvent(db.Model): """历史事件模型""" id = db.Column(db.Integer, primary_key=True) event_id = db.Column(db.Integer, db.ForeignKey('event.id')) title = db.Column(db.String(200)) content = db.Column(db.Text) event_date = db.Column(db.DateTime) relevance = db.Column(db.Integer) # 相关性 importance = db.Column(db.Integer) # 重要程度 related_stock = db.Column(db.JSON) # 保留JSON字段 created_at = db.Column(db.DateTime, default=beijing_now) # 新增关系 stocks = db.relationship('HistoricalEventStock', backref='historical_event', lazy='dynamic', cascade='all, delete-orphan') class HistoricalEventStock(db.Model): """历史事件相关股票模型""" __tablename__ = 'historical_event_stocks' id = db.Column(db.Integer, primary_key=True) historical_event_id = db.Column(db.Integer, db.ForeignKey('historical_event.id'), nullable=False) stock_code = db.Column(db.String(20), nullable=False) stock_name = db.Column(db.String(50)) relation_desc = db.Column(db.Text) correlation = db.Column(db.Float, default=0.5) sector = db.Column(db.String(100)) created_at = db.Column(db.DateTime, default=beijing_now) __table_args__ = ( db.UniqueConstraint('historical_event_id', 'stock_code', name='unique_event_stock'), ) @app.route('/event/follow/', methods=['POST']) @token_required def follow_event(event_id): """关注/取消关注事件""" event = Event.query.get_or_404(event_id) follow = EventFollow.query.filter_by( user_id=request.user.id, event_id=event_id ).first() try: if follow: db.session.delete(follow) event.follower_count -= 1 message = '已取消关注' else: follow = EventFollow(user_id=request.user.id, event_id=event_id) db.session.add(follow) event.follower_count += 1 message = '已关注' db.session.commit() return jsonify({'success': True, 'message': message}) except Exception as e: db.session.rollback() return jsonify({'success': False, 'message': '操作失败,请重试'}) # 帖子相关路由 @app.route('/post/create/', methods=['GET', 'POST']) @token_required def create_post(event_id): """创建新帖子""" event = Event.query.get_or_404(event_id) if request.method == 'POST': try: post = Post( event_id=event_id, user_id=request.user.id, title=request.form.get('title'), content=request.form['content'], content_type=request.form.get('content_type', 'text') ) db.session.add(post) event.post_count += 1 db.session.commit() # 检查是否是 API 请求(通过 Accept header 或 Content-Type 判断) if request.headers.get('Accept') == 'application/json' or \ request.headers.get('Content-Type', '').startswith('application/json'): return jsonify({ 'success': True, 'message': '发布成功', 'data': { 'post_id': post.id, 'event_id': event_id, 'redirect_url': url_for('event_detail', event_id=event_id) } }) else: # 传统表单提交,添加成功消息并重定向 flash('发布成功', 'success') return redirect(url_for('event_detail', event_id=event_id)) except Exception as e: db.session.rollback() if request.headers.get('Accept') == 'application/json' or \ request.headers.get('Content-Type', '').startswith('application/json'): return jsonify({ 'success': False, 'message': '发布失败,请重试' }), 400 else: flash('发布失败,请重试', 'error') app.logger.error(f"Error creating post: {str(e)}") return render_template('projects/create_post.html', event=event) # 点赞相关路由 @app.route('/post/like/', methods=['POST']) @token_required def like_post(post_id): """点赞/取消点赞帖子""" post = Post.query.get_or_404(post_id) like = PostLike.query.filter_by( user_id=request.user.id, post_id=post_id ).first() try: if like: # 取消点赞 db.session.delete(like) post.likes_count -= 1 message = '已取消点赞' else: # 添加点赞 like = PostLike(user_id=request.user.id, post_id=post_id) db.session.add(like) post.likes_count += 1 message = '已点赞' db.session.commit() return jsonify({ 'success': True, 'message': message, 'likes_count': post.likes_count }) except Exception as e: db.session.rollback() return jsonify({'success': False, 'message': '操作失败,请重试'}) def update_user_activity(): """更新用户活跃度""" with app.app_context(): try: # 获取过去7天内的用户活动数据 seven_days_ago = beijing_now() - timedelta(days=7) # 统计用户发帖、评论、点赞等活动 active_users = db.session.query( User.id, db.func.count(Post.id).label('post_count'), db.func.count(Comment.id).label('comment_count'), db.func.count(PostLike.id).label('like_count') ).outerjoin(Post, User.id == Post.user_id) \ .outerjoin(Comment, User.id == Comment.user_id) \ .outerjoin(PostLike, User.id == PostLike.user_id) \ .filter( db.or_( Post.created_at >= seven_days_ago, Comment.created_at >= seven_days_ago, PostLike.created_at >= seven_days_ago ) ).group_by(User.id).all() # 更新用户活跃度分数 for user_id, post_count, comment_count, like_count in active_users: activity_score = post_count * 2 + comment_count * 1 + like_count * 0.5 User.query.filter_by(id=user_id).update({ 'activity_score': activity_score, 'last_active': beijing_now() }) db.session.commit() current_app.logger.info("Successfully updated user activity scores") except Exception as e: db.session.rollback() current_app.logger.error(f"Error updating user activity: {str(e)}") @app.route('/post/comment/', methods=['POST']) @token_required def add_comment(post_id): """添加评论""" post = Post.query.get_or_404(post_id) try: content = request.form.get('content') parent_id = request.form.get('parent_id', type=int) if not content: return jsonify({'success': False, 'message': '评论内容不能为空'}) comment = Comment( post_id=post_id, user_id=request.user.id, content=content, parent_id=parent_id ) db.session.add(comment) post.comments_count += 1 db.session.commit() return jsonify({ 'success': True, 'message': '评论成功', 'comment': { 'id': comment.id, 'content': comment.content, 'user_name': request.user.username, 'user_avatar': get_full_avatar_url(request.user.avatar_url), # 修改这里 'created_at': comment.created_at.strftime('%Y-%m-%d %H:%M:%S') } }) except Exception as e: db.session.rollback() return jsonify({'success': False, 'message': '评论失败,请重试'}) @app.route('/post/comments/') def get_comments(post_id): """获取帖子评论列表""" page = request.args.get('page', 1, type=int) # 获取顶层评论 comments = Comment.query.filter_by( post_id=post_id, parent_id=None, status='active' ).order_by( Comment.created_at.desc() ).paginate(page=page, per_page=20) # 同时获取每个顶层评论的部分回复 comments_data = [] for comment in comments.items: replies = Comment.query.filter_by( parent_id=comment.id, status='active' ).order_by( Comment.created_at.asc() ).limit(3).all() comments_data.append({ 'id': comment.id, 'content': comment.content, 'user': { 'id': comment.user.id, 'username': comment.user.username, 'avatar_url': get_full_avatar_url(comment.user.avatar_url), # 修改这里 }, 'created_at': comment.created_at.strftime('%Y-%m-%d %H:%M:%S'), 'replies': [{ 'id': reply.id, 'content': reply.content, 'user': { 'id': reply.user.id, 'username': reply.user.username, 'avatar_url': get_full_avatar_url(reply.user.avatar_url), # 修改这里 }, 'created_at': reply.created_at.strftime('%Y-%m-%d %H:%M:%S') } for reply in replies] }) return jsonify({ 'comments': comments_data, 'total': comments.total, 'pages': comments.pages, 'current_page': comments.page }) beijing_tz = pytz.timezone('Asia/Shanghai') def update_hot_scores(): """ 更新所有事件的热度分数 在Flask应用上下文中执行数据库操作 """ with app.app_context(): try: # 获取所有活跃事件 events = Event.query.filter_by(status='active').all() current_time = beijing_now() for event in events: # 确保created_at有时区信息,解决naive和aware datetime比较问题 created_at = beijing_tz.localize( event.created_at) if event.created_at.tzinfo is None else event.created_at # 使用处理后的created_at计算hours_passed hours_passed = (current_time - created_at).total_seconds() / 3600 # 基础分数 - 帖子数和评论数 posts = Post.query.filter_by(event_id=event.id).all() post_count = len(posts) comment_count = sum(post.comments_count for post in posts) # 获取24小时内的新增帖子数 recent_posts = Post.query.filter( Post.event_id == event.id, Post.created_at >= current_time - timedelta(hours=24) ).count() # 获取点赞数 like_count = db.session.query(func.sum(Post.likes_count)).filter( Post.event_id == event.id ).scalar() or 0 # 基础互动分数 = 帖子数 * 2 + 评论数 * 1 + 点赞数 * 0.5 interaction_score = (post_count * 2) + (comment_count * 1) + (like_count * 0.5) # 关注度分数 = 关注人数 * 3 follow_score = event.follower_count * 3 # 浏览量分数 = log(浏览量) if event.view_count > 0: view_score = math.log(event.view_count) * 2 else: view_score = 0 # 时间衰减因子 - 使用上面已经计算好的hours_passed time_decay = math.exp(-hours_passed / 72) # 3天后衰减为原始分数的1/e # 最近活跃度权重 recent_activity_weight = (recent_posts * 5) # 24小时内的新帖权重高 # 总分 = (互动分数 + 关注度分数 + 浏览量分数 + 最近活跃度) * 时间衰减 total_score = (interaction_score + follow_score + view_score + recent_activity_weight) * time_decay # 更新热度分数 event.hot_score = round(total_score, 2) # 分数的对数值作为事件的trending_score (用于趋势排序) if total_score > 0: event.trending_score = math.log(total_score) * time_decay else: event.trending_score = 0 # 记录热度历史 history = EventHotHistory( event_id=event.id, score=event.hot_score, interaction_score=interaction_score, follow_score=follow_score, view_score=view_score, recent_activity_score=recent_activity_weight, time_decay=time_decay ) db.session.add(history) db.session.commit() app.logger.info("Successfully updated event hot scores") except Exception as e: db.session.rollback() app.logger.error(f"Error updating hot scores: {str(e)}") raise # 添加热度历史记录模型 def calculate_hot_score(event): """计算事件热度分数""" current_time = beijing_now() time_diff = (current_time - event.created_at).total_seconds() / 3600 # 转换为小时 # 基础分数 = 浏览量 * 0.1 + 帖子数 * 0.5 + 关注数 * 1 base_score = ( event.view_count * 0.1 + event.post_count * 0.5 + event.follower_count * 1 ) # 时间衰减因子,72小时(3天)内的事件获得较高权重 time_factor = max(1 - (time_diff / 72), 0.1) return base_score * time_factor @app.route('/api/sector/hierarchy', methods=['GET']) def api_sector_hierarchy(): """行业层级关系接口:展示多个行业分类体系的层级结构""" try: # 定义需要返回的行业分类体系 classification_systems = [ '申银万国行业分类' ] result = [] # 改为数组 for classification in classification_systems: # 查询特定分类标准的数据 sectors = SectorInfo.query.filter_by(F002V=classification).all() if not sectors: continue # 构建该分类体系的层级结构 hierarchy = {} for sector in sectors: level1 = sector.F004V # 一级行业 level2 = sector.F005V # 二级行业 level3 = sector.F006V # 三级行业 level4 = sector.F007V # 四级行业 # 统计股票数量 stock_code = sector.SECCODE # 初始化一级行业 if level1 not in hierarchy: hierarchy[level1] = { 'level2_sectors': {}, 'stocks': set(), 'stocks_count': 0 } # 添加股票到一级行业 if stock_code: hierarchy[level1]['stocks'].add(stock_code) # 处理二级行业 if level2: if level2 not in hierarchy[level1]['level2_sectors']: hierarchy[level1]['level2_sectors'][level2] = { 'level3_sectors': {}, 'stocks': set(), 'stocks_count': 0 } # 添加股票到二级行业 if stock_code: hierarchy[level1]['level2_sectors'][level2]['stocks'].add(stock_code) # 处理三级行业 if level3: if level3 not in hierarchy[level1]['level2_sectors'][level2]['level3_sectors']: hierarchy[level1]['level2_sectors'][level2]['level3_sectors'][level3] = { 'level4_sectors': [], 'stocks': set(), 'stocks_count': 0 } # 添加股票到三级行业 if stock_code: hierarchy[level1]['level2_sectors'][level2]['level3_sectors'][level3]['stocks'].add( stock_code) # 处理四级行业 if level4 and level4 not in \ hierarchy[level1]['level2_sectors'][level2]['level3_sectors'][level3]['level4_sectors']: hierarchy[level1]['level2_sectors'][level2]['level3_sectors'][level3][ 'level4_sectors'].append(level4) # 计算股票数量并清理set对象 formatted_hierarchy = [] for level1, level1_data in hierarchy.items(): level1_item = { 'level1_sector': level1, 'stocks_count': len(level1_data['stocks']), 'level2_sectors': [] } for level2, level2_data in level1_data['level2_sectors'].items(): level2_item = { 'level2_sector': level2, 'stocks_count': len(level2_data['stocks']), 'level3_sectors': [] } for level3, level3_data in level2_data['level3_sectors'].items(): level3_item = { 'level3_sector': level3, 'stocks_count': len(level3_data['stocks']), 'level4_sectors': level3_data['level4_sectors'] } level2_item['level3_sectors'].append(level3_item) # 按股票数量排序 level2_item['level3_sectors'].sort(key=lambda x: x['stocks_count'], reverse=True) level1_item['level2_sectors'].append(level2_item) # 按股票数量排序 level1_item['level2_sectors'].sort(key=lambda x: x['stocks_count'], reverse=True) formatted_hierarchy.append(level1_item) # 按股票数量排序 formatted_hierarchy.sort(key=lambda x: x['stocks_count'], reverse=True) # 将该分类体系添加到结果数组中 result.append({ 'classification_name': classification, 'total_level1_count': len(formatted_hierarchy), 'total_stocks_count': sum(item['stocks_count'] for item in formatted_hierarchy), 'hierarchy': formatted_hierarchy }) # 按总股票数量排序 result.sort(key=lambda x: x['total_stocks_count'], reverse=True) return jsonify({ "code": 200, "message": "success", "data": result }) except Exception as e: return jsonify({ "code": 500, "message": str(e), "data": None }), 500 @app.route('/api/sector/banner', methods=['GET']) def api_sector_banner(): """行业分类 banner 接口:返回一级分类和对应二级行业列表""" try: # 原始映射 sector_map = { '石油石化': '大周期', '煤炭': '大周期', '有色金属': '大周期', '钢铁': '大周期', '基础化工': '大周期', '建筑材料': '大周期', '机械设备': '大周期', '电力设备及新能源': '大周期', '国防军工': '大周期', '电力设备': '大周期', '电网设备': '大周期', '风力发电': '大周期', '太阳能发电': '大周期', '建筑装饰': '大周期', '汽车': '大消费', '家用电器': '大消费', '酒类': '大消费', '食品饮料': '大消费', '医药生物': '大消费', '纺织服饰': '大消费', '农林牧渔': '大消费', '商贸零售': '大消费', '轻工制造': '大消费', '消费者服务': '大消费', '美容护理': '大消费', '社会服务': '大消费', '银行': '大金融地产', '证券': '大金融地产', '保险': '大金融地产', '多元金融': '大金融地产', '综合金融': '大金融地产', '房地产': '大金融地产', '非银金融': '大金融地产', '计算机': 'TMT板块', '电子': 'TMT板块', '传媒': 'TMT板块', '通信': 'TMT板块', '交通运输': '公共产业板块', '电力公用事业': '公共产业板块', '建筑': '公共产业板块', '环保': '公共产业板块', '综合': '公共产业板块', '公用事业': '公共产业板块' } # 重组结构为 一级 → [二级...] result_dict = {} for sub_sector, primary_sector in sector_map.items(): result_dict.setdefault(primary_sector, []).append(sub_sector) # 格式化成列表 result_list = [ {"primary_sector": primary, "sub_sectors": subs} for primary, subs in result_dict.items() ] return jsonify({ "code": 200, "message": "success", "data": result_list }) except Exception as e: return jsonify({ "code": 500, "message": str(e), "data": None }), 500 def get_limit_rate(stock_code): """ 根据股票代码获取涨跌停限制比例 Args: stock_code: 股票代码 Returns: float: 涨跌停限制比例 """ if not stock_code: return 10.0 # 去掉市场后缀 clean_code = stock_code.replace('.SH', '').replace('.SZ', '').replace('.BJ', '') # ST股票 (5%涨跌停) if 'ST' in stock_code.upper(): return 5.0 # 科创板 (688开头, 20%涨跌停) if clean_code.startswith('688'): return 20.0 # 创业板注册制 (30开头, 20%涨跌停) if clean_code.startswith('30'): return 20.0 # 北交所 (43、83、87开头, 30%涨跌停) if clean_code.startswith(('43', '83', '87')): return 30.0 # 主板、中小板默认 (10%涨跌停) return 10.0 @app.route('/api/events', methods=['GET']) def api_get_events(): """ 获取事件列表API - 优化版本(保持完全兼容) 仅限 Pro/Max 会员访问(小程序功能) 优化策略: 1. 使用ind_type字段简化内部逻辑 2. 批量获取股票行情,包括周涨跌计算 3. 保持原有返回数据结构不变 """ try: # ==================== 参数解析 ==================== # 分页参数 page = max(1, request.args.get('page', 1, type=int)) per_page = min(100, max(1, request.args.get('per_page', 10, type=int))) # 基础筛选参数 event_type = request.args.get('type', 'all') event_status = request.args.get('status', 'active') importance = request.args.get('importance', 'all') # 日期筛选参数 start_date = request.args.get('start_date') end_date = request.args.get('end_date') date_range = request.args.get('date_range') recent_days = request.args.get('recent_days', type=int) time_filter = request.args.get('time_filter') # 时间快速筛选参数 # 行业筛选参数(重新设计) ind_type = request.args.get('ind_type', 'all') stock_sector = request.args.get('stock_sector', 'all') secondary_sector = request.args.get('secondary_sector', 'all') # 新的行业层级筛选参数 industry_level = request.args.get('industry_level', type=int) # 筛选层级:1-4 industry_classification = request.args.get('industry_classification') # 行业名称 # 如果使用旧参数,映射到ind_type if ind_type == 'all' and stock_sector != 'all': ind_type = stock_sector # 标签筛选参数 tag = request.args.get('tag') tags = request.args.get('tags') keywords = request.args.get('keywords') # 搜索参数 search_query = request.args.get('q') search_type = request.args.get('search_type', 'topic') search_fields = request.args.get('search_fields', 'title,description').split(',') # 排序参数 sort_by = request.args.get('sort', 'new') return_type = request.args.get('return_type', 'avg') order = request.args.get('order', 'desc') # 收益率筛选参数 min_avg_return = request.args.get('min_avg_return', type=float) max_avg_return = request.args.get('max_avg_return', type=float) min_max_return = request.args.get('min_max_return', type=float) max_max_return = request.args.get('max_max_return', type=float) min_week_return = request.args.get('min_week_return', type=float) max_week_return = request.args.get('max_week_return', type=float) # 其他筛选参数 min_hot_score = request.args.get('min_hot_score', type=float) max_hot_score = request.args.get('max_hot_score', type=float) min_view_count = request.args.get('min_view_count', type=int) creator_id = request.args.get('creator_id', type=int) # 返回格式参数 include_creator = request.args.get('include_creator', 'true').lower() == 'true' include_stats = request.args.get('include_stats', 'true').lower() == 'true' include_related_data = request.args.get('include_related_data', 'false').lower() == 'true' # ==================== 构建查询 ==================== query = Event.query # 状态筛选 if event_status != 'all': query = query.filter_by(status=event_status) # 事件类型筛选 if event_type != 'all': query = query.filter_by(event_type=event_type) # 重要性筛选(支持多选,逗号分隔,如 importance=S,A,B) if importance != 'all': importance_list = [i.strip().upper() for i in importance.split(',') if i.strip()] if len(importance_list) == 1: query = query.filter_by(importance=importance_list[0]) elif len(importance_list) > 1: query = query.filter(Event.importance.in_(importance_list)) # 行业类型筛选(使用ind_type字段) if ind_type != 'all': query = query.filter_by(ind_type=ind_type) # 创建者筛选 if creator_id: query = query.filter_by(creator_id=creator_id) # ==================== 日期筛选 ==================== # 时间快速筛选(优先级最高) time_filter_applied = False if time_filter: now = datetime.now() today = now.date() if time_filter == 'latest': # 最新:最近100条,不设时间筛选,但限制数量(在后面排序后处理) time_filter_applied = True # 特殊处理:latest模式下per_page强制为100,忽略分页 per_page = 100 page = 1 elif time_filter == 'intraday': # 盘中:从今天早上9:30到当前时间 start_time = datetime.combine(today, datetime.strptime('09:30', '%H:%M').time()) query = query.filter(Event.created_at >= start_time) query = query.filter(Event.created_at <= now) time_filter_applied = True elif time_filter == 'morning': # 早盘:从今天早上9:30到11:30(上午盘交易时段) start_time = datetime.combine(today, datetime.strptime('09:30', '%H:%M').time()) end_time = datetime.combine(today, datetime.strptime('11:30', '%H:%M').time()) query = query.filter(Event.created_at >= start_time) query = query.filter(Event.created_at <= end_time) time_filter_applied = True elif time_filter == 'afternoon': # 午盘:从今天上午11:30至今(下午盘交易时段开始后) start_time = datetime.combine(today, datetime.strptime('11:30', '%H:%M').time()) query = query.filter(Event.created_at >= start_time) query = query.filter(Event.created_at <= now) time_filter_applied = True elif time_filter == 'today': # 今日全天:从昨天15:00到现在 yesterday = today - timedelta(days=1) start_time = datetime.combine(yesterday, datetime.strptime('15:00', '%H:%M').time()) query = query.filter(Event.created_at >= start_time) query = query.filter(Event.created_at <= now) time_filter_applied = True elif time_filter == 'yesterday': # 昨日:上一个交易日的完整数据 # 计算上一个交易日(跳过周末) def get_previous_trading_day(d): d = d - timedelta(days=1) while d.weekday() >= 5: # 5=周六, 6=周日 d = d - timedelta(days=1) return d last_trading_day = get_previous_trading_day(today) day_before_last = get_previous_trading_day(last_trading_day) # 从上上个交易日15:00到上个交易日15:00 start_time = datetime.combine(day_before_last, datetime.strptime('15:00', '%H:%M').time()) end_time = datetime.combine(last_trading_day, datetime.strptime('15:00', '%H:%M').time()) query = query.filter(Event.created_at >= start_time) query = query.filter(Event.created_at <= end_time) time_filter_applied = True elif time_filter == 'week': # 近一周:自然日7天内 start_time = datetime.combine(today - timedelta(days=7), datetime.min.time()) query = query.filter(Event.created_at >= start_time) time_filter_applied = True elif time_filter == 'month': # 近一月:自然日30天内 start_time = datetime.combine(today - timedelta(days=30), datetime.min.time()) query = query.filter(Event.created_at >= start_time) time_filter_applied = True # 如果没有使用time_filter,则使用其他日期筛选方式 if not time_filter_applied: if recent_days: cutoff_date = datetime.now() - timedelta(days=recent_days) query = query.filter(Event.created_at >= cutoff_date) else: # 处理日期范围字符串 if date_range and ' 至 ' in date_range: try: start_date_str, end_date_str = date_range.split(' 至 ') start_date = start_date_str.strip() end_date = end_date_str.strip() except ValueError: pass # 开始日期 if start_date: try: if len(start_date) == 10: start_datetime = datetime.strptime(start_date, '%Y-%m-%d') else: start_datetime = datetime.strptime(start_date, '%Y-%m-%d %H:%M:%S') query = query.filter(Event.created_at >= start_datetime) except ValueError: pass # 结束日期 if end_date: try: if len(end_date) == 10: end_datetime = datetime.strptime(end_date, '%Y-%m-%d') end_datetime = end_datetime.replace(hour=23, minute=59, second=59) else: end_datetime = datetime.strptime(end_date, '%Y-%m-%d %H:%M:%S') query = query.filter(Event.created_at <= end_datetime) except ValueError: pass # ==================== 行业层级筛选(申银万国行业分类) ==================== if industry_level and industry_classification: # 排除行业分类体系名称本身,这些不是具体的行业 classification_systems = [ '申银万国行业分类', '中上协行业分类', '巨潮行业分类', '新财富行业分类', '证监会行业分类', '证监会行业分类(2001)' ] if industry_classification not in classification_systems: # 使用内存缓存获取行业代码前缀(性能优化:避免每次请求都查询数据库) # 前端发送的level值: # level=2 -> 一级行业 # level=3 -> 二级行业 # level=4 -> 三级行业 # level=5 -> 四级行业 if industry_level in SYWG_INDUSTRY_CACHE: # 直接从缓存中获取代码前缀列表 code_prefixes = SYWG_INDUSTRY_CACHE[industry_level].get(industry_classification, []) if code_prefixes: # 构建查询条件:查找related_industries以这些前缀开头的事件 # related_industries 现在是 varchar 格式,如 "S640701" conditions = [] for prefix in code_prefixes: conditions.append(Event.related_industries.like(f"{prefix}%")) if conditions: query = query.filter(or_(*conditions)) else: # 没有找到匹配的行业代码,返回空结果 query = query.filter(Event.id == -1) else: # 无效的层级参数 app.logger.warning(f"Invalid industry_level: {industry_level}") else: # industry_classification 是分类体系名称,不进行筛选 app.logger.info( f"Skipping filter: industry_classification '{industry_classification}' is a classification system name") # ==================== 细分行业筛选(保留向后兼容) ==================== elif secondary_sector != 'all': # 直接按行业名称查询(最后一级行业 - level5/f007v) sector_code_query = db.session.query(text("DISTINCT f003v")).select_from( text("ea_sector") ).filter( text("f002v = '申银万国行业分类' AND f007v = :sector_name") ).params(sector_name=secondary_sector) sector_result = sector_code_query.first() if sector_result and sector_result[0]: industry_code_to_search = sector_result[0] # related_industries 现在是 varchar 格式,直接匹配 query = query.filter(Event.related_industries == industry_code_to_search) else: # 如果没有找到对应的行业代码,返回空结果 query = query.filter(Event.id == -1) # ==================== 概念/标签筛选 ==================== # 单个标签筛选 if tag: if isinstance(db.engine.dialect, MySQLDialect): query = query.filter(text("JSON_CONTAINS(keywords, :tag, '$')")) query = query.params(tag=json.dumps(tag)) else: query = query.filter(Event.keywords.cast(JSONB).contains([tag])) # 多个标签筛选 (AND逻辑) if tags: tag_list = [t.strip() for t in tags.split(',') if t.strip()] for single_tag in tag_list: if isinstance(db.engine.dialect, MySQLDialect): query = query.filter(text("JSON_CONTAINS(keywords, :tag, '$')")) query = query.params(tag=json.dumps(single_tag)) else: query = query.filter(Event.keywords.cast(JSONB).contains([single_tag])) # 关键词筛选 (OR逻辑) if keywords: keyword_list = [k.strip() for k in keywords.split(',') if k.strip()] keyword_filters = [] for keyword in keyword_list: if isinstance(db.engine.dialect, MySQLDialect): keyword_filters.append(text("JSON_CONTAINS(keywords, :keyword, '$')")) else: keyword_filters.append(Event.keywords.cast(JSONB).contains([keyword])) if keyword_filters: query = query.filter(or_(*keyword_filters)) # ==================== 搜索功能 ==================== if search_query: search_terms = search_query.strip().split() if search_type == 'stock': # 股票搜索 query = query.join(RelatedStock).filter( or_( RelatedStock.stock_code.ilike(f'%{search_query}%'), RelatedStock.stock_name.ilike(f'%{search_query}%') ) ).distinct() elif search_type == 'all': # 全局搜索 search_filters = [] # 文本字段搜索 for term in search_terms: term_filters = [] if 'title' in search_fields: term_filters.append(Event.title.ilike(f'%{term}%')) if 'description' in search_fields: term_filters.append(Event.description.ilike(f'%{term}%')) if 'keywords' in search_fields: if isinstance(db.engine.dialect, MySQLDialect): term_filters.append(text("JSON_CONTAINS(keywords, :term, '$')")) else: term_filters.append(Event.keywords.cast(JSONB).contains([term])) if term_filters: search_filters.append(or_(*term_filters)) # 股票搜索 stock_subquery = db.session.query(RelatedStock.event_id).filter( or_( RelatedStock.stock_code.ilike(f'%{search_query}%'), RelatedStock.stock_name.ilike(f'%{search_query}%') ) ).subquery() search_filters.append(Event.id.in_(stock_subquery)) if search_filters: query = query.filter(or_(*search_filters)) else: # 话题搜索 (默认) for term in search_terms: term_filters = [] if 'title' in search_fields: term_filters.append(Event.title.ilike(f'%{term}%')) if 'description' in search_fields: term_filters.append(Event.description.ilike(f'%{term}%')) if 'keywords' in search_fields: if isinstance(db.engine.dialect, MySQLDialect): term_filters.append(text("JSON_CONTAINS(keywords, :term, '$')")) else: term_filters.append(Event.keywords.cast(JSONB).contains([term])) if term_filters: query = query.filter(or_(*term_filters)) # ==================== 收益率筛选 ==================== if min_avg_return is not None: query = query.filter(Event.related_avg_chg >= min_avg_return) if max_avg_return is not None: query = query.filter(Event.related_avg_chg <= max_avg_return) if min_max_return is not None: query = query.filter(Event.related_max_chg >= min_max_return) if max_max_return is not None: query = query.filter(Event.related_max_chg <= max_max_return) if min_week_return is not None: query = query.filter(Event.related_week_chg >= min_week_return) if max_week_return is not None: query = query.filter(Event.related_week_chg <= max_week_return) # ==================== 其他数值筛选 ==================== if min_hot_score is not None: query = query.filter(Event.hot_score >= min_hot_score) if max_hot_score is not None: query = query.filter(Event.hot_score <= max_hot_score) if min_view_count is not None: query = query.filter(Event.view_count >= min_view_count) # ==================== 排序逻辑 ==================== order_func = desc if order.lower() == 'desc' else asc if sort_by == 'hot': query = query.order_by(order_func(Event.hot_score), desc(Event.created_at)) elif sort_by == 'new': query = query.order_by(order_func(Event.created_at)) elif sort_by == 'returns': if return_type == 'avg': query = query.order_by(order_func(Event.related_avg_chg), desc(Event.created_at)) elif return_type == 'max': query = query.order_by(order_func(Event.related_max_chg), desc(Event.created_at)) elif return_type == 'week': query = query.order_by(order_func(Event.related_week_chg), desc(Event.created_at)) elif sort_by == 'importance': importance_order = case( (Event.importance == 'S', 1), (Event.importance == 'A', 2), (Event.importance == 'B', 3), (Event.importance == 'C', 4), else_=5 ) if order.lower() == 'desc': query = query.order_by(importance_order, desc(Event.created_at)) else: query = query.order_by(desc(importance_order), desc(Event.created_at)) elif sort_by == 'view_count': query = query.order_by(order_func(Event.view_count), desc(Event.created_at)) elif sort_by == 'follow' and hasattr(request, 'user') and request.user.is_authenticated: # 关注的事件排序 query = query.join(EventFollow).filter( EventFollow.user_id == request.user.id ).order_by(order_func(Event.created_at)) else: # 兜底排序:始终按时间倒序 query = query.order_by(desc(Event.created_at)) # ==================== 分页查询 ==================== paginated = query.paginate(page=page, per_page=per_page, error_out=False) # ==================== 构建响应数据 ==================== events_data = [] for event in paginated.items: # 构建事件数据(保持原有结构,个股信息和统计置空) event_dict = { 'id': event.id, 'title': event.title, 'description': event.description, 'event_type': event.event_type, 'importance': event.importance, 'status': event.status, 'created_at': event.created_at.isoformat() if event.created_at else None, 'updated_at': event.updated_at.isoformat() if event.updated_at else None, 'start_time': event.start_time.isoformat() if event.start_time else None, 'end_time': event.end_time.isoformat() if event.end_time else None, # 个股信息(置空) 'related_stocks': [], # 股票统计(置空或使用数据库字段) 'stocks_stats': { 'stocks_count': 10, 'valid_stocks_count': 0, # 使用数据库字段的涨跌幅 'avg_week_change': round(event.related_week_chg, 2) if event.related_week_chg else 0, 'max_week_change': round(event.related_max_chg, 2) if event.related_max_chg else 0, 'avg_daily_change': round(event.related_avg_chg, 2) if event.related_avg_chg else 0, 'max_daily_change': round(event.related_max_chg, 2) if event.related_max_chg else 0 } } # 统计信息(可选) if include_stats: event_dict.update({ 'hot_score': event.hot_score, 'view_count': event.view_count, 'post_count': event.post_count, 'follower_count': event.follower_count, 'related_avg_chg': event.related_avg_chg, 'related_max_chg': event.related_max_chg, 'related_week_chg': event.related_week_chg, 'invest_score': event.invest_score, 'trending_score': event.trending_score, }) # 创建者信息(可选) if include_creator: event_dict['creator'] = { 'id': event.creator.id if event.creator else None, 'username': event.creator.username if event.creator else 'Anonymous', 'avatar_url': get_full_avatar_url(event.creator.avatar_url) if event.creator else None, 'is_creator': event.creator.is_creator if event.creator else False, 'creator_type': event.creator.creator_type if event.creator else None } # 关联数据 event_dict['keywords'] = event.keywords if isinstance(event.keywords, list) else [] event_dict['related_industries'] = event.related_industries # 包含统计信息(可选,置空) if include_stats: event_dict['stats'] = { 'related_stocks_count': 10, 'historical_events_count': 0, 'related_data_count': 0, 'related_concepts_count': 0 } # 包含关联数据(可选,已置空) if include_related_data: event_dict['related_stocks'] = [] events_data.append(event_dict) # ==================== 构建筛选信息 ==================== applied_filters = {} # 记录已应用的筛选条件 if event_type != 'all': applied_filters['type'] = event_type if importance != 'all': applied_filters['importance'] = importance if time_filter: applied_filters['time_filter'] = time_filter if start_date: applied_filters['start_date'] = start_date if end_date: applied_filters['end_date'] = end_date if industry_level and industry_classification: applied_filters['industry_level'] = industry_level applied_filters['industry_classification'] = industry_classification if tag: applied_filters['tag'] = tag if tags: applied_filters['tags'] = tags if search_query: applied_filters['search_query'] = search_query applied_filters['search_type'] = search_type # ==================== 返回结果(保持完全兼容,统计数据置空) ==================== return jsonify({ 'success': True, 'data': { 'events': events_data, 'pagination': { 'page': paginated.page, 'per_page': paginated.per_page, 'total': paginated.total, 'pages': paginated.pages, 'has_prev': paginated.has_prev, 'has_next': paginated.has_next }, 'filters': { 'applied_filters': { 'type': event_type, 'status': event_status, 'importance': importance, 'stock_sector': stock_sector, # 保持兼容 'secondary_sector': secondary_sector, # 保持兼容 'sort': sort_by, 'order': order } }, # 整体股票涨跌幅分布统计(置空) 'overall_stats': { 'total_stocks': 0, 'change_distribution': { 'limit_down': 0, 'down_over_5': 0, 'down_5_to_1': 0, 'down_within_1': 0, 'flat': 0, 'up_within_1': 0, 'up_1_to_5': 0, 'up_over_5': 0, 'limit_up': 0 }, 'change_distribution_percentages': { 'limit_down': 0, 'down_over_5': 0, 'down_5_to_1': 0, 'down_within_1': 0, 'flat': 0, 'up_within_1': 0, 'up_1_to_5': 0, 'up_over_5': 0, 'limit_up': 0 } } } }) except Exception as e: app.logger.error(f"获取事件列表出错: {str(e)}", exc_info=True) return jsonify({ 'success': False, 'error': str(e), 'error_type': type(e).__name__ }), 500 def get_filter_counts(base_query): """ 获取各个筛选条件的计数信息 可用于前端显示筛选选项的可用数量 """ try: counts = {} # 重要性计数 importance_counts = db.session.query( Event.importance, func.count(Event.id).label('count') ).filter( Event.id.in_(base_query.with_entities(Event.id).subquery()) ).group_by(Event.importance).all() counts['importance'] = {item.importance or 'unknown': item.count for item in importance_counts} # 事件类型计数 type_counts = db.session.query( Event.event_type, func.count(Event.id).label('count') ).filter( Event.id.in_(base_query.with_entities(Event.id).subquery()) ).group_by(Event.event_type).all() counts['event_type'] = {item.event_type or 'unknown': item.count for item in type_counts} return counts except Exception: return {} def get_event_class(count): """根据事件数量返回对应的样式类""" if count >= 10: return 'bg-gradient-danger' elif count >= 7: return 'bg-gradient-warning' elif count >= 4: return 'bg-gradient-info' else: return 'bg-gradient-success' @app.route('/api/calendar-event-counts') def get_calendar_event_counts(): """获取整月的事件数量统计,仅统计type为event的事件""" try: # 获取当前月份的开始和结束日期 today = datetime.now() start_date = today.replace(day=1) if today.month == 12: end_date = today.replace(year=today.year + 1, month=1, day=1) else: end_date = today.replace(month=today.month + 1, day=1) # 修改查询以仅统计type为event的事件数量 query = """ SELECT DATE(calendar_time) as date, COUNT(*) as count FROM future_events WHERE calendar_time BETWEEN :start_date AND :end_date AND type = 'event' GROUP BY DATE(calendar_time) """ result = db.session.execute(text(query), { 'start_date': start_date, 'end_date': end_date }) # 格式化结果为日历事件格式 events = [{ 'title': f'{day.count} 个事件', 'start': day.date.isoformat() if day.date else None, 'className': get_event_class(day.count) } for day in result] return jsonify(events) except Exception as e: app.logger.error(f"获取日历事件统计出错: {str(e)}", exc_info=True) return jsonify({'error': str(e), 'error_type': type(e).__name__}), 500 def get_full_avatar_url(avatar_url): """ 统一处理头像URL,确保返回完整的可访问URL Args: avatar_url: 头像URL字符串 Returns: 完整的头像URL,如果没有头像则返回默认头像URL """ if not avatar_url: # 返回默认头像 return f"{DOMAIN}/static/assets/img/default-avatar.png" # 如果已经是完整URL(http或https开头),直接返回 if avatar_url.startswith(('http://', 'https://')): return avatar_url # 如果是相对路径,拼接域名 if avatar_url.startswith('/'): return f"{DOMAIN}{avatar_url}" else: return f"{DOMAIN}/{avatar_url}" # 修改User模型的to_dict方法 def to_dict(self): """转换为字典格式,方便API返回""" return { 'id': self.id, 'username': self.username, 'email': self.email, 'nickname': self.nickname, 'avatar_url': get_full_avatar_url(self.avatar_url), # 使用统一处理函数 'bio': self.bio, 'location': self.location, 'is_verified': self.is_verified, 'user_level': self.user_level, 'reputation_score': self.reputation_score, 'post_count': self.post_count, 'follower_count': self.follower_count, 'following_count': self.following_count, 'created_at': self.created_at.isoformat() if self.created_at else None, 'last_seen': self.last_seen.isoformat() if self.last_seen else None } # ==================== 标准化API接口 ==================== # 1. 首页接口 @app.route('/api/home', methods=['GET']) def api_home(): try: seven_days_ago = datetime.now() - timedelta(days=7) hot_events = Event.query.filter( Event.status == 'active', Event.created_at >= seven_days_ago ).order_by(Event.hot_score.desc()).limit(10).all() events_data = [] for event in hot_events: related_stocks = RelatedStock.query.filter_by(event_id=event.id).all() # 计算相关性统计数据 correlations = [float(stock.correlation or 0) for stock in related_stocks] avg_correlation = sum(correlations) / len(correlations) if correlations else 0 max_correlation = max(correlations) if correlations else 0 stocks_data = [] total_week_change = 0 max_week_change = 0 total_daily_change = 0 # 新增:日涨跌幅总和 max_daily_change = 0 # 新增:最大日涨跌幅 valid_stocks_count = 0 for stock in related_stocks: stock_code = stock.stock_code.split('.')[0] # 获取最新交易日数据 latest_trade = db.session.execute(text(""" SELECT * FROM ea_trade WHERE SECCODE = :stock_code ORDER BY TRADEDATE DESC LIMIT 1 """), {"stock_code": stock_code}).first() week_change = 0 daily_change = 0 # 新增:日涨跌幅 if latest_trade and latest_trade.F007N: latest_price = float(latest_trade.F007N or 0) latest_date = latest_trade.TRADEDATE daily_change = float(latest_trade.F010N or 0) # F010N是日涨跌幅字段 # 更新日涨跌幅统计 total_daily_change += daily_change max_daily_change = max(max_daily_change, daily_change) # 获取最近5条交易记录 week_ago_trades = db.session.execute(text(""" SELECT * FROM ea_trade WHERE SECCODE = :stock_code AND TRADEDATE < :latest_date ORDER BY TRADEDATE DESC LIMIT 5 """), { "stock_code": stock_code, "latest_date": latest_date }).fetchall() if week_ago_trades and week_ago_trades[-1].F007N: week_ago_price = float(week_ago_trades[-1].F007N or 0) if week_ago_price > 0: week_change = (latest_price - week_ago_price) / week_ago_price * 100 total_week_change += week_change max_week_change = max(max_week_change, week_change) valid_stocks_count += 1 stocks_data.append({ "stock_code": stock.stock_code, "stock_name": stock.stock_name, "correlation": float(stock.correlation or 0), "sector": stock.sector, "week_change": round(week_change, 2), "daily_change": round(daily_change, 2), # 新增:个股日涨跌幅 "latest_trade_date": latest_trade.TRADEDATE.strftime("%Y-%m-%d") if latest_trade else None }) # 计算平均值 avg_week_change = total_week_change / valid_stocks_count if valid_stocks_count > 0 else 0 avg_daily_change = total_daily_change / valid_stocks_count if valid_stocks_count > 0 else 0 events_data.append({ "id": event.id, "title": event.title, "description": event.description, 'event_type': event.event_type, 'importance': event.importance, # 添加重要性 'status': event.status, "created_at": event.created_at.strftime("%Y-%m-%d %H:%M:%S"), 'updated_at': event.updated_at.isoformat() if event.updated_at else None, 'start_time': event.start_time.isoformat() if event.start_time else None, 'end_time': event.end_time.isoformat() if event.end_time else None, 'view_count': event.view_count, # 添加浏览量 'post_count': event.post_count, # 添加帖子数 'follower_count': event.follower_count, # 添加关注者数 "related_stocks": stocks_data, "stocks_stats": { "avg_correlation": round(avg_correlation, 2), "max_correlation": round(max_correlation, 2), "stocks_count": len(related_stocks), "valid_stocks_count": valid_stocks_count, # 周涨跌统计 "avg_week_change": round(avg_week_change, 2), "max_week_change": round(max_week_change, 2), # 日涨跌统计 "avg_daily_change": round(avg_daily_change, 2), "max_daily_change": round(max_daily_change, 2) } }) return jsonify({ "code": 200, "message": "success", "data": { "events": events_data } }) except Exception as e: print(f"Error in api_home: {str(e)}") return jsonify({ "code": 500, "message": str(e), "data": None }), 500 @app.route('/api/auth/logout', methods=['POST']) def logout_with_token(): """使用token登出""" # 从请求头获取token auth_header = request.headers.get('Authorization') if auth_header and auth_header.startswith('Bearer '): token = auth_header[7:] else: data = request.get_json() token = data.get('token') if data else None if token and token in user_tokens: user_tokens.delete(token) # 清除session session.clear() return jsonify({'message': '登出成功'}), 200 def send_sms_code(phone, code, template_id): """发送短信验证码""" try: cred = credential.Credential(SMS_SECRET_ID, SMS_SECRET_KEY) httpProfile = HttpProfile() httpProfile.endpoint = "sms.tencentcloudapi.com" clientProfile = ClientProfile() clientProfile.httpProfile = httpProfile client = sms_client.SmsClient(cred, "ap-beijing", clientProfile) req = models.SendSmsRequest() params = { "PhoneNumberSet": [phone], "SmsSdkAppId": SMS_SDK_APP_ID, "TemplateId": template_id, "SignName": SMS_SIGN_NAME, "TemplateParamSet": [code] if template_id == SMS_TEMPLATE_REGISTER else [code, "5"] } req.from_json_string(json.dumps(params)) resp = client.SendSms(req) return True except TencentCloudSDKException as err: print(f"SMS Error: {err}") return False def generate_verification_code(): """生成6位数字验证码""" return ''.join(random.choices(string.digits, k=6)) @app.route('/api/auth/send-sms', methods=['POST']) def send_sms_verification(): """发送手机验证码(统一接口,自动判断场景)""" data = request.get_json() phone = data.get('phone') if not phone: return jsonify({'error': '手机号不能为空'}), 400 # 检查手机号是否已注册 user_exists = User.query.filter_by(phone=phone).first() is not None # 生成验证码 code = generate_verification_code() # 根据用户是否存在自动选择模板 template_id = SMS_TEMPLATE_LOGIN if user_exists else SMS_TEMPLATE_REGISTER # 发送短信 if send_sms_code(phone, code, template_id): # 统一存储验证码(5分钟有效) verification_codes[phone] = { 'code': code, 'expires': time.time() + 300 } # 简单返回成功,不暴露用户是否存在的信息 return jsonify({ 'message': '验证码已发送', 'expires_in': 300 # 告诉前端验证码有效期(秒) }), 200 else: return jsonify({'error': '验证码发送失败'}), 500 def generate_token(length=32): """生成随机token""" characters = string.ascii_letters + string.digits return ''.join(secrets.choice(characters) for _ in range(length)) @app.route('/api/auth/login/phone', methods=['POST']) def login_with_phone(): """统一的手机号登录/注册接口""" data = request.get_json() phone = data.get('phone') code = data.get('code') username = data.get('username') # 可选,新用户可以提供 password = data.get('password') # 可选,新用户可以提供 if not all([phone, code]): return jsonify({ 'code': 400, 'message': '手机号和验证码不能为空' }), 400 # 验证验证码 stored_code = verification_codes.get(phone) if not stored_code or stored_code['expires'] < time.time(): return jsonify({ 'code': 400, 'message': '验证码已过期' }), 400 if stored_code['code'] != code: return jsonify({ 'code': 400, 'message': '验证码错误' }), 400 try: # 查找用户 user = User.query.filter_by(phone=phone).first() is_new_user = False # 如果用户不存在,自动注册 if not user: is_new_user = True # 如果提供了用户名,检查是否已存在 if username: if User.query.filter_by(username=username).first(): return jsonify({ 'code': 400, 'message': '用户名已被使用,请换一个' }), 400 else: # 自动生成用户名 base_username = f"user_{phone[-4:]}" username = base_username counter = 1 while User.query.filter_by(username=username).first(): username = f"{base_username}_{counter}" counter += 1 # 创建新用户 user = User(username=username, phone=phone) user.email = f"{username}@valuefrontier.temp" # 如果提供了密码就使用,否则生成随机密码 if password: user.set_password(password) else: random_password = generate_token(16) user.set_password(random_password) user.phone_confirmed = True db.session.add(user) db.session.commit() # 生成token token = generate_token(32) # 存储token映射(30天有效期) user_tokens.set(token, { 'user_id': user.id, 'expires': datetime.now() + timedelta(days=30) }) # 清除验证码 del verification_codes[phone] # 设置session(保持与原有逻辑兼容) session.permanent = True session['user_id'] = user.id session['username'] = user.username session['logged_in'] = True # 返回响应 response_data = { 'code': 0, 'message': '欢迎回来' if not is_new_user else '注册成功,欢迎加入', 'token': token, 'is_new_user': is_new_user, # 告诉前端是否是新用户 'user': { 'id': user.id, 'username': user.username, 'phone': user.phone, 'need_complete_profile': is_new_user # 提示新用户完善资料 } } return jsonify(response_data), 200 except Exception as e: db.session.rollback() print(f"Login/Register error: {e}") return jsonify({ 'code': 500, 'message': '操作失败,请重试' }), 500 @app.route('/api/auth/verify-token', methods=['POST']) def verify_token(): """验证token有效性(可选接口)""" data = request.get_json() token = data.get('token') if not token: return jsonify({'valid': False, 'message': 'Token不能为空'}), 400 token_data = user_tokens.get(token) if not token_data: return jsonify({'valid': False, 'message': 'Token无效', 'code': 401}), 401 # 检查是否过期(expires 可能是字符串或 datetime) expires = token_data['expires'] if isinstance(expires, str): expires = datetime.fromisoformat(expires) if expires < datetime.now(): user_tokens.delete(token) return jsonify({'valid': False, 'message': 'Token已过期'}), 401 # 获取用户信息 user = User.query.get(token_data['user_id']) if not user: return jsonify({'valid': False, 'message': '用户不存在'}), 404 return jsonify({ 'valid': True, 'user': { 'id': user.id, 'username': user.username, 'phone': user.phone } }), 200 def generate_jwt_token(user_id): """ 生成JWT Token - 与原系统保持一致 Args: user_id: 用户ID Returns: str: JWT token字符串 """ payload = { 'user_id': user_id, 'exp': datetime.utcnow() + timedelta(hours=JWT_EXPIRATION_HOURS), 'iat': datetime.utcnow() } token = jwt.encode(payload, JWT_SECRET_KEY, algorithm=JWT_ALGORITHM) return token @app.route('/api/auth/login/wechat', methods=['POST']) def api_login_wechat(): try: # 1. 获取请求数据 data = request.get_json() code = data.get('code') if data else None if not code: return jsonify({ 'code': 400, 'message': '缺少必要的参数', 'data': None }), 400 # 2. 验证code格式 if not isinstance(code, str) or len(code) < 10: return jsonify({ 'code': 400, 'message': 'code格式无效', 'data': None }), 400 logger.info(f"开始处理微信登录,code长度: {len(code)}") # 3. 调用微信接口获取用户信息 wx_api_url = 'https://api.weixin.qq.com/sns/jscode2session' params = { 'appid': WECHAT_APP_ID, 'secret': WECHAT_APP_SECRET, 'js_code': code, 'grant_type': 'authorization_code' } try: response = requests.get(wx_api_url, params=params, timeout=10) response.raise_for_status() wx_data = response.json() # 检查微信API返回的错误 if 'errcode' in wx_data and wx_data['errcode'] != 0: error_messages = { -1: '系统繁忙,请稍后重试', 40029: 'code无效或已过期', 45011: '频率限制,请稍后再试', 40013: 'AppID错误', 40125: 'AppSecret错误', 40226: '高风险用户,登录被拦截' } error_msg = error_messages.get( wx_data['errcode'], f"微信接口错误: {wx_data.get('errmsg', '未知错误')}" ) logger.error(f"WeChat API error {wx_data['errcode']}: {error_msg}") return jsonify({ 'code': 400, 'message': error_msg, 'data': None }), 400 # 验证必需字段 if 'openid' not in wx_data or 'session_key' not in wx_data: logger.error("响应缺少必需字段") return jsonify({ 'code': 500, 'message': '微信响应格式错误', 'data': None }), 500 openid = wx_data['openid'] session_key = wx_data['session_key'] unionid = wx_data.get('unionid') # 可能为None logger.info(f"成功获取微信用户信息 - OpenID: {openid[:8]}...") if unionid: logger.info(f"获取到UnionID: {unionid[:8]}...") except requests.exceptions.Timeout: logger.error("请求微信API超时") return jsonify({ 'code': 500, 'message': '请求超时,请重试', 'data': None }), 500 except requests.exceptions.RequestException as e: logger.error(f"网络请求失败: {str(e)}") return jsonify({ 'code': 500, 'message': '网络错误', 'data': None }), 500 # 4. 查找或创建用户 - 核心逻辑 user = None is_new_user = False logger.info(f"开始查找用户 - UnionID: {unionid}, OpenID: {openid[:8]}...") if unionid: # 情况1: 有unionid,优先通过unionid查找 user = User.query.filter_by(wechat_union_id=unionid).first() if user: logger.info(f"通过UnionID找到用户: {user.username}") # 更新openid(可能用户从不同小程序登录) if user.wechat_open_id != openid: user.wechat_open_id = openid logger.info(f"更新用户OpenID: {openid[:8]}...") else: # unionid没找到,再尝试用openid查找(处理历史数据) user = User.query.filter_by(wechat_open_id=openid).first() if user: logger.info(f"通过OpenID找到用户: {user.username}") # 补充unionid user.wechat_union_id = unionid logger.info(f"为用户补充UnionID: {unionid[:8]}...") else: # 情况2: 没有unionid,只能通过openid查找 logger.warning("未获取到UnionID(小程序可能未绑定开放平台)") user = User.query.filter_by(wechat_open_id=openid).first() if user: logger.info(f"通过OpenID找到用户: {user.username}") # 5. 创建新用户 if not user: is_new_user = True # 生成唯一用户名 timestamp = int(time.time()) username = f"wx_{timestamp}_{openid[-6:]}" # 确保用户名唯一 counter = 0 base_username = username while User.query.filter_by(username=username).first(): counter += 1 username = f"{base_username}_{counter}" # 创建用户对象(使用你的User模型) user = User( username=username, email=f"{username}@wechat.local", # 占位邮箱 password="wechat_login_no_password" # 微信登录不需要密码 ) # 设置微信相关字段 user.wechat_open_id = openid user.wechat_union_id = unionid user.status = 'active' user.email_confirmed = False # 设置默认值 user.nickname = f"微信用户{openid[-4:]}" user.bio = "" # 空的个人简介 user.avatar_url = None # 稍后会处理 user.is_creator = False user.is_verified = False user.user_level = 1 user.reputation_score = 0 user.contribution_point = 0 user.post_count = 0 user.comment_count = 0 user.follower_count = 0 user.following_count = 0 # 设置默认偏好 user.email_notifications = True user.privacy_level = 'public' user.theme_preference = 'light' db.session.add(user) logger.info(f"创建新用户: {username}") else: # 更新最后登录时间 user.update_last_seen() logger.info(f"用户登录: {user.username}") # 6. 提交数据库更改 try: db.session.commit() except Exception as e: db.session.rollback() logger.error(f"保存用户信息失败: {str(e)}") return jsonify({ 'code': 500, 'message': '保存用户信息失败', 'data': None }), 500 # 7. 生成JWT token(使用原系统的生成方法) token = generate_token(32) # 使用相同的随机字符串生成器 # 存储token映射(与手机登录保持一致) user_tokens.set(token, { 'user_id': user.id, 'expires': datetime.now() + timedelta(days=30) # 30天有效期 }) # 设置session(可选,保持与手机登录一致) session.permanent = True session['user_id'] = user.id session['username'] = user.username session['logged_in'] = True # 9. 构造返回数据 - 完全匹配要求的格式 # 手机绑定状态 phone_bindcd = bool(user.phone and user.phone_confirmed) response_data = { 'code': 200, 'data': { 'token': token, # 现在这个token能被token_required识别了 'user': { 'avatar_url': get_full_avatar_url(user.avatar_url), 'bio': user.bio or "", 'email': user.email, 'id': user.id, 'is_creator': user.is_creator, 'is_verified': user.is_verified, 'nickname': user.nickname or user.username, 'reputation_score': user.reputation_score, 'user_level': user.user_level, 'username': user.username, # 手机绑定状态 'phone_bindcd': phone_bindcd, 'phone': user.phone if phone_bindcd else None } }, 'message': '登录成功' } # 10. 记录日志 logger.info( f"微信登录成功 - 用户ID: {user.id}, " f"用户名: {user.username}, " f"新用户: {is_new_user}, " f"有UnionID: {unionid is not None}" ) return jsonify(response_data), 200 except Exception as e: # 捕获所有未处理的异常 logger.error(f"微信登录处理异常: {str(e)}", exc_info=True) db.session.rollback() return jsonify({ 'code': 500, 'message': '服务器内部错误', 'data': None }), 500 # ============================================ # 微信获取手机号接口 # ============================================ # 微信 access_token 缓存 Key _WECHAT_ACCESS_TOKEN_KEY = 'vf_wechat_access_token' def get_wechat_access_token(force_refresh=False): """ 获取微信小程序 access_token(使用 Redis 缓存,支持多 worker 共享) access_token 有效期为 7200 秒,提前 5 分钟刷新 Args: force_refresh: 是否强制刷新(当 token 失效时使用) """ import time redis_client = None try: import redis redis_client = redis.from_url(os.environ.get('REDIS_URL', 'redis://:VF_Redis_2024@localhost:6379/0')) # 如果不是强制刷新,尝试从 Redis 获取缓存的 token if not force_refresh: cached = redis_client.get(_WECHAT_ACCESS_TOKEN_KEY) if cached: logger.debug("从 Redis 获取微信 access_token") return cached.decode('utf-8') else: # 强制刷新时,先删除旧的缓存 redis_client.delete(_WECHAT_ACCESS_TOKEN_KEY) logger.info("强制刷新:已删除旧的 access_token 缓存") except Exception as e: logger.warning(f"Redis 操作失败: {e},将直接请求微信接口") redis_client = None # 请求新的 access_token url = 'https://api.weixin.qq.com/cgi-bin/token' params = { 'grant_type': 'client_credential', 'appid': WECHAT_APP_ID, 'secret': WECHAT_APP_SECRET } try: response = requests.get(url, params=params, timeout=10) result = response.json() if 'access_token' in result: access_token = result['access_token'] expires_in = result.get('expires_in', 7200) # 存入 Redis(提前 5 分钟过期,确保不会使用即将过期的 token) if redis_client: try: redis_client.setex(_WECHAT_ACCESS_TOKEN_KEY, expires_in - 300, access_token) logger.info(f"微信 access_token 已缓存到 Redis,有效期: {expires_in - 300}秒") except Exception as e: logger.warning(f"Redis 缓存 access_token 失败: {e}") return access_token else: logger.error(f"获取微信 access_token 失败: {result}") return None except Exception as e: logger.error(f"获取微信 access_token 异常: {e}") return None @app.route('/api/auth/bindphone/wechat', methods=['POST']) @token_required def api_bindphone_wechat(): """ 微信小程序绑定手机号接口 前端调用 wx.getPhoneNumber 获取 code,传给后端 后端用 code 调用微信接口获取手机号,绑定到当前用户 请求参数: { "code": "微信返回的动态令牌code" } 返回: { "code": 200, "message": "success", "data": { "phone": "138xxxx1234", "bindcd": true } } """ try: user = request.user # 兼容 JSON 和 form-data 两种方式 code = None data = request.get_json(force=True, silent=True) if data: code = data.get('code') if not code: code = request.form.get('code') or request.args.get('code') if not code: return jsonify({ 'code': 400, 'message': '缺少必要的参数 code', 'data': None }), 400 # 调用微信接口获取手机号(支持 token 失效自动重试) def call_wechat_phone_api(force_refresh=False): access_token = get_wechat_access_token(force_refresh=force_refresh) if not access_token: return None, {'errcode': -1, 'errmsg': '获取 access_token 失败'} wx_phone_url = f'https://api.weixin.qq.com/wxa/business/getuserphonenumber?access_token={access_token}' payload = {'code': code} try: response = requests.post(wx_phone_url, json=payload, timeout=10) return access_token, response.json() except Exception as e: logger.error(f"调用微信获取手机号接口异常: {e}") return access_token, {'errcode': -1, 'errmsg': str(e)} # 第一次尝试 access_token, result = call_wechat_phone_api(force_refresh=False) logger.info(f"微信获取手机号响应: {result}") # 如果 token 失效(40001, 42001),强制刷新后重试一次 errcode = result.get('errcode') if errcode in (40001, 42001, 40014): logger.warning(f"access_token 失效 (errcode={errcode}),强制刷新后重试") access_token, result = call_wechat_phone_api(force_refresh=True) logger.info(f"重试后微信获取手机号响应: {result}") errcode = result.get('errcode') # 解析微信返回结果 if errcode != 0: error_msg = result.get('errmsg', '未知错误') logger.error(f"微信获取手机号失败: errcode={errcode}, errmsg={error_msg}") # 常见错误码处理 if errcode == 40029: return jsonify({'code': 400, 'message': 'code无效或已过期,请重新获取', 'data': None}), 400 elif errcode == 40013: return jsonify({'code': 400, 'message': 'AppID无效', 'data': None}), 400 elif errcode == -1: return jsonify({'code': 500, 'message': '微信服务繁忙,请稍后重试', 'data': None}), 500 elif errcode in (40001, 42001, 40014): return jsonify({'code': 500, 'message': '微信凭证失效,请稍后重试', 'data': None}), 500 else: return jsonify({'code': 400, 'message': f'获取手机号失败: {error_msg}', 'data': None}), 400 phone_info = result.get('phone_info', {}) phone_number = phone_info.get('phoneNumber') or phone_info.get('purePhoneNumber') if not phone_number: return jsonify({ 'code': 400, 'message': '未获取到手机号', 'data': None }), 400 # 4. 检查当前用户是否已绑定此手机号 if user.phone == phone_number and user.phone_confirmed: # 已经绑定过了,直接返回成功 return jsonify({ 'code': 200, 'message': '手机号已绑定', 'data': { 'phone': phone_number, 'bindcd': True } }) # 5. 检查手机号是否已被其他用户绑定(只查确认绑定的) existing_user = User.query.filter( User.phone == phone_number, User.phone_confirmed == True, User.id != user.id ).first() if existing_user: logger.warning(f"手机号 {phone_number[:3]}****{phone_number[-4:]} 已被用户 {existing_user.id} 绑定,当前用户 {user.id} 尝试绑定失败") return jsonify({ 'code': 400, 'message': '该手机号已被其他账号绑定', 'data': None }), 400 # 6. 更新用户手机号 user.phone = phone_number user.phone_confirmed = True user.phone_confirm_time = beijing_now() db.session.commit() logger.info(f"用户 {user.id} 成功绑定手机号: {phone_number[:3]}****{phone_number[-4:]}") return jsonify({ 'code': 200, 'message': '绑定成功', 'data': { 'phone': phone_number, 'bindcd': True } }) except Exception as e: logger.error(f"绑定手机号异常: {e}", exc_info=True) db.session.rollback() return jsonify({ 'code': 500, 'message': '服务器内部错误', 'data': None }), 500 @app.route('/api/auth/login/email', methods=['POST']) def api_login_email(): """邮箱登录接口""" try: data = request.get_json() email = data.get('email') password = data.get('password') if not email or not password: return jsonify({ 'code': 400, 'message': '邮箱和密码不能为空', 'data': None }), 400 user = User.query.filter_by(email=email).first() if not user or not user.check_password(password): return jsonify({ 'code': 400, 'message': '邮箱或密码错误', 'data': None }), 400 token = generate_jwt_token(user.id) login_user(user) user.update_last_seen() db.session.commit() return jsonify({ 'code': 200, 'message': '登录成功', 'data': { 'token': token, 'user_id': user.id, 'username': user.username, 'email': user.email, 'is_verified': user.is_verified, 'user_level': user.user_level } }) except Exception as e: return jsonify({ 'code': 500, 'message': str(e), 'data': None }), 500 # 5. 事件详情-相关标的接口 @app.route('/api/event//related-stocks-detail', methods=['GET']) def api_event_related_stocks(event_id): """事件相关标的详情接口 - 仅限 Pro/Max 会员""" try: from datetime import datetime, timedelta, time as dt_time from sqlalchemy import text event = Event.query.get_or_404(event_id) related_stocks = event.related_stocks.order_by(RelatedStock.correlation.desc()).all() # 获取ClickHouse客户端用于分时数据查询 client = get_clickhouse_client() # 获取事件时间(如果事件有开始时间,使用开始时间;否则使用创建时间) event_time = event.start_time if event.start_time else event.created_at current_time = datetime.now() # 定义交易日和时间范围计算函数(与 app.py 中的逻辑完全一致) def get_trading_day_and_times(event_datetime): event_date = event_datetime.date() event_time_only = event_datetime.time() # Trading hours market_open = dt_time(9, 30) market_close = dt_time(15, 0) with engine.connect() as conn: # First check if the event date itself is a trading day is_trading_day = conn.execute(text(""" SELECT 1 FROM trading_days WHERE EXCHANGE_DATE = :date """), {"date": event_date}).fetchone() is not None if is_trading_day: # If it's a trading day, determine time period based on event time if event_time_only < market_open: # Before market opens - use full trading day return event_date, market_open, market_close elif event_time_only > market_close: # After market closes - get next trading day next_trading_day = conn.execute(text(""" SELECT EXCHANGE_DATE FROM trading_days WHERE EXCHANGE_DATE > :date ORDER BY EXCHANGE_DATE LIMIT 1 """), {"date": event_date}).fetchone() # Convert to date object if we found a next trading day return (next_trading_day[0].date() if next_trading_day else None, market_open, market_close) else: # During trading hours return event_date, event_time_only, market_close else: # If not a trading day, get next trading day next_trading_day = conn.execute(text(""" SELECT EXCHANGE_DATE FROM trading_days WHERE EXCHANGE_DATE > :date ORDER BY EXCHANGE_DATE LIMIT 1 """), {"date": event_date}).fetchone() # Convert to date object if we found a next trading day return (next_trading_day[0].date() if next_trading_day else None, market_open, market_close) trading_day, start_time, end_time = get_trading_day_and_times(event_time) if not trading_day: # 如果没有交易日,返回空数据 return jsonify({ 'code': 200, 'message': 'success', 'data': { 'event_id': event_id, 'event_title': event.title, 'event_desc': event.description, 'event_type': event.event_type, 'event_importance': event.importance, 'event_status': event.status, 'event_created_at': event.created_at.strftime("%Y-%m-%d %H:%M:%S"), 'event_start_time': event.start_time.isoformat() if event.start_time else None, 'event_end_time': event.end_time.isoformat() if event.end_time else None, 'keywords': event.keywords, 'view_count': event.view_count, 'post_count': event.post_count, 'follower_count': event.follower_count, 'related_stocks': [], 'total_count': 0 } }) # For historical dates, ensure we're using actual data start_datetime = datetime.combine(trading_day, start_time) end_datetime = datetime.combine(trading_day, end_time) # If the trading day is in the future relative to current time, return only names without data if trading_day > current_time.date(): start_datetime = datetime.combine(trading_day, start_time) end_datetime = datetime.combine(trading_day, end_time) print(f"事件时间: {event_time}, 交易日: {trading_day}, 时间范围: {start_datetime} - {end_datetime}") def get_minute_chart_data(stock_code): """获取股票分时图数据""" try: # 获取当前日期或最新交易日的分时数据 from datetime import datetime, timedelta, time as dt_time today = datetime.now().date() # 获取最新交易日的分时数据 data = client.execute(""" SELECT timestamp, open, high, low, close, volume, amt FROM stock_minute WHERE code = %(code)s AND timestamp >= %(start)s AND timestamp <= %(end)s ORDER BY timestamp """, { 'code': stock_code, 'start': datetime.combine(today, dt_time(9, 30)), 'end': datetime.combine(today, dt_time(15, 0)) }) # 如果今天没有数据,获取最近的交易日数据 if not data: # 获取最近的交易日数据 recent_data = client.execute(""" SELECT timestamp, open, high, low, close, volume, amt FROM stock_minute WHERE code = %(code)s AND timestamp >= ( SELECT MAX (timestamp) - INTERVAL 1 DAY FROM stock_minute WHERE code = %(code)s ) ORDER BY timestamp """, { 'code': stock_code }) data = recent_data # 格式化数据 minute_data = [] for row in data: minute_data.append({ 'time': row[0].strftime('%H:%M'), 'open': float(row[1]) if row[1] else None, 'high': float(row[2]) if row[2] else None, 'low': float(row[3]) if row[3] else None, 'close': float(row[4]) if row[4] else None, 'volume': float(row[5]) if row[5] else None, 'amount': float(row[6]) if row[6] else None }) return minute_data except Exception as e: print(f"Error fetching minute data for {stock_code}: {e}") return [] # ==================== 性能优化:批量查询所有股票数据 ==================== # 1. 收集所有股票代码 stock_codes = [stock.stock_code for stock in related_stocks] # 2. 批量查询股票基本信息 stock_info_map = {} if stock_codes: stock_infos = StockBasicInfo.query.filter(StockBasicInfo.SECCODE.in_(stock_codes)).all() for info in stock_infos: stock_info_map[info.SECCODE] = info # 处理不带后缀的股票代码 base_codes = [code.split('.')[0] for code in stock_codes if '.' in code and code not in stock_info_map] if base_codes: base_infos = StockBasicInfo.query.filter(StockBasicInfo.SECCODE.in_(base_codes)).all() for info in base_infos: # 将不带后缀的信息映射到带后缀的代码 for code in stock_codes: if code.split('.')[0] == info.SECCODE and code not in stock_info_map: stock_info_map[code] = info # 3. 批量查询 ClickHouse 数据(价格、涨跌幅、分时图数据) price_data_map = {} # 存储价格和涨跌幅数据 minute_chart_map = {} # 存储分时图数据 try: if stock_codes: print(f"批量查询 {len(stock_codes)} 只股票的价格数据...") print(f" 股票代码: {stock_codes}") print(f" 时间范围: {start_datetime} 至 {end_datetime}") # 诊断:检查这些股票在 ClickHouse 中是否有数据 diag_query = """ SELECT code, MIN(timestamp) as min_ts, MAX(timestamp) as max_ts, COUNT(*) as cnt FROM stock_minute WHERE code IN %(codes)s GROUP BY code """ diag_result = client.execute(diag_query, {'codes': tuple(stock_codes)}) print(f" 诊断 - 各股票数据情况: {diag_result}") # 3.1 批量查询价格和涨跌幅数据(简化版,避免 CTE 兼容性问题) # 使用 argMin/argMax 聚合函数获取第一条和最后一条记录的值 batch_price_query = """ SELECT code, argMin(close, timestamp) as first_price, argMax(close, timestamp) as last_price, (argMax(close, timestamp) - argMin(close, timestamp)) / argMin(close, timestamp) * 100 as change_pct, argMax(open, timestamp) as open_price, max(high) as high_price, min(low) as low_price, sum(volume) as volume, sum(amt) as amount FROM stock_minute WHERE code IN %(codes)s AND timestamp >= %(start)s AND timestamp <= %(end)s GROUP BY code """ price_data = client.execute(batch_price_query, { 'codes': tuple(stock_codes), 'start': start_datetime, 'end': end_datetime }) print(f"批量查询返回 {len(price_data)} 条价格数据") # 解析批量查询结果 for row in price_data: code = row[0] first_price = float(row[1]) if row[1] is not None else None last_price = float(row[2]) if row[2] is not None else None change_pct = float(row[3]) if row[3] is not None else None open_price = float(row[4]) if row[4] is not None else None high_price = float(row[5]) if row[5] is not None else None low_price = float(row[6]) if row[6] is not None else None volume = int(row[7]) if row[7] is not None else None amount = float(row[8]) if row[8] is not None else None change_amount = None if last_price is not None and first_price is not None: change_amount = last_price - first_price price_data_map[code] = { 'latest_price': last_price, 'first_price': first_price, 'change_pct': change_pct, 'change_amount': change_amount, 'open_price': open_price, 'high_price': high_price, 'low_price': low_price, 'volume': volume, 'amount': amount, } # 3.2 批量查询分时图数据 print(f"批量查询分时图数据...") minute_chart_query = """ SELECT code, timestamp, open, high, low, close, volume, amt FROM stock_minute WHERE code IN %(codes)s AND timestamp >= %(start)s AND timestamp <= %(end)s ORDER BY code, timestamp \ """ minute_data = client.execute(minute_chart_query, { 'codes': tuple(stock_codes), 'start': start_datetime, 'end': end_datetime }) print(f"批量查询返回 {len(minute_data)} 条分时数据") # 按股票代码分组分时数据 for row in minute_data: code = row[0] if code not in minute_chart_map: minute_chart_map[code] = [] minute_chart_map[code].append({ 'time': row[1].strftime('%H:%M'), 'open': float(row[2]) if row[2] else None, 'high': float(row[3]) if row[3] else None, 'low': float(row[4]) if row[4] else None, 'close': float(row[5]) if row[5] else None, 'volume': float(row[6]) if row[6] else None, 'amount': float(row[7]) if row[7] else None }) except Exception as e: print(f"批量查询 ClickHouse 失败: {e}") # 如果批量查询失败,price_data_map 和 minute_chart_map 为空,后续会使用降级方案 # 4. 组装每个股票的数据(从批量查询结果中获取) stocks_data = [] for stock in related_stocks: print(f"正在组装股票 {stock.stock_code} 的数据...") # 从批量查询结果中获取股票基本信息 stock_info = stock_info_map.get(stock.stock_code) # 从批量查询结果中获取价格数据 price_info = price_data_map.get(stock.stock_code) latest_price = None first_price = None change_pct = None change_amount = None open_price = None high_price = None low_price = None volume = None amount = None trade_date = trading_day if price_info: # 使用批量查询的结果 latest_price = price_info['latest_price'] first_price = price_info['first_price'] change_pct = price_info['change_pct'] change_amount = price_info['change_amount'] open_price = price_info['open_price'] high_price = price_info['high_price'] low_price = price_info['low_price'] volume = price_info['volume'] amount = price_info['amount'] else: # 如果批量查询没有返回数据,使用降级方案(TradeData) print(f"股票 {stock.stock_code} 批量查询无数据,使用降级方案...") try: latest_trade = None search_codes = [stock.stock_code, stock.stock_code.split('.')[0]] for code in search_codes: latest_trade = TradeData.query.filter_by(SECCODE=code) \ .order_by(TradeData.TRADEDATE.desc()).first() if latest_trade: break if latest_trade: latest_price = float(latest_trade.F007N) if latest_trade.F007N else None open_price = float(latest_trade.F003N) if latest_trade.F003N else None high_price = float(latest_trade.F005N) if latest_trade.F005N else None low_price = float(latest_trade.F006N) if latest_trade.F006N else None first_price = float(latest_trade.F002N) if latest_trade.F002N else None volume = float(latest_trade.F004N) if latest_trade.F004N else None amount = float(latest_trade.F011N) if latest_trade.F011N else None trade_date = latest_trade.TRADEDATE # 计算涨跌幅 if latest_trade.F010N: change_pct = float(latest_trade.F010N) if latest_trade.F009N: change_amount = float(latest_trade.F009N) except Exception as fallback_error: print(f"降级查询也失败 {stock.stock_code}: {fallback_error}") # 从批量查询结果中获取分时图数据 minute_chart_data = minute_chart_map.get(stock.stock_code, []) stock_data = { 'id': stock.id, 'stock_code': stock.stock_code, 'stock_name': stock.stock_name, 'sector': stock.sector, 'relation_desc': stock.relation_desc, 'correlation': stock.correlation, 'momentum': stock.momentum, 'listing_date': stock_info.F006D.isoformat() if stock_info and stock_info.F006D else None, 'market': stock_info.F005V if stock_info else None, # 交易数据 'trade_data': { 'latest_price': latest_price, 'first_price': first_price, # 事件发生时的价格 'open_price': open_price, 'high_price': high_price, 'low_price': low_price, 'change_amount': round(change_amount, 2) if change_amount is not None else None, 'change_pct': round(change_pct, 2) if change_pct is not None else None, 'volume': volume, 'amount': amount, 'trade_date': trade_date.isoformat() if trade_date else None, 'event_start_time': start_datetime.isoformat() if start_datetime else None, # 事件开始时间 'event_end_time': end_datetime.isoformat() if end_datetime else None, # 查询结束时间 } if latest_price is not None else None, # 分时图数据 'minute_chart_data': minute_chart_data, # 图表URL 'charts': { 'minute_chart_url': f"/api/stock/{stock.stock_code}/minute-chart", 'daily_chart_url': f"/api/stock/{stock.stock_code}/kline", } } stocks_data.append(stock_data) return jsonify({ 'code': 200, 'message': 'success', 'data': { 'event_id': event_id, 'event_title': event.title, 'event_desc': event.description, 'event_type': event.event_type, 'event_importance': event.importance, 'event_status': event.status, 'event_created_at': event.created_at.strftime("%Y-%m-%d %H:%M:%S"), 'event_start_time': event.start_time.isoformat() if event.start_time else None, 'event_end_time': event.end_time.isoformat() if event.end_time else None, 'keywords': event.keywords, 'view_count': event.view_count, 'post_count': event.post_count, 'follower_count': event.follower_count, 'related_stocks': stocks_data, 'total_count': len(stocks_data) } }) except Exception as e: return jsonify({ 'code': 500, 'message': str(e), 'data': None }), 500 @app.route('/api/stock//minute-chart', methods=['GET']) def get_minute_chart_data(stock_code): """获取股票分时图数据 - 仅限 Pro/Max 会员""" client = get_clickhouse_client() try: # 获取当前日期或最新交易日的分时数据 from datetime import datetime, timedelta, time as dt_time today = datetime.now().date() # 获取最新交易日的分时数据 data = client.execute(""" SELECT timestamp, open, high, low, close, volume, amt FROM stock_minute WHERE code = %(code)s AND timestamp >= %(start)s AND timestamp <= %(end)s ORDER BY timestamp """, { 'code': stock_code, 'start': datetime.combine(today, dt_time(9, 30)), 'end': datetime.combine(today, dt_time(15, 0)) }) # 如果今天没有数据,获取最近的交易日数据 if not data: # 获取最近的交易日数据 recent_data = client.execute(""" SELECT timestamp, open, high, low, close, volume, amt FROM stock_minute WHERE code = %(code)s AND timestamp >= ( SELECT MAX (timestamp) - INTERVAL 1 DAY FROM stock_minute WHERE code = %(code)s ) ORDER BY timestamp """, { 'code': stock_code }) data = recent_data # 格式化数据 minute_data = [] for row in data: minute_data.append({ 'time': row[0].strftime('%H:%M'), 'open': float(row[1]) if row[1] else None, 'high': float(row[2]) if row[2] else None, 'low': float(row[3]) if row[3] else None, 'close': float(row[4]) if row[4] else None, 'volume': float(row[5]) if row[5] else None, 'amount': float(row[6]) if row[6] else None }) return minute_data except Exception as e: print(f"Error getting minute chart data: {e}") return [] @app.route('/api/event//stock//detail', methods=['GET']) def api_stock_detail(event_id, stock_code): """个股详情接口 - 仅限 Pro/Max 会员""" try: # 验证事件是否存在 event = Event.query.get_or_404(event_id) # 获取查询参数 include_minute_data = request.args.get('include_minute_data', 'true').lower() == 'true' include_full_sources = request.args.get('include_full_sources', 'false').lower() == 'true' # 是否包含完整研报来源 # 获取股票基本信息 basic_info = None base_code = stock_code.split('.')[0] # 去掉后缀 # 按优先级查找股票信息 basic_info = StockBasicInfo.query.filter_by(SECCODE=stock_code).first() if not basic_info: basic_info = StockBasicInfo.query.filter( StockBasicInfo.SECCODE.ilike(f"{stock_code}%") ).first() if not basic_info: basic_info = StockBasicInfo.query.filter( StockBasicInfo.SECCODE.ilike(f"{base_code}%") ).first() company_info = CompanyInfo.query.filter_by(SECCODE=stock_code).first() if not company_info: company_info = CompanyInfo.query.filter_by(SECCODE=base_code).first() if not basic_info: return jsonify({ 'code': 404, 'stock_code': stock_code, 'message': '股票不存在', 'data': None }), 404 # 获取最新交易数据 latest_trade = TradeData.query.filter_by(SECCODE=stock_code) \ .order_by(TradeData.TRADEDATE.desc()).first() if not latest_trade: latest_trade = TradeData.query.filter_by(SECCODE=base_code) \ .order_by(TradeData.TRADEDATE.desc()).first() # 获取分时数据 minute_chart_data = [] if include_minute_data: minute_chart_data = get_minute_chart_data(stock_code) # 获取该事件的相关描述 related_stock = RelatedStock.query.filter_by( event_id=event_id ).filter( db.or_( RelatedStock.stock_code == stock_code, RelatedStock.stock_code == base_code, RelatedStock.stock_code.like(f"{base_code}.%") ) ).first() related_desc = None if related_stock: # 处理研报来源数据 retrieved_sources_data = None sources_summary = None if related_stock.retrieved_sources: try: # 解析研报来源 import json sources = related_stock.retrieved_sources if isinstance(related_stock.retrieved_sources, list) else json.loads( related_stock.retrieved_sources) # 统计信息 sources_summary = { 'total_count': len(sources), 'has_sources': True, 'match_scores': {} } # 统计匹配分数分布 for source in sources: score = source.get('match_score', '未知') sources_summary['match_scores'][score] = sources_summary['match_scores'].get(score, 0) + 1 # 根据参数决定返回完整数据还是摘要 if include_full_sources: # 返回完整的研报来源 retrieved_sources_data = sources else: # 只返回前5条高质量来源作为预览 # 优先返回匹配度高的 high_quality_sources = [s for s in sources if s.get('match_score') == '好'][:3] medium_quality_sources = [s for s in sources if s.get('match_score') == '中'][:2] preview_sources = high_quality_sources + medium_quality_sources if not preview_sources: # 如果没有高中匹配度的,返回前5条 preview_sources = sources[:5] retrieved_sources_data = [] for source in preview_sources: retrieved_sources_data.append({ 'report_title': source.get('report_title', ''), 'author': source.get('author', ''), 'sentences': source.get('sentences', '')[:200] + '...' if len( source.get('sentences', '')) > 200 else source.get('sentences', ''), # 限制长度 'match_score': source.get('match_score', ''), 'declare_date': source.get('declare_date', '') }) except Exception as e: print(f"Error processing retrieved_sources for stock {stock_code}: {e}") sources_summary = {'has_sources': False, 'error': str(e)} else: sources_summary = {'has_sources': False, 'total_count': 0} related_desc = { 'event_id': related_stock.event_id, 'relation_desc': related_stock.relation_desc, 'sector': related_stock.sector, 'correlation': float(related_stock.correlation) if related_stock.correlation else None, 'momentum': related_stock.momentum, # 新增研报来源相关字段 'retrieved_sources': retrieved_sources_data, 'sources_summary': sources_summary, 'retrieved_update_time': related_stock.retrieved_update_time.isoformat() if related_stock.retrieved_update_time else None, # 添加获取完整来源的URL 'sources_detail_url': f"/api/event/{event_id}/stock/{stock_code}/sources" if sources_summary.get( 'has_sources') else None } response_data = { 'code': 200, 'message': 'success', 'data': { 'event_info': { 'event_id': event.id, 'event_title': event.title, 'event_description': event.description, 'event_start_time': event.start_time.isoformat() if event.start_time else None, 'event_created_at': event.created_at.strftime("%Y-%m-%d %H:%M:%S") if event.created_at else None }, 'basic_info': { 'stock_code': basic_info.SECCODE, 'stock_name': basic_info.SECNAME, 'org_name': basic_info.ORGNAME, 'pinyin': basic_info.F001V, 'category': basic_info.F003V, 'market': basic_info.F005V, 'listing_date': basic_info.F006D.isoformat() if basic_info.F006D else None, 'status': basic_info.F011V }, 'company_info': { 'english_name': company_info.F001V if company_info else None, 'legal_representative': company_info.F003V if company_info else None, 'main_business': company_info.F015V if company_info else None, 'business_scope': company_info.F016V if company_info else None, 'company_intro': company_info.F017V if company_info else None, 'csrc_industry_l1': company_info.F030V if company_info else None, 'csrc_industry_l2': company_info.F032V if company_info else None }, 'latest_trade': { 'trade_date': latest_trade.TRADEDATE.isoformat() if latest_trade else None, 'close_price': float(latest_trade.F007N) if latest_trade and latest_trade.F007N else None, 'change': float(latest_trade.F009N) if latest_trade and latest_trade.F009N else None, 'change_pct': float(latest_trade.F010N) if latest_trade and latest_trade.F010N else None, 'volume': float(latest_trade.F004N) if latest_trade and latest_trade.F004N else None, 'amount': float(latest_trade.F011N) if latest_trade and latest_trade.F011N else None } if latest_trade else None, 'minute_chart_data': minute_chart_data, 'related_desc': related_desc } } response = jsonify(response_data) response.headers['Content-Type'] = 'application/json; charset=utf-8' return response except Exception as e: return jsonify({ 'code': 500, 'message': str(e), 'data': None }), 500 def get_stock_minute_chart_data(stock_code): """获取股票分时图数据""" try: client = get_clickhouse_client() # 获取当前日期(使用最新的交易日) from datetime import datetime, timedelta, time as dt_time import csv def get_trading_days(): trading_days = set() with open('tdays.csv', 'r') as f: reader = csv.DictReader(f) for row in reader: trading_days.add(datetime.strptime(row['DateTime'], '%Y/%m/%d').date()) return trading_days trading_days = get_trading_days() def find_latest_trading_day(current_date): """找到最新的交易日""" while current_date >= min(trading_days): if current_date in trading_days: return current_date current_date -= timedelta(days=1) return None target_date = find_latest_trading_day(datetime.now().date()) if not target_date: return [] # 获取分时数据 data = client.execute(""" SELECT timestamp, open, high, low, close, volume, amt FROM stock_minute WHERE code = %(code)s AND timestamp BETWEEN %(start)s AND %(end)s ORDER BY timestamp """, { 'code': stock_code, 'start': datetime.combine(target_date, dt_time(9, 30)), 'end': datetime.combine(target_date, dt_time(15, 0)) }) minute_data = [] for row in data: minute_data.append({ 'time': row[0].strftime('%H:%M'), 'open': float(row[1]), 'high': float(row[2]), 'low': float(row[3]), 'close': float(row[4]), 'volume': float(row[5]), 'amount': float(row[6]) }) return minute_data except Exception as e: print(f"Error getting minute chart data: {e}") return [] # 7. 事件详情-相关概念接口 @app.route('/api/event//related-concepts', methods=['GET']) def api_event_related_concepts(event_id): """事件相关概念接口""" try: event = Event.query.get_or_404(event_id) related_concepts = event.related_concepts.all() base_url = request.host_url concepts_data = [] for concept in related_concepts: image_paths = concept.image_paths_list image_urls = [base_url + 'data/concepts/' + p for p in image_paths] concepts_data.append({ 'id': concept.id, 'concept_code': concept.concept_code, 'concept': concept.concept, 'reason': concept.reason, 'image_paths': image_paths, 'image_urls': image_urls, 'first_image': image_urls[0] if image_urls else None }) return jsonify({ 'code': 200, 'message': 'success', 'data': { 'event_id': event_id, 'event_title': event.title, 'related_concepts': concepts_data, 'total_count': len(concepts_data) } }) except Exception as e: return jsonify({ 'code': 500, 'message': str(e), 'data': None }), 500 # 8. 事件详情-历史事件接口 @app.route('/api/event//historical-events', methods=['GET']) def api_event_historical_events(event_id): """事件历史事件接口""" try: event = Event.query.get_or_404(event_id) historical_events = event.historical_events.order_by( HistoricalEvent.importance.desc(), HistoricalEvent.event_date.desc() ).all() events_data = [] for hist_event in historical_events: # 获取相关股票信息 related_stocks = [] valid_changes = [] # 用于计算涨跌幅 for stock in hist_event.stocks.all(): base_stock_code = stock.stock_code.split('.')[0] # 获取股票当日交易数据 trade_data = TradeData.query.filter( TradeData.SECCODE.startswith(base_stock_code) ).order_by(TradeData.TRADEDATE.desc()).first() if trade_data and trade_data.F010N is not None: daily_change = float(trade_data.F010N) valid_changes.append(daily_change) else: daily_change = None stock_data = { 'stock_code': stock.stock_code, 'stock_name': stock.stock_name, 'relation_desc': stock.relation_desc, 'correlation': stock.correlation, 'sector': stock.sector, 'daily_change': daily_change, 'has_trade_data': True if trade_data else False } related_stocks.append(stock_data) # 计算相关股票的平均涨幅和最大涨幅 avg_change = None max_change = None if valid_changes: avg_change = sum(valid_changes) / len(valid_changes) max_change = max(valid_changes) events_data.append({ 'id': hist_event.id, 'title': hist_event.title, 'content': hist_event.content, 'event_date': hist_event.event_date.isoformat() if hist_event.event_date else None, 'relevance': hist_event.relevance, 'importance': hist_event.importance, 'related_stocks': related_stocks, # 使用计算得到的涨幅数据 'related_avg_chg': round(avg_change, 2) if avg_change is not None else None, 'related_max_chg': round(max_change, 2) if max_change is not None else None }) # 计算当前事件的相关股票涨幅数据 current_valid_changes = [] for stock in event.related_stocks: base_stock_code = stock.stock_code.split('.')[0] trade_data = TradeData.query.filter( TradeData.SECCODE.startswith(base_stock_code) ).order_by(TradeData.TRADEDATE.desc()).first() if trade_data and trade_data.F010N is not None: current_valid_changes.append(float(trade_data.F010N)) current_avg_change = None current_max_change = None if current_valid_changes: current_avg_change = sum(current_valid_changes) / len(current_valid_changes) current_max_change = max(current_valid_changes) return jsonify({ 'code': 200, 'message': 'success', 'data': { 'event_id': event_id, 'event_title': event.title, 'invest_score': event.invest_score, 'related_avg_chg': round(current_avg_change, 2) if current_avg_change is not None else None, 'related_max_chg': round(current_max_change, 2) if current_max_change is not None else None, 'historical_events': events_data, 'total_count': len(events_data) } }) except Exception as e: print(f"Error in api_event_historical_events: {str(e)}") return jsonify({ 'code': 500, 'message': str(e), 'data': None }), 500 @app.route('/api/event//comments', methods=['GET']) def get_event_comments(event_id): """获取事件的所有评论和帖子(嵌套格式) Query参数: - page: 页码(默认1) - per_page: 每页评论数(默认20) - sort: 排序方式(time_desc/time_asc/hot, 默认time_desc) - include_posts: 是否包含帖子信息(默认true) - reply_limit: 每个评论显示的回复数量限制(默认3) 返回: { "success": true, "data": { "event": { "id": 事件ID, "title": "事件标题", "description": "事件描述" }, "posts": [帖子信息], "total": 总评论数, "current_page": 当前页码, "total_pages": 总页数, "comments": [ { "comment_id": 评论ID, "content": "评论内容", "created_at": "评论时间", "post_id": 所属帖子ID, "post_title": "帖子标题", "user": { "user_id": 用户ID, "nickname": "用户昵称", "avatar_url": "头像URL" }, "reply_count": 总回复数量, "has_more_replies": 是否有更多回复, "list": [ # 回复列表 { "comment_id": 回复ID, "content": "回复内容", "created_at": "回复时间", "user": { "user_id": 用户ID, "nickname": "用户昵称", "avatar_url": "头像URL" }, "reply_to": { # 被回复的用户信息 "user_id": 用户ID, "nickname": "用户昵称" } } ] } ] } } """ try: # 获取查询参数 page = request.args.get('page', 1, type=int) per_page = request.args.get('per_page', 20, type=int) sort = request.args.get('sort', 'time_desc') include_posts = request.args.get('include_posts', 'true').lower() == 'true' reply_limit = request.args.get('reply_limit', 3, type=int) # 每个评论显示的回复数限制 # 参数验证 if page < 1: page = 1 if per_page < 1 or per_page > 100: per_page = 20 if reply_limit < 0 or reply_limit > 50: # 限制回复数量 reply_limit = 3 # 获取事件信息 event = Event.query.get_or_404(event_id) # 获取事件下的所有帖子 posts_query = Post.query.filter_by(event_id=event_id, status='active') \ .order_by(Post.is_top.desc(), Post.created_at.desc()) posts = posts_query.all() # 格式化帖子数据 posts_data = [] if include_posts: for post in posts: posts_data.append({ 'post_id': post.id, 'title': post.title, 'content': post.content, 'content_type': post.content_type, 'created_at': post.created_at.strftime('%Y-%m-%d %H:%M:%S'), 'updated_at': post.updated_at.strftime('%Y-%m-%d %H:%M:%S') if post.updated_at else None, 'likes_count': post.likes_count, 'comments_count': post.comments_count, 'view_count': post.view_count, 'is_top': post.is_top, 'user': { 'user_id': post.user.id, 'username': post.user.username, 'nickname': post.user.nickname or post.user.username, 'avatar_url': get_full_avatar_url(post.user.avatar_url), 'user_level': post.user.user_level, 'is_verified': post.user.is_verified, 'is_creator': post.user.is_creator } }) # 获取帖子ID列表用于查询评论 post_ids = [p.id for p in posts] if not post_ids: return jsonify({ 'success': True, 'data': { 'event': { 'id': event.id, 'title': event.title, 'description': event.description, 'event_type': event.event_type, 'importance': event.importance, 'status': event.status }, 'posts': posts_data, 'total': 0, 'current_page': page, 'total_pages': 0, 'comments': [] } }) # 构建基础查询 - 只查询主评论 base_query = Comment.query.filter( Comment.post_id.in_(post_ids), Comment.parent_id == None, # 只查询主评论 Comment.status == 'active' ) # 排序处理 if sort == 'time_asc': base_query = base_query.order_by(Comment.created_at.asc()) elif sort == 'hot': # 这里可以根据你的业务逻辑添加热度排序 base_query = base_query.order_by(Comment.created_at.desc()) else: # 默认按时间倒序 base_query = base_query.order_by(Comment.created_at.desc()) # 执行分页查询 pagination = base_query.paginate(page=page, per_page=per_page, error_out=False) # 格式化评论数据(嵌套格式) comments_data = [] for comment in pagination.items: # 获取评论的总回复数量 reply_count = Comment.query.filter_by( parent_id=comment.id, status='active' ).count() # 获取指定数量的回复 replies_query = Comment.query.filter_by( parent_id=comment.id, status='active' ).order_by(Comment.created_at.asc()) # 回复按时间正序排列 if reply_limit > 0: replies = replies_query.limit(reply_limit).all() else: replies = [] # 格式化回复数据 - 作为list字段 replies_list = [] for reply in replies: # 获取被回复的用户信息(这里是回复主评论,所以reply_to就是主评论的用户) reply_to_user = { 'user_id': comment.user.id, 'nickname': comment.user.nickname or comment.user.username } replies_list.append({ 'comment_id': reply.id, 'content': reply.content, 'created_at': reply.created_at.strftime('%Y-%m-%d %H:%M:%S'), 'user': { 'user_id': reply.user.id, 'username': reply.user.username, 'nickname': reply.user.nickname or reply.user.username, 'avatar_url': get_full_avatar_url(reply.user.avatar_url), 'user_level': reply.user.user_level, 'is_verified': reply.user.is_verified }, 'reply_to': reply_to_user }) # 获取评论所属的帖子信息 post = comment.post # 构建嵌套格式的评论数据 comments_data.append({ 'comment_id': comment.id, 'content': comment.content, 'created_at': comment.created_at.strftime('%Y-%m-%d %H:%M:%S'), 'post_id': comment.post_id, 'post_title': post.title if post else None, 'post_content_preview': post.content[:100] + '...' if post and len( post.content) > 100 else post.content if post else None, 'user': { 'user_id': comment.user.id, 'username': comment.user.username, 'nickname': comment.user.nickname or comment.user.username, 'avatar_url': get_full_avatar_url(comment.user.avatar_url), 'user_level': comment.user.user_level, 'is_verified': comment.user.is_verified }, 'reply_count': reply_count, 'has_more_replies': reply_count > len(replies_list), 'list': replies_list # 嵌套的回复列表 }) return jsonify({ 'success': True, 'data': { 'event': { 'id': event.id, 'title': event.title, 'description': event.description, 'event_type': event.event_type, 'importance': event.importance, 'status': event.status, 'created_at': event.created_at.strftime('%Y-%m-%d %H:%M:%S'), 'hot_score': event.hot_score, 'view_count': event.view_count, 'post_count': event.post_count, 'follower_count': event.follower_count }, 'posts': posts_data, 'posts_count': len(posts_data), 'total': pagination.total, 'current_page': pagination.page, 'total_pages': pagination.pages, 'comments': comments_data } }) except Exception as e: return jsonify({ 'success': False, 'message': '获取评论列表失败', 'error': str(e) }), 500 @app.route('/api/comment//replies', methods=['GET']) def get_comment_replies(comment_id): """获取某条评论的所有回复 Query参数: - page: 页码(默认1) - per_page: 每页回复数(默认20) - sort: 排序方式(time_desc/time_asc, 默认time_desc) 返回格式: { "code": 200, "message": "success", "data": { "comment": { # 原评论信息 "id": 评论ID, "content": "评论内容", "created_at": "评论时间", "user": { "id": 用户ID, "nickname": "用户昵称", "avatar_url": "头像URL" } }, "replies": { # 回复信息 "total": 总回复数, "current_page": 当前页码, "total_pages": 总页数, "items": [ { "id": 回复ID, "content": "回复内容", "created_at": "回复时间", "user": { "id": 用户ID, "nickname": "用户昵称", "avatar_url": "头像URL" }, "reply_to": { # 被回复的用户信息 "id": 用户ID, "nickname": "用户昵称" } } ] } } } """ try: # 获取查询参数 page = request.args.get('page', 1, type=int) per_page = request.args.get('per_page', 20, type=int) sort = request.args.get('sort', 'time_desc') # 参数验证 if page < 1: page = 1 if per_page < 1 or per_page > 100: per_page = 20 # 获取原评论信息 comment = Comment.query.get_or_404(comment_id) if comment.status != 'active': return jsonify({ 'code': 404, 'message': '评论不存在或已被删除', 'data': None }), 404 # 构建原评论数据 comment_data = { 'comment_id': comment.id, 'content': comment.content, 'created_at': comment.created_at.strftime('%Y-%m-%d %H:%M:%S'), 'user': { 'user_id': comment.user.id, 'nickname': comment.user.nickname or comment.user.username, 'avatar_url': get_full_avatar_url(comment.user.avatar_url), # 修改这里 } } # 构建回复查询 replies_query = Comment.query.filter_by( parent_id=comment_id, status='active' ) # 排序处理 if sort == 'time_asc': replies_query = replies_query.order_by(Comment.created_at.asc()) else: # 默认按时间倒序 replies_query = replies_query.order_by(Comment.created_at.desc()) # 执行分页查询 pagination = replies_query.paginate(page=page, per_page=per_page, error_out=False) # 格式化回复数据 replies_data = [] for reply in pagination.items: # 获取被回复的用户信息 reply_to_user = None if reply.parent_id: parent_comment = Comment.query.get(reply.parent_id) if parent_comment: reply_to_user = { 'id': parent_comment.user.id, 'nickname': parent_comment.user.nickname or parent_comment.user.username } replies_data.append({ 'reply_id': reply.id, 'content': reply.content, 'created_at': reply.created_at.strftime('%Y-%m-%d %H:%M:%S'), 'user': { 'user_id': reply.user.id, 'nickname': reply.user.nickname or reply.user.username, 'avatar_url': get_full_avatar_url(reply.user.avatar_url), # 修改这里 }, 'reply_to': reply_to_user }) return jsonify({ 'code': 200, 'message': 'success', 'data': { 'comment': comment_data, 'replies': { 'total': pagination.total, 'current_page': pagination.page, 'total_pages': pagination.pages, 'items': replies_data } } }) except Exception as e: return jsonify({ 'code': 500, 'message': str(e), 'data': None }), 500 # 工具函数:解析JSON字段 def parse_json_field(field_value): """解析JSON字段""" if not field_value: return [] try: if isinstance(field_value, str): if field_value.startswith('['): return json.loads(field_value) else: return field_value.split(',') else: return field_value except: return [] # 工具函数:获取 future_events 表字段值,支持新旧字段回退 def get_future_event_field(row, new_field, old_field): """ 获取 future_events 表字段值,支持新旧字段回退 如果新字段存在且不为空,使用新字段;否则使用旧字段 """ new_value = getattr(row, new_field, None) if hasattr(row, new_field) else None old_value = getattr(row, old_field, None) if hasattr(row, old_field) else None # 如果新字段有值(不为空字符串),使用新字段 if new_value is not None and str(new_value).strip(): return new_value return old_value # 工具函数:解析新的 best_matches 数据结构(含研报引用信息) def parse_best_matches(best_matches_value): """ 解析新的 best_matches 数据结构(含研报引用信息) 新结构示例: [ { "stock_code": "300451.SZ", "company_name": "创业慧康", "original_description": "核心标的,医疗信息化...", "best_report_title": "报告标题", "best_report_author": "作者", "best_report_sentences": "相关内容", "best_report_match_score": "好", "best_report_match_ratio": 0.9285714285714286, "best_report_declare_date": "2023-04-25T00:00:00", "total_reports": 9, "high_score_reports": 6 }, ... ] 返回统一格式的股票列表,兼容旧格式 """ if not best_matches_value: return [] try: # 解析 JSON if isinstance(best_matches_value, str): data = json.loads(best_matches_value) else: data = best_matches_value if not isinstance(data, list): return [] result = [] for item in data: if isinstance(item, dict): # 新结构:包含研报信息的字典 # 将相关度转为整数百分比 (0.928 -> 93) raw_score = item.get('best_report_match_ratio', 0) int_score = int(round(raw_score * 100)) if raw_score and raw_score <= 1 else int(round(raw_score)) if raw_score else 0 stock_info = { 'code': item.get('stock_code', ''), 'name': item.get('company_name', ''), 'description': item.get('original_description', ''), 'score': int_score, # 研报引用信息 'report': { 'title': item.get('best_report_title', ''), 'author': item.get('best_report_author', ''), 'sentences': item.get('best_report_sentences', ''), 'match_score': item.get('best_report_match_score', ''), 'match_ratio': item.get('best_report_match_ratio', 0), 'declare_date': item.get('best_report_declare_date', ''), 'total_reports': item.get('total_reports', 0), 'high_score_reports': item.get('high_score_reports', 0) } if item.get('best_report_title') else None } result.append(stock_info) elif isinstance(item, (list, tuple)) and len(item) >= 2: # 旧结构:[code, name, description, score] result.append({ 'code': item[0], 'name': item[1], 'description': item[2] if len(item) > 2 else '', 'score': item[3] if len(item) > 3 else 0, 'report': None }) return result except Exception as e: print(f"parse_best_matches error: {e}") return [] # 工具函数:处理转义字符,保留 Markdown 格式 def unescape_markdown_text(text): """ 将数据库中存储的转义字符串转换为真正的换行符和特殊字符 例如:'\\n\\n#### 标题' -> '\n\n#### 标题' """ if not text: return text # 将转义的换行符转换为真正的换行符 # 注意:这里处理的是字符串字面量 '\\n',不是转义序列 text = text.replace('\\n', '\n') text = text.replace('\\r', '\r') text = text.replace('\\t', '\t') return text.strip() # 工具函数:清理 Markdown 文本 def clean_markdown_text(text): """清理文本中的 Markdown 符号和多余的换行符 Args: text: 原始文本(可能包含 Markdown 符号) Returns: 清理后的纯文本 """ if not text: return text import re # 1. 移除 Markdown 标题符号 (### , ## , # ) text = re.sub(r'^#{1,6}\s+', '', text, flags=re.MULTILINE) # 2. 移除 Markdown 加粗符号 (**text** 或 __text__) text = re.sub(r'\*\*(.+?)\*\*', r'\1', text) text = re.sub(r'__(.+?)__', r'\1', text) # 3. 移除 Markdown 斜体符号 (*text* 或 _text_) text = re.sub(r'\*(.+?)\*', r'\1', text) text = re.sub(r'_(.+?)_', r'\1', text) # 4. 移除 Markdown 列表符号 (- , * , + , 1. ) text = re.sub(r'^[\s]*[-*+]\s+', '', text, flags=re.MULTILINE) text = re.sub(r'^[\s]*\d+\.\s+', '', text, flags=re.MULTILINE) # 5. 移除 Markdown 引用符号 (> ) text = re.sub(r'^>\s+', '', text, flags=re.MULTILINE) # 6. 移除 Markdown 代码块符号 (``` 或 `) text = re.sub(r'```[\s\S]*?```', '', text) text = re.sub(r'`(.+?)`', r'\1', text) # 7. 移除 Markdown 链接 ([text](url) -> text) text = re.sub(r'\[(.+?)\]\(.+?\)', r'\1', text) # 8. 清理多余的换行符 # 将多个连续的换行符(\n\n\n...)替换为单个换行符 text = re.sub(r'\n{3,}', '\n\n', text) # 9. 清理行首行尾的空白字符 text = re.sub(r'^\s+|\s+$', '', text, flags=re.MULTILINE) # 10. 移除多余的空格(连续多个空格替换为单个空格) text = re.sub(r' {2,}', ' ', text) # 11. 清理首尾空白 text = text.strip() return text # 10. 投资日历-事件接口(增强版) @app.route('/api/calendar/events', methods=['GET']) def api_calendar_events(): """投资日历事件接口 - 连接 future_events 表 (修正版)""" try: start_date = request.args.get('start') end_date = request.args.get('end') importance = request.args.get('importance', 'all') category = request.args.get('category', 'all') search_query = request.args.get('q', '').strip() # 新增搜索参数 page = int(request.args.get('page', 1)) per_page = int(request.args.get('per_page', 10)) offset = (page - 1) * per_page # 构建基础查询 - 使用 future_events 表 # 添加新字段 second_modified_text, `second_modified_text.1`, best_matches 支持新旧回退 query = """ SELECT data_id, \ calendar_time, \ type, \ star, \ title, \ former, \ forecast, \ related_stocks, \ concepts, \ second_modified_text, \ `second_modified_text.1` as second_modified_text_1, \ best_matches FROM future_events WHERE 1 = 1 \ """ params = {} if start_date: query += " AND calendar_time >= :start_date" params['start_date'] = datetime.fromisoformat(start_date) if end_date: query += " AND calendar_time <= :end_date" params['end_date'] = datetime.fromisoformat(end_date) # 重要性筛选(支持多选,逗号分隔,如 importance=S,A,B) if importance != 'all': importance_list = [i.strip().upper() for i in importance.split(',') if i.strip()] if len(importance_list) == 1: query += " AND star = :importance" params['importance'] = importance_list[0] elif len(importance_list) > 1: placeholders = ', '.join([f':imp_{i}' for i in range(len(importance_list))]) query += f" AND star IN ({placeholders})" for i, imp in enumerate(importance_list): params[f'imp_{i}'] = imp # 新增搜索条件 if search_query: # 使用LIKE进行模糊搜索,同时搜索title和related_stocks字段 # 对于JSON字段,MySQL会将其作为文本进行搜索 query += """ AND ( title LIKE :search_pattern OR CAST(related_stocks AS CHAR) LIKE :search_pattern OR CAST(concepts AS CHAR) LIKE :search_pattern )""" params['search_pattern'] = f'%{search_query}%' query += " ORDER BY calendar_time LIMIT :limit OFFSET :offset" params['limit'] = per_page params['offset'] = offset result = db.session.execute(text(query), params) events = result.fetchall() # 总数统计(不包含分页) count_query = """ SELECT COUNT(*) as count \ FROM future_events \ WHERE 1=1 \ """ count_params = params.copy() count_params.pop('limit', None) count_params.pop('offset', None) if start_date: count_query += " AND calendar_time >= :start_date" if end_date: count_query += " AND calendar_time <= :end_date" # 重要性筛选(支持多选,逗号分隔) if importance != 'all': importance_list = [i.strip().upper() for i in importance.split(',') if i.strip()] if len(importance_list) == 1: count_query += " AND star = :importance" elif len(importance_list) > 1: placeholders = ', '.join([f':imp_{i}' for i in range(len(importance_list))]) count_query += f" AND star IN ({placeholders})" # 新增搜索条件到计数查询 if search_query: count_query += """ AND ( title LIKE :search_pattern OR CAST(related_stocks AS CHAR) LIKE :search_pattern OR CAST(concepts AS CHAR) LIKE :search_pattern )""" total_count_result = db.session.execute(text(count_query), count_params).fetchone() total_count = total_count_result.count if total_count_result else 0 events_data = [] for event in events: # 使用新字段回退机制获取 former 和 forecast # second_modified_text -> former former_value = get_future_event_field(event, 'second_modified_text', 'former') # second_modified_text.1 -> forecast forecast_new = getattr(event, 'second_modified_text_1', None) forecast_value = forecast_new if (forecast_new and str(forecast_new).strip()) else getattr(event, 'forecast', None) # 解析相关股票 - 优先使用 best_matches,回退到 related_stocks related_stocks_list = [] related_avg_chg = 0 related_max_chg = 0 related_week_chg = 0 # 优先使用 best_matches(新结构,含研报引用) best_matches = getattr(event, 'best_matches', None) if best_matches and str(best_matches).strip(): # 使用新的 parse_best_matches 函数解析 parsed_stocks = parse_best_matches(best_matches) else: # 回退到旧的 related_stocks 处理 parsed_stocks = [] if event.related_stocks: try: import ast if isinstance(event.related_stocks, str): try: stock_data = json.loads(event.related_stocks) except: stock_data = ast.literal_eval(event.related_stocks) else: stock_data = event.related_stocks if stock_data: for stock_info in stock_data: if isinstance(stock_info, list) and len(stock_info) >= 2: # 将相关度转为整数百分比 raw_score = stock_info[3] if len(stock_info) > 3 else 0 int_score = int(round(raw_score * 100)) if raw_score and raw_score <= 1 else int(round(raw_score)) if raw_score else 0 parsed_stocks.append({ 'code': stock_info[0], 'name': stock_info[1], 'description': stock_info[2] if len(stock_info) > 2 else '', 'score': int_score, 'report': None }) except Exception as e: print(f"Error parsing related_stocks for event {event.data_id}: {e}") # 处理解析后的股票数据,获取交易信息 if parsed_stocks: try: daily_changes = [] week_changes = [] for stock_info in parsed_stocks: stock_code = stock_info.get('code', '') stock_name = stock_info.get('name', '') description = stock_info.get('description', '') score = stock_info.get('score', 0) report = stock_info.get('report', None) if stock_code: # 规范化股票代码,移除后缀 clean_code = stock_code.replace('.SZ', '').replace('.SH', '').replace('.BJ', '') # 使用模糊匹配查询真实的交易数据 trade_query = """ SELECT F007N as close_price, F010N as change_pct, TRADEDATE FROM ea_trade WHERE SECCODE LIKE :stock_code_pattern ORDER BY TRADEDATE DESC LIMIT 7 \ """ trade_result = db.session.execute(text(trade_query), {'stock_code_pattern': f'{clean_code}%'}) trade_data = trade_result.fetchall() daily_chg = 0 week_chg = 0 if trade_data: # 日涨跌幅(当日) daily_chg = float(trade_data[0].change_pct or 0) # 周涨跌幅(5个交易日) if len(trade_data) >= 5: current_price = float(trade_data[0].close_price or 0) week_ago_price = float(trade_data[4].close_price or 0) if week_ago_price > 0: week_chg = ((current_price - week_ago_price) / week_ago_price) * 100 # 收集涨跌幅数据 daily_changes.append(daily_chg) week_changes.append(week_chg) related_stocks_list.append({ 'code': stock_code, 'name': stock_name, 'description': description, 'score': score, 'daily_chg': daily_chg, 'week_chg': week_chg, 'report': report # 添加研报引用信息 }) # 计算平均收益率 if daily_changes: related_avg_chg = round(sum(daily_changes) / len(daily_changes), 4) related_max_chg = round(max(daily_changes), 4) if week_changes: related_week_chg = round(sum(week_changes) / len(week_changes), 4) except Exception as e: print(f"Error processing related stocks for event {event.data_id}: {e}") # 解析相关概念 related_concepts = extract_concepts_from_concepts_field(event.concepts) # 获取评星等级 star_rating = event.star # 如果有搜索关键词,可以高亮显示匹配的部分(可选功能) highlight_match = False if search_query: # 检查是否在标题中匹配 if search_query.lower() in (event.title or '').lower(): highlight_match = 'title' # 检查是否在股票中匹配 elif any(search_query.lower() in str(stock).lower() for stock in related_stocks_list): highlight_match = 'stocks' # 检查是否在概念中匹配 elif search_query.lower() in str(related_concepts).lower(): highlight_match = 'concepts' # 将转义的换行符转换为真正的换行符,保留 Markdown 格式 # 使用新字段回退后的值(former_value, forecast_value) cleaned_former = unescape_markdown_text(former_value) cleaned_forecast = unescape_markdown_text(forecast_value) event_dict = { 'id': event.data_id, 'title': event.title, 'description': f"前值: {cleaned_former}, 预测: {cleaned_forecast}" if cleaned_former or cleaned_forecast else "", 'start_time': event.calendar_time.isoformat() if event.calendar_time else None, 'end_time': None, # future_events 表没有结束时间 'category': { 'event_type': event.type, 'importance': event.star, 'star_rating': star_rating }, 'star_rating': star_rating, 'related_concepts': related_concepts, 'related_stocks': related_stocks_list, 'related_avg_chg': round(related_avg_chg, 2), 'related_max_chg': round(related_max_chg, 2), 'related_week_chg': round(related_week_chg, 2), 'former': cleaned_former, 'forecast': cleaned_forecast } # 可选:添加搜索匹配标记 if search_query and highlight_match: event_dict['search_match'] = highlight_match events_data.append(event_dict) return jsonify({ 'code': 200, 'message': 'success', 'data': { 'events': events_data, 'total_count': total_count, 'page': page, 'per_page': per_page, 'total_pages': (total_count + per_page - 1) // per_page, 'search_query': search_query # 返回搜索关键词 } }) except Exception as e: return jsonify({ 'code': 500, 'message': str(e), 'data': None }), 500 # 11. 投资日历-数据接口 @app.route('/api/calendar/data', methods=['GET']) def api_calendar_data(): """投资日历数据接口""" try: start_date = request.args.get('start') end_date = request.args.get('end') data_type = request.args.get('type', 'all') # 分页参数 page = int(request.args.get('page', 1)) page_size = int(request.args.get('page_size', 20)) # 默认每页20条 # 验证分页参数 if page < 1: page = 1 if page_size < 1 or page_size > 100: # 限制每页最大100条 page_size = 20 query1 = RelatedData.query if start_date: query1 = query1.filter(RelatedData.created_at >= datetime.fromisoformat(start_date)) if end_date: query1 = query1.filter(RelatedData.created_at <= datetime.fromisoformat(end_date)) if data_type != 'all': query1 = query1.filter_by(data_type=data_type) data_list1 = query1.order_by(RelatedData.created_at.desc()).all() query2_sql = """ SELECT data_id as id, \ title, \ type as data_type, \ former, \ forecast, \ star, \ calendar_time as created_at FROM future_events WHERE type = 'data' \ """ # 添加时间筛选条件 params = {} if start_date: query2_sql += " AND calendar_time >= :start_date" params['start_date'] = start_date if end_date: query2_sql += " AND calendar_time <= :end_date" params['end_date'] = end_date if data_type != 'all': query2_sql += " AND type = :data_type" params['data_type'] = data_type query2_sql += " ORDER BY calendar_time DESC" result2 = db.session.execute(text(query2_sql), params) result_data = [] # 处理 RelatedData 的数据 for data in data_list1: result_data.append({ 'id': data.id, 'title': data.title, 'data_type': data.data_type, 'data_content': data.data_content, 'description': data.description, 'created_at': data.created_at.isoformat() if data.created_at else None, 'event_id': data.event_id, 'source': 'related_data', # 标识数据来源 'former': None, 'forecast': None, 'fact': None, 'star': None }) # 处理 future_events 的数据 for row in result2: result_data.append({ 'id': row.id, 'title': row.title, 'data_type': row.data_type, 'data_content': None, 'description': None, 'created_at': row.created_at.isoformat() if row.created_at else None, 'event_id': None, 'source': 'future_events', # 标识数据来源 'former': row.former, 'forecast': row.forecast, 'fact': None, 'star': row.star }) # 按时间排序(最新的在前面) result_data.sort(key=lambda x: x['created_at'] or '1900-01-01', reverse=True) # 计算分页 total_count = len(result_data) total_pages = (total_count + page_size - 1) // page_size # 向上取整 # 计算起始和结束索引 start_index = (page - 1) * page_size end_index = start_index + page_size # 获取当前页数据 current_page_data = result_data[start_index:end_index] # 分别统计两个数据源的数量(用于原有逻辑) related_data_count = len(data_list1) future_events_count = len(list(result2)) return jsonify({ 'code': 200, 'message': 'success', 'data': { 'data_list': current_page_data, 'pagination': { 'current_page': page, 'page_size': page_size, 'total_count': total_count, 'total_pages': total_pages, 'has_next': page < total_pages, 'has_prev': page > 1 }, # 保留原有字段,便于兼容 'total_count': total_count, 'related_data_count': related_data_count, 'future_events_count': future_events_count } }) except ValueError as ve: # 处理分页参数格式错误 return jsonify({ 'code': 400, 'message': f'分页参数格式错误: {str(ve)}', 'data': None }), 400 except Exception as e: return jsonify({ 'code': 500, 'message': str(e), 'data': None }), 500 # 12. 投资日历-详情接口 def extract_concepts_from_concepts_field(concepts_text): """从concepts字段中提取概念信息""" if not concepts_text: return [] try: import json import ast # 解析concepts字段的JSON/字符串数据 if isinstance(concepts_text, str): try: # 先尝试JSON解析 concepts_data = json.loads(concepts_text) except: # 如果JSON解析失败,尝试ast.literal_eval解析 concepts_data = ast.literal_eval(concepts_text) else: concepts_data = concepts_text extracted_concepts = [] for concept_info in concepts_data: if isinstance(concept_info, list) and len(concept_info) >= 3: concept_name = concept_info[0] # 概念名称 reason = concept_info[1] # 原因/描述 score = concept_info[2] # 分数 extracted_concepts.append({ 'name': concept_name, 'reason': reason, 'score': score }) return extracted_concepts except Exception as e: print(f"Error extracting concepts: {e}") return [] @app.route('/api/calendar/detail/', methods=['GET']) def api_future_event_detail(item_id): """未来事件详情接口 - 连接 future_events 表 (修正数据解析) - 仅限 Pro/Max 会员""" try: # 从 future_events 表查询事件详情 # 添加新字段 second_modified_text, `second_modified_text.1`, best_matches 支持新旧回退 query = """ SELECT data_id, \ calendar_time, \ type, \ star, \ title, \ former, \ forecast, \ related_stocks, \ concepts, \ second_modified_text, \ `second_modified_text.1` as second_modified_text_1, \ best_matches FROM future_events WHERE data_id = :item_id \ """ result = db.session.execute(text(query), {'item_id': item_id}) event = result.fetchone() if not event: return jsonify({ 'code': 404, 'message': 'Event not found', 'data': None }), 404 # 使用新字段回退机制获取 former 和 forecast # second_modified_text -> former former_value = get_future_event_field(event, 'second_modified_text', 'former') # second_modified_text.1 -> forecast forecast_new = getattr(event, 'second_modified_text_1', None) forecast_value = forecast_new if (forecast_new and str(forecast_new).strip()) else getattr(event, 'forecast', None) extracted_concepts = extract_concepts_from_concepts_field(event.concepts) # 解析相关股票 related_stocks_list = [] sector_stats = { '全部股票': 0, '大周期': 0, '大消费': 0, 'TMT板块': 0, '大金融地产': 0, '公共产业板块': 0, '其他': 0 } # 申万一级行业到主板块的映射 sector_map = { # 大周期 '石油石化': '大周期', '煤炭': '大周期', '有色金属': '大周期', '钢铁': '大周期', '基础化工': '大周期', '建筑材料': '大周期', '机械设备': '大周期', '电力设备及新能源': '大周期', '国防军工': '大周期', '电力设备': '大周期', '电网设备': '大周期', '风力发电': '大周期', '太阳能发电': '大周期', '建筑装饰': '大周期', '建筑': '大周期', '交通运输': '大周期', '采掘': '大周期', '公用事业': '大周期', # 大消费 '汽车': '大消费', '家用电器': '大消费', '酒类': '大消费', '食品饮料': '大消费', '医药生物': '大消费', '纺织服饰': '大消费', '农林牧渔': '大消费', '商贸零售': '大消费', '轻工制造': '大消费', '消费者服务': '大消费', '美容护理': '大消费', '社会服务': '大消费', '纺织服装': '大消费', '商业贸易': '大消费', '休闲服务': '大消费', # 大金融地产 '银行': '大金融地产', '证券': '大金融地产', '保险': '大金融地产', '多元金融': '大金融地产', '综合金融': '大金融地产', '房地产': '大金融地产', '非银金融': '大金融地产', # TMT板块 '计算机': 'TMT板块', '电子': 'TMT板块', '传媒': 'TMT板块', '通信': 'TMT板块', # 公共产业 '环保': '公共产业板块', '综合': '公共产业板块' } # 处理相关股票 - 优先使用 best_matches,回退到 related_stocks related_avg_chg = 0 related_max_chg = 0 related_week_chg = 0 # 优先使用 best_matches(新结构,含研报引用) best_matches = getattr(event, 'best_matches', None) if best_matches and str(best_matches).strip(): # 使用新的 parse_best_matches 函数解析 parsed_stocks = parse_best_matches(best_matches) else: # 回退到旧的 related_stocks 处理 parsed_stocks = [] if event.related_stocks: try: import ast if isinstance(event.related_stocks, str): try: stock_data = json.loads(event.related_stocks) except: stock_data = ast.literal_eval(event.related_stocks) else: stock_data = event.related_stocks if stock_data: for stock_info in stock_data: if isinstance(stock_info, list) and len(stock_info) >= 2: parsed_stocks.append({ 'code': stock_info[0], 'name': stock_info[1], 'description': stock_info[2] if len(stock_info) > 2 else '', 'score': stock_info[3] if len(stock_info) > 3 else 0, 'report': None }) except Exception as e: print(f"Error parsing related_stocks for event {event.data_id}: {e}") # 处理解析后的股票数据 if parsed_stocks: try: daily_changes = [] week_changes = [] for stock_info in parsed_stocks: stock_code = stock_info.get('code', '') stock_name = stock_info.get('name', '') description = stock_info.get('description', '') score = stock_info.get('score', 0) report = stock_info.get('report', None) if stock_code: # 规范化股票代码,移除后缀 clean_code = stock_code.replace('.SZ', '').replace('.SH', '').replace('.BJ', '') print(f"Processing stock: {clean_code} - {stock_name}") # 调试输出 # 使用模糊匹配LIKE查询申万一级行业F004V sector_query = """ SELECT F004V as sw_primary_sector FROM ea_sector WHERE SECCODE LIKE :stock_code_pattern AND F002V = '申银万国行业分类' LIMIT 1 \ """ sector_result = db.session.execute(text(sector_query), {'stock_code_pattern': f'{clean_code}%'}) sector_row = sector_result.fetchone() # 根据申万一级行业(F004V)映射到主板块 sw_primary_sector = sector_row.sw_primary_sector if sector_row else None primary_sector = sector_map.get(sw_primary_sector, '其他') if sw_primary_sector else '其他' print( f"Stock: {clean_code}, SW Primary: {sw_primary_sector}, Primary Sector: {primary_sector}") # 通过SQL查询获取真实的日涨跌幅和周涨跌幅 trade_query = """ SELECT F007N as close_price, F010N as change_pct, TRADEDATE FROM ea_trade WHERE SECCODE LIKE :stock_code_pattern ORDER BY TRADEDATE DESC LIMIT 7 \ """ trade_result = db.session.execute(text(trade_query), {'stock_code_pattern': f'{clean_code}%'}) trade_data = trade_result.fetchall() daily_chg = 0 week_chg = 0 if trade_data: # 日涨跌幅(当日) daily_chg = float(trade_data[0].change_pct or 0) # 周涨跌幅(5个交易日) if len(trade_data) >= 5: current_price = float(trade_data[0].close_price or 0) week_ago_price = float(trade_data[4].close_price or 0) if week_ago_price > 0: week_chg = ((current_price - week_ago_price) / week_ago_price) * 100 print( f"Trade data found: {len(trade_data) if trade_data else 0} records, daily_chg: {daily_chg}") # 统计各分类数量 sector_stats['全部股票'] += 1 sector_stats[primary_sector] += 1 # 收集涨跌幅数据 daily_changes.append(daily_chg) week_changes.append(week_chg) related_stocks_list.append({ 'code': stock_code, # 原始股票代码 'name': stock_name, # 股票名称 'description': description, # 关联描述 'score': score, # 关联分数 'sw_primary_sector': sw_primary_sector, # 申万一级行业(F004V) 'primary_sector': primary_sector, # 主板块分类 'daily_change': daily_chg, # 真实的日涨跌幅 'week_change': week_chg, # 真实的周涨跌幅 'report': report # 研报引用信息(新字段) }) # 计算平均收益率 if daily_changes: related_avg_chg = sum(daily_changes) / len(daily_changes) related_max_chg = max(daily_changes) if week_changes: related_week_chg = sum(week_changes) / len(week_changes) except Exception as e: print(f"Error processing related stocks: {e}") import traceback traceback.print_exc() # 构建返回数据,使用新字段回退后的值 detail_data = { 'id': event.data_id, 'title': event.title, 'type': event.type, 'star': event.star, 'calendar_time': event.calendar_time.isoformat() if event.calendar_time else None, 'former': former_value, # 使用回退后的值(优先 second_modified_text) 'forecast': forecast_value, # 使用回退后的值(优先 second_modified_text.1) 'concepts': event.concepts, 'extracted_concepts': extracted_concepts, 'related_stocks': related_stocks_list, 'sector_stats': sector_stats, 'related_avg_chg': round(related_avg_chg, 2), 'related_max_chg': round(related_max_chg, 2), 'related_week_chg': round(related_week_chg, 2) } return jsonify({ 'code': 200, 'message': 'success', 'data': { 'type': 'future_event', 'detail': detail_data } }) except Exception as e: return jsonify({ 'code': 500, 'message': str(e), 'data': None }), 500 # 13-15. 筛选弹窗接口(已有,优化格式) @app.route('/api/filter/options', methods=['GET']) def api_filter_options(): """筛选选项接口""" try: # 获取排序选项 sort_options = [ {'key': 'new', 'name': '最新', 'desc': '按创建时间排序'}, {'key': 'hot', 'name': '热门', 'desc': '按热度分数排序'}, {'key': 'returns', 'name': '收益率', 'desc': '按收益率排序'}, {'key': 'importance', 'name': '重要性', 'desc': '按重要性等级排序'}, {'key': 'view_count', 'name': '浏览量', 'desc': '按浏览次数排序'} ] # 获取行业筛选选项 industry_options = db.session.execute(text(""" SELECT DISTINCT f002v as classification_name, COUNT(*) as count FROM ea_sector WHERE f002v IS NOT NULL GROUP BY f002v ORDER BY f002v """)).fetchall() # 获取重要性选项 importance_options = [ {'key': 'S', 'name': 'S级', 'desc': '重大事件'}, {'key': 'A', 'name': 'A级', 'desc': '重要事件'}, {'key': 'B', 'name': 'B级', 'desc': '普通事件'}, {'key': 'C', 'name': 'C级', 'desc': '参考事件'} ] return jsonify({ 'code': 200, 'message': 'success', 'data': { 'sort_options': sort_options, 'industry_options': [{ 'name': row.classification_name, 'count': row.count } for row in industry_options], 'importance_options': importance_options } }) except Exception as e: return jsonify({ 'code': 500, 'message': str(e), 'data': None }), 500 # 16-17. 会员权益接口 @app.route('/api/membership/status', methods=['GET']) @token_required def api_membership_status(): """会员状态接口""" try: user = request.user # 获取用户订阅信息 subscription = get_user_subscription_safe(user.id) subscription_type = subscription.subscription_type # free/pro/max is_active = subscription.is_active() # is_member: pro或max且未过期为True(兼容旧逻辑) is_member = is_active and subscription_type in ('pro', 'max') # 手机绑定状态 phone_bindcd = bool(user.phone and user.phone_confirmed) # 格式化到期日期,供前端直接显示 member_expire_date = None if subscription.end_date: member_expire_date = subscription.end_date.strftime('%Y-%m-%d') return jsonify({ 'code': 200, 'message': 'success', 'data': { 'user_id': user.id, # 会员信息 'subscription_type': subscription_type, # free/pro/max 'subscription_status': subscription.subscription_status, # active/expired/cancelled 'is_member': is_member, # 兼容旧逻辑 'is_active': is_active, 'start_date': subscription.start_date.isoformat() if subscription.start_date else None, 'end_date': subscription.end_date.isoformat() if subscription.end_date else None, 'member_expire_date': member_expire_date, # 兼容前端显示格式 YYYY-MM-DD 'days_left': subscription.days_left(), 'auto_renewal': subscription.auto_renewal, # 手机绑定状态 'phone_bindcd': phone_bindcd, 'phone': user.phone if phone_bindcd else None, # 兼容旧字段 'user_level': user.user_level, 'benefits': { 'unlimited_access': is_member, 'priority_support': subscription_type == 'max', 'advanced_analytics': is_member, 'custom_alerts': is_member, 'max_exclusive': subscription_type == 'max' # max专属功能 } } }) except Exception as e: return jsonify({ 'code': 500, 'message': str(e), 'data': None }), 500 # 18-19. 个人中心接口 @app.route('/api/user/profile', methods=['GET']) @token_required def api_user_profile(): """个人资料接口""" try: user = request.user likes_count = PostLike.query.filter_by(user_id=user.id).count() follows_count = EventFollow.query.filter_by(user_id=user.id).count() comments_made = Comment.query.filter_by(user_id=user.id).count() comments_received = db.session.query(Comment) \ .join(Post, Comment.post_id == Post.id) \ .filter(Post.user_id == user.id).count() replies_received = Comment.query.filter( Comment.parent_id.in_( db.session.query(Comment.id).filter_by(user_id=user.id) ) ).count() # 总评论数(发出的评论 + 收到的评论和回复) total_comments = comments_made + comments_received + replies_received # 判断手机号绑定状态 phone_bindcd = bool(user.phone and user.phone_confirmed) profile_data = { 'basic_info': { 'user_id': user.id, 'username': user.username, 'email': user.email, 'phone': user.phone if phone_bindcd else None, 'phone_bindcd': phone_bindcd, 'nickname': user.nickname, 'avatar_url': get_full_avatar_url(user.avatar_url), # 修改这里 'bio': user.bio, 'gender': user.gender, 'birth_date': user.birth_date.isoformat() if user.birth_date else None, 'location': user.location }, 'account_status': { 'email_confirmed': user.email_confirmed, 'phone_confirmed': user.phone_confirmed, 'is_verified': user.is_verified, 'verify_time': user.verify_time.isoformat() if user.verify_time else None, 'created_at': user.created_at.isoformat() if user.created_at else None, 'last_seen': user.last_seen.isoformat() if user.last_seen else None }, 'statistics': { 'likes_count': likes_count, # 点赞数 'follows_count': follows_count, # 关注数 'total_comments': total_comments, # 总评论数 'comments_detail': { 'comments_made': comments_made, # 发出的评论 'comments_received': comments_received, # 收到的评论 'replies_received': replies_received # 收到的回复 } }, 'investment_preferences': { 'trading_experience': user.trading_experience, 'investment_style': user.investment_style, 'risk_preference': user.risk_preference, 'investment_amount': user.investment_amount, 'preferred_markets': user.preferred_markets }, 'community_stats': { 'user_level': user.user_level, 'reputation_score': user.reputation_score, 'contribution_point': user.contribution_point, 'post_count': user.post_count, 'comment_count': user.comment_count, 'follower_count': user.follower_count, 'following_count': user.following_count }, 'settings': { 'email_notifications': user.email_notifications, 'sms_notifications': user.sms_notifications, 'privacy_level': user.privacy_level, 'theme_preference': user.theme_preference } } return jsonify({ 'code': 200, 'message': 'success', 'data': profile_data }) except Exception as e: return jsonify({ 'code': 500, 'message': str(e), 'data': None }), 500 @app.route('/api/user/feedback', methods=['POST']) @token_required def api_user_feedback(): """用户反馈接口""" try: user = request.user # 获取请求数据(兼容多种格式) data = request.get_json(force=True, silent=True) if not data: data = request.form.to_dict() content = data.get('content', '').strip() if not content: return jsonify({ 'code': 400, 'message': '反馈内容不能为空', 'data': None }), 400 feedback_type = data.get('type', 'general') title = data.get('title', '').strip() contact = data.get('contact', '').strip() images = data.get('images', []) # 确保images是列表 if isinstance(images, str): images = [images] if images else [] # 创建反馈记录 feedback = Feedback( user_id=user.id, feedback_type=feedback_type, title=title if title else None, content=content, contact=contact if contact else user.phone or user.email, images=images if images else None, status='pending' ) db.session.add(feedback) db.session.commit() return jsonify({ 'code': 200, 'message': '反馈提交成功,感谢您的宝贵意见!', 'data': { 'feedback_id': feedback.id, 'status': feedback.status, 'created_at': feedback.created_at.isoformat() if feedback.created_at else None } }) except Exception as e: db.session.rollback() return jsonify({ 'code': 500, 'message': f'反馈提交失败: {str(e)}', 'data': None }), 500 @app.route('/api/user/feedback', methods=['GET']) @token_required def api_user_feedback_list(): """获取用户反馈列表""" try: user = request.user page = request.args.get('page', 1, type=int) per_page = request.args.get('per_page', 10, type=int) feedbacks = Feedback.query.filter_by(user_id=user.id) \ .order_by(Feedback.created_at.desc()) \ .paginate(page=page, per_page=per_page, error_out=False) feedback_list = [] for fb in feedbacks.items: feedback_list.append({ 'id': fb.id, 'type': fb.feedback_type, 'title': fb.title, 'content': fb.content, 'status': fb.status, 'admin_reply': fb.admin_reply, 'created_at': fb.created_at.isoformat() if fb.created_at else None, 'updated_at': fb.updated_at.isoformat() if fb.updated_at else None }) return jsonify({ 'code': 200, 'message': 'success', 'data': { 'feedbacks': feedback_list, 'pagination': { 'page': page, 'per_page': per_page, 'total': feedbacks.total, 'pages': feedbacks.pages } } }) except Exception as e: return jsonify({ 'code': 500, 'message': str(e), 'data': None }), 500 # 在文件开头添加缓存变量 _agreements_cache = {} _cache_loaded = False def load_agreements_from_docx(): """从docx文件中加载协议内容,只读取一次""" global _agreements_cache, _cache_loaded if _cache_loaded: return _agreements_cache try: # 定义文件路径和对应的协议类型 docx_files = { 'about_us': 'about_us.docx', # 关于我们 'service_terms': 'service_terms.docx', # 服务条款 'privacy_policy': 'privacy_policy.docx' # 隐私政策 } # 定义协议标题 titles = { 'about_us': '关于我们', 'service_terms': '服务条款', 'privacy_policy': '隐私政策' } for agreement_type, filename in docx_files.items(): file_path = os.path.join(os.path.dirname(__file__), filename) if os.path.exists(file_path): try: # 读取docx文件 doc = Document(file_path) # 提取文本内容 content_paragraphs = [] for paragraph in doc.paragraphs: if paragraph.text.strip(): # 跳过空段落 content_paragraphs.append(paragraph.text.strip()) # 合并所有段落 content = '\n\n'.join(content_paragraphs) # 获取文件修改时间作为版本标识 file_stat = os.stat(file_path) last_modified = file_stat.st_mtime # 缓存内容 _agreements_cache[agreement_type] = { 'title': titles.get(agreement_type, agreement_type), 'content': content, 'last_updated': last_modified, 'version': '1.0', 'file_path': filename } print(f"Successfully loaded {agreement_type} from {filename}") except Exception as e: print(f"Error reading {filename}: {str(e)}") # 如果读取失败,使用默认内容 _agreements_cache[agreement_type] = { 'title': titles.get(agreement_type, agreement_type), 'content': f"协议内容正在加载中,请稍后再试。(文件:{filename})", 'last_updated': None, 'version': '1.0', 'file_path': filename, 'error': str(e) } else: print(f"File not found: {filename}") # 如果文件不存在,使用默认内容 _agreements_cache[agreement_type] = { 'title': titles.get(agreement_type, agreement_type), 'content': f"协议文件未找到,请联系管理员。(文件:{filename})", 'last_updated': None, 'version': '1.0', 'file_path': filename, 'error': 'File not found' } _cache_loaded = True print(f"Agreements cache loaded successfully. Total: {len(_agreements_cache)} agreements") except Exception as e: print(f"Error loading agreements: {str(e)}") _cache_loaded = False return _agreements_cache @app.route('/api/agreements', methods=['GET']) def api_agreements(): """平台协议接口 - 从docx文件读取""" try: # 获取查询参数 agreement_type = request.args.get('type') # about_us, service_terms, privacy_policy force_reload = request.args.get('reload', 'false').lower() == 'true' # 强制重新加载 # 如果需要强制重新加载,清除缓存 if force_reload: global _cache_loaded _cache_loaded = False _agreements_cache.clear() # 加载协议内容 agreements_data = load_agreements_from_docx() if not agreements_data: return jsonify({ 'code': 500, 'message': 'Failed to load agreements', 'data': None }), 500 # 如果指定了特定协议类型,只返回该协议 if agreement_type and agreement_type in agreements_data: return jsonify({ 'code': 200, 'message': 'success', 'data': { 'agreement_type': agreement_type, **agreements_data[agreement_type] } }) # 返回所有协议 return jsonify({ 'code': 200, 'message': 'success', 'data': { 'agreements': agreements_data, 'available_types': list(agreements_data.keys()), 'cache_loaded': _cache_loaded, 'total_agreements': len(agreements_data) } }) except Exception as e: return jsonify({ 'code': 500, 'message': str(e), 'data': None }), 500 # 20. 个人中心-我的关注接口 @app.route('/api/user/activities', methods=['GET']) @token_required def api_user_activities(): """用户活动接口(我的关注、评论、点赞)""" try: user = request.user activity_type = request.args.get('type', 'follows') # follows, comments, likes, commented page = request.args.get('page', 1, type=int) per_page = min(50, request.args.get('per_page', 20, type=int)) if activity_type == 'follows': # 我的关注列表 follows = EventFollow.query.filter_by(user_id=user.id) \ .order_by(EventFollow.created_at.desc()) \ .paginate(page=page, per_page=per_page, error_out=False) activities = [] for follow in follows.items: # 获取相关股票并添加单日涨幅 related_stocks_data = [] for stock in follow.event.related_stocks.limit(5): # 处理股票代码,移除可能的后缀(如 .SZ 或 .SH) base_stock_code = stock.stock_code.split('.')[0] # 获取股票最新交易数据 trade_data = TradeData.query.filter( TradeData.SECCODE.startswith(base_stock_code) ).order_by(TradeData.TRADEDATE.desc()).first() # 计算单日涨幅 daily_change = None if trade_data and trade_data.F010N is not None: daily_change = float(trade_data.F010N) related_stocks_data.append({ 'stock_code': stock.stock_code, 'stock_name': stock.stock_name, 'correlation': stock.correlation, 'daily_change': daily_change, # 新增:单日涨幅 'daily_change_formatted': f"{daily_change:.2f}%" if daily_change is not None else "暂无数据" # 格式化显示 }) activities.append({ 'event_id': follow.event_id, 'event_title': follow.event.title, 'event_description': follow.event.description, 'follow_time': follow.created_at.isoformat() if follow.created_at else None, 'event_hot_score': follow.event.hot_score, # 新增字段 'importance': follow.event.importance, # 重要性 'related_avg_chg': follow.event.related_avg_chg, # 平均涨幅 'related_max_chg': follow.event.related_max_chg, # 最大涨幅 'related_week_chg': follow.event.related_week_chg, # 周涨幅 'related_stocks': related_stocks_data, # 修改:包含单日涨幅的相关股票 'created_at': follow.event.created_at.isoformat() if follow.event.created_at else None, # 发布时间 'preview': follow.event.description[:200] if follow.event.description else None, # 预览(限制200字) 'comment_count': follow.event.post_count, # 评论数 'view_count': follow.event.view_count, # 评论数 'follower_count': follow.event.follower_count # 关注数 }) total = follows.total pages = follows.pages elif activity_type == 'likes': # 我的点赞列表 likes = PostLike.query.filter_by(user_id=user.id) \ .order_by(PostLike.created_at.desc()) \ .paginate(page=page, per_page=per_page, error_out=False) activities = [{ 'like_id': like.id, 'post_id': like.post_id, 'post_content': like.post.content, 'like_time': like.created_at.isoformat() if like.created_at else None, # 新增发布人信息 'author': { 'nickname': like.post.user.nickname or like.post.user.username, 'avatar_url': get_full_avatar_url(like.post.user.avatar_url), # 修改这里 } } for like in likes.items] total = likes.total pages = likes.pages elif activity_type == 'comments': # 我的评论列表(增强版 - 添加重要性和事件内容) comments = Comment.query.filter_by(user_id=user.id) \ .join(Post, Comment.post_id == Post.id) \ .join(Event, Post.event_id == Event.id) \ .order_by(Comment.created_at.desc()) \ .paginate(page=page, per_page=per_page, error_out=False) activities = [] for comment in comments.items: # 通过关联路径获取事件信息:comment.post_id -> post.id -> post.event_id -> event.id post = comment.post event = post.event if post else None activity_data = { 'comment_id': comment.id, 'post_id': comment.post_id, 'content': comment.content, # 评论内容 'created_at': comment.created_at.isoformat() if comment.created_at else None, 'post_title': post.title if post and post.title else None, 'post_content': post.content if post else None, # 新增:评论者信息(当前用户) 'commenter': { 'id': comment.user.id, 'username': comment.user.username, 'nickname': comment.user.nickname or comment.user.username, 'avatar_url': get_full_avatar_url(comment.user.avatar_url), 'user_level': comment.user.user_level, 'is_verified': comment.user.is_verified }, # 新增字段:事件信息 'event': { 'id': event.id if event else None, 'title': event.title if event else None, 'description': event.description if event else None, # 事件内容 'importance': event.importance if event else None, # 重要性 'event_type': event.event_type if event else None, 'hot_score': event.hot_score if event else None, 'view_count': event.view_count if event else None, 'related_avg_chg': event.related_avg_chg if event else None, 'created_at': event.created_at.isoformat() if event and event.created_at else None }, # 新增:帖子作者信息 'post_author': { 'id': post.user.id if post else None, 'username': post.user.username if post else None, 'nickname': post.user.nickname or post.user.username if post else None, 'avatar_url': get_full_avatar_url(post.user.avatar_url) if post else None, } } activities.append(activity_data) total = comments.total pages = comments.pages elif activity_type == 'commented': # 评论了我的帖子 my_posts = Post.query.filter_by(user_id=user.id).subquery() comments = Comment.query.join(my_posts, Comment.post_id == my_posts.c.id) \ .filter(Comment.user_id != user.id) \ .order_by(Comment.created_at.desc()) \ .paginate(page=page, per_page=per_page, error_out=False) activities = [{ 'comment_id': comment.id, 'comment_content': comment.content, 'comment_time': comment.created_at.isoformat() if comment.created_at else None, 'commenter_nickname': comment.user.nickname or comment.user.username, 'commenter_avatar': get_full_avatar_url(comment.user.avatar_url), # 修改这里 'post_content': comment.post.content, 'event_title': comment.post.event.title, 'event_id': comment.post.event_id } for comment in comments.items] total = comments.total pages = comments.pages return jsonify({ 'code': 200, 'message': 'success', 'data': { 'activities': activities, 'total': total, 'pages': pages, 'current_page': page } }) except Exception as e: print(f"Error in api_user_activities: {str(e)}") return jsonify({ 'code': 500, 'message': '服务器内部错误', 'data': None }), 500 # 通用错误处理 @app.errorhandler(404) def api_not_found(error): if request.path.startswith('/api/'): return jsonify({ 'code': 404, 'message': '接口不存在', 'data': None }), 404 return error @app.errorhandler(405) def api_method_not_allowed(error): if request.path.startswith('/api/'): return jsonify({ 'code': 405, 'message': '请求方法不允许', 'data': None }), 405 return error # 应用启动时自动初始化(兼容 Gunicorn 和直接运行) _sywg_cache_initialized = False def ensure_sywg_cache_initialized(): """确保申银万国行业分类缓存已初始化(懒加载,首次请求时触发)""" global _sywg_cache_initialized if not _sywg_cache_initialized: init_sywg_industry_cache() _sywg_cache_initialized = True @app.before_request def before_request_init(): """首次请求时初始化缓存""" global _sywg_cache_initialized if not _sywg_cache_initialized: ensure_sywg_cache_initialized() # ===================== 应用启动配置 ===================== # 生产环境推荐使用 Gunicorn 启动(见下方命令) # # 【启动方式 1】使用 Gunicorn + gevent(推荐): # USE_GEVENT=true gunicorn -w 4 -k gevent --worker-connections 1000 \ # -b 0.0.0.0:5002 --timeout 120 --graceful-timeout 30 \ # --certfile=/etc/letsencrypt/live/api.valuefrontier.cn/fullchain.pem \ # --keyfile=/etc/letsencrypt/live/api.valuefrontier.cn/privkey.pem \ # app_vx:app # # 【启动方式 2】使用 Gunicorn + 多进程(不使用 gevent): # gunicorn -w 4 -b 0.0.0.0:5002 --timeout 120 --graceful-timeout 30 \ # --certfile=/etc/letsencrypt/live/api.valuefrontier.cn/fullchain.pem \ # --keyfile=/etc/letsencrypt/live/api.valuefrontier.cn/privkey.pem \ # app_vx:app # # 【启动方式 3】开发环境直接运行(仅限本地调试): # python app_vx.py # # ===================== 应用启动配置结束 ===================== if __name__ == '__main__': import argparse parser = argparse.ArgumentParser(description='启动 Flask 应用') parser.add_argument('--debug', action='store_true', help='启用调试模式(仅限开发环境)') parser.add_argument('--port', type=int, default=5002, help='监听端口(默认 5002)') parser.add_argument('--no-ssl', action='store_true', help='禁用 SSL(仅限开发环境)') args = parser.parse_args() # 直接运行时,立即初始化缓存 with app.app_context(): init_sywg_industry_cache() _sywg_cache_initialized = True # 警告:生产环境不应使用 debug=True if args.debug: logger.warning("⚠️ 调试模式已启用,仅限开发环境使用!") # SSL 配置 ssl_context = None if not args.no_ssl: cert_file = '/etc/nginx/ssl/api.valuefrontier.cn/fullchain.pem' key_file = '/etc/nginx/ssl/api.valuefrontier.cn/privkey.pem' if os.path.exists(cert_file) and os.path.exists(key_file): ssl_context = (cert_file, key_file) else: logger.warning("⚠️ SSL 证书文件不存在,将使用 HTTP 模式") logger.info(f"🚀 启动 Flask 应用: port={args.port}, debug={args.debug}, ssl={'enabled' if ssl_context else 'disabled'}") app.run( host='0.0.0.0', port=args.port, debug=args.debug, ssl_context=ssl_context, threaded=True # 启用多线程处理请求 )