import csv import logging import random import re import math import os import secrets import string import pytz import requests from flask_compress import Compress from functools import wraps from pathlib import Path import time from sqlalchemy import create_engine, text, func, or_, case, event, desc, asc from flask import Flask, has_request_context, render_template, request, jsonify, redirect, url_for, flash, session, \ render_template_string, current_app, send_from_directory # Flask 3.x 兼容性补丁:flask-sqlalchemy 旧版本需要 _app_ctx_stack try: from flask import _app_ctx_stack except ImportError: import flask from werkzeug.local import LocalStack import threading # 创建一个兼容的 LocalStack 子类 class CompatLocalStack(LocalStack): @property def __ident_func__(self): # 返回当前线程的标识函数 # 优先使用 greenlet(协程),否则使用 threading try: from greenlet import getcurrent return getcurrent except ImportError: return threading.get_ident flask._app_ctx_stack = CompatLocalStack() from flask_sqlalchemy import SQLAlchemy from flask_login import LoginManager, UserMixin, login_user, logout_user, login_required, current_user from flask_mail import Mail, Message from itsdangerous import URLSafeTimedSerializer from flask_migrate import Migrate from flask_session import Session # type: ignore from sqlalchemy.dialects.mysql.base import MySQLDialect from sqlalchemy.dialects.postgresql import JSONB from werkzeug.utils import secure_filename from PIL import Image from datetime import datetime, timedelta, time as dt_time from werkzeug.security import generate_password_hash, check_password_hash import json from clickhouse_driver import Client as Cclient import jwt from docx import Document from tencentcloud.common import credential from tencentcloud.common.profile.client_profile import ClientProfile from tencentcloud.common.profile.http_profile import HttpProfile from tencentcloud.sms.v20210111 import sms_client, models from tencentcloud.common.exception.tencent_cloud_sdk_exception import TencentCloudSDKException engine = create_engine("mysql+pymysql://root:Zzl5588161!@222.128.1.157:33060/stock", echo=False, pool_size=20, max_overflow=50) engine_med = create_engine("mysql+pymysql://root:Zzl5588161!@222.128.1.157:33060/med", echo=False) engine_2 = create_engine("mysql+pymysql://root:Zzl5588161!@222.128.1.157:33060/valuefrontier", echo=False) logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) app = Flask(__name__) Compress(app) UPLOAD_FOLDER = 'static/uploads/avatars' ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif'} MAX_CONTENT_LENGTH = 16 * 1024 * 1024 # 16MB max file size # Configure Flask-Compress app.config['COMPRESS_ALGORITHM'] = ['gzip', 'br'] app.config['COMPRESS_MIMETYPES'] = [ 'text/html', 'text/css', 'text/xml', 'application/json', 'application/javascript', 'application/x-javascript' ] user_tokens = {} app.config['SECRET_KEY'] = 'vf7891574233241' app.config['SQLALCHEMY_DATABASE_URI'] = 'mysql+pymysql://root:Zzl5588161!@222.128.1.157:33060/stock?charset=utf8mb4' app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False app.config['JSON_AS_ASCII'] = False app.config['JSONIFY_PRETTYPRINT_REGULAR'] = True # 邮件配置 app.config['MAIL_SERVER'] = 'smtp.exmail.qq.com' app.config['MAIL_PORT'] = 465 app.config['MAIL_USE_SSL'] = True app.config['MAIL_USERNAME'] = 'admin@valuefrontier.cn' app.config['MAIL_PASSWORD'] = 'QYncRu6WUdASvTg4' app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER app.config['MAX_CONTENT_LENGTH'] = MAX_CONTENT_LENGTH # 腾讯云短信配置 SMS_SECRET_ID = 'AKID2we9TacdTAhCjCSYTErHVimeJo9Yr00s' SMS_SECRET_KEY = 'pMlBWijlkgT9fz5ziEXdWEnAPTJzRfkf' SMS_SDK_APP_ID = "1400972398" SMS_SIGN_NAME = "价值前沿科技" SMS_TEMPLATE_REGISTER = "2386557" # 注册模板 SMS_TEMPLATE_LOGIN = "2386540" # 登录模板 verification_codes = {} # 微信小程序 app.config['WECHAT_APP_ID'] = 'wx0edeaab76d4fa414' app.config['WECHAT_APP_SECRET'] = '0d0c70084f05a8c1411f6b89da7e815d' app.config['BASE_URL'] = 'http://43.143.189.195:5002' app.config['WECHAT_REDIRECT_URI'] = f"{app.config['BASE_URL']}/api/wechat/callback" WECHAT_APP_ID = 'wx0edeaab76d4fa414' WECHAT_APP_SECRET = '0d0c70084f05a8c1411f6b89da7e815d' JWT_SECRET_KEY = 'vfllmgreat33818!' # 请修改为安全的密钥 JWT_ALGORITHM = 'HS256' JWT_EXPIRATION_HOURS = 24 * 7 # Token有效期7天 # Session 配置 - 使用文件系统存储(替代 Redis) app.config['SESSION_TYPE'] = 'filesystem' app.config['SESSION_FILE_DIR'] = os.path.join(os.path.dirname(__file__), 'flask_session') app.config['SESSION_PERMANENT'] = True app.config['PERMANENT_SESSION_LIFETIME'] = timedelta(days=7) # Session 有效期 7 天 # 确保 session 目录存在 os.makedirs(app.config['SESSION_FILE_DIR'], exist_ok=True) # Cache directory setup CACHE_DIR = Path('cache') CACHE_DIR.mkdir(exist_ok=True) # Memory management constants MAX_MEMORY_PERCENT = 75 MEMORY_CHECK_INTERVAL = 300 MAX_CACHE_ITEMS = 50 # 申银万国行业分类缓存(启动时初始化,避免每次请求都查询数据库) # 结构: {industry_level: {industry_name: [code_prefix1, code_prefix2, ...]}} SYWG_INDUSTRY_CACHE = { 2: {}, # level2: 一级行业 3: {}, # level3: 二级行业 4: {}, # level4: 三级行业 5: {} # level5: 四级行业 } # 初始化扩展 db = SQLAlchemy(app) mail = Mail(app) login_manager = LoginManager(app) login_manager.login_view = 'login' serializer = URLSafeTimedSerializer(app.config['SECRET_KEY']) migrate = Migrate(app, db) DOMAIN = 'http://43.143.189.195:5002' JWT_SECRET = 'Llmgreat123' JWT_EXPIRES_SECONDS = 3600 # 1小时有效期 Session(app) def token_required(f): """装饰器:需要token认证的接口""" from functools import wraps @wraps(f) def decorated_function(*args, **kwargs): token = None # 从请求头获取token auth_header = request.headers.get('Authorization') if auth_header and auth_header.startswith('Bearer '): token = auth_header[7:] if not token: return jsonify({'message': '缺少认证token'}), 401 token_data = user_tokens.get(token) if not token_data: return jsonify({'message': 'Token无效', 'code': 401}), 401 # 检查是否过期 if token_data['expires'] < datetime.now(): del user_tokens[token] return jsonify({'message': 'Token已过期'}), 401 # 获取用户对象并添加到请求上下文 user = User.query.get(token_data['user_id']) if not user: return jsonify({'message': '用户不存在'}), 404 # 将用户对象添加到request request.user = user request.current_user_id = token_data['user_id'] return f(*args, **kwargs) return decorated_function def beijing_now(): # 使用 pytz 处理时区 beijing_tz = pytz.timezone('Asia/Shanghai') return datetime.now(beijing_tz) # ============================================ # 订阅功能模块(与 app.py 保持一致) # ============================================ class UserSubscription(db.Model): """用户订阅表 - 独立于现有User表""" __tablename__ = 'user_subscriptions' id = db.Column(db.Integer, primary_key=True, autoincrement=True) user_id = db.Column(db.Integer, nullable=False, unique=True, index=True) subscription_type = db.Column(db.String(10), nullable=False, default='free') subscription_status = db.Column(db.String(20), nullable=False, default='active') start_date = db.Column(db.DateTime, nullable=True) end_date = db.Column(db.DateTime, nullable=True) billing_cycle = db.Column(db.String(10), nullable=True) auto_renewal = db.Column(db.Boolean, nullable=False, default=False) created_at = db.Column(db.DateTime, default=beijing_now) updated_at = db.Column(db.DateTime, default=beijing_now, onupdate=beijing_now) def is_active(self): if self.subscription_status != 'active': return False if self.subscription_type == 'free': return True if self.end_date: # 将数据库的 naive datetime 转换为带时区的 aware datetime beijing_tz = pytz.timezone('Asia/Shanghai') end_date_aware = self.end_date if self.end_date.tzinfo else beijing_tz.localize(self.end_date) return beijing_now() <= end_date_aware return True def days_left(self): if not self.is_active(): return 0 if self.subscription_type == 'free': return 999 if not self.end_date: return 999 try: now = beijing_now() # 将数据库的 naive datetime 转换为带时区的 aware datetime beijing_tz = pytz.timezone('Asia/Shanghai') end_date_aware = self.end_date if self.end_date.tzinfo else beijing_tz.localize(self.end_date) delta = end_date_aware - now return max(0, delta.days) except Exception as e: return 0 def to_dict(self): return { 'type': self.subscription_type, 'status': self.subscription_status, 'is_active': self.is_active(), 'start_date': self.start_date.isoformat() if self.start_date else None, 'end_date': self.end_date.isoformat() if self.end_date else None, 'days_left': self.days_left(), 'billing_cycle': self.billing_cycle, 'auto_renewal': self.auto_renewal } # ============================================ # 订阅等级工具函数 # ============================================ def get_user_subscription_safe(user_id): """ 安全地获取用户订阅信息 :param user_id: 用户ID :return: UserSubscription 对象或默认免费订阅 """ try: subscription = UserSubscription.query.filter_by(user_id=user_id).first() if not subscription: # 如果用户没有订阅记录,创建默认免费订阅 subscription = UserSubscription( user_id=user_id, subscription_type='free', subscription_status='active' ) db.session.add(subscription) db.session.commit() return subscription except Exception as e: print(f"获取用户订阅信息失败: {e}") # 返回一个临时的免费订阅对象(不保存到数据库) temp_sub = UserSubscription( user_id=user_id, subscription_type='free', subscription_status='active' ) return temp_sub def _get_current_subscription_info(): """ 获取当前登录用户订阅信息的字典形式,未登录或异常时视为免费用户。 小程序场景下从 request.current_user_id 获取用户ID """ try: user_id = getattr(request, 'current_user_id', None) if not user_id: return { 'type': 'free', 'status': 'active', 'is_active': True } sub = get_user_subscription_safe(user_id) return { 'type': sub.subscription_type, 'status': sub.subscription_status, 'is_active': sub.is_active(), 'start_date': sub.start_date.isoformat() if sub.start_date else None, 'end_date': sub.end_date.isoformat() if sub.end_date else None, 'days_left': sub.days_left() } except Exception as e: print(f"获取订阅信息异常: {e}") return { 'type': 'free', 'status': 'active', 'is_active': True } def _subscription_level(sub_type): """将订阅类型映射到等级数值,free=0, pro=1, max=2。""" mapping = {'free': 0, 'pro': 1, 'max': 2} return mapping.get((sub_type or 'free').lower(), 0) def _has_required_level(required: str) -> bool: """判断当前用户是否达到所需订阅级别。""" info = _get_current_subscription_info() if not info.get('is_active', True): return False return _subscription_level(info.get('type')) >= _subscription_level(required) # ============================================ # 权限装饰器 # ============================================ def subscription_required(level='pro'): """ 订阅等级装饰器 - 小程序专用 用法: @subscription_required('pro') # 需要 Pro 或 Max 用户 @subscription_required('max') # 仅限 Max 用户 注意:此装饰器需要配合 使用 """ from functools import wraps def decorator(f): @wraps(f) def decorated_function(*args, **kwargs): if not _has_required_level(level): current_info = _get_current_subscription_info() current_type = current_info.get('type', 'free') is_active = current_info.get('is_active', False) if not is_active: return jsonify({ 'success': False, 'error': '您的订阅已过期,请续费后继续使用', 'error_code': 'SUBSCRIPTION_EXPIRED', 'current_subscription': current_type, 'required_subscription': level }), 403 return jsonify({ 'success': False, 'error': f'此功能需要 {level.upper()} 或更高等级会员', 'error_code': 'SUBSCRIPTION_REQUIRED', 'current_subscription': current_type, 'required_subscription': level }), 403 return f(*args, **kwargs) return decorated_function return decorator def pro_or_max_required(f): """ 快捷装饰器:要求 Pro 或 Max 用户(小程序专用场景) 等同于 @subscription_required('pro') """ from functools import wraps @wraps(f) def decorated_function(*args, **kwargs): if not _has_required_level('pro'): current_info = _get_current_subscription_info() current_type = current_info.get('type', 'free') return jsonify({ 'success': False, 'error': '小程序功能仅对 Pro 和 Max 会员开放', 'error_code': 'MINIPROGRAM_PRO_REQUIRED', 'current_subscription': current_type, 'required_subscription': 'pro', 'message': '请升级到 Pro 或 Max 会员以使用小程序完整功能' }), 403 return f(*args, **kwargs) return decorated_function class User(UserMixin, db.Model): """用户模型""" id = db.Column(db.Integer, primary_key=True) # 基础账号信息(注册时必填) username = db.Column(db.String(80), unique=True, nullable=False) # 用户名 email = db.Column(db.String(120), unique=True, nullable=False) # 邮箱 password_hash = db.Column(db.String(128), nullable=False) # 密码哈希 email_confirmed = db.Column(db.Boolean, default=False) # 邮箱是否验证 wechat_union_id = db.Column(db.String(100), unique=True) # 微信 UnionID wechat_open_id = db.Column(db.String(100)) # 微信 OpenID # 账号状态 created_at = db.Column(db.DateTime, default=beijing_now) # 注册时间 last_seen = db.Column(db.DateTime, default=beijing_now) # 最后活跃时间 status = db.Column(db.String(20), default='active') # 账号状态 active/banned/deleted # 个人资料(可选,后续在个人中心完善) nickname = db.Column(db.String(30)) # 社区昵称 avatar_url = db.Column(db.String(200)) # 头像URL banner_url = db.Column(db.String(200)) # 个人主页背景图 bio = db.Column(db.String(200)) # 个人简介 gender = db.Column(db.String(10)) # 性别 birth_date = db.Column(db.Date) # 生日 location = db.Column(db.String(100)) # 所在地 # 联系方式(可选) phone = db.Column(db.String(20)) # 手机号 wechat_id = db.Column(db.String(80)) # 微信号 # 实名认证信息(可选) real_name = db.Column(db.String(30)) # 真实姓名 id_number = db.Column(db.String(18)) # 身份证号(加密存储) is_verified = db.Column(db.Boolean, default=False) # 是否实名认证 verify_time = db.Column(db.DateTime) # 实名认证时间 # 投资相关信息(可选) trading_experience = db.Column(db.String(200)) # 炒股年限 investment_style = db.Column(db.String(50)) # 投资风格 risk_preference = db.Column(db.String(20)) # 风险偏好 investment_amount = db.Column(db.String(20)) # 投资规模 preferred_markets = db.Column(db.String(200), default='[]') # 偏好市场 JSON # 社区信息(系统自动更新) user_level = db.Column(db.Integer, default=1) # 用户等级 reputation_score = db.Column(db.Integer, default=0) # 信用积分 contribution_point = db.Column(db.Integer, default=0) # 贡献点数 post_count = db.Column(db.Integer, default=0) # 发帖数 comment_count = db.Column(db.Integer, default=0) # 评论数 follower_count = db.Column(db.Integer, default=0) # 粉丝数 following_count = db.Column(db.Integer, default=0) # 关注数 # 创作者信息(可选) is_creator = db.Column(db.Boolean, default=False) # 是否创作者 creator_type = db.Column(db.String(20)) # 创作者类型 creator_tags = db.Column(db.String(200), default='[]') # 创作者标签 JSON # 系统设置 email_notifications = db.Column(db.Boolean, default=True) # 邮件通知 sms_notifications = db.Column(db.Boolean, default=False) # 短信通知 wechat_notifications = db.Column(db.Boolean, default=False) # 微信通知 notification_preferences = db.Column(db.String(500), default='{}') # 通知偏好 JSON privacy_level = db.Column(db.String(20), default='public') # 隐私级别 theme_preference = db.Column(db.String(20), default='light') # 主题偏好 blocked_keywords = db.Column(db.String(500), default='[]') # 屏蔽关键词 JSON # 手机号验证 phone_confirmed = db.Column(db.Boolean, default=False) # 手机是否验证 phone_confirm_time = db.Column(db.DateTime) # 手机验证时间 def __init__(self, username, email=None, password=None, phone=None): """初始化用户,只需要基本信息""" self.username = username if email: self.email = email if phone: self.phone = phone if password: self.set_password(password) self.created_at = beijing_now() self.last_seen = beijing_now() def set_password(self, password): """设置密码""" self.password_hash = generate_password_hash(password) def check_password(self, password): """验证密码""" return check_password_hash(self.password_hash, password) def update_last_seen(self): """更新最后活跃时间""" self.last_seen = beijing_now() # JSON 字段的getter和setter def get_preferred_markets(self): """获取偏好市场列表""" if self.preferred_markets: try: return json.loads(self.preferred_markets) except: return [] return [] def get_blocked_keywords(self): """获取屏蔽关键词列表""" if self.blocked_keywords: try: return json.loads(self.blocked_keywords) except: return [] return [] def get_notification_preferences(self): """获取通知偏好设置""" if self.notification_preferences: try: return json.loads(self.notification_preferences) except: return {} return {} def get_creator_tags(self): """获取创作者标签""" if self.creator_tags: try: return json.loads(self.creator_tags) except: return [] return [] def set_preferred_markets(self, markets): """设置偏好市场""" self.preferred_markets = json.dumps(markets) def set_blocked_keywords(self, keywords): """设置屏蔽关键词""" self.blocked_keywords = json.dumps(keywords) def set_notification_preferences(self, preferences): """设置通知偏好""" self.notification_preferences = json.dumps(preferences) def set_creator_tags(self, tags): """设置创作者标签""" self.creator_tags = json.dumps(tags) def to_dict(self): """返回用户的字典表示""" return { 'id': self.id, 'username': self.username, 'email': self.email, 'nickname': self.nickname, 'avatar_url': get_full_avatar_url(self.avatar_url), # 修改这里 'bio': self.bio, 'is_verified': self.is_verified, 'user_level': self.user_level, 'reputation_score': self.reputation_score, 'is_creator': self.is_creator } def __repr__(self): return f'' class Notification(db.Model): """通知模型""" id = db.Column(db.Integer, primary_key=True) user_id = db.Column(db.Integer, db.ForeignKey('user.id')) type = db.Column(db.String(50)) # 通知类型 content = db.Column(db.Text) # 通知内容 link = db.Column(db.String(200)) # 相关链接 is_read = db.Column(db.Boolean, default=False) # 是否已读 created_at = db.Column(db.DateTime, default=beijing_now) def __init__(self, user_id, type, content, link=None): self.user_id = user_id self.type = type self.content = content self.link = link class Event(db.Model): """事件模型""" id = db.Column(db.Integer, primary_key=True) title = db.Column(db.String(200), nullable=False) description = db.Column(db.Text) # 事件类型与状态 event_type = db.Column(db.String(50)) status = db.Column(db.String(20), default='active') # 时间相关 start_time = db.Column(db.DateTime, default=beijing_now) end_time = db.Column(db.DateTime) created_at = db.Column(db.DateTime, default=beijing_now) updated_at = db.Column(db.DateTime, default=beijing_now) # 热度与统计 hot_score = db.Column(db.Float, default=0) view_count = db.Column(db.Integer, default=0) trending_score = db.Column(db.Float, default=0) post_count = db.Column(db.Integer, default=0) follower_count = db.Column(db.Integer, default=0) # 关联信息 related_industries = db.Column(db.JSON) keywords = db.Column(db.JSON) files = db.Column(db.JSON) importance = db.Column(db.String(20)) related_avg_chg = db.Column(db.Float, default=0) related_max_chg = db.Column(db.Float, default=0) related_week_chg = db.Column(db.Float, default=0) # 新增字段 invest_score = db.Column(db.Integer) # 超预期得分 expectation_surprise_score = db.Column(db.Integer) # 创建者信息 creator_id = db.Column(db.Integer, db.ForeignKey('user.id')) creator = db.relationship('User', backref='created_events') # 关系 posts = db.relationship('Post', backref='event', lazy='dynamic') followers = db.relationship('EventFollow', backref='event', lazy='dynamic') related_stocks = db.relationship('RelatedStock', backref='event', lazy='dynamic') historical_events = db.relationship('HistoricalEvent', backref='event', lazy='dynamic') related_data = db.relationship('RelatedData', backref='event', lazy='dynamic') related_concepts = db.relationship('RelatedConcepts', backref='event', lazy='dynamic') ind_type = db.Column(db.String(255)) @property def keywords_list(self): """返回解析后的关键词列表""" if not self.keywords: return [] if isinstance(self.keywords, list): return self.keywords try: # 如果是字符串,尝试解析JSON if isinstance(self.keywords, str): decoded = json.loads(self.keywords) # 处理Unicode编码的情况 if isinstance(decoded, list): return [ keyword.encode('utf-8').decode('unicode_escape') if isinstance(keyword, str) and '\\u' in keyword else keyword for keyword in decoded ] return [] # 如果已经是字典或其他格式,尝试转换为列表 return list(self.keywords) except (json.JSONDecodeError, AttributeError, TypeError): return [] def set_keywords(self, keywords): """设置关键词列表""" if isinstance(keywords, list): self.keywords = json.dumps(keywords, ensure_ascii=False) elif isinstance(keywords, str): try: # 尝试解析JSON字符串 parsed = json.loads(keywords) if isinstance(parsed, list): self.keywords = json.dumps(parsed, ensure_ascii=False) else: self.keywords = json.dumps([keywords], ensure_ascii=False) except json.JSONDecodeError: # 如果不是有效的JSON,将其作为单个关键词 self.keywords = json.dumps([keywords], ensure_ascii=False) class RelatedStock(db.Model): """相关标的模型""" id = db.Column(db.Integer, primary_key=True) event_id = db.Column(db.Integer, db.ForeignKey('event.id')) stock_code = db.Column(db.String(20)) # 股票代码 stock_name = db.Column(db.String(100)) # 股票名称 sector = db.Column(db.String(100)) # 关联类型 relation_desc = db.Column(db.String(1024)) # 关联原因描述 created_at = db.Column(db.DateTime, default=beijing_now) updated_at = db.Column(db.DateTime, default=beijing_now, onupdate=beijing_now) correlation = db.Column(db.Float()) momentum = db.Column(db.String(1024)) # 动量 # 新增字段 retrieved_sources = db.Column(db.JSON) # 研报检索源数据 retrieved_update_time = db.Column(db.DateTime) # 检索数据更新时间 class RelatedData(db.Model): """关联数据模型""" id = db.Column(db.Integer, primary_key=True) event_id = db.Column(db.Integer, db.ForeignKey('event.id')) title = db.Column(db.String(200)) # 数据标题 data_type = db.Column(db.String(50)) # 数据类型 data_content = db.Column(db.JSON) # 数据内容(JSON格式) description = db.Column(db.Text) # 数据描述 created_at = db.Column(db.DateTime, default=beijing_now) class RelatedConcepts(db.Model): """关联数据模型""" id = db.Column(db.Integer, primary_key=True) event_id = db.Column(db.Integer, db.ForeignKey('event.id')) concept_code = db.Column(db.String(20)) # 数据标题 concept = db.Column(db.String(100)) # 数据类型 reason = db.Column(db.Text) # 数据描述 image_paths = db.Column(db.JSON) # 数据内容(JSON格式) created_at = db.Column(db.DateTime, default=beijing_now) @property def image_paths_list(self): """返回解析后的图片路径列表""" if not self.image_paths: return [] try: # 如果是字符串,先解析成JSON if isinstance(self.image_paths, str): paths = json.loads(self.image_paths) else: paths = self.image_paths # 确保paths是列表 if not isinstance(paths, list): paths = [paths] # 从每个对象中提取path字段 return [item['path'] if isinstance(item, dict) and 'path' in item else item for item in paths] except Exception as e: print(f"Error processing image paths: {e}") return [] def get_first_image_path(self): """获取第一张图片的完整路径""" paths = self.image_paths_list if not paths: return None # 获取第一个路径 first_path = paths[0] # 返回完整路径 return first_path class EventHotHistory(db.Model): """事件热度历史记录""" id = db.Column(db.Integer, primary_key=True) event_id = db.Column(db.Integer, db.ForeignKey('event.id')) score = db.Column(db.Float) # 总分 interaction_score = db.Column(db.Float) # 互动分数 follow_score = db.Column(db.Float) # 关注度分数 view_score = db.Column(db.Float) # 浏览量分数 recent_activity_score = db.Column(db.Float) # 最近活跃度分数 time_decay = db.Column(db.Float) # 时间衰减因子 created_at = db.Column(db.DateTime, default=beijing_now) event = db.relationship('Event', backref='hot_history') class Post(db.Model): """帖子模型""" id = db.Column(db.Integer, primary_key=True) event_id = db.Column(db.Integer, db.ForeignKey('event.id'), nullable=False) user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False) # 内容 title = db.Column(db.String(200)) # 标题(可选) content = db.Column(db.Text, nullable=False) # 内容 content_type = db.Column(db.String(20), default='text') # 内容类型:text/rich_text/link # 时间 created_at = db.Column(db.DateTime, default=beijing_now) updated_at = db.Column(db.DateTime, default=beijing_now, onupdate=beijing_now) # 统计 likes_count = db.Column(db.Integer, default=0) comments_count = db.Column(db.Integer, default=0) view_count = db.Column(db.Integer, default=0) # 状态 status = db.Column(db.String(20), default='active') # active/hidden/deleted is_top = db.Column(db.Boolean, default=False) # 是否置顶 # 关系 user = db.relationship('User', backref='posts') likes = db.relationship('PostLike', backref='post', lazy='dynamic') comments = db.relationship('Comment', backref='post', lazy='dynamic') # 辅助模型 class EventFollow(db.Model): """事件关注""" id = db.Column(db.Integer, primary_key=True) user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False) event_id = db.Column(db.Integer, db.ForeignKey('event.id'), nullable=False) created_at = db.Column(db.DateTime, default=beijing_now) user = db.relationship('User', backref='event_follows') __table_args__ = (db.UniqueConstraint('user_id', 'event_id'),) class PostLike(db.Model): """帖子点赞""" id = db.Column(db.Integer, primary_key=True) user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False) post_id = db.Column(db.Integer, db.ForeignKey('post.id'), nullable=False) created_at = db.Column(db.DateTime, default=beijing_now) user = db.relationship('User', backref='post_likes') __table_args__ = (db.UniqueConstraint('user_id', 'post_id'),) class Comment(db.Model): """评论""" id = db.Column(db.Integer, primary_key=True) post_id = db.Column(db.Integer, db.ForeignKey('post.id'), nullable=False) user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False) content = db.Column(db.Text, nullable=False) parent_id = db.Column(db.Integer, db.ForeignKey('comment.id')) # 父评论ID,用于回复 created_at = db.Column(db.DateTime, default=beijing_now) status = db.Column(db.String(20), default='active') user = db.relationship('User', backref='comments') replies = db.relationship('Comment', backref=db.backref('parent', remote_side=[id])) class StockBasicInfo(db.Model): __tablename__ = 'ea_stocklist' SECCODE = db.Column(db.String(10), primary_key=True) SECNAME = db.Column(db.String(40)) ORGNAME = db.Column(db.String(100)) F001V = db.Column(db.String(100)) # Pinyin abbreviation F003V = db.Column(db.String(50)) # Security category F005V = db.Column(db.String(50)) # Trading market F006D = db.Column(db.DateTime) # Listing date F011V = db.Column(db.String(50)) # Listing status class CompanyInfo(db.Model): __tablename__ = 'ea_baseinfo' SECCODE = db.Column(db.String(10), primary_key=True) SECNAME = db.Column(db.String(40)) ORGNAME = db.Column(db.String(100)) F001V = db.Column(db.String(100)) # English name F003V = db.Column(db.String(40)) # Legal representative F015V = db.Column(db.String(500)) # Main business F016V = db.Column(db.String(4000)) # Business scope F017V = db.Column(db.String(2000)) # Company introduction F030V = db.Column(db.String(60)) # CSRC industry first level F032V = db.Column(db.String(60)) # CSRC industry second level class TradeData(db.Model): __tablename__ = 'ea_trade' SECCODE = db.Column(db.String(10), primary_key=True) SECNAME = db.Column(db.String(40)) TRADEDATE = db.Column(db.Date, primary_key=True) F002N = db.Column(db.Numeric(18, 4)) # Previous close F003N = db.Column(db.Numeric(18, 4)) # Open price F004N = db.Column(db.Numeric(18, 4)) # Trading volume F005N = db.Column(db.Numeric(18, 4)) # High price F006N = db.Column(db.Numeric(18, 4)) # Low price F007N = db.Column(db.Numeric(18, 4)) # Close price F009N = db.Column(db.Numeric(18, 4)) # Change F010N = db.Column(db.Numeric(18, 4)) # Change percentage F011N = db.Column(db.Numeric(18, 4)) # Trading amount class SectorInfo(db.Model): __tablename__ = 'ea_sector' SECCODE = db.Column(db.String(10), primary_key=True) SECNAME = db.Column(db.String(40)) F001V = db.Column(db.String(50), primary_key=True) # Classification standard code F002V = db.Column(db.String(50)) # Classification standard F003V = db.Column(db.String(50)) # Sector code F004V = db.Column(db.String(50)) # Sector level 1 name F005V = db.Column(db.String(50)) # Sector level 2 name F006V = db.Column(db.String(50)) # Sector level 3 name F007V = db.Column(db.String(50)) # Sector level 4 name def init_sywg_industry_cache(): """ 初始化申银万国行业分类缓存 在程序启动时调用,将所有行业分类数据加载到内存中 """ global SYWG_INDUSTRY_CACHE try: app.logger.info('开始初始化申银万国行业分类缓存...') # 定义层级映射关系 level_column_map = { 2: 'f004v', # level2 对应一级行业 3: 'f005v', # level3 对应二级行业 4: 'f006v', # level4 对应三级行业 5: 'f007v' # level5 对应四级行业 } # 定义代码前缀长度映射 prefix_length_map = { 2: 3, # S + 2位 3: 5, # S + 2位 + 2位 4: 7, # S + 2位 + 2位 + 2位 5: 9 # 完整代码 } # 遍历所有层级 for level, column_name in level_column_map.items(): # 查询该层级的所有行业及其代码 query_sql = f""" SELECT DISTINCT {column_name} as industry_name, f003v as code FROM ea_sector WHERE f002v = '申银万国行业分类' AND {column_name} IS NOT NULL AND {column_name} != '' """ result = db.session.execute(text(query_sql)) rows = result.fetchall() # 构建该层级的缓存 industry_dict = {} for row in rows: industry_name = row[0] code = row[1] if industry_name and code: # 获取代码前缀 prefix_length = prefix_length_map[level] code_prefix = code[:prefix_length] # 将前缀添加到对应行业的列表中 if industry_name not in industry_dict: industry_dict[industry_name] = set() industry_dict[industry_name].add(code_prefix) # 将set转换为list并存储到缓存中 for industry_name, prefixes in industry_dict.items(): SYWG_INDUSTRY_CACHE[level][industry_name] = list(prefixes) app.logger.info(f'Level {level} 缓存完成,共 {len(industry_dict)} 个行业') # 统计总数 total_count = sum(len(industries) for industries in SYWG_INDUSTRY_CACHE.values()) app.logger.info(f'申银万国行业分类缓存初始化完成,共缓存 {total_count} 个行业分类') except Exception as e: app.logger.error(f'初始化申银万国行业分类缓存失败: {str(e)}') import traceback app.logger.error(traceback.format_exc()) def send_async_email(msg): """异步发送邮件""" try: mail.send(msg) except Exception as e: app.logger.error(f"Error sending async email: {str(e)}") def verify_sms_code(phone_number, code): """验证短信验证码""" stored_code = session.get('sms_verification_code') stored_phone = session.get('sms_verification_phone') expiration = session.get('sms_verification_expiration') if not all([stored_code, stored_phone, expiration]): return False, "请先获取验证码" if stored_phone != phone_number: return False, "手机号与验证码不匹配" if beijing_now().timestamp() > expiration: return False, "验证码已过期" if code != stored_code: return False, "验证码错误" return True, "验证成功" def allowed_file(filename): return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS # ============================================ # 订阅相关 API 接口(小程序专用) # ============================================ @app.route('/api/subscription/info', methods=['GET']) @token_required def get_subscription_info(): """ 获取当前用户的订阅信息 - 小程序专用接口 返回用户当前订阅类型、状态、剩余天数等信息 """ try: info = _get_current_subscription_info() return jsonify({ 'success': True, 'data': info }) except Exception as e: print(f"获取订阅信息错误: {e}") return jsonify({ 'success': True, 'data': { 'type': 'free', 'status': 'active', 'is_active': True, 'days_left': 0 } }) @app.route('/api/subscription/check', methods=['GET']) @token_required def check_subscription_access(): """ 检查当前用户是否有权限使用小程序功能 返回:是否为 Pro/Max 用户 """ try: has_access = _has_required_level('pro') info = _get_current_subscription_info() return jsonify({ 'success': True, 'data': { 'has_access': has_access, 'subscription_type': info.get('type', 'free'), 'is_active': info.get('is_active', False), 'message': '您可以使用小程序功能' if has_access else '小程序功能仅对 Pro 和 Max 会员开放' } }) except Exception as e: print(f"检查订阅权限错误: {e}") return jsonify({ 'success': False, 'error': str(e) }), 500 # ============================================ # 现有接口示例(应用权限控制) # ============================================ # 更新视图函数 @app.route('/settings/profile', methods=['POST']) @token_required def update_profile(): """更新个人资料""" try: user = request.user form = request.form # 基本信息更新 user.nickname = form.get('nickname') user.bio = form.get('bio') user.gender = form.get('gender') user.birth_date = datetime.strptime(form.get('birth_date'), '%Y-%m-%d') if form.get('birth_date') else None user.phone = form.get('phone') user.location = form.get('location') user.wechat_id = form.get('wechat_id') # 处理头像上传 if 'avatar' in request.files: file = request.files['avatar'] if file and allowed_file(file.filename): # 生成安全的文件名 filename = secure_filename(f"{user.id}_{int(datetime.now().timestamp())}_{file.filename}") filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename) # 确保上传目录存在 os.makedirs(os.path.dirname(filepath), exist_ok=True) # 保存并处理图片 image = Image.open(file) image.thumbnail((300, 300)) # 调整图片大小 image.save(filepath) # 更新用户头像URL user.avatar_url = f'{DOMAIN}/static/uploads/avatars/{filename}' db.session.commit() return jsonify({'success': True, 'message': '个人资料已更新'}) except Exception as e: db.session.rollback() app.logger.error(f"Error updating profile: {str(e)}") return jsonify({'success': False, 'message': '更新失败,请重试'}) # 投资偏好设置 @app.route('/settings/investment_preferences', methods=['POST']) @token_required def update_investment_preferences(): """更新投资偏好""" try: user = request.user form = request.form user.trading_experience = form.get('trading_experience') user.investment_style = form.get('investment_style') user.risk_preference = form.get('risk_preference') user.investment_amount = form.get('investment_amount') user.preferred_markets = json.dumps(request.form.getlist('preferred_markets')) db.session.commit() return jsonify({'success': True, 'message': '投资偏好已更新'}) except Exception as e: db.session.rollback() app.logger.error(f"Error updating investment preferences: {str(e)}") return jsonify({'success': False, 'message': '更新失败,请重试'}) def get_clickhouse_client(): return Cclient( host='222.128.1.157', port=18000, user='default', password='Zzl33818!', database='stock' ) @app.route('/api/stock//kline') def get_stock_kline(stock_code): """获取股票K线数据 - 仅限 Pro/Max 会员(小程序功能)""" chart_type = request.args.get('chart_type', 'daily') # 默认改为daily event_time = request.args.get('event_time') try: event_datetime = datetime.fromisoformat(event_time) if event_time else datetime.now() except ValueError: return jsonify({'error': 'Invalid event_time format'}), 400 # 获取股票名称 try: with engine.connect() as conn: result = conn.execute(text( "SELECT SECNAME FROM ea_stocklist WHERE SECCODE = :code" ), {"code": stock_code.split('.')[0]}).fetchone() stock_name = result[0] if result else 'Unknown' except Exception as e: print(f"Error getting stock name: {e}") stock_name = 'Unknown' if chart_type == 'daily': return get_daily_kline(stock_code, event_datetime, stock_name) elif chart_type == 'minute': return get_minute_kline(stock_code, event_datetime, stock_name) else: return jsonify({ 'error': 'Invalid chart type', 'message': 'Supported types: daily, minute', 'code': stock_code, 'name': stock_name }), 400 def get_daily_kline(stock_code, event_datetime, stock_name): """处理日K线数据""" stock_code = stock_code.split('.')[0] print(f"Debug: stock_code={stock_code}, event_datetime={event_datetime}, stock_name={stock_name}") try: with engine.connect() as conn: # 获取事件日期前后的数据 kline_sql = """ WITH date_range AS (SELECT TRADEDATE \ FROM ea_trade \ WHERE SECCODE = :stock_code \ AND TRADEDATE BETWEEN DATE_SUB(:trade_date, INTERVAL 60 DAY) \ AND :trade_date \ GROUP BY TRADEDATE \ ORDER BY TRADEDATE) SELECT t.TRADEDATE, CAST(t.F003N AS FLOAT) as open, CAST(t.F007N AS FLOAT) as close, CAST(t.F005N AS FLOAT) as high, CAST(t.F006N AS FLOAT) as low, CAST(t.F004N AS FLOAT) as volume FROM ea_trade t JOIN date_range d \ ON t.TRADEDATE = d.TRADEDATE WHERE t.SECCODE = :stock_code ORDER BY t.TRADEDATE \ """ result = conn.execute(text(kline_sql), { "stock_code": stock_code, "trade_date": event_datetime.date() }).fetchall() print(f"Debug: Query result count: {len(result)}") if not result: print("Debug: No data found, trying fallback query...") # 如果没有数据,尝试获取最近的交易数据 fallback_sql = """ SELECT TRADEDATE, CAST(F003N AS FLOAT) as open, CAST(F007N AS FLOAT) as close, CAST(F005N AS FLOAT) as high, CAST(F006N AS FLOAT) as low, CAST(F004N AS FLOAT) as volume FROM ea_trade WHERE SECCODE = :stock_code AND TRADEDATE <= :trade_date AND F003N IS NOT NULL AND F007N IS NOT NULL AND F005N IS NOT NULL AND F006N IS NOT NULL AND F004N IS NOT NULL ORDER BY TRADEDATE LIMIT 100 \ """ result = conn.execute(text(fallback_sql), { "stock_code": stock_code, "trade_date": event_datetime.date() }).fetchall() print(f"Debug: Fallback query result count: {len(result)}") if not result: return jsonify({ 'error': 'No data available', 'code': stock_code, 'name': stock_name, 'data': [], 'trade_date': event_datetime.date().strftime('%Y-%m-%d'), 'type': 'daily' }) kline_data = [] for row in result: try: kline_data.append({ 'time': row.TRADEDATE.strftime('%Y-%m-%d'), 'open': float(row.open) if row.open else 0, 'high': float(row.high) if row.high else 0, 'low': float(row.low) if row.low else 0, 'close': float(row.close) if row.close else 0, 'volume': float(row.volume) if row.volume else 0 }) except (ValueError, TypeError) as e: print(f"Debug: Error processing row: {e}") continue print(f"Debug: Final kline_data count: {len(kline_data)}") return jsonify({ 'code': stock_code, 'name': stock_name, 'data': kline_data, 'trade_date': event_datetime.date().strftime('%Y-%m-%d'), 'type': 'daily', 'is_history': True, 'data_count': len(kline_data) }) except Exception as e: print(f"Error in get_daily_kline: {e}") return jsonify({ 'error': f'Database error: {str(e)}', 'code': stock_code, 'name': stock_name, 'data': [], 'trade_date': event_datetime.date().strftime('%Y-%m-%d'), 'type': 'daily' }), 500 def get_minute_kline(stock_code, event_datetime, stock_name): """处理分钟K线数据 - 包含零轴(昨日收盘价)""" client = get_clickhouse_client() stock_code_short = stock_code.split('.')[0] # 获取不带后缀的股票代码 def get_trading_days(): trading_days = set() with open('tdays.csv', 'r') as f: reader = csv.DictReader(f) for row in reader: trading_days.add(datetime.strptime(row['DateTime'], '%Y/%m/%d').date()) return trading_days trading_days = get_trading_days() def find_next_trading_day(current_date): """找到下一个交易日""" while current_date <= max(trading_days): current_date += timedelta(days=1) if current_date in trading_days: return current_date return None def find_prev_trading_day(current_date): """找到前一个交易日""" while current_date >= min(trading_days): current_date -= timedelta(days=1) if current_date in trading_days: return current_date return None def get_prev_close(stock_code_short, target_date): """获取前一交易日的收盘价作为零轴基准""" prev_date = find_prev_trading_day(target_date) if not prev_date: return None try: with engine.connect() as conn: # 查询前一交易日的收盘价 sql = """ SELECT CAST(F007N AS FLOAT) as close FROM ea_trade WHERE SECCODE = :stock_code AND TRADEDATE = :prev_date AND F007N IS NOT NULL LIMIT 1 \ """ result = conn.execute(text(sql), { "stock_code": stock_code_short, "prev_date": prev_date }).fetchone() if result: return float(result.close) else: # 如果指定日期没有数据,尝试获取最近的收盘价 fallback_sql = """ SELECT CAST(F007N AS FLOAT) as close, TRADEDATE FROM ea_trade WHERE SECCODE = :stock_code AND TRADEDATE \ < :target_date AND F007N IS NOT NULL ORDER BY TRADEDATE DESC LIMIT 1 \ """ result = conn.execute(text(fallback_sql), { "stock_code": stock_code_short, "target_date": target_date }).fetchone() if result: print(f"Using close price from {result.TRADEDATE} as zero axis") return float(result.close) except Exception as e: print(f"Error getting previous close: {e}") return None target_date = event_datetime.date() is_after_market = event_datetime.time() > dt_time(15, 0) # 核心逻辑:先判断当前日期是否是交易日,以及是否已收盘 if target_date in trading_days and is_after_market: # 如果是交易日且已收盘,查找下一个交易日 next_trade_date = find_next_trading_day(target_date) if next_trade_date: target_date = next_trade_date elif target_date not in trading_days: # 如果不是交易日,先尝试找下一个交易日 next_trade_date = find_next_trading_day(target_date) if next_trade_date: target_date = next_trade_date else: # 如果找不到下一个交易日,找最近的历史交易日 target_date = find_prev_trading_day(target_date) if not target_date: return jsonify({ 'error': 'No data available', 'code': stock_code, 'name': stock_name, 'data': [], 'trade_date': event_datetime.date().strftime('%Y-%m-%d'), 'type': 'minute' }) # 获取前一交易日收盘价作为零轴 zero_axis = get_prev_close(stock_code_short, target_date) # 获取目标日期的完整交易时段数据 data = client.execute(""" SELECT timestamp, open, high, low, close, volume, amt FROM stock_minute WHERE code = %(code)s AND timestamp BETWEEN %(start)s AND %(end)s ORDER BY timestamp """, { 'code': stock_code, 'start': datetime.combine(target_date, dt_time(9, 30)), 'end': datetime.combine(target_date, dt_time(15, 0)) }) kline_data = [] for row in data: point = { 'time': row[0].strftime('%H:%M'), 'open': float(row[1]), 'high': float(row[2]), 'low': float(row[3]), 'close': float(row[4]), 'volume': float(row[5]), 'amount': float(row[6]) } # 如果有零轴数据,计算涨跌幅和涨跌额 if zero_axis: point['prev_close'] = zero_axis point['change'] = point['close'] - zero_axis # 涨跌额 point['change_pct'] = ((point['close'] - zero_axis) / zero_axis * 100) if zero_axis != 0 else 0 # 涨跌幅百分比 kline_data.append(point) response_data = { 'code': stock_code, 'name': stock_name, 'data': kline_data, 'trade_date': target_date.strftime('%Y-%m-%d'), 'type': 'minute', 'is_history': target_date < event_datetime.date() } # 添加零轴信息到响应中 if zero_axis: response_data['zero_axis'] = zero_axis response_data['prev_close'] = zero_axis # 计算当日整体涨跌幅(如果有数据) if kline_data: last_close = kline_data[-1]['close'] response_data['day_change'] = last_close - zero_axis response_data['day_change_pct'] = ((last_close - zero_axis) / zero_axis * 100) if zero_axis != 0 else 0 return jsonify(response_data) class HistoricalEvent(db.Model): """历史事件模型""" id = db.Column(db.Integer, primary_key=True) event_id = db.Column(db.Integer, db.ForeignKey('event.id')) title = db.Column(db.String(200)) content = db.Column(db.Text) event_date = db.Column(db.DateTime) relevance = db.Column(db.Integer) # 相关性 importance = db.Column(db.Integer) # 重要程度 related_stock = db.Column(db.JSON) # 保留JSON字段 created_at = db.Column(db.DateTime, default=beijing_now) # 新增关系 stocks = db.relationship('HistoricalEventStock', backref='historical_event', lazy='dynamic', cascade='all, delete-orphan') class HistoricalEventStock(db.Model): """历史事件相关股票模型""" __tablename__ = 'historical_event_stocks' id = db.Column(db.Integer, primary_key=True) historical_event_id = db.Column(db.Integer, db.ForeignKey('historical_event.id'), nullable=False) stock_code = db.Column(db.String(20), nullable=False) stock_name = db.Column(db.String(50)) relation_desc = db.Column(db.Text) correlation = db.Column(db.Float, default=0.5) sector = db.Column(db.String(100)) created_at = db.Column(db.DateTime, default=beijing_now) __table_args__ = ( db.UniqueConstraint('historical_event_id', 'stock_code', name='unique_event_stock'), ) @app.route('/event/follow/', methods=['POST']) @token_required def follow_event(event_id): """关注/取消关注事件""" event = Event.query.get_or_404(event_id) follow = EventFollow.query.filter_by( user_id=request.user.id, event_id=event_id ).first() try: if follow: db.session.delete(follow) event.follower_count -= 1 message = '已取消关注' else: follow = EventFollow(user_id=request.user.id, event_id=event_id) db.session.add(follow) event.follower_count += 1 message = '已关注' db.session.commit() return jsonify({'success': True, 'message': message}) except Exception as e: db.session.rollback() return jsonify({'success': False, 'message': '操作失败,请重试'}) # 帖子相关路由 @app.route('/post/create/', methods=['GET', 'POST']) @token_required def create_post(event_id): """创建新帖子""" event = Event.query.get_or_404(event_id) if request.method == 'POST': try: post = Post( event_id=event_id, user_id=request.user.id, title=request.form.get('title'), content=request.form['content'], content_type=request.form.get('content_type', 'text') ) db.session.add(post) event.post_count += 1 db.session.commit() # 检查是否是 API 请求(通过 Accept header 或 Content-Type 判断) if request.headers.get('Accept') == 'application/json' or \ request.headers.get('Content-Type', '').startswith('application/json'): return jsonify({ 'success': True, 'message': '发布成功', 'data': { 'post_id': post.id, 'event_id': event_id, 'redirect_url': url_for('event_detail', event_id=event_id) } }) else: # 传统表单提交,添加成功消息并重定向 flash('发布成功', 'success') return redirect(url_for('event_detail', event_id=event_id)) except Exception as e: db.session.rollback() if request.headers.get('Accept') == 'application/json' or \ request.headers.get('Content-Type', '').startswith('application/json'): return jsonify({ 'success': False, 'message': '发布失败,请重试' }), 400 else: flash('发布失败,请重试', 'error') app.logger.error(f"Error creating post: {str(e)}") return render_template('projects/create_post.html', event=event) # 点赞相关路由 @app.route('/post/like/', methods=['POST']) @token_required def like_post(post_id): """点赞/取消点赞帖子""" post = Post.query.get_or_404(post_id) like = PostLike.query.filter_by( user_id=request.user.id, post_id=post_id ).first() try: if like: # 取消点赞 db.session.delete(like) post.likes_count -= 1 message = '已取消点赞' else: # 添加点赞 like = PostLike(user_id=request.user.id, post_id=post_id) db.session.add(like) post.likes_count += 1 message = '已点赞' db.session.commit() return jsonify({ 'success': True, 'message': message, 'likes_count': post.likes_count }) except Exception as e: db.session.rollback() return jsonify({'success': False, 'message': '操作失败,请重试'}) def update_user_activity(): """更新用户活跃度""" with app.app_context(): try: # 获取过去7天内的用户活动数据 seven_days_ago = beijing_now() - timedelta(days=7) # 统计用户发帖、评论、点赞等活动 active_users = db.session.query( User.id, db.func.count(Post.id).label('post_count'), db.func.count(Comment.id).label('comment_count'), db.func.count(PostLike.id).label('like_count') ).outerjoin(Post, User.id == Post.user_id) \ .outerjoin(Comment, User.id == Comment.user_id) \ .outerjoin(PostLike, User.id == PostLike.user_id) \ .filter( db.or_( Post.created_at >= seven_days_ago, Comment.created_at >= seven_days_ago, PostLike.created_at >= seven_days_ago ) ).group_by(User.id).all() # 更新用户活跃度分数 for user_id, post_count, comment_count, like_count in active_users: activity_score = post_count * 2 + comment_count * 1 + like_count * 0.5 User.query.filter_by(id=user_id).update({ 'activity_score': activity_score, 'last_active': beijing_now() }) db.session.commit() current_app.logger.info("Successfully updated user activity scores") except Exception as e: db.session.rollback() current_app.logger.error(f"Error updating user activity: {str(e)}") @app.route('/post/comment/', methods=['POST']) @token_required def add_comment(post_id): """添加评论""" post = Post.query.get_or_404(post_id) try: content = request.form.get('content') parent_id = request.form.get('parent_id', type=int) if not content: return jsonify({'success': False, 'message': '评论内容不能为空'}) comment = Comment( post_id=post_id, user_id=request.user.id, content=content, parent_id=parent_id ) db.session.add(comment) post.comments_count += 1 db.session.commit() return jsonify({ 'success': True, 'message': '评论成功', 'comment': { 'id': comment.id, 'content': comment.content, 'user_name': request.user.username, 'user_avatar': get_full_avatar_url(request.user.avatar_url), # 修改这里 'created_at': comment.created_at.strftime('%Y-%m-%d %H:%M:%S') } }) except Exception as e: db.session.rollback() return jsonify({'success': False, 'message': '评论失败,请重试'}) @app.route('/post/comments/') def get_comments(post_id): """获取帖子评论列表""" page = request.args.get('page', 1, type=int) # 获取顶层评论 comments = Comment.query.filter_by( post_id=post_id, parent_id=None, status='active' ).order_by( Comment.created_at.desc() ).paginate(page=page, per_page=20) # 同时获取每个顶层评论的部分回复 comments_data = [] for comment in comments.items: replies = Comment.query.filter_by( parent_id=comment.id, status='active' ).order_by( Comment.created_at.asc() ).limit(3).all() comments_data.append({ 'id': comment.id, 'content': comment.content, 'user': { 'id': comment.user.id, 'username': comment.user.username, 'avatar_url': get_full_avatar_url(comment.user.avatar_url), # 修改这里 }, 'created_at': comment.created_at.strftime('%Y-%m-%d %H:%M:%S'), 'replies': [{ 'id': reply.id, 'content': reply.content, 'user': { 'id': reply.user.id, 'username': reply.user.username, 'avatar_url': get_full_avatar_url(reply.user.avatar_url), # 修改这里 }, 'created_at': reply.created_at.strftime('%Y-%m-%d %H:%M:%S') } for reply in replies] }) return jsonify({ 'comments': comments_data, 'total': comments.total, 'pages': comments.pages, 'current_page': comments.page }) beijing_tz = pytz.timezone('Asia/Shanghai') def update_hot_scores(): """ 更新所有事件的热度分数 在Flask应用上下文中执行数据库操作 """ with app.app_context(): try: # 获取所有活跃事件 events = Event.query.filter_by(status='active').all() current_time = beijing_now() for event in events: # 确保created_at有时区信息,解决naive和aware datetime比较问题 created_at = beijing_tz.localize( event.created_at) if event.created_at.tzinfo is None else event.created_at # 使用处理后的created_at计算hours_passed hours_passed = (current_time - created_at).total_seconds() / 3600 # 基础分数 - 帖子数和评论数 posts = Post.query.filter_by(event_id=event.id).all() post_count = len(posts) comment_count = sum(post.comments_count for post in posts) # 获取24小时内的新增帖子数 recent_posts = Post.query.filter( Post.event_id == event.id, Post.created_at >= current_time - timedelta(hours=24) ).count() # 获取点赞数 like_count = db.session.query(func.sum(Post.likes_count)).filter( Post.event_id == event.id ).scalar() or 0 # 基础互动分数 = 帖子数 * 2 + 评论数 * 1 + 点赞数 * 0.5 interaction_score = (post_count * 2) + (comment_count * 1) + (like_count * 0.5) # 关注度分数 = 关注人数 * 3 follow_score = event.follower_count * 3 # 浏览量分数 = log(浏览量) if event.view_count > 0: view_score = math.log(event.view_count) * 2 else: view_score = 0 # 时间衰减因子 - 使用上面已经计算好的hours_passed time_decay = math.exp(-hours_passed / 72) # 3天后衰减为原始分数的1/e # 最近活跃度权重 recent_activity_weight = (recent_posts * 5) # 24小时内的新帖权重高 # 总分 = (互动分数 + 关注度分数 + 浏览量分数 + 最近活跃度) * 时间衰减 total_score = (interaction_score + follow_score + view_score + recent_activity_weight) * time_decay # 更新热度分数 event.hot_score = round(total_score, 2) # 分数的对数值作为事件的trending_score (用于趋势排序) if total_score > 0: event.trending_score = math.log(total_score) * time_decay else: event.trending_score = 0 # 记录热度历史 history = EventHotHistory( event_id=event.id, score=event.hot_score, interaction_score=interaction_score, follow_score=follow_score, view_score=view_score, recent_activity_score=recent_activity_weight, time_decay=time_decay ) db.session.add(history) db.session.commit() app.logger.info("Successfully updated event hot scores") except Exception as e: db.session.rollback() app.logger.error(f"Error updating hot scores: {str(e)}") raise # 添加热度历史记录模型 def calculate_hot_score(event): """计算事件热度分数""" current_time = beijing_now() time_diff = (current_time - event.created_at).total_seconds() / 3600 # 转换为小时 # 基础分数 = 浏览量 * 0.1 + 帖子数 * 0.5 + 关注数 * 1 base_score = ( event.view_count * 0.1 + event.post_count * 0.5 + event.follower_count * 1 ) # 时间衰减因子,72小时(3天)内的事件获得较高权重 time_factor = max(1 - (time_diff / 72), 0.1) return base_score * time_factor @app.route('/api/sector/hierarchy', methods=['GET']) def api_sector_hierarchy(): """行业层级关系接口:展示多个行业分类体系的层级结构""" try: # 定义需要返回的行业分类体系 classification_systems = [ '申银万国行业分类' ] result = [] # 改为数组 for classification in classification_systems: # 查询特定分类标准的数据 sectors = SectorInfo.query.filter_by(F002V=classification).all() if not sectors: continue # 构建该分类体系的层级结构 hierarchy = {} for sector in sectors: level1 = sector.F004V # 一级行业 level2 = sector.F005V # 二级行业 level3 = sector.F006V # 三级行业 level4 = sector.F007V # 四级行业 # 统计股票数量 stock_code = sector.SECCODE # 初始化一级行业 if level1 not in hierarchy: hierarchy[level1] = { 'level2_sectors': {}, 'stocks': set(), 'stocks_count': 0 } # 添加股票到一级行业 if stock_code: hierarchy[level1]['stocks'].add(stock_code) # 处理二级行业 if level2: if level2 not in hierarchy[level1]['level2_sectors']: hierarchy[level1]['level2_sectors'][level2] = { 'level3_sectors': {}, 'stocks': set(), 'stocks_count': 0 } # 添加股票到二级行业 if stock_code: hierarchy[level1]['level2_sectors'][level2]['stocks'].add(stock_code) # 处理三级行业 if level3: if level3 not in hierarchy[level1]['level2_sectors'][level2]['level3_sectors']: hierarchy[level1]['level2_sectors'][level2]['level3_sectors'][level3] = { 'level4_sectors': [], 'stocks': set(), 'stocks_count': 0 } # 添加股票到三级行业 if stock_code: hierarchy[level1]['level2_sectors'][level2]['level3_sectors'][level3]['stocks'].add( stock_code) # 处理四级行业 if level4 and level4 not in \ hierarchy[level1]['level2_sectors'][level2]['level3_sectors'][level3]['level4_sectors']: hierarchy[level1]['level2_sectors'][level2]['level3_sectors'][level3][ 'level4_sectors'].append(level4) # 计算股票数量并清理set对象 formatted_hierarchy = [] for level1, level1_data in hierarchy.items(): level1_item = { 'level1_sector': level1, 'stocks_count': len(level1_data['stocks']), 'level2_sectors': [] } for level2, level2_data in level1_data['level2_sectors'].items(): level2_item = { 'level2_sector': level2, 'stocks_count': len(level2_data['stocks']), 'level3_sectors': [] } for level3, level3_data in level2_data['level3_sectors'].items(): level3_item = { 'level3_sector': level3, 'stocks_count': len(level3_data['stocks']), 'level4_sectors': level3_data['level4_sectors'] } level2_item['level3_sectors'].append(level3_item) # 按股票数量排序 level2_item['level3_sectors'].sort(key=lambda x: x['stocks_count'], reverse=True) level1_item['level2_sectors'].append(level2_item) # 按股票数量排序 level1_item['level2_sectors'].sort(key=lambda x: x['stocks_count'], reverse=True) formatted_hierarchy.append(level1_item) # 按股票数量排序 formatted_hierarchy.sort(key=lambda x: x['stocks_count'], reverse=True) # 将该分类体系添加到结果数组中 result.append({ 'classification_name': classification, 'total_level1_count': len(formatted_hierarchy), 'total_stocks_count': sum(item['stocks_count'] for item in formatted_hierarchy), 'hierarchy': formatted_hierarchy }) # 按总股票数量排序 result.sort(key=lambda x: x['total_stocks_count'], reverse=True) return jsonify({ "code": 200, "message": "success", "data": result }) except Exception as e: return jsonify({ "code": 500, "message": str(e), "data": None }), 500 @app.route('/api/sector/banner', methods=['GET']) def api_sector_banner(): """行业分类 banner 接口:返回一级分类和对应二级行业列表""" try: # 原始映射 sector_map = { '石油石化': '大周期', '煤炭': '大周期', '有色金属': '大周期', '钢铁': '大周期', '基础化工': '大周期', '建筑材料': '大周期', '机械设备': '大周期', '电力设备及新能源': '大周期', '国防军工': '大周期', '电力设备': '大周期', '电网设备': '大周期', '风力发电': '大周期', '太阳能发电': '大周期', '建筑装饰': '大周期', '汽车': '大消费', '家用电器': '大消费', '酒类': '大消费', '食品饮料': '大消费', '医药生物': '大消费', '纺织服饰': '大消费', '农林牧渔': '大消费', '商贸零售': '大消费', '轻工制造': '大消费', '消费者服务': '大消费', '美容护理': '大消费', '社会服务': '大消费', '银行': '大金融地产', '证券': '大金融地产', '保险': '大金融地产', '多元金融': '大金融地产', '综合金融': '大金融地产', '房地产': '大金融地产', '非银金融': '大金融地产', '计算机': 'TMT板块', '电子': 'TMT板块', '传媒': 'TMT板块', '通信': 'TMT板块', '交通运输': '公共产业板块', '电力公用事业': '公共产业板块', '建筑': '公共产业板块', '环保': '公共产业板块', '综合': '公共产业板块', '公用事业': '公共产业板块' } # 重组结构为 一级 → [二级...] result_dict = {} for sub_sector, primary_sector in sector_map.items(): result_dict.setdefault(primary_sector, []).append(sub_sector) # 格式化成列表 result_list = [ {"primary_sector": primary, "sub_sectors": subs} for primary, subs in result_dict.items() ] return jsonify({ "code": 200, "message": "success", "data": result_list }) except Exception as e: return jsonify({ "code": 500, "message": str(e), "data": None }), 500 def get_limit_rate(stock_code): """ 根据股票代码获取涨跌停限制比例 Args: stock_code: 股票代码 Returns: float: 涨跌停限制比例 """ if not stock_code: return 10.0 # 去掉市场后缀 clean_code = stock_code.replace('.SH', '').replace('.SZ', '').replace('.BJ', '') # ST股票 (5%涨跌停) if 'ST' in stock_code.upper(): return 5.0 # 科创板 (688开头, 20%涨跌停) if clean_code.startswith('688'): return 20.0 # 创业板注册制 (30开头, 20%涨跌停) if clean_code.startswith('30'): return 20.0 # 北交所 (43、83、87开头, 30%涨跌停) if clean_code.startswith(('43', '83', '87')): return 30.0 # 主板、中小板默认 (10%涨跌停) return 10.0 @app.route('/api/events', methods=['GET']) def api_get_events(): """ 获取事件列表API - 优化版本(保持完全兼容) 仅限 Pro/Max 会员访问(小程序功能) 优化策略: 1. 使用ind_type字段简化内部逻辑 2. 批量获取股票行情,包括周涨跌计算 3. 保持原有返回数据结构不变 """ try: # ==================== 参数解析 ==================== # 分页参数 page = max(1, request.args.get('page', 1, type=int)) per_page = min(100, max(1, request.args.get('per_page', 10, type=int))) # 基础筛选参数 event_type = request.args.get('type', 'all') event_status = request.args.get('status', 'active') importance = request.args.get('importance', 'all') # 日期筛选参数 start_date = request.args.get('start_date') end_date = request.args.get('end_date') date_range = request.args.get('date_range') recent_days = request.args.get('recent_days', type=int) # 行业筛选参数(重新设计) ind_type = request.args.get('ind_type', 'all') stock_sector = request.args.get('stock_sector', 'all') secondary_sector = request.args.get('secondary_sector', 'all') # 新的行业层级筛选参数 industry_level = request.args.get('industry_level', type=int) # 筛选层级:1-4 industry_classification = request.args.get('industry_classification') # 行业名称 # 如果使用旧参数,映射到ind_type if ind_type == 'all' and stock_sector != 'all': ind_type = stock_sector # 标签筛选参数 tag = request.args.get('tag') tags = request.args.get('tags') keywords = request.args.get('keywords') # 搜索参数 search_query = request.args.get('q') search_type = request.args.get('search_type', 'topic') search_fields = request.args.get('search_fields', 'title,description').split(',') # 排序参数 sort_by = request.args.get('sort', 'new') return_type = request.args.get('return_type', 'avg') order = request.args.get('order', 'desc') # 收益率筛选参数 min_avg_return = request.args.get('min_avg_return', type=float) max_avg_return = request.args.get('max_avg_return', type=float) min_max_return = request.args.get('min_max_return', type=float) max_max_return = request.args.get('max_max_return', type=float) min_week_return = request.args.get('min_week_return', type=float) max_week_return = request.args.get('max_week_return', type=float) # 其他筛选参数 min_hot_score = request.args.get('min_hot_score', type=float) max_hot_score = request.args.get('max_hot_score', type=float) min_view_count = request.args.get('min_view_count', type=int) creator_id = request.args.get('creator_id', type=int) # 返回格式参数 include_creator = request.args.get('include_creator', 'true').lower() == 'true' include_stats = request.args.get('include_stats', 'true').lower() == 'true' include_related_data = request.args.get('include_related_data', 'false').lower() == 'true' # ==================== 构建查询 ==================== query = Event.query # 状态筛选 if event_status != 'all': query = query.filter_by(status=event_status) # 事件类型筛选 if event_type != 'all': query = query.filter_by(event_type=event_type) # 重要性筛选 if importance != 'all': query = query.filter_by(importance=importance) # 行业类型筛选(使用ind_type字段) if ind_type != 'all': query = query.filter_by(ind_type=ind_type) # 创建者筛选 if creator_id: query = query.filter_by(creator_id=creator_id) # ==================== 日期筛选 ==================== if recent_days: cutoff_date = datetime.now() - timedelta(days=recent_days) query = query.filter(Event.created_at >= cutoff_date) else: # 处理日期范围字符串 if date_range and ' 至 ' in date_range: try: start_date_str, end_date_str = date_range.split(' 至 ') start_date = start_date_str.strip() end_date = end_date_str.strip() except ValueError: pass # 开始日期 if start_date: try: if len(start_date) == 10: start_datetime = datetime.strptime(start_date, '%Y-%m-%d') else: start_datetime = datetime.strptime(start_date, '%Y-%m-%d %H:%M:%S') query = query.filter(Event.created_at >= start_datetime) except ValueError: pass # 结束日期 if end_date: try: if len(end_date) == 10: end_datetime = datetime.strptime(end_date, '%Y-%m-%d') end_datetime = end_datetime.replace(hour=23, minute=59, second=59) else: end_datetime = datetime.strptime(end_date, '%Y-%m-%d %H:%M:%S') query = query.filter(Event.created_at <= end_datetime) except ValueError: pass # ==================== 行业层级筛选(申银万国行业分类) ==================== if industry_level and industry_classification: # 排除行业分类体系名称本身,这些不是具体的行业 classification_systems = [ '申银万国行业分类', '中上协行业分类', '巨潮行业分类', '新财富行业分类', '证监会行业分类', '证监会行业分类(2001)' ] if industry_classification not in classification_systems: # 使用内存缓存获取行业代码前缀(性能优化:避免每次请求都查询数据库) # 前端发送的level值: # level=2 -> 一级行业 # level=3 -> 二级行业 # level=4 -> 三级行业 # level=5 -> 四级行业 if industry_level in SYWG_INDUSTRY_CACHE: # 直接从缓存中获取代码前缀列表 code_prefixes = SYWG_INDUSTRY_CACHE[industry_level].get(industry_classification, []) if code_prefixes: # 构建查询条件:查找related_industries中包含这些前缀的事件 if isinstance(db.engine.dialect, MySQLDialect): # MySQL JSON查询 conditions = [] for prefix in code_prefixes: conditions.append( text(""" JSON_SEARCH( related_industries, 'one', CONCAT(:prefix, '%'), NULL, '$[*]."申银万国行业分类"' ) IS NOT NULL """).params(prefix=prefix) ) if conditions: query = query.filter(or_(*conditions)) else: # 其他数据库 pattern_conditions = [] for prefix in code_prefixes: pattern_conditions.append( text("related_industries::text LIKE :pattern").params( pattern=f'%"申银万国行业分类": "{prefix}%' ) ) if pattern_conditions: query = query.filter(or_(*pattern_conditions)) else: # 没有找到匹配的行业代码,返回空结果 query = query.filter(Event.id == -1) else: # 无效的层级参数 app.logger.warning(f"Invalid industry_level: {industry_level}") else: # industry_classification 是分类体系名称,不进行筛选 app.logger.info( f"Skipping filter: industry_classification '{industry_classification}' is a classification system name") # ==================== 细分行业筛选(保留向后兼容) ==================== elif secondary_sector != 'all': # 直接按行业名称查询(最后一级行业 - level5/f007v) sector_code_query = db.session.query(text("DISTINCT f003v")).select_from( text("ea_sector") ).filter( text("f002v = '申银万国行业分类' AND f007v = :sector_name") ).params(sector_name=secondary_sector) sector_result = sector_code_query.first() if sector_result and sector_result[0]: industry_code_to_search = sector_result[0] # 在related_industries JSON中查找包含该代码的事件 if isinstance(db.engine.dialect, MySQLDialect): query = query.filter( text(""" JSON_SEARCH( related_industries, 'one', :industry_code, NULL, '$[*]."申银万国行业分类"' ) IS NOT NULL """) ).params(industry_code=industry_code_to_search) else: query = query.filter( text(""" related_industries::text LIKE :pattern """) ).params(pattern=f'%"申银万国行业分类": "{industry_code_to_search}"%') else: # 如果没有找到对应的行业代码,返回空结果 query = query.filter(Event.id == -1) # ==================== 概念/标签筛选 ==================== # 单个标签筛选 if tag: if isinstance(db.engine.dialect, MySQLDialect): query = query.filter(text("JSON_CONTAINS(keywords, :tag, '$')")) query = query.params(tag=json.dumps(tag)) else: query = query.filter(Event.keywords.cast(JSONB).contains([tag])) # 多个标签筛选 (AND逻辑) if tags: tag_list = [t.strip() for t in tags.split(',') if t.strip()] for single_tag in tag_list: if isinstance(db.engine.dialect, MySQLDialect): query = query.filter(text("JSON_CONTAINS(keywords, :tag, '$')")) query = query.params(tag=json.dumps(single_tag)) else: query = query.filter(Event.keywords.cast(JSONB).contains([single_tag])) # 关键词筛选 (OR逻辑) if keywords: keyword_list = [k.strip() for k in keywords.split(',') if k.strip()] keyword_filters = [] for keyword in keyword_list: if isinstance(db.engine.dialect, MySQLDialect): keyword_filters.append(text("JSON_CONTAINS(keywords, :keyword, '$')")) else: keyword_filters.append(Event.keywords.cast(JSONB).contains([keyword])) if keyword_filters: query = query.filter(or_(*keyword_filters)) # ==================== 搜索功能 ==================== if search_query: search_terms = search_query.strip().split() if search_type == 'stock': # 股票搜索 query = query.join(RelatedStock).filter( or_( RelatedStock.stock_code.ilike(f'%{search_query}%'), RelatedStock.stock_name.ilike(f'%{search_query}%') ) ).distinct() elif search_type == 'all': # 全局搜索 search_filters = [] # 文本字段搜索 for term in search_terms: term_filters = [] if 'title' in search_fields: term_filters.append(Event.title.ilike(f'%{term}%')) if 'description' in search_fields: term_filters.append(Event.description.ilike(f'%{term}%')) if 'keywords' in search_fields: if isinstance(db.engine.dialect, MySQLDialect): term_filters.append(text("JSON_CONTAINS(keywords, :term, '$')")) else: term_filters.append(Event.keywords.cast(JSONB).contains([term])) if term_filters: search_filters.append(or_(*term_filters)) # 股票搜索 stock_subquery = db.session.query(RelatedStock.event_id).filter( or_( RelatedStock.stock_code.ilike(f'%{search_query}%'), RelatedStock.stock_name.ilike(f'%{search_query}%') ) ).subquery() search_filters.append(Event.id.in_(stock_subquery)) if search_filters: query = query.filter(or_(*search_filters)) else: # 话题搜索 (默认) for term in search_terms: term_filters = [] if 'title' in search_fields: term_filters.append(Event.title.ilike(f'%{term}%')) if 'description' in search_fields: term_filters.append(Event.description.ilike(f'%{term}%')) if 'keywords' in search_fields: if isinstance(db.engine.dialect, MySQLDialect): term_filters.append(text("JSON_CONTAINS(keywords, :term, '$')")) else: term_filters.append(Event.keywords.cast(JSONB).contains([term])) if term_filters: query = query.filter(or_(*term_filters)) # ==================== 收益率筛选 ==================== if min_avg_return is not None: query = query.filter(Event.related_avg_chg >= min_avg_return) if max_avg_return is not None: query = query.filter(Event.related_avg_chg <= max_avg_return) if min_max_return is not None: query = query.filter(Event.related_max_chg >= min_max_return) if max_max_return is not None: query = query.filter(Event.related_max_chg <= max_max_return) if min_week_return is not None: query = query.filter(Event.related_week_chg >= min_week_return) if max_week_return is not None: query = query.filter(Event.related_week_chg <= max_week_return) # ==================== 其他数值筛选 ==================== if min_hot_score is not None: query = query.filter(Event.hot_score >= min_hot_score) if max_hot_score is not None: query = query.filter(Event.hot_score <= max_hot_score) if min_view_count is not None: query = query.filter(Event.view_count >= min_view_count) # ==================== 排序逻辑 ==================== order_func = desc if order.lower() == 'desc' else asc if sort_by == 'hot': query = query.order_by(order_func(Event.hot_score)) elif sort_by == 'new': query = query.order_by(order_func(Event.created_at)) elif sort_by == 'returns': if return_type == 'avg': query = query.order_by(order_func(Event.related_avg_chg)) elif return_type == 'max': query = query.order_by(order_func(Event.related_max_chg)) elif return_type == 'week': query = query.order_by(order_func(Event.related_week_chg)) elif sort_by == 'importance': importance_order = case( (Event.importance == 'S', 1), (Event.importance == 'A', 2), (Event.importance == 'B', 3), (Event.importance == 'C', 4), else_=5 ) if order.lower() == 'desc': query = query.order_by(importance_order) else: query = query.order_by(desc(importance_order)) elif sort_by == 'view_count': query = query.order_by(order_func(Event.view_count)) elif sort_by == 'follow' and hasattr(request, 'user') and request.user.is_authenticated: # 关注的事件排序 query = query.join(EventFollow).filter( EventFollow.user_id == request.user.id ).order_by(order_func(Event.created_at)) # ==================== 分页查询 ==================== paginated = query.paginate(page=page, per_page=per_page, error_out=False) # ==================== 构建响应数据 ==================== events_data = [] for event in paginated.items: # 构建事件数据(保持原有结构,个股信息和统计置空) event_dict = { 'id': event.id, 'title': event.title, 'description': event.description, 'event_type': event.event_type, 'importance': event.importance, 'status': event.status, 'created_at': event.created_at.isoformat() if event.created_at else None, 'updated_at': event.updated_at.isoformat() if event.updated_at else None, 'start_time': event.start_time.isoformat() if event.start_time else None, 'end_time': event.end_time.isoformat() if event.end_time else None, # 个股信息(置空) 'related_stocks': [], # 股票统计(置空或使用数据库字段) 'stocks_stats': { 'stocks_count': 10, 'valid_stocks_count': 0, # 使用数据库字段的涨跌幅 'avg_week_change': round(event.related_week_chg, 2) if event.related_week_chg else 0, 'max_week_change': round(event.related_max_chg, 2) if event.related_max_chg else 0, 'avg_daily_change': round(event.related_avg_chg, 2) if event.related_avg_chg else 0, 'max_daily_change': round(event.related_max_chg, 2) if event.related_max_chg else 0 } } # 统计信息(可选) if include_stats: event_dict.update({ 'hot_score': event.hot_score, 'view_count': event.view_count, 'post_count': event.post_count, 'follower_count': event.follower_count, 'related_avg_chg': event.related_avg_chg, 'related_max_chg': event.related_max_chg, 'related_week_chg': event.related_week_chg, 'invest_score': event.invest_score, 'trending_score': event.trending_score, }) # 创建者信息(可选) if include_creator: event_dict['creator'] = { 'id': event.creator.id if event.creator else None, 'username': event.creator.username if event.creator else 'Anonymous', 'avatar_url': get_full_avatar_url(event.creator.avatar_url) if event.creator else None, 'is_creator': event.creator.is_creator if event.creator else False, 'creator_type': event.creator.creator_type if event.creator else None } # 关联数据 event_dict['keywords'] = event.keywords if isinstance(event.keywords, list) else [] event_dict['related_industries'] = event.related_industries # 包含统计信息(可选,置空) if include_stats: event_dict['stats'] = { 'related_stocks_count': 10, 'historical_events_count': 0, 'related_data_count': 0, 'related_concepts_count': 0 } # 包含关联数据(可选,已置空) if include_related_data: event_dict['related_stocks'] = [] events_data.append(event_dict) # ==================== 构建筛选信息 ==================== applied_filters = {} # 记录已应用的筛选条件 if event_type != 'all': applied_filters['type'] = event_type if importance != 'all': applied_filters['importance'] = importance if start_date: applied_filters['start_date'] = start_date if end_date: applied_filters['end_date'] = end_date if industry_level and industry_classification: applied_filters['industry_level'] = industry_level applied_filters['industry_classification'] = industry_classification if tag: applied_filters['tag'] = tag if tags: applied_filters['tags'] = tags if search_query: applied_filters['search_query'] = search_query applied_filters['search_type'] = search_type # ==================== 返回结果(保持完全兼容,统计数据置空) ==================== return jsonify({ 'success': True, 'data': { 'events': events_data, 'pagination': { 'page': paginated.page, 'per_page': paginated.per_page, 'total': paginated.total, 'pages': paginated.pages, 'has_prev': paginated.has_prev, 'has_next': paginated.has_next }, 'filters': { 'applied_filters': { 'type': event_type, 'status': event_status, 'importance': importance, 'stock_sector': stock_sector, # 保持兼容 'secondary_sector': secondary_sector, # 保持兼容 'sort': sort_by, 'order': order } }, # 整体股票涨跌幅分布统计(置空) 'overall_stats': { 'total_stocks': 0, 'change_distribution': { 'limit_down': 0, 'down_over_5': 0, 'down_5_to_1': 0, 'down_within_1': 0, 'flat': 0, 'up_within_1': 0, 'up_1_to_5': 0, 'up_over_5': 0, 'limit_up': 0 }, 'change_distribution_percentages': { 'limit_down': 0, 'down_over_5': 0, 'down_5_to_1': 0, 'down_within_1': 0, 'flat': 0, 'up_within_1': 0, 'up_1_to_5': 0, 'up_over_5': 0, 'limit_up': 0 } } } }) except Exception as e: app.logger.error(f"获取事件列表出错: {str(e)}", exc_info=True) return jsonify({ 'success': False, 'error': str(e), 'error_type': type(e).__name__ }), 500 def get_filter_counts(base_query): """ 获取各个筛选条件的计数信息 可用于前端显示筛选选项的可用数量 """ try: counts = {} # 重要性计数 importance_counts = db.session.query( Event.importance, func.count(Event.id).label('count') ).filter( Event.id.in_(base_query.with_entities(Event.id).subquery()) ).group_by(Event.importance).all() counts['importance'] = {item.importance or 'unknown': item.count for item in importance_counts} # 事件类型计数 type_counts = db.session.query( Event.event_type, func.count(Event.id).label('count') ).filter( Event.id.in_(base_query.with_entities(Event.id).subquery()) ).group_by(Event.event_type).all() counts['event_type'] = {item.event_type or 'unknown': item.count for item in type_counts} return counts except Exception: return {} def get_event_class(count): """根据事件数量返回对应的样式类""" if count >= 10: return 'bg-gradient-danger' elif count >= 7: return 'bg-gradient-warning' elif count >= 4: return 'bg-gradient-info' else: return 'bg-gradient-success' @app.route('/api/calendar-event-counts') def get_calendar_event_counts(): """获取整月的事件数量统计,仅统计type为event的事件""" try: # 获取当前月份的开始和结束日期 today = datetime.now() start_date = today.replace(day=1) if today.month == 12: end_date = today.replace(year=today.year + 1, month=1, day=1) else: end_date = today.replace(month=today.month + 1, day=1) # 修改查询以仅统计type为event的事件数量 query = """ SELECT DATE (calendar_time) as date, COUNT (*) as count FROM future_events WHERE calendar_time BETWEEN :start_date \ AND :end_date AND type = 'event' GROUP BY DATE (calendar_time) \ """ result = db.session.execute(text(query), { 'start_date': start_date, 'end_date': end_date }) # 格式化结果为日历事件格式 events = [{ 'title': f'{day.count} 个事件', 'start': day.date.isoformat() if day.date else None, 'className': get_event_class(day.count) } for day in result] return jsonify(events) except Exception as e: return jsonify({'error': str(e)}), 500 def get_full_avatar_url(avatar_url): """ 统一处理头像URL,确保返回完整的可访问URL Args: avatar_url: 头像URL字符串 Returns: 完整的头像URL,如果没有头像则返回默认头像URL """ if not avatar_url: # 返回默认头像 return f"{DOMAIN}/static/assets/img/default-avatar.png" # 如果已经是完整URL(http或https开头),直接返回 if avatar_url.startswith(('http://', 'https://')): return avatar_url # 如果是相对路径,拼接域名 if avatar_url.startswith('/'): return f"{DOMAIN}{avatar_url}" else: return f"{DOMAIN}/{avatar_url}" # 修改User模型的to_dict方法 def to_dict(self): """转换为字典格式,方便API返回""" return { 'id': self.id, 'username': self.username, 'email': self.email, 'nickname': self.nickname, 'avatar_url': get_full_avatar_url(self.avatar_url), # 使用统一处理函数 'bio': self.bio, 'location': self.location, 'is_verified': self.is_verified, 'user_level': self.user_level, 'reputation_score': self.reputation_score, 'post_count': self.post_count, 'follower_count': self.follower_count, 'following_count': self.following_count, 'created_at': self.created_at.isoformat() if self.created_at else None, 'last_seen': self.last_seen.isoformat() if self.last_seen else None } # ==================== 标准化API接口 ==================== # 1. 首页接口 @app.route('/api/home', methods=['GET']) def api_home(): try: seven_days_ago = datetime.now() - timedelta(days=7) hot_events = Event.query.filter( Event.status == 'active', Event.created_at >= seven_days_ago ).order_by(Event.hot_score.desc()).limit(10).all() events_data = [] for event in hot_events: related_stocks = RelatedStock.query.filter_by(event_id=event.id).all() # 计算相关性统计数据 correlations = [float(stock.correlation or 0) for stock in related_stocks] avg_correlation = sum(correlations) / len(correlations) if correlations else 0 max_correlation = max(correlations) if correlations else 0 stocks_data = [] total_week_change = 0 max_week_change = 0 total_daily_change = 0 # 新增:日涨跌幅总和 max_daily_change = 0 # 新增:最大日涨跌幅 valid_stocks_count = 0 for stock in related_stocks: stock_code = stock.stock_code.split('.')[0] # 获取最新交易日数据 latest_trade = db.session.execute(text(""" SELECT * FROM ea_trade WHERE SECCODE = :stock_code ORDER BY TRADEDATE DESC LIMIT 1 """), {"stock_code": stock_code}).first() week_change = 0 daily_change = 0 # 新增:日涨跌幅 if latest_trade and latest_trade.F007N: latest_price = float(latest_trade.F007N or 0) latest_date = latest_trade.TRADEDATE daily_change = float(latest_trade.F010N or 0) # F010N是日涨跌幅字段 # 更新日涨跌幅统计 total_daily_change += daily_change max_daily_change = max(max_daily_change, daily_change) # 获取最近5条交易记录 week_ago_trades = db.session.execute(text(""" SELECT * FROM ea_trade WHERE SECCODE = :stock_code AND TRADEDATE < :latest_date ORDER BY TRADEDATE DESC LIMIT 5 """), { "stock_code": stock_code, "latest_date": latest_date }).fetchall() if week_ago_trades and week_ago_trades[-1].F007N: week_ago_price = float(week_ago_trades[-1].F007N or 0) if week_ago_price > 0: week_change = (latest_price - week_ago_price) / week_ago_price * 100 total_week_change += week_change max_week_change = max(max_week_change, week_change) valid_stocks_count += 1 stocks_data.append({ "stock_code": stock.stock_code, "stock_name": stock.stock_name, "correlation": float(stock.correlation or 0), "sector": stock.sector, "week_change": round(week_change, 2), "daily_change": round(daily_change, 2), # 新增:个股日涨跌幅 "latest_trade_date": latest_trade.TRADEDATE.strftime("%Y-%m-%d") if latest_trade else None }) # 计算平均值 avg_week_change = total_week_change / valid_stocks_count if valid_stocks_count > 0 else 0 avg_daily_change = total_daily_change / valid_stocks_count if valid_stocks_count > 0 else 0 events_data.append({ "id": event.id, "title": event.title, "description": event.description, 'event_type': event.event_type, 'importance': event.importance, # 添加重要性 'status': event.status, "created_at": event.created_at.strftime("%Y-%m-%d %H:%M:%S"), 'updated_at': event.updated_at.isoformat() if event.updated_at else None, 'start_time': event.start_time.isoformat() if event.start_time else None, 'end_time': event.end_time.isoformat() if event.end_time else None, 'view_count': event.view_count, # 添加浏览量 'post_count': event.post_count, # 添加帖子数 'follower_count': event.follower_count, # 添加关注者数 "related_stocks": stocks_data, "stocks_stats": { "avg_correlation": round(avg_correlation, 2), "max_correlation": round(max_correlation, 2), "stocks_count": len(related_stocks), "valid_stocks_count": valid_stocks_count, # 周涨跌统计 "avg_week_change": round(avg_week_change, 2), "max_week_change": round(max_week_change, 2), # 日涨跌统计 "avg_daily_change": round(avg_daily_change, 2), "max_daily_change": round(max_daily_change, 2) } }) return jsonify({ "code": 200, "message": "success", "data": { "events": events_data } }) except Exception as e: print(f"Error in api_home: {str(e)}") return jsonify({ "code": 500, "message": str(e), "data": None }), 500 @app.route('/api/auth/logout', methods=['POST']) def logout_with_token(): """使用token登出""" # 从请求头获取token auth_header = request.headers.get('Authorization') if auth_header and auth_header.startswith('Bearer '): token = auth_header[7:] else: data = request.get_json() token = data.get('token') if data else None if token and token in user_tokens: del user_tokens[token] # 清除session session.clear() return jsonify({'message': '登出成功'}), 200 def send_sms_code(phone, code, template_id): """发送短信验证码""" try: cred = credential.Credential(SMS_SECRET_ID, SMS_SECRET_KEY) httpProfile = HttpProfile() httpProfile.endpoint = "sms.tencentcloudapi.com" clientProfile = ClientProfile() clientProfile.httpProfile = httpProfile client = sms_client.SmsClient(cred, "ap-beijing", clientProfile) req = models.SendSmsRequest() params = { "PhoneNumberSet": [phone], "SmsSdkAppId": SMS_SDK_APP_ID, "TemplateId": template_id, "SignName": SMS_SIGN_NAME, "TemplateParamSet": [code] if template_id == SMS_TEMPLATE_REGISTER else [code, "5"] } req.from_json_string(json.dumps(params)) resp = client.SendSms(req) return True except TencentCloudSDKException as err: print(f"SMS Error: {err}") return False def generate_verification_code(): """生成6位数字验证码""" return ''.join(random.choices(string.digits, k=6)) @app.route('/api/auth/send-sms', methods=['POST']) def send_sms_verification(): """发送手机验证码(统一接口,自动判断场景)""" data = request.get_json() phone = data.get('phone') if not phone: return jsonify({'error': '手机号不能为空'}), 400 # 检查手机号是否已注册 user_exists = User.query.filter_by(phone=phone).first() is not None # 生成验证码 code = generate_verification_code() # 根据用户是否存在自动选择模板 template_id = SMS_TEMPLATE_LOGIN if user_exists else SMS_TEMPLATE_REGISTER # 发送短信 if send_sms_code(phone, code, template_id): # 统一存储验证码(5分钟有效) verification_codes[phone] = { 'code': code, 'expires': time.time() + 300 } # 简单返回成功,不暴露用户是否存在的信息 return jsonify({ 'message': '验证码已发送', 'expires_in': 300 # 告诉前端验证码有效期(秒) }), 200 else: return jsonify({'error': '验证码发送失败'}), 500 def generate_token(length=32): """生成随机token""" characters = string.ascii_letters + string.digits return ''.join(secrets.choice(characters) for _ in range(length)) @app.route('/api/auth/login/phone', methods=['POST']) def login_with_phone(): """统一的手机号登录/注册接口""" data = request.get_json() phone = data.get('phone') code = data.get('code') username = data.get('username') # 可选,新用户可以提供 password = data.get('password') # 可选,新用户可以提供 if not all([phone, code]): return jsonify({ 'code': 400, 'message': '手机号和验证码不能为空' }), 400 # 验证验证码 stored_code = verification_codes.get(phone) if not stored_code or stored_code['expires'] < time.time(): return jsonify({ 'code': 400, 'message': '验证码已过期' }), 400 if stored_code['code'] != code: return jsonify({ 'code': 400, 'message': '验证码错误' }), 400 try: # 查找用户 user = User.query.filter_by(phone=phone).first() is_new_user = False # 如果用户不存在,自动注册 if not user: is_new_user = True # 如果提供了用户名,检查是否已存在 if username: if User.query.filter_by(username=username).first(): return jsonify({ 'code': 400, 'message': '用户名已被使用,请换一个' }), 400 else: # 自动生成用户名 base_username = f"user_{phone[-4:]}" username = base_username counter = 1 while User.query.filter_by(username=username).first(): username = f"{base_username}_{counter}" counter += 1 # 创建新用户 user = User(username=username, phone=phone) user.email = f"{username}@valuefrontier.temp" # 如果提供了密码就使用,否则生成随机密码 if password: user.set_password(password) else: random_password = generate_token(16) user.set_password(random_password) user.phone_confirmed = True db.session.add(user) db.session.commit() # 生成token token = generate_token(32) # 存储token映射(30天有效期) user_tokens[token] = { 'user_id': user.id, 'expires': datetime.now() + timedelta(days=30) } # 清除验证码 del verification_codes[phone] # 设置session(保持与原有逻辑兼容) session.permanent = True session['user_id'] = user.id session['username'] = user.username session['logged_in'] = True # 返回响应 response_data = { 'code': 0, 'message': '欢迎回来' if not is_new_user else '注册成功,欢迎加入', 'token': token, 'is_new_user': is_new_user, # 告诉前端是否是新用户 'user': { 'id': user.id, 'username': user.username, 'phone': user.phone, 'need_complete_profile': is_new_user # 提示新用户完善资料 } } return jsonify(response_data), 200 except Exception as e: db.session.rollback() print(f"Login/Register error: {e}") return jsonify({ 'code': 500, 'message': '操作失败,请重试' }), 500 @app.route('/api/auth/verify-token', methods=['POST']) def verify_token(): """验证token有效性(可选接口)""" data = request.get_json() token = data.get('token') if not token: return jsonify({'valid': False, 'message': 'Token不能为空'}), 400 token_data = user_tokens.get(token) if not token_data: return jsonify({'valid': False, 'message': 'Token无效', 'code': 401}), 401 # 检查是否过期 if token_data['expires'] < datetime.now(): del user_tokens[token] return jsonify({'valid': False, 'message': 'Token已过期'}), 401 # 获取用户信息 user = User.query.get(token_data['user_id']) if not user: return jsonify({'valid': False, 'message': '用户不存在'}), 404 return jsonify({ 'valid': True, 'user': { 'id': user.id, 'username': user.username, 'phone': user.phone } }), 200 def generate_jwt_token(user_id): """ 生成JWT Token - 与原系统保持一致 Args: user_id: 用户ID Returns: str: JWT token字符串 """ payload = { 'user_id': user_id, 'exp': datetime.utcnow() + timedelta(hours=JWT_EXPIRATION_HOURS), 'iat': datetime.utcnow() } token = jwt.encode(payload, JWT_SECRET_KEY, algorithm=JWT_ALGORITHM) return token @app.route('/api/auth/login/wechat', methods=['POST']) def api_login_wechat(): try: # 1. 获取请求数据 data = request.get_json() code = data.get('code') if data else None if not code: return jsonify({ 'code': 400, 'message': '缺少必要的参数', 'data': None }), 400 # 2. 验证code格式 if not isinstance(code, str) or len(code) < 10: return jsonify({ 'code': 400, 'message': 'code格式无效', 'data': None }), 400 logger.info(f"开始处理微信登录,code长度: {len(code)}") # 3. 调用微信接口获取用户信息 wx_api_url = 'https://api.weixin.qq.com/sns/jscode2session' params = { 'appid': WECHAT_APP_ID, 'secret': WECHAT_APP_SECRET, 'js_code': code, 'grant_type': 'authorization_code' } try: response = requests.get(wx_api_url, params=params, timeout=10) response.raise_for_status() wx_data = response.json() # 检查微信API返回的错误 if 'errcode' in wx_data and wx_data['errcode'] != 0: error_messages = { -1: '系统繁忙,请稍后重试', 40029: 'code无效或已过期', 45011: '频率限制,请稍后再试', 40013: 'AppID错误', 40125: 'AppSecret错误', 40226: '高风险用户,登录被拦截' } error_msg = error_messages.get( wx_data['errcode'], f"微信接口错误: {wx_data.get('errmsg', '未知错误')}" ) logger.error(f"WeChat API error {wx_data['errcode']}: {error_msg}") return jsonify({ 'code': 400, 'message': error_msg, 'data': None }), 400 # 验证必需字段 if 'openid' not in wx_data or 'session_key' not in wx_data: logger.error("响应缺少必需字段") return jsonify({ 'code': 500, 'message': '微信响应格式错误', 'data': None }), 500 openid = wx_data['openid'] session_key = wx_data['session_key'] unionid = wx_data.get('unionid') # 可能为None logger.info(f"成功获取微信用户信息 - OpenID: {openid[:8]}...") if unionid: logger.info(f"获取到UnionID: {unionid[:8]}...") except requests.exceptions.Timeout: logger.error("请求微信API超时") return jsonify({ 'code': 500, 'message': '请求超时,请重试', 'data': None }), 500 except requests.exceptions.RequestException as e: logger.error(f"网络请求失败: {str(e)}") return jsonify({ 'code': 500, 'message': '网络错误', 'data': None }), 500 # 4. 查找或创建用户 - 核心逻辑 user = None is_new_user = False logger.info(f"开始查找用户 - UnionID: {unionid}, OpenID: {openid[:8]}...") if unionid: # 情况1: 有unionid,优先通过unionid查找 user = User.query.filter_by(wechat_union_id=unionid).first() if user: logger.info(f"通过UnionID找到用户: {user.username}") # 更新openid(可能用户从不同小程序登录) if user.wechat_open_id != openid: user.wechat_open_id = openid logger.info(f"更新用户OpenID: {openid[:8]}...") else: # unionid没找到,再尝试用openid查找(处理历史数据) user = User.query.filter_by(wechat_open_id=openid).first() if user: logger.info(f"通过OpenID找到用户: {user.username}") # 补充unionid user.wechat_union_id = unionid logger.info(f"为用户补充UnionID: {unionid[:8]}...") else: # 情况2: 没有unionid,只能通过openid查找 logger.warning("未获取到UnionID(小程序可能未绑定开放平台)") user = User.query.filter_by(wechat_open_id=openid).first() if user: logger.info(f"通过OpenID找到用户: {user.username}") # 5. 创建新用户 if not user: is_new_user = True # 生成唯一用户名 timestamp = int(time.time()) username = f"wx_{timestamp}_{openid[-6:]}" # 确保用户名唯一 counter = 0 base_username = username while User.query.filter_by(username=username).first(): counter += 1 username = f"{base_username}_{counter}" # 创建用户对象(使用你的User模型) user = User( username=username, email=f"{username}@wechat.local", # 占位邮箱 password="wechat_login_no_password" # 微信登录不需要密码 ) # 设置微信相关字段 user.wechat_open_id = openid user.wechat_union_id = unionid user.status = 'active' user.email_confirmed = False # 设置默认值 user.nickname = f"微信用户{openid[-4:]}" user.bio = "" # 空的个人简介 user.avatar_url = None # 稍后会处理 user.is_creator = False user.is_verified = False user.user_level = 1 user.reputation_score = 0 user.contribution_point = 0 user.post_count = 0 user.comment_count = 0 user.follower_count = 0 user.following_count = 0 # 设置默认偏好 user.email_notifications = True user.privacy_level = 'public' user.theme_preference = 'light' db.session.add(user) logger.info(f"创建新用户: {username}") else: # 更新最后登录时间 user.update_last_seen() logger.info(f"用户登录: {user.username}") # 6. 提交数据库更改 try: db.session.commit() except Exception as e: db.session.rollback() logger.error(f"保存用户信息失败: {str(e)}") return jsonify({ 'code': 500, 'message': '保存用户信息失败', 'data': None }), 500 # 7. 生成JWT token(使用原系统的生成方法) token = generate_token(32) # 使用相同的随机字符串生成器 # 存储token映射(与手机登录保持一致) user_tokens[token] = { 'user_id': user.id, 'expires': datetime.now() + timedelta(days=30) # 30天有效期 } # 设置session(可选,保持与手机登录一致) session.permanent = True session['user_id'] = user.id session['username'] = user.username session['logged_in'] = True # 9. 构造返回数据 - 完全匹配要求的格式 response_data = { 'code': 200, 'data': { 'token': token, # 现在这个token能被token_required识别了 'user': { 'avatar_url': get_full_avatar_url(user.avatar_url), 'bio': user.bio or "", 'email': user.email, 'id': user.id, 'is_creator': user.is_creator, 'is_verified': user.is_verified, 'nickname': user.nickname or user.username, 'reputation_score': user.reputation_score, 'user_level': user.user_level, 'username': user.username } }, 'message': '登录成功' } # 10. 记录日志 logger.info( f"微信登录成功 - 用户ID: {user.id}, " f"用户名: {user.username}, " f"新用户: {is_new_user}, " f"有UnionID: {unionid is not None}" ) return jsonify(response_data), 200 except Exception as e: # 捕获所有未处理的异常 logger.error(f"微信登录处理异常: {str(e)}", exc_info=True) db.session.rollback() return jsonify({ 'code': 500, 'message': '服务器内部错误', 'data': None }), 500 @app.route('/api/auth/login/email', methods=['POST']) def api_login_email(): """邮箱登录接口""" try: data = request.get_json() email = data.get('email') password = data.get('password') if not email or not password: return jsonify({ 'code': 400, 'message': '邮箱和密码不能为空', 'data': None }), 400 user = User.query.filter_by(email=email).first() if not user or not user.check_password(password): return jsonify({ 'code': 400, 'message': '邮箱或密码错误', 'data': None }), 400 token = generate_jwt_token(user.id) login_user(user) user.update_last_seen() db.session.commit() return jsonify({ 'code': 200, 'message': '登录成功', 'data': { 'token': token, 'user_id': user.id, 'username': user.username, 'email': user.email, 'is_verified': user.is_verified, 'user_level': user.user_level } }) except Exception as e: return jsonify({ 'code': 500, 'message': str(e), 'data': None }), 500 # 5. 事件详情-相关标的接口 @app.route('/api/event//related-stocks-detail', methods=['GET']) def api_event_related_stocks(event_id): """事件相关标的详情接口 - 仅限 Pro/Max 会员""" try: from datetime import datetime, timedelta, time as dt_time from sqlalchemy import text event = Event.query.get_or_404(event_id) related_stocks = event.related_stocks.order_by(RelatedStock.correlation.desc()).all() # 获取ClickHouse客户端用于分时数据查询 client = get_clickhouse_client() # 获取事件时间(如果事件有开始时间,使用开始时间;否则使用创建时间) event_time = event.start_time if event.start_time else event.created_at current_time = datetime.now() # 定义交易日和时间范围计算函数(与 app.py 中的逻辑完全一致) def get_trading_day_and_times(event_datetime): event_date = event_datetime.date() event_time_only = event_datetime.time() # Trading hours market_open = dt_time(9, 30) market_close = dt_time(15, 0) with engine.connect() as conn: # First check if the event date itself is a trading day is_trading_day = conn.execute(text(""" SELECT 1 FROM trading_days WHERE EXCHANGE_DATE = :date """), {"date": event_date}).fetchone() is not None if is_trading_day: # If it's a trading day, determine time period based on event time if event_time_only < market_open: # Before market opens - use full trading day return event_date, market_open, market_close elif event_time_only > market_close: # After market closes - get next trading day next_trading_day = conn.execute(text(""" SELECT EXCHANGE_DATE FROM trading_days WHERE EXCHANGE_DATE > :date ORDER BY EXCHANGE_DATE LIMIT 1 """), {"date": event_date}).fetchone() # Convert to date object if we found a next trading day return (next_trading_day[0].date() if next_trading_day else None, market_open, market_close) else: # During trading hours return event_date, event_time_only, market_close else: # If not a trading day, get next trading day next_trading_day = conn.execute(text(""" SELECT EXCHANGE_DATE FROM trading_days WHERE EXCHANGE_DATE > :date ORDER BY EXCHANGE_DATE LIMIT 1 """), {"date": event_date}).fetchone() # Convert to date object if we found a next trading day return (next_trading_day[0].date() if next_trading_day else None, market_open, market_close) trading_day, start_time, end_time = get_trading_day_and_times(event_time) if not trading_day: # 如果没有交易日,返回空数据 return jsonify({ 'code': 200, 'message': 'success', 'data': { 'event_id': event_id, 'event_title': event.title, 'event_desc': event.description, 'event_type': event.event_type, 'event_importance': event.importance, 'event_status': event.status, 'event_created_at': event.created_at.strftime("%Y-%m-%d %H:%M:%S"), 'event_start_time': event.start_time.isoformat() if event.start_time else None, 'event_end_time': event.end_time.isoformat() if event.end_time else None, 'keywords': event.keywords, 'view_count': event.view_count, 'post_count': event.post_count, 'follower_count': event.follower_count, 'related_stocks': [], 'total_count': 0 } }) # For historical dates, ensure we're using actual data start_datetime = datetime.combine(trading_day, start_time) end_datetime = datetime.combine(trading_day, end_time) # If the trading day is in the future relative to current time, return only names without data if trading_day > current_time.date(): start_datetime = datetime.combine(trading_day, start_time) end_datetime = datetime.combine(trading_day, end_time) print(f"事件时间: {event_time}, 交易日: {trading_day}, 时间范围: {start_datetime} - {end_datetime}") def get_minute_chart_data(stock_code): """获取股票分时图数据""" try: # 获取当前日期或最新交易日的分时数据 from datetime import datetime, timedelta, time as dt_time today = datetime.now().date() # 获取最新交易日的分时数据 data = client.execute(""" SELECT timestamp, open, high, low, close, volume, amt FROM stock_minute WHERE code = %(code)s AND timestamp >= %(start)s AND timestamp <= %(end)s ORDER BY timestamp """, { 'code': stock_code, 'start': datetime.combine(today, dt_time(9, 30)), 'end': datetime.combine(today, dt_time(15, 0)) }) # 如果今天没有数据,获取最近的交易日数据 if not data: # 获取最近的交易日数据 recent_data = client.execute(""" SELECT timestamp, open, high, low, close, volume, amt FROM stock_minute WHERE code = %(code)s AND timestamp >= ( SELECT MAX (timestamp) - INTERVAL 1 DAY FROM stock_minute WHERE code = %(code)s ) ORDER BY timestamp """, { 'code': stock_code }) data = recent_data # 格式化数据 minute_data = [] for row in data: minute_data.append({ 'time': row[0].strftime('%H:%M'), 'open': float(row[1]) if row[1] else None, 'high': float(row[2]) if row[2] else None, 'low': float(row[3]) if row[3] else None, 'close': float(row[4]) if row[4] else None, 'volume': float(row[5]) if row[5] else None, 'amount': float(row[6]) if row[6] else None }) return minute_data except Exception as e: print(f"Error fetching minute data for {stock_code}: {e}") return [] # ==================== 性能优化:批量查询所有股票数据 ==================== # 1. 收集所有股票代码 stock_codes = [stock.stock_code for stock in related_stocks] # 2. 批量查询股票基本信息 stock_info_map = {} if stock_codes: stock_infos = StockBasicInfo.query.filter(StockBasicInfo.SECCODE.in_(stock_codes)).all() for info in stock_infos: stock_info_map[info.SECCODE] = info # 处理不带后缀的股票代码 base_codes = [code.split('.')[0] for code in stock_codes if '.' in code and code not in stock_info_map] if base_codes: base_infos = StockBasicInfo.query.filter(StockBasicInfo.SECCODE.in_(base_codes)).all() for info in base_infos: # 将不带后缀的信息映射到带后缀的代码 for code in stock_codes: if code.split('.')[0] == info.SECCODE and code not in stock_info_map: stock_info_map[code] = info # 3. 批量查询 ClickHouse 数据(价格、涨跌幅、分时图数据) price_data_map = {} # 存储价格和涨跌幅数据 minute_chart_map = {} # 存储分时图数据 try: if stock_codes: print(f"批量查询 {len(stock_codes)} 只股票的价格数据...") # 3.1 批量查询价格和涨跌幅数据(使用子查询方式,避免窗口函数与 GROUP BY 冲突) batch_price_query = """ WITH first_prices AS (SELECT code, close as first_price \ , ROW_NUMBER() OVER (PARTITION BY code ORDER BY timestamp ASC) as rn FROM stock_minute WHERE code IN %(codes)s AND timestamp >= %(start)s AND timestamp <= %(end)s ) \ , last_prices AS ( SELECT code, close as last_price, open as open_price, high as high_price, low as low_price, volume, amt as amount, ROW_NUMBER() OVER (PARTITION BY code ORDER BY timestamp DESC) as rn FROM stock_minute WHERE code IN %(codes)s AND timestamp >= %(start)s AND timestamp <= %(end)s ) SELECT fp.code, \ fp.first_price, \ lp.last_price, \ (lp.last_price - fp.first_price) / fp.first_price * 100 as change_pct, \ lp.open_price, \ lp.high_price, \ lp.low_price, \ lp.volume, \ lp.amount FROM first_prices fp INNER JOIN last_prices lp ON fp.code = lp.code WHERE fp.rn = 1 \ AND lp.rn = 1 \ """ price_data = client.execute(batch_price_query, { 'codes': stock_codes, 'start': start_datetime, 'end': end_datetime }) print(f"批量查询返回 {len(price_data)} 条价格数据") # 解析批量查询结果 for row in price_data: code = row[0] first_price = float(row[1]) if row[1] is not None else None last_price = float(row[2]) if row[2] is not None else None change_pct = float(row[3]) if row[3] is not None else None open_price = float(row[4]) if row[4] is not None else None high_price = float(row[5]) if row[5] is not None else None low_price = float(row[6]) if row[6] is not None else None volume = int(row[7]) if row[7] is not None else None amount = float(row[8]) if row[8] is not None else None change_amount = None if last_price is not None and first_price is not None: change_amount = last_price - first_price price_data_map[code] = { 'latest_price': last_price, 'first_price': first_price, 'change_pct': change_pct, 'change_amount': change_amount, 'open_price': open_price, 'high_price': high_price, 'low_price': low_price, 'volume': volume, 'amount': amount, } # 3.2 批量查询分时图数据 print(f"批量查询分时图数据...") minute_chart_query = """ SELECT code, timestamp, open, high, low, close, volume, amt FROM stock_minute WHERE code IN %(codes)s AND timestamp >= %(start)s AND timestamp <= %(end)s ORDER BY code, timestamp \ """ minute_data = client.execute(minute_chart_query, { 'codes': stock_codes, 'start': start_datetime, 'end': end_datetime }) print(f"批量查询返回 {len(minute_data)} 条分时数据") # 按股票代码分组分时数据 for row in minute_data: code = row[0] if code not in minute_chart_map: minute_chart_map[code] = [] minute_chart_map[code].append({ 'time': row[1].strftime('%H:%M'), 'open': float(row[2]) if row[2] else None, 'high': float(row[3]) if row[3] else None, 'low': float(row[4]) if row[4] else None, 'close': float(row[5]) if row[5] else None, 'volume': float(row[6]) if row[6] else None, 'amount': float(row[7]) if row[7] else None }) except Exception as e: print(f"批量查询 ClickHouse 失败: {e}") # 如果批量查询失败,price_data_map 和 minute_chart_map 为空,后续会使用降级方案 # 4. 组装每个股票的数据(从批量查询结果中获取) stocks_data = [] for stock in related_stocks: print(f"正在组装股票 {stock.stock_code} 的数据...") # 从批量查询结果中获取股票基本信息 stock_info = stock_info_map.get(stock.stock_code) # 从批量查询结果中获取价格数据 price_info = price_data_map.get(stock.stock_code) latest_price = None first_price = None change_pct = None change_amount = None open_price = None high_price = None low_price = None volume = None amount = None trade_date = trading_day if price_info: # 使用批量查询的结果 latest_price = price_info['latest_price'] first_price = price_info['first_price'] change_pct = price_info['change_pct'] change_amount = price_info['change_amount'] open_price = price_info['open_price'] high_price = price_info['high_price'] low_price = price_info['low_price'] volume = price_info['volume'] amount = price_info['amount'] else: # 如果批量查询没有返回数据,使用降级方案(TradeData) print(f"股票 {stock.stock_code} 批量查询无数据,使用降级方案...") try: latest_trade = None search_codes = [stock.stock_code, stock.stock_code.split('.')[0]] for code in search_codes: latest_trade = TradeData.query.filter_by(SECCODE=code) \ .order_by(TradeData.TRADEDATE.desc()).first() if latest_trade: break if latest_trade: latest_price = float(latest_trade.F007N) if latest_trade.F007N else None open_price = float(latest_trade.F003N) if latest_trade.F003N else None high_price = float(latest_trade.F005N) if latest_trade.F005N else None low_price = float(latest_trade.F006N) if latest_trade.F006N else None first_price = float(latest_trade.F002N) if latest_trade.F002N else None volume = float(latest_trade.F004N) if latest_trade.F004N else None amount = float(latest_trade.F011N) if latest_trade.F011N else None trade_date = latest_trade.TRADEDATE # 计算涨跌幅 if latest_trade.F010N: change_pct = float(latest_trade.F010N) if latest_trade.F009N: change_amount = float(latest_trade.F009N) except Exception as fallback_error: print(f"降级查询也失败 {stock.stock_code}: {fallback_error}") # 从批量查询结果中获取分时图数据 minute_chart_data = minute_chart_map.get(stock.stock_code, []) stock_data = { 'id': stock.id, 'stock_code': stock.stock_code, 'stock_name': stock.stock_name, 'sector': stock.sector, 'relation_desc': stock.relation_desc, 'correlation': stock.correlation, 'momentum': stock.momentum, 'listing_date': stock_info.F006D.isoformat() if stock_info and stock_info.F006D else None, 'market': stock_info.F005V if stock_info else None, # 交易数据 'trade_data': { 'latest_price': latest_price, 'first_price': first_price, # 事件发生时的价格 'open_price': open_price, 'high_price': high_price, 'low_price': low_price, 'change_amount': round(change_amount, 2) if change_amount is not None else None, 'change_pct': round(change_pct, 2) if change_pct is not None else None, 'volume': volume, 'amount': amount, 'trade_date': trade_date.isoformat() if trade_date else None, 'event_start_time': start_datetime.isoformat() if start_datetime else None, # 事件开始时间 'event_end_time': end_datetime.isoformat() if end_datetime else None, # 查询结束时间 } if latest_price is not None else None, # 分时图数据 'minute_chart_data': minute_chart_data, # 图表URL 'charts': { 'minute_chart_url': f"/api/stock/{stock.stock_code}/minute-chart", 'daily_chart_url': f"/api/stock/{stock.stock_code}/kline", } } stocks_data.append(stock_data) return jsonify({ 'code': 200, 'message': 'success', 'data': { 'event_id': event_id, 'event_title': event.title, 'event_desc': event.description, 'event_type': event.event_type, 'event_importance': event.importance, 'event_status': event.status, 'event_created_at': event.created_at.strftime("%Y-%m-%d %H:%M:%S"), 'event_start_time': event.start_time.isoformat() if event.start_time else None, 'event_end_time': event.end_time.isoformat() if event.end_time else None, 'keywords': event.keywords, 'view_count': event.view_count, 'post_count': event.post_count, 'follower_count': event.follower_count, 'related_stocks': stocks_data, 'total_count': len(stocks_data) } }) except Exception as e: return jsonify({ 'code': 500, 'message': str(e), 'data': None }), 500 @app.route('/api/stock//minute-chart', methods=['GET']) def get_minute_chart_data(stock_code): """获取股票分时图数据 - 仅限 Pro/Max 会员""" client = get_clickhouse_client() try: # 获取当前日期或最新交易日的分时数据 from datetime import datetime, timedelta, time as dt_time today = datetime.now().date() # 获取最新交易日的分时数据 data = client.execute(""" SELECT timestamp, open, high, low, close, volume, amt FROM stock_minute WHERE code = %(code)s AND timestamp >= %(start)s AND timestamp <= %(end)s ORDER BY timestamp """, { 'code': stock_code, 'start': datetime.combine(today, dt_time(9, 30)), 'end': datetime.combine(today, dt_time(15, 0)) }) # 如果今天没有数据,获取最近的交易日数据 if not data: # 获取最近的交易日数据 recent_data = client.execute(""" SELECT timestamp, open, high, low, close, volume, amt FROM stock_minute WHERE code = %(code)s AND timestamp >= ( SELECT MAX (timestamp) - INTERVAL 1 DAY FROM stock_minute WHERE code = %(code)s ) ORDER BY timestamp """, { 'code': stock_code }) data = recent_data # 格式化数据 minute_data = [] for row in data: minute_data.append({ 'time': row[0].strftime('%H:%M'), 'open': float(row[1]) if row[1] else None, 'high': float(row[2]) if row[2] else None, 'low': float(row[3]) if row[3] else None, 'close': float(row[4]) if row[4] else None, 'volume': float(row[5]) if row[5] else None, 'amount': float(row[6]) if row[6] else None }) return minute_data except Exception as e: print(f"Error getting minute chart data: {e}") return [] @app.route('/api/event//stock//detail', methods=['GET']) def api_stock_detail(event_id, stock_code): """个股详情接口 - 仅限 Pro/Max 会员""" try: # 验证事件是否存在 event = Event.query.get_or_404(event_id) # 获取查询参数 include_minute_data = request.args.get('include_minute_data', 'true').lower() == 'true' include_full_sources = request.args.get('include_full_sources', 'false').lower() == 'true' # 是否包含完整研报来源 # 获取股票基本信息 basic_info = None base_code = stock_code.split('.')[0] # 去掉后缀 # 按优先级查找股票信息 basic_info = StockBasicInfo.query.filter_by(SECCODE=stock_code).first() if not basic_info: basic_info = StockBasicInfo.query.filter( StockBasicInfo.SECCODE.ilike(f"{stock_code}%") ).first() if not basic_info: basic_info = StockBasicInfo.query.filter( StockBasicInfo.SECCODE.ilike(f"{base_code}%") ).first() company_info = CompanyInfo.query.filter_by(SECCODE=stock_code).first() if not company_info: company_info = CompanyInfo.query.filter_by(SECCODE=base_code).first() if not basic_info: return jsonify({ 'code': 404, 'stock_code': stock_code, 'message': '股票不存在', 'data': None }), 404 # 获取最新交易数据 latest_trade = TradeData.query.filter_by(SECCODE=stock_code) \ .order_by(TradeData.TRADEDATE.desc()).first() if not latest_trade: latest_trade = TradeData.query.filter_by(SECCODE=base_code) \ .order_by(TradeData.TRADEDATE.desc()).first() # 获取分时数据 minute_chart_data = [] if include_minute_data: minute_chart_data = get_minute_chart_data(stock_code) # 获取该事件的相关描述 related_stock = RelatedStock.query.filter_by( event_id=event_id ).filter( db.or_( RelatedStock.stock_code == stock_code, RelatedStock.stock_code == base_code, RelatedStock.stock_code.like(f"{base_code}.%") ) ).first() related_desc = None if related_stock: # 处理研报来源数据 retrieved_sources_data = None sources_summary = None if related_stock.retrieved_sources: try: # 解析研报来源 import json sources = related_stock.retrieved_sources if isinstance(related_stock.retrieved_sources, list) else json.loads( related_stock.retrieved_sources) # 统计信息 sources_summary = { 'total_count': len(sources), 'has_sources': True, 'match_scores': {} } # 统计匹配分数分布 for source in sources: score = source.get('match_score', '未知') sources_summary['match_scores'][score] = sources_summary['match_scores'].get(score, 0) + 1 # 根据参数决定返回完整数据还是摘要 if include_full_sources: # 返回完整的研报来源 retrieved_sources_data = sources else: # 只返回前5条高质量来源作为预览 # 优先返回匹配度高的 high_quality_sources = [s for s in sources if s.get('match_score') == '好'][:3] medium_quality_sources = [s for s in sources if s.get('match_score') == '中'][:2] preview_sources = high_quality_sources + medium_quality_sources if not preview_sources: # 如果没有高中匹配度的,返回前5条 preview_sources = sources[:5] retrieved_sources_data = [] for source in preview_sources: retrieved_sources_data.append({ 'report_title': source.get('report_title', ''), 'author': source.get('author', ''), 'sentences': source.get('sentences', '')[:200] + '...' if len( source.get('sentences', '')) > 200 else source.get('sentences', ''), # 限制长度 'match_score': source.get('match_score', ''), 'declare_date': source.get('declare_date', '') }) except Exception as e: print(f"Error processing retrieved_sources for stock {stock_code}: {e}") sources_summary = {'has_sources': False, 'error': str(e)} else: sources_summary = {'has_sources': False, 'total_count': 0} related_desc = { 'event_id': related_stock.event_id, 'relation_desc': related_stock.relation_desc, 'sector': related_stock.sector, 'correlation': float(related_stock.correlation) if related_stock.correlation else None, 'momentum': related_stock.momentum, # 新增研报来源相关字段 'retrieved_sources': retrieved_sources_data, 'sources_summary': sources_summary, 'retrieved_update_time': related_stock.retrieved_update_time.isoformat() if related_stock.retrieved_update_time else None, # 添加获取完整来源的URL 'sources_detail_url': f"/api/event/{event_id}/stock/{stock_code}/sources" if sources_summary.get( 'has_sources') else None } response_data = { 'code': 200, 'message': 'success', 'data': { 'event_info': { 'event_id': event.id, 'event_title': event.title, 'event_description': event.description }, 'basic_info': { 'stock_code': basic_info.SECCODE, 'stock_name': basic_info.SECNAME, 'org_name': basic_info.ORGNAME, 'pinyin': basic_info.F001V, 'category': basic_info.F003V, 'market': basic_info.F005V, 'listing_date': basic_info.F006D.isoformat() if basic_info.F006D else None, 'status': basic_info.F011V }, 'company_info': { 'english_name': company_info.F001V if company_info else None, 'legal_representative': company_info.F003V if company_info else None, 'main_business': company_info.F015V if company_info else None, 'business_scope': company_info.F016V if company_info else None, 'company_intro': company_info.F017V if company_info else None, 'csrc_industry_l1': company_info.F030V if company_info else None, 'csrc_industry_l2': company_info.F032V if company_info else None }, 'latest_trade': { 'trade_date': latest_trade.TRADEDATE.isoformat() if latest_trade else None, 'close_price': float(latest_trade.F007N) if latest_trade and latest_trade.F007N else None, 'change': float(latest_trade.F009N) if latest_trade and latest_trade.F009N else None, 'change_pct': float(latest_trade.F010N) if latest_trade and latest_trade.F010N else None, 'volume': float(latest_trade.F004N) if latest_trade and latest_trade.F004N else None, 'amount': float(latest_trade.F011N) if latest_trade and latest_trade.F011N else None } if latest_trade else None, 'minute_chart_data': minute_chart_data, 'related_desc': related_desc } } response = jsonify(response_data) response.headers['Content-Type'] = 'application/json; charset=utf-8' return response except Exception as e: return jsonify({ 'code': 500, 'message': str(e), 'data': None }), 500 def get_stock_minute_chart_data(stock_code): """获取股票分时图数据""" try: client = get_clickhouse_client() # 获取当前日期(使用最新的交易日) from datetime import datetime, timedelta, time as dt_time import csv def get_trading_days(): trading_days = set() with open('tdays.csv', 'r') as f: reader = csv.DictReader(f) for row in reader: trading_days.add(datetime.strptime(row['DateTime'], '%Y/%m/%d').date()) return trading_days trading_days = get_trading_days() def find_latest_trading_day(current_date): """找到最新的交易日""" while current_date >= min(trading_days): if current_date in trading_days: return current_date current_date -= timedelta(days=1) return None target_date = find_latest_trading_day(datetime.now().date()) if not target_date: return [] # 获取分时数据 data = client.execute(""" SELECT timestamp, open, high, low, close, volume, amt FROM stock_minute WHERE code = %(code)s AND timestamp BETWEEN %(start)s AND %(end)s ORDER BY timestamp """, { 'code': stock_code, 'start': datetime.combine(target_date, dt_time(9, 30)), 'end': datetime.combine(target_date, dt_time(15, 0)) }) minute_data = [] for row in data: minute_data.append({ 'time': row[0].strftime('%H:%M'), 'open': float(row[1]), 'high': float(row[2]), 'low': float(row[3]), 'close': float(row[4]), 'volume': float(row[5]), 'amount': float(row[6]) }) return minute_data except Exception as e: print(f"Error getting minute chart data: {e}") return [] # 7. 事件详情-相关概念接口 @app.route('/api/event//related-concepts', methods=['GET']) def api_event_related_concepts(event_id): """事件相关概念接口""" try: event = Event.query.get_or_404(event_id) related_concepts = event.related_concepts.all() base_url = request.host_url concepts_data = [] for concept in related_concepts: image_paths = concept.image_paths_list image_urls = [base_url + 'data/concepts/' + p for p in image_paths] concepts_data.append({ 'id': concept.id, 'concept_code': concept.concept_code, 'concept': concept.concept, 'reason': concept.reason, 'image_paths': image_paths, 'image_urls': image_urls, 'first_image': image_urls[0] if image_urls else None }) return jsonify({ 'code': 200, 'message': 'success', 'data': { 'event_id': event_id, 'event_title': event.title, 'related_concepts': concepts_data, 'total_count': len(concepts_data) } }) except Exception as e: return jsonify({ 'code': 500, 'message': str(e), 'data': None }), 500 # 8. 事件详情-历史事件接口 @app.route('/api/event//historical-events', methods=['GET']) def api_event_historical_events(event_id): """事件历史事件接口""" try: event = Event.query.get_or_404(event_id) historical_events = event.historical_events.order_by( HistoricalEvent.importance.desc(), HistoricalEvent.event_date.desc() ).all() events_data = [] for hist_event in historical_events: # 获取相关股票信息 related_stocks = [] valid_changes = [] # 用于计算涨跌幅 for stock in hist_event.stocks.all(): base_stock_code = stock.stock_code.split('.')[0] # 获取股票当日交易数据 trade_data = TradeData.query.filter( TradeData.SECCODE.startswith(base_stock_code) ).order_by(TradeData.TRADEDATE.desc()).first() if trade_data and trade_data.F010N is not None: daily_change = float(trade_data.F010N) valid_changes.append(daily_change) else: daily_change = None stock_data = { 'stock_code': stock.stock_code, 'stock_name': stock.stock_name, 'relation_desc': stock.relation_desc, 'correlation': stock.correlation, 'sector': stock.sector, 'daily_change': daily_change, 'has_trade_data': True if trade_data else False } related_stocks.append(stock_data) # 计算相关股票的平均涨幅和最大涨幅 avg_change = None max_change = None if valid_changes: avg_change = sum(valid_changes) / len(valid_changes) max_change = max(valid_changes) events_data.append({ 'id': hist_event.id, 'title': hist_event.title, 'content': hist_event.content, 'event_date': hist_event.event_date.isoformat() if hist_event.event_date else None, 'relevance': hist_event.relevance, 'importance': hist_event.importance, 'related_stocks': related_stocks, # 使用计算得到的涨幅数据 'related_avg_chg': round(avg_change, 2) if avg_change is not None else None, 'related_max_chg': round(max_change, 2) if max_change is not None else None }) # 计算当前事件的相关股票涨幅数据 current_valid_changes = [] for stock in event.related_stocks: base_stock_code = stock.stock_code.split('.')[0] trade_data = TradeData.query.filter( TradeData.SECCODE.startswith(base_stock_code) ).order_by(TradeData.TRADEDATE.desc()).first() if trade_data and trade_data.F010N is not None: current_valid_changes.append(float(trade_data.F010N)) current_avg_change = None current_max_change = None if current_valid_changes: current_avg_change = sum(current_valid_changes) / len(current_valid_changes) current_max_change = max(current_valid_changes) return jsonify({ 'code': 200, 'message': 'success', 'data': { 'event_id': event_id, 'event_title': event.title, 'invest_score': event.invest_score, 'related_avg_chg': round(current_avg_change, 2) if current_avg_change is not None else None, 'related_max_chg': round(current_max_change, 2) if current_max_change is not None else None, 'historical_events': events_data, 'total_count': len(events_data) } }) except Exception as e: print(f"Error in api_event_historical_events: {str(e)}") return jsonify({ 'code': 500, 'message': str(e), 'data': None }), 500 @app.route('/api/event//comments', methods=['GET']) def get_event_comments(event_id): """获取事件的所有评论和帖子(嵌套格式) Query参数: - page: 页码(默认1) - per_page: 每页评论数(默认20) - sort: 排序方式(time_desc/time_asc/hot, 默认time_desc) - include_posts: 是否包含帖子信息(默认true) - reply_limit: 每个评论显示的回复数量限制(默认3) 返回: { "success": true, "data": { "event": { "id": 事件ID, "title": "事件标题", "description": "事件描述" }, "posts": [帖子信息], "total": 总评论数, "current_page": 当前页码, "total_pages": 总页数, "comments": [ { "comment_id": 评论ID, "content": "评论内容", "created_at": "评论时间", "post_id": 所属帖子ID, "post_title": "帖子标题", "user": { "user_id": 用户ID, "nickname": "用户昵称", "avatar_url": "头像URL" }, "reply_count": 总回复数量, "has_more_replies": 是否有更多回复, "list": [ # 回复列表 { "comment_id": 回复ID, "content": "回复内容", "created_at": "回复时间", "user": { "user_id": 用户ID, "nickname": "用户昵称", "avatar_url": "头像URL" }, "reply_to": { # 被回复的用户信息 "user_id": 用户ID, "nickname": "用户昵称" } } ] } ] } } """ try: # 获取查询参数 page = request.args.get('page', 1, type=int) per_page = request.args.get('per_page', 20, type=int) sort = request.args.get('sort', 'time_desc') include_posts = request.args.get('include_posts', 'true').lower() == 'true' reply_limit = request.args.get('reply_limit', 3, type=int) # 每个评论显示的回复数限制 # 参数验证 if page < 1: page = 1 if per_page < 1 or per_page > 100: per_page = 20 if reply_limit < 0 or reply_limit > 50: # 限制回复数量 reply_limit = 3 # 获取事件信息 event = Event.query.get_or_404(event_id) # 获取事件下的所有帖子 posts_query = Post.query.filter_by(event_id=event_id, status='active') \ .order_by(Post.is_top.desc(), Post.created_at.desc()) posts = posts_query.all() # 格式化帖子数据 posts_data = [] if include_posts: for post in posts: posts_data.append({ 'post_id': post.id, 'title': post.title, 'content': post.content, 'content_type': post.content_type, 'created_at': post.created_at.strftime('%Y-%m-%d %H:%M:%S'), 'updated_at': post.updated_at.strftime('%Y-%m-%d %H:%M:%S') if post.updated_at else None, 'likes_count': post.likes_count, 'comments_count': post.comments_count, 'view_count': post.view_count, 'is_top': post.is_top, 'user': { 'user_id': post.user.id, 'username': post.user.username, 'nickname': post.user.nickname or post.user.username, 'avatar_url': get_full_avatar_url(post.user.avatar_url), 'user_level': post.user.user_level, 'is_verified': post.user.is_verified, 'is_creator': post.user.is_creator } }) # 获取帖子ID列表用于查询评论 post_ids = [p.id for p in posts] if not post_ids: return jsonify({ 'success': True, 'data': { 'event': { 'id': event.id, 'title': event.title, 'description': event.description, 'event_type': event.event_type, 'importance': event.importance, 'status': event.status }, 'posts': posts_data, 'total': 0, 'current_page': page, 'total_pages': 0, 'comments': [] } }) # 构建基础查询 - 只查询主评论 base_query = Comment.query.filter( Comment.post_id.in_(post_ids), Comment.parent_id == None, # 只查询主评论 Comment.status == 'active' ) # 排序处理 if sort == 'time_asc': base_query = base_query.order_by(Comment.created_at.asc()) elif sort == 'hot': # 这里可以根据你的业务逻辑添加热度排序 base_query = base_query.order_by(Comment.created_at.desc()) else: # 默认按时间倒序 base_query = base_query.order_by(Comment.created_at.desc()) # 执行分页查询 pagination = base_query.paginate(page=page, per_page=per_page, error_out=False) # 格式化评论数据(嵌套格式) comments_data = [] for comment in pagination.items: # 获取评论的总回复数量 reply_count = Comment.query.filter_by( parent_id=comment.id, status='active' ).count() # 获取指定数量的回复 replies_query = Comment.query.filter_by( parent_id=comment.id, status='active' ).order_by(Comment.created_at.asc()) # 回复按时间正序排列 if reply_limit > 0: replies = replies_query.limit(reply_limit).all() else: replies = [] # 格式化回复数据 - 作为list字段 replies_list = [] for reply in replies: # 获取被回复的用户信息(这里是回复主评论,所以reply_to就是主评论的用户) reply_to_user = { 'user_id': comment.user.id, 'nickname': comment.user.nickname or comment.user.username } replies_list.append({ 'comment_id': reply.id, 'content': reply.content, 'created_at': reply.created_at.strftime('%Y-%m-%d %H:%M:%S'), 'user': { 'user_id': reply.user.id, 'username': reply.user.username, 'nickname': reply.user.nickname or reply.user.username, 'avatar_url': get_full_avatar_url(reply.user.avatar_url), 'user_level': reply.user.user_level, 'is_verified': reply.user.is_verified }, 'reply_to': reply_to_user }) # 获取评论所属的帖子信息 post = comment.post # 构建嵌套格式的评论数据 comments_data.append({ 'comment_id': comment.id, 'content': comment.content, 'created_at': comment.created_at.strftime('%Y-%m-%d %H:%M:%S'), 'post_id': comment.post_id, 'post_title': post.title if post else None, 'post_content_preview': post.content[:100] + '...' if post and len( post.content) > 100 else post.content if post else None, 'user': { 'user_id': comment.user.id, 'username': comment.user.username, 'nickname': comment.user.nickname or comment.user.username, 'avatar_url': get_full_avatar_url(comment.user.avatar_url), 'user_level': comment.user.user_level, 'is_verified': comment.user.is_verified }, 'reply_count': reply_count, 'has_more_replies': reply_count > len(replies_list), 'list': replies_list # 嵌套的回复列表 }) return jsonify({ 'success': True, 'data': { 'event': { 'id': event.id, 'title': event.title, 'description': event.description, 'event_type': event.event_type, 'importance': event.importance, 'status': event.status, 'created_at': event.created_at.strftime('%Y-%m-%d %H:%M:%S'), 'hot_score': event.hot_score, 'view_count': event.view_count, 'post_count': event.post_count, 'follower_count': event.follower_count }, 'posts': posts_data, 'posts_count': len(posts_data), 'total': pagination.total, 'current_page': pagination.page, 'total_pages': pagination.pages, 'comments': comments_data } }) except Exception as e: return jsonify({ 'success': False, 'message': '获取评论列表失败', 'error': str(e) }), 500 @app.route('/api/comment//replies', methods=['GET']) def get_comment_replies(comment_id): """获取某条评论的所有回复 Query参数: - page: 页码(默认1) - per_page: 每页回复数(默认20) - sort: 排序方式(time_desc/time_asc, 默认time_desc) 返回格式: { "code": 200, "message": "success", "data": { "comment": { # 原评论信息 "id": 评论ID, "content": "评论内容", "created_at": "评论时间", "user": { "id": 用户ID, "nickname": "用户昵称", "avatar_url": "头像URL" } }, "replies": { # 回复信息 "total": 总回复数, "current_page": 当前页码, "total_pages": 总页数, "items": [ { "id": 回复ID, "content": "回复内容", "created_at": "回复时间", "user": { "id": 用户ID, "nickname": "用户昵称", "avatar_url": "头像URL" }, "reply_to": { # 被回复的用户信息 "id": 用户ID, "nickname": "用户昵称" } } ] } } } """ try: # 获取查询参数 page = request.args.get('page', 1, type=int) per_page = request.args.get('per_page', 20, type=int) sort = request.args.get('sort', 'time_desc') # 参数验证 if page < 1: page = 1 if per_page < 1 or per_page > 100: per_page = 20 # 获取原评论信息 comment = Comment.query.get_or_404(comment_id) if comment.status != 'active': return jsonify({ 'code': 404, 'message': '评论不存在或已被删除', 'data': None }), 404 # 构建原评论数据 comment_data = { 'comment_id': comment.id, 'content': comment.content, 'created_at': comment.created_at.strftime('%Y-%m-%d %H:%M:%S'), 'user': { 'user_id': comment.user.id, 'nickname': comment.user.nickname or comment.user.username, 'avatar_url': get_full_avatar_url(comment.user.avatar_url), # 修改这里 } } # 构建回复查询 replies_query = Comment.query.filter_by( parent_id=comment_id, status='active' ) # 排序处理 if sort == 'time_asc': replies_query = replies_query.order_by(Comment.created_at.asc()) else: # 默认按时间倒序 replies_query = replies_query.order_by(Comment.created_at.desc()) # 执行分页查询 pagination = replies_query.paginate(page=page, per_page=per_page, error_out=False) # 格式化回复数据 replies_data = [] for reply in pagination.items: # 获取被回复的用户信息 reply_to_user = None if reply.parent_id: parent_comment = Comment.query.get(reply.parent_id) if parent_comment: reply_to_user = { 'id': parent_comment.user.id, 'nickname': parent_comment.user.nickname or parent_comment.user.username } replies_data.append({ 'reply_id': reply.id, 'content': reply.content, 'created_at': reply.created_at.strftime('%Y-%m-%d %H:%M:%S'), 'user': { 'user_id': reply.user.id, 'nickname': reply.user.nickname or reply.user.username, 'avatar_url': get_full_avatar_url(reply.user.avatar_url), # 修改这里 }, 'reply_to': reply_to_user }) return jsonify({ 'code': 200, 'message': 'success', 'data': { 'comment': comment_data, 'replies': { 'total': pagination.total, 'current_page': pagination.page, 'total_pages': pagination.pages, 'items': replies_data } } }) except Exception as e: return jsonify({ 'code': 500, 'message': str(e), 'data': None }), 500 # 工具函数:处理转义字符,保留 Markdown 格式 def unescape_markdown_text(text): """ 将数据库中存储的转义字符串转换为真正的换行符和特殊字符 例如:'\\n\\n#### 标题' -> '\n\n#### 标题' """ if not text: return text # 将转义的换行符转换为真正的换行符 # 注意:这里处理的是字符串字面量 '\\n',不是转义序列 text = text.replace('\\n', '\n') text = text.replace('\\r', '\r') text = text.replace('\\t', '\t') return text.strip() # 工具函数:清理 Markdown 文本 def clean_markdown_text(text): """清理文本中的 Markdown 符号和多余的换行符 Args: text: 原始文本(可能包含 Markdown 符号) Returns: 清理后的纯文本 """ if not text: return text import re # 1. 移除 Markdown 标题符号 (### , ## , # ) text = re.sub(r'^#{1,6}\s+', '', text, flags=re.MULTILINE) # 2. 移除 Markdown 加粗符号 (**text** 或 __text__) text = re.sub(r'\*\*(.+?)\*\*', r'\1', text) text = re.sub(r'__(.+?)__', r'\1', text) # 3. 移除 Markdown 斜体符号 (*text* 或 _text_) text = re.sub(r'\*(.+?)\*', r'\1', text) text = re.sub(r'_(.+?)_', r'\1', text) # 4. 移除 Markdown 列表符号 (- , * , + , 1. ) text = re.sub(r'^[\s]*[-*+]\s+', '', text, flags=re.MULTILINE) text = re.sub(r'^[\s]*\d+\.\s+', '', text, flags=re.MULTILINE) # 5. 移除 Markdown 引用符号 (> ) text = re.sub(r'^>\s+', '', text, flags=re.MULTILINE) # 6. 移除 Markdown 代码块符号 (``` 或 `) text = re.sub(r'```[\s\S]*?```', '', text) text = re.sub(r'`(.+?)`', r'\1', text) # 7. 移除 Markdown 链接 ([text](url) -> text) text = re.sub(r'\[(.+?)\]\(.+?\)', r'\1', text) # 8. 清理多余的换行符 # 将多个连续的换行符(\n\n\n...)替换为单个换行符 text = re.sub(r'\n{3,}', '\n\n', text) # 9. 清理行首行尾的空白字符 text = re.sub(r'^\s+|\s+$', '', text, flags=re.MULTILINE) # 10. 移除多余的空格(连续多个空格替换为单个空格) text = re.sub(r' {2,}', ' ', text) # 11. 清理首尾空白 text = text.strip() return text # 10. 投资日历-事件接口(增强版) @app.route('/api/calendar/events', methods=['GET']) def api_calendar_events(): """投资日历事件接口 - 连接 future_events 表 (修正版)""" try: start_date = request.args.get('start') end_date = request.args.get('end') importance = request.args.get('importance', 'all') category = request.args.get('category', 'all') search_query = request.args.get('q', '').strip() # 新增搜索参数 page = int(request.args.get('page', 1)) per_page = int(request.args.get('per_page', 10)) offset = (page - 1) * per_page # 构建基础查询 - 使用 future_events 表 query = """ SELECT data_id, \ calendar_time, \ type, \ star, \ title, \ former, \ forecast, \ fact, \ related_stocks, \ concepts, \ inferred_tag FROM future_events WHERE 1 = 1 \ """ params = {} if start_date: query += " AND calendar_time >= :start_date" params['start_date'] = datetime.fromisoformat(start_date) if end_date: query += " AND calendar_time <= :end_date" params['end_date'] = datetime.fromisoformat(end_date) if importance != 'all': query += " AND star = :importance" params['importance'] = importance if category != 'all': # category参数用于筛选inferred_tag字段(如"大周期"、"大消费"等) query += " AND inferred_tag = :category" params['category'] = category # 新增搜索条件 if search_query: # 使用LIKE进行模糊搜索,同时搜索title和related_stocks字段 # 对于JSON字段,MySQL会将其作为文本进行搜索 query += """ AND ( title LIKE :search_pattern OR CAST(related_stocks AS CHAR) LIKE :search_pattern OR CAST(concepts AS CHAR) LIKE :search_pattern )""" params['search_pattern'] = f'%{search_query}%' query += " ORDER BY calendar_time LIMIT :limit OFFSET :offset" params['limit'] = per_page params['offset'] = offset result = db.session.execute(text(query), params) events = result.fetchall() # 总数统计(不包含分页) count_query = """ SELECT COUNT(*) as count \ FROM future_events \ WHERE 1=1 \ """ count_params = params.copy() count_params.pop('limit', None) count_params.pop('offset', None) if start_date: count_query += " AND calendar_time >= :start_date" if end_date: count_query += " AND calendar_time <= :end_date" if importance != 'all': count_query += " AND star = :importance" if category != 'all': count_query += " AND inferred_tag = :category" # 新增搜索条件到计数查询 if search_query: count_query += """ AND ( title LIKE :search_pattern OR CAST(related_stocks AS CHAR) LIKE :search_pattern OR CAST(concepts AS CHAR) LIKE :search_pattern )""" total_count_result = db.session.execute(text(count_query), count_params).fetchone() total_count = total_count_result.count if total_count_result else 0 events_data = [] for event in events: # 解析相关股票 related_stocks_list = [] related_avg_chg = 0 related_max_chg = 0 related_week_chg = 0 # 处理相关股票数据 if event.related_stocks: try: import json import ast # 使用与detail接口相同的解析逻辑 if isinstance(event.related_stocks, str): try: stock_data = json.loads(event.related_stocks) except: stock_data = ast.literal_eval(event.related_stocks) else: stock_data = event.related_stocks if stock_data: daily_changes = [] week_changes = [] # 处理正确的数据格式 [股票代码, 股票名称, 描述, 分数] for stock_info in stock_data: if isinstance(stock_info, list) and len(stock_info) >= 2: stock_code = stock_info[0] # 股票代码 stock_name = stock_info[1] # 股票名称 description = stock_info[2] if len(stock_info) > 2 else '' score = stock_info[3] if len(stock_info) > 3 else 0 else: continue if stock_code: # 规范化股票代码,移除后缀 clean_code = stock_code.replace('.SZ', '').replace('.SH', '').replace('.BJ', '') # 使用模糊匹配查询真实的交易数据 trade_query = """ SELECT F007N as close_price, F010N as change_pct, TRADEDATE FROM ea_trade WHERE SECCODE LIKE :stock_code_pattern ORDER BY TRADEDATE DESC LIMIT 7 \ """ trade_result = db.session.execute(text(trade_query), {'stock_code_pattern': f'{clean_code}%'}) trade_data = trade_result.fetchall() daily_chg = 0 week_chg = 0 if trade_data: # 日涨跌幅(当日) daily_chg = float(trade_data[0].change_pct or 0) # 周涨跌幅(5个交易日) if len(trade_data) >= 5: current_price = float(trade_data[0].close_price or 0) week_ago_price = float(trade_data[4].close_price or 0) if week_ago_price > 0: week_chg = ((current_price - week_ago_price) / week_ago_price) * 100 # 收集涨跌幅数据 daily_changes.append(daily_chg) week_changes.append(week_chg) related_stocks_list.append({ 'code': stock_code, 'name': stock_name, 'description': description, 'score': score, 'daily_chg': daily_chg, 'week_chg': week_chg }) # 计算平均收益率 if daily_changes: related_avg_chg = round(sum(daily_changes) / len(daily_changes), 4) related_max_chg = round(max(daily_changes), 4) if week_changes: related_week_chg = round(sum(week_changes) / len(week_changes), 4) except Exception as e: print(f"Error processing related stocks for event {event.data_id}: {e}") # 解析相关概念 related_concepts = extract_concepts_from_concepts_field(event.concepts) # 获取评星等级 star_rating = event.star # 如果有搜索关键词,可以高亮显示匹配的部分(可选功能) highlight_match = False if search_query: # 检查是否在标题中匹配 if search_query.lower() in (event.title or '').lower(): highlight_match = 'title' # 检查是否在股票中匹配 elif any(search_query.lower() in str(stock).lower() for stock in related_stocks_list): highlight_match = 'stocks' # 检查是否在概念中匹配 elif search_query.lower() in str(related_concepts).lower(): highlight_match = 'concepts' # 将转义的换行符转换为真正的换行符,保留 Markdown 格式 cleaned_former = unescape_markdown_text(event.former) cleaned_forecast = unescape_markdown_text(event.forecast) cleaned_fact = unescape_markdown_text(event.fact) event_dict = { 'id': event.data_id, 'title': event.title, 'description': f"前值: {cleaned_former}, 预测: {cleaned_forecast}, 实际: {cleaned_fact}" if cleaned_former or cleaned_forecast or cleaned_fact else "", 'start_time': event.calendar_time.isoformat() if event.calendar_time else None, 'end_time': None, # future_events 表没有结束时间 'category': { 'event_type': event.type, 'importance': event.star, 'star_rating': star_rating, 'inferred_tag': event.inferred_tag # 添加inferred_tag到返回数据 }, 'star_rating': star_rating, 'inferred_tag': event.inferred_tag, # 直接返回行业标签 'related_concepts': related_concepts, 'related_stocks': related_stocks_list, 'related_avg_chg': round(related_avg_chg, 2), 'related_max_chg': round(related_max_chg, 2), 'related_week_chg': round(related_week_chg, 2), 'former': cleaned_former, 'forecast': cleaned_forecast, 'fact': cleaned_fact } # 可选:添加搜索匹配标记 if search_query and highlight_match: event_dict['search_match'] = highlight_match events_data.append(event_dict) return jsonify({ 'code': 200, 'message': 'success', 'data': { 'events': events_data, 'total_count': total_count, 'page': page, 'per_page': per_page, 'total_pages': (total_count + per_page - 1) // per_page, 'search_query': search_query # 返回搜索关键词 } }) except Exception as e: return jsonify({ 'code': 500, 'message': str(e), 'data': None }), 500 # 11. 投资日历-数据接口 @app.route('/api/calendar/data', methods=['GET']) def api_calendar_data(): """投资日历数据接口""" try: start_date = request.args.get('start') end_date = request.args.get('end') data_type = request.args.get('type', 'all') # 分页参数 page = int(request.args.get('page', 1)) page_size = int(request.args.get('page_size', 20)) # 默认每页20条 # 验证分页参数 if page < 1: page = 1 if page_size < 1 or page_size > 100: # 限制每页最大100条 page_size = 20 query1 = RelatedData.query if start_date: query1 = query1.filter(RelatedData.created_at >= datetime.fromisoformat(start_date)) if end_date: query1 = query1.filter(RelatedData.created_at <= datetime.fromisoformat(end_date)) if data_type != 'all': query1 = query1.filter_by(data_type=data_type) data_list1 = query1.order_by(RelatedData.created_at.desc()).all() query2_sql = """ SELECT data_id as id, \ title, \ type as data_type, \ former, \ forecast, \ fact, \ star, \ calendar_time as created_at FROM future_events WHERE type = 'data' \ """ # 添加时间筛选条件 params = {} if start_date: query2_sql += " AND calendar_time >= :start_date" params['start_date'] = start_date if end_date: query2_sql += " AND calendar_time <= :end_date" params['end_date'] = end_date if data_type != 'all': query2_sql += " AND type = :data_type" params['data_type'] = data_type query2_sql += " ORDER BY calendar_time DESC" result2 = db.session.execute(text(query2_sql), params) result_data = [] # 处理 RelatedData 的数据 for data in data_list1: result_data.append({ 'id': data.id, 'title': data.title, 'data_type': data.data_type, 'data_content': data.data_content, 'description': data.description, 'created_at': data.created_at.isoformat() if data.created_at else None, 'event_id': data.event_id, 'source': 'related_data', # 标识数据来源 'former': None, 'forecast': None, 'fact': None, 'star': None }) # 处理 future_events 的数据 for row in result2: result_data.append({ 'id': row.id, 'title': row.title, 'data_type': row.data_type, 'data_content': None, 'description': None, 'created_at': row.created_at.isoformat() if row.created_at else None, 'event_id': None, 'source': 'future_events', # 标识数据来源 'former': row.former, 'forecast': row.forecast, 'fact': row.fact, 'star': row.star }) # 按时间排序(最新的在前面) result_data.sort(key=lambda x: x['created_at'] or '1900-01-01', reverse=True) # 计算分页 total_count = len(result_data) total_pages = (total_count + page_size - 1) // page_size # 向上取整 # 计算起始和结束索引 start_index = (page - 1) * page_size end_index = start_index + page_size # 获取当前页数据 current_page_data = result_data[start_index:end_index] # 分别统计两个数据源的数量(用于原有逻辑) related_data_count = len(data_list1) future_events_count = len(list(result2)) return jsonify({ 'code': 200, 'message': 'success', 'data': { 'data_list': current_page_data, 'pagination': { 'current_page': page, 'page_size': page_size, 'total_count': total_count, 'total_pages': total_pages, 'has_next': page < total_pages, 'has_prev': page > 1 }, # 保留原有字段,便于兼容 'total_count': total_count, 'related_data_count': related_data_count, 'future_events_count': future_events_count } }) except ValueError as ve: # 处理分页参数格式错误 return jsonify({ 'code': 400, 'message': f'分页参数格式错误: {str(ve)}', 'data': None }), 400 except Exception as e: return jsonify({ 'code': 500, 'message': str(e), 'data': None }), 500 # 12. 投资日历-详情接口 def extract_concepts_from_concepts_field(concepts_text): """从concepts字段中提取概念信息""" if not concepts_text: return [] try: import json import ast # 解析concepts字段的JSON/字符串数据 if isinstance(concepts_text, str): try: # 先尝试JSON解析 concepts_data = json.loads(concepts_text) except: # 如果JSON解析失败,尝试ast.literal_eval解析 concepts_data = ast.literal_eval(concepts_text) else: concepts_data = concepts_text extracted_concepts = [] for concept_info in concepts_data: if isinstance(concept_info, list) and len(concept_info) >= 3: concept_name = concept_info[0] # 概念名称 reason = concept_info[1] # 原因/描述 score = concept_info[2] # 分数 extracted_concepts.append({ 'name': concept_name, 'reason': reason, 'score': score }) return extracted_concepts except Exception as e: print(f"Error extracting concepts: {e}") return [] @app.route('/api/calendar/detail/', methods=['GET']) def api_future_event_detail(item_id): """未来事件详情接口 - 连接 future_events 表 (修正数据解析) - 仅限 Pro/Max 会员""" try: # 从 future_events 表查询事件详情 query = """ SELECT data_id, \ calendar_time, \ type, \ star, \ title, \ former, \ forecast, \ fact, \ related_stocks, \ concepts FROM future_events WHERE data_id = :item_id \ """ result = db.session.execute(text(query), {'item_id': item_id}) event = result.fetchone() if not event: return jsonify({ 'code': 404, 'message': 'Event not found', 'data': None }), 404 extracted_concepts = extract_concepts_from_concepts_field(event.concepts) # 解析相关股票 related_stocks_list = [] sector_stats = { '全部股票': 0, '大周期': 0, '大消费': 0, 'TMT板块': 0, '大金融地产': 0, '公共产业板块': 0, '其他': 0 } # 申万一级行业到主板块的映射 sector_map = { # 大周期 '石油石化': '大周期', '煤炭': '大周期', '有色金属': '大周期', '钢铁': '大周期', '基础化工': '大周期', '建筑材料': '大周期', '机械设备': '大周期', '电力设备及新能源': '大周期', '国防军工': '大周期', '电力设备': '大周期', '电网设备': '大周期', '风力发电': '大周期', '太阳能发电': '大周期', '建筑装饰': '大周期', '建筑': '大周期', '交通运输': '大周期', '采掘': '大周期', '公用事业': '大周期', # 大消费 '汽车': '大消费', '家用电器': '大消费', '酒类': '大消费', '食品饮料': '大消费', '医药生物': '大消费', '纺织服饰': '大消费', '农林牧渔': '大消费', '商贸零售': '大消费', '轻工制造': '大消费', '消费者服务': '大消费', '美容护理': '大消费', '社会服务': '大消费', '纺织服装': '大消费', '商业贸易': '大消费', '休闲服务': '大消费', # 大金融地产 '银行': '大金融地产', '证券': '大金融地产', '保险': '大金融地产', '多元金融': '大金融地产', '综合金融': '大金融地产', '房地产': '大金融地产', '非银金融': '大金融地产', # TMT板块 '计算机': 'TMT板块', '电子': 'TMT板块', '传媒': 'TMT板块', '通信': 'TMT板块', # 公共产业 '环保': '公共产业板块', '综合': '公共产业板块' } # 处理相关股票 related_avg_chg = 0 related_max_chg = 0 related_week_chg = 0 if event.related_stocks: try: import json import ast # **修正:正确解析related_stocks数据结构** if isinstance(event.related_stocks, str): try: # 先尝试JSON解析 stock_data = json.loads(event.related_stocks) except: # 如果JSON解析失败,尝试ast.literal_eval解析 stock_data = ast.literal_eval(event.related_stocks) else: stock_data = event.related_stocks print(f"Parsed stock_data: {stock_data}") # 调试输出 if stock_data: daily_changes = [] week_changes = [] # **修正:处理正确的数据格式 [股票代码, 股票名称, 描述, 分数]** for stock_info in stock_data: if isinstance(stock_info, list) and len(stock_info) >= 2: stock_code = stock_info[0] # 第一个元素是股票代码 stock_name = stock_info[1] # 第二个元素是股票名称 description = stock_info[2] if len(stock_info) > 2 else '' score = stock_info[3] if len(stock_info) > 3 else 0 else: continue # 跳过格式不正确的数据 if stock_code: # 规范化股票代码,移除后缀 clean_code = stock_code.replace('.SZ', '').replace('.SH', '').replace('.BJ', '') print(f"Processing stock: {clean_code} - {stock_name}") # 调试输出 # 使用模糊匹配LIKE查询申万一级行业F004V sector_query = """ SELECT F004V as sw_primary_sector FROM ea_sector WHERE SECCODE LIKE :stock_code_pattern AND F002V = '申银万国行业分类' LIMIT 1 \ """ sector_result = db.session.execute(text(sector_query), {'stock_code_pattern': f'{clean_code}%'}) sector_row = sector_result.fetchone() # 根据申万一级行业(F004V)映射到主板块 sw_primary_sector = sector_row.sw_primary_sector if sector_row else None primary_sector = sector_map.get(sw_primary_sector, '其他') if sw_primary_sector else '其他' print( f"Stock: {clean_code}, SW Primary: {sw_primary_sector}, Primary Sector: {primary_sector}") # 通过SQL查询获取真实的日涨跌幅和周涨跌幅 trade_query = """ SELECT F007N as close_price, F010N as change_pct, TRADEDATE FROM ea_trade WHERE SECCODE LIKE :stock_code_pattern ORDER BY TRADEDATE DESC LIMIT 7 \ """ trade_result = db.session.execute(text(trade_query), {'stock_code_pattern': f'{clean_code}%'}) trade_data = trade_result.fetchall() daily_chg = 0 week_chg = 0 if trade_data: # 日涨跌幅(当日) daily_chg = float(trade_data[0].change_pct or 0) # 周涨跌幅(5个交易日) if len(trade_data) >= 5: current_price = float(trade_data[0].close_price or 0) week_ago_price = float(trade_data[4].close_price or 0) if week_ago_price > 0: week_chg = ((current_price - week_ago_price) / week_ago_price) * 100 print( f"Trade data found: {len(trade_data) if trade_data else 0} records, daily_chg: {daily_chg}") # 统计各分类数量 sector_stats['全部股票'] += 1 sector_stats[primary_sector] += 1 # 收集涨跌幅数据 daily_changes.append(daily_chg) week_changes.append(week_chg) related_stocks_list.append({ 'code': stock_code, # 原始股票代码 'name': stock_name, # 股票名称 'description': description, # 关联描述 'score': score, # 关联分数 'sw_primary_sector': sw_primary_sector, # 申万一级行业(F004V) 'primary_sector': primary_sector, # 主板块分类 'daily_change': daily_chg, # 真实的日涨跌幅 'week_change': week_chg # 真实的周涨跌幅 }) # 计算平均收益率 if daily_changes: related_avg_chg = sum(daily_changes) / len(daily_changes) related_max_chg = max(daily_changes) if week_changes: related_week_chg = sum(week_changes) / len(week_changes) except Exception as e: print(f"Error processing related stocks: {e}") import traceback traceback.print_exc() # 构建返回数据 detail_data = { 'id': event.data_id, 'title': event.title, 'type': event.type, 'star': event.star, 'calendar_time': event.calendar_time.isoformat() if event.calendar_time else None, 'former': event.former, 'forecast': event.forecast, 'fact': event.fact, 'concepts': event.concepts, 'extracted_concepts': extracted_concepts, 'related_stocks': related_stocks_list, 'sector_stats': sector_stats, 'related_avg_chg': round(related_avg_chg, 2), 'related_max_chg': round(related_max_chg, 2), 'related_week_chg': round(related_week_chg, 2) } return jsonify({ 'code': 200, 'message': 'success', 'data': { 'type': 'future_event', 'detail': detail_data } }) except Exception as e: return jsonify({ 'code': 500, 'message': str(e), 'data': None }), 500 # 13-15. 筛选弹窗接口(已有,优化格式) @app.route('/api/filter/options', methods=['GET']) def api_filter_options(): """筛选选项接口""" try: # 获取排序选项 sort_options = [ {'key': 'new', 'name': '最新', 'desc': '按创建时间排序'}, {'key': 'hot', 'name': '热门', 'desc': '按热度分数排序'}, {'key': 'returns', 'name': '收益率', 'desc': '按收益率排序'}, {'key': 'importance', 'name': '重要性', 'desc': '按重要性等级排序'}, {'key': 'view_count', 'name': '浏览量', 'desc': '按浏览次数排序'} ] # 获取行业筛选选项 industry_options = db.session.execute(text(""" SELECT DISTINCT f002v as classification_name, COUNT(*) as count FROM ea_sector WHERE f002v IS NOT NULL GROUP BY f002v ORDER BY f002v """)).fetchall() # 获取重要性选项 importance_options = [ {'key': 'S', 'name': 'S级', 'desc': '重大事件'}, {'key': 'A', 'name': 'A级', 'desc': '重要事件'}, {'key': 'B', 'name': 'B级', 'desc': '普通事件'}, {'key': 'C', 'name': 'C级', 'desc': '参考事件'} ] return jsonify({ 'code': 200, 'message': 'success', 'data': { 'sort_options': sort_options, 'industry_options': [{ 'name': row.classification_name, 'count': row.count } for row in industry_options], 'importance_options': importance_options } }) except Exception as e: return jsonify({ 'code': 500, 'message': str(e), 'data': None }), 500 # 16-17. 会员权益接口 @app.route('/api/membership/status', methods=['GET']) @token_required def api_membership_status(): """会员状态接口""" try: user = request.user # TODO: 根据实际业务逻辑判断会员状态 # 这里假设用户表中有会员相关字段 is_member = getattr(user, 'is_member', False) member_expire_date = getattr(user, 'member_expire_date', None) return jsonify({ 'code': 200, 'message': 'success', 'data': { 'user_id': user.id, 'is_member': is_member, 'member_expire_date': member_expire_date.isoformat() if member_expire_date else None, 'user_level': user.user_level, 'benefits': { 'unlimited_access': is_member, 'priority_support': is_member, 'advanced_analytics': is_member, 'custom_alerts': is_member } } }) except Exception as e: return jsonify({ 'code': 500, 'message': str(e), 'data': None }), 500 # 18-19. 个人中心接口 @app.route('/api/user/profile', methods=['GET']) @token_required def api_user_profile(): """个人资料接口""" try: user = request.user likes_count = PostLike.query.filter_by(user_id=user.id).count() follows_count = EventFollow.query.filter_by(user_id=user.id).count() comments_made = Comment.query.filter_by(user_id=user.id).count() comments_received = db.session.query(Comment) \ .join(Post, Comment.post_id == Post.id) \ .filter(Post.user_id == user.id).count() replies_received = Comment.query.filter( Comment.parent_id.in_( db.session.query(Comment.id).filter_by(user_id=user.id) ) ).count() # 总评论数(发出的评论 + 收到的评论和回复) total_comments = comments_made + comments_received + replies_received profile_data = { 'basic_info': { 'user_id': user.id, 'username': user.username, 'email': user.email, 'phone': user.phone, 'nickname': user.nickname, 'avatar_url': get_full_avatar_url(user.avatar_url), # 修改这里 'bio': user.bio, 'gender': user.gender, 'birth_date': user.birth_date.isoformat() if user.birth_date else None, 'location': user.location }, 'account_status': { 'email_confirmed': user.email_confirmed, 'phone_confirmed': user.phone_confirmed, 'is_verified': user.is_verified, 'verify_time': user.verify_time.isoformat() if user.verify_time else None, 'created_at': user.created_at.isoformat() if user.created_at else None, 'last_seen': user.last_seen.isoformat() if user.last_seen else None }, 'statistics': { 'likes_count': likes_count, # 点赞数 'follows_count': follows_count, # 关注数 'total_comments': total_comments, # 总评论数 'comments_detail': { 'comments_made': comments_made, # 发出的评论 'comments_received': comments_received, # 收到的评论 'replies_received': replies_received # 收到的回复 } }, 'investment_preferences': { 'trading_experience': user.trading_experience, 'investment_style': user.investment_style, 'risk_preference': user.risk_preference, 'investment_amount': user.investment_amount, 'preferred_markets': user.preferred_markets }, 'community_stats': { 'user_level': user.user_level, 'reputation_score': user.reputation_score, 'contribution_point': user.contribution_point, 'post_count': user.post_count, 'comment_count': user.comment_count, 'follower_count': user.follower_count, 'following_count': user.following_count }, 'settings': { 'email_notifications': user.email_notifications, 'sms_notifications': user.sms_notifications, 'privacy_level': user.privacy_level, 'theme_preference': user.theme_preference } } return jsonify({ 'code': 200, 'message': 'success', 'data': profile_data }) except Exception as e: return jsonify({ 'code': 500, 'message': str(e), 'data': None }), 500 # 在文件开头添加缓存变量 _agreements_cache = {} _cache_loaded = False def load_agreements_from_docx(): """从docx文件中加载协议内容,只读取一次""" global _agreements_cache, _cache_loaded if _cache_loaded: return _agreements_cache try: # 定义文件路径和对应的协议类型 docx_files = { 'about_us': 'about_us.docx', # 关于我们 'service_terms': 'service_terms.docx', # 服务条款 'privacy_policy': 'privacy_policy.docx' # 隐私政策 } # 定义协议标题 titles = { 'about_us': '关于我们', 'service_terms': '服务条款', 'privacy_policy': '隐私政策' } for agreement_type, filename in docx_files.items(): file_path = os.path.join(os.path.dirname(__file__), filename) if os.path.exists(file_path): try: # 读取docx文件 doc = Document(file_path) # 提取文本内容 content_paragraphs = [] for paragraph in doc.paragraphs: if paragraph.text.strip(): # 跳过空段落 content_paragraphs.append(paragraph.text.strip()) # 合并所有段落 content = '\n\n'.join(content_paragraphs) # 获取文件修改时间作为版本标识 file_stat = os.stat(file_path) last_modified = file_stat.st_mtime # 缓存内容 _agreements_cache[agreement_type] = { 'title': titles.get(agreement_type, agreement_type), 'content': content, 'last_updated': last_modified, 'version': '1.0', 'file_path': filename } print(f"Successfully loaded {agreement_type} from {filename}") except Exception as e: print(f"Error reading {filename}: {str(e)}") # 如果读取失败,使用默认内容 _agreements_cache[agreement_type] = { 'title': titles.get(agreement_type, agreement_type), 'content': f"协议内容正在加载中,请稍后再试。(文件:{filename})", 'last_updated': None, 'version': '1.0', 'file_path': filename, 'error': str(e) } else: print(f"File not found: {filename}") # 如果文件不存在,使用默认内容 _agreements_cache[agreement_type] = { 'title': titles.get(agreement_type, agreement_type), 'content': f"协议文件未找到,请联系管理员。(文件:{filename})", 'last_updated': None, 'version': '1.0', 'file_path': filename, 'error': 'File not found' } _cache_loaded = True print(f"Agreements cache loaded successfully. Total: {len(_agreements_cache)} agreements") except Exception as e: print(f"Error loading agreements: {str(e)}") _cache_loaded = False return _agreements_cache @app.route('/api/agreements', methods=['GET']) def api_agreements(): """平台协议接口 - 从docx文件读取""" try: # 获取查询参数 agreement_type = request.args.get('type') # about_us, service_terms, privacy_policy force_reload = request.args.get('reload', 'false').lower() == 'true' # 强制重新加载 # 如果需要强制重新加载,清除缓存 if force_reload: global _cache_loaded _cache_loaded = False _agreements_cache.clear() # 加载协议内容 agreements_data = load_agreements_from_docx() if not agreements_data: return jsonify({ 'code': 500, 'message': 'Failed to load agreements', 'data': None }), 500 # 如果指定了特定协议类型,只返回该协议 if agreement_type and agreement_type in agreements_data: return jsonify({ 'code': 200, 'message': 'success', 'data': { 'agreement_type': agreement_type, **agreements_data[agreement_type] } }) # 返回所有协议 return jsonify({ 'code': 200, 'message': 'success', 'data': { 'agreements': agreements_data, 'available_types': list(agreements_data.keys()), 'cache_loaded': _cache_loaded, 'total_agreements': len(agreements_data) } }) except Exception as e: return jsonify({ 'code': 500, 'message': str(e), 'data': None }), 500 # 20. 个人中心-我的关注接口 @app.route('/api/user/activities', methods=['GET']) @token_required def api_user_activities(): """用户活动接口(我的关注、评论、点赞)""" try: user = request.user activity_type = request.args.get('type', 'follows') # follows, comments, likes, commented page = request.args.get('page', 1, type=int) per_page = min(50, request.args.get('per_page', 20, type=int)) if activity_type == 'follows': # 我的关注列表 follows = EventFollow.query.filter_by(user_id=user.id) \ .order_by(EventFollow.created_at.desc()) \ .paginate(page=page, per_page=per_page, error_out=False) activities = [] for follow in follows.items: # 获取相关股票并添加单日涨幅 related_stocks_data = [] for stock in follow.event.related_stocks.limit(5): # 处理股票代码,移除可能的后缀(如 .SZ 或 .SH) base_stock_code = stock.stock_code.split('.')[0] # 获取股票最新交易数据 trade_data = TradeData.query.filter( TradeData.SECCODE.startswith(base_stock_code) ).order_by(TradeData.TRADEDATE.desc()).first() # 计算单日涨幅 daily_change = None if trade_data and trade_data.F010N is not None: daily_change = float(trade_data.F010N) related_stocks_data.append({ 'stock_code': stock.stock_code, 'stock_name': stock.stock_name, 'correlation': stock.correlation, 'daily_change': daily_change, # 新增:单日涨幅 'daily_change_formatted': f"{daily_change:.2f}%" if daily_change is not None else "暂无数据" # 格式化显示 }) activities.append({ 'event_id': follow.event_id, 'event_title': follow.event.title, 'event_description': follow.event.description, 'follow_time': follow.created_at.isoformat() if follow.created_at else None, 'event_hot_score': follow.event.hot_score, # 新增字段 'importance': follow.event.importance, # 重要性 'related_avg_chg': follow.event.related_avg_chg, # 平均涨幅 'related_max_chg': follow.event.related_max_chg, # 最大涨幅 'related_week_chg': follow.event.related_week_chg, # 周涨幅 'related_stocks': related_stocks_data, # 修改:包含单日涨幅的相关股票 'created_at': follow.event.created_at.isoformat() if follow.event.created_at else None, # 发布时间 'preview': follow.event.description[:200] if follow.event.description else None, # 预览(限制200字) 'comment_count': follow.event.post_count, # 评论数 'view_count': follow.event.view_count, # 评论数 'follower_count': follow.event.follower_count # 关注数 }) total = follows.total pages = follows.pages elif activity_type == 'likes': # 我的点赞列表 likes = PostLike.query.filter_by(user_id=user.id) \ .order_by(PostLike.created_at.desc()) \ .paginate(page=page, per_page=per_page, error_out=False) activities = [{ 'like_id': like.id, 'post_id': like.post_id, 'post_content': like.post.content, 'like_time': like.created_at.isoformat() if like.created_at else None, # 新增发布人信息 'author': { 'nickname': like.post.user.nickname or like.post.user.username, 'avatar_url': get_full_avatar_url(like.post.user.avatar_url), # 修改这里 } } for like in likes.items] total = likes.total pages = likes.pages elif activity_type == 'comments': # 我的评论列表(增强版 - 添加重要性和事件内容) comments = Comment.query.filter_by(user_id=user.id) \ .join(Post, Comment.post_id == Post.id) \ .join(Event, Post.event_id == Event.id) \ .order_by(Comment.created_at.desc()) \ .paginate(page=page, per_page=per_page, error_out=False) activities = [] for comment in comments.items: # 通过关联路径获取事件信息:comment.post_id -> post.id -> post.event_id -> event.id post = comment.post event = post.event if post else None activity_data = { 'comment_id': comment.id, 'post_id': comment.post_id, 'content': comment.content, # 评论内容 'created_at': comment.created_at.isoformat() if comment.created_at else None, 'post_title': post.title if post and post.title else None, 'post_content': post.content if post else None, # 新增:评论者信息(当前用户) 'commenter': { 'id': comment.user.id, 'username': comment.user.username, 'nickname': comment.user.nickname or comment.user.username, 'avatar_url': get_full_avatar_url(comment.user.avatar_url), 'user_level': comment.user.user_level, 'is_verified': comment.user.is_verified }, # 新增字段:事件信息 'event': { 'id': event.id if event else None, 'title': event.title if event else None, 'description': event.description if event else None, # 事件内容 'importance': event.importance if event else None, # 重要性 'event_type': event.event_type if event else None, 'hot_score': event.hot_score if event else None, 'view_count': event.view_count if event else None, 'related_avg_chg': event.related_avg_chg if event else None, 'created_at': event.created_at.isoformat() if event and event.created_at else None }, # 新增:帖子作者信息 'post_author': { 'id': post.user.id if post else None, 'username': post.user.username if post else None, 'nickname': post.user.nickname or post.user.username if post else None, 'avatar_url': get_full_avatar_url(post.user.avatar_url) if post else None, } } activities.append(activity_data) total = comments.total pages = comments.pages elif activity_type == 'commented': # 评论了我的帖子 my_posts = Post.query.filter_by(user_id=user.id).subquery() comments = Comment.query.join(my_posts, Comment.post_id == my_posts.c.id) \ .filter(Comment.user_id != user.id) \ .order_by(Comment.created_at.desc()) \ .paginate(page=page, per_page=per_page, error_out=False) activities = [{ 'comment_id': comment.id, 'comment_content': comment.content, 'comment_time': comment.created_at.isoformat() if comment.created_at else None, 'commenter_nickname': comment.user.nickname or comment.user.username, 'commenter_avatar': get_full_avatar_url(comment.user.avatar_url), # 修改这里 'post_content': comment.post.content, 'event_title': comment.post.event.title, 'event_id': comment.post.event_id } for comment in comments.items] total = comments.total pages = comments.pages return jsonify({ 'code': 200, 'message': 'success', 'data': { 'activities': activities, 'total': total, 'pages': pages, 'current_page': page } }) except Exception as e: print(f"Error in api_user_activities: {str(e)}") return jsonify({ 'code': 500, 'message': '服务器内部错误', 'data': None }), 500 class UserFeedback(db.Model): """用户反馈模型""" id = db.Column(db.Integer, primary_key=True) user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False) type = db.Column(db.String(50), nullable=False) # 反馈类型 content = db.Column(db.Text, nullable=False) # 反馈内容 contact_info = db.Column(db.String(100)) # 联系方式 status = db.Column(db.String(20), default='pending') # 状态:pending/processing/resolved/closed admin_reply = db.Column(db.Text) # 管理员回复 created_at = db.Column(db.DateTime, default=beijing_now) updated_at = db.Column(db.DateTime, default=beijing_now, onupdate=beijing_now) # 关联关系 user = db.relationship('User', backref='feedbacks') def __init__(self, user_id, type, content, contact_info=None): self.user_id = user_id self.type = type self.content = content self.contact_info = contact_info def to_dict(self): return { 'id': self.id, 'type': self.type, 'content': self.content, 'contact_info': self.contact_info, 'status': self.status, 'admin_reply': self.admin_reply, 'created_at': self.created_at.strftime('%Y-%m-%d %H:%M:%S'), 'updated_at': self.updated_at.strftime('%Y-%m-%d %H:%M:%S') } # 通用错误处理 @app.errorhandler(404) def api_not_found(error): if request.path.startswith('/api/'): return jsonify({ 'code': 404, 'message': '接口不存在', 'data': None }), 404 return error @app.errorhandler(405) def api_method_not_allowed(error): if request.path.startswith('/api/'): return jsonify({ 'code': 405, 'message': '请求方法不允许', 'data': None }), 405 return error if __name__ == '__main__': # 初始化申银万国行业分类缓存 with app.app_context(): init_sywg_industry_cache() app.run( host='0.0.0.0', port=5002, debug=True, ssl_context=( '/etc/letsencrypt/live/api.valuefrontier.cn/fullchain.pem', '/etc/letsencrypt/live/api.valuefrontier.cn/privkey.pem' ) )