diff --git a/BYTEDESK_INTEGRATION_FILES.txt b/BYTEDESK_INTEGRATION_FILES.txt deleted file mode 100644 index 9088fcc9..00000000 --- a/BYTEDESK_INTEGRATION_FILES.txt +++ /dev/null @@ -1,49 +0,0 @@ -# Bytedesk 客服系统集成文件 - -以下文件和目录属于客服系统集成功能,未提交到当前分支: - -## 1. Dify 机器人控制逻辑 -**位置**: public/index.html -**状态**: 已存入 stash -**Stash ID**: stash@{0} -**说明**: 根据路径控制 Dify 机器人显示(已设置为完全不显示,只使用 Bytedesk 客服) - -## 2. Bytedesk 集成代码 -**位置**: src/bytedesk-integration/ -**状态**: 未跟踪文件(需要手动管理) -**内容**: - - .env.bytedesk.example - Bytedesk 环境变量配置示例 - - App.jsx.example - 集成 Bytedesk 的示例代码 - - components/ - Bytedesk 相关组件 - - config/ - Bytedesk 配置文件 - - 前端工程师集成手册.md - 详细集成文档 - -## 恢复方法 - -### 恢复 public/index.html 的改动: -```bash -git stash apply stash@{0} -``` - -### 使用 Bytedesk 集成代码: -```bash -# 查看集成手册 -cat src/bytedesk-integration/前端工程师集成手册.md - -# 复制示例配置 -cp src/bytedesk-integration/.env.bytedesk.example .env.bytedesk -cp src/bytedesk-integration/App.jsx.example src/App.jsx -``` - -## 注意事项 - -⚠️ **重要提示:** -- `src/bytedesk-integration/` 目录中的文件是未跟踪的(untracked) -- 如果需要提交客服功能,需要先添加到 git: - ```bash - git add src/bytedesk-integration/ - git commit -m "feat: 集成 Bytedesk 客服系统" - ``` - -- 当前分支(feature_bugfix/251110_event)专注于非客服功能 -- 建议在单独的分支中开发客服功能 diff --git a/app_vx.py b/app_vx.py index e27d7b1a..f03f89df 100644 --- a/app_vx.py +++ b/app_vx.py @@ -16,6 +16,28 @@ import time from sqlalchemy import create_engine, text, func, or_, case, event, desc, asc from flask import Flask, has_request_context, render_template, request, jsonify, redirect, url_for, flash, session, render_template_string, current_app, send_from_directory +# Flask 3.x 兼容性补丁:flask-sqlalchemy 旧版本需要 _app_ctx_stack +try: + from flask import _app_ctx_stack +except ImportError: + import flask + from werkzeug.local import LocalStack + import threading + + # 创建一个兼容的 LocalStack 子类 + class CompatLocalStack(LocalStack): + @property + def __ident_func__(self): + # 返回当前线程的标识函数 + # 优先使用 greenlet(协程),否则使用 threading + try: + from greenlet import getcurrent + return getcurrent + except ImportError: + return threading.get_ident + + flask._app_ctx_stack = CompatLocalStack() + from flask_sqlalchemy import SQLAlchemy from flask_login import LoginManager, UserMixin, login_user, logout_user, login_required, current_user from flask_mail import Mail, Message @@ -1518,9 +1540,6 @@ def like_post(post_id): post.likes_count += 1 message = '已点赞' - # 可以在这里添加点赞通知 - if post.user_id != request.user.id: - notify_user_post_liked(post) db.session.commit() return jsonify({ @@ -1597,15 +1616,6 @@ def add_comment(post_id): db.session.add(comment) post.comments_count += 1 - # 如果是回复评论,可以添加通知 - if parent_id: - parent_comment = Comment.query.get(parent_id) - if parent_comment and parent_comment.user_id != request.user.id: - notify_user_comment_replied(parent_comment) - - # 如果是评论帖子,通知帖子作者 - elif post.user_id != request.user.id: - notify_user_post_commented(post) db.session.commit() @@ -3853,17 +3863,171 @@ def api_event_related_stocks(event_id): print(f"Error fetching minute data for {stock_code}: {e}") return [] + # ==================== 性能优化:批量查询所有股票数据 ==================== + # 1. 收集所有股票代码 + stock_codes = [stock.stock_code for stock in related_stocks] + + # 2. 批量查询股票基本信息 + stock_info_map = {} + if stock_codes: + stock_infos = StockBasicInfo.query.filter(StockBasicInfo.SECCODE.in_(stock_codes)).all() + for info in stock_infos: + stock_info_map[info.SECCODE] = info + + # 处理不带后缀的股票代码 + base_codes = [code.split('.')[0] for code in stock_codes if '.' in code and code not in stock_info_map] + if base_codes: + base_infos = StockBasicInfo.query.filter(StockBasicInfo.SECCODE.in_(base_codes)).all() + for info in base_infos: + # 将不带后缀的信息映射到带后缀的代码 + for code in stock_codes: + if code.split('.')[0] == info.SECCODE and code not in stock_info_map: + stock_info_map[code] = info + + # 3. 批量查询 ClickHouse 数据(价格、涨跌幅、分时图数据) + price_data_map = {} # 存储价格和涨跌幅数据 + minute_chart_map = {} # 存储分时图数据 + + try: + if stock_codes: + print(f"批量查询 {len(stock_codes)} 只股票的价格数据...") + + # 3.1 批量查询价格和涨跌幅数据(使用子查询方式,避免窗口函数与 GROUP BY 冲突) + batch_price_query = """ + WITH first_prices AS ( + SELECT + code, + close as first_price, + ROW_NUMBER() OVER (PARTITION BY code ORDER BY timestamp ASC) as rn + FROM stock_minute + WHERE code IN %(codes)s + AND timestamp >= %(start)s + AND timestamp <= %(end)s + ), + last_prices AS ( + SELECT + code, + close as last_price, + open as open_price, + high as high_price, + low as low_price, + volume, + amt as amount, + ROW_NUMBER() OVER (PARTITION BY code ORDER BY timestamp DESC) as rn + FROM stock_minute + WHERE code IN %(codes)s + AND timestamp >= %(start)s + AND timestamp <= %(end)s + ) + SELECT + fp.code, + fp.first_price, + lp.last_price, + (lp.last_price - fp.first_price) / fp.first_price * 100 as change_pct, + lp.open_price, + lp.high_price, + lp.low_price, + lp.volume, + lp.amount + FROM first_prices fp + INNER JOIN last_prices lp ON fp.code = lp.code + WHERE fp.rn = 1 AND lp.rn = 1 + """ + + price_data = client.execute(batch_price_query, { + 'codes': stock_codes, + 'start': start_datetime, + 'end': end_datetime + }) + + print(f"批量查询返回 {len(price_data)} 条价格数据") + + # 解析批量查询结果 + for row in price_data: + code = row[0] + first_price = float(row[1]) if row[1] is not None else None + last_price = float(row[2]) if row[2] is not None else None + change_pct = float(row[3]) if row[3] is not None else None + open_price = float(row[4]) if row[4] is not None else None + high_price = float(row[5]) if row[5] is not None else None + low_price = float(row[6]) if row[6] is not None else None + volume = int(row[7]) if row[7] is not None else None + amount = float(row[8]) if row[8] is not None else None + + change_amount = None + if last_price is not None and first_price is not None: + change_amount = last_price - first_price + + price_data_map[code] = { + 'latest_price': last_price, + 'first_price': first_price, + 'change_pct': change_pct, + 'change_amount': change_amount, + 'open_price': open_price, + 'high_price': high_price, + 'low_price': low_price, + 'volume': volume, + 'amount': amount, + } + + # 3.2 批量查询分时图数据 + print(f"批量查询分时图数据...") + minute_chart_query = """ + SELECT + code, + timestamp, + open, + high, + low, + close, + volume, + amt + FROM stock_minute + WHERE code IN %(codes)s + AND timestamp >= %(start)s + AND timestamp <= %(end)s + ORDER BY code, timestamp + """ + + minute_data = client.execute(minute_chart_query, { + 'codes': stock_codes, + 'start': start_datetime, + 'end': end_datetime + }) + + print(f"批量查询返回 {len(minute_data)} 条分时数据") + + # 按股票代码分组分时数据 + for row in minute_data: + code = row[0] + if code not in minute_chart_map: + minute_chart_map[code] = [] + + minute_chart_map[code].append({ + 'time': row[1].strftime('%H:%M'), + 'open': float(row[2]) if row[2] else None, + 'high': float(row[3]) if row[3] else None, + 'low': float(row[4]) if row[4] else None, + 'close': float(row[5]) if row[5] else None, + 'volume': float(row[6]) if row[6] else None, + 'amount': float(row[7]) if row[7] else None + }) + + except Exception as e: + print(f"批量查询 ClickHouse 失败: {e}") + # 如果批量查询失败,price_data_map 和 minute_chart_map 为空,后续会使用降级方案 + + # 4. 组装每个股票的数据(从批量查询结果中获取) stocks_data = [] for stock in related_stocks: - print(f"正在处理股票 {stock.stock_code} 的价格数据...") + print(f"正在组装股票 {stock.stock_code} 的数据...") - # 获取股票基本信息 - stock_info = StockBasicInfo.query.filter_by(SECCODE=stock.stock_code).first() - if not stock_info: - base_code = stock.stock_code.split('.')[0] - stock_info = StockBasicInfo.query.filter_by(SECCODE=base_code).first() + # 从批量查询结果中获取股票基本信息 + stock_info = stock_info_map.get(stock.stock_code) + + # 从批量查询结果中获取价格数据 + price_info = price_data_map.get(stock.stock_code) - # 使用与 get_stock_quotes 完全相同的逻辑计算涨跌幅 latest_price = None first_price = None change_pct = None @@ -3875,79 +4039,20 @@ def api_event_related_stocks(event_id): amount = None trade_date = trading_day - try: - # 使用与 get_stock_quotes 完全相同的 SQL 查询 - # 获取事件时间点的第一个价格 (first_price) 和当前时间的最后一个价格 (last_price) - data = client.execute(""" - WITH first_price AS ( - SELECT close - FROM stock_minute - WHERE code = %(code)s - AND timestamp >= %(start)s - AND timestamp <= %(end)s - ORDER BY timestamp - LIMIT 1 - ), - last_price AS ( - SELECT close - FROM stock_minute - WHERE code = %(code)s - AND timestamp >= %(start)s - AND timestamp <= %(end)s - ORDER BY timestamp DESC - LIMIT 1 - ) - SELECT - last_price.close as last_price, - (last_price.close - first_price.close) / first_price.close * 100 as change, - first_price.close as first_price - FROM last_price - CROSS JOIN first_price - WHERE EXISTS (SELECT 1 FROM first_price) - AND EXISTS (SELECT 1 FROM last_price) - """, { - 'code': stock.stock_code, - 'start': start_datetime, - 'end': end_datetime - }) - - print(f"股票 {stock.stock_code} 查询结果: {data}") - - if data and data[0] and data[0][0] is not None: - latest_price = float(data[0][0]) - change_pct = float(data[0][1]) if data[0][1] is not None else None - first_price = float(data[0][2]) if len(data[0]) > 2 and data[0][2] is not None else None - - # 计算涨跌额 - if latest_price is not None and first_price is not None: - change_amount = latest_price - first_price - - # 获取额外的价格信息(开盘价、最高价、最低价等) - extra_data = client.execute(""" - SELECT - open, high, low, volume, amt - FROM stock_minute - WHERE code = %(code)s - AND timestamp >= %(start)s - AND timestamp <= %(end)s - ORDER BY timestamp DESC - LIMIT 1 - """, { - 'code': stock.stock_code, - 'start': start_datetime, - 'end': end_datetime - }) - - if extra_data and extra_data[0]: - open_price = float(extra_data[0][0]) if extra_data[0][0] else None - high_price = float(extra_data[0][1]) if extra_data[0][1] else None - low_price = float(extra_data[0][2]) if extra_data[0][2] else None - volume = int(extra_data[0][3]) if extra_data[0][3] else None - amount = float(extra_data[0][4]) if extra_data[0][4] else None - - except Exception as e: - print(f"Error fetching price data for {stock.stock_code}: {e}") - # 如果 ClickHouse 查询失败,尝试使用 TradeData 作为降级方案 + if price_info: + # 使用批量查询的结果 + latest_price = price_info['latest_price'] + first_price = price_info['first_price'] + change_pct = price_info['change_pct'] + change_amount = price_info['change_amount'] + open_price = price_info['open_price'] + high_price = price_info['high_price'] + low_price = price_info['low_price'] + volume = price_info['volume'] + amount = price_info['amount'] + else: + # 如果批量查询没有返回数据,使用降级方案(TradeData) + print(f"股票 {stock.stock_code} 批量查询无数据,使用降级方案...") try: latest_trade = None search_codes = [stock.stock_code, stock.stock_code.split('.')[0]] @@ -3974,10 +4079,10 @@ def api_event_related_stocks(event_id): if latest_trade.F009N: change_amount = float(latest_trade.F009N) except Exception as fallback_error: - print(f"Fallback query also failed for {stock.stock_code}: {fallback_error}") + print(f"降级查询也失败 {stock.stock_code}: {fallback_error}") - # 获取分时图数据 - minute_chart_data = get_minute_chart_data(stock.stock_code) + # 从批量查询结果中获取分时图数据 + minute_chart_data = minute_chart_map.get(stock.stock_code, []) stock_data = { 'id': stock.id, diff --git a/app_vx.py.optimized_backup b/app_vx.py.optimized_backup deleted file mode 100644 index a30b4af7..00000000 --- a/app_vx.py.optimized_backup +++ /dev/null @@ -1,6318 +0,0 @@ -import csv -import logging -import random -import re -import math -import os -import secrets -import string - -import pytz -import requests -from flask_compress import Compress -from functools import wraps -from pathlib import Path -import time -from sqlalchemy import create_engine, text, func, or_, case, event, desc, asc -from flask import Flask, has_request_context, render_template, request, jsonify, redirect, url_for, flash, session, render_template_string, current_app, send_from_directory -from flask_sqlalchemy import SQLAlchemy -from flask_login import LoginManager, UserMixin, login_user, logout_user, login_required, current_user -from flask_mail import Mail, Message -from itsdangerous import URLSafeTimedSerializer -from flask_migrate import Migrate -from flask_session import Session # type: ignore -from sqlalchemy.dialects.mysql.base import MySQLDialect -from sqlalchemy.dialects.postgresql import JSONB -from werkzeug.utils import secure_filename -from PIL import Image -from datetime import datetime, timedelta, time as dt_time -from werkzeug.security import generate_password_hash, check_password_hash -import json -from clickhouse_driver import Client as Cclient -import jwt -from docx import Document -from tencentcloud.common import credential -from tencentcloud.common.profile.client_profile import ClientProfile -from tencentcloud.common.profile.http_profile import HttpProfile -from tencentcloud.sms.v20210111 import sms_client, models -from tencentcloud.common.exception.tencent_cloud_sdk_exception import TencentCloudSDKException - -engine = create_engine("mysql+pymysql://root:Zzl5588161!@222.128.1.157:33060/stock", echo=False, pool_size=20, - max_overflow=50) -engine_med = create_engine("mysql+pymysql://root:Zzl5588161!@222.128.1.157:33060/med", echo=False) -engine_2 = create_engine("mysql+pymysql://root:Zzl5588161!@222.128.1.157:33060/valuefrontier", echo=False) -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') -logger = logging.getLogger(__name__) -app = Flask(__name__) -Compress(app) -UPLOAD_FOLDER = 'static/uploads/avatars' -ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif'} -MAX_CONTENT_LENGTH = 16 * 1024 * 1024 # 16MB max file size -# Configure Flask-Compress -app.config['COMPRESS_ALGORITHM'] = ['gzip', 'br'] -app.config['COMPRESS_MIMETYPES'] = [ - 'text/html', - 'text/css', - 'text/xml', - 'application/json', - 'application/javascript', - 'application/x-javascript' -] -user_tokens = {} -app.config['SECRET_KEY'] = 'vf7891574233241' -app.config['SQLALCHEMY_DATABASE_URI'] = 'mysql+pymysql://root:Zzl5588161!@222.128.1.157:33060/stock?charset=utf8mb4' -app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False - -app.config['JSON_AS_ASCII'] = False -app.config['JSONIFY_PRETTYPRINT_REGULAR'] = True - -# 邮件配置 -app.config['MAIL_SERVER'] = 'smtp.exmail.qq.com' -app.config['MAIL_PORT'] = 465 -app.config['MAIL_USE_SSL'] = True -app.config['MAIL_USERNAME'] = 'admin@valuefrontier.cn' -app.config['MAIL_PASSWORD'] = 'QYncRu6WUdASvTg4' - -app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER -app.config['MAX_CONTENT_LENGTH'] = MAX_CONTENT_LENGTH - -# 腾讯云短信配置 -SMS_SECRET_ID = 'AKID2we9TacdTAhCjCSYTErHVimeJo9Yr00s' -SMS_SECRET_KEY = 'pMlBWijlkgT9fz5ziEXdWEnAPTJzRfkf' -SMS_SDK_APP_ID = "1400972398" -SMS_SIGN_NAME = "价值前沿科技" -SMS_TEMPLATE_REGISTER = "2386557" # 注册模板 -SMS_TEMPLATE_LOGIN = "2386540" # 登录模板 -verification_codes = {} - -#微信小程序 -app.config['WECHAT_APP_ID'] = 'wx0edeaab76d4fa414' -app.config['WECHAT_APP_SECRET'] = '0d0c70084f05a8c1411f6b89da7e815d' -app.config['BASE_URL'] = 'http://43.143.189.195:5002' -app.config['WECHAT_REDIRECT_URI'] = f"{app.config['BASE_URL']}/api/wechat/callback" -WECHAT_APP_ID = 'wx0edeaab76d4fa414' -WECHAT_APP_SECRET = '0d0c70084f05a8c1411f6b89da7e815d' -JWT_SECRET_KEY = 'vfllmgreat33818!' # 请修改为安全的密钥 -JWT_ALGORITHM = 'HS256' -JWT_EXPIRATION_HOURS = 24 * 7 # Token有效期7天 - -# Session 配置 - 使用文件系统存储(替代 Redis) -app.config['SESSION_TYPE'] = 'filesystem' -app.config['SESSION_FILE_DIR'] = os.path.join(os.path.dirname(__file__), 'flask_session') -app.config['SESSION_PERMANENT'] = True -app.config['PERMANENT_SESSION_LIFETIME'] = timedelta(days=7) # Session 有效期 7 天 - -# 确保 session 目录存在 -os.makedirs(app.config['SESSION_FILE_DIR'], exist_ok=True) - -# Cache directory setup -CACHE_DIR = Path('cache') -CACHE_DIR.mkdir(exist_ok=True) - -# Memory management constants -MAX_MEMORY_PERCENT = 75 -MEMORY_CHECK_INTERVAL = 300 -MAX_CACHE_ITEMS = 50 - -# 初始化扩展 -db = SQLAlchemy(app) -mail = Mail(app) -login_manager = LoginManager(app) -login_manager.login_view = 'login' -serializer = URLSafeTimedSerializer(app.config['SECRET_KEY']) - -migrate = Migrate(app, db) - -DOMAIN = 'http://43.143.189.195:5002' - -JWT_SECRET = 'Llmgreat123' -JWT_EXPIRES_SECONDS = 3600 # 1小时有效期 - -Session(app) - - -def token_required(f): - """装饰器:需要token认证的接口""" - from functools import wraps - - @wraps(f) - def decorated_function(*args, **kwargs): - token = None - - # 从请求头获取token - auth_header = request.headers.get('Authorization') - if auth_header and auth_header.startswith('Bearer '): - token = auth_header[7:] - - if not token: - return jsonify({'message': '缺少认证token'}), 401 - - token_data = user_tokens.get(token) - if not token_data: - return jsonify({'message': 'Token无效','code':401}), 401 - - # 检查是否过期 - if token_data['expires'] < datetime.now(): - del user_tokens[token] - return jsonify({'message': 'Token已过期'}), 401 - - # 获取用户对象并添加到请求上下文 - user = User.query.get(token_data['user_id']) - if not user: - return jsonify({'message': '用户不存在'}), 404 - - # 将用户对象添加到request - request.user = user - request.current_user_id = token_data['user_id'] - - return f(*args, **kwargs) - - return decorated_function - - - - -def beijing_now(): - # 使用 pytz 处理时区 - beijing_tz = pytz.timezone('Asia/Shanghai') - return datetime.now(beijing_tz) - - -# ============================================ -# 订阅功能模块(与 app.py 保持一致) -# ============================================ -class UserSubscription(db.Model): - """用户订阅表 - 独立于现有User表""" - __tablename__ = 'user_subscriptions' - - id = db.Column(db.Integer, primary_key=True, autoincrement=True) - user_id = db.Column(db.Integer, nullable=False, unique=True, index=True) - subscription_type = db.Column(db.String(10), nullable=False, default='free') - subscription_status = db.Column(db.String(20), nullable=False, default='active') - start_date = db.Column(db.DateTime, nullable=True) - end_date = db.Column(db.DateTime, nullable=True) - billing_cycle = db.Column(db.String(10), nullable=True) - auto_renewal = db.Column(db.Boolean, nullable=False, default=False) - created_at = db.Column(db.DateTime, default=beijing_now) - updated_at = db.Column(db.DateTime, default=beijing_now, onupdate=beijing_now) - - def is_active(self): - if self.subscription_status != 'active': - return False - if self.subscription_type == 'free': - return True - if self.end_date: - # 将数据库的 naive datetime 转换为带时区的 aware datetime - beijing_tz = pytz.timezone('Asia/Shanghai') - end_date_aware = self.end_date if self.end_date.tzinfo else beijing_tz.localize(self.end_date) - return beijing_now() <= end_date_aware - return True - - def days_left(self): - if not self.is_active(): - return 0 - if self.subscription_type == 'free': - return 999 - if not self.end_date: - return 999 - try: - now = beijing_now() - # 将数据库的 naive datetime 转换为带时区的 aware datetime - beijing_tz = pytz.timezone('Asia/Shanghai') - end_date_aware = self.end_date if self.end_date.tzinfo else beijing_tz.localize(self.end_date) - delta = end_date_aware - now - return max(0, delta.days) - except Exception as e: - return 0 - - def to_dict(self): - return { - 'type': self.subscription_type, - 'status': self.subscription_status, - 'is_active': self.is_active(), - 'start_date': self.start_date.isoformat() if self.start_date else None, - 'end_date': self.end_date.isoformat() if self.end_date else None, - 'days_left': self.days_left(), - 'billing_cycle': self.billing_cycle, - 'auto_renewal': self.auto_renewal - } - - -# ============================================ -# 订阅等级工具函数 -# ============================================ -def get_user_subscription_safe(user_id): - """ - 安全地获取用户订阅信息 - :param user_id: 用户ID - :return: UserSubscription 对象或默认免费订阅 - """ - try: - subscription = UserSubscription.query.filter_by(user_id=user_id).first() - if not subscription: - # 如果用户没有订阅记录,创建默认免费订阅 - subscription = UserSubscription( - user_id=user_id, - subscription_type='free', - subscription_status='active' - ) - db.session.add(subscription) - db.session.commit() - return subscription - except Exception as e: - print(f"获取用户订阅信息失败: {e}") - # 返回一个临时的免费订阅对象(不保存到数据库) - temp_sub = UserSubscription( - user_id=user_id, - subscription_type='free', - subscription_status='active' - ) - return temp_sub - - -def _get_current_subscription_info(): - """ - 获取当前登录用户订阅信息的字典形式,未登录或异常时视为免费用户。 - 小程序场景下从 request.current_user_id 获取用户ID - """ - try: - user_id = getattr(request, 'current_user_id', None) - if not user_id: - return { - 'type': 'free', - 'status': 'active', - 'is_active': True - } - sub = get_user_subscription_safe(user_id) - return { - 'type': sub.subscription_type, - 'status': sub.subscription_status, - 'is_active': sub.is_active(), - 'start_date': sub.start_date.isoformat() if sub.start_date else None, - 'end_date': sub.end_date.isoformat() if sub.end_date else None, - 'days_left': sub.days_left() - } - except Exception as e: - print(f"获取订阅信息异常: {e}") - return { - 'type': 'free', - 'status': 'active', - 'is_active': True - } - - -def _subscription_level(sub_type): - """将订阅类型映射到等级数值,free=0, pro=1, max=2。""" - mapping = {'free': 0, 'pro': 1, 'max': 2} - return mapping.get((sub_type or 'free').lower(), 0) - - -def _has_required_level(required: str) -> bool: - """判断当前用户是否达到所需订阅级别。""" - info = _get_current_subscription_info() - if not info.get('is_active', True): - return False - return _subscription_level(info.get('type')) >= _subscription_level(required) - - -# ============================================ -# 权限装饰器 -# ============================================ -def subscription_required(level='pro'): - """ - 订阅等级装饰器 - 小程序专用 - 用法: - @subscription_required('pro') # 需要 Pro 或 Max 用户 - @subscription_required('max') # 仅限 Max 用户 - - 注意:此装饰器需要配合 使用 - """ - from functools import wraps - - def decorator(f): - @wraps(f) - def decorated_function(*args, **kwargs): - if not _has_required_level(level): - current_info = _get_current_subscription_info() - current_type = current_info.get('type', 'free') - is_active = current_info.get('is_active', False) - - if not is_active: - return jsonify({ - 'success': False, - 'error': '您的订阅已过期,请续费后继续使用', - 'error_code': 'SUBSCRIPTION_EXPIRED', - 'current_subscription': current_type, - 'required_subscription': level - }), 403 - - return jsonify({ - 'success': False, - 'error': f'此功能需要 {level.upper()} 或更高等级会员', - 'error_code': 'SUBSCRIPTION_REQUIRED', - 'current_subscription': current_type, - 'required_subscription': level - }), 403 - - return f(*args, **kwargs) - - return decorated_function - - return decorator - - -def pro_or_max_required(f): - """ - 快捷装饰器:要求 Pro 或 Max 用户(小程序专用场景) - 等同于 @subscription_required('pro') - """ - from functools import wraps - - @wraps(f) - def decorated_function(*args, **kwargs): - if not _has_required_level('pro'): - current_info = _get_current_subscription_info() - current_type = current_info.get('type', 'free') - - return jsonify({ - 'success': False, - 'error': '小程序功能仅对 Pro 和 Max 会员开放', - 'error_code': 'MINIPROGRAM_PRO_REQUIRED', - 'current_subscription': current_type, - 'required_subscription': 'pro', - 'message': '请升级到 Pro 或 Max 会员以使用小程序完整功能' - }), 403 - - return f(*args, **kwargs) - - return decorated_function - - -class User(UserMixin, db.Model): - """用户模型""" - id = db.Column(db.Integer, primary_key=True) - - # 基础账号信息(注册时必填) - username = db.Column(db.String(80), unique=True, nullable=False) # 用户名 - email = db.Column(db.String(120), unique=True, nullable=False) # 邮箱 - password_hash = db.Column(db.String(128), nullable=False) # 密码哈希 - email_confirmed = db.Column(db.Boolean, default=False) # 邮箱是否验证 - wechat_union_id = db.Column(db.String(100), unique=True) # 微信 UnionID - wechat_open_id = db.Column(db.String(100)) # 微信 OpenID - - # 账号状态 - created_at = db.Column(db.DateTime, default=beijing_now) # 注册时间 - last_seen = db.Column(db.DateTime, default=beijing_now) # 最后活跃时间 - status = db.Column(db.String(20), default='active') # 账号状态 active/banned/deleted - - # 个人资料(可选,后续在个人中心完善) - nickname = db.Column(db.String(30)) # 社区昵称 - avatar_url = db.Column(db.String(200)) # 头像URL - banner_url = db.Column(db.String(200)) # 个人主页背景图 - bio = db.Column(db.String(200)) # 个人简介 - gender = db.Column(db.String(10)) # 性别 - birth_date = db.Column(db.Date) # 生日 - location = db.Column(db.String(100)) # 所在地 - - # 联系方式(可选) - phone = db.Column(db.String(20)) # 手机号 - wechat_id = db.Column(db.String(80)) # 微信号 - - # 实名认证信息(可选) - real_name = db.Column(db.String(30)) # 真实姓名 - id_number = db.Column(db.String(18)) # 身份证号(加密存储) - is_verified = db.Column(db.Boolean, default=False) # 是否实名认证 - verify_time = db.Column(db.DateTime) # 实名认证时间 - - # 投资相关信息(可选) - trading_experience = db.Column(db.String(200)) # 炒股年限 - investment_style = db.Column(db.String(50)) # 投资风格 - risk_preference = db.Column(db.String(20)) # 风险偏好 - investment_amount = db.Column(db.String(20)) # 投资规模 - preferred_markets = db.Column(db.String(200), default='[]') # 偏好市场 JSON - - # 社区信息(系统自动更新) - user_level = db.Column(db.Integer, default=1) # 用户等级 - reputation_score = db.Column(db.Integer, default=0) # 信用积分 - contribution_point = db.Column(db.Integer, default=0) # 贡献点数 - post_count = db.Column(db.Integer, default=0) # 发帖数 - comment_count = db.Column(db.Integer, default=0) # 评论数 - follower_count = db.Column(db.Integer, default=0) # 粉丝数 - following_count = db.Column(db.Integer, default=0) # 关注数 - - # 创作者信息(可选) - is_creator = db.Column(db.Boolean, default=False) # 是否创作者 - creator_type = db.Column(db.String(20)) # 创作者类型 - creator_tags = db.Column(db.String(200), default='[]') # 创作者标签 JSON - - # 系统设置 - email_notifications = db.Column(db.Boolean, default=True) # 邮件通知 - sms_notifications = db.Column(db.Boolean, default=False) # 短信通知 - wechat_notifications = db.Column(db.Boolean, default=False) # 微信通知 - notification_preferences = db.Column(db.String(500), default='{}') # 通知偏好 JSON - privacy_level = db.Column(db.String(20), default='public') # 隐私级别 - theme_preference = db.Column(db.String(20), default='light') # 主题偏好 - blocked_keywords = db.Column(db.String(500), default='[]') # 屏蔽关键词 JSON - # 手机号验证 - phone_confirmed = db.Column(db.Boolean, default=False) # 手机是否验证 - phone_confirm_time = db.Column(db.DateTime) # 手机验证时间 - - def __init__(self, username, email=None, password=None, phone=None): - """初始化用户,只需要基本信息""" - self.username = username - if email: - self.email = email - if phone: - self.phone = phone - if password: - self.set_password(password) - self.created_at = beijing_now() - self.last_seen = beijing_now() - - def set_password(self, password): - """设置密码""" - self.password_hash = generate_password_hash(password) - - def check_password(self, password): - """验证密码""" - return check_password_hash(self.password_hash, password) - - def update_last_seen(self): - """更新最后活跃时间""" - self.last_seen = beijing_now() - - # JSON 字段的getter和setter - def get_preferred_markets(self): - """获取偏好市场列表""" - if self.preferred_markets: - try: - return json.loads(self.preferred_markets) - except: - return [] - return [] - - def get_blocked_keywords(self): - """获取屏蔽关键词列表""" - if self.blocked_keywords: - try: - return json.loads(self.blocked_keywords) - except: - return [] - return [] - - def get_notification_preferences(self): - """获取通知偏好设置""" - if self.notification_preferences: - try: - return json.loads(self.notification_preferences) - except: - return {} - return {} - - def get_creator_tags(self): - """获取创作者标签""" - if self.creator_tags: - try: - return json.loads(self.creator_tags) - except: - return [] - return [] - - def set_preferred_markets(self, markets): - """设置偏好市场""" - self.preferred_markets = json.dumps(markets) - - def set_blocked_keywords(self, keywords): - """设置屏蔽关键词""" - self.blocked_keywords = json.dumps(keywords) - - def set_notification_preferences(self, preferences): - """设置通知偏好""" - self.notification_preferences = json.dumps(preferences) - - def set_creator_tags(self, tags): - """设置创作者标签""" - self.creator_tags = json.dumps(tags) - - def to_dict(self): - """返回用户的字典表示""" - return { - 'id': self.id, - 'username': self.username, - 'email': self.email, - 'nickname': self.nickname, - 'avatar_url': get_full_avatar_url(self.avatar_url), # 修改这里 - 'bio': self.bio, - 'is_verified': self.is_verified, - 'user_level': self.user_level, - 'reputation_score': self.reputation_score, - 'is_creator': self.is_creator - } - - def __repr__(self): - return f'' - - -class Notification(db.Model): - """通知模型""" - id = db.Column(db.Integer, primary_key=True) - user_id = db.Column(db.Integer, db.ForeignKey('user.id')) - type = db.Column(db.String(50)) # 通知类型 - content = db.Column(db.Text) # 通知内容 - link = db.Column(db.String(200)) # 相关链接 - is_read = db.Column(db.Boolean, default=False) # 是否已读 - created_at = db.Column(db.DateTime, default=beijing_now) - - def __init__(self, user_id, type, content, link=None): - self.user_id = user_id - self.type = type - self.content = content - self.link = link - - -class Event(db.Model): - """事件模型""" - id = db.Column(db.Integer, primary_key=True) - title = db.Column(db.String(200), nullable=False) - description = db.Column(db.Text) - - # 事件类型与状态 - event_type = db.Column(db.String(50)) - status = db.Column(db.String(20), default='active') - - # 时间相关 - start_time = db.Column(db.DateTime, default=beijing_now) - end_time = db.Column(db.DateTime) - created_at = db.Column(db.DateTime, default=beijing_now) - updated_at = db.Column(db.DateTime, default=beijing_now) - - # 热度与统计 - hot_score = db.Column(db.Float, default=0) - view_count = db.Column(db.Integer, default=0) - trending_score = db.Column(db.Float, default=0) - post_count = db.Column(db.Integer, default=0) - follower_count = db.Column(db.Integer, default=0) - - # 关联信息 - related_industries = db.Column(db.JSON) - keywords = db.Column(db.JSON) - files = db.Column(db.JSON) - importance = db.Column(db.String(20)) - related_avg_chg = db.Column(db.Float, default=0) - related_max_chg = db.Column(db.Float, default=0) - related_week_chg = db.Column(db.Float, default=0) - - # 新增字段 - invest_score = db.Column(db.Integer) # 超预期得分 - expectation_surprise_score = db.Column(db.Integer) - # 创建者信息 - creator_id = db.Column(db.Integer, db.ForeignKey('user.id')) - creator = db.relationship('User', backref='created_events') - - # 关系 - posts = db.relationship('Post', backref='event', lazy='dynamic') - followers = db.relationship('EventFollow', backref='event', lazy='dynamic') - related_stocks = db.relationship('RelatedStock', backref='event', lazy='dynamic') - historical_events = db.relationship('HistoricalEvent', backref='event', lazy='dynamic') - related_data = db.relationship('RelatedData', backref='event', lazy='dynamic') - related_concepts = db.relationship('RelatedConcepts', backref='event', lazy='dynamic') - ind_type = db.Column(db.String(255)) - @property - def keywords_list(self): - """返回解析后的关键词列表""" - if not self.keywords: - return [] - - if isinstance(self.keywords, list): - return self.keywords - - try: - # 如果是字符串,尝试解析JSON - if isinstance(self.keywords, str): - decoded = json.loads(self.keywords) - # 处理Unicode编码的情况 - if isinstance(decoded, list): - return [ - keyword.encode('utf-8').decode('unicode_escape') - if isinstance(keyword, str) and '\\u' in keyword - else keyword - for keyword in decoded - ] - return [] - - # 如果已经是字典或其他格式,尝试转换为列表 - return list(self.keywords) - except (json.JSONDecodeError, AttributeError, TypeError): - return [] - - def set_keywords(self, keywords): - """设置关键词列表""" - if isinstance(keywords, list): - self.keywords = json.dumps(keywords, ensure_ascii=False) - elif isinstance(keywords, str): - try: - # 尝试解析JSON字符串 - parsed = json.loads(keywords) - if isinstance(parsed, list): - self.keywords = json.dumps(parsed, ensure_ascii=False) - else: - self.keywords = json.dumps([keywords], ensure_ascii=False) - except json.JSONDecodeError: - # 如果不是有效的JSON,将其作为单个关键词 - self.keywords = json.dumps([keywords], ensure_ascii=False) - - -class RelatedStock(db.Model): - """相关标的模型""" - id = db.Column(db.Integer, primary_key=True) - event_id = db.Column(db.Integer, db.ForeignKey('event.id')) - stock_code = db.Column(db.String(20)) # 股票代码 - stock_name = db.Column(db.String(100)) # 股票名称 - sector = db.Column(db.String(100)) # 关联类型 - relation_desc = db.Column(db.String(1024)) # 关联原因描述 - created_at = db.Column(db.DateTime, default=beijing_now) - updated_at = db.Column(db.DateTime, default=beijing_now, onupdate=beijing_now) - correlation = db.Column(db.Float()) - momentum = db.Column(db.String(1024)) # 动量 - # 新增字段 - retrieved_sources = db.Column(db.JSON) # 研报检索源数据 - retrieved_update_time = db.Column(db.DateTime) # 检索数据更新时间 - - -class RelatedData(db.Model): - """关联数据模型""" - id = db.Column(db.Integer, primary_key=True) - event_id = db.Column(db.Integer, db.ForeignKey('event.id')) - title = db.Column(db.String(200)) # 数据标题 - data_type = db.Column(db.String(50)) # 数据类型 - data_content = db.Column(db.JSON) # 数据内容(JSON格式) - description = db.Column(db.Text) # 数据描述 - created_at = db.Column(db.DateTime, default=beijing_now) - - -class RelatedConcepts(db.Model): - """关联数据模型""" - id = db.Column(db.Integer, primary_key=True) - event_id = db.Column(db.Integer, db.ForeignKey('event.id')) - concept_code = db.Column(db.String(20)) # 数据标题 - concept = db.Column(db.String(100)) # 数据类型 - reason = db.Column(db.Text) # 数据描述 - image_paths = db.Column(db.JSON) # 数据内容(JSON格式) - created_at = db.Column(db.DateTime, default=beijing_now) - - @property - def image_paths_list(self): - """返回解析后的图片路径列表""" - if not self.image_paths: - return [] - - try: - # 如果是字符串,先解析成JSON - if isinstance(self.image_paths, str): - paths = json.loads(self.image_paths) - else: - paths = self.image_paths - - # 确保paths是列表 - if not isinstance(paths, list): - paths = [paths] - - # 从每个对象中提取path字段 - return [item['path'] if isinstance(item, dict) and 'path' in item - else item for item in paths] - except Exception as e: - print(f"Error processing image paths: {e}") - return [] - - def get_first_image_path(self): - """获取第一张图片的完整路径""" - paths = self.image_paths_list - if not paths: - return None - - # 获取第一个路径 - first_path = paths[0] - # 返回完整路径 - return first_path - - -class EventHotHistory(db.Model): - """事件热度历史记录""" - id = db.Column(db.Integer, primary_key=True) - event_id = db.Column(db.Integer, db.ForeignKey('event.id')) - score = db.Column(db.Float) # 总分 - interaction_score = db.Column(db.Float) # 互动分数 - follow_score = db.Column(db.Float) # 关注度分数 - view_score = db.Column(db.Float) # 浏览量分数 - recent_activity_score = db.Column(db.Float) # 最近活跃度分数 - time_decay = db.Column(db.Float) # 时间衰减因子 - created_at = db.Column(db.DateTime, default=beijing_now) - - event = db.relationship('Event', backref='hot_history') - - -class Post(db.Model): - """帖子模型""" - id = db.Column(db.Integer, primary_key=True) - event_id = db.Column(db.Integer, db.ForeignKey('event.id'), nullable=False) - user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False) - - # 内容 - title = db.Column(db.String(200)) # 标题(可选) - content = db.Column(db.Text, nullable=False) # 内容 - content_type = db.Column(db.String(20), default='text') # 内容类型:text/rich_text/link - - # 时间 - created_at = db.Column(db.DateTime, default=beijing_now) - updated_at = db.Column(db.DateTime, default=beijing_now, onupdate=beijing_now) - - # 统计 - likes_count = db.Column(db.Integer, default=0) - comments_count = db.Column(db.Integer, default=0) - view_count = db.Column(db.Integer, default=0) - - # 状态 - status = db.Column(db.String(20), default='active') # active/hidden/deleted - is_top = db.Column(db.Boolean, default=False) # 是否置顶 - - # 关系 - user = db.relationship('User', backref='posts') - likes = db.relationship('PostLike', backref='post', lazy='dynamic') - comments = db.relationship('Comment', backref='post', lazy='dynamic') - - -# 辅助模型 -class EventFollow(db.Model): - """事件关注""" - id = db.Column(db.Integer, primary_key=True) - user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False) - event_id = db.Column(db.Integer, db.ForeignKey('event.id'), nullable=False) - created_at = db.Column(db.DateTime, default=beijing_now) - - user = db.relationship('User', backref='event_follows') - - __table_args__ = (db.UniqueConstraint('user_id', 'event_id'),) - - -class PostLike(db.Model): - """帖子点赞""" - id = db.Column(db.Integer, primary_key=True) - user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False) - post_id = db.Column(db.Integer, db.ForeignKey('post.id'), nullable=False) - created_at = db.Column(db.DateTime, default=beijing_now) - - user = db.relationship('User', backref='post_likes') - - __table_args__ = (db.UniqueConstraint('user_id', 'post_id'),) - - -class Comment(db.Model): - """评论""" - id = db.Column(db.Integer, primary_key=True) - post_id = db.Column(db.Integer, db.ForeignKey('post.id'), nullable=False) - user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False) - content = db.Column(db.Text, nullable=False) - parent_id = db.Column(db.Integer, db.ForeignKey('comment.id')) # 父评论ID,用于回复 - created_at = db.Column(db.DateTime, default=beijing_now) - status = db.Column(db.String(20), default='active') - - user = db.relationship('User', backref='comments') - replies = db.relationship('Comment', backref=db.backref('parent', remote_side=[id])) - - -class StockBasicInfo(db.Model): - __tablename__ = 'ea_stocklist' - - SECCODE = db.Column(db.String(10), primary_key=True) - SECNAME = db.Column(db.String(40)) - ORGNAME = db.Column(db.String(100)) - F001V = db.Column(db.String(100)) # Pinyin abbreviation - F003V = db.Column(db.String(50)) # Security category - F005V = db.Column(db.String(50)) # Trading market - F006D = db.Column(db.DateTime) # Listing date - F011V = db.Column(db.String(50)) # Listing status - - -class CompanyInfo(db.Model): - __tablename__ = 'ea_baseinfo' - - SECCODE = db.Column(db.String(10), primary_key=True) - SECNAME = db.Column(db.String(40)) - ORGNAME = db.Column(db.String(100)) - F001V = db.Column(db.String(100)) # English name - F003V = db.Column(db.String(40)) # Legal representative - F015V = db.Column(db.String(500)) # Main business - F016V = db.Column(db.String(4000)) # Business scope - F017V = db.Column(db.String(2000)) # Company introduction - F030V = db.Column(db.String(60)) # CSRC industry first level - F032V = db.Column(db.String(60)) # CSRC industry second level - - - -class TradeData(db.Model): - __tablename__ = 'ea_trade' - - SECCODE = db.Column(db.String(10), primary_key=True) - SECNAME = db.Column(db.String(40)) - TRADEDATE = db.Column(db.Date, primary_key=True) - F002N = db.Column(db.Numeric(18, 4)) # Previous close - F003N = db.Column(db.Numeric(18, 4)) # Open price - F004N = db.Column(db.Numeric(18, 4)) # Trading volume - F005N = db.Column(db.Numeric(18, 4)) # High price - F006N = db.Column(db.Numeric(18, 4)) # Low price - F007N = db.Column(db.Numeric(18, 4)) # Close price - F009N = db.Column(db.Numeric(18, 4)) # Change - F010N = db.Column(db.Numeric(18, 4)) # Change percentage - F011N = db.Column(db.Numeric(18, 4)) # Trading amount - - -class SectorInfo(db.Model): - __tablename__ = 'ea_sector' - - SECCODE = db.Column(db.String(10), primary_key=True) - SECNAME = db.Column(db.String(40)) - F001V = db.Column(db.String(50), primary_key=True) # Classification standard code - F002V = db.Column(db.String(50)) # Classification standard - F003V = db.Column(db.String(50)) # Sector code - F004V = db.Column(db.String(50)) # Sector level 1 name - F005V = db.Column(db.String(50)) # Sector level 2 name - F006V = db.Column(db.String(50)) # Sector level 3 name - F007V = db.Column(db.String(50)) # Sector level 4 name -def send_async_email(msg): - """异步发送邮件""" - try: - mail.send(msg) - except Exception as e: - app.logger.error(f"Error sending async email: {str(e)}") -def verify_sms_code(phone_number, code): - """验证短信验证码""" - stored_code = session.get('sms_verification_code') - stored_phone = session.get('sms_verification_phone') - expiration = session.get('sms_verification_expiration') - - if not all([stored_code, stored_phone, expiration]): - return False, "请先获取验证码" - - if stored_phone != phone_number: - return False, "手机号与验证码不匹配" - - if beijing_now().timestamp() > expiration: - return False, "验证码已过期" - - if code != stored_code: - return False, "验证码错误" - - return True, "验证成功" - - -def allowed_file(filename): - return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS - - -# ============================================ -# 订阅相关 API 接口(小程序专用) -# ============================================ -@app.route('/api/subscription/info', methods=['GET']) -@token_required -def get_subscription_info(): - """ - 获取当前用户的订阅信息 - 小程序专用接口 - 返回用户当前订阅类型、状态、剩余天数等信息 - """ - try: - info = _get_current_subscription_info() - return jsonify({ - 'success': True, - 'data': info - }) - except Exception as e: - print(f"获取订阅信息错误: {e}") - return jsonify({ - 'success': True, - 'data': { - 'type': 'free', - 'status': 'active', - 'is_active': True, - 'days_left': 0 - } - }) - - -@app.route('/api/subscription/check', methods=['GET']) -@token_required -def check_subscription_access(): - """ - 检查当前用户是否有权限使用小程序功能 - 返回:是否为 Pro/Max 用户 - """ - try: - has_access = _has_required_level('pro') - info = _get_current_subscription_info() - - return jsonify({ - 'success': True, - 'data': { - 'has_access': has_access, - 'subscription_type': info.get('type', 'free'), - 'is_active': info.get('is_active', False), - 'message': '您可以使用小程序功能' if has_access else '小程序功能仅对 Pro 和 Max 会员开放' - } - }) - except Exception as e: - print(f"检查订阅权限错误: {e}") - return jsonify({ - 'success': False, - 'error': str(e) - }), 500 - - -# ============================================ -# 现有接口示例(应用权限控制) -# ============================================ - -# 更新视图函数 -@app.route('/settings/profile', methods=['POST']) -@token_required -def update_profile(): - """更新个人资料""" - try: - user = request.user - form = request.form - - # 基本信息更新 - user.nickname = form.get('nickname') - user.bio = form.get('bio') - user.gender = form.get('gender') - user.birth_date = datetime.strptime(form.get('birth_date'), '%Y-%m-%d') if form.get('birth_date') else None - user.phone = form.get('phone') - user.location = form.get('location') - user.wechat_id = form.get('wechat_id') - - # 处理头像上传 - if 'avatar' in request.files: - file = request.files['avatar'] - if file and allowed_file(file.filename): - # 生成安全的文件名 - filename = secure_filename(f"{user.id}_{int(datetime.now().timestamp())}_{file.filename}") - filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename) - - # 确保上传目录存在 - os.makedirs(os.path.dirname(filepath), exist_ok=True) - - # 保存并处理图片 - image = Image.open(file) - image.thumbnail((300, 300)) # 调整图片大小 - image.save(filepath) - - # 更新用户头像URL - user.avatar_url = f'{DOMAIN}/static/uploads/avatars/{filename}' - - db.session.commit() - return jsonify({'success': True, 'message': '个人资料已更新'}) - - except Exception as e: - db.session.rollback() - app.logger.error(f"Error updating profile: {str(e)}") - return jsonify({'success': False, 'message': '更新失败,请重试'}) - - - -# 投资偏好设置 -@app.route('/settings/investment_preferences', methods=['POST']) -@token_required -def update_investment_preferences(): - """更新投资偏好""" - try: - user = request.user - form = request.form - - user.trading_experience = form.get('trading_experience') - user.investment_style = form.get('investment_style') - user.risk_preference = form.get('risk_preference') - user.investment_amount = form.get('investment_amount') - user.preferred_markets = json.dumps(request.form.getlist('preferred_markets')) - - db.session.commit() - return jsonify({'success': True, 'message': '投资偏好已更新'}) - - except Exception as e: - db.session.rollback() - app.logger.error(f"Error updating investment preferences: {str(e)}") - return jsonify({'success': False, 'message': '更新失败,请重试'}) - - -def get_clickhouse_client(): - return Cclient( - host='222.128.1.157', - port=18000, - user='default', - password='Zzl33818!', - database='stock' - ) - - -@app.route('/api/stock//kline') -def get_stock_kline(stock_code): - """获取股票K线数据 - 仅限 Pro/Max 会员(小程序功能)""" - chart_type = request.args.get('chart_type', 'daily') # 默认改为daily - event_time = request.args.get('event_time') - - try: - event_datetime = datetime.fromisoformat(event_time) if event_time else datetime.now() - except ValueError: - return jsonify({'error': 'Invalid event_time format'}), 400 - - # 获取股票名称 - try: - with engine.connect() as conn: - result = conn.execute(text( - "SELECT SECNAME FROM ea_stocklist WHERE SECCODE = :code" - ), {"code": stock_code.split('.')[0]}).fetchone() - stock_name = result[0] if result else 'Unknown' - except Exception as e: - print(f"Error getting stock name: {e}") - stock_name = 'Unknown' - - if chart_type == 'daily': - return get_daily_kline(stock_code, event_datetime, stock_name) - elif chart_type == 'minute': - return get_minute_kline(stock_code, event_datetime, stock_name) - else: - return jsonify({ - 'error': 'Invalid chart type', - 'message': 'Supported types: daily, minute', - 'code': stock_code, - 'name': stock_name - }), 400 - - -def get_daily_kline(stock_code, event_datetime, stock_name): - """处理日K线数据""" - stock_code = stock_code.split('.')[0] - - print(f"Debug: stock_code={stock_code}, event_datetime={event_datetime}, stock_name={stock_name}") - - try: - with engine.connect() as conn: - # 获取事件日期前后的数据 - kline_sql = """ - WITH date_range AS ( - SELECT TRADEDATE - FROM ea_trade - WHERE SECCODE = :stock_code - AND TRADEDATE BETWEEN DATE_SUB(:trade_date, INTERVAL 60 DAY) - AND :trade_date - GROUP BY TRADEDATE - ORDER BY TRADEDATE - ) - SELECT t.TRADEDATE, - CAST(t.F003N AS FLOAT) as open, - CAST(t.F007N AS FLOAT) as close, - CAST(t.F005N AS FLOAT) as high, - CAST(t.F006N AS FLOAT) as low, - CAST(t.F004N AS FLOAT) as volume - FROM ea_trade t - JOIN date_range d ON t.TRADEDATE = d.TRADEDATE - WHERE t.SECCODE = :stock_code - ORDER BY t.TRADEDATE - """ - - result = conn.execute(text(kline_sql), { - "stock_code": stock_code, - "trade_date": event_datetime.date() - }).fetchall() - - print(f"Debug: Query result count: {len(result)}") - - if not result: - print("Debug: No data found, trying fallback query...") - # 如果没有数据,尝试获取最近的交易数据 - fallback_sql = """ - SELECT TRADEDATE, - CAST(F003N AS FLOAT) as open, - CAST(F007N AS FLOAT) as close, - CAST(F005N AS FLOAT) as high, - CAST(F006N AS FLOAT) as low, - CAST(F004N AS FLOAT) as volume - FROM ea_trade - WHERE SECCODE = :stock_code - AND TRADEDATE <= :trade_date - AND F003N IS NOT NULL - AND F007N IS NOT NULL - AND F005N IS NOT NULL - AND F006N IS NOT NULL - AND F004N IS NOT NULL - ORDER BY TRADEDATE - LIMIT 100 - """ - - result = conn.execute(text(fallback_sql), { - "stock_code": stock_code, - "trade_date": event_datetime.date() - }).fetchall() - - print(f"Debug: Fallback query result count: {len(result)}") - - if not result: - return jsonify({ - 'error': 'No data available', - 'code': stock_code, - 'name': stock_name, - 'data': [], - 'trade_date': event_datetime.date().strftime('%Y-%m-%d'), - 'type': 'daily' - }) - - kline_data = [] - for row in result: - try: - kline_data.append({ - 'time': row.TRADEDATE.strftime('%Y-%m-%d'), - 'open': float(row.open) if row.open else 0, - 'high': float(row.high) if row.high else 0, - 'low': float(row.low) if row.low else 0, - 'close': float(row.close) if row.close else 0, - 'volume': float(row.volume) if row.volume else 0 - }) - except (ValueError, TypeError) as e: - print(f"Debug: Error processing row: {e}") - continue - - print(f"Debug: Final kline_data count: {len(kline_data)}") - - return jsonify({ - 'code': stock_code, - 'name': stock_name, - 'data': kline_data, - 'trade_date': event_datetime.date().strftime('%Y-%m-%d'), - 'type': 'daily', - 'is_history': True, - 'data_count': len(kline_data) - }) - - except Exception as e: - print(f"Error in get_daily_kline: {e}") - return jsonify({ - 'error': f'Database error: {str(e)}', - 'code': stock_code, - 'name': stock_name, - 'data': [], - 'trade_date': event_datetime.date().strftime('%Y-%m-%d'), - 'type': 'daily' - }), 500 - - -def get_minute_kline(stock_code, event_datetime, stock_name): - """处理分钟K线数据 - 包含零轴(昨日收盘价)""" - client = get_clickhouse_client() - stock_code_short = stock_code.split('.')[0] # 获取不带后缀的股票代码 - - def get_trading_days(): - trading_days = set() - with open('tdays.csv', 'r') as f: - reader = csv.DictReader(f) - for row in reader: - trading_days.add(datetime.strptime(row['DateTime'], '%Y/%m/%d').date()) - return trading_days - - trading_days = get_trading_days() - - def find_next_trading_day(current_date): - """找到下一个交易日""" - while current_date <= max(trading_days): - current_date += timedelta(days=1) - if current_date in trading_days: - return current_date - return None - - def find_prev_trading_day(current_date): - """找到前一个交易日""" - while current_date >= min(trading_days): - current_date -= timedelta(days=1) - if current_date in trading_days: - return current_date - return None - - def get_prev_close(stock_code_short, target_date): - """获取前一交易日的收盘价作为零轴基准""" - prev_date = find_prev_trading_day(target_date) - if not prev_date: - return None - - try: - with engine.connect() as conn: - # 查询前一交易日的收盘价 - sql = """ - SELECT CAST(F007N AS FLOAT) as close - FROM ea_trade - WHERE SECCODE = :stock_code - AND TRADEDATE = :prev_date - AND F007N IS NOT NULL - LIMIT 1 - """ - result = conn.execute(text(sql), { - "stock_code": stock_code_short, - "prev_date": prev_date - }).fetchone() - - if result: - return float(result.close) - else: - # 如果指定日期没有数据,尝试获取最近的收盘价 - fallback_sql = """ - SELECT CAST(F007N AS FLOAT) as close, TRADEDATE - FROM ea_trade - WHERE SECCODE = :stock_code - AND TRADEDATE < :target_date - AND F007N IS NOT NULL - ORDER BY TRADEDATE DESC - LIMIT 1 - """ - result = conn.execute(text(fallback_sql), { - "stock_code": stock_code_short, - "target_date": target_date - }).fetchone() - - if result: - print(f"Using close price from {result.TRADEDATE} as zero axis") - return float(result.close) - - except Exception as e: - print(f"Error getting previous close: {e}") - - return None - - target_date = event_datetime.date() - is_after_market = event_datetime.time() > dt_time(15, 0) - - # 核心逻辑:先判断当前日期是否是交易日,以及是否已收盘 - if target_date in trading_days and is_after_market: - # 如果是交易日且已收盘,查找下一个交易日 - next_trade_date = find_next_trading_day(target_date) - if next_trade_date: - target_date = next_trade_date - elif target_date not in trading_days: - # 如果不是交易日,先尝试找下一个交易日 - next_trade_date = find_next_trading_day(target_date) - if next_trade_date: - target_date = next_trade_date - else: - # 如果找不到下一个交易日,找最近的历史交易日 - target_date = find_prev_trading_day(target_date) - - if not target_date: - return jsonify({ - 'error': 'No data available', - 'code': stock_code, - 'name': stock_name, - 'data': [], - 'trade_date': event_datetime.date().strftime('%Y-%m-%d'), - 'type': 'minute' - }) - - # 获取前一交易日收盘价作为零轴 - zero_axis = get_prev_close(stock_code_short, target_date) - - # 获取目标日期的完整交易时段数据 - data = client.execute(""" - SELECT - timestamp, - open, - high, - low, - close, - volume, - amt - FROM stock_minute - WHERE code = %(code)s - AND timestamp BETWEEN %(start)s AND %(end)s - ORDER BY timestamp - """, { - 'code': stock_code, - 'start': datetime.combine(target_date, dt_time(9, 30)), - 'end': datetime.combine(target_date, dt_time(15, 0)) - }) - - kline_data = [] - for row in data: - point = { - 'time': row[0].strftime('%H:%M'), - 'open': float(row[1]), - 'high': float(row[2]), - 'low': float(row[3]), - 'close': float(row[4]), - 'volume': float(row[5]), - 'amount': float(row[6]) - } - - # 如果有零轴数据,计算涨跌幅和涨跌额 - if zero_axis: - point['prev_close'] = zero_axis - point['change'] = point['close'] - zero_axis # 涨跌额 - point['change_pct'] = ((point['close'] - zero_axis) / zero_axis * 100) if zero_axis != 0 else 0 # 涨跌幅百分比 - - kline_data.append(point) - - response_data = { - 'code': stock_code, - 'name': stock_name, - 'data': kline_data, - 'trade_date': target_date.strftime('%Y-%m-%d'), - 'type': 'minute', - 'is_history': target_date < event_datetime.date() - } - - # 添加零轴信息到响应中 - if zero_axis: - response_data['zero_axis'] = zero_axis - response_data['prev_close'] = zero_axis - - # 计算当日整体涨跌幅(如果有数据) - if kline_data: - last_close = kline_data[-1]['close'] - response_data['day_change'] = last_close - zero_axis - response_data['day_change_pct'] = ((last_close - zero_axis) / zero_axis * 100) if zero_axis != 0 else 0 - - return jsonify(response_data) - -class HistoricalEvent(db.Model): - """历史事件模型""" - id = db.Column(db.Integer, primary_key=True) - event_id = db.Column(db.Integer, db.ForeignKey('event.id')) - title = db.Column(db.String(200)) - content = db.Column(db.Text) - event_date = db.Column(db.DateTime) - relevance = db.Column(db.Integer) # 相关性 - importance = db.Column(db.Integer) # 重要程度 - related_stock = db.Column(db.JSON) # 保留JSON字段 - created_at = db.Column(db.DateTime, default=beijing_now) - - # 新增关系 - stocks = db.relationship('HistoricalEventStock', backref='historical_event', lazy='dynamic', - cascade='all, delete-orphan') - - -class HistoricalEventStock(db.Model): - """历史事件相关股票模型""" - __tablename__ = 'historical_event_stocks' - - id = db.Column(db.Integer, primary_key=True) - historical_event_id = db.Column(db.Integer, db.ForeignKey('historical_event.id'), nullable=False) - stock_code = db.Column(db.String(20), nullable=False) - stock_name = db.Column(db.String(50)) - relation_desc = db.Column(db.Text) - correlation = db.Column(db.Float, default=0.5) - sector = db.Column(db.String(100)) - created_at = db.Column(db.DateTime, default=beijing_now) - - __table_args__ = ( - db.UniqueConstraint('historical_event_id', 'stock_code', name='unique_event_stock'), - ) - - - -@app.route('/event/follow/', methods=['POST']) -@token_required -def follow_event(event_id): - """关注/取消关注事件""" - event = Event.query.get_or_404(event_id) - follow = EventFollow.query.filter_by( - user_id=request.user.id, - event_id=event_id - ).first() - - try: - if follow: - db.session.delete(follow) - event.follower_count -= 1 - message = '已取消关注' - else: - follow = EventFollow(user_id=request.user.id, event_id=event_id) - db.session.add(follow) - event.follower_count += 1 - message = '已关注' - - db.session.commit() - return jsonify({'success': True, 'message': message}) - - except Exception as e: - db.session.rollback() - return jsonify({'success': False, 'message': '操作失败,请重试'}) - - -# 帖子相关路由 -@app.route('/post/create/', methods=['GET', 'POST']) -@token_required -def create_post(event_id): - """创建新帖子""" - event = Event.query.get_or_404(event_id) - - if request.method == 'POST': - try: - post = Post( - event_id=event_id, - user_id=request.user.id, - title=request.form.get('title'), - content=request.form['content'], - content_type=request.form.get('content_type', 'text') - ) - - db.session.add(post) - event.post_count += 1 - db.session.commit() - - # 检查是否是 API 请求(通过 Accept header 或 Content-Type 判断) - if request.headers.get('Accept') == 'application/json' or \ - request.headers.get('Content-Type', '').startswith('application/json'): - return jsonify({ - 'success': True, - 'message': '发布成功', - 'data': { - 'post_id': post.id, - 'event_id': event_id, - 'redirect_url': url_for('event_detail', event_id=event_id) - } - }) - else: - # 传统表单提交,添加成功消息并重定向 - flash('发布成功', 'success') - return redirect(url_for('event_detail', event_id=event_id)) - - except Exception as e: - db.session.rollback() - if request.headers.get('Accept') == 'application/json' or \ - request.headers.get('Content-Type', '').startswith('application/json'): - return jsonify({ - 'success': False, - 'message': '发布失败,请重试' - }), 400 - else: - flash('发布失败,请重试', 'error') - app.logger.error(f"Error creating post: {str(e)}") - - return render_template('projects/create_post.html', event=event) - - -# 点赞相关路由 -@app.route('/post/like/', methods=['POST']) -@token_required -def like_post(post_id): - """点赞/取消点赞帖子""" - post = Post.query.get_or_404(post_id) - like = PostLike.query.filter_by( - user_id=request.user.id, - post_id=post_id - ).first() - - try: - if like: - # 取消点赞 - db.session.delete(like) - post.likes_count -= 1 - message = '已取消点赞' - else: - # 添加点赞 - like = PostLike(user_id=request.user.id, post_id=post_id) - db.session.add(like) - post.likes_count += 1 - message = '已点赞' - - # 可以在这里添加点赞通知 - if post.user_id != request.user.id: - notify_user_post_liked(post) - - db.session.commit() - return jsonify({ - 'success': True, - 'message': message, - 'likes_count': post.likes_count - }) - - except Exception as e: - db.session.rollback() - return jsonify({'success': False, 'message': '操作失败,请重试'}) - - -def update_user_activity(): - """更新用户活跃度""" - with app.app_context(): - try: - # 获取过去7天内的用户活动数据 - seven_days_ago = beijing_now() - timedelta(days=7) - - # 统计用户发帖、评论、点赞等活动 - active_users = db.session.query( - User.id, - db.func.count(Post.id).label('post_count'), - db.func.count(Comment.id).label('comment_count'), - db.func.count(PostLike.id).label('like_count') - ).outerjoin(Post, User.id == Post.user_id) \ - .outerjoin(Comment, User.id == Comment.user_id) \ - .outerjoin(PostLike, User.id == PostLike.user_id) \ - .filter( - db.or_( - Post.created_at >= seven_days_ago, - Comment.created_at >= seven_days_ago, - PostLike.created_at >= seven_days_ago - ) - ).group_by(User.id).all() - - # 更新用户活跃度分数 - for user_id, post_count, comment_count, like_count in active_users: - activity_score = post_count * 2 + comment_count * 1 + like_count * 0.5 - User.query.filter_by(id=user_id).update({ - 'activity_score': activity_score, - 'last_active': beijing_now() - }) - - db.session.commit() - current_app.logger.info("Successfully updated user activity scores") - - except Exception as e: - db.session.rollback() - current_app.logger.error(f"Error updating user activity: {str(e)}") - - -@app.route('/post/comment/', methods=['POST']) -@token_required -def add_comment(post_id): - """添加评论""" - post = Post.query.get_or_404(post_id) - - try: - content = request.form.get('content') - parent_id = request.form.get('parent_id', type=int) - - if not content: - return jsonify({'success': False, 'message': '评论内容不能为空'}) - - comment = Comment( - post_id=post_id, - user_id=request.user.id, - content=content, - parent_id=parent_id - ) - - db.session.add(comment) - post.comments_count += 1 - - # 如果是回复评论,可以添加通知 - if parent_id: - parent_comment = Comment.query.get(parent_id) - if parent_comment and parent_comment.user_id != request.user.id: - notify_user_comment_replied(parent_comment) - - # 如果是评论帖子,通知帖子作者 - elif post.user_id != request.user.id: - notify_user_post_commented(post) - - db.session.commit() - - return jsonify({ - 'success': True, - 'message': '评论成功', - 'comment': { - 'id': comment.id, - 'content': comment.content, - 'user_name': request.user.username, - 'user_avatar': get_full_avatar_url(request.user.avatar_url), # 修改这里 - 'created_at': comment.created_at.strftime('%Y-%m-%d %H:%M:%S') - } - }) - - except Exception as e: - db.session.rollback() - return jsonify({'success': False, 'message': '评论失败,请重试'}) - - -@app.route('/post/comments/') - -def get_comments(post_id): - """获取帖子评论列表""" - page = request.args.get('page', 1, type=int) - - # 获取顶层评论 - comments = Comment.query.filter_by( - post_id=post_id, - parent_id=None, - status='active' - ).order_by( - Comment.created_at.desc() - ).paginate(page=page, per_page=20) - - # 同时获取每个顶层评论的部分回复 - comments_data = [] - for comment in comments.items: - replies = Comment.query.filter_by( - parent_id=comment.id, - status='active' - ).order_by( - Comment.created_at.asc() - ).limit(3).all() - - comments_data.append({ - 'id': comment.id, - 'content': comment.content, - 'user': { - 'id': comment.user.id, - 'username': comment.user.username, - 'avatar_url': get_full_avatar_url(comment.user.avatar_url), # 修改这里 - }, - 'created_at': comment.created_at.strftime('%Y-%m-%d %H:%M:%S'), - 'replies': [{ - 'id': reply.id, - 'content': reply.content, - 'user': { - 'id': reply.user.id, - 'username': reply.user.username, - 'avatar_url': get_full_avatar_url(reply.user.avatar_url), # 修改这里 - }, - 'created_at': reply.created_at.strftime('%Y-%m-%d %H:%M:%S') - } for reply in replies] - }) - - return jsonify({ - 'comments': comments_data, - 'total': comments.total, - 'pages': comments.pages, - 'current_page': comments.page - }) - - -beijing_tz = pytz.timezone('Asia/Shanghai') - - -def update_hot_scores(): - """ - 更新所有事件的热度分数 - 在Flask应用上下文中执行数据库操作 - """ - with app.app_context(): - try: - # 获取所有活跃事件 - events = Event.query.filter_by(status='active').all() - current_time = beijing_now() - - for event in events: - # 确保created_at有时区信息,解决naive和aware datetime比较问题 - created_at = beijing_tz.localize( - event.created_at) if event.created_at.tzinfo is None else event.created_at - - # 使用处理后的created_at计算hours_passed - hours_passed = (current_time - created_at).total_seconds() / 3600 - - # 基础分数 - 帖子数和评论数 - posts = Post.query.filter_by(event_id=event.id).all() - post_count = len(posts) - comment_count = sum(post.comments_count for post in posts) - - # 获取24小时内的新增帖子数 - recent_posts = Post.query.filter( - Post.event_id == event.id, - Post.created_at >= current_time - timedelta(hours=24) - ).count() - - # 获取点赞数 - like_count = db.session.query(func.sum(Post.likes_count)).filter( - Post.event_id == event.id - ).scalar() or 0 - - # 基础互动分数 = 帖子数 * 2 + 评论数 * 1 + 点赞数 * 0.5 - interaction_score = (post_count * 2) + (comment_count * 1) + (like_count * 0.5) - - # 关注度分数 = 关注人数 * 3 - follow_score = event.follower_count * 3 - - # 浏览量分数 = log(浏览量) - if event.view_count > 0: - view_score = math.log(event.view_count) * 2 - else: - view_score = 0 - - # 时间衰减因子 - 使用上面已经计算好的hours_passed - time_decay = math.exp(-hours_passed / 72) # 3天后衰减为原始分数的1/e - - # 最近活跃度权重 - recent_activity_weight = (recent_posts * 5) # 24小时内的新帖权重高 - - # 总分 = (互动分数 + 关注度分数 + 浏览量分数 + 最近活跃度) * 时间衰减 - total_score = (interaction_score + follow_score + view_score + recent_activity_weight) * time_decay - - # 更新热度分数 - event.hot_score = round(total_score, 2) - - # 分数的对数值作为事件的trending_score (用于趋势排序) - if total_score > 0: - event.trending_score = math.log(total_score) * time_decay - else: - event.trending_score = 0 - - # 记录热度历史 - history = EventHotHistory( - event_id=event.id, - score=event.hot_score, - interaction_score=interaction_score, - follow_score=follow_score, - view_score=view_score, - recent_activity_score=recent_activity_weight, - time_decay=time_decay - ) - db.session.add(history) - - db.session.commit() - app.logger.info("Successfully updated event hot scores") - - except Exception as e: - db.session.rollback() - app.logger.error(f"Error updating hot scores: {str(e)}") - raise - - -# 添加热度历史记录模型 - - -def calculate_hot_score(event): - """计算事件热度分数""" - current_time = beijing_now() - time_diff = (current_time - event.created_at).total_seconds() / 3600 # 转换为小时 - - # 基础分数 = 浏览量 * 0.1 + 帖子数 * 0.5 + 关注数 * 1 - base_score = ( - event.view_count * 0.1 + - event.post_count * 0.5 + - event.follower_count * 1 - ) - - # 时间衰减因子,72小时(3天)内的事件获得较高权重 - time_factor = max(1 - (time_diff / 72), 0.1) - - return base_score * time_factor - - -@app.route('/api/sector/hierarchy', methods=['GET']) -def api_sector_hierarchy(): - """行业层级关系接口:展示多个行业分类体系的层级结构""" - try: - # 定义需要返回的行业分类体系 - classification_systems = [ - '申银万国行业分类' - ] - - result = [] # 改为数组 - - for classification in classification_systems: - # 查询特定分类标准的数据 - sectors = SectorInfo.query.filter_by(F002V=classification).all() - - if not sectors: - continue - - # 构建该分类体系的层级结构 - hierarchy = {} - - for sector in sectors: - level1 = sector.F004V # 一级行业 - level2 = sector.F005V # 二级行业 - level3 = sector.F006V # 三级行业 - level4 = sector.F007V # 四级行业 - - # 统计股票数量 - stock_code = sector.SECCODE - - # 初始化一级行业 - if level1 not in hierarchy: - hierarchy[level1] = { - 'level2_sectors': {}, - 'stocks': set(), - 'stocks_count': 0 - } - - # 添加股票到一级行业 - if stock_code: - hierarchy[level1]['stocks'].add(stock_code) - - # 处理二级行业 - if level2: - if level2 not in hierarchy[level1]['level2_sectors']: - hierarchy[level1]['level2_sectors'][level2] = { - 'level3_sectors': {}, - 'stocks': set(), - 'stocks_count': 0 - } - - # 添加股票到二级行业 - if stock_code: - hierarchy[level1]['level2_sectors'][level2]['stocks'].add(stock_code) - - # 处理三级行业 - if level3: - if level3 not in hierarchy[level1]['level2_sectors'][level2]['level3_sectors']: - hierarchy[level1]['level2_sectors'][level2]['level3_sectors'][level3] = { - 'level4_sectors': [], - 'stocks': set(), - 'stocks_count': 0 - } - - # 添加股票到三级行业 - if stock_code: - hierarchy[level1]['level2_sectors'][level2]['level3_sectors'][level3]['stocks'].add( - stock_code) - - # 处理四级行业 - if level4 and level4 not in \ - hierarchy[level1]['level2_sectors'][level2]['level3_sectors'][level3]['level4_sectors']: - hierarchy[level1]['level2_sectors'][level2]['level3_sectors'][level3][ - 'level4_sectors'].append(level4) - - # 计算股票数量并清理set对象 - formatted_hierarchy = [] - for level1, level1_data in hierarchy.items(): - level1_item = { - 'level1_sector': level1, - 'stocks_count': len(level1_data['stocks']), - 'level2_sectors': [] - } - - for level2, level2_data in level1_data['level2_sectors'].items(): - level2_item = { - 'level2_sector': level2, - 'stocks_count': len(level2_data['stocks']), - 'level3_sectors': [] - } - - for level3, level3_data in level2_data['level3_sectors'].items(): - level3_item = { - 'level3_sector': level3, - 'stocks_count': len(level3_data['stocks']), - 'level4_sectors': level3_data['level4_sectors'] - } - level2_item['level3_sectors'].append(level3_item) - - # 按股票数量排序 - level2_item['level3_sectors'].sort(key=lambda x: x['stocks_count'], reverse=True) - level1_item['level2_sectors'].append(level2_item) - - # 按股票数量排序 - level1_item['level2_sectors'].sort(key=lambda x: x['stocks_count'], reverse=True) - formatted_hierarchy.append(level1_item) - - # 按股票数量排序 - formatted_hierarchy.sort(key=lambda x: x['stocks_count'], reverse=True) - - # 将该分类体系添加到结果数组中 - result.append({ - 'classification_name': classification, - 'total_level1_count': len(formatted_hierarchy), - 'total_stocks_count': sum(item['stocks_count'] for item in formatted_hierarchy), - 'hierarchy': formatted_hierarchy - }) - - # 按总股票数量排序 - result.sort(key=lambda x: x['total_stocks_count'], reverse=True) - - return jsonify({ - "code": 200, - "message": "success", - "data": result - }) - - except Exception as e: - return jsonify({ - "code": 500, - "message": str(e), - "data": None - }), 500 - - -@app.route('/api/sector/banner', methods=['GET']) -def api_sector_banner(): - """行业分类 banner 接口:返回一级分类和对应二级行业列表""" - try: - # 原始映射 - sector_map = { - '石油石化': '大周期', '煤炭': '大周期', '有色金属': '大周期', '钢铁': '大周期', '基础化工': '大周期', - '建筑材料': '大周期', '机械设备': '大周期', '电力设备及新能源': '大周期', '国防军工': '大周期', - '电力设备': '大周期', '电网设备': '大周期', '风力发电': '大周期', '太阳能发电': '大周期', - '建筑装饰': '大周期', - - '汽车': '大消费', '家用电器': '大消费', '酒类': '大消费', '食品饮料': '大消费', '医药生物': '大消费', - '纺织服饰': '大消费', '农林牧渔': '大消费', '商贸零售': '大消费', '轻工制造': '大消费', - '消费者服务': '大消费', '美容护理': '大消费', '社会服务': '大消费', - - '银行': '大金融地产', '证券': '大金融地产', '保险': '大金融地产', '多元金融': '大金融地产', - '综合金融': '大金融地产', '房地产': '大金融地产', '非银金融': '大金融地产', - - '计算机': 'TMT板块', '电子': 'TMT板块', '传媒': 'TMT板块', '通信': 'TMT板块', - - '交通运输': '公共产业板块', '电力公用事业': '公共产业板块', '建筑': '公共产业板块', - '环保': '公共产业板块', '综合': '公共产业板块', '公用事业': '公共产业板块' - } - - # 重组结构为 一级 → [二级...] - result_dict = {} - for sub_sector, primary_sector in sector_map.items(): - result_dict.setdefault(primary_sector, []).append(sub_sector) - - # 格式化成列表 - result_list = [ - {"primary_sector": primary, "sub_sectors": subs} - for primary, subs in result_dict.items() - ] - - return jsonify({ - "code": 200, - "message": "success", - "data": result_list - }) - - except Exception as e: - return jsonify({ - "code": 500, - "message": str(e), - "data": None - }), 500 - - -def get_limit_rate(stock_code): - """ - 根据股票代码获取涨跌停限制比例 - - Args: - stock_code: 股票代码 - - Returns: - float: 涨跌停限制比例 - """ - if not stock_code: - return 10.0 - - # 去掉市场后缀 - clean_code = stock_code.replace('.SH', '').replace('.SZ', '').replace('.BJ', '') - - # ST股票 (5%涨跌停) - if 'ST' in stock_code.upper(): - return 5.0 - - # 科创板 (688开头, 20%涨跌停) - if clean_code.startswith('688'): - return 20.0 - - # 创业板注册制 (30开头, 20%涨跌停) - if clean_code.startswith('30'): - return 20.0 - - # 北交所 (43、83、87开头, 30%涨跌停) - if clean_code.startswith(('43', '83', '87')): - return 30.0 - - # 主板、中小板默认 (10%涨跌停) - return 10.0 - - -@app.route('/api/events', methods=['GET']) - - -def api_get_events(): - """ - 获取事件列表API - 优化版本(保持完全兼容) - 仅限 Pro/Max 会员访问(小程序功能) - - 优化策略: - 1. 使用ind_type字段简化内部逻辑 - 2. 批量获取股票行情,包括周涨跌计算 - 3. 保持原有返回数据结构不变 - """ - try: - # ==================== 参数解析 ==================== - - # 分页参数 - page = max(1, request.args.get('page', 1, type=int)) - per_page = min(100, max(1, request.args.get('per_page', 10, type=int))) - - # 基础筛选参数 - event_type = request.args.get('type', 'all') - event_status = request.args.get('status', 'active') - importance = request.args.get('importance', 'all') - - # 日期筛选参数 - start_date = request.args.get('start_date') - end_date = request.args.get('end_date') - date_range = request.args.get('date_range') - recent_days = request.args.get('recent_days', type=int) - - # 行业筛选参数(重新设计) - ind_type = request.args.get('ind_type', 'all') - stock_sector = request.args.get('stock_sector', 'all') - secondary_sector = request.args.get('secondary_sector', 'all') - - # 新的行业层级筛选参数 - industry_level = request.args.get('industry_level', type=int) # 筛选层级:1-4 - industry_classification = request.args.get('industry_classification') # 行业名称 - - # 如果使用旧参数,映射到ind_type - if ind_type == 'all' and stock_sector != 'all': - ind_type = stock_sector - - # 标签筛选参数 - tag = request.args.get('tag') - tags = request.args.get('tags') - keywords = request.args.get('keywords') - - # 搜索参数 - search_query = request.args.get('q') - search_type = request.args.get('search_type', 'topic') - search_fields = request.args.get('search_fields', 'title,description').split(',') - - # 排序参数 - sort_by = request.args.get('sort', 'new') - return_type = request.args.get('return_type', 'avg') - order = request.args.get('order', 'desc') - - # 收益率筛选参数 - min_avg_return = request.args.get('min_avg_return', type=float) - max_avg_return = request.args.get('max_avg_return', type=float) - min_max_return = request.args.get('min_max_return', type=float) - max_max_return = request.args.get('max_max_return', type=float) - min_week_return = request.args.get('min_week_return', type=float) - max_week_return = request.args.get('max_week_return', type=float) - - # 其他筛选参数 - min_hot_score = request.args.get('min_hot_score', type=float) - max_hot_score = request.args.get('max_hot_score', type=float) - min_view_count = request.args.get('min_view_count', type=int) - creator_id = request.args.get('creator_id', type=int) - - # 返回格式参数 - include_creator = request.args.get('include_creator', 'true').lower() == 'true' - include_stats = request.args.get('include_stats', 'true').lower() == 'true' - include_related_data = request.args.get('include_related_data', 'false').lower() == 'true' - - # ==================== 构建查询 ==================== - - query = Event.query - - # 状态筛选 - if event_status != 'all': - query = query.filter_by(status=event_status) - - # 事件类型筛选 - if event_type != 'all': - query = query.filter_by(event_type=event_type) - - # 重要性筛选 - if importance != 'all': - query = query.filter_by(importance=importance) - - # 行业类型筛选(使用ind_type字段) - if ind_type != 'all': - query = query.filter_by(ind_type=ind_type) - - # 创建者筛选 - if creator_id: - query = query.filter_by(creator_id=creator_id) - - # ==================== 日期筛选 ==================== - - if recent_days: - cutoff_date = datetime.now() - timedelta(days=recent_days) - query = query.filter(Event.created_at >= cutoff_date) - else: - # 处理日期范围字符串 - if date_range and ' 至 ' in date_range: - try: - start_date_str, end_date_str = date_range.split(' 至 ') - start_date = start_date_str.strip() - end_date = end_date_str.strip() - except ValueError: - pass - - # 开始日期 - if start_date: - try: - if len(start_date) == 10: - start_datetime = datetime.strptime(start_date, '%Y-%m-%d') - else: - start_datetime = datetime.strptime(start_date, '%Y-%m-%d %H:%M:%S') - query = query.filter(Event.created_at >= start_datetime) - except ValueError: - pass - - # 结束日期 - if end_date: - try: - if len(end_date) == 10: - end_datetime = datetime.strptime(end_date, '%Y-%m-%d') - end_datetime = end_datetime.replace(hour=23, minute=59, second=59) - else: - end_datetime = datetime.strptime(end_date, '%Y-%m-%d %H:%M:%S') - query = query.filter(Event.created_at <= end_datetime) - except ValueError: - pass - - # ==================== 行业层级筛选(申银万国行业分类) ==================== - - if industry_level and industry_classification: - # 排除行业分类体系名称本身,这些不是具体的行业 - classification_systems = [ - '申银万国行业分类', '中上协行业分类', '巨潮行业分类', - '新财富行业分类', '证监会行业分类', '证监会行业分类(2001)' - ] - - if industry_classification not in classification_systems: - # 根据层级和名称查询对应的行业代码 - # 前端发送的level值直接对应数据库字段: - # level=2 -> f004v(一级行业) - # level=3 -> f005v(二级行业) - # level=4 -> f006v(三级行业) - # level=5 -> f007v(四级行业) - level_column_map = { - 2: 'f004v', # level2 对应一级行业 - 3: 'f005v', # level3 对应二级行业 - 4: 'f006v', # level4 对应三级行业 - 5: 'f007v' # level5 对应四级行业 - } - - if industry_level in level_column_map: - level_column = level_column_map[industry_level] - - # 查询所有匹配该行业名称的代码 - sector_codes_sql = f""" - SELECT DISTINCT f003v - FROM ea_sector - WHERE f002v = '申银万国行业分类' - AND {level_column} = :industry_name - """ - - result = db.session.execute( - text(sector_codes_sql), - {'industry_name': industry_classification} - ) - - matching_codes = [row[0] for row in result.fetchall()] - - if matching_codes: - # 根据层级确定代码前缀长度 - # 申银万国代码规则:S + 2位一级 + 2位二级 + 2位三级 + 2位四级 - prefix_length_map = { - 2: 3, # level2: S + 2位(一级行业) - 3: 5, # level3: S + 2位 + 2位(二级行业) - 4: 7, # level4: S + 2位 + 2位 + 2位(三级行业) - 5: 9 # level5: 完整代码(四级行业) - } - - prefix_length = prefix_length_map.get(industry_level, 9) - - # 获取所有代码的共同前缀(用于模糊匹配) - code_prefixes = list(set([code[:prefix_length] for code in matching_codes if code])) - - if code_prefixes: - # 构建查询条件:查找related_industries中包含这些前缀的事件 - if isinstance(db.engine.dialect, MySQLDialect): - # MySQL JSON查询 - conditions = [] - for prefix in code_prefixes: - conditions.append( - text(""" - JSON_SEARCH( - related_industries, - 'one', - CONCAT(:prefix, '%'), - NULL, - '$[*]."申银万国行业分类"' - ) IS NOT NULL - """).params(prefix=prefix) - ) - - if conditions: - query = query.filter(or_(*conditions)) - else: - # 其他数据库 - pattern_conditions = [] - for prefix in code_prefixes: - pattern_conditions.append( - text("related_industries::text LIKE :pattern").params( - pattern=f'%"申银万国行业分类": "{prefix}%' - ) - ) - - if pattern_conditions: - query = query.filter(or_(*pattern_conditions)) - else: - # 没有找到匹配的行业代码,返回空结果 - query = query.filter(Event.id == -1) - else: - # 无效的层级参数 - app.logger.warning(f"Invalid industry_level: {industry_level}") - else: - # industry_classification 是分类体系名称,不进行筛选 - app.logger.info( - f"Skipping filter: industry_classification '{industry_classification}' is a classification system name") - - # ==================== 细分行业筛选(保留向后兼容) ==================== - - elif secondary_sector != 'all': - # 直接按行业名称查询(最后一级行业 - level5/f007v) - sector_code_query = db.session.query(text("DISTINCT f003v")).select_from( - text("ea_sector") - ).filter( - text("f002v = '申银万国行业分类' AND f007v = :sector_name") - ).params(sector_name=secondary_sector) - - sector_result = sector_code_query.first() - - if sector_result and sector_result[0]: - industry_code_to_search = sector_result[0] - - # 在related_industries JSON中查找包含该代码的事件 - if isinstance(db.engine.dialect, MySQLDialect): - query = query.filter( - text(""" - JSON_SEARCH( - related_industries, - 'one', - :industry_code, - NULL, - '$[*]."申银万国行业分类"' - ) IS NOT NULL - """) - ).params(industry_code=industry_code_to_search) - else: - query = query.filter( - text(""" - related_industries::text LIKE :pattern - """) - ).params(pattern=f'%"申银万国行业分类": "{industry_code_to_search}"%') - else: - # 如果没有找到对应的行业代码,返回空结果 - query = query.filter(Event.id == -1) - - # ==================== 概念/标签筛选 ==================== - - # 单个标签筛选 - if tag: - if isinstance(db.engine.dialect, MySQLDialect): - query = query.filter(text("JSON_CONTAINS(keywords, :tag, '$')")) - query = query.params(tag=json.dumps(tag)) - else: - query = query.filter(Event.keywords.cast(JSONB).contains([tag])) - - # 多个标签筛选 (AND逻辑) - if tags: - tag_list = [t.strip() for t in tags.split(',') if t.strip()] - for single_tag in tag_list: - if isinstance(db.engine.dialect, MySQLDialect): - query = query.filter(text("JSON_CONTAINS(keywords, :tag, '$')")) - query = query.params(tag=json.dumps(single_tag)) - else: - query = query.filter(Event.keywords.cast(JSONB).contains([single_tag])) - - # 关键词筛选 (OR逻辑) - if keywords: - keyword_list = [k.strip() for k in keywords.split(',') if k.strip()] - keyword_filters = [] - for keyword in keyword_list: - if isinstance(db.engine.dialect, MySQLDialect): - keyword_filters.append(text("JSON_CONTAINS(keywords, :keyword, '$')")) - else: - keyword_filters.append(Event.keywords.cast(JSONB).contains([keyword])) - if keyword_filters: - query = query.filter(or_(*keyword_filters)) - - # ==================== 搜索功能 ==================== - - if search_query: - search_terms = search_query.strip().split() - - if search_type == 'stock': - # 股票搜索 - query = query.join(RelatedStock).filter( - or_( - RelatedStock.stock_code.ilike(f'%{search_query}%'), - RelatedStock.stock_name.ilike(f'%{search_query}%') - ) - ).distinct() - elif search_type == 'all': - # 全局搜索 - search_filters = [] - - # 文本字段搜索 - for term in search_terms: - term_filters = [] - if 'title' in search_fields: - term_filters.append(Event.title.ilike(f'%{term}%')) - if 'description' in search_fields: - term_filters.append(Event.description.ilike(f'%{term}%')) - if 'keywords' in search_fields: - if isinstance(db.engine.dialect, MySQLDialect): - term_filters.append(text("JSON_CONTAINS(keywords, :term, '$')")) - else: - term_filters.append(Event.keywords.cast(JSONB).contains([term])) - - if term_filters: - search_filters.append(or_(*term_filters)) - - # 股票搜索 - stock_subquery = db.session.query(RelatedStock.event_id).filter( - or_( - RelatedStock.stock_code.ilike(f'%{search_query}%'), - RelatedStock.stock_name.ilike(f'%{search_query}%') - ) - ).subquery() - - search_filters.append(Event.id.in_(stock_subquery)) - - if search_filters: - query = query.filter(or_(*search_filters)) - else: - # 话题搜索 (默认) - for term in search_terms: - term_filters = [] - if 'title' in search_fields: - term_filters.append(Event.title.ilike(f'%{term}%')) - if 'description' in search_fields: - term_filters.append(Event.description.ilike(f'%{term}%')) - if 'keywords' in search_fields: - if isinstance(db.engine.dialect, MySQLDialect): - term_filters.append(text("JSON_CONTAINS(keywords, :term, '$')")) - else: - term_filters.append(Event.keywords.cast(JSONB).contains([term])) - - if term_filters: - query = query.filter(or_(*term_filters)) - - # ==================== 收益率筛选 ==================== - - if min_avg_return is not None: - query = query.filter(Event.related_avg_chg >= min_avg_return) - if max_avg_return is not None: - query = query.filter(Event.related_avg_chg <= max_avg_return) - - if min_max_return is not None: - query = query.filter(Event.related_max_chg >= min_max_return) - if max_max_return is not None: - query = query.filter(Event.related_max_chg <= max_max_return) - - if min_week_return is not None: - query = query.filter(Event.related_week_chg >= min_week_return) - if max_week_return is not None: - query = query.filter(Event.related_week_chg <= max_week_return) - - # ==================== 其他数值筛选 ==================== - - if min_hot_score is not None: - query = query.filter(Event.hot_score >= min_hot_score) - if max_hot_score is not None: - query = query.filter(Event.hot_score <= max_hot_score) - - if min_view_count is not None: - query = query.filter(Event.view_count >= min_view_count) - - # ==================== 排序逻辑 ==================== - - order_func = desc if order.lower() == 'desc' else asc - - if sort_by == 'hot': - query = query.order_by(order_func(Event.hot_score)) - elif sort_by == 'new': - query = query.order_by(order_func(Event.created_at)) - elif sort_by == 'returns': - if return_type == 'avg': - query = query.order_by(order_func(Event.related_avg_chg)) - elif return_type == 'max': - query = query.order_by(order_func(Event.related_max_chg)) - elif return_type == 'week': - query = query.order_by(order_func(Event.related_week_chg)) - elif sort_by == 'importance': - importance_order = case( - (Event.importance == 'S', 1), - (Event.importance == 'A', 2), - (Event.importance == 'B', 3), - (Event.importance == 'C', 4), - else_=5 - ) - if order.lower() == 'desc': - query = query.order_by(importance_order) - else: - query = query.order_by(desc(importance_order)) - elif sort_by == 'view_count': - query = query.order_by(order_func(Event.view_count)) - elif sort_by == 'follow' and hasattr(request, 'user') and request.user.is_authenticated: - # 关注的事件排序 - query = query.join(EventFollow).filter( - EventFollow.user_id == request.user.id - ).order_by(order_func(Event.created_at)) - - # ==================== 分页查询 ==================== - - paginated = query.paginate(page=page, per_page=per_page, error_out=False) - - # ==================== 批量获取股票行情数据(优化版) ==================== - - # 1. 收集当前页所有事件的ID - event_ids = [event.id for event in paginated.items] - - # 2. 获取所有相关股票 - all_related_stocks = {} - if event_ids: - related_stocks = RelatedStock.query.filter( - RelatedStock.event_id.in_(event_ids) - ).all() - - # 按事件ID分组 - for stock in related_stocks: - if stock.event_id not in all_related_stocks: - all_related_stocks[stock.event_id] = [] - all_related_stocks[stock.event_id].append(stock) - - # 3. 收集所有股票代码 - all_stock_codes = [] - stock_code_mapping = {} # 清理后的代码 -> 原始代码的映射 - - for stocks in all_related_stocks.values(): - for stock in stocks: - clean_code = stock.stock_code.replace('.SH', '').replace('.SZ', '').replace('.BJ', '') - all_stock_codes.append(clean_code) - stock_code_mapping[clean_code] = stock.stock_code - - # 去重 - all_stock_codes = list(set(all_stock_codes)) - - # 4. 批量查询最近7个交易日的数据(用于计算日涨跌和周涨跌) - stock_price_data = {} - - if all_stock_codes: - # 构建SQL查询 - 获取最近7个交易日的数据 - codes_str = "'" + "', '".join(all_stock_codes) + "'" - - # 获取最近7个交易日的数据 - recent_trades_sql = f""" - SELECT - SECCODE, - SECNAME, - F007N as close_price, - F010N as daily_change, - TRADEDATE, - ROW_NUMBER() OVER (PARTITION BY SECCODE ORDER BY TRADEDATE DESC) as rn - FROM ea_trade - WHERE SECCODE IN ({codes_str}) - AND F007N IS NOT NULL - AND TRADEDATE >= DATE_SUB(CURDATE(), INTERVAL 10 DAY) - ORDER BY SECCODE, TRADEDATE DESC - """ - - result = db.session.execute(text(recent_trades_sql)) - - # 整理数据 - for row in result.fetchall(): - sec_code = row[0] - if sec_code not in stock_price_data: - stock_price_data[sec_code] = { - 'stock_name': row[1], - 'prices': [] - } - - stock_price_data[sec_code]['prices'].append({ - 'close_price': float(row[2]) if row[2] else 0, - 'daily_change': float(row[3]) if row[3] else 0, - 'trade_date': row[4], - 'rank': row[5] - }) - - # 5. 计算日涨跌和周涨跌 - stock_changes = {} - - for sec_code, data in stock_price_data.items(): - prices = data['prices'] - - # 最新日涨跌(第1条记录) - daily_change = 0 - if prices and prices[0]['rank'] == 1: - daily_change = prices[0]['daily_change'] - - # 计算周涨跌(最新价 vs 5个交易日前的价格) - week_change = 0 - if len(prices) >= 2: - latest_price = prices[0]['close_price'] - # 找到第5个交易日的数据(如果有) - week_ago_price = None - for price_data in prices: - if price_data['rank'] >= 5: - week_ago_price = price_data['close_price'] - break - - # 如果没有第5天的数据,使用最早的数据 - if week_ago_price is None and len(prices) > 1: - week_ago_price = prices[-1]['close_price'] - - if week_ago_price and week_ago_price > 0: - week_change = (latest_price - week_ago_price) / week_ago_price * 100 - - stock_changes[sec_code] = { - 'stock_name': data['stock_name'], - 'daily_change': daily_change, - 'week_change': week_change - } - - # ==================== 获取整体统计信息 ==================== - - # 获取所有筛选条件下的事件和股票(用于统计) - all_filtered_events = query.limit(500).all() - all_event_ids = [e.id for e in all_filtered_events] - - overall_distribution = { - 'limit_down': 0, - 'down_over_5': 0, - 'down_5_to_1': 0, - 'down_within_1': 0, - 'flat': 0, - 'up_within_1': 0, - 'up_1_to_5': 0, - 'up_over_5': 0, - 'limit_up': 0 - } - - if all_event_ids: - # 获取所有相关股票 - all_stocks_for_stats = RelatedStock.query.filter( - RelatedStock.event_id.in_(all_event_ids) - ).all() - - # 统计涨跌分布 - for stock in all_stocks_for_stats: - clean_code = stock.stock_code.replace('.SH', '').replace('.SZ', '').replace('.BJ', '') - if clean_code in stock_changes: - daily_change = stock_changes[clean_code]['daily_change'] - - # 计算涨跌停限制 - limit_rate = get_limit_rate(stock.stock_code) - - # 分类统计 - if daily_change <= -limit_rate + 0.01: - overall_distribution['limit_down'] += 1 - elif daily_change >= limit_rate - 0.01: - overall_distribution['limit_up'] += 1 - elif daily_change > 5: - overall_distribution['up_over_5'] += 1 - elif daily_change > 1: - overall_distribution['up_1_to_5'] += 1 - elif daily_change > 0.1: - overall_distribution['up_within_1'] += 1 - elif daily_change >= -0.1: - overall_distribution['flat'] += 1 - elif daily_change > -1: - overall_distribution['down_within_1'] += 1 - elif daily_change > -5: - overall_distribution['down_5_to_1'] += 1 - else: - overall_distribution['down_over_5'] += 1 - - # ==================== 构建响应数据 ==================== - - events_data = [] - for event in paginated.items: - event_stocks = all_related_stocks.get(event.id, []) - stocks_data = [] - - total_daily_change = 0 - max_daily_change = -999 - total_week_change = 0 - max_week_change = -999 - valid_stocks_count = 0 - - # 处理每个股票的数据 - for stock in event_stocks: - clean_code = stock.stock_code.replace('.SH', '').replace('.SZ', '').replace('.BJ', '') - stock_info = stock_changes.get(clean_code, {}) - - daily_change = stock_info.get('daily_change', 0) - week_change = stock_info.get('week_change', 0) - - if stock_info: - total_daily_change += daily_change - max_daily_change = max(max_daily_change, daily_change) - total_week_change += week_change - max_week_change = max(max_week_change, week_change) - valid_stocks_count += 1 - - stocks_data.append({ - "stock_code": stock.stock_code, - "stock_name": stock.stock_name, - "sector": stock.sector, - "week_change": round(week_change, 2), - "daily_change": round(daily_change, 2) - }) - - avg_daily_change = total_daily_change / valid_stocks_count if valid_stocks_count > 0 else 0 - avg_week_change = total_week_change / valid_stocks_count if valid_stocks_count > 0 else 0 - - if max_daily_change == -999: - max_daily_change = 0 - if max_week_change == -999: - max_week_change = 0 - - # 构建事件数据 - event_dict = { - 'id': event.id, - 'title': event.title, - 'description': event.description, - 'event_type': event.event_type, - 'importance': event.importance, - 'status': event.status, - 'created_at': event.created_at.isoformat() if event.created_at else None, - 'updated_at': event.updated_at.isoformat() if event.updated_at else None, - 'start_time': event.start_time.isoformat() if event.start_time else None, - 'end_time': event.end_time.isoformat() if event.end_time else None, - 'related_stocks': stocks_data, - 'stocks_stats': { - 'stocks_count': len(event_stocks), - 'valid_stocks_count': valid_stocks_count, - # 周涨跌统计 - 'avg_week_change': round(avg_week_change, 2), - 'max_week_change': round(max_week_change, 2), - # 日涨跌统计 - 'avg_daily_change': round(avg_daily_change, 2), - 'max_daily_change': round(max_daily_change, 2) - } - } - - # 统计信息 - if include_stats: - event_dict.update({ - 'hot_score': event.hot_score, - 'view_count': event.view_count, - 'post_count': event.post_count, - 'follower_count': event.follower_count, - 'related_avg_chg': event.related_avg_chg, - 'related_max_chg': event.related_max_chg, - 'related_week_chg': event.related_week_chg, - 'invest_score': event.invest_score, - 'trending_score': event.trending_score, - }) - - # 创建者信息 - if include_creator: - event_dict['creator'] = { - 'id': event.creator.id if event.creator else None, - 'username': event.creator.username if event.creator else 'Anonymous', - 'avatar_url': get_full_avatar_url(event.creator.avatar_url) if event.creator else None, - 'is_creator': event.creator.is_creator if event.creator else False, - 'creator_type': event.creator.creator_type if event.creator else None - } - - # 关联数据 - event_dict['keywords'] = event.keywords if isinstance(event.keywords, list) else [] - event_dict['related_industries'] = event.related_industries - - # 包含统计信息 - if include_stats: - event_dict['stats'] = { - 'related_stocks_count': len(event_stocks), - 'historical_events_count': 0, # 需要额外查询 - 'related_data_count': 0, # 需要额外查询 - 'related_concepts_count': 0 # 需要额外查询 - } - - # 包含关联数据 - if include_related_data: - event_dict['related_stocks'] = [{ - 'id': stock.id, - 'stock_code': stock.stock_code, - 'stock_name': stock.stock_name, - 'sector': stock.sector, - 'correlation': float(stock.correlation) if stock.correlation else 0 - } for stock in event_stocks[:5]] # 限制返回5个 - - events_data.append(event_dict) - - # ==================== 构建筛选信息 ==================== - - applied_filters = {} - - # 记录已应用的筛选条件 - if event_type != 'all': - applied_filters['type'] = event_type - if importance != 'all': - applied_filters['importance'] = importance - if start_date: - applied_filters['start_date'] = start_date - if end_date: - applied_filters['end_date'] = end_date - if industry_level and industry_classification: - applied_filters['industry_level'] = industry_level - applied_filters['industry_classification'] = industry_classification - if tag: - applied_filters['tag'] = tag - if tags: - applied_filters['tags'] = tags - if search_query: - applied_filters['search_query'] = search_query - applied_filters['search_type'] = search_type - - # ==================== 返回结果(保持完全兼容) ==================== - - return jsonify({ - 'success': True, - 'data': { - 'events': events_data, - 'pagination': { - 'page': paginated.page, - 'per_page': paginated.per_page, - 'total': paginated.total, - 'pages': paginated.pages, - 'has_prev': paginated.has_prev, - 'has_next': paginated.has_next - }, - 'filters': { - 'applied_filters': { - 'type': event_type, - 'status': event_status, - 'importance': importance, - 'stock_sector': stock_sector, # 保持兼容 - 'secondary_sector': secondary_sector, # 保持兼容 - 'sort': sort_by, - 'order': order - } - }, - # 整体股票涨跌幅分布统计 - 'overall_stats': { - 'total_stocks': len(all_stocks_for_stats) if 'all_stocks_for_stats' in locals() else 0, - 'change_distribution': overall_distribution, - 'change_distribution_percentages': { - k: v for k, v in overall_distribution.items() - } - } - } - }) - - except Exception as e: - app.logger.error(f"获取事件列表出错: {str(e)}", exc_info=True) - return jsonify({ - 'success': False, - 'error': str(e), - 'error_type': type(e).__name__ - }), 500 - - -def get_filter_counts(base_query): - """ - 获取各个筛选条件的计数信息 - 可用于前端显示筛选选项的可用数量 - """ - try: - counts = {} - - # 重要性计数 - importance_counts = db.session.query( - Event.importance, - func.count(Event.id).label('count') - ).filter( - Event.id.in_(base_query.with_entities(Event.id).subquery()) - ).group_by(Event.importance).all() - - counts['importance'] = {item.importance or 'unknown': item.count for item in importance_counts} - - # 事件类型计数 - type_counts = db.session.query( - Event.event_type, - func.count(Event.id).label('count') - ).filter( - Event.id.in_(base_query.with_entities(Event.id).subquery()) - ).group_by(Event.event_type).all() - - counts['event_type'] = {item.event_type or 'unknown': item.count for item in type_counts} - - return counts - except Exception: - return {} - - -def get_event_class(count): - """根据事件数量返回对应的样式类""" - if count >= 10: - return 'bg-gradient-danger' - elif count >= 7: - return 'bg-gradient-warning' - elif count >= 4: - return 'bg-gradient-info' - else: - return 'bg-gradient-success' -@app.route('/api/calendar-event-counts') - - -def get_calendar_event_counts(): - """获取整月的事件数量统计,仅统计type为event的事件""" - try: - # 获取当前月份的开始和结束日期 - today = datetime.now() - start_date = today.replace(day=1) - if today.month == 12: - end_date = today.replace(year=today.year + 1, month=1, day=1) - else: - end_date = today.replace(month=today.month + 1, day=1) - - # 修改查询以仅统计type为event的事件数量 - query = """ - SELECT DATE(calendar_time) as date, COUNT(*) as count - FROM future_events - WHERE calendar_time BETWEEN :start_date AND :end_date - AND type = 'event' - GROUP BY DATE(calendar_time) - """ - - result = db.session.execute(text(query), { - 'start_date': start_date, - 'end_date': end_date - }) - - # 格式化结果为日历事件格式 - events = [{ - 'title': f'{day.count} 个事件', - 'start': day.date.isoformat() if day.date else None, - 'className': get_event_class(day.count) - } for day in result] - - return jsonify(events) - - except Exception as e: - return jsonify({'error': str(e)}), 500 - - - -def get_full_avatar_url(avatar_url): - """ - 统一处理头像URL,确保返回完整的可访问URL - - Args: - avatar_url: 头像URL字符串 - - Returns: - 完整的头像URL,如果没有头像则返回默认头像URL - """ - if not avatar_url: - # 返回默认头像 - return f"{DOMAIN}/static/assets/img/default-avatar.png" - - # 如果已经是完整URL(http或https开头),直接返回 - if avatar_url.startswith(('http://', 'https://')): - return avatar_url - - # 如果是相对路径,拼接域名 - if avatar_url.startswith('/'): - return f"{DOMAIN}{avatar_url}" - else: - return f"{DOMAIN}/{avatar_url}" - - -# 修改User模型的to_dict方法 -def to_dict(self): - """转换为字典格式,方便API返回""" - return { - 'id': self.id, - 'username': self.username, - 'email': self.email, - 'nickname': self.nickname, - 'avatar_url': get_full_avatar_url(self.avatar_url), # 使用统一处理函数 - 'bio': self.bio, - 'location': self.location, - 'is_verified': self.is_verified, - 'user_level': self.user_level, - 'reputation_score': self.reputation_score, - 'post_count': self.post_count, - 'follower_count': self.follower_count, - 'following_count': self.following_count, - 'created_at': self.created_at.isoformat() if self.created_at else None, - 'last_seen': self.last_seen.isoformat() if self.last_seen else None - } - -# ==================== 标准化API接口 ==================== - -# 1. 首页接口 -@app.route('/api/home', methods=['GET']) - - -def api_home(): - try: - seven_days_ago = datetime.now() - timedelta(days=7) - hot_events = Event.query.filter( - Event.status == 'active', - Event.created_at >= seven_days_ago - ).order_by(Event.hot_score.desc()).limit(10).all() - - events_data = [] - for event in hot_events: - related_stocks = RelatedStock.query.filter_by(event_id=event.id).all() - - # 计算相关性统计数据 - correlations = [float(stock.correlation or 0) for stock in related_stocks] - avg_correlation = sum(correlations) / len(correlations) if correlations else 0 - max_correlation = max(correlations) if correlations else 0 - - stocks_data = [] - total_week_change = 0 - max_week_change = 0 - total_daily_change = 0 # 新增:日涨跌幅总和 - max_daily_change = 0 # 新增:最大日涨跌幅 - valid_stocks_count = 0 - - for stock in related_stocks: - stock_code = stock.stock_code.split('.')[0] - - # 获取最新交易日数据 - latest_trade = db.session.execute(text(""" - SELECT * FROM ea_trade - WHERE SECCODE = :stock_code - ORDER BY TRADEDATE DESC - LIMIT 1 - """), {"stock_code": stock_code}).first() - - week_change = 0 - daily_change = 0 # 新增:日涨跌幅 - - if latest_trade and latest_trade.F007N: - latest_price = float(latest_trade.F007N or 0) - latest_date = latest_trade.TRADEDATE - daily_change = float(latest_trade.F010N or 0) # F010N是日涨跌幅字段 - - # 更新日涨跌幅统计 - total_daily_change += daily_change - max_daily_change = max(max_daily_change, daily_change) - - # 获取最近5条交易记录 - week_ago_trades = db.session.execute(text(""" - SELECT * FROM ea_trade - WHERE SECCODE = :stock_code - AND TRADEDATE < :latest_date - ORDER BY TRADEDATE DESC - LIMIT 5 - """), { - "stock_code": stock_code, - "latest_date": latest_date - }).fetchall() - - if week_ago_trades and week_ago_trades[-1].F007N: - week_ago_price = float(week_ago_trades[-1].F007N or 0) - if week_ago_price > 0: - week_change = (latest_price - week_ago_price) / week_ago_price * 100 - total_week_change += week_change - max_week_change = max(max_week_change, week_change) - valid_stocks_count += 1 - - stocks_data.append({ - "stock_code": stock.stock_code, - "stock_name": stock.stock_name, - "correlation": float(stock.correlation or 0), - "sector": stock.sector, - "week_change": round(week_change, 2), - "daily_change": round(daily_change, 2), # 新增:个股日涨跌幅 - "latest_trade_date": latest_trade.TRADEDATE.strftime("%Y-%m-%d") if latest_trade else None - }) - - # 计算平均值 - avg_week_change = total_week_change / valid_stocks_count if valid_stocks_count > 0 else 0 - avg_daily_change = total_daily_change / valid_stocks_count if valid_stocks_count > 0 else 0 - - events_data.append({ - "id": event.id, - "title": event.title, - "description": event.description, - 'event_type': event.event_type, - 'importance': event.importance, # 添加重要性 - 'status': event.status, - "created_at": event.created_at.strftime("%Y-%m-%d %H:%M:%S"), - 'updated_at': event.updated_at.isoformat() if event.updated_at else None, - 'start_time': event.start_time.isoformat() if event.start_time else None, - 'end_time': event.end_time.isoformat() if event.end_time else None, - 'view_count': event.view_count, # 添加浏览量 - 'post_count': event.post_count, # 添加帖子数 - 'follower_count': event.follower_count, # 添加关注者数 - "related_stocks": stocks_data, - "stocks_stats": { - "avg_correlation": round(avg_correlation, 2), - "max_correlation": round(max_correlation, 2), - "stocks_count": len(related_stocks), - "valid_stocks_count": valid_stocks_count, - # 周涨跌统计 - "avg_week_change": round(avg_week_change, 2), - "max_week_change": round(max_week_change, 2), - # 日涨跌统计 - "avg_daily_change": round(avg_daily_change, 2), - "max_daily_change": round(max_daily_change, 2) - } - }) - - return jsonify({ - "code": 200, - "message": "success", - "data": { - "events": events_data - } - }) - - except Exception as e: - print(f"Error in api_home: {str(e)}") - return jsonify({ - "code": 500, - "message": str(e), - "data": None - }), 500 - -@app.route('/api/auth/logout', methods=['POST']) -def logout_with_token(): - """使用token登出""" - # 从请求头获取token - auth_header = request.headers.get('Authorization') - if auth_header and auth_header.startswith('Bearer '): - token = auth_header[7:] - else: - data = request.get_json() - token = data.get('token') if data else None - - if token and token in user_tokens: - del user_tokens[token] - - # 清除session - session.clear() - - return jsonify({'message': '登出成功'}), 200 -def send_sms_code(phone, code, template_id): - """发送短信验证码""" - try: - cred = credential.Credential(SMS_SECRET_ID, SMS_SECRET_KEY) - httpProfile = HttpProfile() - httpProfile.endpoint = "sms.tencentcloudapi.com" - - clientProfile = ClientProfile() - clientProfile.httpProfile = httpProfile - client = sms_client.SmsClient(cred, "ap-beijing", clientProfile) - - req = models.SendSmsRequest() - params = { - "PhoneNumberSet": [phone], - "SmsSdkAppId": SMS_SDK_APP_ID, - "TemplateId": template_id, - "SignName": SMS_SIGN_NAME, - "TemplateParamSet": [code] if template_id == SMS_TEMPLATE_REGISTER else [code, "5"] - } - req.from_json_string(json.dumps(params)) - - resp = client.SendSms(req) - return True - except TencentCloudSDKException as err: - print(f"SMS Error: {err}") - return False - -def generate_verification_code(): - """生成6位数字验证码""" - return ''.join(random.choices(string.digits, k=6)) - -@app.route('/api/auth/send-sms', methods=['POST']) -def send_sms_verification(): - """发送手机验证码(统一接口,自动判断场景)""" - data = request.get_json() - phone = data.get('phone') - - if not phone: - return jsonify({'error': '手机号不能为空'}), 400 - - # 检查手机号是否已注册 - user_exists = User.query.filter_by(phone=phone).first() is not None - - # 生成验证码 - code = generate_verification_code() - - # 根据用户是否存在自动选择模板 - template_id = SMS_TEMPLATE_LOGIN if user_exists else SMS_TEMPLATE_REGISTER - - # 发送短信 - if send_sms_code(phone, code, template_id): - # 统一存储验证码(5分钟有效) - verification_codes[phone] = { - 'code': code, - 'expires': time.time() + 300 - } - - # 简单返回成功,不暴露用户是否存在的信息 - return jsonify({ - 'message': '验证码已发送', - 'expires_in': 300 # 告诉前端验证码有效期(秒) - }), 200 - else: - return jsonify({'error': '验证码发送失败'}), 500 - - -def generate_token(length=32): - """生成随机token""" - characters = string.ascii_letters + string.digits - return ''.join(secrets.choice(characters) for _ in range(length)) - - - -@app.route('/api/auth/login/phone', methods=['POST']) -def login_with_phone(): - """统一的手机号登录/注册接口""" - data = request.get_json() - phone = data.get('phone') - code = data.get('code') - username = data.get('username') # 可选,新用户可以提供 - password = data.get('password') # 可选,新用户可以提供 - - if not all([phone, code]): - return jsonify({ - 'code': 400, - 'message': '手机号和验证码不能为空' - }), 400 - - # 验证验证码 - stored_code = verification_codes.get(phone) - - if not stored_code or stored_code['expires'] < time.time(): - return jsonify({ - 'code': 400, - 'message': '验证码已过期' - }), 400 - - if stored_code['code'] != code: - return jsonify({ - 'code': 400, - 'message': '验证码错误' - }), 400 - - try: - # 查找用户 - user = User.query.filter_by(phone=phone).first() - is_new_user = False - - # 如果用户不存在,自动注册 - if not user: - is_new_user = True - - # 如果提供了用户名,检查是否已存在 - if username: - if User.query.filter_by(username=username).first(): - return jsonify({ - 'code': 400, - 'message': '用户名已被使用,请换一个' - }), 400 - else: - # 自动生成用户名 - base_username = f"user_{phone[-4:]}" - username = base_username - counter = 1 - while User.query.filter_by(username=username).first(): - username = f"{base_username}_{counter}" - counter += 1 - - # 创建新用户 - user = User(username=username, phone=phone) - user.email = f"{username}@valuefrontier.temp" - - # 如果提供了密码就使用,否则生成随机密码 - if password: - user.set_password(password) - else: - random_password = generate_token(16) - user.set_password(random_password) - - user.phone_confirmed = True - - db.session.add(user) - db.session.commit() - - # 生成token - token = generate_token(32) - - # 存储token映射(30天有效期) - user_tokens[token] = { - 'user_id': user.id, - 'expires': datetime.now() + timedelta(days=30) - } - - # 清除验证码 - del verification_codes[phone] - - # 设置session(保持与原有逻辑兼容) - session.permanent = True - session['user_id'] = user.id - session['username'] = user.username - session['logged_in'] = True - - # 返回响应 - response_data = { - 'code': 0, - 'message': '欢迎回来' if not is_new_user else '注册成功,欢迎加入', - 'token': token, - 'is_new_user': is_new_user, # 告诉前端是否是新用户 - 'user': { - 'id': user.id, - 'username': user.username, - 'phone': user.phone, - 'need_complete_profile': is_new_user # 提示新用户完善资料 - } - } - - return jsonify(response_data), 200 - - except Exception as e: - db.session.rollback() - print(f"Login/Register error: {e}") - return jsonify({ - 'code': 500, - 'message': '操作失败,请重试' - }), 500 - - -@app.route('/api/auth/verify-token', methods=['POST']) -def verify_token(): - """验证token有效性(可选接口)""" - data = request.get_json() - token = data.get('token') - - if not token: - return jsonify({'valid': False, 'message': 'Token不能为空'}), 400 - - token_data = user_tokens.get(token) - - if not token_data: - return jsonify({'valid': False, 'message': 'Token无效','code':401}), 401 - - # 检查是否过期 - if token_data['expires'] < datetime.now(): - del user_tokens[token] - return jsonify({'valid': False, 'message': 'Token已过期'}), 401 - - # 获取用户信息 - user = User.query.get(token_data['user_id']) - if not user: - return jsonify({'valid': False, 'message': '用户不存在'}), 404 - - return jsonify({ - 'valid': True, - 'user': { - 'id': user.id, - 'username': user.username, - 'phone': user.phone - } - }), 200 - - - - -def generate_jwt_token(user_id): - """ - 生成JWT Token - 与原系统保持一致 - - Args: - user_id: 用户ID - - Returns: - str: JWT token字符串 - """ - payload = { - 'user_id': user_id, - 'exp': datetime.utcnow() + timedelta(hours=JWT_EXPIRATION_HOURS), - 'iat': datetime.utcnow() - } - - token = jwt.encode(payload, JWT_SECRET_KEY, algorithm=JWT_ALGORITHM) - return token - - - - -@app.route('/api/auth/login/wechat', methods=['POST']) -def api_login_wechat(): - try: - # 1. 获取请求数据 - data = request.get_json() - code = data.get('code') if data else None - - if not code: - return jsonify({ - 'code': 400, - 'message': '缺少必要的参数', - 'data': None - }), 400 - - # 2. 验证code格式 - if not isinstance(code, str) or len(code) < 10: - return jsonify({ - 'code': 400, - 'message': 'code格式无效', - 'data': None - }), 400 - - logger.info(f"开始处理微信登录,code长度: {len(code)}") - - # 3. 调用微信接口获取用户信息 - wx_api_url = 'https://api.weixin.qq.com/sns/jscode2session' - params = { - 'appid': WECHAT_APP_ID, - 'secret': WECHAT_APP_SECRET, - 'js_code': code, - 'grant_type': 'authorization_code' - } - - try: - response = requests.get(wx_api_url, params=params, timeout=10) - response.raise_for_status() - wx_data = response.json() - - # 检查微信API返回的错误 - if 'errcode' in wx_data and wx_data['errcode'] != 0: - error_messages = { - -1: '系统繁忙,请稍后重试', - 40029: 'code无效或已过期', - 45011: '频率限制,请稍后再试', - 40013: 'AppID错误', - 40125: 'AppSecret错误', - 40226: '高风险用户,登录被拦截' - } - - error_msg = error_messages.get( - wx_data['errcode'], - f"微信接口错误: {wx_data.get('errmsg', '未知错误')}" - ) - - logger.error(f"WeChat API error {wx_data['errcode']}: {error_msg}") - - return jsonify({ - 'code': 400, - 'message': error_msg, - 'data': None - }), 400 - - # 验证必需字段 - if 'openid' not in wx_data or 'session_key' not in wx_data: - logger.error("响应缺少必需字段") - return jsonify({ - 'code': 500, - 'message': '微信响应格式错误', - 'data': None - }), 500 - - openid = wx_data['openid'] - session_key = wx_data['session_key'] - unionid = wx_data.get('unionid') # 可能为None - - logger.info(f"成功获取微信用户信息 - OpenID: {openid[:8]}...") - if unionid: - logger.info(f"获取到UnionID: {unionid[:8]}...") - - except requests.exceptions.Timeout: - logger.error("请求微信API超时") - return jsonify({ - 'code': 500, - 'message': '请求超时,请重试', - 'data': None - }), 500 - except requests.exceptions.RequestException as e: - logger.error(f"网络请求失败: {str(e)}") - return jsonify({ - 'code': 500, - 'message': '网络错误', - 'data': None - }), 500 - - # 4. 查找或创建用户 - 核心逻辑 - user = None - is_new_user = False - - - logger.info(f"开始查找用户 - UnionID: {unionid}, OpenID: {openid[:8]}...") - - if unionid: - # 情况1: 有unionid,优先通过unionid查找 - user = User.query.filter_by(wechat_union_id=unionid).first() - - if user: - logger.info(f"通过UnionID找到用户: {user.username}") - # 更新openid(可能用户从不同小程序登录) - if user.wechat_open_id != openid: - user.wechat_open_id = openid - logger.info(f"更新用户OpenID: {openid[:8]}...") - else: - # unionid没找到,再尝试用openid查找(处理历史数据) - user = User.query.filter_by(wechat_open_id=openid).first() - if user: - logger.info(f"通过OpenID找到用户: {user.username}") - # 补充unionid - user.wechat_union_id = unionid - logger.info(f"为用户补充UnionID: {unionid[:8]}...") - else: - # 情况2: 没有unionid,只能通过openid查找 - logger.warning("未获取到UnionID(小程序可能未绑定开放平台)") - user = User.query.filter_by(wechat_open_id=openid).first() - if user: - logger.info(f"通过OpenID找到用户: {user.username}") - - # 5. 创建新用户 - if not user: - is_new_user = True - - # 生成唯一用户名 - timestamp = int(time.time()) - username = f"wx_{timestamp}_{openid[-6:]}" - - # 确保用户名唯一 - counter = 0 - base_username = username - while User.query.filter_by(username=username).first(): - counter += 1 - username = f"{base_username}_{counter}" - - # 创建用户对象(使用你的User模型) - user = User( - username=username, - email=f"{username}@wechat.local", # 占位邮箱 - password="wechat_login_no_password" # 微信登录不需要密码 - ) - - # 设置微信相关字段 - user.wechat_open_id = openid - user.wechat_union_id = unionid - user.status = 'active' - user.email_confirmed = False - - # 设置默认值 - user.nickname = f"微信用户{openid[-4:]}" - user.bio = "" # 空的个人简介 - user.avatar_url = None # 稍后会处理 - user.is_creator = False - user.is_verified = False - user.user_level = 1 - user.reputation_score = 0 - user.contribution_point = 0 - user.post_count = 0 - user.comment_count = 0 - user.follower_count = 0 - user.following_count = 0 - - # 设置默认偏好 - user.email_notifications = True - user.privacy_level = 'public' - user.theme_preference = 'light' - - db.session.add(user) - logger.info(f"创建新用户: {username}") - else: - # 更新最后登录时间 - user.update_last_seen() - logger.info(f"用户登录: {user.username}") - - # 6. 提交数据库更改 - try: - db.session.commit() - except Exception as e: - db.session.rollback() - logger.error(f"保存用户信息失败: {str(e)}") - return jsonify({ - 'code': 500, - 'message': '保存用户信息失败', - 'data': None - }), 500 - - # 7. 生成JWT token(使用原系统的生成方法) - token = generate_token(32) # 使用相同的随机字符串生成器 - - # 存储token映射(与手机登录保持一致) - user_tokens[token] = { - 'user_id': user.id, - 'expires': datetime.now() + timedelta(days=30) # 30天有效期 - } - - # 设置session(可选,保持与手机登录一致) - session.permanent = True - session['user_id'] = user.id - session['username'] = user.username - session['logged_in'] = True - - # 9. 构造返回数据 - 完全匹配要求的格式 - response_data = { - 'code': 200, - 'data': { - 'token': token, # 现在这个token能被token_required识别了 - 'user': { - 'avatar_url': get_full_avatar_url(user.avatar_url), - 'bio': user.bio or "", - 'email': user.email, - 'id': user.id, - 'is_creator': user.is_creator, - 'is_verified': user.is_verified, - 'nickname': user.nickname or user.username, - 'reputation_score': user.reputation_score, - 'user_level': user.user_level, - 'username': user.username - } - }, - 'message': '登录成功' - } - - # 10. 记录日志 - logger.info( - f"微信登录成功 - 用户ID: {user.id}, " - f"用户名: {user.username}, " - f"新用户: {is_new_user}, " - f"有UnionID: {unionid is not None}" - ) - - return jsonify(response_data), 200 - - except Exception as e: - # 捕获所有未处理的异常 - logger.error(f"微信登录处理异常: {str(e)}", exc_info=True) - db.session.rollback() - - return jsonify({ - 'code': 500, - 'message': '服务器内部错误', - 'data': None - }), 500 - - -@app.route('/api/auth/login/email', methods=['POST']) -def api_login_email(): - """邮箱登录接口""" - try: - data = request.get_json() - email = data.get('email') - password = data.get('password') - - if not email or not password: - return jsonify({ - 'code': 400, - 'message': '邮箱和密码不能为空', - 'data': None - }), 400 - - user = User.query.filter_by(email=email).first() - if not user or not user.check_password(password): - return jsonify({ - 'code': 400, - 'message': '邮箱或密码错误', - 'data': None - }), 400 - token = generate_jwt_token(user.id) - login_user(user) - user.update_last_seen() - db.session.commit() - - return jsonify({ - 'code': 200, - 'message': '登录成功', - 'data': { - 'token': token, - 'user_id': user.id, - 'username': user.username, - 'email': user.email, - 'is_verified': user.is_verified, - 'user_level': user.user_level - } - }) - except Exception as e: - return jsonify({ - 'code': 500, - 'message': str(e), - 'data': None - }), 500 - - -# 5. 事件详情-相关标的接口 -@app.route('/api/event//related-stocks-detail', methods=['GET']) -def api_event_related_stocks(event_id): - """事件相关标的详情接口 - 仅限 Pro/Max 会员""" - try: - import time as time_module - from datetime import datetime, timedelta, time as dt_time - from sqlalchemy import text - - # 性能计时开始 - start_time = time_module.time() - - event = Event.query.get_or_404(event_id) - related_stocks = event.related_stocks.order_by(RelatedStock.correlation.desc()).all() - - # 获取ClickHouse客户端用于分时数据查询 - client = get_clickhouse_client() - - # 获取事件时间(如果事件有开始时间,使用开始时间;否则使用创建时间) - event_time = event.start_time if event.start_time else event.created_at - current_time = datetime.now() - - # 定义交易日和时间范围计算函数(与 app.py 中的逻辑完全一致) - def get_trading_day_and_times(event_datetime): - event_date = event_datetime.date() - event_time_only = event_datetime.time() - - # Trading hours - market_open = dt_time(9, 30) - market_close = dt_time(15, 0) - - with engine.connect() as conn: - # First check if the event date itself is a trading day - is_trading_day = conn.execute(text(""" - SELECT 1 - FROM trading_days - WHERE EXCHANGE_DATE = :date - """), {"date": event_date}).fetchone() is not None - - if is_trading_day: - # If it's a trading day, determine time period based on event time - if event_time_only < market_open: - # Before market opens - use full trading day - return event_date, market_open, market_close - elif event_time_only > market_close: - # After market closes - get next trading day - next_trading_day = conn.execute(text(""" - SELECT EXCHANGE_DATE - FROM trading_days - WHERE EXCHANGE_DATE > :date - ORDER BY EXCHANGE_DATE LIMIT 1 - """), {"date": event_date}).fetchone() - # Convert to date object if we found a next trading day - return (next_trading_day[0].date() if next_trading_day else None, - market_open, market_close) - else: - # During trading hours - return event_date, event_time_only, market_close - else: - # If not a trading day, get next trading day - next_trading_day = conn.execute(text(""" - SELECT EXCHANGE_DATE - FROM trading_days - WHERE EXCHANGE_DATE > :date - ORDER BY EXCHANGE_DATE LIMIT 1 - """), {"date": event_date}).fetchone() - # Convert to date object if we found a next trading day - return (next_trading_day[0].date() if next_trading_day else None, - market_open, market_close) - - trading_day, start_time, end_time = get_trading_day_and_times(event_time) - - if not trading_day: - # 如果没有交易日,返回空数据 - return jsonify({ - 'code': 200, - 'message': 'success', - 'data': { - 'event_id': event_id, - 'event_title': event.title, - 'event_desc': event.description, - 'event_type': event.event_type, - 'event_importance': event.importance, - 'event_status': event.status, - 'event_created_at': event.created_at.strftime("%Y-%m-%d %H:%M:%S"), - 'event_start_time': event.start_time.isoformat() if event.start_time else None, - 'event_end_time': event.end_time.isoformat() if event.end_time else None, - 'keywords': event.keywords, - 'view_count': event.view_count, - 'post_count': event.post_count, - 'follower_count': event.follower_count, - 'related_stocks': [], - 'total_count': 0 - } - }) - - # For historical dates, ensure we're using actual data - start_datetime = datetime.combine(trading_day, start_time) - end_datetime = datetime.combine(trading_day, end_time) - - # If the trading day is in the future relative to current time, return only names without data - if trading_day > current_time.date(): - start_datetime = datetime.combine(trading_day, start_time) - end_datetime = datetime.combine(trading_day, end_time) - - print(f"事件时间: {event_time}, 交易日: {trading_day}, 时间范围: {start_datetime} - {end_datetime}") - - # ======================================== - # 辅助函数:安全的浮点数计算 - # ======================================== - def safe_float(value): - """安全地转换为 float,如果是 None 或无效值则返回 None""" - try: - if value is None: - return None - return float(value) - except (ValueError, TypeError): - return None - - def safe_calculate_change(latest, first): - """安全地计算涨跌额和涨跌幅""" - try: - if latest is None or first is None or first == 0: - return None, None - change_amount = latest - first - change_pct = (change_amount / first) * 100 - return change_amount, change_pct - except (TypeError, ZeroDivisionError): - return None, None - - # ======================================== - # 性能优化:批量查询所有股票的价格数据 - # ======================================== - stock_codes = [stock.stock_code for stock in related_stocks] - print(f"批量查询 {len(stock_codes)} 只股票的价格数据...") - - # 批量查询:使用 IN 语句一次性获取所有股票的 first_price 和 last_price - stock_price_map = {} - try: - batch_query = """ - SELECT - code, - first_value(close) OVER (PARTITION BY code ORDER BY timestamp ASC) as first_price, - last_value(close) OVER (PARTITION BY code ORDER BY timestamp ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) as last_price - FROM stock_minute - WHERE code IN %(codes)s - AND timestamp >= %(start)s - AND timestamp <= %(end)s - GROUP BY code, close, timestamp - ORDER BY code, timestamp DESC - """ - - # 更优化的查询:使用子查询分别获取 first 和 last - optimized_batch_query = """ - WITH first_prices AS ( - SELECT - code, - close as first_price - FROM ( - SELECT - code, - close, - ROW_NUMBER() OVER (PARTITION BY code ORDER BY timestamp ASC) as rn - FROM stock_minute - WHERE code IN %(codes)s - AND timestamp >= %(start)s - AND timestamp <= %(end)s - ) t - WHERE rn = 1 - ), - last_prices AS ( - SELECT - code, - close as last_price, - open as open_price, - high as high_price, - low as low_price, - volume, - amt as amount - FROM ( - SELECT - code, - close, - open, - high, - low, - volume, - amt, - ROW_NUMBER() OVER (PARTITION BY code ORDER BY timestamp DESC) as rn - FROM stock_minute - WHERE code IN %(codes)s - AND timestamp >= %(start)s - AND timestamp <= %(end)s - ) t - WHERE rn = 1 - ) - SELECT - f.code, - f.first_price, - l.last_price, - (l.last_price - f.first_price) / f.first_price * 100 as change_pct, - l.last_price - f.first_price as change_amount, - l.open_price, - l.high_price, - l.low_price, - l.volume, - l.amount - FROM first_prices f - INNER JOIN last_prices l ON f.code = l.code - """ - - batch_data = client.execute(optimized_batch_query, { - 'codes': tuple(stock_codes), - 'start': start_datetime, - 'end': end_datetime - }) - - # 将查询结果存入字典,key 为股票代码 - for row in batch_data: - # 使用辅助函数安全转换 - first_price_val = safe_float(row[1]) - latest_price_val = safe_float(row[2]) - - # 使用 SQL 计算的结果,如果为 None 则重新计算 - change_pct_val = safe_float(row[3]) - change_amount_val = safe_float(row[4]) - - # 如果 SQL 计算失败,使用辅助函数重新计算 - if change_pct_val is None or change_amount_val is None: - calc_amount, calc_pct = safe_calculate_change(latest_price_val, first_price_val) - if change_amount_val is None: - change_amount_val = calc_amount - if change_pct_val is None: - change_pct_val = calc_pct - - stock_price_map[row[0]] = { - 'first_price': first_price_val, - 'latest_price': latest_price_val, - 'change_pct': change_pct_val, - 'change_amount': change_amount_val, - 'open_price': safe_float(row[5]), - 'high_price': safe_float(row[6]), - 'low_price': safe_float(row[7]), - 'volume': int(row[8]) if row[8] is not None else None, - 'amount': safe_float(row[9]), - } - - print(f"批量查询成功,获取到 {len(stock_price_map)} 只股票的数据") - - except Exception as e: - import traceback - print(f"批量查询失败,将回退到逐个查询: {e}") - print(f"详细错误信息: {traceback.format_exc()}") - stock_price_map = {} - - def get_minute_chart_data(stock_code): - """获取股票分时图数据""" - try: - # 获取当前日期或最新交易日的分时数据 - from datetime import datetime, timedelta, time as dt_time - today = datetime.now().date() - - # 获取最新交易日的分时数据 - data = client.execute(""" - SELECT - timestamp, - open, - high, - low, - close, - volume, - amt - FROM stock_minute - WHERE code = %(code)s - AND timestamp >= %(start)s - AND timestamp <= %(end)s - ORDER BY timestamp - """, { - 'code': stock_code, - 'start': datetime.combine(today, dt_time(9, 30)), - 'end': datetime.combine(today, dt_time(15, 0)) - }) - - # 如果今天没有数据,获取最近的交易日数据 - if not data: - # 获取最近的交易日数据 - recent_data = client.execute(""" - SELECT - timestamp, - open, - high, - low, - close, - volume, - amt - FROM stock_minute - WHERE code = %(code)s - AND timestamp >= ( - SELECT MAX(timestamp) - INTERVAL 1 DAY - FROM stock_minute - WHERE code = %(code)s - ) - ORDER BY timestamp - """, { - 'code': stock_code - }) - data = recent_data - - # 格式化数据 - minute_data = [] - for row in data: - minute_data.append({ - 'time': row[0].strftime('%H:%M'), - 'open': float(row[1]) if row[1] else None, - 'high': float(row[2]) if row[2] else None, - 'low': float(row[3]) if row[3] else None, - 'close': float(row[4]) if row[4] else None, - 'volume': float(row[5]) if row[5] else None, - 'amount': float(row[6]) if row[6] else None - }) - - return minute_data - - except Exception as e: - print(f"Error fetching minute data for {stock_code}: {e}") - return [] - - stocks_data = [] - for stock in related_stocks: - # 获取股票基本信息 - stock_info = StockBasicInfo.query.filter_by(SECCODE=stock.stock_code).first() - if not stock_info: - base_code = stock.stock_code.split('.')[0] - stock_info = StockBasicInfo.query.filter_by(SECCODE=base_code).first() - - # 初始化变量 - latest_price = None - first_price = None - change_pct = None - change_amount = None - open_price = None - high_price = None - low_price = None - volume = None - amount = None - trade_date = trading_day - - # 优先使用批量查询的结果 - if stock.stock_code in stock_price_map: - print(f"使用批量查询结果: {stock.stock_code}") - price_data = stock_price_map[stock.stock_code] - latest_price = price_data['latest_price'] - first_price = price_data['first_price'] - change_pct = price_data['change_pct'] - change_amount = price_data['change_amount'] - open_price = price_data['open_price'] - high_price = price_data['high_price'] - low_price = price_data['low_price'] - volume = price_data['volume'] - amount = price_data['amount'] - else: - # 批量查询失败或该股票无数据,回退到单独查询 - print(f"批量查询无结果,单独查询: {stock.stock_code}") - try: - # 使用与 get_stock_quotes 完全相同的 SQL 查询 - # 获取事件时间点的第一个价格 (first_price) 和当前时间的最后一个价格 (last_price) - data = client.execute(""" - WITH first_price AS ( - SELECT close - FROM stock_minute - WHERE code = %(code)s - AND timestamp >= %(start)s - AND timestamp <= %(end)s - ORDER BY timestamp - LIMIT 1 - ), - last_price AS ( - SELECT close - FROM stock_minute - WHERE code = %(code)s - AND timestamp >= %(start)s - AND timestamp <= %(end)s - ORDER BY timestamp DESC - LIMIT 1 - ) - SELECT - last_price.close as last_price, - (last_price.close - first_price.close) / first_price.close * 100 as change, - first_price.close as first_price - FROM last_price - CROSS JOIN first_price - WHERE EXISTS (SELECT 1 FROM first_price) - AND EXISTS (SELECT 1 FROM last_price) - """, { - 'code': stock.stock_code, - 'start': start_datetime, - 'end': end_datetime - }) - - print(f"股票 {stock.stock_code} 查询结果: {data}") - - if data and data[0] and data[0][0] is not None: - latest_price = safe_float(data[0][0]) - change_pct = safe_float(data[0][1]) - first_price = safe_float(data[0][2]) if len(data[0]) > 2 else None - - # 使用辅助函数计算涨跌额和涨跌幅 - if change_pct is None or change_amount is None: - calc_amount, calc_pct = safe_calculate_change(latest_price, first_price) - if change_amount is None: - change_amount = calc_amount - if change_pct is None: - change_pct = calc_pct - - # 获取额外的价格信息(开盘价、最高价、最低价等) - extra_data = client.execute(""" - SELECT - open, high, low, volume, amt - FROM stock_minute - WHERE code = %(code)s - AND timestamp >= %(start)s - AND timestamp <= %(end)s - ORDER BY timestamp DESC - LIMIT 1 - """, { - 'code': stock.stock_code, - 'start': start_datetime, - 'end': end_datetime - }) - - if extra_data and extra_data[0]: - open_price = safe_float(extra_data[0][0]) - high_price = safe_float(extra_data[0][1]) - low_price = safe_float(extra_data[0][2]) - volume = int(extra_data[0][3]) if extra_data[0][3] is not None else None - amount = safe_float(extra_data[0][4]) - - except Exception as e: - import traceback - print(f"Error fetching price data for {stock.stock_code}: {e}") - print(f"详细错误: {traceback.format_exc()}") - # 如果 ClickHouse 查询失败,尝试使用 TradeData 作为降级方案 - try: - latest_trade = None - search_codes = [stock.stock_code, stock.stock_code.split('.')[0]] - - for code in search_codes: - latest_trade = TradeData.query.filter_by(SECCODE=code) \ - .order_by(TradeData.TRADEDATE.desc()).first() - if latest_trade: - break - - if latest_trade: - latest_price = safe_float(latest_trade.F007N) - open_price = safe_float(latest_trade.F003N) - high_price = safe_float(latest_trade.F005N) - low_price = safe_float(latest_trade.F006N) - first_price = safe_float(latest_trade.F002N) - volume = safe_float(latest_trade.F004N) - amount = safe_float(latest_trade.F011N) - trade_date = latest_trade.TRADEDATE - - # 优先使用数据库字段 - change_pct = safe_float(latest_trade.F010N) - change_amount = safe_float(latest_trade.F009N) - - # 如果数据库字段为空,使用辅助函数计算 - if change_pct is None or change_amount is None: - calc_amount, calc_pct = safe_calculate_change(latest_price, first_price) - if change_amount is None: - change_amount = calc_amount - if change_pct is None: - change_pct = calc_pct - except Exception as fallback_error: - print(f"Fallback query also failed for {stock.stock_code}: {fallback_error}") - - # 获取分时图数据 - minute_chart_data = get_minute_chart_data(stock.stock_code) - - stock_data = { - 'id': stock.id, - 'stock_code': stock.stock_code, - 'stock_name': stock.stock_name, - 'sector': stock.sector, - 'relation_desc': stock.relation_desc, - 'correlation': stock.correlation, - 'momentum': stock.momentum, - 'listing_date': stock_info.F006D.isoformat() if stock_info and stock_info.F006D else None, - 'market': stock_info.F005V if stock_info else None, - - # 交易数据 - 'trade_data': { - 'latest_price': latest_price, - 'first_price': first_price, # 事件发生时的价格 - 'open_price': open_price, - 'high_price': high_price, - 'low_price': low_price, - 'change_amount': round(change_amount, 2) if change_amount is not None else None, - 'change_pct': round(change_pct, 2) if change_pct is not None else None, - 'volume': volume, - 'amount': amount, - 'trade_date': trade_date.isoformat() if trade_date else None, - 'event_start_time': start_datetime.isoformat() if start_datetime else None, # 事件开始时间 - 'event_end_time': end_datetime.isoformat() if end_datetime else None, # 查询结束时间 - } if latest_price is not None else None, - - # 分时图数据 - 'minute_chart_data': minute_chart_data, - - # 图表URL - 'charts': { - 'minute_chart_url': f"/api/stock/{stock.stock_code}/minute-chart", - 'daily_chart_url': f"/api/stock/{stock.stock_code}/kline", - } - } - - stocks_data.append(stock_data) - - # 性能计时结束 - end_time = time_module.time() - elapsed_time = end_time - start_time - print(f"⏱️ 接口总耗时: {elapsed_time:.2f}秒,处理了 {len(stocks_data)} 只股票") - - return jsonify({ - 'code': 200, - 'message': 'success', - 'data': { - 'event_id': event_id, - 'event_title': event.title, - 'event_desc': event.description, - 'event_type': event.event_type, - 'event_importance': event.importance, - 'event_status': event.status, - 'event_created_at': event.created_at.strftime("%Y-%m-%d %H:%M:%S"), - 'event_start_time': event.start_time.isoformat() if event.start_time else None, - 'event_end_time': event.end_time.isoformat() if event.end_time else None, - 'keywords': event.keywords, - 'view_count': event.view_count, - 'post_count': event.post_count, - 'follower_count': event.follower_count, - 'related_stocks': stocks_data, - 'total_count': len(stocks_data) - } - }) - except Exception as e: - return jsonify({ - 'code': 500, - 'message': str(e), - 'data': None - }), 500 - - -@app.route('/api/stock//minute-chart', methods=['GET']) - - -def get_minute_chart_data(stock_code): - """获取股票分时图数据 - 仅限 Pro/Max 会员""" - client = get_clickhouse_client() - try: - # 获取当前日期或最新交易日的分时数据 - from datetime import datetime, timedelta, time as dt_time - today = datetime.now().date() - - # 获取最新交易日的分时数据 - data = client.execute(""" - SELECT - timestamp, - open, - high, - low, - close, - volume, - amt - FROM stock_minute - WHERE code = %(code)s - AND timestamp >= %(start)s - AND timestamp <= %(end)s - ORDER BY timestamp - """, { - 'code': stock_code, - 'start': datetime.combine(today, dt_time(9, 30)), - 'end': datetime.combine(today, dt_time(15, 0)) - }) - - # 如果今天没有数据,获取最近的交易日数据 - if not data: - # 获取最近的交易日数据 - recent_data = client.execute(""" - SELECT - timestamp, - open, - high, - low, - close, - volume, - amt - FROM stock_minute - WHERE code = %(code)s - AND timestamp >= ( - SELECT MAX(timestamp) - INTERVAL 1 DAY - FROM stock_minute - WHERE code = %(code)s - ) - ORDER BY timestamp - """, { - 'code': stock_code - }) - data = recent_data - - # 格式化数据 - minute_data = [] - for row in data: - minute_data.append({ - 'time': row[0].strftime('%H:%M'), - 'open': float(row[1]) if row[1] else None, - 'high': float(row[2]) if row[2] else None, - 'low': float(row[3]) if row[3] else None, - 'close': float(row[4]) if row[4] else None, - 'volume': float(row[5]) if row[5] else None, - 'amount': float(row[6]) if row[6] else None - }) - - return minute_data - except Exception as e: - print(f"Error getting minute chart data: {e}") - return [] - - -@app.route('/api/event//stock//detail', methods=['GET']) - - -def api_stock_detail(event_id, stock_code): - """个股详情接口 - 仅限 Pro/Max 会员""" - try: - # 验证事件是否存在 - event = Event.query.get_or_404(event_id) - - # 获取查询参数 - include_minute_data = request.args.get('include_minute_data', 'true').lower() == 'true' - include_full_sources = request.args.get('include_full_sources', 'false').lower() == 'true' # 是否包含完整研报来源 - - # 获取股票基本信息 - basic_info = None - base_code = stock_code.split('.')[0] # 去掉后缀 - - # 按优先级查找股票信息 - basic_info = StockBasicInfo.query.filter_by(SECCODE=stock_code).first() - if not basic_info: - basic_info = StockBasicInfo.query.filter( - StockBasicInfo.SECCODE.ilike(f"{stock_code}%") - ).first() - if not basic_info: - basic_info = StockBasicInfo.query.filter( - StockBasicInfo.SECCODE.ilike(f"{base_code}%") - ).first() - - company_info = CompanyInfo.query.filter_by(SECCODE=stock_code).first() - if not company_info: - company_info = CompanyInfo.query.filter_by(SECCODE=base_code).first() - - if not basic_info: - return jsonify({ - 'code': 404, - 'stock_code': stock_code, - 'message': '股票不存在', - 'data': None - }), 404 - - # 获取最新交易数据 - latest_trade = TradeData.query.filter_by(SECCODE=stock_code) \ - .order_by(TradeData.TRADEDATE.desc()).first() - if not latest_trade: - latest_trade = TradeData.query.filter_by(SECCODE=base_code) \ - .order_by(TradeData.TRADEDATE.desc()).first() - - # 获取分时数据 - minute_chart_data = [] - if include_minute_data: - minute_chart_data = get_minute_chart_data(stock_code) - - # 获取该事件的相关描述 - related_stock = RelatedStock.query.filter_by( - event_id=event_id - ).filter( - db.or_( - RelatedStock.stock_code == stock_code, - RelatedStock.stock_code == base_code, - RelatedStock.stock_code.like(f"{base_code}.%") - ) - ).first() - - related_desc = None - if related_stock: - # 处理研报来源数据 - retrieved_sources_data = None - sources_summary = None - - if related_stock.retrieved_sources: - try: - # 解析研报来源 - import json - sources = related_stock.retrieved_sources if isinstance(related_stock.retrieved_sources, - list) else json.loads( - related_stock.retrieved_sources) - - # 统计信息 - sources_summary = { - 'total_count': len(sources), - 'has_sources': True, - 'match_scores': {} - } - - # 统计匹配分数分布 - for source in sources: - score = source.get('match_score', '未知') - sources_summary['match_scores'][score] = sources_summary['match_scores'].get(score, 0) + 1 - - # 根据参数决定返回完整数据还是摘要 - if include_full_sources: - # 返回完整的研报来源 - retrieved_sources_data = sources - else: - # 只返回前5条高质量来源作为预览 - # 优先返回匹配度高的 - high_quality_sources = [s for s in sources if s.get('match_score') == '好'][:3] - medium_quality_sources = [s for s in sources if s.get('match_score') == '中'][:2] - - preview_sources = high_quality_sources + medium_quality_sources - if not preview_sources: # 如果没有高中匹配度的,返回前5条 - preview_sources = sources[:5] - - retrieved_sources_data = [] - for source in preview_sources: - retrieved_sources_data.append({ - 'report_title': source.get('report_title', ''), - 'author': source.get('author', ''), - 'sentences': source.get('sentences', '')[:200] + '...' if len( - source.get('sentences', '')) > 200 else source.get('sentences', ''), # 限制长度 - 'match_score': source.get('match_score', ''), - 'declare_date': source.get('declare_date', '') - }) - - except Exception as e: - print(f"Error processing retrieved_sources for stock {stock_code}: {e}") - sources_summary = {'has_sources': False, 'error': str(e)} - else: - sources_summary = {'has_sources': False, 'total_count': 0} - - related_desc = { - 'event_id': related_stock.event_id, - 'relation_desc': related_stock.relation_desc, - 'sector': related_stock.sector, - 'correlation': float(related_stock.correlation) if related_stock.correlation else None, - 'momentum': related_stock.momentum, - - # 新增研报来源相关字段 - 'retrieved_sources': retrieved_sources_data, - 'sources_summary': sources_summary, - 'retrieved_update_time': related_stock.retrieved_update_time.isoformat() if related_stock.retrieved_update_time else None, - - # 添加获取完整来源的URL - 'sources_detail_url': f"/api/event/{event_id}/stock/{stock_code}/sources" if sources_summary.get( - 'has_sources') else None - } - - response_data = { - 'code': 200, - 'message': 'success', - 'data': { - 'event_info': { - 'event_id': event.id, - 'event_title': event.title, - 'event_description': event.description - }, - 'basic_info': { - 'stock_code': basic_info.SECCODE, - 'stock_name': basic_info.SECNAME, - 'org_name': basic_info.ORGNAME, - 'pinyin': basic_info.F001V, - 'category': basic_info.F003V, - 'market': basic_info.F005V, - 'listing_date': basic_info.F006D.isoformat() if basic_info.F006D else None, - 'status': basic_info.F011V - }, - 'company_info': { - 'english_name': company_info.F001V if company_info else None, - 'legal_representative': company_info.F003V if company_info else None, - 'main_business': company_info.F015V if company_info else None, - 'business_scope': company_info.F016V if company_info else None, - 'company_intro': company_info.F017V if company_info else None, - 'csrc_industry_l1': company_info.F030V if company_info else None, - 'csrc_industry_l2': company_info.F032V if company_info else None - }, - 'latest_trade': { - 'trade_date': latest_trade.TRADEDATE.isoformat() if latest_trade else None, - 'close_price': float(latest_trade.F007N) if latest_trade and latest_trade.F007N else None, - 'change': float(latest_trade.F009N) if latest_trade and latest_trade.F009N else None, - 'change_pct': float(latest_trade.F010N) if latest_trade and latest_trade.F010N else None, - 'volume': float(latest_trade.F004N) if latest_trade and latest_trade.F004N else None, - 'amount': float(latest_trade.F011N) if latest_trade and latest_trade.F011N else None - } if latest_trade else None, - 'minute_chart_data': minute_chart_data, - 'related_desc': related_desc - } - } - - response = jsonify(response_data) - response.headers['Content-Type'] = 'application/json; charset=utf-8' - return response - - except Exception as e: - return jsonify({ - 'code': 500, - 'message': str(e), - 'data': None - }), 500 - -def get_stock_minute_chart_data(stock_code): - """获取股票分时图数据""" - try: - client = get_clickhouse_client() - - # 获取当前日期(使用最新的交易日) - from datetime import datetime, timedelta, time as dt_time - import csv - - def get_trading_days(): - trading_days = set() - with open('tdays.csv', 'r') as f: - reader = csv.DictReader(f) - for row in reader: - trading_days.add(datetime.strptime(row['DateTime'], '%Y/%m/%d').date()) - return trading_days - - trading_days = get_trading_days() - - def find_latest_trading_day(current_date): - """找到最新的交易日""" - while current_date >= min(trading_days): - if current_date in trading_days: - return current_date - current_date -= timedelta(days=1) - return None - - target_date = find_latest_trading_day(datetime.now().date()) - - if not target_date: - return [] - - # 获取分时数据 - data = client.execute(""" - SELECT - timestamp, - open, - high, - low, - close, - volume, - amt - FROM stock_minute - WHERE code = %(code)s - AND timestamp BETWEEN %(start)s AND %(end)s - ORDER BY timestamp - """, { - 'code': stock_code, - 'start': datetime.combine(target_date, dt_time(9, 30)), - 'end': datetime.combine(target_date, dt_time(15, 0)) - }) - - minute_data = [] - for row in data: - minute_data.append({ - 'time': row[0].strftime('%H:%M'), - 'open': float(row[1]), - 'high': float(row[2]), - 'low': float(row[3]), - 'close': float(row[4]), - 'volume': float(row[5]), - 'amount': float(row[6]) - }) - - return minute_data - - except Exception as e: - print(f"Error getting minute chart data: {e}") - return [] - - -# 7. 事件详情-相关概念接口 -@app.route('/api/event//related-concepts', methods=['GET']) - - -def api_event_related_concepts(event_id): - """事件相关概念接口""" - try: - event = Event.query.get_or_404(event_id) - related_concepts = event.related_concepts.all() - base_url = request.host_url - - concepts_data = [] - for concept in related_concepts: - image_paths = concept.image_paths_list - image_urls = [base_url + 'data/concepts/' + p for p in image_paths] - concepts_data.append({ - 'id': concept.id, - 'concept_code': concept.concept_code, - 'concept': concept.concept, - 'reason': concept.reason, - 'image_paths': image_paths, - 'image_urls': image_urls, - 'first_image': image_urls[0] if image_urls else None - }) - - return jsonify({ - 'code': 200, - 'message': 'success', - 'data': { - 'event_id': event_id, - 'event_title': event.title, - 'related_concepts': concepts_data, - 'total_count': len(concepts_data) - } - }) - except Exception as e: - return jsonify({ - 'code': 500, - 'message': str(e), - 'data': None - }), 500 - - -# 8. 事件详情-历史事件接口 -@app.route('/api/event//historical-events', methods=['GET']) - - -def api_event_historical_events(event_id): - """事件历史事件接口""" - try: - event = Event.query.get_or_404(event_id) - historical_events = event.historical_events.order_by( - HistoricalEvent.importance.desc(), - HistoricalEvent.event_date.desc() - ).all() - - events_data = [] - for hist_event in historical_events: - # 获取相关股票信息 - related_stocks = [] - valid_changes = [] # 用于计算涨跌幅 - - for stock in hist_event.stocks.all(): - base_stock_code = stock.stock_code.split('.')[0] - # 获取股票当日交易数据 - trade_data = TradeData.query.filter( - TradeData.SECCODE.startswith(base_stock_code) - ).order_by(TradeData.TRADEDATE.desc()).first() - - if trade_data and trade_data.F010N is not None: - daily_change = float(trade_data.F010N) - valid_changes.append(daily_change) - else: - daily_change = None - - stock_data = { - 'stock_code': stock.stock_code, - 'stock_name': stock.stock_name, - 'relation_desc': stock.relation_desc, - 'correlation': stock.correlation, - 'sector': stock.sector, - 'daily_change': daily_change, - 'has_trade_data': True if trade_data else False - } - related_stocks.append(stock_data) - - # 计算相关股票的平均涨幅和最大涨幅 - avg_change = None - max_change = None - if valid_changes: - avg_change = sum(valid_changes) / len(valid_changes) - max_change = max(valid_changes) - - events_data.append({ - 'id': hist_event.id, - 'title': hist_event.title, - 'content': hist_event.content, - 'event_date': hist_event.event_date.isoformat() if hist_event.event_date else None, - 'relevance': hist_event.relevance, - 'importance': hist_event.importance, - 'related_stocks': related_stocks, - # 使用计算得到的涨幅数据 - 'related_avg_chg': round(avg_change, 2) if avg_change is not None else None, - 'related_max_chg': round(max_change, 2) if max_change is not None else None - }) - - # 计算当前事件的相关股票涨幅数据 - current_valid_changes = [] - for stock in event.related_stocks: - base_stock_code = stock.stock_code.split('.')[0] - trade_data = TradeData.query.filter( - TradeData.SECCODE.startswith(base_stock_code) - ).order_by(TradeData.TRADEDATE.desc()).first() - - if trade_data and trade_data.F010N is not None: - current_valid_changes.append(float(trade_data.F010N)) - - current_avg_change = None - current_max_change = None - if current_valid_changes: - current_avg_change = sum(current_valid_changes) / len(current_valid_changes) - current_max_change = max(current_valid_changes) - - return jsonify({ - 'code': 200, - 'message': 'success', - 'data': { - 'event_id': event_id, - 'event_title': event.title, - 'invest_score': event.invest_score, - 'related_avg_chg': round(current_avg_change, 2) if current_avg_change is not None else None, - 'related_max_chg': round(current_max_change, 2) if current_max_change is not None else None, - 'historical_events': events_data, - 'total_count': len(events_data) - } - }) - except Exception as e: - print(f"Error in api_event_historical_events: {str(e)}") - return jsonify({ - 'code': 500, - 'message': str(e), - 'data': None - }), 500 - - -@app.route('/api/event//comments', methods=['GET']) - - -def get_event_comments(event_id): - """获取事件的所有评论和帖子(嵌套格式) - - Query参数: - - page: 页码(默认1) - - per_page: 每页评论数(默认20) - - sort: 排序方式(time_desc/time_asc/hot, 默认time_desc) - - include_posts: 是否包含帖子信息(默认true) - - reply_limit: 每个评论显示的回复数量限制(默认3) - - 返回: - { - "success": true, - "data": { - "event": { - "id": 事件ID, - "title": "事件标题", - "description": "事件描述" - }, - "posts": [帖子信息], - "total": 总评论数, - "current_page": 当前页码, - "total_pages": 总页数, - "comments": [ - { - "comment_id": 评论ID, - "content": "评论内容", - "created_at": "评论时间", - "post_id": 所属帖子ID, - "post_title": "帖子标题", - "user": { - "user_id": 用户ID, - "nickname": "用户昵称", - "avatar_url": "头像URL" - }, - "reply_count": 总回复数量, - "has_more_replies": 是否有更多回复, - "list": [ # 回复列表 - { - "comment_id": 回复ID, - "content": "回复内容", - "created_at": "回复时间", - "user": { - "user_id": 用户ID, - "nickname": "用户昵称", - "avatar_url": "头像URL" - }, - "reply_to": { # 被回复的用户信息 - "user_id": 用户ID, - "nickname": "用户昵称" - } - } - ] - } - ] - } - } - """ - try: - # 获取查询参数 - page = request.args.get('page', 1, type=int) - per_page = request.args.get('per_page', 20, type=int) - sort = request.args.get('sort', 'time_desc') - include_posts = request.args.get('include_posts', 'true').lower() == 'true' - reply_limit = request.args.get('reply_limit', 3, type=int) # 每个评论显示的回复数限制 - - # 参数验证 - if page < 1: - page = 1 - if per_page < 1 or per_page > 100: - per_page = 20 - if reply_limit < 0 or reply_limit > 50: # 限制回复数量 - reply_limit = 3 - - # 获取事件信息 - event = Event.query.get_or_404(event_id) - - # 获取事件下的所有帖子 - posts_query = Post.query.filter_by(event_id=event_id, status='active') \ - .order_by(Post.is_top.desc(), Post.created_at.desc()) - posts = posts_query.all() - - # 格式化帖子数据 - posts_data = [] - if include_posts: - for post in posts: - posts_data.append({ - 'post_id': post.id, - 'title': post.title, - 'content': post.content, - 'content_type': post.content_type, - 'created_at': post.created_at.strftime('%Y-%m-%d %H:%M:%S'), - 'updated_at': post.updated_at.strftime('%Y-%m-%d %H:%M:%S') if post.updated_at else None, - 'likes_count': post.likes_count, - 'comments_count': post.comments_count, - 'view_count': post.view_count, - 'is_top': post.is_top, - 'user': { - 'user_id': post.user.id, - 'username': post.user.username, - 'nickname': post.user.nickname or post.user.username, - 'avatar_url': get_full_avatar_url(post.user.avatar_url), - 'user_level': post.user.user_level, - 'is_verified': post.user.is_verified, - 'is_creator': post.user.is_creator - } - }) - - # 获取帖子ID列表用于查询评论 - post_ids = [p.id for p in posts] - - if not post_ids: - return jsonify({ - 'success': True, - 'data': { - 'event': { - 'id': event.id, - 'title': event.title, - 'description': event.description, - 'event_type': event.event_type, - 'importance': event.importance, - 'status': event.status - }, - 'posts': posts_data, - 'total': 0, - 'current_page': page, - 'total_pages': 0, - 'comments': [] - } - }) - - # 构建基础查询 - 只查询主评论 - base_query = Comment.query.filter( - Comment.post_id.in_(post_ids), - Comment.parent_id == None, # 只查询主评论 - Comment.status == 'active' - ) - - # 排序处理 - if sort == 'time_asc': - base_query = base_query.order_by(Comment.created_at.asc()) - elif sort == 'hot': - # 这里可以根据你的业务逻辑添加热度排序 - base_query = base_query.order_by(Comment.created_at.desc()) - else: # 默认按时间倒序 - base_query = base_query.order_by(Comment.created_at.desc()) - - # 执行分页查询 - pagination = base_query.paginate(page=page, per_page=per_page, error_out=False) - - # 格式化评论数据(嵌套格式) - comments_data = [] - for comment in pagination.items: - # 获取评论的总回复数量 - reply_count = Comment.query.filter_by( - parent_id=comment.id, - status='active' - ).count() - - # 获取指定数量的回复 - replies_query = Comment.query.filter_by( - parent_id=comment.id, - status='active' - ).order_by(Comment.created_at.asc()) # 回复按时间正序排列 - - if reply_limit > 0: - replies = replies_query.limit(reply_limit).all() - else: - replies = [] - - # 格式化回复数据 - 作为list字段 - replies_list = [] - for reply in replies: - # 获取被回复的用户信息(这里是回复主评论,所以reply_to就是主评论的用户) - reply_to_user = { - 'user_id': comment.user.id, - 'nickname': comment.user.nickname or comment.user.username - } - - replies_list.append({ - 'comment_id': reply.id, - 'content': reply.content, - 'created_at': reply.created_at.strftime('%Y-%m-%d %H:%M:%S'), - 'user': { - 'user_id': reply.user.id, - 'username': reply.user.username, - 'nickname': reply.user.nickname or reply.user.username, - 'avatar_url': get_full_avatar_url(reply.user.avatar_url), - 'user_level': reply.user.user_level, - 'is_verified': reply.user.is_verified - }, - 'reply_to': reply_to_user - }) - - # 获取评论所属的帖子信息 - post = comment.post - - # 构建嵌套格式的评论数据 - comments_data.append({ - 'comment_id': comment.id, - 'content': comment.content, - 'created_at': comment.created_at.strftime('%Y-%m-%d %H:%M:%S'), - 'post_id': comment.post_id, - 'post_title': post.title if post else None, - 'post_content_preview': post.content[:100] + '...' if post and len( - post.content) > 100 else post.content if post else None, - 'user': { - 'user_id': comment.user.id, - 'username': comment.user.username, - 'nickname': comment.user.nickname or comment.user.username, - 'avatar_url': get_full_avatar_url(comment.user.avatar_url), - 'user_level': comment.user.user_level, - 'is_verified': comment.user.is_verified - }, - 'reply_count': reply_count, - 'has_more_replies': reply_count > len(replies_list), - 'list': replies_list # 嵌套的回复列表 - }) - - return jsonify({ - 'success': True, - 'data': { - 'event': { - 'id': event.id, - 'title': event.title, - 'description': event.description, - 'event_type': event.event_type, - 'importance': event.importance, - 'status': event.status, - 'created_at': event.created_at.strftime('%Y-%m-%d %H:%M:%S'), - 'hot_score': event.hot_score, - 'view_count': event.view_count, - 'post_count': event.post_count, - 'follower_count': event.follower_count - }, - 'posts': posts_data, - 'posts_count': len(posts_data), - 'total': pagination.total, - 'current_page': pagination.page, - 'total_pages': pagination.pages, - 'comments': comments_data - } - }) - - except Exception as e: - return jsonify({ - 'success': False, - 'message': '获取评论列表失败', - 'error': str(e) - }), 500 - - -@app.route('/api/comment//replies', methods=['GET']) - - -def get_comment_replies(comment_id): - """获取某条评论的所有回复 - - Query参数: - - page: 页码(默认1) - - per_page: 每页回复数(默认20) - - sort: 排序方式(time_desc/time_asc, 默认time_desc) - - 返回格式: - { - "code": 200, - "message": "success", - "data": { - "comment": { # 原评论信息 - "id": 评论ID, - "content": "评论内容", - "created_at": "评论时间", - "user": { - "id": 用户ID, - "nickname": "用户昵称", - "avatar_url": "头像URL" - } - }, - "replies": { # 回复信息 - "total": 总回复数, - "current_page": 当前页码, - "total_pages": 总页数, - "items": [ - { - "id": 回复ID, - "content": "回复内容", - "created_at": "回复时间", - "user": { - "id": 用户ID, - "nickname": "用户昵称", - "avatar_url": "头像URL" - }, - "reply_to": { # 被回复的用户信息 - "id": 用户ID, - "nickname": "用户昵称" - } - } - ] - } - } - } - """ - try: - # 获取查询参数 - page = request.args.get('page', 1, type=int) - per_page = request.args.get('per_page', 20, type=int) - sort = request.args.get('sort', 'time_desc') - - # 参数验证 - if page < 1: - page = 1 - if per_page < 1 or per_page > 100: - per_page = 20 - - # 获取原评论信息 - comment = Comment.query.get_or_404(comment_id) - if comment.status != 'active': - return jsonify({ - 'code': 404, - 'message': '评论不存在或已被删除', - 'data': None - }), 404 - - # 构建原评论数据 - comment_data = { - 'comment_id': comment.id, - 'content': comment.content, - 'created_at': comment.created_at.strftime('%Y-%m-%d %H:%M:%S'), - 'user': { - 'user_id': comment.user.id, - 'nickname': comment.user.nickname or comment.user.username, - 'avatar_url': get_full_avatar_url(comment.user.avatar_url), # 修改这里 - } - } - - # 构建回复查询 - replies_query = Comment.query.filter_by( - parent_id=comment_id, - status='active' - ) - - # 排序处理 - if sort == 'time_asc': - replies_query = replies_query.order_by(Comment.created_at.asc()) - else: # 默认按时间倒序 - replies_query = replies_query.order_by(Comment.created_at.desc()) - - # 执行分页查询 - pagination = replies_query.paginate(page=page, per_page=per_page, error_out=False) - - # 格式化回复数据 - replies_data = [] - for reply in pagination.items: - # 获取被回复的用户信息 - reply_to_user = None - if reply.parent_id: - parent_comment = Comment.query.get(reply.parent_id) - if parent_comment: - reply_to_user = { - 'id': parent_comment.user.id, - 'nickname': parent_comment.user.nickname or parent_comment.user.username - } - - replies_data.append({ - 'reply_id': reply.id, - 'content': reply.content, - 'created_at': reply.created_at.strftime('%Y-%m-%d %H:%M:%S'), - 'user': { - 'user_id': reply.user.id, - 'nickname': reply.user.nickname or reply.user.username, - 'avatar_url': get_full_avatar_url(reply.user.avatar_url), # 修改这里 - }, - 'reply_to': reply_to_user - }) - - return jsonify({ - 'code': 200, - 'message': 'success', - 'data': { - 'comment': comment_data, - 'replies': { - 'total': pagination.total, - 'current_page': pagination.page, - 'total_pages': pagination.pages, - 'items': replies_data - } - } - }) - - except Exception as e: - return jsonify({ - 'code': 500, - 'message': str(e), - 'data': None - }), 500 - - -# 10. 投资日历-事件接口(增强版) -@app.route('/api/calendar/events', methods=['GET']) - - -def api_calendar_events(): - """投资日历事件接口 - 连接 future_events 表 (修正版)""" - try: - start_date = request.args.get('start') - end_date = request.args.get('end') - importance = request.args.get('importance', 'all') - category = request.args.get('category', 'all') - search_query = request.args.get('q', '').strip() # 新增搜索参数 - page = int(request.args.get('page', 1)) - per_page = int(request.args.get('per_page', 10)) - offset = (page - 1) * per_page - - # 构建基础查询 - 使用 future_events 表 - query = """ - SELECT - data_id, - calendar_time, - type, - star, - title, - former, - forecast, - fact, - related_stocks, - concepts, - inferred_tag - FROM future_events - WHERE 1=1 - """ - - params = {} - - if start_date: - query += " AND calendar_time >= :start_date" - params['start_date'] = datetime.fromisoformat(start_date) - if end_date: - query += " AND calendar_time <= :end_date" - params['end_date'] = datetime.fromisoformat(end_date) - if importance != 'all': - query += " AND star = :importance" - params['importance'] = importance - if category != 'all': - # category参数用于筛选inferred_tag字段(如"大周期"、"大消费"等) - query += " AND inferred_tag = :category" - params['category'] = category - - # 新增搜索条件 - if search_query: - # 使用LIKE进行模糊搜索,同时搜索title和related_stocks字段 - # 对于JSON字段,MySQL会将其作为文本进行搜索 - query += """ AND ( - title LIKE :search_pattern - OR CAST(related_stocks AS CHAR) LIKE :search_pattern - OR CAST(concepts AS CHAR) LIKE :search_pattern - )""" - params['search_pattern'] = f'%{search_query}%' - - query += " ORDER BY calendar_time LIMIT :limit OFFSET :offset" - params['limit'] = per_page - params['offset'] = offset - - result = db.session.execute(text(query), params) - events = result.fetchall() - - # 总数统计(不包含分页) - count_query = """ - SELECT COUNT(*) as count FROM future_events WHERE 1=1 - """ - count_params = params.copy() - count_params.pop('limit', None) - count_params.pop('offset', None) - - if start_date: - count_query += " AND calendar_time >= :start_date" - if end_date: - count_query += " AND calendar_time <= :end_date" - if importance != 'all': - count_query += " AND star = :importance" - if category != 'all': - count_query += " AND inferred_tag = :category" - - # 新增搜索条件到计数查询 - if search_query: - count_query += """ AND ( - title LIKE :search_pattern - OR CAST(related_stocks AS CHAR) LIKE :search_pattern - OR CAST(concepts AS CHAR) LIKE :search_pattern - )""" - - total_count_result = db.session.execute(text(count_query), count_params).fetchone() - total_count = total_count_result.count if total_count_result else 0 - - events_data = [] - for event in events: - # 解析相关股票 - related_stocks_list = [] - related_avg_chg = 0 - related_max_chg = 0 - related_week_chg = 0 - - # 处理相关股票数据 - if event.related_stocks: - try: - import json - import ast - - # 使用与detail接口相同的解析逻辑 - if isinstance(event.related_stocks, str): - try: - stock_data = json.loads(event.related_stocks) - except: - stock_data = ast.literal_eval(event.related_stocks) - else: - stock_data = event.related_stocks - - if stock_data: - daily_changes = [] - week_changes = [] - - # 处理正确的数据格式 [股票代码, 股票名称, 描述, 分数] - for stock_info in stock_data: - if isinstance(stock_info, list) and len(stock_info) >= 2: - stock_code = stock_info[0] # 股票代码 - stock_name = stock_info[1] # 股票名称 - description = stock_info[2] if len(stock_info) > 2 else '' - score = stock_info[3] if len(stock_info) > 3 else 0 - else: - continue - - if stock_code: - # 规范化股票代码,移除后缀 - clean_code = stock_code.replace('.SZ', '').replace('.SH', '').replace('.BJ', '') - - # 使用模糊匹配查询真实的交易数据 - trade_query = """ - SELECT F007N as close_price, F010N as change_pct, TRADEDATE - FROM ea_trade - WHERE SECCODE LIKE :stock_code_pattern - ORDER BY TRADEDATE DESC - LIMIT 7 - """ - trade_result = db.session.execute(text(trade_query), - {'stock_code_pattern': f'{clean_code}%'}) - trade_data = trade_result.fetchall() - - daily_chg = 0 - week_chg = 0 - - if trade_data: - # 日涨跌幅(当日) - daily_chg = float(trade_data[0].change_pct or 0) - - # 周涨跌幅(5个交易日) - if len(trade_data) >= 5: - current_price = float(trade_data[0].close_price or 0) - week_ago_price = float(trade_data[4].close_price or 0) - if week_ago_price > 0: - week_chg = ((current_price - week_ago_price) / week_ago_price) * 100 - - # 收集涨跌幅数据 - daily_changes.append(daily_chg) - week_changes.append(week_chg) - - related_stocks_list.append({ - 'code': stock_code, - 'name': stock_name, - 'description': description, - 'score': score, - 'daily_chg': daily_chg, - 'week_chg': week_chg - }) - - # 计算平均收益率 - if daily_changes: - related_avg_chg = round(sum(daily_changes) / len(daily_changes), 4) - related_max_chg = round(max(daily_changes), 4) - - if week_changes: - related_week_chg = round(sum(week_changes) / len(week_changes), 4) - - except Exception as e: - print(f"Error processing related stocks for event {event.data_id}: {e}") - - # 解析相关概念 - related_concepts = extract_concepts_from_concepts_field(event.concepts) - - # 获取评星等级 - star_rating = event.star - - # 如果有搜索关键词,可以高亮显示匹配的部分(可选功能) - highlight_match = False - if search_query: - # 检查是否在标题中匹配 - if search_query.lower() in (event.title or '').lower(): - highlight_match = 'title' - # 检查是否在股票中匹配 - elif any(search_query.lower() in str(stock).lower() for stock in related_stocks_list): - highlight_match = 'stocks' - # 检查是否在概念中匹配 - elif search_query.lower() in str(related_concepts).lower(): - highlight_match = 'concepts' - - event_dict = { - 'id': event.data_id, - 'title': event.title, - 'description': f"前值: {event.former}, 预测: {event.forecast}, 实际: {event.fact}" if event.former or event.forecast or event.fact else "", - 'start_time': event.calendar_time.isoformat() if event.calendar_time else None, - 'end_time': None, # future_events 表没有结束时间 - 'category': { - 'event_type': event.type, - 'importance': event.star, - 'star_rating': star_rating, - 'inferred_tag': event.inferred_tag # 添加inferred_tag到返回数据 - }, - 'star_rating': star_rating, - 'inferred_tag': event.inferred_tag, # 直接返回行业标签 - 'related_concepts': related_concepts, - 'related_stocks': related_stocks_list, - 'related_avg_chg': round(related_avg_chg, 2), - 'related_max_chg': round(related_max_chg, 2), - 'related_week_chg': round(related_week_chg, 2), - 'former': event.former, - 'forecast': event.forecast, - 'fact': event.fact - } - - # 可选:添加搜索匹配标记 - if search_query and highlight_match: - event_dict['search_match'] = highlight_match - - events_data.append(event_dict) - - return jsonify({ - 'code': 200, - 'message': 'success', - 'data': { - 'events': events_data, - 'total_count': total_count, - 'page': page, - 'per_page': per_page, - 'total_pages': (total_count + per_page - 1) // per_page, - 'search_query': search_query # 返回搜索关键词 - } - }) - - except Exception as e: - return jsonify({ - 'code': 500, - 'message': str(e), - 'data': None - }), 500 - - -# 11. 投资日历-数据接口 -@app.route('/api/calendar/data', methods=['GET']) - - -def api_calendar_data(): - """投资日历数据接口""" - try: - start_date = request.args.get('start') - end_date = request.args.get('end') - data_type = request.args.get('type', 'all') - - # 分页参数 - page = int(request.args.get('page', 1)) - page_size = int(request.args.get('page_size', 20)) # 默认每页20条 - - # 验证分页参数 - if page < 1: - page = 1 - if page_size < 1 or page_size > 100: # 限制每页最大100条 - page_size = 20 - - query1 = RelatedData.query - - if start_date: - query1 = query1.filter(RelatedData.created_at >= datetime.fromisoformat(start_date)) - if end_date: - query1 = query1.filter(RelatedData.created_at <= datetime.fromisoformat(end_date)) - if data_type != 'all': - query1 = query1.filter_by(data_type=data_type) - - data_list1 = query1.order_by(RelatedData.created_at.desc()).all() - - query2_sql = """ - SELECT - data_id as id, - title, - type as data_type, - former, - forecast, - fact, - star, - calendar_time as created_at - FROM future_events - WHERE type = 'data' - """ - - # 添加时间筛选条件 - params = {} - if start_date: - query2_sql += " AND calendar_time >= :start_date" - params['start_date'] = start_date - if end_date: - query2_sql += " AND calendar_time <= :end_date" - params['end_date'] = end_date - if data_type != 'all': - query2_sql += " AND type = :data_type" - params['data_type'] = data_type - - query2_sql += " ORDER BY calendar_time DESC" - - result2 = db.session.execute(text(query2_sql), params) - - result_data = [] - - # 处理 RelatedData 的数据 - for data in data_list1: - result_data.append({ - 'id': data.id, - 'title': data.title, - 'data_type': data.data_type, - 'data_content': data.data_content, - 'description': data.description, - 'created_at': data.created_at.isoformat() if data.created_at else None, - 'event_id': data.event_id, - 'source': 'related_data', # 标识数据来源 - 'former': None, - 'forecast': None, - 'fact': None, - 'star': None - }) - - # 处理 future_events 的数据 - for row in result2: - result_data.append({ - 'id': row.id, - 'title': row.title, - 'data_type': row.data_type, - 'data_content': None, - 'description': None, - 'created_at': row.created_at.isoformat() if row.created_at else None, - 'event_id': None, - 'source': 'future_events', # 标识数据来源 - 'former': row.former, - 'forecast': row.forecast, - 'fact': row.fact, - 'star': row.star - }) - - # 按时间排序(最新的在前面) - result_data.sort(key=lambda x: x['created_at'] or '1900-01-01', reverse=True) - - # 计算分页 - total_count = len(result_data) - total_pages = (total_count + page_size - 1) // page_size # 向上取整 - - # 计算起始和结束索引 - start_index = (page - 1) * page_size - end_index = start_index + page_size - - # 获取当前页数据 - current_page_data = result_data[start_index:end_index] - - # 分别统计两个数据源的数量(用于原有逻辑) - related_data_count = len(data_list1) - future_events_count = len(list(result2)) - - return jsonify({ - 'code': 200, - 'message': 'success', - 'data': { - 'data_list': current_page_data, - 'pagination': { - 'current_page': page, - 'page_size': page_size, - 'total_count': total_count, - 'total_pages': total_pages, - 'has_next': page < total_pages, - 'has_prev': page > 1 - }, - # 保留原有字段,便于兼容 - 'total_count': total_count, - 'related_data_count': related_data_count, - 'future_events_count': future_events_count - } - }) - except ValueError as ve: - # 处理分页参数格式错误 - return jsonify({ - 'code': 400, - 'message': f'分页参数格式错误: {str(ve)}', - 'data': None - }), 400 - except Exception as e: - return jsonify({ - 'code': 500, - 'message': str(e), - 'data': None - }), 500 - -# 12. 投资日历-详情接口 -def extract_concepts_from_concepts_field(concepts_text): - """从concepts字段中提取概念信息""" - if not concepts_text: - return [] - - try: - import json - import ast - - # 解析concepts字段的JSON/字符串数据 - if isinstance(concepts_text, str): - try: - # 先尝试JSON解析 - concepts_data = json.loads(concepts_text) - except: - # 如果JSON解析失败,尝试ast.literal_eval解析 - concepts_data = ast.literal_eval(concepts_text) - else: - concepts_data = concepts_text - - extracted_concepts = [] - for concept_info in concepts_data: - if isinstance(concept_info, list) and len(concept_info) >= 3: - concept_name = concept_info[0] # 概念名称 - reason = concept_info[1] # 原因/描述 - score = concept_info[2] # 分数 - - extracted_concepts.append({ - 'name': concept_name, - 'reason': reason, - 'score': score - }) - - return extracted_concepts - except Exception as e: - print(f"Error extracting concepts: {e}") - return [] - - -@app.route('/api/calendar/detail/', methods=['GET']) - - -def api_future_event_detail(item_id): - """未来事件详情接口 - 连接 future_events 表 (修正数据解析) - 仅限 Pro/Max 会员""" - try: - # 从 future_events 表查询事件详情 - query = """ - SELECT - data_id, - calendar_time, - type, - star, - title, - former, - forecast, - fact, - related_stocks, - concepts - FROM future_events - WHERE data_id = :item_id - """ - - result = db.session.execute(text(query), {'item_id': item_id}) - event = result.fetchone() - - if not event: - return jsonify({ - 'code': 404, - 'message': 'Event not found', - 'data': None - }), 404 - - extracted_concepts = extract_concepts_from_concepts_field(event.concepts) - - # 解析相关股票 - related_stocks_list = [] - sector_stats = { - '全部股票': 0, - '大周期': 0, - '大消费': 0, - 'TMT板块': 0, - '大金融地产': 0, - '公共产业板块': 0, - '其他': 0 - } - - # 申万一级行业到主板块的映射 - sector_map = { - # 大周期 - '石油石化': '大周期', '煤炭': '大周期', '有色金属': '大周期', - '钢铁': '大周期', '基础化工': '大周期', '建筑材料': '大周期', - '机械设备': '大周期', '电力设备及新能源': '大周期', '国防军工': '大周期', - '电力设备': '大周期', '电网设备': '大周期', '风力发电': '大周期', - '太阳能发电': '大周期', '建筑装饰': '大周期', '建筑': '大周期', - '交通运输': '大周期', '采掘': '大周期', '公用事业': '大周期', - - # 大消费 - '汽车': '大消费', '家用电器': '大消费', '酒类': '大消费', - '食品饮料': '大消费', '医药生物': '大消费', '纺织服饰': '大消费', - '农林牧渔': '大消费', '商贸零售': '大消费', '轻工制造': '大消费', - '消费者服务': '大消费', '美容护理': '大消费', '社会服务': '大消费', - '纺织服装': '大消费', '商业贸易': '大消费', '休闲服务': '大消费', - - # 大金融地产 - '银行': '大金融地产', '证券': '大金融地产', '保险': '大金融地产', - '多元金融': '大金融地产', '综合金融': '大金融地产', - '房地产': '大金融地产', '非银金融': '大金融地产', - - # TMT板块 - '计算机': 'TMT板块', '电子': 'TMT板块', '传媒': 'TMT板块', '通信': 'TMT板块', - - # 公共产业 - '环保': '公共产业板块', '综合': '公共产业板块' - } - - # 处理相关股票 - related_avg_chg = 0 - related_max_chg = 0 - related_week_chg = 0 - - if event.related_stocks: - try: - import json - import ast - - # **修正:正确解析related_stocks数据结构** - if isinstance(event.related_stocks, str): - try: - # 先尝试JSON解析 - stock_data = json.loads(event.related_stocks) - except: - # 如果JSON解析失败,尝试ast.literal_eval解析 - stock_data = ast.literal_eval(event.related_stocks) - else: - stock_data = event.related_stocks - - print(f"Parsed stock_data: {stock_data}") # 调试输出 - - if stock_data: - daily_changes = [] - week_changes = [] - - # **修正:处理正确的数据格式 [股票代码, 股票名称, 描述, 分数]** - for stock_info in stock_data: - if isinstance(stock_info, list) and len(stock_info) >= 2: - stock_code = stock_info[0] # 第一个元素是股票代码 - stock_name = stock_info[1] # 第二个元素是股票名称 - description = stock_info[2] if len(stock_info) > 2 else '' - score = stock_info[3] if len(stock_info) > 3 else 0 - else: - continue # 跳过格式不正确的数据 - - if stock_code: - # 规范化股票代码,移除后缀 - clean_code = stock_code.replace('.SZ', '').replace('.SH', '').replace('.BJ', '') - - print(f"Processing stock: {clean_code} - {stock_name}") # 调试输出 - - # 使用模糊匹配LIKE查询申万一级行业F004V - sector_query = """ - SELECT F004V as sw_primary_sector - FROM ea_sector - WHERE SECCODE LIKE :stock_code_pattern - AND F002V = '申银万国行业分类' - LIMIT 1 - """ - sector_result = db.session.execute(text(sector_query), - {'stock_code_pattern': f'{clean_code}%'}) - sector_row = sector_result.fetchone() - - # 根据申万一级行业(F004V)映射到主板块 - sw_primary_sector = sector_row.sw_primary_sector if sector_row else None - primary_sector = sector_map.get(sw_primary_sector, '其他') if sw_primary_sector else '其他' - - print( - f"Stock: {clean_code}, SW Primary: {sw_primary_sector}, Primary Sector: {primary_sector}") - - # 通过SQL查询获取真实的日涨跌幅和周涨跌幅 - trade_query = """ - SELECT F007N as close_price, F010N as change_pct, TRADEDATE - FROM ea_trade - WHERE SECCODE LIKE :stock_code_pattern - ORDER BY TRADEDATE DESC - LIMIT 7 - """ - trade_result = db.session.execute(text(trade_query), - {'stock_code_pattern': f'{clean_code}%'}) - trade_data = trade_result.fetchall() - - daily_chg = 0 - week_chg = 0 - - if trade_data: - # 日涨跌幅(当日) - daily_chg = float(trade_data[0].change_pct or 0) - - # 周涨跌幅(5个交易日) - if len(trade_data) >= 5: - current_price = float(trade_data[0].close_price or 0) - week_ago_price = float(trade_data[4].close_price or 0) - if week_ago_price > 0: - week_chg = ((current_price - week_ago_price) / week_ago_price) * 100 - - print( - f"Trade data found: {len(trade_data) if trade_data else 0} records, daily_chg: {daily_chg}") - - # 统计各分类数量 - sector_stats['全部股票'] += 1 - sector_stats[primary_sector] += 1 - - # 收集涨跌幅数据 - daily_changes.append(daily_chg) - week_changes.append(week_chg) - - related_stocks_list.append({ - 'code': stock_code, # 原始股票代码 - 'name': stock_name, # 股票名称 - 'description': description, # 关联描述 - 'score': score, # 关联分数 - 'sw_primary_sector': sw_primary_sector, # 申万一级行业(F004V) - 'primary_sector': primary_sector, # 主板块分类 - 'daily_change': daily_chg, # 真实的日涨跌幅 - 'week_change': week_chg # 真实的周涨跌幅 - }) - - # 计算平均收益率 - if daily_changes: - related_avg_chg = sum(daily_changes) / len(daily_changes) - related_max_chg = max(daily_changes) - - if week_changes: - related_week_chg = sum(week_changes) / len(week_changes) - - except Exception as e: - print(f"Error processing related stocks: {e}") - import traceback - traceback.print_exc() - - # 构建返回数据 - detail_data = { - 'id': event.data_id, - 'title': event.title, - 'type': event.type, - 'star': event.star, - 'calendar_time': event.calendar_time.isoformat() if event.calendar_time else None, - 'former': event.former, - 'forecast': event.forecast, - 'fact': event.fact, - 'concepts': event.concepts, - 'extracted_concepts': extracted_concepts, - 'related_stocks': related_stocks_list, - 'sector_stats': sector_stats, - 'related_avg_chg': round(related_avg_chg, 2), - 'related_max_chg': round(related_max_chg, 2), - 'related_week_chg': round(related_week_chg, 2) - } - - return jsonify({ - 'code': 200, - 'message': 'success', - 'data': { - 'type': 'future_event', - 'detail': detail_data - } - }) - - except Exception as e: - return jsonify({ - 'code': 500, - 'message': str(e), - 'data': None - }), 500 - - -# 13-15. 筛选弹窗接口(已有,优化格式) -@app.route('/api/filter/options', methods=['GET']) - - -def api_filter_options(): - """筛选选项接口""" - try: - # 获取排序选项 - sort_options = [ - {'key': 'new', 'name': '最新', 'desc': '按创建时间排序'}, - {'key': 'hot', 'name': '热门', 'desc': '按热度分数排序'}, - {'key': 'returns', 'name': '收益率', 'desc': '按收益率排序'}, - {'key': 'importance', 'name': '重要性', 'desc': '按重要性等级排序'}, - {'key': 'view_count', 'name': '浏览量', 'desc': '按浏览次数排序'} - ] - - # 获取行业筛选选项 - industry_options = db.session.execute(text(""" - SELECT DISTINCT f002v as classification_name, COUNT(*) as count - FROM ea_sector - WHERE f002v IS NOT NULL - GROUP BY f002v - ORDER BY f002v - """)).fetchall() - - # 获取重要性选项 - importance_options = [ - {'key': 'S', 'name': 'S级', 'desc': '重大事件'}, - {'key': 'A', 'name': 'A级', 'desc': '重要事件'}, - {'key': 'B', 'name': 'B级', 'desc': '普通事件'}, - {'key': 'C', 'name': 'C级', 'desc': '参考事件'} - ] - - return jsonify({ - 'code': 200, - 'message': 'success', - 'data': { - 'sort_options': sort_options, - 'industry_options': [{ - 'name': row.classification_name, - 'count': row.count - } for row in industry_options], - 'importance_options': importance_options - } - }) - except Exception as e: - return jsonify({ - 'code': 500, - 'message': str(e), - 'data': None - }), 500 - - -# 16-17. 会员权益接口 -@app.route('/api/membership/status', methods=['GET']) - -def api_membership_status(): - """会员状态接口""" - try: - user = request.user - - # TODO: 根据实际业务逻辑判断会员状态 - # 这里假设用户表中有会员相关字段 - is_member = getattr(user, 'is_member', False) - member_expire_date = getattr(user, 'member_expire_date', None) - - return jsonify({ - 'code': 200, - 'message': 'success', - 'data': { - 'user_id': user.id, - 'is_member': is_member, - 'member_expire_date': member_expire_date.isoformat() if member_expire_date else None, - 'user_level': user.user_level, - 'benefits': { - 'unlimited_access': is_member, - 'priority_support': is_member, - 'advanced_analytics': is_member, - 'custom_alerts': is_member - } - } - }) - except Exception as e: - return jsonify({ - 'code': 500, - 'message': str(e), - 'data': None - }), 500 - - -# 18-19. 个人中心接口 -@app.route('/api/user/profile', methods=['GET']) -@token_required -def api_user_profile(): - """个人资料接口""" - try: - user = request.user - - likes_count = PostLike.query.filter_by(user_id=user.id).count() - follows_count = EventFollow.query.filter_by(user_id=user.id).count() - comments_made = Comment.query.filter_by(user_id=user.id).count() - - comments_received = db.session.query(Comment) \ - .join(Post, Comment.post_id == Post.id) \ - .filter(Post.user_id == user.id).count() - - replies_received = Comment.query.filter( - Comment.parent_id.in_( - db.session.query(Comment.id).filter_by(user_id=user.id) - ) - ).count() - - # 总评论数(发出的评论 + 收到的评论和回复) - total_comments = comments_made + comments_received + replies_received - profile_data = { - 'basic_info': { - 'user_id': user.id, - 'username': user.username, - 'email': user.email, - 'phone': user.phone, - 'nickname': user.nickname, - 'avatar_url': get_full_avatar_url(user.avatar_url), # 修改这里 - 'bio': user.bio, - 'gender': user.gender, - 'birth_date': user.birth_date.isoformat() if user.birth_date else None, - 'location': user.location - }, - 'account_status': { - 'email_confirmed': user.email_confirmed, - 'phone_confirmed': user.phone_confirmed, - 'is_verified': user.is_verified, - 'verify_time': user.verify_time.isoformat() if user.verify_time else None, - 'created_at': user.created_at.isoformat() if user.created_at else None, - 'last_seen': user.last_seen.isoformat() if user.last_seen else None - }, - 'statistics': { - 'likes_count': likes_count, # 点赞数 - 'follows_count': follows_count, # 关注数 - 'total_comments': total_comments, # 总评论数 - 'comments_detail': { - 'comments_made': comments_made, # 发出的评论 - 'comments_received': comments_received, # 收到的评论 - 'replies_received': replies_received # 收到的回复 - } - }, - 'investment_preferences': { - 'trading_experience': user.trading_experience, - 'investment_style': user.investment_style, - 'risk_preference': user.risk_preference, - 'investment_amount': user.investment_amount, - 'preferred_markets': user.preferred_markets - }, - 'community_stats': { - 'user_level': user.user_level, - 'reputation_score': user.reputation_score, - 'contribution_point': user.contribution_point, - 'post_count': user.post_count, - 'comment_count': user.comment_count, - 'follower_count': user.follower_count, - 'following_count': user.following_count - }, - 'settings': { - 'email_notifications': user.email_notifications, - 'sms_notifications': user.sms_notifications, - 'privacy_level': user.privacy_level, - 'theme_preference': user.theme_preference - } - } - - return jsonify({ - 'code': 200, - 'message': 'success', - 'data': profile_data - }) - except Exception as e: - return jsonify({ - 'code': 500, - 'message': str(e), - 'data': None - }), 500 - - -# 在文件开头添加缓存变量 -_agreements_cache = {} -_cache_loaded = False - - -def load_agreements_from_docx(): - """从docx文件中加载协议内容,只读取一次""" - global _agreements_cache, _cache_loaded - - if _cache_loaded: - return _agreements_cache - - try: - # 定义文件路径和对应的协议类型 - docx_files = { - 'about_us': 'about_us.docx', # 关于我们 - 'service_terms': 'service_terms.docx', # 服务条款 - 'privacy_policy': 'privacy_policy.docx' # 隐私政策 - } - - # 定义协议标题 - titles = { - 'about_us': '关于我们', - 'service_terms': '服务条款', - 'privacy_policy': '隐私政策' - } - - for agreement_type, filename in docx_files.items(): - file_path = os.path.join(os.path.dirname(__file__), filename) - - if os.path.exists(file_path): - try: - # 读取docx文件 - doc = Document(file_path) - - # 提取文本内容 - content_paragraphs = [] - for paragraph in doc.paragraphs: - if paragraph.text.strip(): # 跳过空段落 - content_paragraphs.append(paragraph.text.strip()) - - # 合并所有段落 - content = '\n\n'.join(content_paragraphs) - - # 获取文件修改时间作为版本标识 - file_stat = os.stat(file_path) - last_modified = file_stat.st_mtime - - # 缓存内容 - _agreements_cache[agreement_type] = { - 'title': titles.get(agreement_type, agreement_type), - 'content': content, - 'last_updated': last_modified, - 'version': '1.0', - 'file_path': filename - } - - print(f"Successfully loaded {agreement_type} from {filename}") - - except Exception as e: - print(f"Error reading {filename}: {str(e)}") - # 如果读取失败,使用默认内容 - _agreements_cache[agreement_type] = { - 'title': titles.get(agreement_type, agreement_type), - 'content': f"协议内容正在加载中,请稍后再试。(文件:{filename})", - 'last_updated': None, - 'version': '1.0', - 'file_path': filename, - 'error': str(e) - } - else: - print(f"File not found: {filename}") - # 如果文件不存在,使用默认内容 - _agreements_cache[agreement_type] = { - 'title': titles.get(agreement_type, agreement_type), - 'content': f"协议文件未找到,请联系管理员。(文件:{filename})", - 'last_updated': None, - 'version': '1.0', - 'file_path': filename, - 'error': 'File not found' - } - - _cache_loaded = True - print(f"Agreements cache loaded successfully. Total: {len(_agreements_cache)} agreements") - - except Exception as e: - print(f"Error loading agreements: {str(e)}") - _cache_loaded = False - - return _agreements_cache - - -@app.route('/api/agreements', methods=['GET']) -def api_agreements(): - """平台协议接口 - 从docx文件读取""" - try: - # 获取查询参数 - agreement_type = request.args.get('type') # about_us, service_terms, privacy_policy - force_reload = request.args.get('reload', 'false').lower() == 'true' # 强制重新加载 - - # 如果需要强制重新加载,清除缓存 - if force_reload: - global _cache_loaded - _cache_loaded = False - _agreements_cache.clear() - - # 加载协议内容 - agreements_data = load_agreements_from_docx() - - if not agreements_data: - return jsonify({ - 'code': 500, - 'message': 'Failed to load agreements', - 'data': None - }), 500 - - # 如果指定了特定协议类型,只返回该协议 - if agreement_type and agreement_type in agreements_data: - return jsonify({ - 'code': 200, - 'message': 'success', - 'data': { - 'agreement_type': agreement_type, - **agreements_data[agreement_type] - } - }) - - # 返回所有协议 - return jsonify({ - 'code': 200, - 'message': 'success', - 'data': { - 'agreements': agreements_data, - 'available_types': list(agreements_data.keys()), - 'cache_loaded': _cache_loaded, - 'total_agreements': len(agreements_data) - } - }) - - except Exception as e: - return jsonify({ - 'code': 500, - 'message': str(e), - 'data': None - }), 500 - - -# 20. 个人中心-我的关注接口 -@app.route('/api/user/activities', methods=['GET']) -@token_required -def api_user_activities(): - """用户活动接口(我的关注、评论、点赞)""" - try: - user = request.user - activity_type = request.args.get('type', 'follows') # follows, comments, likes, commented - page = request.args.get('page', 1, type=int) - per_page = min(50, request.args.get('per_page', 20, type=int)) - - if activity_type == 'follows': - # 我的关注列表 - follows = EventFollow.query.filter_by(user_id=user.id) \ - .order_by(EventFollow.created_at.desc()) \ - .paginate(page=page, per_page=per_page, error_out=False) - - activities = [] - for follow in follows.items: - # 获取相关股票并添加单日涨幅 - related_stocks_data = [] - for stock in follow.event.related_stocks.limit(5): - # 处理股票代码,移除可能的后缀(如 .SZ 或 .SH) - base_stock_code = stock.stock_code.split('.')[0] - - # 获取股票最新交易数据 - trade_data = TradeData.query.filter( - TradeData.SECCODE.startswith(base_stock_code) - ).order_by(TradeData.TRADEDATE.desc()).first() - - # 计算单日涨幅 - daily_change = None - if trade_data and trade_data.F010N is not None: - daily_change = float(trade_data.F010N) - - related_stocks_data.append({ - 'stock_code': stock.stock_code, - 'stock_name': stock.stock_name, - 'correlation': stock.correlation, - 'daily_change': daily_change, # 新增:单日涨幅 - 'daily_change_formatted': f"{daily_change:.2f}%" if daily_change is not None else "暂无数据" - # 格式化显示 - }) - - activities.append({ - 'event_id': follow.event_id, - 'event_title': follow.event.title, - 'event_description': follow.event.description, - 'follow_time': follow.created_at.isoformat() if follow.created_at else None, - 'event_hot_score': follow.event.hot_score, - # 新增字段 - 'importance': follow.event.importance, # 重要性 - 'related_avg_chg': follow.event.related_avg_chg, # 平均涨幅 - 'related_max_chg': follow.event.related_max_chg, # 最大涨幅 - 'related_week_chg': follow.event.related_week_chg, # 周涨幅 - 'related_stocks': related_stocks_data, # 修改:包含单日涨幅的相关股票 - 'created_at': follow.event.created_at.isoformat() if follow.event.created_at else None, # 发布时间 - 'preview': follow.event.description[:200] if follow.event.description else None, # 预览(限制200字) - 'comment_count': follow.event.post_count, # 评论数 - 'view_count': follow.event.view_count, # 评论数 - 'follower_count': follow.event.follower_count # 关注数 - }) - - total = follows.total - pages = follows.pages - - elif activity_type == 'likes': - # 我的点赞列表 - likes = PostLike.query.filter_by(user_id=user.id) \ - .order_by(PostLike.created_at.desc()) \ - .paginate(page=page, per_page=per_page, error_out=False) - - activities = [{ - 'like_id': like.id, - 'post_id': like.post_id, - 'post_content': like.post.content, - 'like_time': like.created_at.isoformat() if like.created_at else None, - # 新增发布人信息 - 'author': { - 'nickname': like.post.user.nickname or like.post.user.username, - 'avatar_url': get_full_avatar_url(like.post.user.avatar_url), # 修改这里 - } - } for like in likes.items] - - total = likes.total - pages = likes.pages - - elif activity_type == 'comments': - # 我的评论列表(增强版 - 添加重要性和事件内容) - comments = Comment.query.filter_by(user_id=user.id) \ - .join(Post, Comment.post_id == Post.id) \ - .join(Event, Post.event_id == Event.id) \ - .order_by(Comment.created_at.desc()) \ - .paginate(page=page, per_page=per_page, error_out=False) - - activities = [] - for comment in comments.items: - # 通过关联路径获取事件信息:comment.post_id -> post.id -> post.event_id -> event.id - post = comment.post - event = post.event if post else None - - activity_data = { - 'comment_id': comment.id, - 'post_id': comment.post_id, - 'content': comment.content, # 评论内容 - 'created_at': comment.created_at.isoformat() if comment.created_at else None, - 'post_title': post.title if post and post.title else None, - 'post_content': post.content if post else None, - - # 新增:评论者信息(当前用户) - 'commenter': { - 'id': comment.user.id, - 'username': comment.user.username, - 'nickname': comment.user.nickname or comment.user.username, - 'avatar_url': get_full_avatar_url(comment.user.avatar_url), - 'user_level': comment.user.user_level, - 'is_verified': comment.user.is_verified - }, - - # 新增字段:事件信息 - 'event': { - 'id': event.id if event else None, - 'title': event.title if event else None, - 'description': event.description if event else None, # 事件内容 - 'importance': event.importance if event else None, # 重要性 - 'event_type': event.event_type if event else None, - 'hot_score': event.hot_score if event else None, - 'view_count': event.view_count if event else None, - 'related_avg_chg': event.related_avg_chg if event else None, - 'created_at': event.created_at.isoformat() if event and event.created_at else None - }, - - # 新增:帖子作者信息 - 'post_author': { - 'id': post.user.id if post else None, - 'username': post.user.username if post else None, - 'nickname': post.user.nickname or post.user.username if post else None, - 'avatar_url': get_full_avatar_url(post.user.avatar_url) if post else None, - } - } - activities.append(activity_data) - - total = comments.total - pages = comments.pages - - elif activity_type == 'commented': - # 评论了我的帖子 - my_posts = Post.query.filter_by(user_id=user.id).subquery() - comments = Comment.query.join(my_posts, Comment.post_id == my_posts.c.id) \ - .filter(Comment.user_id != user.id) \ - .order_by(Comment.created_at.desc()) \ - .paginate(page=page, per_page=per_page, error_out=False) - - activities = [{ - 'comment_id': comment.id, - 'comment_content': comment.content, - 'comment_time': comment.created_at.isoformat() if comment.created_at else None, - 'commenter_nickname': comment.user.nickname or comment.user.username, - 'commenter_avatar': get_full_avatar_url(comment.user.avatar_url), # 修改这里 - 'post_content': comment.post.content, - 'event_title': comment.post.event.title, - 'event_id': comment.post.event_id - } for comment in comments.items] - - total = comments.total - pages = comments.pages - - return jsonify({ - 'code': 200, - 'message': 'success', - 'data': { - 'activities': activities, - 'total': total, - 'pages': pages, - 'current_page': page - } - }) - - except Exception as e: - print(f"Error in api_user_activities: {str(e)}") - return jsonify({ - 'code': 500, - 'message': '服务器内部错误', - 'data': None - }), 500 - - -class UserFeedback(db.Model): - """用户反馈模型""" - id = db.Column(db.Integer, primary_key=True) - user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False) - type = db.Column(db.String(50), nullable=False) # 反馈类型 - content = db.Column(db.Text, nullable=False) # 反馈内容 - contact_info = db.Column(db.String(100)) # 联系方式 - status = db.Column(db.String(20), default='pending') # 状态:pending/processing/resolved/closed - admin_reply = db.Column(db.Text) # 管理员回复 - created_at = db.Column(db.DateTime, default=beijing_now) - updated_at = db.Column(db.DateTime, default=beijing_now, onupdate=beijing_now) - - # 关联关系 - user = db.relationship('User', backref='feedbacks') - - def __init__(self, user_id, type, content, contact_info=None): - self.user_id = user_id - self.type = type - self.content = content - self.contact_info = contact_info - - def to_dict(self): - return { - 'id': self.id, - 'type': self.type, - 'content': self.content, - 'contact_info': self.contact_info, - 'status': self.status, - 'admin_reply': self.admin_reply, - 'created_at': self.created_at.strftime('%Y-%m-%d %H:%M:%S'), - 'updated_at': self.updated_at.strftime('%Y-%m-%d %H:%M:%S') - } - - - -# 通用错误处理 -@app.errorhandler(404) -def api_not_found(error): - if request.path.startswith('/api/'): - return jsonify({ - 'code': 404, - 'message': '接口不存在', - 'data': None - }), 404 - return error - - -@app.errorhandler(405) -def api_method_not_allowed(error): - if request.path.startswith('/api/'): - return jsonify({ - 'code': 405, - 'message': '请求方法不允许', - 'data': None - }), 405 - return error - - - -if __name__ == '__main__': - - app.run( - host='0.0.0.0', - port=5002, - debug=True, - ssl_context=( - '/home/ubuntu/dify/docker/nginx/ssl/fullchain.pem', - '/home/ubuntu/dify/docker/nginx/ssl/privkey.pem' - ) - ) diff --git a/app_vx.py.backup b/app_vx_raw.py similarity index 61% rename from app_vx.py.backup rename to app_vx_raw.py index 68a1dd1c..6a7d98f4 100644 --- a/app_vx.py.backup +++ b/app_vx_raw.py @@ -1,24 +1,38 @@ +import hmac +import json +import base64 +from hashlib import sha1 import csv import logging -import random import re +from urllib.parse import urlencode, quote import math import os -import secrets -import string - import pytz import requests from flask_compress import Compress -from functools import wraps +from collections import defaultdict +import jieba +import jieba.analyse +from functools import lru_cache, wraps +import threading from pathlib import Path +import pickle +import psutil import time -from sqlalchemy import create_engine, text, func, or_, case, event, desc, asc -from flask import Flask, has_request_context, render_template, request, jsonify, redirect, url_for, flash, session, render_template_string, current_app, send_from_directory +import gc +from typing import Dict, Any, Optional, Tuple +import pandas as pd +from sqlalchemy import Column, Integer, String, Boolean, DateTime, create_engine, text, func, or_, case, event, desc, \ + JSON, asc +from flask import Flask, session, has_request_context, render_template, request, jsonify, redirect, url_for, flash, \ + session, render_template_string, current_app, send_from_directory from flask_sqlalchemy import SQLAlchemy from flask_login import LoginManager, UserMixin, login_user, logout_user, login_required, current_user from flask_mail import Mail, Message from itsdangerous import URLSafeTimedSerializer +import random +import string from flask_migrate import Migrate from flask_session import Session # type: ignore from sqlalchemy.dialects.mysql.base import MySQLDialect @@ -26,21 +40,22 @@ from sqlalchemy.dialects.postgresql import JSONB from werkzeug.utils import secure_filename from PIL import Image from datetime import datetime, timedelta, time as dt_time +import hashlib from werkzeug.security import generate_password_hash, check_password_hash import json +from config import STOP_WORDS from clickhouse_driver import Client as Cclient import jwt +import uuid +import redis from docx import Document -from tencentcloud.common import credential -from tencentcloud.common.profile.client_profile import ClientProfile -from tencentcloud.common.profile.http_profile import HttpProfile -from tencentcloud.sms.v20210111 import sms_client, models -from tencentcloud.common.exception.tencent_cloud_sdk_exception import TencentCloudSDKException -engine = create_engine("mysql+pymysql://root:Zzl5588161!@222.128.1.157:33060/stock", echo=False, pool_size=20, +# 初始化 Flask-Migrate + +engine = create_engine("mysql+pymysql://root:Zzl5588161!@111.198.58.126:33060/stock", echo=False, pool_size=20, max_overflow=50) -engine_med = create_engine("mysql+pymysql://root:Zzl5588161!@222.128.1.157:33060/med", echo=False) -engine_2 = create_engine("mysql+pymysql://root:Zzl5588161!@222.128.1.157:33060/valuefrontier", echo=False) +engine_med = create_engine("mysql+pymysql://root:Zzl5588161!@111.198.58.126:33060/med", echo=False) +engine_2 = create_engine("mysql+pymysql://root:Zzl5588161!@111.198.58.126:33060/valuefrontier", echo=False) logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) app = Flask(__name__) @@ -58,9 +73,17 @@ app.config['COMPRESS_MIMETYPES'] = [ 'application/javascript', 'application/x-javascript' ] -user_tokens = {} +# Redis 初始化 +redis_client = redis.StrictRedis( + host='43.143.189.195', # 改为你的 Redis 服务器地址 + port=6379, + password='Zzl338180', + db=0, + decode_responses=True +) + app.config['SECRET_KEY'] = 'vf7891574233241' -app.config['SQLALCHEMY_DATABASE_URI'] = 'mysql+pymysql://root:Zzl5588161!@222.128.1.157:33060/stock?charset=utf8mb4' +app.config['SQLALCHEMY_DATABASE_URI'] = 'mysql+pymysql://root:Zzl5588161!@111.198.58.126:33060/stock?charset=utf8mb4' app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False app.config['JSON_AS_ASCII'] = False @@ -76,34 +99,24 @@ app.config['MAIL_PASSWORD'] = 'QYncRu6WUdASvTg4' app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER app.config['MAX_CONTENT_LENGTH'] = MAX_CONTENT_LENGTH -# 腾讯云短信配置 -SMS_SECRET_ID = 'AKID2we9TacdTAhCjCSYTErHVimeJo9Yr00s' -SMS_SECRET_KEY = 'pMlBWijlkgT9fz5ziEXdWEnAPTJzRfkf' -SMS_SDK_APP_ID = "1400972398" -SMS_SIGN_NAME = "价值前沿科技" -SMS_TEMPLATE_REGISTER = "2386557" # 注册模板 -SMS_TEMPLATE_LOGIN = "2386540" # 登录模板 -verification_codes = {} +# 短信验证模块 +app.config['QINIU_ACCESS_KEY'] = "0MIwksc8RvcNten1iUdTbtwB6orPOfzRiqYTXVOU" +app.config['QINIU_SECRET_KEY'] = "QmjyOi27XoLtBccv5AAI6khIcJncLfr0ErjSMu_i" +app.config['QINIU_TEMPLATE_ID'] = "1901640687774875648" +app.config['QINIU_SIGNATURE_ID'] = "1900745528702943232" -#微信小程序 -app.config['WECHAT_APP_ID'] = 'wx0edeaab76d4fa414' -app.config['WECHAT_APP_SECRET'] = '0d0c70084f05a8c1411f6b89da7e815d' +app.config['WECHAT_APP_ID'] = 'wxa8d74c47041b5f87' +app.config['WECHAT_APP_SECRET'] = 'eedef95b11787fd7ca7f1acc6c9061bc' app.config['BASE_URL'] = 'http://43.143.189.195:5002' app.config['WECHAT_REDIRECT_URI'] = f"{app.config['BASE_URL']}/api/wechat/callback" -WECHAT_APP_ID = 'wx0edeaab76d4fa414' -WECHAT_APP_SECRET = '0d0c70084f05a8c1411f6b89da7e815d' -JWT_SECRET_KEY = 'vfllmgreat33818!' # 请修改为安全的密钥 -JWT_ALGORITHM = 'HS256' -JWT_EXPIRATION_HOURS = 24 * 7 # Token有效期7天 -# Session 配置 - 使用文件系统存储(替代 Redis) -app.config['SESSION_TYPE'] = 'filesystem' -app.config['SESSION_FILE_DIR'] = os.path.join(os.path.dirname(__file__), 'flask_session') -app.config['SESSION_PERMANENT'] = True -app.config['PERMANENT_SESSION_LIFETIME'] = timedelta(days=7) # Session 有效期 7 天 - -# 确保 session 目录存在 -os.makedirs(app.config['SESSION_FILE_DIR'], exist_ok=True) +app.config['CACHE_TYPE'] = 'redis' +app.config['CACHE_REDIS_HOST'] = '43.143.189.195' # 使用实际的服务器IP +app.config['CACHE_REDIS_PORT'] = 6379 +app.config['CACHE_REDIS_PASSWORD'] = 'Zzl338180' +app.config['CACHE_DEFAULT_TIMEOUT'] = 300 +app.config['SESSION_TYPE'] = 'redis' +app.config['SESSION_REDIS'] = redis.Redis(host='43.143.189.195', port=6379, password='Zzl338180') # Cache directory setup CACHE_DIR = Path('cache') @@ -126,50 +139,202 @@ migrate = Migrate(app, db) DOMAIN = 'http://43.143.189.195:5002' JWT_SECRET = 'Llmgreat123' +JWT_ALGORITHM = 'HS256' JWT_EXPIRES_SECONDS = 3600 # 1小时有效期 Session(app) def token_required(f): - """装饰器:需要token认证的接口""" - from functools import wraps - @wraps(f) - def decorated_function(*args, **kwargs): + def decorated(*args, **kwargs): token = None - # 从请求头获取token - auth_header = request.headers.get('Authorization') - if auth_header and auth_header.startswith('Bearer '): - token = auth_header[7:] + # 从请求头中提取 Authorization: Bearer + if 'Authorization' in request.headers: + auth_header = request.headers['Authorization'] + if auth_header.startswith('Bearer '): + token = auth_header.split(" ")[1] if not token: - return jsonify({'message': '缺少认证token'}), 401 + return jsonify({'code': 401, 'message': '未提供 Token'}), 401 - token_data = user_tokens.get(token) - if not token_data: - return jsonify({'message': 'Token无效','code':401}), 401 + try: + payload = jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALGORITHM]) + user = User.query.get(payload['user_id']) + if not user: + return jsonify({'code': 401, 'message': '无效 Token'}), 401 + # 动态绑定用户对象(可选) + request.user = user - # 检查是否过期 - if token_data['expires'] < datetime.now(): - del user_tokens[token] - return jsonify({'message': 'Token已过期'}), 401 - - # 获取用户对象并添加到请求上下文 - user = User.query.get(token_data['user_id']) - if not user: - return jsonify({'message': '用户不存在'}), 404 - - # 将用户对象添加到request - request.user = user - request.current_user_id = token_data['user_id'] + except jwt.ExpiredSignatureError: + return jsonify({'code': 401, 'message': 'Token 已过期'}), 401 + except jwt.InvalidTokenError: + return jsonify({'code': 401, 'message': '无效 Token'}), 401 return f(*args, **kwargs) - return decorated_function + return decorated +def generate_qiniu_token(access_key, secret_key, method, path, host, content_type=None, body=None): + """生成七牛云认证令牌""" + # 步骤1-2: 添加方法和路径 + data = f"{method} {path}" + + # 步骤3: 添加主机 + data += f"\nHost: {host}" + + # 步骤4: 添加内容类型(如果存在) + if content_type: + data += f"\nContent-Type: {content_type}" + + # 步骤5: 添加回车 + data += "\n\n" + + # 步骤6: 添加请求体 + if body and content_type and content_type != "application/octet-stream": + data += body + + # 计算HMAC-SHA1签名并编码 + sign = hmac.new(secret_key.encode(), data.encode(), sha1).digest() + encoded_sign = base64.urlsafe_b64encode(sign).decode('utf-8') + + # 生成七牛令牌 + return f"Qiniu {access_key}:{encoded_sign}" + + +def send_sms_verification_minimal(phone_number, redis_client): + """ + 最小化版本 - 使用urllib避免requests库问题 + """ + import urllib.request + import urllib.parse + import urllib.error + import json as json_lib + import hmac + from hashlib import sha1 + import base64 + import random + import string + import ssl + + try: + # 生成6位数字验证码 + verification_code = ''.join(random.choices(string.digits, k=6)) + print(f"生成验证码: {verification_code}") + + # 存储验证码到 Redis + redis_key = f"sms:login:{phone_number}" # 统一key格式 + redis_client.setex(redis_key, 600, verification_code) + print(f"验证码已存储到Redis,key: {redis_key}") + + # 硬编码配置 + QINIU_ACCESS_KEY = "0MIwksc8RvcNten1iUdTbtwB6orPOfzRiqYTXVOU" + QINIU_SECRET_KEY = "QmjyOi27XoLtBccv5AAI6khIcJncLfr0ErjSMu_i" + QINIU_TEMPLATE_ID = "1901640687774875648" + QINIU_SIGNATURE_ID = "1900745528702943232" + + # 准备请求数据 + host = "sms.qiniuapi.com" + path = "/v1/message/single" + method = "POST" + content_type = "application/json" + + request_data = { + "template_id": QINIU_TEMPLATE_ID, + "mobile": phone_number, + "parameters": { + "code": verification_code + }, + "signature_id": QINIU_SIGNATURE_ID + } + + body = json_lib.dumps(request_data) + print("请求数据准备完成") + + # 生成签名 + data = f"{method} {path}" + data += f"\nHost: {host}" + data += f"\nContent-Type: {content_type}" + data += "\n\n" + data += body + + sign = hmac.new(QINIU_SECRET_KEY.encode(), data.encode(), sha1).digest() + encoded_sign = base64.urlsafe_b64encode(sign).decode('utf-8') + token = f"Qiniu {QINIU_ACCESS_KEY}:{encoded_sign}" + print("签名生成完成") + + # 创建请求 + url = f"https://{host}{path}" + headers = { + "Content-Type": content_type, + "Authorization": token, + "User-Agent": "ValueFrontier/1.0" + } + + # 使用urllib发送请求 + req = urllib.request.Request( + url, + data=body.encode('utf-8'), + headers=headers, + method=method + ) + + # 创建SSL上下文,忽略证书验证 + ctx = ssl.create_default_context() + ctx.check_hostname = False + ctx.verify_mode = ssl.CERT_NONE + + print("开始发送请求") + + # 发送请求 + with urllib.request.urlopen(req, timeout=10, context=ctx) as response: + response_data = response.read().decode('utf-8') + result = json_lib.loads(response_data) + + print(f"响应: {result}") + print(f"实际发送的验证码: {verification_code}") # 重要:记录实际验证码 + + # 检查结果 + if 'message_id' in result: + print(f"发送成功: {phone_number}, message_id: {result.get('message_id')}") + return True, "验证码发送成功", verification_code # 返回验证码用于测试 + else: + error_msg = result.get('message', '发送失败') + print(f"发送失败: {error_msg}") + return False, f"发送失败: {error_msg}", None + + except urllib.error.HTTPError as e: + print(f"HTTP错误: {e.code} - {e.reason}") + return False, f"网络错误: {e.code}" + except urllib.error.URLError as e: + print(f"URL错误: {e.reason}") + return False, "网络连接失败" + except Exception as e: + print(f"其他异常: {type(e).__name__} - {str(e)}") + return False, "发送失败,请重试" + + +def verify_sms_code(phone_number, code): + """验证短信验证码""" + stored_code = session.get('sms_verification_code') + stored_phone = session.get('sms_verification_phone') + expiration = session.get('sms_verification_expiration') + + if not all([stored_code, stored_phone, expiration]): + return False, "请先获取验证码" + + if stored_phone != phone_number: + return False, "手机号与验证码不匹配" + + if beijing_now().timestamp() > expiration: + return False, "验证码已过期" + + if code != stored_code: + return False, "验证码错误" + + return True, "验证成功" def beijing_now(): @@ -178,216 +343,6 @@ def beijing_now(): return datetime.now(beijing_tz) -# ============================================ -# 订阅功能模块(与 app.py 保持一致) -# ============================================ -class UserSubscription(db.Model): - """用户订阅表 - 独立于现有User表""" - __tablename__ = 'user_subscriptions' - - id = db.Column(db.Integer, primary_key=True, autoincrement=True) - user_id = db.Column(db.Integer, nullable=False, unique=True, index=True) - subscription_type = db.Column(db.String(10), nullable=False, default='free') - subscription_status = db.Column(db.String(20), nullable=False, default='active') - start_date = db.Column(db.DateTime, nullable=True) - end_date = db.Column(db.DateTime, nullable=True) - billing_cycle = db.Column(db.String(10), nullable=True) - auto_renewal = db.Column(db.Boolean, nullable=False, default=False) - created_at = db.Column(db.DateTime, default=beijing_now) - updated_at = db.Column(db.DateTime, default=beijing_now, onupdate=beijing_now) - - def is_active(self): - if self.subscription_status != 'active': - return False - if self.subscription_type == 'free': - return True - if self.end_date: - # 将数据库的 naive datetime 转换为带时区的 aware datetime - beijing_tz = pytz.timezone('Asia/Shanghai') - end_date_aware = self.end_date if self.end_date.tzinfo else beijing_tz.localize(self.end_date) - return beijing_now() <= end_date_aware - return True - - def days_left(self): - if not self.is_active(): - return 0 - if self.subscription_type == 'free': - return 999 - if not self.end_date: - return 999 - try: - now = beijing_now() - # 将数据库的 naive datetime 转换为带时区的 aware datetime - beijing_tz = pytz.timezone('Asia/Shanghai') - end_date_aware = self.end_date if self.end_date.tzinfo else beijing_tz.localize(self.end_date) - delta = end_date_aware - now - return max(0, delta.days) - except Exception as e: - return 0 - - def to_dict(self): - return { - 'type': self.subscription_type, - 'status': self.subscription_status, - 'is_active': self.is_active(), - 'start_date': self.start_date.isoformat() if self.start_date else None, - 'end_date': self.end_date.isoformat() if self.end_date else None, - 'days_left': self.days_left(), - 'billing_cycle': self.billing_cycle, - 'auto_renewal': self.auto_renewal - } - - -# ============================================ -# 订阅等级工具函数 -# ============================================ -def get_user_subscription_safe(user_id): - """ - 安全地获取用户订阅信息 - :param user_id: 用户ID - :return: UserSubscription 对象或默认免费订阅 - """ - try: - subscription = UserSubscription.query.filter_by(user_id=user_id).first() - if not subscription: - # 如果用户没有订阅记录,创建默认免费订阅 - subscription = UserSubscription( - user_id=user_id, - subscription_type='free', - subscription_status='active' - ) - db.session.add(subscription) - db.session.commit() - return subscription - except Exception as e: - print(f"获取用户订阅信息失败: {e}") - # 返回一个临时的免费订阅对象(不保存到数据库) - temp_sub = UserSubscription( - user_id=user_id, - subscription_type='free', - subscription_status='active' - ) - return temp_sub - - -def _get_current_subscription_info(): - """ - 获取当前登录用户订阅信息的字典形式,未登录或异常时视为免费用户。 - 小程序场景下从 request.current_user_id 获取用户ID - """ - try: - user_id = getattr(request, 'current_user_id', None) - if not user_id: - return { - 'type': 'free', - 'status': 'active', - 'is_active': True - } - sub = get_user_subscription_safe(user_id) - return { - 'type': sub.subscription_type, - 'status': sub.subscription_status, - 'is_active': sub.is_active(), - 'start_date': sub.start_date.isoformat() if sub.start_date else None, - 'end_date': sub.end_date.isoformat() if sub.end_date else None, - 'days_left': sub.days_left() - } - except Exception as e: - print(f"获取订阅信息异常: {e}") - return { - 'type': 'free', - 'status': 'active', - 'is_active': True - } - - -def _subscription_level(sub_type): - """将订阅类型映射到等级数值,free=0, pro=1, max=2。""" - mapping = {'free': 0, 'pro': 1, 'max': 2} - return mapping.get((sub_type or 'free').lower(), 0) - - -def _has_required_level(required: str) -> bool: - """判断当前用户是否达到所需订阅级别。""" - info = _get_current_subscription_info() - if not info.get('is_active', True): - return False - return _subscription_level(info.get('type')) >= _subscription_level(required) - - -# ============================================ -# 权限装饰器 -# ============================================ -def subscription_required(level='pro'): - """ - 订阅等级装饰器 - 小程序专用 - 用法: - @subscription_required('pro') # 需要 Pro 或 Max 用户 - @subscription_required('max') # 仅限 Max 用户 - - 注意:此装饰器需要配合 使用 - """ - from functools import wraps - - def decorator(f): - @wraps(f) - def decorated_function(*args, **kwargs): - if not _has_required_level(level): - current_info = _get_current_subscription_info() - current_type = current_info.get('type', 'free') - is_active = current_info.get('is_active', False) - - if not is_active: - return jsonify({ - 'success': False, - 'error': '您的订阅已过期,请续费后继续使用', - 'error_code': 'SUBSCRIPTION_EXPIRED', - 'current_subscription': current_type, - 'required_subscription': level - }), 403 - - return jsonify({ - 'success': False, - 'error': f'此功能需要 {level.upper()} 或更高等级会员', - 'error_code': 'SUBSCRIPTION_REQUIRED', - 'current_subscription': current_type, - 'required_subscription': level - }), 403 - - return f(*args, **kwargs) - - return decorated_function - - return decorator - - -def pro_or_max_required(f): - """ - 快捷装饰器:要求 Pro 或 Max 用户(小程序专用场景) - 等同于 @subscription_required('pro') - """ - from functools import wraps - - @wraps(f) - def decorated_function(*args, **kwargs): - if not _has_required_level('pro'): - current_info = _get_current_subscription_info() - current_type = current_info.get('type', 'free') - - return jsonify({ - 'success': False, - 'error': '小程序功能仅对 Pro 和 Max 会员开放', - 'error_code': 'MINIPROGRAM_PRO_REQUIRED', - 'current_subscription': current_type, - 'required_subscription': 'pro', - 'message': '请升级到 Pro 或 Max 会员以使用小程序完整功能' - }), 403 - - return f(*args, **kwargs) - - return decorated_function - - class User(UserMixin, db.Model): """用户模型""" id = db.Column(db.Integer, primary_key=True) @@ -553,6 +508,574 @@ class User(UserMixin, db.Model): return f'' +@app.route('/send_sms_verification_route', methods=['POST']) +def send_sms_verification_route(): + """发送短信验证码API""" + try: + print("收到短信发送请求") + + if not request.is_json: + print("请求不是JSON格式") + return jsonify({ + 'success': False, + 'message': '无效的请求' + }) + + phone = request.json.get('phone') + print(f"手机号: {phone}") + + if not phone: + return jsonify({ + 'success': False, + 'message': '手机号不能为空' + }) + + # 验证手机号格式 + if not phone.isdigit() or len(phone) != 11 or not phone.startswith('1'): + return jsonify({ + 'success': False, + 'message': '无效的手机号码格式' + }) + + # 使用最小化函数 + success, message = send_sms_verification_minimal(phone, redis_client) + + return jsonify({ + 'success': success, + 'message': message + }) + + except Exception as e: + print(f"路由异常: {type(e).__name__} - {str(e)}") + return jsonify({ + 'success': False, + 'message': '系统错误' + }) + + +@app.route('/check_phone', methods=['POST']) +def check_phone(): + """检查手机号是否已注册""" + phone = request.json.get('phone', '') + user = User.query.filter_by(phone=phone).first() + return jsonify({ + 'exists': user is not None, + 'message': '该手机号已被注册' if user else 'ok' + }) + + +@app.route('/register_with_phone', methods=['POST']) +def register_with_phone(): + """手机号注册""" + username = request.form.get('username') + phone = request.form.get('phone') + password = request.form.get('password') + verification_code = request.form.get('verification_code') + + # 验证数据 + if User.query.filter_by(username=username).first(): + return jsonify({ + 'success': False, + 'error_field': 'username', + 'message': '用户名已存在' + }) + + if User.query.filter_by(phone=phone).first(): + return jsonify({ + 'success': False, + 'error_field': 'phone', + 'message': '手机号已被注册' + }) + + # 验证短信验证码 + success, message = verify_sms_code(phone, verification_code) + if not success: + return jsonify({ + 'success': False, + 'error_field': 'verification_code', + 'message': message + }) + + # 创建用户 + try: + user = User(username=username, password=password) + user.phone = phone + user.phone_confirmed = True + user.phone_confirm_time = beijing_now() + + db.session.add(user) + db.session.commit() + + # 清除会话中的验证信息 + session.pop('sms_verification_code', None) + session.pop('sms_verification_phone', None) + session.pop('sms_verification_expiration', None) + + # 自动登录 + login_user(user) + + return jsonify({ + 'success': True, + 'message': '注册成功!', + 'redirect_url': url_for('index') + }) + except Exception as e: + db.session.rollback() + app.logger.error(f"手机注册错误: {str(e)}") + return jsonify({ + 'success': False, + 'message': '注册失败,请重试' + }) + + +@app.route('/login_with_phone', methods=['POST']) +def login_with_phone(): + """手机号验证码登录""" + phone = request.form.get('phone') + verification_code = request.form.get('verification_code') + next_page = request.form.get('next') + + # 验证短信验证码 + success, message = verify_sms_code(phone, verification_code) + if not success: + return jsonify({ + 'success': False, + 'message': message + }) + + # 根据手机号查找用户 + user = User.query.filter_by(phone=phone).first() + if not user: + return jsonify({ + 'success': False, + 'message': '该手机号未注册' + }) + + # 登录用户 + try: + login_user(user) + + # 清除会话中的验证信息 + session.pop('sms_verification_code', None) + session.pop('sms_verification_phone', None) + session.pop('sms_verification_expiration', None) + + # 更新用户最后活跃时间 + user.update_last_seen() + db.session.commit() + + # 重定向到下一页或首页 + redirect_url = next_page if next_page and next_page.startswith('/') else url_for('index') + + return jsonify({ + 'success': True, + 'message': '登录成功!', + 'redirect_url': redirect_url + }) + except Exception as e: + app.logger.error(f"手机登录错误: {str(e)}") + return jsonify({ + 'success': False, + 'message': '登录失败,请重试' + }) + + +@app.route('/bind_phone', methods=['POST']) +@token_required +def bind_phone(): + """绑定手机号到现有账号""" + phone = request.form.get('phone') + verification_code = request.form.get('verification_code') + + # 验证短信验证码 + success, message = verify_sms_code(phone, verification_code) + if not success: + return jsonify({ + 'success': False, + 'message': message + }) + + # 检查手机号是否已被其他账号使用 + existing_user = User.query.filter_by(phone=phone).first() + if existing_user and existing_user.id != request.user.id: + return jsonify({ + 'success': False, + 'message': '该手机号已被其他账号绑定' + }) + + try: + request.user.phone = phone + request.user.phone_confirmed = True + request.user.phone_confirm_time = beijing_now() + db.session.commit() + + # 清除会话中的验证信息 + session.pop('sms_verification_code', None) + session.pop('sms_verification_phone', None) + session.pop('sms_verification_expiration', None) + + return jsonify({ + 'success': True, + 'message': '手机号绑定成功' + }) + except Exception as e: + db.session.rollback() + app.logger.error(f"手机绑定错误: {str(e)}") + return jsonify({ + 'success': False, + 'message': '绑定失败,请重试' + }) + + +# 路由 +@app.route('/register', methods=['GET', 'POST']) +def register(): + if request.method == 'POST': + username = request.form.get('username') + email = request.form.get('email') + password = request.form.get('password') + verification_code = request.form.get('verification_code') + + # 验证数据 + if User.query.filter_by(username=username).first(): + return jsonify({ + 'success': False, + 'error_field': 'username', + 'message': '用户名已存在' + }) + + if User.query.filter_by(email=email).first(): + return jsonify({ + 'success': False, + 'error_field': 'email', + 'message': '邮箱已被注册' + }) + + # 验证验证码 + stored_code = session.get('verification_code') + stored_email = session.get('verification_email') + expiration = session.get('verification_expiration') + + if not all([stored_code, stored_email, expiration]): + return jsonify({ + 'success': False, + 'error_field': 'verification_code', + 'message': '请先获取验证码' + }) + + if stored_email != email: + return jsonify({ + 'success': False, + 'error_field': 'verification_code', + 'message': '邮箱与验证码不匹配' + }) + + if beijing_now().timestamp() > expiration: + return jsonify({ + 'success': False, + 'error_field': 'verification_code', + 'message': '验证码已过期' + }) + + if verification_code != stored_code: + return jsonify({ + 'success': False, + 'error_field': 'verification_code', + 'message': '验证码错误' + }) + + # 创建用户 + try: + user = User(username=username, email=email, password=password) + # 已验证邮箱,直接设置为已验证状态 + user.email_confirmed = True + db.session.add(user) + db.session.commit() + + # 清除会话中的验证信息 + session.pop('verification_code', None) + session.pop('verification_email', None) + session.pop('verification_expiration', None) + + return jsonify({ + 'success': True, + 'message': '注册成功!', + 'redirect_url': url_for('login') + }) + except Exception as e: + db.session.rollback() + app.logger.error(f"Registration error: {str(e)}") + return jsonify({ + 'success': False, + 'message': '注册失败,请重试' + }) + + return render_template('pages/sign-up/sign-up-basic.html') + + +@app.route('/check_username', methods=['POST']) +def check_username(): + username = request.json.get('username', '') + user = User.query.filter_by(username=username).first() + return jsonify({ + 'exists': user is not None, + 'message': '该用户名已被使用' if user else 'ok' + }) + + +@app.route('/check_email', methods=['POST']) +def check_email(): + email = request.json.get('email', '') + user = User.query.filter_by(email=email).first() + return jsonify({ + 'exists': user is not None, + 'message': '该邮箱已被注册' if user else 'ok' + }) + + +@app.route('/resend_verification', methods=['GET']) +def resend_verification(): + email = session.get('verification_email') + if not email: + return jsonify({ + 'success': False, + 'message': '请重新注册' + }) + + user = User.query.filter_by(email=email).first() + if not user: + return jsonify({ + 'success': False, + 'message': '用户不存在' + }) + + if user.email_confirmed: + return jsonify({ + 'success': False, + 'message': '邮箱已经验证过了' + }) + + # 发送新的验证码 + send_verification_email(email) + return jsonify({ + 'success': True, + 'message': '新的验证码已发送,请查收' + }) + + +@app.route('/send_verification', methods=['POST']) +def send_verification(): + if not request.is_json: + return jsonify({ + 'success': False, + 'message': '无效的请求' + }) + + email = request.json.get('email') + + if not email: + return jsonify({ + 'success': False, + 'message': '邮箱不能为空' + }) + + # Check if email is already registered + user = User.query.filter_by(email=email).first() + if user and user.email_confirmed: + return jsonify({ + 'success': False, + 'message': '该邮箱已注册且已验证' + }) + + # Send verification email + send_verification_email(email) + + return jsonify({ + 'success': True, + 'message': '验证码已发送' + }) + + +# 修改现有的登录路由以支持手机号登录 +# 修改现有的登录路由以支持手机号登录 +@app.route('/login', methods=['GET', 'POST']) +def login(): + if request.method == 'POST': + login_type = request.form.get('login_type', 'username') + next_page = request.form.get('next') + + # 常规用户名/邮箱登录 + if login_type in ['username', 'email']: + # 获取用户名/邮箱和密码 + if login_type == 'username': + credential = request.form.get('username') + password = request.form.get('password') + else: # email + credential = request.form.get('email') + password = request.form.get('email_password') + + # 查找用户 + user = None + if '@' in credential and login_type == 'email': + user = User.query.filter_by(email=credential).first() + elif login_type == 'username': + # 先尝试用户名查找 + user = User.query.filter_by(username=credential).first() + # 如果没找到,尝试手机号 + if user is None and credential.isdigit() and len(credential) == 11: + user = User.query.filter_by(phone=credential).first() + # 如果还没找到,尝试邮箱 + if user is None and '@' in credential: + user = User.query.filter_by(email=credential).first() + + # 验证密码 + if user and user.check_password(password): + # 检查账号状态 + if user.status != 'active': + flash('您的账号已被禁用或删除') + return redirect(url_for('login', next=next_page)) + + # 优先检查邮箱验证,其次检查手机号验证 + if not user.email_confirmed and not user.phone_confirmed: + flash('请先验证您的邮箱或手机号') + return redirect(url_for('login', next=next_page)) + + login_user(user) + + # 更新用户最后活跃时间 + user.update_last_seen() + db.session.commit() + + # 确保 next_page 是安全的 URL + if next_page and next_page.startswith('/'): + return redirect(next_page) + return redirect(url_for('index')) + + flash('账号或密码错误') + return redirect(url_for('login', next=next_page)) + + # 手机号验证码登录 + elif login_type == 'phone': + return redirect(url_for('login_with_phone_page')) + + # GET 请求,显示登录页面 + return render_template('pages/sign-in/index.html') + + +# 手机号登录页面 +@app.route('/login/phone', methods=['GET']) +def login_with_phone_page(): + next_page = request.args.get('next') + return render_template('pages/sign-in/sign-in-phone.html', next=next_page) + + +def send_notification_email(recipient, subject, template, **kwargs): + """ + 发送通知邮件 + :param recipient: 收件人邮箱 + :param subject: 邮件主题 + :param template: 模板文件名 + :param kwargs: 传递给模板的参数 + """ + try: + # 读取邮件模板内容 + msg = Message( + subject=subject, + sender=app.config['MAIL_USERNAME'], + recipients=[recipient] + ) + + # 渲染HTML邮件内容 + if template == 'emails/notification_post_liked.html': + msg.html = render_template_string(""" +
+

你的帖子收到了新的点赞

+
+

{{ liker.username }} 点赞了你的帖子

+

帖子内容: {{ post.content[:100] }}...

+
+ + 查看详情 + +

+ 如果你不想再收到此类通知,可以在个人设置中关闭邮件通知 +

+
+ """, **kwargs) + + elif template == 'emails/notification_post_commented.html': + msg.html = render_template_string(""" +
+

你的帖子收到了新的评论

+
+

{{ commenter.username }} 评论了你的帖子:

+

帖子内容: {{ post.content[:100] }}...

+
+ + 查看详情 + +

+ 如果你不想再收到此类通知,可以在个人设置中关闭邮件通知 +

+
+ """, **kwargs) + + elif template == 'emails/notification_comment_replied.html': + msg.html = render_template_string(""" +
+

你的评论收到了新的回复

+
+

{{ replier.username }} 回复了你的评论:

+

你的评论: {{ comment.content[:100] }}...

+
+ + 查看详情 + +

+ 如果你不想再收到此类通知,可以在个人设置中关闭邮件通知 +

+
+ """, **kwargs) + + # 使用异步任务发送邮件 + send_async_email(msg) + return True + + except Exception as e: + app.logger.error(f"Error sending notification email: {str(e)}") + return False + + +def send_async_email(msg): + """异步发送邮件""" + try: + mail.send(msg) + except Exception as e: + app.logger.error(f"Error sending async email: {str(e)}") + + +@app.route('/logout') +@token_required +def logout(): + logout_user() + return redirect(url_for('index')) + + +@app.context_processor +def inject_user(): + if has_request_context() and hasattr(request, 'user'): + return dict(current_user=request.user) + return dict(current_user=None) + + +@login_manager.user_loader +def load_user(user_id): + return User.query.get(int(user_id)) + + class Notification(db.Model): """通知模型""" id = db.Column(db.Integer, primary_key=True) @@ -570,6 +1093,61 @@ class Notification(db.Model): self.link = link +class CreatorApplication(db.Model): + """创作者申请模型""" + id = db.Column(db.Integer, primary_key=True) + user_id = db.Column(db.Integer, db.ForeignKey('user.id')) + status = db.Column(db.String(20), default='pending') # pending, approved, rejected + application_type = db.Column(db.String(20)) # analyst, strategist, trader, researcher + description = db.Column(db.Text) # 申请说明/个人介绍 + + # 资质材料 + qualifications = db.Column(db.JSON) # 资格证书、工作经验等 + sample_works = db.Column(db.JSON) # 作品案例链接 + + # 审核信息 + reviewer_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=True) # 审核人 + review_time = db.Column(db.DateTime, nullable=True) # 审核时间 + review_notes = db.Column(db.Text, nullable=True) # 审核备注 + + # 时间戳 + created_at = db.Column(db.DateTime, default=beijing_now) + updated_at = db.Column(db.DateTime, default=beijing_now, onupdate=beijing_now) + + # 关系 + user = db.relationship('User', foreign_keys=[user_id], backref='creator_applications') + reviewer = db.relationship('User', foreign_keys=[reviewer_id]) + + def __init__(self, user_id, status='pending', application_type=None, description=None): + self.user_id = user_id + self.status = status + self.application_type = application_type + self.description = description + + def approve(self, reviewer_id, notes=None): + """批准申请""" + self.status = 'approved' + self.reviewer_id = reviewer_id + self.review_time = beijing_now() + self.review_notes = notes + + # 更新用户为创作者 + user = User.query.get(self.user_id) + if user: + user.is_creator = True + user.creator_type = self.application_type + + def reject(self, reviewer_id, notes=None): + """拒绝申请""" + self.status = 'rejected' + self.reviewer_id = reviewer_id + self.review_time = beijing_now() + self.review_notes = notes + + def __repr__(self): + return f'' + + class Event(db.Model): """事件模型""" id = db.Column(db.Integer, primary_key=True) @@ -674,9 +1252,6 @@ class RelatedStock(db.Model): updated_at = db.Column(db.DateTime, default=beijing_now, onupdate=beijing_now) correlation = db.Column(db.Float()) momentum = db.Column(db.String(1024)) # 动量 - # 新增字段 - retrieved_sources = db.Column(db.JSON) # 研报检索源数据 - retrieved_update_time = db.Column(db.DateTime) # 检索数据更新时间 class RelatedData(db.Model): @@ -878,101 +1453,149 @@ class SectorInfo(db.Model): F005V = db.Column(db.String(50)) # Sector level 2 name F006V = db.Column(db.String(50)) # Sector level 3 name F007V = db.Column(db.String(50)) # Sector level 4 name -def send_async_email(msg): - """异步发送邮件""" + + +# 发送验证邮件 +def send_verification_email(email): + """ + 发送验证邮件 + :param email: 收件人邮箱 + """ + # 生成6位随机验证码 + verification_code = ''.join(random.choices('0123456789', k=6)) + + # 设置验证码有效期为10分钟 + expiration = beijing_now().timestamp() + 600 + + # 存储到会话 + session['verification_code'] = verification_code + session['verification_email'] = email + session['verification_expiration'] = expiration + + # 邮件主题 + subject = '价值前沿 - 邮箱验证' + + # 邮件内容 + html = f''' +
+

邮箱验证

+

感谢您注册价值前沿!请使用以下验证码完成邮箱验证:

+
+ {verification_code} +
+

验证码有效期为10分钟,请尽快完成验证。

+

如果这不是您的操作,请忽略此邮件。

+

+ 此邮件由系统自动发送,请勿回复。 +

+
+ ''' + + # 创建邮件对象 + msg = Message( + subject=subject, + sender=app.config['MAIL_USERNAME'], + recipients=[email], + html=html + ) + + # 使用异步任务发送邮件 + send_async_email(msg) + + + + +# 保护需要登录的路由 + +@app.route('/') +def index(): + return render_template('presentation.html') + + +@app.route('/profile') +@token_required +def profile(): + # Check if the user has completed the wizard + # If any essential fields are empty, redirect to the wizard + if not request.user.nickname or not request.user.investment_style or request.user.preferred_markets == '[]': + return redirect(url_for('profile_wizard')) + + # Get user's recent posts (limit to 5) + posts = Post.query.filter_by(user_id=request.user.id, status='active') \ + .order_by(Post.created_at.desc()) \ + .limit(5).all() + + # Get user's followed events (limit to 5) + followed_events = EventFollow.query.filter_by(user_id=request.user.id) \ + .order_by(EventFollow.created_at.desc()) \ + .limit(5).all() + + return render_template('pages/account/profile.html', + user=request.user, + posts=posts, + followed_events=followed_events) + + +@app.route('/profile/wizard') +@token_required +def profile_wizard(): + return render_template('pages/account/profile_wizard.html', user=request.user) + + +@app.route('/profile/wizard/save', methods=['POST']) +@token_required +def save_profile_wizard(): try: - mail.send(msg) + data = request.json + + # Update user information + user = request.user + + # Basic info (Step 1) + if 'nickname' in data: + user.nickname = data.get('nickname') + if 'gender' in data: + user.gender = data.get('gender') + if 'bio' in data: + user.bio = data.get('bio') + + # Investment preference (Step 2) + if 'investment_style' in data: + user.investment_style = data.get('investment_style') + + # Investment details (Step 3) + if 'trading_experience' in data: + user.trading_experience = int(data.get('trading_experience')) if data.get('trading_experience') else None + if 'investment_amount' in data: + user.investment_amount = data.get('investment_amount') + if 'risk_preference' in data: + user.risk_preference = data.get('risk_preference') + if 'preferred_markets' in data: + user.preferred_markets = json.dumps(data.get('preferred_markets', [])) + + # Save changes + db.session.commit() + + return jsonify({'success': True, 'message': '个人资料已成功更新'}) except Exception as e: - app.logger.error(f"Error sending async email: {str(e)}") -def verify_sms_code(phone_number, code): - """验证短信验证码""" - stored_code = session.get('sms_verification_code') - stored_phone = session.get('sms_verification_phone') - expiration = session.get('sms_verification_expiration') + db.session.rollback() + app.logger.error(f"Error saving profile wizard data: {str(e)}") + return jsonify({'success': False, 'message': f'保存失败: {str(e)}'}) - if not all([stored_code, stored_phone, expiration]): - return False, "请先获取验证码" - if stored_phone != phone_number: - return False, "手机号与验证码不匹配" - - if beijing_now().timestamp() > expiration: - return False, "验证码已过期" - - if code != stored_code: - return False, "验证码错误" - - return True, "验证成功" +@app.route('/settings', methods=['GET']) +@token_required +def settings(): + return render_template('pages/account/settings.html', user=request.user) def allowed_file(filename): return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS -# ============================================ -# 订阅相关 API 接口(小程序专用) -# ============================================ -@app.route('/api/subscription/info', methods=['GET']) - -def get_subscription_info(): - """ - 获取当前用户的订阅信息 - 小程序专用接口 - 返回用户当前订阅类型、状态、剩余天数等信息 - """ - try: - info = _get_current_subscription_info() - return jsonify({ - 'success': True, - 'data': info - }) - except Exception as e: - print(f"获取订阅信息错误: {e}") - return jsonify({ - 'success': True, - 'data': { - 'type': 'free', - 'status': 'active', - 'is_active': True, - 'days_left': 0 - } - }) - - -@app.route('/api/subscription/check', methods=['GET']) - -def check_subscription_access(): - """ - 检查当前用户是否有权限使用小程序功能 - 返回:是否为 Pro/Max 用户 - """ - try: - has_access = _has_required_level('pro') - info = _get_current_subscription_info() - - return jsonify({ - 'success': True, - 'data': { - 'has_access': has_access, - 'subscription_type': info.get('type', 'free'), - 'is_active': info.get('is_active', False), - 'message': '您可以使用小程序功能' if has_access else '小程序功能仅对 Pro 和 Max 会员开放' - } - }) - except Exception as e: - print(f"检查订阅权限错误: {e}") - return jsonify({ - 'success': False, - 'error': str(e) - }), 500 - - -# ============================================ -# 现有接口示例(应用权限控制) -# ============================================ - # 更新视图函数 @app.route('/settings/profile', methods=['POST']) - +@token_required def update_profile(): """更新个人资料""" try: @@ -1016,10 +1639,76 @@ def update_profile(): return jsonify({'success': False, 'message': '更新失败,请重试'}) +@app.route('/settings/verify_identity', methods=['POST']) +@token_required +def verify_identity(): + """处理实名认证""" + try: + user = request.user + + # 如果已经实名认证过,返回错误 + if user.is_verified: + return jsonify({'success': False, 'message': '您已完成实名认证'}) + + # 获取表单数据 + real_name = request.form.get('real_name') + id_number = request.form.get('id_number') + + # 简单的身份证号验证(实际项目中应该使用更复杂的验证方式) + if not (real_name and id_number and len(id_number) == 18): + return jsonify({'success': False, 'message': '请输入有效的身份信息'}) + + # 更新用户信息 + user.real_name = real_name + # 存储加密后的身份证号(实际项目中应使用更安全的加密方式) + user.id_number = hashlib.sha256(id_number.encode()).hexdigest() + user.is_verified = True + user.verify_time = beijing_now() + + db.session.commit() + return jsonify({'success': True, 'message': '实名认证成功'}) + + except Exception as e: + db.session.rollback() + app.logger.error(f"Error in identity verification: {str(e)}") + return jsonify({'success': False, 'message': '认证失败,请重试'}) + + +@app.route('/settings/password', methods=['POST']) +@token_required +def update_password(): + """修改密码""" + try: + current_password = request.form.get('current_password') + new_password = request.form.get('new_password') + confirm_password = request.form.get('confirm_password') + + # 验证当前密码 + if not request.user.check_password(current_password): + return jsonify({'success': False, 'message': '当前密码错误'}) + + # 验证新密码 + if new_password != confirm_password: + return jsonify({'success': False, 'message': '两次输入的密码不一致'}) + + if len(new_password) < 8: + return jsonify({'success': False, 'message': '密码长度至少为8位'}) + + # 更新密码 + request.user.set_password(new_password) + db.session.commit() + + return jsonify({'success': True, 'message': '密码已更新'}) + + except Exception as e: + db.session.rollback() + app.logger.error(f"Error updating password: {str(e)}") + return jsonify({'success': False, 'message': '密码更新失败,请重试'}) + # 投资偏好设置 @app.route('/settings/investment_preferences', methods=['POST']) - +@token_required def update_investment_preferences(): """更新投资偏好""" try: @@ -1041,10 +1730,383 @@ def update_investment_preferences(): return jsonify({'success': False, 'message': '更新失败,请重试'}) +@app.route('/settings/email', methods=['POST']) +@token_required +def update_email(): + """修改邮箱""" + try: + new_email = request.form.get('new_email') + if User.query.filter_by(email=new_email).first(): + flash('该邮箱已被使用', 'error') + return redirect(url_for('settings')) + + # 发送验证邮件 + send_verification_email(new_email, purpose='change_email') + session['new_email'] = new_email + + flash('验证邮件已发送到新邮箱,请查收', 'success') + return redirect(url_for('verify_new_email')) + + except Exception as e: + app.logger.error(f"Error updating email: {str(e)}") + flash('操作失败,请重试', 'error') + + return redirect(url_for('settings')) + + +@app.route('/settings/notifications', methods=['POST']) +@token_required +def update_notifications(): + """更新通知设置""" + try: + user = request.user + + # 更新通知方式 + user.email_notifications = 'email_notifications' in request.form + user.sms_notifications = 'sms_notifications' in request.form + user.wechat_notifications = 'wechat_notifications' in request.form + + # 更新通知类型偏好 + notification_preferences = { + 'notify_comments': 'notify_comments' in request.form, + 'notify_likes': 'notify_likes' in request.form, + 'notify_follows': 'notify_follows' in request.form, + 'notify_system': 'notify_system' in request.form, + 'notify_news': 'notify_news' in request.form + } + user.notification_preferences = json.dumps(notification_preferences) + + db.session.commit() + return jsonify({'success': True, 'message': '通知设置已更新'}) + + except Exception as e: + db.session.rollback() + app.logger.error(f"Error updating notifications: {str(e)}") + return jsonify({'success': False, 'message': '设置更新失败,请重试'}) + + +@app.route('/settings/privacy', methods=['POST']) +@token_required +def update_privacy(): + """更新隐私设置""" + try: + user = request.user + + user.privacy_level = request.form.get('privacy_level', 'public') + user.blocked_keywords = json.dumps( + [kw.strip() for kw in request.form.get('blocked_keywords', '').split('\n') if kw.strip()] + ) + + db.session.commit() + return jsonify({'success': True, 'message': '隐私设置已更新'}) + + except Exception as e: + db.session.rollback() + app.logger.error(f"Error updating privacy settings: {str(e)}") + return jsonify({'success': False, 'message': '设置更新失败,请重试'}) + + +@app.route('/settings/creator', methods=['POST']) +@token_required +def update_creator_settings(): + """更新创作者设置""" + try: + user = request.user + + if not user.is_creator: + return jsonify({'success': False, 'message': '您还不是创作者'}) + + user.creator_type = request.form.get('creator_type') + user.creator_tags = json.dumps(request.form.getlist('creator_tags')) + + db.session.commit() + return jsonify({'success': True, 'message': '创作者设置已更新'}) + + except Exception as e: + db.session.rollback() + app.logger.error(f"Error updating creator settings: {str(e)}") + return jsonify({'success': False, 'message': '设置更新失败,请重试'}) + + +@app.route('/settings/apply_creator', methods=['POST']) +@token_required +def apply_creator(): + """申请成为创作者""" + try: + user = request.user + + # 检查是否已经是创作者 + if user.is_creator: + return jsonify({'success': False, 'message': '您已经是创作者了'}) + + # 检查是否有待审核的申请 + existing_application = CreatorApplication.query.filter_by( + user_id=user.id, + status='pending' + ).first() + + if existing_application: + return jsonify({'success': False, 'message': '您已有一个正在审核的申请'}) + + # 基本条件检查 + if not user.is_verified: + return jsonify({'success': False, 'message': '请先完成实名认证'}) + + if user.post_count < 5: + return jsonify({'success': False, 'message': '需要至少发布5篇文章才能申请'}) + + # 获取申请信息 + application_type = request.form.get('creator_type') + description = request.form.get('description') + + # 验证必填信息 + if not all([application_type, description]): + return jsonify({'success': False, 'message': '请填写完整的申请信息'}) + + # 处理资质材料 + qualifications = [] + if 'qualifications' in request.files: + files = request.files.getlist('qualifications') + for file in files: + if file and allowed_file(file.filename): + filename = secure_filename(f"qual_{user.id}_{int(time.time())}_{file.filename}") + filepath = os.path.join(app.config['UPLOAD_FOLDER'], 'qualifications', filename) + os.makedirs(os.path.dirname(filepath), exist_ok=True) + file.save(filepath) + qualifications.append({ + 'name': file.filename, + 'path': f'/static/uploads/qualifications/{filename}' + }) + + # 创建申请记录 + application = CreatorApplication( + user_id=user.id, + status='pending', + application_type=application_type, + description=description + ) + + if qualifications: + application.qualifications = qualifications + + # 处理作品案例链接 + sample_works = request.form.get('sample_works') + if sample_works: + application.sample_works = json.loads(sample_works) + + db.session.add(application) + db.session.commit() + + # 发送通知给管理员 + notify_admins_new_application(application.id) + + # 发送确认邮件给申请者 + send_application_confirmation_email(user.email, application_type) + + return jsonify({ + 'success': True, + 'message': '申请已提交,我们会在3个工作日内审核并通知您结果' + }) + + except Exception as e: + db.session.rollback() + app.logger.error(f"Error in creator application: {str(e)}") + return jsonify({'success': False, 'message': '申请提交失败,请重试'}) + + +def notify_admins_new_application(application_id): + """通知管理员有新的创作者申请""" + admin_users = User.query.filter_by(is_admin=True).all() + for admin in admin_users: + # 发送站内通知 + notification = Notification( + user_id=admin.id, + type='new_creator_application', + content=f'有新的创作者申请需要审核 (ID: {application_id})', + link=f'/admin/creator_applications/{application_id}' + ) + db.session.add(notification) + + # 如果管理员开启了邮件通知 + if admin.email_notifications: + send_admin_notification_email( + admin.email, + '新的创作者申请待审核', + f'有新的创作者申请需要审核,请登录管理后台查看。\n申请ID: {application_id}' + ) + + db.session.commit() + + +def send_application_confirmation_email(email, application_type): + """发送申请确认邮件""" + subject = '创作者申请确认' + html_content = render_template( + 'emails/creator_application_confirmation.html', + application_type=application_type + ) + + msg = Message( + subject, + recipients=[email], + html=html_content + ) + mail.send(msg) + + +def send_admin_notification_email(email, subject, content): + """发送管理员通知邮件""" + try: + msg = Message( + subject, + sender=app.config['MAIL_USERNAME'], + recipients=[email] + ) + + msg.html = render_template( + 'emails/admin_notification.html', + subject=subject, + content=content + ) + + mail.send(msg) + return True + except Exception as e: + app.logger.error(f"Error sending admin notification email: {str(e)}") + return False + +@app.route('/api/stock/quotes', methods=['POST']) +def get_stock_quotes(): + codes = request.json.get('codes', []) + event_time = datetime.fromisoformat(request.json.get('event_time')) + current_time = datetime.now() + client = get_clickhouse_client() + + # Get stock names from MySQL + stock_names = {} + with engine.connect() as conn: + for code in codes: + codez = code.split('.')[0] + result = conn.execute(text( + "SELECT SECNAME FROM ea_stocklist WHERE SECCODE = :code" + ), {"code": codez}).fetchone() + if result: + stock_names[code] = result[0] + + def get_trading_day_and_times(event_datetime): + event_date = event_datetime.date() + event_time = event_datetime.time() + + # Trading hours + market_open = dt_time(9, 30) + market_close = dt_time(15, 0) + + with engine.connect() as conn: + # First check if the event date itself is a trading day + is_trading_day = conn.execute(text(""" + SELECT 1 FROM trading_days + WHERE EXCHANGE_DATE = :date + """), {"date": event_date}).fetchone() is not None + + if is_trading_day: + # If it's a trading day, determine time period based on event time + if event_time < market_open: + # Before market opens - use full trading day + return event_date, market_open, market_close + elif event_time > market_close: + # After market closes - get next trading day + next_trading_day = conn.execute(text(""" + SELECT EXCHANGE_DATE FROM trading_days + WHERE EXCHANGE_DATE > :date + ORDER BY EXCHANGE_DATE LIMIT 1 + """), {"date": event_date}).fetchone() + # Convert to date object if we found a next trading day + return (next_trading_day[0].date() if next_trading_day else None, + market_open, market_close) + else: + # During trading hours + return event_date, event_time, market_close + else: + # If not a trading day, get next trading day + next_trading_day = conn.execute(text(""" + SELECT EXCHANGE_DATE FROM trading_days + WHERE EXCHANGE_DATE > :date + ORDER BY EXCHANGE_DATE LIMIT 1 + """), {"date": event_date}).fetchone() + # Convert to date object if we found a next trading day + return (next_trading_day[0].date() if next_trading_day else None, + market_open, market_close) + + trading_day, start_time, end_time = get_trading_day_and_times(event_time) + + if not trading_day: + return jsonify({code: {'name': name, 'price': None, 'change': None} + for code, name in stock_names.items()}) + + # For historical dates, ensure we're using actual data + start_datetime = datetime.combine(trading_day, start_time) + end_datetime = datetime.combine(trading_day, end_time) + + # If the trading day is in the future relative to current time, + # return only names without data + if trading_day > current_time.date(): + return jsonify({code: {'name': name, 'price': None, 'change': None} + for code, name in stock_names.items()}) + + results = {} + for code in codes: + # Get the first price and last price for the trading period + data = client.execute(""" + WITH first_price AS ( + SELECT close + FROM stock_minute + WHERE code = %(code)s + AND timestamp >= %(start)s + AND timestamp <= %(end)s + ORDER BY timestamp + LIMIT 1 + ), + last_price AS ( + SELECT close + FROM stock_minute + WHERE code = %(code)s + AND timestamp >= %(start)s + AND timestamp <= %(end)s + ORDER BY timestamp DESC + LIMIT 1 + ) + SELECT + last_price.close as last_price, + (last_price.close - first_price.close) / first_price.close * 100 as change + FROM last_price + CROSS JOIN first_price + WHERE EXISTS (SELECT 1 FROM first_price) AND EXISTS (SELECT 1 FROM last_price) + """, { + 'code': code, + 'start': start_datetime, + 'end': end_datetime + }) + + if data and data[0]: + results[code] = { + 'price': data[0][0], + 'change': data[0][1], + 'name': stock_names.get(code, 'Unknown') + } + else: + results[code] = { + 'price': None, + 'change': None, + 'name': stock_names.get(code, 'Unknown') + } + + return jsonify(results) + + def get_clickhouse_client(): return Cclient( - host='222.128.1.157', - port=18000, + host='111.198.58.126', + port=18778, user='default', password='Zzl33818!', database='stock' @@ -1053,7 +2115,7 @@ def get_clickhouse_client(): @app.route('/api/stock//kline') def get_stock_kline(stock_code): - """获取股票K线数据 - 仅限 Pro/Max 会员(小程序功能)""" + """获取股票K线数据""" chart_type = request.args.get('chart_type', 'daily') # 默认改为daily event_time = request.args.get('event_time') @@ -1114,7 +2176,7 @@ def get_daily_kline(stock_code, event_datetime, stock_name): FROM ea_trade t JOIN date_range d ON t.TRADEDATE = d.TRADEDATE WHERE t.SECCODE = :stock_code - ORDER BY t.TRADEDATE + ORDER BY t.TRADEDATE DESC """ result = conn.execute(text(kline_sql), { @@ -1142,7 +2204,7 @@ def get_daily_kline(stock_code, event_datetime, stock_name): AND F005N IS NOT NULL AND F006N IS NOT NULL AND F004N IS NOT NULL - ORDER BY TRADEDATE + ORDER BY TRADEDATE DESC LIMIT 100 """ @@ -1203,9 +2265,8 @@ def get_daily_kline(stock_code, event_datetime, stock_name): def get_minute_kline(stock_code, event_datetime, stock_name): - """处理分钟K线数据 - 包含零轴(昨日收盘价)""" + """处理分钟K线数据""" client = get_clickhouse_client() - stock_code_short = stock_code.split('.')[0] # 获取不带后缀的股票代码 def get_trading_days(): trading_days = set() @@ -1233,59 +2294,10 @@ def get_minute_kline(stock_code, event_datetime, stock_name): return current_date return None - def get_prev_close(stock_code_short, target_date): - """获取前一交易日的收盘价作为零轴基准""" - prev_date = find_prev_trading_day(target_date) - if not prev_date: - return None - - try: - with engine.connect() as conn: - # 查询前一交易日的收盘价 - sql = """ - SELECT CAST(F007N AS FLOAT) as close - FROM ea_trade - WHERE SECCODE = :stock_code - AND TRADEDATE = :prev_date - AND F007N IS NOT NULL - LIMIT 1 - """ - result = conn.execute(text(sql), { - "stock_code": stock_code_short, - "prev_date": prev_date - }).fetchone() - - if result: - return float(result.close) - else: - # 如果指定日期没有数据,尝试获取最近的收盘价 - fallback_sql = """ - SELECT CAST(F007N AS FLOAT) as close, TRADEDATE - FROM ea_trade - WHERE SECCODE = :stock_code - AND TRADEDATE < :target_date - AND F007N IS NOT NULL - ORDER BY TRADEDATE DESC - LIMIT 1 - """ - result = conn.execute(text(fallback_sql), { - "stock_code": stock_code_short, - "target_date": target_date - }).fetchone() - - if result: - print(f"Using close price from {result.TRADEDATE} as zero axis") - return float(result.close) - - except Exception as e: - print(f"Error getting previous close: {e}") - - return None - target_date = event_datetime.date() is_after_market = event_datetime.time() > dt_time(15, 0) - # 核心逻辑:先判断当前日期是否是交易日,以及是否已收盘 + # 核心逻辑改动:先判断当前日期是否是交易日,以及是否已收盘 if target_date in trading_days and is_after_market: # 如果是交易日且已收盘,查找下一个交易日 next_trade_date = find_next_trading_day(target_date) @@ -1310,9 +2322,6 @@ def get_minute_kline(stock_code, event_datetime, stock_name): 'type': 'minute' }) - # 获取前一交易日收盘价作为零轴 - zero_axis = get_prev_close(stock_code_short, target_date) - # 获取目标日期的完整交易时段数据 data = client.execute(""" SELECT @@ -1333,47 +2342,187 @@ def get_minute_kline(stock_code, event_datetime, stock_name): 'end': datetime.combine(target_date, dt_time(15, 0)) }) - kline_data = [] - for row in data: - point = { - 'time': row[0].strftime('%H:%M'), - 'open': float(row[1]), - 'high': float(row[2]), - 'low': float(row[3]), - 'close': float(row[4]), - 'volume': float(row[5]), - 'amount': float(row[6]) - } + kline_data = [{ + 'time': row[0].strftime('%H:%M'), + 'open': float(row[1]), + 'high': float(row[2]), + 'low': float(row[3]), + 'close': float(row[4]), + 'volume': float(row[5]), + 'amount': float(row[6]) + } for row in data] - # 如果有零轴数据,计算涨跌幅和涨跌额 - if zero_axis: - point['prev_close'] = zero_axis - point['change'] = point['close'] - zero_axis # 涨跌额 - point['change_pct'] = ((point['close'] - zero_axis) / zero_axis * 100) if zero_axis != 0 else 0 # 涨跌幅百分比 - - kline_data.append(point) - - response_data = { + return jsonify({ 'code': stock_code, 'name': stock_name, 'data': kline_data, 'trade_date': target_date.strftime('%Y-%m-%d'), 'type': 'minute', 'is_history': target_date < event_datetime.date() - } + }) - # 添加零轴信息到响应中 - if zero_axis: - response_data['zero_axis'] = zero_axis - response_data['prev_close'] = zero_axis - # 计算当日整体涨跌幅(如果有数据) - if kline_data: - last_close = kline_data[-1]['close'] - response_data['day_change'] = last_close - zero_axis - response_data['day_change_pct'] = ((last_close - zero_axis) / zero_axis * 100) if zero_axis != 0 else 0 +@app.route('/api/related-stock/add', methods=['POST']) +@login_required +def add_related_stock(): + data = request.json + event_id = data.get('event_id') + stock_code = data.get('stock_code') + relation_desc = data.get('relation_desc') + + # 验证股票是否存在 + client = get_clickhouse_client() + stock_exists = client.execute( + "SELECT 1 FROM stock_minute WHERE code = %(code)s LIMIT 1", + {'code': stock_code} + ) + + if not stock_exists: + return jsonify({ + 'success': False, + 'message': '股票代码不存在' + }) + + try: + related_stock = RelatedStock( + event_id=event_id, + stock_code=stock_code, + relation_desc=relation_desc + ) + db.session.add(related_stock) + db.session.commit() + + return jsonify({'success': True}) + except Exception as e: + db.session.rollback() + return jsonify({ + 'success': False, + 'message': str(e) + }) + + +@app.route('/api/related-stock/', methods=['DELETE']) +@login_required +def delete_related_stock(stock_id): + try: + stock = RelatedStock.query.get_or_404(stock_id) + db.session.delete(stock) + db.session.commit() + return jsonify({'success': True}) + except Exception as e: + db.session.rollback() + return jsonify({ + 'success': False, + 'message': str(e) + }) + + +# 事件相关路由 +@app.route('/event/create', methods=['GET', 'POST']) +@token_required +def create_event(): + """创建新事件""" + if request.method == 'POST': + try: + app.logger.info("Received event creation request") + + # 获取表单数据 + title = request.form.get('title') + description = request.form.get('description') + event_type = request.form.get('event_type') + + if not all([title, description, event_type]): + return jsonify({ + 'success': False, + 'message': '请填写所有必填字段' + }), 400 + + # 创建新事件 + event = Event( + title=title, + description=description, + event_type=event_type, + creator_id=request.user.id, + status='active' + ) + + # 处理可选字段 + if request.form.get('is_top'): + event.is_top = True + + if request.form.getlist('keywords'): + keywords = request.form.getlist('keywords') + # 确保关键词是UTF-8编码 + keywords = [keyword.encode('utf-8').decode('utf-8') for keyword in keywords] + event.keywords = json.dumps(keywords, ensure_ascii=False) + + if request.form.getlist('related_stocks'): + event.related_stocks = json.dumps(request.form.getlist('related_stocks')) + + if request.form.getlist('related_industries'): + event.related_industries = json.dumps(request.form.getlist('related_industries')) + + db.session.add(event) + db.session.commit() + + app.logger.info(f"Event created successfully with ID: {event.id}") + + return jsonify({ + 'success': True, + 'event_id': event.id, + 'message': '话题创建成功!' + }) + + except Exception as e: + db.session.rollback() + app.logger.error(f"Error creating event: {str(e)}") + return jsonify({ + 'success': False, + 'message': f'创建失败:{str(e)}' + }), 500 + + return render_template('projects/create_event.html') + + +@app.route('/file-upload', methods=['POST']) +@token_required +def upload_file(): + """处理文件上传""" + try: + if 'file' not in request.files: + return jsonify({'error': 'No file part'}), 400 + + file = request.files['file'] + if file.filename == '': + return jsonify({'error': 'No selected file'}), 400 + + if file and allowed_file(file.filename): + filename = secure_filename(file.filename) + # 生成带时间戳的文件名以避免冲突 + unique_filename = f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_{filename}" + filepath = os.path.join(app.config['UPLOAD_FOLDER'], 'events', unique_filename) + + # 确保上传目录存在 + os.makedirs(os.path.dirname(filepath), exist_ok=True) + + file.save(filepath) + return jsonify({ + 'success': True, + 'file': unique_filename, + 'url': url_for('static', filename=f'uploads/events/{unique_filename}') + }) + + return jsonify({'error': 'File type not allowed'}), 400 + + except Exception as e: + app.logger.error(f"Error uploading file: {str(e)}") + return jsonify({'error': 'Upload failed'}), 500 + + +@app.route('/data/concepts/') +def concept_images(filename): + return send_from_directory(os.path.join(app.root_path, 'data', 'concepts'), filename) - return jsonify(response_data) class HistoricalEvent(db.Model): """历史事件模型""" @@ -1410,9 +2559,40 @@ class HistoricalEventStock(db.Model): ) +@app.route('/api/historical-event//stocks') +def get_historical_event_stocks(event_id): + """获取历史事件的相关股票""" + historical_event = HistoricalEvent.query.get_or_404(event_id) + stocks = historical_event.stocks.all() + + return jsonify({ + 'success': True, + 'stocks': [{ + 'id': stock.id, + 'stock_code': stock.stock_code, + 'stock_name': stock.stock_name, + 'relation_desc': stock.relation_desc, + 'correlation': stock.correlation, + 'sector': stock.sector + } for stock in stocks] + }) + + +@app.route('/api/related-data/') +def get_related_data_details(data_id): + """获取关联数据详情""" + data = RelatedData.query.get_or_404(data_id) + return jsonify({ + 'id': data.id, + 'title': data.title, + 'data_type': data.data_type, + 'data_content': data.data_content, + 'description': data.description + }) + @app.route('/event/follow/', methods=['POST']) - +@token_required def follow_event(event_id): """关注/取消关注事件""" event = Event.query.get_or_404(event_id) @@ -1442,7 +2622,7 @@ def follow_event(event_id): # 帖子相关路由 @app.route('/post/create/', methods=['GET', 'POST']) - +@token_required def create_post(event_id): """创建新帖子""" event = Event.query.get_or_404(event_id) @@ -1495,7 +2675,7 @@ def create_post(event_id): # 点赞相关路由 @app.route('/post/like/', methods=['POST']) - +@token_required def like_post(post_id): """点赞/取消点赞帖子""" post = Post.query.get_or_404(post_id) @@ -1533,6 +2713,117 @@ def like_post(post_id): return jsonify({'success': False, 'message': '操作失败,请重试'}) +# 通知相关函数 +def notify_user_post_liked(post): + """当用户的帖子被点赞时发送通知""" + try: + notification = Notification( + user_id=post.user_id, + type='post_like', + content=f'{request.user.username} 点赞了你的帖子', + link=url_for('event_detail', event_id=post.event_id, _anchor=f'post-{post.id}'), + related_user_id=request.user.id, + related_post_id=post.id + ) + db.session.add(notification) + + # 如果用户开启了邮件通知 + user = User.query.get(post.user_id) + if user.email_notifications: + send_notification_email( + recipient=user.email, + subject='你的帖子收到了新的点赞', + template='emails/notification_post_liked.html', + user=user, + post=post, + liker=request.user + ) + + except Exception as e: + app.logger.error(f"Error creating like notification: {str(e)}") + # 通知创建失败不应影响主要功能 + pass + + +def notify_user_post_commented(post): + """当用户的帖子收到评论时发送通知""" + try: + notification = Notification( + user_id=post.user_id, + type='post_comment', + content=f'{request.user.username} 评论了你的帖子', + link=url_for('event_detail', event_id=post.event_id, _anchor=f'post-{post.id}'), + related_user_id=request.user.id, + related_post_id=post.id + ) + db.session.add(notification) + + # 如果用户开启了邮件通知 + user = User.query.get(post.user_id) + if user.email_notifications: + send_notification_email( + recipient=user.email, + subject='你的帖子收到了新的评论', + template='emails/notification_post_commented.html', + user=user, + post=post, + commenter=request.user + ) + + except Exception as e: + app.logger.error(f"Error creating comment notification: {str(e)}") + pass + + +def notify_user_comment_replied(parent_comment): + """当用户的评论被回复时发送通知""" + try: + notification = Notification( + user_id=parent_comment.user_id, + type='comment_reply', + content=f'{request.user.username} 回复了你的评论', + link=url_for('event_detail', + event_id=parent_comment.post.event_id, + _anchor=f'comment-{parent_comment.id}'), + related_user_id=request.user.id, + related_post_id=parent_comment.post_id, + related_comment_id=parent_comment.id + ) + db.session.add(notification) + + # 如果用户开启了邮件通知 + user = User.query.get(parent_comment.user_id) + if user.email_notifications: + send_notification_email( + recipient=user.email, + subject='你的评论收到了新的回复', + template='emails/notification_comment_replied.html', + user=user, + comment=parent_comment, + replier=request.user + ) + + except Exception as e: + app.logger.error(f"Error creating reply notification: {str(e)}") + pass + + +def cleanup_old_notifications(): + """清理30天前的已读通知""" + with app.app_context(): + try: + thirty_days_ago = beijing_now() - timedelta(days=30) + Notification.query.filter( + Notification.created_at < thirty_days_ago, + Notification.is_read == True + ).delete() + db.session.commit() + current_app.logger.info("Successfully cleaned up old notifications") + except Exception as e: + db.session.rollback() + current_app.logger.error(f"Error cleaning up notifications: {str(e)}") + + def update_user_activity(): """更新用户活跃度""" with app.app_context(): @@ -1574,7 +2865,7 @@ def update_user_activity(): @app.route('/post/comment/', methods=['POST']) - +@token_required def add_comment(post_id): """添加评论""" post = Post.query.get_or_404(post_id) @@ -1626,7 +2917,7 @@ def add_comment(post_id): @app.route('/post/comments/') - +@token_required def get_comments(post_id): """获取帖子评论列表""" page = request.args.get('page', 1, type=int) @@ -1795,7 +3086,11 @@ def api_sector_hierarchy(): try: # 定义需要返回的行业分类体系 classification_systems = [ - '申银万国行业分类' + '申银万国行业分类', + '巨潮行业分类', + '新财富行业分类', + '证监会行业分类(2001)', + '中上协行业分类' ] result = [] # 改为数组 @@ -1924,6 +3219,135 @@ def api_sector_hierarchy(): }), 500 +@app.route('/api/sector/hierarchy/simple', methods=['GET']) +def api_sector_hierarchy_simple(): + """简化版行业层级关系接口:只展示到三级分类""" + try: + # 查询所有申银万国行业分类数据 + sectors = SectorInfo.query.filter_by(F002V='申银万国行业分类').all() + + # 构建简化层级结构 + hierarchy = {} + + for sector in sectors: + sw_primary = sector.F004V # 申万一级行业 + sw_secondary = sector.F005V # 申万二级行业 + + # 获取对应的主板块分类 + primary_sector = get_primary_sector_by_sw_primary(sw_primary) + + # 初始化主板块 + if primary_sector not in hierarchy: + hierarchy[primary_sector] = {} + + # 初始化申万一级行业 + if sw_primary not in hierarchy[primary_sector]: + hierarchy[primary_sector][sw_primary] = set() + + # 添加申万二级行业 + if sw_secondary: + hierarchy[primary_sector][sw_primary].add(sw_secondary) + + # 格式化输出结构 + result = [] + for primary_sector, sw_primaries in hierarchy.items(): + primary_data = { + 'primary_sector': primary_sector, + 'sw_primary_sectors': [] + } + + for sw_primary, sw_secondaries in sw_primaries.items(): + sw_primary_item = { + 'sw_primary_sector': sw_primary, + 'sw_secondary_sectors': sorted(list(sw_secondaries)) + } + primary_data['sw_primary_sectors'].append(sw_primary_item) + + result.append(primary_data) + + # 按主板块名称排序 + result.sort(key=lambda x: x['primary_sector']) + + return jsonify({ + "code": 200, + "message": "success", + "data": result + }) + + except Exception as e: + return jsonify({ + "code": 500, + "message": str(e), + "data": None + }), 500 + + +@app.route('/api/sector/mapping', methods=['GET']) +def api_sector_mapping(): + """行业映射关系接口:展示primary_sector到申万一级行业的映射关系""" + try: + # 使用现有的映射关系 + sector_map = { + # 大周期 + '石油石化': '大周期', '煤炭': '大周期', '有色金属': '大周期', + '钢铁': '大周期', '基础化工': '大周期', '建筑材料': '大周期', + '机械设备': '大周期', '电力设备及新能源': '大周期', '国防军工': '大周期', + '电力设备': '大周期', '电网设备': '大周期', '风力发电': '大周期', + '太阳能发电': '大周期', '建筑装饰': '大周期', '建筑': '大周期', + '交通运输': '大周期', '采掘': '大周期', '公用事业': '大周期', + + # 大消费 + '汽车': '大消费', '家用电器': '大消费', '酒类': '大消费', + '食品饮料': '大消费', '医药生物': '大消费', '纺织服饰': '大消费', + '农林牧渔': '大消费', '商贸零售': '大消费', '轻工制造': '大消费', + '消费者服务': '大消费', '美容护理': '大消费', '社会服务': '大消费', + '纺织服装': '大消费', '商业贸易': '大消费', '休闲服务': '大消费', + + # 大金融地产 + '银行': '大金融地产', '证券': '大金融地产', '保险': '大金融地产', + '多元金融': '大金融地产', '综合金融': '大金融地产', + '房地产': '大金融地产', '非银金融': '大金融地产', + + # TMT板块 + '计算机': 'TMT板块', '电子': 'TMT板块', '传媒': 'TMT板块', '通信': 'TMT板块', + + # 公共产业 + '环保': '公共产业板块', '综合': '公共产业板块' + } + + # 重组为 primary_sector -> [sw_primary_sectors] + result = {} + for sw_primary, primary_sector in sector_map.items(): + if primary_sector not in result: + result[primary_sector] = [] + result[primary_sector].append(sw_primary) + + # 格式化输出 + formatted_result = [ + { + 'primary_sector': primary, + 'sw_primary_sectors': sorted(sw_primaries) + } + for primary, sw_primaries in result.items() + ] + + # 按主板块名称排序 + formatted_result.sort(key=lambda x: x['primary_sector']) + + return jsonify({ + "code": 200, + "message": "success", + "data": formatted_result + }) + + except Exception as e: + return jsonify({ + "code": 500, + "message": str(e), + "data": None + }), 500 + + @app.route('/api/sector/banner', methods=['GET']) def api_sector_banner(): """行业分类 banner 接口:返回一级分类和对应二级行业列表""" @@ -1973,6 +3397,578 @@ def api_sector_banner(): }), 500 +def get_primary_sector_by_sw_primary(sw_primary): + """根据申万一级行业获取主板块分类""" + sector_map = { + # 大周期 + '石油石化': '大周期', '煤炭': '大周期', '有色金属': '大周期', + '钢铁': '大周期', '基础化工': '大周期', '建筑材料': '大周期', + '机械设备': '大周期', '电力设备及新能源': '大周期', '国防军工': '大周期', + '电力设备': '大周期', '电网设备': '大周期', '风力发电': '大周期', + '太阳能发电': '大周期', '建筑装饰': '大周期', '建筑': '大周期', + '交通运输': '大周期', '采掘': '大周期', '公用事业': '大周期', + + # 大消费 + '汽车': '大消费', '家用电器': '大消费', '酒类': '大消费', + '食品饮料': '大消费', '医药生物': '大消费', '纺织服饰': '大消费', + '农林牧渔': '大消费', '商贸零售': '大消费', '轻工制造': '大消费', + '消费者服务': '大消费', '美容护理': '大消费', '社会服务': '大消费', + '纺织服装': '大消费', '商业贸易': '大消费', '休闲服务': '大消费', + + # 大金融地产 + '银行': '大金融地产', '证券': '大金融地产', '保险': '大金融地产', + '多元金融': '大金融地产', '综合金融': '大金融地产', + '房地产': '大金融地产', '非银金融': '大金融地产', + + # TMT板块 (重点:电子 → TMT板块) + '计算机': 'TMT板块', '电子': 'TMT板块', '传媒': 'TMT板块', '通信': 'TMT板块', + + # 公共产业 + '环保': '公共产业板块', '综合': '公共产业板块' + } + return sector_map.get(sw_primary, '其他') + + +@app.route('/api/stock//primary') +@token_required +def get_stock_primary_sector(stock_id): + """获取股票申万一级行业及对应主板块分类""" + try: + # 处理股票代码,移除后缀如 .SH/.SZ + base_stock_code = stock_id.split('.')[0] + + # 查找该股票的行业分类信息(以申银万国行业为标准) + sector_info = SectorInfo.query.filter( + SectorInfo.SECCODE.ilike(f"{base_stock_code}%"), + SectorInfo.F002V == '申银万国行业分类' + ).first() + + if not sector_info: + return jsonify({ + 'code': 404, + 'message': f'未找到股票 {stock_id} 的行业分类信息', + 'data': None + }), 404 + + sw_primary_sector = sector_info.F004V # 申万一级行业 + primary_sector = get_primary_sector_by_sw_primary(sw_primary_sector) + + result = { + 'stock_code': sector_info.SECCODE, + 'stock_name': sector_info.SECNAME, + 'sw_primary_sector': sw_primary_sector, # 申万一级行业 + 'primary_sector': primary_sector # 对应的主板块分类 + } + + return jsonify({ + 'code': 200, + 'message': 'success', + 'data': result + }) + + except Exception as e: + print(f"Error in get_stock_primary_sector: {str(e)}") + return jsonify({ + 'code': 500, + 'message': str(e), + 'data': None + }), 500 + + +@app.route('/api/stock//secondary') +@token_required +def get_stock_secondary_sector(stock_id): + """获取股票申万二级行业信息""" + try: + base_stock_code = stock_id.split('.')[0] + + sector_info = SectorInfo.query.filter( + SectorInfo.SECCODE.ilike(f"{base_stock_code}%"), + SectorInfo.F002V == '申银万国行业分类' + ).first() + + if not sector_info: + return jsonify({ + 'code': 404, + 'message': f'未找到股票 {stock_id} 的行业分类信息', + 'data': None + }), 404 + + sw_primary_sector = sector_info.F004V # 申万一级行业 + sw_secondary_sector = sector_info.F005V # 申万二级行业 + primary_sector = get_primary_sector_by_sw_primary(sw_primary_sector) + + result = { + 'stock_code': sector_info.SECCODE, + 'stock_name': sector_info.SECNAME, + 'sw_primary_sector': sw_primary_sector, + 'sw_secondary_sector': sw_secondary_sector, + 'primary_sector': primary_sector # 基于申万一级行业映射 + } + + return jsonify({ + 'code': 200, + 'message': 'success', + 'data': result + }) + + except Exception as e: + print(f"Error in get_stock_secondary_sector: {str(e)}") + return jsonify({ + 'code': 500, + 'message': str(e), + 'data': None + }), 500 + + +@app.route('/api/stock//third') +@token_required +def get_stock_third_sector(stock_id): + """获取股票申万三级行业信息 - 对应F005V""" + try: + base_stock_code = stock_id.split('.')[0] + + sector_info = SectorInfo.query.filter( + SectorInfo.SECCODE.ilike(f"{base_stock_code}%"), + SectorInfo.F002V == '申银万国行业分类' + ).first() + + if not sector_info: + return jsonify({ + 'code': 404, + 'message': f'未找到股票 {stock_id} 的行业分类信息', + 'data': None + }), 404 + + sw_primary_sector = sector_info.F004V # 申万一级行业 + sw_secondary_sector = sector_info.F005V # 申万二级行业 + sw_third_sector = sector_info.F005V # 申万三级行业 (根据你的说明对应F005V) + primary_sector = get_primary_sector_by_sw_primary(sw_primary_sector) + + result = { + 'stock_code': sector_info.SECCODE, + 'stock_name': sector_info.SECNAME, + 'sw_primary_sector': sw_primary_sector, + 'sw_secondary_sector': sw_secondary_sector, + 'sw_third_sector': sw_third_sector, # 对应F005V + 'primary_sector': primary_sector + } + + return jsonify({ + 'code': 200, + 'message': 'success', + 'data': result + }) + + except Exception as e: + print(f"Error in get_stock_third_sector: {str(e)}") + return jsonify({ + 'code': 500, + 'message': str(e), + 'data': None + }), 500 + + +@app.route('/api/stock//fourth') +@token_required +def get_stock_fourth_sector(stock_id): + """获取股票申万四级行业信息 - 对应F006V""" + try: + base_stock_code = stock_id.split('.')[0] + + sector_info = SectorInfo.query.filter( + SectorInfo.SECCODE.ilike(f"{base_stock_code}%"), + SectorInfo.F002V == '申银万国行业分类' + ).first() + + if not sector_info: + return jsonify({ + 'code': 404, + 'message': f'未找到股票 {stock_id} 的行业分类信息', + 'data': None + }), 404 + + sw_primary_sector = sector_info.F004V # 申万一级行业 + sw_secondary_sector = sector_info.F005V # 申万二级行业 + sw_third_sector = sector_info.F005V # 申万三级行业 (F005V) + sw_fourth_sector = sector_info.F006V # 申万四级行业 (F006V) + primary_sector = get_primary_sector_by_sw_primary(sw_primary_sector) + + result = { + 'stock_code': sector_info.SECCODE, + 'stock_name': sector_info.SECNAME, + 'sw_primary_sector': sw_primary_sector, + 'sw_secondary_sector': sw_secondary_sector, + 'sw_third_sector': sw_third_sector, + 'sw_fourth_sector': sw_fourth_sector, # 对应F006V + 'primary_sector': primary_sector + } + + return jsonify({ + 'code': 200, + 'message': 'success', + 'data': result + }) + + except Exception as e: + print(f"Error in get_stock_fourth_sector: {str(e)}") + return jsonify({ + 'code': 500, + 'message': str(e), + 'data': None + }), 500 + + +@app.route('/api/stock//fifth') +@token_required +def get_stock_fifth_sector(stock_id): + """获取股票申万五级行业信息 - 对应F007V""" + try: + base_stock_code = stock_id.split('.')[0] + + sector_info = SectorInfo.query.filter( + SectorInfo.SECCODE.ilike(f"{base_stock_code}%"), + SectorInfo.F002V == '申银万国行业分类' + ).first() + + if not sector_info: + return jsonify({ + 'code': 404, + 'message': f'未找到股票 {stock_id} 的行业分类信息', + 'data': None + }), 404 + + sw_primary_sector = sector_info.F004V # 申万一级行业 + sw_secondary_sector = sector_info.F005V # 申万二级行业 + sw_third_sector = sector_info.F005V # 申万三级行业 (F005V) + sw_fourth_sector = sector_info.F006V # 申万四级行业 (F006V) + sw_fifth_sector = sector_info.F007V # 申万五级行业 (F007V) - 新增字段 + primary_sector = get_primary_sector_by_sw_primary(sw_primary_sector) + + result = { + 'stock_code': sector_info.SECCODE, + 'stock_name': sector_info.SECNAME, + 'sw_primary_sector': sw_primary_sector, + 'sw_secondary_sector': sw_secondary_sector, + 'sw_third_sector': sw_third_sector, + 'sw_fourth_sector': sw_fourth_sector, + 'sw_fifth_sector': sw_fifth_sector, # 对应F007V + 'primary_sector': primary_sector + } + + return jsonify({ + 'code': 200, + 'message': 'success', + 'data': result + }) + + except Exception as e: + print(f"Error in get_stock_fifth_sector: {str(e)}") + return jsonify({ + 'code': 500, + 'message': str(e), + 'data': None + }), 500 + + +def get_calendar_events(): + """获取日历事件数据""" + date_str = request.args.get('date') + if not date_str: + return jsonify({'error': 'Date parameter is required'}), 400 + + + try: + # 解析日期 + target_date = datetime.strptime(date_str, '%Y-%m-%d') + end_date = target_date + timedelta(days=1) + + # 更新查询以包含related_stocks和concepts + query = """ + WITH RankedEvents AS ( + SELECT + data_id, + calendar_time, + type, + star, + title, + former, + forecast, + fact, + related_stocks, -- 添加相关股票 + concepts, -- 添加相关概念 + primary_sectors, + inferred_tag + ROW_NUMBER() OVER (PARTITION BY title ORDER BY star DESC) as rn + FROM future_events + WHERE calendar_time BETWEEN :start_date AND :end_date + ) + SELECT DISTINCT + data_id, + calendar_time, + type, + star, + title, + former, + forecast, + fact, + related_stocks, + concepts, + primary_sectors, + inferred_tag + FROM RankedEvents + WHERE rn = 1 + ORDER BY calendar_time + """ + + result = db.session.execute(text(query), { + 'start_date': target_date, + 'end_date': end_date + }) + + # 更新返回的事件数据格式化 + events = [] + for event in result: + events.append({ + 'calendar_time': event.calendar_time.isoformat() if event.calendar_time else None, + 'type': event.type, + 'star': event.star, + 'title': event.title, + 'former': event.former, + 'forecast': event.forecast, + 'fact': event.fact, + 'related_stocks': event.related_stocks, + 'concepts': event.concepts, + 'primary_sectors': event.primary_sectors, + 'inferred_tag': event.inferred_tag + }) + + # 提交事务 + db.session.commit() + + return jsonify(events) + + except Exception as e: + return jsonify({'error': str(e)}), 500 + +# New API endpoint for hot news +@app.route('/api/hot-news') +def hot_news(): + """Get the 4 hottest news items from the last 3 days based on average price increase""" + def format_events(events): + """Format event list to JSON-compatible data structure""" + try: + return [{ + 'id': event.id, + 'title': event.title, + 'description': event.description, + 'created_at': event.updated_at.strftime('%Y-%m-%d %H:%M:%S'), + 'importance': event.importance, + 'creator': { + 'username': event.creator.username if event.creator else 'Anonymous', + 'avatar_url': get_full_avatar_url(event.creator.avatar_url) if event.creator else None + } if event.creator else None, + 'related_avg_chg': event.related_avg_chg, + 'related_max_chg': event.related_max_chg, + 'related_week_chg': event.related_week_chg, + 'post_count': event.post_count, + 'follower_count': event.follower_count, + 'view_count': event.view_count + } for event in events] + except Exception as e: + logger.error(f"Error formatting events: {str(e)}", exc_info=True) + return [] + + try: + # Calculate date 3 days ago + three_days_ago = datetime.now() - timedelta(days=3) + + # Query events from last 3 days, sorted by average price increase + hot_events = Event.query.filter( + Event.status == 'active', + Event.created_at >= three_days_ago, + Event.related_avg_chg != None, # Ensure price data exists + Event.related_avg_chg > 0 # Only positive changes + ).order_by(Event.related_avg_chg.desc()).limit(4).all() + + # If not enough events with positive price increases, get additional popular events + if len(hot_events) < 4: + additional_events = Event.query.filter( + Event.status == 'active', + Event.created_at >= three_days_ago, + ~Event.id.in_([event.id for event in hot_events]) + ).order_by(Event.hot_score.desc()).limit(4 - len(hot_events)).all() + + hot_events.extend(additional_events) + + # Format response data + events_data = format_events(hot_events) + + return jsonify(events_data) + + except Exception as e: + logger.error(f"Error getting hot news: {str(e)}", exc_info=True) + return jsonify({'error': str(e)}), 500 + + +@app.route('/api/event//related-stocks') +def get_event_related_stocks(event_id): + # sector_map:二级行业 → 一级行业 + + """获取事件相关股票列表""" + try: + event = Event.query.get_or_404(event_id) + related_stocks = event.related_stocks.order_by(RelatedStock.correlation.desc()).all() + sector_map = { + # 大周期 + '石油石化': '大周期', + '煤炭': '大周期', + '有色金属': '大周期', + '钢铁': '大周期', + '基础化工': '大周期', + '建筑材料': '大周期', + '机械设备': '大周期', + '电力设备及新能源': '大周期', + '国防军工': '大周期', + '电力设备': '大周期', + '电网设备': '大周期', + '风力发电': '大周期', + '太阳能发电': '大周期', + '建筑装饰': '大周期', + + # 大消费 + '汽车': '大消费', + '家用电器': '大消费', + '酒类': '大消费', + '食品饮料': '大消费', + '医药生物': '大消费', + '纺织服饰': '大消费', + '农林牧渔': '大消费', + '商贸零售': '大消费', + '轻工制造': '大消费', + '消费者服务': '大消费', + '美容护理': '大消费', + '社会服务': '大消费', + + # 大金融地产 + '银行': '大金融地产', + '证券': '大金融地产', + '保险': '大金融地产', + '多元金融': '大金融地产', + '综合金融': '大金融地产', + '房地产': '大金融地产', + '非银金融': '大金融地产', + + # TMT + '计算机': 'TMT板块', + '电子': 'TMT板块', + '传媒': 'TMT板块', + '通信': 'TMT板块', + + # 公共产业 + '交通运输': '公共产业板块', + '电力公用事业': '公共产业板块', + '建筑': '公共产业板块', + '环保': '公共产业板块', + '综合': '公共产业板块', + '公用事业': '公共产业板块', + } + + stocks_data = [] + for stock in related_stocks: + # 处理股票代码,移除可能的后缀 + base_stock_code = stock.stock_code.split('.')[0] + + # 查询申万行业分类 + sector_info = SectorInfo.query.filter( + SectorInfo.SECCODE.ilike(f"{base_stock_code}%"), # 使用ilike进行不区分大小写的匹配 + SectorInfo.F002V == '申银万国行业分类' + ).first() + + # 获取申万一级行业名称 + sw_sector = sector_info.F004V if sector_info else None + + # 确定primary_sector + primary_sector = sector_map.get(sw_sector, '未知') if sw_sector else sector_map.get(stock.sector, '未知') + + stocks_data.append({ + 'stock_code': stock.stock_code, + 'stock_name': stock.stock_name, + 'sector': stock.sector, + 'sw_sector': sw_sector, # 添加申万行业分类信息 + 'primary_sector': primary_sector, + 'relation_desc': stock.relation_desc, + 'correlation': stock.correlation, + 'momentum': stock.momentum + }) + + return jsonify({ + 'code': 200, + 'message': 'success', + 'data': stocks_data + }) + + except Exception as e: + print(f"Error in get_event_related_stocks: {str(e)}") + return jsonify({ + 'code': 500, + 'message': str(e), + 'data': None + }), 500 + + +def calculate_change_distribution(stocks_data): + """ + 计算涨跌幅分布统计 + + Args: + stocks_data: 包含股票代码和涨跌幅的数据列表 + + Returns: + dict: 涨跌幅分布统计 + """ + distribution = { + 'limit_down': 0, # 跌停 + 'down_over_5': 0, # 跌5%以上 + 'down_5_to_1': 0, # 跌5%到1% + 'down_within_1': 0, # 跌1%以内 + 'flat': 0, # 平盘±0% + 'up_within_1': 0, # 涨1%以内 + 'up_1_to_5': 0, # 涨1-5% + 'up_over_5': 0, # 涨5%以上不涨停 + 'limit_up': 0 # 涨停 + } + + for stock in stocks_data: + change = stock.get('daily_change', 0) + stock_code = stock.get('stock_code', '') + + # 判断涨跌停限制 + limit_rate = get_limit_rate(stock_code) + + # 判断涨停/跌停 (允许0.01%的误差) + if change <= -limit_rate + 0.01: + distribution['limit_down'] += 1 + elif change >= limit_rate - 0.01: + distribution['limit_up'] += 1 + elif change > 5: + distribution['up_over_5'] += 1 + elif change > 1: + distribution['up_1_to_5'] += 1 + elif change > 0.1: + distribution['up_within_1'] += 1 + elif change >= -0.1: + distribution['flat'] += 1 + elif change > -1: + distribution['down_within_1'] += 1 + elif change > -5: + distribution['down_5_to_1'] += 1 + else: + distribution['down_over_5'] += 1 + + return distribution + + def get_limit_rate(stock_code): """ 根据股票代码获取涨跌停限制比例 @@ -2010,12 +4006,9 @@ def get_limit_rate(stock_code): @app.route('/api/events', methods=['GET']) - - def api_get_events(): """ 获取事件列表API - 优化版本(保持完全兼容) - 仅限 Pro/Max 会员访问(小程序功能) 优化策略: 1. 使用ind_type字段简化内部逻辑 @@ -2040,15 +4033,11 @@ def api_get_events(): date_range = request.args.get('date_range') recent_days = request.args.get('recent_days', type=int) - # 行业筛选参数(重新设计) + # 行业筛选参数(兼容新旧版本) ind_type = request.args.get('ind_type', 'all') stock_sector = request.args.get('stock_sector', 'all') secondary_sector = request.args.get('secondary_sector', 'all') - # 新的行业层级筛选参数 - industry_level = request.args.get('industry_level', type=int) # 筛选层级:1-4 - industry_classification = request.args.get('industry_classification') # 行业名称 - # 如果使用旧参数,映射到ind_type if ind_type == 'all' and stock_sector != 'all': ind_type = stock_sector @@ -2149,143 +4138,6 @@ def api_get_events(): except ValueError: pass - # ==================== 行业层级筛选(申银万国行业分类) ==================== - - if industry_level and industry_classification: - # 排除行业分类体系名称本身,这些不是具体的行业 - classification_systems = [ - '申银万国行业分类', '中上协行业分类', '巨潮行业分类', - '新财富行业分类', '证监会行业分类', '证监会行业分类(2001)' - ] - - if industry_classification not in classification_systems: - # 根据层级和名称查询对应的行业代码 - # 前端发送的level值直接对应数据库字段: - # level=2 -> f004v(一级行业) - # level=3 -> f005v(二级行业) - # level=4 -> f006v(三级行业) - # level=5 -> f007v(四级行业) - level_column_map = { - 2: 'f004v', # level2 对应一级行业 - 3: 'f005v', # level3 对应二级行业 - 4: 'f006v', # level4 对应三级行业 - 5: 'f007v' # level5 对应四级行业 - } - - if industry_level in level_column_map: - level_column = level_column_map[industry_level] - - # 查询所有匹配该行业名称的代码 - sector_codes_sql = f""" - SELECT DISTINCT f003v - FROM ea_sector - WHERE f002v = '申银万国行业分类' - AND {level_column} = :industry_name - """ - - result = db.session.execute( - text(sector_codes_sql), - {'industry_name': industry_classification} - ) - - matching_codes = [row[0] for row in result.fetchall()] - - if matching_codes: - # 根据层级确定代码前缀长度 - # 申银万国代码规则:S + 2位一级 + 2位二级 + 2位三级 + 2位四级 - prefix_length_map = { - 2: 3, # level2: S + 2位(一级行业) - 3: 5, # level3: S + 2位 + 2位(二级行业) - 4: 7, # level4: S + 2位 + 2位 + 2位(三级行业) - 5: 9 # level5: 完整代码(四级行业) - } - - prefix_length = prefix_length_map.get(industry_level, 9) - - # 获取所有代码的共同前缀(用于模糊匹配) - code_prefixes = list(set([code[:prefix_length] for code in matching_codes if code])) - - if code_prefixes: - # 构建查询条件:查找related_industries中包含这些前缀的事件 - if isinstance(db.engine.dialect, MySQLDialect): - # MySQL JSON查询 - conditions = [] - for prefix in code_prefixes: - conditions.append( - text(""" - JSON_SEARCH( - related_industries, - 'one', - CONCAT(:prefix, '%'), - NULL, - '$[*]."申银万国行业分类"' - ) IS NOT NULL - """).params(prefix=prefix) - ) - - if conditions: - query = query.filter(or_(*conditions)) - else: - # 其他数据库 - pattern_conditions = [] - for prefix in code_prefixes: - pattern_conditions.append( - text("related_industries::text LIKE :pattern").params( - pattern=f'%"申银万国行业分类": "{prefix}%' - ) - ) - - if pattern_conditions: - query = query.filter(or_(*pattern_conditions)) - else: - # 没有找到匹配的行业代码,返回空结果 - query = query.filter(Event.id == -1) - else: - # 无效的层级参数 - app.logger.warning(f"Invalid industry_level: {industry_level}") - else: - # industry_classification 是分类体系名称,不进行筛选 - app.logger.info( - f"Skipping filter: industry_classification '{industry_classification}' is a classification system name") - - # ==================== 细分行业筛选(保留向后兼容) ==================== - - elif secondary_sector != 'all': - # 直接按行业名称查询(最后一级行业 - level5/f007v) - sector_code_query = db.session.query(text("DISTINCT f003v")).select_from( - text("ea_sector") - ).filter( - text("f002v = '申银万国行业分类' AND f007v = :sector_name") - ).params(sector_name=secondary_sector) - - sector_result = sector_code_query.first() - - if sector_result and sector_result[0]: - industry_code_to_search = sector_result[0] - - # 在related_industries JSON中查找包含该代码的事件 - if isinstance(db.engine.dialect, MySQLDialect): - query = query.filter( - text(""" - JSON_SEARCH( - related_industries, - 'one', - :industry_code, - NULL, - '$[*]."申银万国行业分类"' - ) IS NOT NULL - """) - ).params(industry_code=industry_code_to_search) - else: - query = query.filter( - text(""" - related_industries::text LIKE :pattern - """) - ).params(pattern=f'%"申银万国行业分类": "{industry_code_to_search}"%') - else: - # 如果没有找到对应的行业代码,返回空结果 - query = query.filter(Event.id == -1) - # ==================== 概念/标签筛选 ==================== # 单个标签筛选 @@ -2737,9 +4589,6 @@ def api_get_events(): applied_filters['start_date'] = start_date if end_date: applied_filters['end_date'] = end_date - if industry_level and industry_classification: - applied_filters['industry_level'] = industry_level - applied_filters['industry_classification'] = industry_classification if tag: applied_filters['tag'] = tag if tags: @@ -2793,6 +4642,7 @@ def api_get_events(): }), 500 + def get_filter_counts(base_query): """ 获取各个筛选条件的计数信息 @@ -2826,6 +4676,226 @@ def get_filter_counts(base_query): return {} +@app.route('/api/events/filters', methods=['GET']) +def api_get_event_filters(): + """ + 获取事件筛选选项的统计信息 + 用于前端动态生成筛选器选项 + """ + try: + # 基础查询 (只包含激活状态的事件) + base_query = Event.query.filter_by(status='active') + + filter_counts = get_filter_counts(base_query) + + return jsonify({ + 'success': True, + 'data': { + 'filter_counts': filter_counts, + 'available_sorts': [ + {'key': 'new', 'name': '最新', 'desc': '按创建时间排序'}, + {'key': 'hot', 'name': '热门', 'desc': '按热度分数排序'}, + {'key': 'returns', 'name': '收益率', 'desc': '按收益率排序'}, + {'key': 'importance', 'name': '重要性', 'desc': '按重要性等级排序'}, + {'key': 'view_count', 'name': '浏览量', 'desc': '按浏览次数排序'}, + ], + 'available_return_types': [ + {'key': 'avg', 'name': '平均收益率'}, + {'key': 'max', 'name': '最大收益率'}, + {'key': 'week', 'name': '周收益率'}, + ], + 'available_importance_levels': [ + {'key': 'S', 'name': 'S级', 'desc': '重大事件'}, + {'key': 'A', 'name': 'A级', 'desc': '重要事件'}, + {'key': 'B', 'name': 'B级', 'desc': '普通事件'}, + {'key': 'C', 'name': 'C级', 'desc': '参考事件'}, + ] + } + }) + except Exception as e: + app.logger.error(f"获取筛选选项出错: {str(e)}") + return jsonify({ + 'success': False, + 'error': str(e) + }), 500 + + +@app.route('/api/industry-classifications', methods=['GET']) +def api_get_industry_classifications(): + """ + 获取行业分类体系列表 + + 返回: + { + "success": true, + "data": [ + {"classification_name": "申万行业分类标准"}, + {"classification_name": "证监会行业分类标准"} + ] + } + """ + try: + # 获取所有行业分类体系 + sql = """ + SELECT DISTINCT f002v as classification_name + FROM ea_sector + WHERE f002v IS NOT NULL + ORDER BY f002v + """ + + results = db.session.execute(text(sql)).fetchall() + + classifications = [ + {'classification_name': row.classification_name} + for row in results + ] + + return jsonify({ + 'success': True, + 'data': classifications + }) + + except Exception as e: + app.logger.error(f"获取行业分类体系出错: {str(e)}") + return jsonify({ + 'success': False, + 'error': str(e) + }), 500 + + +@app.route('/api/industry-level-codes') +def industry_level_codes_api(): + """ + API端点: 根据行业分类系统和层级获取行业代码和名称 + + 参数: + classification: 行业分类系统名称 + level: 行业层级 (1-4) + level1_name: 一级行业名称 (仅对level>1有效) + level2_name: 二级行业名称 (仅对level>2有效) + level3_name: 三级行业名称 (仅对level>3有效) + """ + classification = request.args.get('classification') + level = request.args.get('level', type=int) + level1_name = request.args.get('level1_name', '') + level2_name = request.args.get('level2_name', '') + level3_name = request.args.get('level3_name', '') + + # 验证参数 + if not classification or not level or level < 1 or level > 4: + return jsonify([]) + + try: + # 层级到字段的映射 + level_fields = { + 1: "f004v", # 一级行业 + 2: "f005v", # 二级行业 + 3: "f006v", # 三级行业 + 4: "f007v", # 四级行业 + } + + field_name = level_fields[level] + + # 根据层级选择不同的查询 + if level == 1: + # 一级行业查询 + sql = f""" + SELECT DISTINCT {field_name} as name, + MIN(f003v) as code + FROM ea_sector + WHERE f002v = :classification + AND {field_name} IS NOT NULL + GROUP BY name + ORDER BY name + """ + params = {"classification": classification} + + elif level == 2: + # 二级行业查询 + sql = f""" + SELECT DISTINCT {field_name} as name, + MIN(f003v) as code + FROM ea_sector + WHERE f002v = :classification + AND f004v = :level1_name + AND {field_name} IS NOT NULL + GROUP BY name + ORDER BY name + """ + params = {"classification": classification, "level1_name": level1_name} + + elif level == 3: + # 三级行业查询 + sql = f""" + SELECT DISTINCT {field_name} as name, + MIN(f003v) as code + FROM ea_sector + WHERE f002v = :classification + AND f004v = :level1_name + AND f005v = :level2_name + AND {field_name} IS NOT NULL + GROUP BY name + ORDER BY name + """ + params = { + "classification": classification, + "level1_name": level1_name, + "level2_name": level2_name + } + + elif level == 4: + # 四级行业查询 + sql = f""" + SELECT DISTINCT f003v as code, + {field_name} as name + FROM ea_sector + WHERE f002v = :classification + AND f004v = :level1_name + AND f005v = :level2_name + AND f006v = :level3_name + AND {field_name} IS NOT NULL + ORDER BY name + """ + params = { + "classification": classification, + "level1_name": level1_name, + "level2_name": level2_name, + "level3_name": level3_name + } + + # 执行查询 + results = db.session.execute(text(sql), params).all() + + # 转换为JSON响应 + return jsonify([{"code": row.code, "name": row.name} for row in results if row.name]) + + except Exception as e: + app.logger.error(f"获取行业代码出错: {str(e)}") + return jsonify({"error": str(e)}), 500 + + +@app.route('/trending') +def trending_events(): + """获取趋势事件(用于首页的趋势板块)""" + # 获取24小时内的热门事件 + recent_events = Event.query.filter( + Event.created_at >= beijing_now() - timedelta(days=1), + Event.status == 'active' + ).order_by( + Event.hot_score.desc() + ).limit(10).all() + + events_data = [{ + 'id': event.id, + 'title': event.title, + 'description': event.description[:100] + '...' if len(event.description) > 100 else event.description, + 'hot_score': event.hot_score, + 'post_count': event.post_count, + 'follower_count': event.follower_count + } for event in recent_events] + + return jsonify({'events': events_data}) + def get_event_class(count): """根据事件数量返回对应的样式类""" if count >= 10: @@ -2837,8 +4907,6 @@ def get_event_class(count): else: return 'bg-gradient-success' @app.route('/api/calendar-event-counts') - - def get_calendar_event_counts(): """获取整月的事件数量统计,仅统计type为event的事件""" try: @@ -2928,8 +4996,6 @@ def to_dict(self): # 1. 首页接口 @app.route('/api/home', methods=['GET']) - - def api_home(): try: seven_days_ago = datetime.now() - timedelta(days=7) @@ -3056,517 +5122,504 @@ def api_home(): "data": None }), 500 -@app.route('/api/auth/logout', methods=['POST']) -def logout_with_token(): - """使用token登出""" - # 从请求头获取token - auth_header = request.headers.get('Authorization') - if auth_header and auth_header.startswith('Bearer '): - token = auth_header[7:] - else: - data = request.get_json() - token = data.get('token') if data else None - if token and token in user_tokens: - del user_tokens[token] +def generate_jwt_token(user_id): + payload = { + 'user_id': user_id, + 'exp': datetime.utcnow() + timedelta(seconds=JWT_EXPIRES_SECONDS) + } + return jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM) - # 清除session - session.clear() - return jsonify({'message': '登出成功'}), 200 -def send_sms_code(phone, code, template_id): - """发送短信验证码""" - try: - cred = credential.Credential(SMS_SECRET_ID, SMS_SECRET_KEY) - httpProfile = HttpProfile() - httpProfile.endpoint = "sms.tencentcloudapi.com" +def save_sms_code_to_redis(phone, code, purpose='login', expire=600): + key = f"sms:{purpose}:{phone}" + redis_client.setex(key, expire, code) - clientProfile = ClientProfile() - clientProfile.httpProfile = httpProfile - client = sms_client.SmsClient(cred, "ap-beijing", clientProfile) - req = models.SendSmsRequest() - params = { - "PhoneNumberSet": [phone], - "SmsSdkAppId": SMS_SDK_APP_ID, - "TemplateId": template_id, - "SignName": SMS_SIGN_NAME, - "TemplateParamSet": [code] if template_id == SMS_TEMPLATE_REGISTER else [code, "5"] - } - req.from_json_string(json.dumps(params)) +def verify_sms_code(phone, code, purpose='login'): + key = f"sms:{purpose}:{phone}" + stored_code = redis_client.get(key) + return stored_code == code - resp = client.SendSms(req) - return True - except TencentCloudSDKException as err: - print(f"SMS Error: {err}") - return False - -def generate_verification_code(): - """生成6位数字验证码""" - return ''.join(random.choices(string.digits, k=6)) @app.route('/api/auth/send-sms', methods=['POST']) -def send_sms_verification(): - """发送手机验证码(统一接口,自动判断场景)""" +def api_send_sms_code(): data = request.get_json() phone = data.get('phone') + purpose = data.get('purpose', 'login') - if not phone: - return jsonify({'error': '手机号不能为空'}), 400 + if not phone or not re.match(r'^1[3-9]\d{9}$', phone): + return jsonify({'code': 400, 'message': '请输入正确手机号'}), 400 - # 检查手机号是否已注册 - user_exists = User.query.filter_by(phone=phone).first() is not None + # 限制发送频率(60秒内不能重复发送) + freq_key = f"sms:freq:{phone}" + if redis_client.exists(freq_key): + return jsonify({'code': 429, 'message': '请勿频繁请求验证码'}), 429 - # 生成验证码 - code = generate_verification_code() + # 检查用途:登录必须是已注册用户 + user = User.query.filter_by(phone=phone).first() + if purpose == 'login' and not user: + return jsonify({'code': 404, 'message': '该手机号尚未注册'}), 404 - # 根据用户是否存在自动选择模板 - template_id = SMS_TEMPLATE_LOGIN if user_exists else SMS_TEMPLATE_REGISTER + # 发送短信并获取验证码 + success, msg, verification_code = send_sms_verification_minimal(phone, redis_client) + if not success: + return jsonify({'code': 500, 'message': msg}), 500 - # 发送短信 - if send_sms_code(phone, code, template_id): - # 统一存储验证码(5分钟有效) - verification_codes[phone] = { - 'code': code, - 'expires': time.time() + 300 - } + # 保存验证码到 Redis(这个验证码会被登录接口使用) + save_sms_code_to_redis(phone, verification_code, purpose) + redis_client.setex(freq_key, 60, '1') # 60秒限制 - # 简单返回成功,不暴露用户是否存在的信息 - return jsonify({ - 'message': '验证码已发送', - 'expires_in': 300 # 告诉前端验证码有效期(秒) - }), 200 - else: - return jsonify({'error': '验证码发送失败'}), 500 - - -def generate_token(length=32): - """生成随机token""" - characters = string.ascii_letters + string.digits - return ''.join(secrets.choice(characters) for _ in range(length)) + # 开发环境下返回验证码用于测试 + response_data = {'code': 200, 'message': '验证码已发送'} + if app.debug: # 只在调试模式下返回验证码 + response_data['debug_code'] = verification_code + return jsonify(response_data) @app.route('/api/auth/login/phone', methods=['POST']) -def login_with_phone(): - """统一的手机号登录/注册接口""" +def login_by_phone(): + """手机验证码登录""" data = request.get_json() phone = data.get('phone') code = data.get('code') - username = data.get('username') # 可选,新用户可以提供 - password = data.get('password') # 可选,新用户可以提供 - if not all([phone, code]): - return jsonify({ - 'code': 400, - 'message': '手机号和验证码不能为空' - }), 400 + if not phone or not code: + return jsonify({'code': 400, 'message': '手机号和验证码不能为空'}), 400 # 验证验证码 - stored_code = verification_codes.get(phone) + if not verify_sms_code(phone, code): + return jsonify({'code': 400, 'message': '验证码错误或已过期'}), 400 - if not stored_code or stored_code['expires'] < time.time(): - return jsonify({ - 'code': 400, - 'message': '验证码已过期' - }), 400 - - if stored_code['code'] != code: - return jsonify({ - 'code': 400, - 'message': '验证码错误' - }), 400 - - try: - # 查找用户 - user = User.query.filter_by(phone=phone).first() - is_new_user = False - - # 如果用户不存在,自动注册 - if not user: - is_new_user = True - - # 如果提供了用户名,检查是否已存在 - if username: - if User.query.filter_by(username=username).first(): - return jsonify({ - 'code': 400, - 'message': '用户名已被使用,请换一个' - }), 400 - else: - # 自动生成用户名 - base_username = f"user_{phone[-4:]}" - username = base_username - counter = 1 - while User.query.filter_by(username=username).first(): - username = f"{base_username}_{counter}" - counter += 1 - - # 创建新用户 - user = User(username=username, phone=phone) - user.email = f"{username}@valuefrontier.temp" - - # 如果提供了密码就使用,否则生成随机密码 - if password: - user.set_password(password) - else: - random_password = generate_token(16) - user.set_password(random_password) - - user.phone_confirmed = True - - db.session.add(user) - db.session.commit() - - # 生成token - token = generate_token(32) - - # 存储token映射(30天有效期) - user_tokens[token] = { - 'user_id': user.id, - 'expires': datetime.now() + timedelta(days=30) - } - - # 清除验证码 - del verification_codes[phone] - - # 设置session(保持与原有逻辑兼容) - session.permanent = True - session['user_id'] = user.id - session['username'] = user.username - session['logged_in'] = True - - # 返回响应 - response_data = { - 'code': 0, - 'message': '欢迎回来' if not is_new_user else '注册成功,欢迎加入', - 'token': token, - 'is_new_user': is_new_user, # 告诉前端是否是新用户 - 'user': { - 'id': user.id, - 'username': user.username, - 'phone': user.phone, - 'need_complete_profile': is_new_user # 提示新用户完善资料 - } - } - - return jsonify(response_data), 200 - - except Exception as e: - db.session.rollback() - print(f"Login/Register error: {e}") - return jsonify({ - 'code': 500, - 'message': '操作失败,请重试' - }), 500 - - -@app.route('/api/auth/verify-token', methods=['POST']) -def verify_token(): - """验证token有效性(可选接口)""" - data = request.get_json() - token = data.get('token') - - if not token: - return jsonify({'valid': False, 'message': 'Token不能为空'}), 400 - - token_data = user_tokens.get(token) - - if not token_data: - return jsonify({'valid': False, 'message': 'Token无效','code':401}), 401 - - # 检查是否过期 - if token_data['expires'] < datetime.now(): - del user_tokens[token] - return jsonify({'valid': False, 'message': 'Token已过期'}), 401 - - # 获取用户信息 - user = User.query.get(token_data['user_id']) + # 查找用户 + user = User.query.filter_by(phone=phone).first() if not user: - return jsonify({'valid': False, 'message': '用户不存在'}), 404 + return jsonify({'code': 404, 'message': '该手机号尚未注册'}), 404 + + # 更新用户最后登录时间 + user.update_last_seen() + db.session.commit() + + # 生成JWT token + token = generate_jwt_token(user.id) return jsonify({ - 'valid': True, - 'user': { - 'id': user.id, - 'username': user.username, - 'phone': user.phone + 'code': 200, + 'message': '登录成功', + 'data': { + 'token': token, + 'user': user.to_dict() } - }), 200 - - - - -def generate_jwt_token(user_id): - """ - 生成JWT Token - 与原系统保持一致 - - Args: - user_id: 用户ID - - Returns: - str: JWT token字符串 - """ - payload = { - 'user_id': user_id, - 'exp': datetime.utcnow() + timedelta(hours=JWT_EXPIRATION_HOURS), - 'iat': datetime.utcnow() - } - - token = jwt.encode(payload, JWT_SECRET_KEY, algorithm=JWT_ALGORITHM) - return token - - + }) @app.route('/api/auth/login/wechat', methods=['POST']) def api_login_wechat(): + """微信登录接口""" try: - # 1. 获取请求数据 data = request.get_json() - code = data.get('code') if data else None + code = data.get('code') # 微信授权码 + union_id = data.get('unionId') # 微信 UnionID - if not code: + if not code and not union_id: return jsonify({ 'code': 400, 'message': '缺少必要的参数', 'data': None }), 400 - # 2. 验证code格式 - if not isinstance(code, str) or len(code) < 10: + # 1. 通过code获取access_token和openid + # TODO: 需要您提供以下信息: + # - WECHAT_APP_ID: 微信开放平台应用ID + # - WECHAT_APP_SECRET: 微信开放平台应用密钥 + wx_api_url = f"https://api.weixin.qq.com/sns/oauth2/access_token?appid={WECHAT_APP_ID}&secret={WECHAT_APP_SECRET}&code={code}&grant_type=authorization_code" + + # 2. 获取用户信息 + user = None + if union_id: + # 通过 union_id 查找用户 + user = User.query.filter_by(wechat_union_id=union_id).first() + + if not user: + # 创建新用户 + username = f"wx_user_{int(time.time())}" # 生成临时用户名 + user = User( + username=username, + wechat_union_id=union_id, + status='active' + ) + db.session.add(user) + db.session.commit() + + # 生成JWT token + token = generate_jwt_token(user.id) + + return jsonify({ + 'code': 200, + 'message': 'success', + 'data': { + 'token': token, + 'user': { + 'id': user.id, + 'username': user.username, + 'nickname': user.nickname, + 'avatar_url': get_full_avatar_url(user.avatar_url), # 修改这里 + 'is_new_user': user.created_at > (datetime.now() - timedelta(minutes=1)) + } + } + }) + + except Exception as e: + return jsonify({ + 'code': 500, + 'message': str(e), + 'data': None + }), 500 + + +# 获取微信登录二维码 +# 获取微信登录二维码 +@app.route('/api/wechat/qrcode', methods=['GET']) +def get_wechat_qrcode(): + """获取微信登录二维码""" + try: + # 生成唯一的state参数 + state = str(uuid.uuid4()) + + # 检查必要的配置 + if not app.config.get('WECHAT_APP_ID'): return jsonify({ - 'code': 400, - 'message': 'code格式无效', + 'code': 500, + 'message': '微信配置未设置', 'data': None - }), 400 + }), 500 - logger.info(f"开始处理微信登录,code长度: {len(code)}") + # 将state保存到session或Redis中 + # 如果有Redis缓存系统: + # cache.set(f"wx_state_{state}", True, timeout=300) + # 如果没有缓存系统,可以使用session: + from flask import session + session[f"wx_state_{state}"] = True - # 3. 调用微信接口获取用户信息 - wx_api_url = 'https://api.weixin.qq.com/sns/jscode2session' - params = { - 'appid': WECHAT_APP_ID, - 'secret': WECHAT_APP_SECRET, - 'js_code': code, + # 构建微信授权URL + base_url = app.config.get('BASE_URL', 'http://43.143.189.195:5002') + redirect_uri = f"{base_url}/api/wechat/callback" + + qrcode_url = f"https://open.weixin.qq.com/connect/qrconnect?appid={app.config['WECHAT_APP_ID']}&redirect_uri={redirect_uri}&response_type=code&scope=snsapi_login&state={state}#wechat_redirect" + + return jsonify({ + 'code': 200, + 'message': 'success', + 'data': { + 'qr_url': qrcode_url, # 修改为 qr_url 匹配前端 + 'state': state + } + }) + except Exception as e: + return jsonify({ + 'code': 500, + 'message': f'生成二维码失败: {str(e)}', + 'data': None + }), 500 + + +# 检查登录状态 +# 生成微信登录二维码 - 新版本 +@app.route('/api/wechat/generate-login-qr', methods=['POST']) +def generate_wechat_login_qr(): + """生成微信登录二维码""" + try: + # 生成唯一标识 + state = str(uuid.uuid4()) + + # 保存到session + session[f"wx_state_{state}"] = { + 'created_at': datetime.now(), + 'status': 'pending' + } + + # 构建登录URL - 这是用户扫码后访问的页面 + base_url = app.config.get('BASE_URL', 'http://43.143.189.195:5002') + login_url = f"{base_url}/wechat-auth?state={state}" + + return jsonify({ + 'code': 200, + 'message': 'success', + 'data': { + 'login_url': login_url, + 'state': state + } + }) + except Exception as e: + return jsonify({ + 'code': 500, + 'message': str(e), + 'data': None + }), 500 + + +# 微信授权页面 - 用户扫码后首先访问这里 +@app.route('/wechat-auth') +def wechat_auth_page(): + """微信授权页面 - 用户扫码后访问""" + state = request.args.get('state') + + if not state or not session.get(f"wx_state_{state}"): + return render_template_string(""" + + 登录失败 + +

登录链接无效

+

请重新扫描二维码

+ + + """) + + # 构建微信授权URL - 使用网页授权接口 + app_id = app.config['WECHAT_APP_ID'] + base_url = app.config.get('BASE_URL', 'http://43.143.189.195:5002') + redirect_uri = quote(f"{base_url}/api/wechat/callback") + + # 使用网页授权接口(snsapi_userinfo) + auth_url = f"https://open.weixin.qq.com/connect/oauth2/authorize?appid={app_id}&redirect_uri={redirect_uri}&response_type=code&scope=snsapi_userinfo&state={state}#wechat_redirect" + + # 直接重定向到微信授权 + return redirect(auth_url) + + +# 微信回调处理 - 完整版本 +@app.route('/api/wechat/callback') +def wechat_callback(): + """微信登录回调处理""" + try: + code = request.args.get('code') + state = request.args.get('state') + + app.logger.info(f"微信回调: code={code}, state={state}") + + if not code or not state: + return render_template_string(""" + + 登录失败 + +

登录失败

+

参数错误,请重新尝试

+ + + + """) + + # 验证state + state_info = session.get(f"wx_state_{state}") + if not state_info: + return render_template_string(""" + + 登录失败 + +

登录失败

+

无效的登录状态

+ + + + """) + + # 1. 获取access_token + token_url = "https://api.weixin.qq.com/sns/oauth2/access_token" + token_params = { + 'appid': app.config['WECHAT_APP_ID'], + 'secret': app.config['WECHAT_APP_SECRET'], + 'code': code, 'grant_type': 'authorization_code' } - try: - response = requests.get(wx_api_url, params=params, timeout=10) - response.raise_for_status() - wx_data = response.json() + app.logger.info(f"请求微信token: {token_params}") - # 检查微信API返回的错误 - if 'errcode' in wx_data and wx_data['errcode'] != 0: - error_messages = { - -1: '系统繁忙,请稍后重试', - 40029: 'code无效或已过期', - 45011: '频率限制,请稍后再试', - 40013: 'AppID错误', - 40125: 'AppSecret错误', - 40226: '高风险用户,登录被拦截' - } + response = requests.get(token_url, params=token_params, timeout=10) + token_data = response.json() - error_msg = error_messages.get( - wx_data['errcode'], - f"微信接口错误: {wx_data.get('errmsg', '未知错误')}" - ) + app.logger.info(f"微信token响应: {token_data}") - logger.error(f"WeChat API error {wx_data['errcode']}: {error_msg}") + if 'errcode' in token_data: + error_msg = token_data.get('errmsg', '未知错误') + app.logger.error(f"获取微信token失败: {error_msg}") + return render_template_string(f""" + + 登录失败 + +

登录失败

+

获取授权失败: {error_msg}

+ + + + """) - return jsonify({ - 'code': 400, - 'message': error_msg, - 'data': None - }), 400 + access_token = token_data['access_token'] + openid = token_data['openid'] - # 验证必需字段 - if 'openid' not in wx_data or 'session_key' not in wx_data: - logger.error("响应缺少必需字段") - return jsonify({ - 'code': 500, - 'message': '微信响应格式错误', - 'data': None - }), 500 + # 2. 获取用户信息 + user_info_url = "https://api.weixin.qq.com/sns/userinfo" + user_params = { + 'access_token': access_token, + 'openid': openid, + 'lang': 'zh_CN' + } - openid = wx_data['openid'] - session_key = wx_data['session_key'] - unionid = wx_data.get('unionid') # 可能为None + user_response = requests.get(user_info_url, params=user_params, timeout=10) + user_info = user_response.json() - logger.info(f"成功获取微信用户信息 - OpenID: {openid[:8]}...") - if unionid: - logger.info(f"获取到UnionID: {unionid[:8]}...") + app.logger.info(f"微信用户信息: {user_info}") - except requests.exceptions.Timeout: - logger.error("请求微信API超时") - return jsonify({ - 'code': 500, - 'message': '请求超时,请重试', - 'data': None - }), 500 - except requests.exceptions.RequestException as e: - logger.error(f"网络请求失败: {str(e)}") - return jsonify({ - 'code': 500, - 'message': '网络错误', - 'data': None - }), 500 + if 'errcode' in user_info: + error_msg = user_info.get('errmsg', '未知错误') + app.logger.error(f"获取微信用户信息失败: {error_msg}") + return render_template_string(f""" + + 登录失败 + +

登录失败

+

获取用户信息失败: {error_msg}

+ + + + """) - # 4. 查找或创建用户 - 核心逻辑 - user = None - is_new_user = False + # 3. 处理用户登录逻辑 + unionid = user_info.get('unionid', openid) # 如果没有unionid就用openid + user = User.query.filter_by(wechat_union_id=unionid).first() - - logger.info(f"开始查找用户 - UnionID: {unionid}, OpenID: {openid[:8]}...") - - if unionid: - # 情况1: 有unionid,优先通过unionid查找 - user = User.query.filter_by(wechat_union_id=unionid).first() - - if user: - logger.info(f"通过UnionID找到用户: {user.username}") - # 更新openid(可能用户从不同小程序登录) - if user.wechat_open_id != openid: - user.wechat_open_id = openid - logger.info(f"更新用户OpenID: {openid[:8]}...") - else: - # unionid没找到,再尝试用openid查找(处理历史数据) - user = User.query.filter_by(wechat_open_id=openid).first() - if user: - logger.info(f"通过OpenID找到用户: {user.username}") - # 补充unionid - user.wechat_union_id = unionid - logger.info(f"为用户补充UnionID: {unionid[:8]}...") - else: - # 情况2: 没有unionid,只能通过openid查找 - logger.warning("未获取到UnionID(小程序可能未绑定开放平台)") - user = User.query.filter_by(wechat_open_id=openid).first() - if user: - logger.info(f"通过OpenID找到用户: {user.username}") - - # 5. 创建新用户 if not user: - is_new_user = True + # 创建新用户 + username = f"wx_user_{int(time.time())}" + email = f"wx_{openid}@temp.com" - # 生成唯一用户名 - timestamp = int(time.time()) - username = f"wx_{timestamp}_{openid[-6:]}" - - # 确保用户名唯一 - counter = 0 - base_username = username - while User.query.filter_by(username=username).first(): - counter += 1 - username = f"{base_username}_{counter}" - - # 创建用户对象(使用你的User模型) user = User( username=username, - email=f"{username}@wechat.local", # 占位邮箱 - password="wechat_login_no_password" # 微信登录不需要密码 + email=email, + password_hash=generate_password_hash(''), # 空密码 + wechat_union_id=unionid, + wechat_open_id=openid, + nickname=user_info.get('nickname', username), + avatar_url=get_full_avatar_url(user_info.get('headimgurl')), + gender='male' if user_info.get('sex') == 1 else 'female', + status='active', + email_confirmed=True # 微信用户默认已验证 ) - - # 设置微信相关字段 - user.wechat_open_id = openid - user.wechat_union_id = unionid - user.status = 'active' - user.email_confirmed = False - - # 设置默认值 - user.nickname = f"微信用户{openid[-4:]}" - user.bio = "" # 空的个人简介 - user.avatar_url = None # 稍后会处理 - user.is_creator = False - user.is_verified = False - user.user_level = 1 - user.reputation_score = 0 - user.contribution_point = 0 - user.post_count = 0 - user.comment_count = 0 - user.follower_count = 0 - user.following_count = 0 - - # 设置默认偏好 - user.email_notifications = True - user.privacy_level = 'public' - user.theme_preference = 'light' - db.session.add(user) - logger.info(f"创建新用户: {username}") - else: - # 更新最后登录时间 - user.update_last_seen() - logger.info(f"用户登录: {user.username}") - - # 6. 提交数据库更改 - try: db.session.commit() - except Exception as e: - db.session.rollback() - logger.error(f"保存用户信息失败: {str(e)}") - return jsonify({ - 'code': 500, - 'message': '保存用户信息失败', - 'data': None - }), 500 + app.logger.info(f"创建新用户: {user.username}") - # 7. 生成JWT token(使用原系统的生成方法) - token = generate_token(32) # 使用相同的随机字符串生成器 + # 4. 生成JWT token + token = generate_jwt_token(user.id) - # 存储token映射(与手机登录保持一致) - user_tokens[token] = { - 'user_id': user.id, - 'expires': datetime.now() + timedelta(days=30) # 30天有效期 + # 5. 更新用户最后活跃时间 + user.update_last_seen() + db.session.commit() + + # 6. 保存登录状态到session + login_info = { + 'status': 'completed', + 'token': token, + 'user': { + 'id': user.id, + 'username': user.username, + 'nickname': user.nickname, + 'avatar_url': get_full_avatar_url(user.avatar_url), # 修改这里 + 'is_new_user': user.created_at > (datetime.now() - timedelta(minutes=1)) + } } + session[f"wx_state_{state}"] = login_info - # 设置session(可选,保持与手机登录一致) - session.permanent = True - session['user_id'] = user.id - session['username'] = user.username - session['logged_in'] = True + app.logger.info(f"微信登录成功: {user.username}") - # 9. 构造返回数据 - 完全匹配要求的格式 - response_data = { - 'code': 200, - 'data': { - 'token': token, # 现在这个token能被token_required识别了 - 'user': { - 'avatar_url': get_full_avatar_url(user.avatar_url), - 'bio': user.bio or "", - 'email': user.email, - 'id': user.id, - 'is_creator': user.is_creator, - 'is_verified': user.is_verified, - 'nickname': user.nickname or user.username, - 'reputation_score': user.reputation_score, - 'user_level': user.user_level, - 'username': user.username + # 7. 返回成功页面,包含token传递逻辑 + return render_template_string(""" + + + 登录成功 + + + + +
+
+

登录成功

+

欢迎 {{ user.nickname }}

+

正在跳转到首页...

+
+ + + + """, user=user, token=token, user_json=json.dumps(login_info['user'])) except Exception as e: - # 捕获所有未处理的异常 - logger.error(f"微信登录处理异常: {str(e)}", exc_info=True) - db.session.rollback() + app.logger.error(f"微信回调错误: {str(e)}") + return render_template_string(f""" + + 登录失败 + +

登录失败

+

系统错误: {str(e)}

+ + + + """), 500 - return jsonify({ - 'code': 500, - 'message': '服务器内部错误', - 'data': None - }), 500 + +# 检查登录状态接口 - 修改版 +@app.route('/api/wechat/check-login', methods=['POST']) +def check_wechat_login(): + """检查微信登录状态""" + try: + data = request.get_json() + state = data.get('state') + + if not state: + return jsonify({'code': 400, 'message': '参数错误'}), 400 + + state_info = session.get(f"wx_state_{state}") + if not state_info: + return jsonify({'code': 400, 'message': '无效状态'}), 400 + + if state_info.get('status') == 'completed': + # 清除session + session.pop(f"wx_state_{state}", None) + + return jsonify({ + 'code': 200, + 'message': 'success', + 'data': { + 'status': 'authorized', + 'token': state_info['token'], + 'user': state_info['user'] + } + }) + else: + return jsonify({ + 'code': 202, + 'message': '等待扫码', + 'data': {'status': 'pending'} + }) + + except Exception as e: + return jsonify({'code': 500, 'message': str(e)}), 500 @app.route('/api/auth/login/email', methods=['POST']) @@ -3616,12 +5669,45 @@ def api_login_email(): }), 500 +@app.route('/api/all-industry-data') +def get_all_industry_data(): + """获取所有行业分类数据""" + try: + query = """ + SELECT DISTINCT + f002v as classification_name, + f003v as code, + f004v as level1, + f005v as level2, + f006v as level3, + f007v as level4 + FROM ea_sector + WHERE f002v NOT IN ('指数成份股', '市场分类', '概念板块', '地区省市分类', '中上协行业分类') + ORDER BY f003v + """ + + with engine.connect() as conn: + result = conn.execute(text(query)) + data = [dict(row) for row in result] + + return jsonify({ + "code": 200, + "message": "success", + "data": data + }) + except Exception as e: + app.logger.error(f"获取行业数据出错: {str(e)}") + return jsonify({ + "code": 500, + "message": str(e), + "data": None + }), 500 + + # 5. 事件详情-相关标的接口 @app.route('/api/event//related-stocks-detail', methods=['GET']) - - def api_event_related_stocks(event_id): - """事件相关标的详情接口 - 仅限 Pro/Max 会员""" + """事件相关标的详情接口""" try: event = Event.query.get_or_404(event_id) related_stocks = event.related_stocks.order_by(RelatedStock.correlation.desc()).all() @@ -3807,11 +5893,9 @@ def api_event_related_stocks(event_id): @app.route('/api/stock//minute-chart', methods=['GET']) - - def get_minute_chart_data(stock_code): - """获取股票分时图数据 - 仅限 Pro/Max 会员""" client = get_clickhouse_client() + """获取股票分时图数据""" try: # 获取当前日期或最新交易日的分时数据 from datetime import datetime, timedelta, time as dt_time @@ -3882,18 +5966,16 @@ def get_minute_chart_data(stock_code): return [] +# 6. 事件详情-个股详情接口(增强版) @app.route('/api/event//stock//detail', methods=['GET']) - - def api_stock_detail(event_id, stock_code): - """个股详情接口 - 仅限 Pro/Max 会员""" + """个股详情接口""" try: # 验证事件是否存在 event = Event.query.get_or_404(event_id) # 获取查询参数 include_minute_data = request.args.get('include_minute_data', 'true').lower() == 'true' - include_full_sources = request.args.get('include_full_sources', 'false').lower() == 'true' # 是否包含完整研报来源 # 获取股票基本信息 basic_info = None @@ -3947,76 +6029,12 @@ def api_stock_detail(event_id, stock_code): related_desc = None if related_stock: - # 处理研报来源数据 - retrieved_sources_data = None - sources_summary = None - - if related_stock.retrieved_sources: - try: - # 解析研报来源 - import json - sources = related_stock.retrieved_sources if isinstance(related_stock.retrieved_sources, - list) else json.loads( - related_stock.retrieved_sources) - - # 统计信息 - sources_summary = { - 'total_count': len(sources), - 'has_sources': True, - 'match_scores': {} - } - - # 统计匹配分数分布 - for source in sources: - score = source.get('match_score', '未知') - sources_summary['match_scores'][score] = sources_summary['match_scores'].get(score, 0) + 1 - - # 根据参数决定返回完整数据还是摘要 - if include_full_sources: - # 返回完整的研报来源 - retrieved_sources_data = sources - else: - # 只返回前5条高质量来源作为预览 - # 优先返回匹配度高的 - high_quality_sources = [s for s in sources if s.get('match_score') == '好'][:3] - medium_quality_sources = [s for s in sources if s.get('match_score') == '中'][:2] - - preview_sources = high_quality_sources + medium_quality_sources - if not preview_sources: # 如果没有高中匹配度的,返回前5条 - preview_sources = sources[:5] - - retrieved_sources_data = [] - for source in preview_sources: - retrieved_sources_data.append({ - 'report_title': source.get('report_title', ''), - 'author': source.get('author', ''), - 'sentences': source.get('sentences', '')[:200] + '...' if len( - source.get('sentences', '')) > 200 else source.get('sentences', ''), # 限制长度 - 'match_score': source.get('match_score', ''), - 'declare_date': source.get('declare_date', '') - }) - - except Exception as e: - print(f"Error processing retrieved_sources for stock {stock_code}: {e}") - sources_summary = {'has_sources': False, 'error': str(e)} - else: - sources_summary = {'has_sources': False, 'total_count': 0} - related_desc = { 'event_id': related_stock.event_id, 'relation_desc': related_stock.relation_desc, 'sector': related_stock.sector, 'correlation': float(related_stock.correlation) if related_stock.correlation else None, - 'momentum': related_stock.momentum, - - # 新增研报来源相关字段 - 'retrieved_sources': retrieved_sources_data, - 'sources_summary': sources_summary, - 'retrieved_update_time': related_stock.retrieved_update_time.isoformat() if related_stock.retrieved_update_time else None, - - # 添加获取完整来源的URL - 'sources_detail_url': f"/api/event/{event_id}/stock/{stock_code}/sources" if sources_summary.get( - 'has_sources') else None + 'momentum': related_stock.momentum } response_data = { @@ -4071,6 +6089,7 @@ def api_stock_detail(event_id, stock_code): 'data': None }), 500 + def get_stock_minute_chart_data(stock_code): """获取股票分时图数据""" try: @@ -4144,8 +6163,6 @@ def get_stock_minute_chart_data(stock_code): # 7. 事件详情-相关概念接口 @app.route('/api/event//related-concepts', methods=['GET']) - - def api_event_related_concepts(event_id): """事件相关概念接口""" try: @@ -4187,8 +6204,6 @@ def api_event_related_concepts(event_id): # 8. 事件详情-历史事件接口 @app.route('/api/event//historical-events', methods=['GET']) - - def api_event_historical_events(event_id): """事件历史事件接口""" try: @@ -4288,8 +6303,7 @@ def api_event_historical_events(event_id): @app.route('/api/event//comments', methods=['GET']) - - +@token_required def get_event_comments(event_id): """获取事件的所有评论和帖子(嵌套格式) @@ -4543,8 +6557,7 @@ def get_event_comments(event_id): @app.route('/api/comment//replies', methods=['GET']) - - +@token_required def get_comment_replies(comment_id): """获取某条评论的所有回复 @@ -4687,10 +6700,43 @@ def get_comment_replies(comment_id): }), 500 +# 9. 事件详情-关联数据接口 +@app.route('/api/event//related-data-list', methods=['GET']) +def api_event_related_data(event_id): + """事件关联数据接口""" + try: + event = Event.query.get_or_404(event_id) + related_data = event.related_data.all() + + data_list = [{ + 'id': data.id, + 'title': data.title, + 'data_type': data.data_type, + 'data_content': data.data_content, + 'description': data.description, + 'created_at': data.created_at.isoformat() if data.created_at else None + } for data in related_data] + + return jsonify({ + 'code': 200, + 'message': 'success', + 'data': { + 'event_id': event_id, + 'event_title': event.title, + 'related_data': data_list, + 'total_count': len(data_list) + } + }) + except Exception as e: + return jsonify({ + 'code': 500, + 'message': str(e), + 'data': None + }), 500 + + # 10. 投资日历-事件接口(增强版) @app.route('/api/calendar/events', methods=['GET']) - - def api_calendar_events(): """投资日历事件接口 - 连接 future_events 表 (修正版)""" try: @@ -4698,7 +6744,6 @@ def api_calendar_events(): end_date = request.args.get('end') importance = request.args.get('importance', 'all') category = request.args.get('category', 'all') - search_query = request.args.get('q', '').strip() # 新增搜索参数 page = int(request.args.get('page', 1)) per_page = int(request.args.get('per_page', 10)) offset = (page - 1) * per_page @@ -4715,8 +6760,7 @@ def api_calendar_events(): forecast, fact, related_stocks, - concepts, - inferred_tag + concepts FROM future_events WHERE 1=1 """ @@ -4733,21 +6777,9 @@ def api_calendar_events(): query += " AND star = :importance" params['importance'] = importance if category != 'all': - # category参数用于筛选inferred_tag字段(如"大周期"、"大消费"等) - query += " AND inferred_tag = :category" + query += " AND type = :category" params['category'] = category - # 新增搜索条件 - if search_query: - # 使用LIKE进行模糊搜索,同时搜索title和related_stocks字段 - # 对于JSON字段,MySQL会将其作为文本进行搜索 - query += """ AND ( - title LIKE :search_pattern - OR CAST(related_stocks AS CHAR) LIKE :search_pattern - OR CAST(concepts AS CHAR) LIKE :search_pattern - )""" - params['search_pattern'] = f'%{search_query}%' - query += " ORDER BY calendar_time LIMIT :limit OFFSET :offset" params['limit'] = per_page params['offset'] = offset @@ -4770,28 +6802,59 @@ def api_calendar_events(): if importance != 'all': count_query += " AND star = :importance" if category != 'all': - count_query += " AND inferred_tag = :category" - - # 新增搜索条件到计数查询 - if search_query: - count_query += """ AND ( - title LIKE :search_pattern - OR CAST(related_stocks AS CHAR) LIKE :search_pattern - OR CAST(concepts AS CHAR) LIKE :search_pattern - )""" + count_query += " AND type = :category" total_count_result = db.session.execute(text(count_query), count_params).fetchone() total_count = total_count_result.count if total_count_result else 0 + # 申万一级行业到主板块的映射 + sector_map = { + # 大周期 + '石油石化': '大周期', '煤炭': '大周期', '有色金属': '大周期', + '钢铁': '大周期', '基础化工': '大周期', '建筑材料': '大周期', + '机械设备': '大周期', '电力设备及新能源': '大周期', '国防军工': '大周期', + '电力设备': '大周期', '电网设备': '大周期', '风力发电': '大周期', + '太阳能发电': '大周期', '建筑装饰': '大周期', '建筑': '大周期', + '交通运输': '大周期', '采掘': '大周期', '公用事业': '大周期', + + # 大消费 + '汽车': '大消费', '家用电器': '大消费', '酒类': '大消费', + '食品饮料': '大消费', '医药生物': '大消费', '纺织服饰': '大消费', + '农林牧渔': '大消费', '商贸零售': '大消费', '轻工制造': '大消费', + '消费者服务': '大消费', '美容护理': '大消费', '社会服务': '大消费', + '纺织服装': '大消费', '商业贸易': '大消费', '休闲服务': '大消费', + + # 大金融地产 + '银行': '大金融地产', '证券': '大金融地产', '保险': '大金融地产', + '多元金融': '大金融地产', '综合金融': '大金融地产', + '房地产': '大金融地产', '非银金融': '大金融地产', + + # TMT板块 + '计算机': 'TMT板块', '电子': 'TMT板块', '传媒': 'TMT板块', '通信': 'TMT板块', + + # 公共产业 + '环保': '公共产业板块', '综合': '公共产业板块' + } + events_data = [] for event in events: - # 解析相关股票 + # 解析相关股票 - 使用与detail接口相同的逻辑 related_stocks_list = [] + sector_stats = { + '全部股票': 0, + '大周期': 0, + '大消费': 0, + 'TMT板块': 0, + '大金融地产': 0, + '公共产业板块': 0, + '其他': 0 + } + related_avg_chg = 0 related_max_chg = 0 related_week_chg = 0 - # 处理相关股票数据 + # **修正:处理相关股票数据** if event.related_stocks: try: import json @@ -4824,6 +6887,23 @@ def api_calendar_events(): # 规范化股票代码,移除后缀 clean_code = stock_code.replace('.SZ', '').replace('.SH', '').replace('.BJ', '') + # 使用模糊匹配LIKE查询申万一级行业F004V + sector_query = """ + SELECT F004V as sw_primary_sector + FROM ea_sector + WHERE SECCODE LIKE :stock_code_pattern + AND F002V = '申银万国行业分类' + LIMIT 1 + """ + sector_result = db.session.execute(text(sector_query), + {'stock_code_pattern': f'{clean_code}%'}) + sector_row = sector_result.fetchone() + + # 根据申万一级行业(F004V)映射到主板块 + sw_primary_sector = sector_row.sw_primary_sector if sector_row else None + primary_sector = sector_map.get(sw_primary_sector, + '其他') if sw_primary_sector else '其他' + # 使用模糊匹配查询真实的交易数据 trade_query = """ SELECT F007N as close_price, F010N as change_pct, TRADEDATE @@ -4850,6 +6930,13 @@ def api_calendar_events(): if week_ago_price > 0: week_chg = ((current_price - week_ago_price) / week_ago_price) * 100 + # **修正:安全地更新sector统计** + sector_stats['全部股票'] += 1 + if primary_sector in sector_stats: + sector_stats[primary_sector] += 1 + else: + sector_stats['其他'] += 1 + # 收集涨跌幅数据 daily_changes.append(daily_chg) week_changes.append(week_chg) @@ -4859,11 +6946,13 @@ def api_calendar_events(): 'name': stock_name, 'description': description, 'score': score, + 'sw_primary_sector': sw_primary_sector, + 'primary_sector': primary_sector, 'daily_chg': daily_chg, 'week_chg': week_chg }) - # 计算平均收益率 + # **修正:确保计算平均收益率** if daily_changes: related_avg_chg = round(sum(daily_changes) / len(daily_changes), 4) related_max_chg = round(max(daily_changes), 4) @@ -4874,26 +6963,12 @@ def api_calendar_events(): except Exception as e: print(f"Error processing related stocks for event {event.data_id}: {e}") - # 解析相关概念 + # **新增:解析相关概念 - 使用与detail接口相同的逻辑** related_concepts = extract_concepts_from_concepts_field(event.concepts) # 获取评星等级 star_rating = event.star - - # 如果有搜索关键词,可以高亮显示匹配的部分(可选功能) - highlight_match = False - if search_query: - # 检查是否在标题中匹配 - if search_query.lower() in (event.title or '').lower(): - highlight_match = 'title' - # 检查是否在股票中匹配 - elif any(search_query.lower() in str(stock).lower() for stock in related_stocks_list): - highlight_match = 'stocks' - # 检查是否在概念中匹配 - elif search_query.lower() in str(related_concepts).lower(): - highlight_match = 'concepts' - - event_dict = { + events_data.append({ 'id': event.data_id, 'title': event.title, 'description': f"前值: {event.former}, 预测: {event.forecast}, 实际: {event.fact}" if event.former or event.forecast or event.fact else "", @@ -4902,26 +6977,19 @@ def api_calendar_events(): 'category': { 'event_type': event.type, 'importance': event.star, - 'star_rating': star_rating, - 'inferred_tag': event.inferred_tag # 添加inferred_tag到返回数据 + 'star_rating': star_rating }, 'star_rating': star_rating, - 'inferred_tag': event.inferred_tag, # 直接返回行业标签 - 'related_concepts': related_concepts, - 'related_stocks': related_stocks_list, - 'related_avg_chg': round(related_avg_chg, 2), - 'related_max_chg': round(related_max_chg, 2), - 'related_week_chg': round(related_week_chg, 2), + 'related_concepts': related_concepts, # **修正:使用正确解析的概念** + 'related_stocks': related_stocks_list, # **修正:使用正确解析的股票** + 'related_avg_chg': round(related_avg_chg, 2), # **修正:真实的涨跌幅** + 'related_max_chg': round(related_max_chg, 2), # **修正:真实的涨跌幅** + 'related_week_chg': round(related_week_chg, 2), # **修正:真实的涨跌幅** + 'sector_stats': sector_stats, # **修正:正确的行业统计** 'former': event.former, 'forecast': event.forecast, 'fact': event.fact - } - - # 可选:添加搜索匹配标记 - if search_query and highlight_match: - event_dict['search_match'] = highlight_match - - events_data.append(event_dict) + }) return jsonify({ 'code': 200, @@ -4931,8 +6999,7 @@ def api_calendar_events(): 'total_count': total_count, 'page': page, 'per_page': per_page, - 'total_pages': (total_count + per_page - 1) // per_page, - 'search_query': search_query # 返回搜索关键词 + 'total_pages': (total_count + per_page - 1) // per_page } }) @@ -4946,8 +7013,6 @@ def api_calendar_events(): # 11. 投资日历-数据接口 @app.route('/api/calendar/data', methods=['GET']) - - def api_calendar_data(): """投资日历数据接口""" try: @@ -5134,10 +7199,8 @@ def extract_concepts_from_concepts_field(concepts_text): @app.route('/api/calendar/detail/', methods=['GET']) - - def api_future_event_detail(item_id): - """未来事件详情接口 - 连接 future_events 表 (修正数据解析) - 仅限 Pro/Max 会员""" + """未来事件详情接口 - 连接 future_events 表 (修正数据解析)""" try: # 从 future_events 表查询事件详情 query = """ @@ -5370,8 +7433,6 @@ def api_future_event_detail(item_id): # 13-15. 筛选弹窗接口(已有,优化格式) @app.route('/api/filter/options', methods=['GET']) - - def api_filter_options(): """筛选选项接口""" try: @@ -5423,7 +7484,7 @@ def api_filter_options(): # 16-17. 会员权益接口 @app.route('/api/membership/status', methods=['GET']) - +@token_required def api_membership_status(): """会员状态接口""" try: @@ -5460,7 +7521,7 @@ def api_membership_status(): # 18-19. 个人中心接口 @app.route('/api/user/profile', methods=['GET']) - +@token_required def api_user_profile(): """个人资料接口""" try: @@ -5697,9 +7758,39 @@ def api_agreements(): }), 500 +# 可选:添加一个管理接口来手动重新加载协议 +@app.route('/api/agreements/reload', methods=['POST']) +def api_reload_agreements(): + """重新加载协议内容(管理员接口)""" + try: + # 清除缓存 + global _cache_loaded + _cache_loaded = False + _agreements_cache.clear() + + # 重新加载 + agreements_data = load_agreements_from_docx() + + return jsonify({ + 'code': 200, + 'message': 'Agreements reloaded successfully', + 'data': { + 'total_agreements': len(agreements_data), + 'agreements': list(agreements_data.keys()) + } + }) + + except Exception as e: + return jsonify({ + 'code': 500, + 'message': str(e), + 'data': None + }), 500 + + # 20. 个人中心-我的关注接口 @app.route('/api/user/activities', methods=['GET']) - +@token_required def api_user_activities(): """用户活动接口(我的关注、评论、点赞)""" try: @@ -5918,6 +8009,89 @@ class UserFeedback(db.Model): } +# 21. 个人中心-意见反馈接口 +@app.route('/api/user/feedback', methods=['POST']) +@token_required +def api_user_feedback(): + """意见反馈接口""" + try: + data = request.get_json() + feedback_type = data.get('type', 'other') # bug, suggestion, complaint, other + content = data.get('content') + contact_info = data.get('contact_info', '') + + if not content: + return jsonify({ + 'code': 400, + 'message': '反馈类型和内容不能为空', + 'data': None + }), 400 + + # 验证反馈类型是否有效 + valid_types = ['bug', 'feature', 'suggestion', 'other'] # 可以根据需求修改 + if feedback_type not in valid_types: + return jsonify({ + 'code': 400, + 'message': '无效的反馈类型', + 'data': None + }), 400 + + # 创建反馈记录 + feedback = UserFeedback( + user_id=request.user.id, + type=feedback_type, + content=content, + contact_info=contact_info + ) + + # 保存到数据库 + db.session.add(feedback) + db.session.commit() + + # 可以在这里添加通知管理员的逻辑 + notify_admin_new_feedback(feedback) + + # 可以保存到数据库或发送邮件通知管理员 + + return jsonify({ + 'code': 200, + 'message': '反馈提交成功,我们会尽快处理', + 'data': feedback.to_dict() + }) + except Exception as e: + return jsonify({ + 'code': 500, + 'message': str(e), + 'data': None + }), 500 + + +def notify_admin_new_feedback(feedback): + """通知管理员新的反馈""" + try: + # 获取管理员邮箱列表 + admin_emails = ['admin@example.com'] # 替换为实际的管理员邮箱列表 + + # 发送通知邮件 + subject = f'新的用户反馈 - {feedback.type}' + body = f""" + 收到新的用户反馈: + + 用户ID: {feedback.user_id} + 反馈类型: {feedback.type} + 反馈内容: {feedback.content} + 联系方式: {feedback.contact_info or '未提供'} + 提交时间: {feedback.created_at} + """ + + for admin_email in admin_emails: + send_notification_email(admin_email, subject, 'emails/admin_notification.html', + feedback=feedback) + + except Exception as e: + app.logger.error(f"发送管理员通知失败: {str(e)}") + + # 通用错误处理 @app.errorhandler(404) @@ -5948,9 +8122,5 @@ if __name__ == '__main__': app.run( host='0.0.0.0', port=5002, - debug=True, - ssl_context=( - '/home/ubuntu/dify/docker/nginx/ssl/fullchain.pem', - '/home/ubuntu/dify/docker/nginx/ssl/privkey.pem' - ) + debug=True ) diff --git a/change_pct_fix.py b/change_pct_fix.py deleted file mode 100644 index 658dbd7a..00000000 --- a/change_pct_fix.py +++ /dev/null @@ -1,97 +0,0 @@ -# 这是要替换的涨跌幅计算逻辑(第3729-3738行) - -# 计算基于事件时间的涨跌幅(参考 app.py /api/stock/quotes 的实现) -change_pct = None -change_amount = None -current_price = None - -try: - # 获取事件时间和当前时间 - evt_time = event.start_time if event.start_time else event.created_at - cur_time = datetime.now() - - # 获取交易日和时间范围 - evt_date = evt_time.date() - evt_time_only = evt_time.time() - market_open = dt_time(9, 30) - market_close = dt_time(15, 0) - - # 检查是否是交易日 - is_trading_day_result = db.session.execute(text(""" - SELECT 1 FROM trading_days WHERE EXCHANGE_DATE = :date - """), {"date": evt_date}).fetchone() - - trading_day = None - start_time = None - end_time = None - - if is_trading_day_result: - # 是交易日 - if evt_time_only < market_open: - # 盘前 - 使用当日开盘 - trading_day, start_time, end_time = evt_date, market_open, market_close - elif evt_time_only > market_close: - # 盘后 - 使用下一交易日 - next_day_result = db.session.execute(text(""" - SELECT EXCHANGE_DATE FROM trading_days - WHERE EXCHANGE_DATE > :date ORDER BY EXCHANGE_DATE LIMIT 1 - """), {"date": evt_date}).fetchone() - if next_day_result: - trading_day, start_time, end_time = next_day_result[0].date(), market_open, market_close - else: - # 盘中 - 从事件时间到收盘 - trading_day, start_time, end_time = evt_date, evt_time_only, market_close - else: - # 非交易日 - 获取下一交易日 - next_day_result = db.session.execute(text(""" - SELECT EXCHANGE_DATE FROM trading_days - WHERE EXCHANGE_DATE > :date ORDER BY EXCHANGE_DATE LIMIT 1 - """), {"date": evt_date}).fetchone() - if next_day_result: - trading_day, start_time, end_time = next_day_result[0].date(), market_open, market_close - - # 如果有有效的交易日且不在未来,查询涨跌幅 - if trading_day and trading_day <= cur_time.date(): - start_dt = datetime.combine(trading_day, start_time) - end_dt = datetime.combine(trading_day, end_time) - - # 查询第一个bar和最后一个bar的价格 - price_data = client.execute(""" - WITH first_price AS ( - SELECT close FROM stock_minute - WHERE code = %(code)s AND timestamp >= %(start)s AND timestamp <= %(end)s - ORDER BY timestamp LIMIT 1 - ), - last_price AS ( - SELECT close FROM stock_minute - WHERE code = %(code)s AND timestamp >= %(start)s AND timestamp <= %(end)s - ORDER BY timestamp DESC LIMIT 1 - ) - SELECT first_price.close as first_price, - last_price.close as last_price, - (last_price.close - first_price.close) / first_price.close * 100 as change_pct - FROM last_price CROSS JOIN first_price - WHERE EXISTS (SELECT 1 FROM first_price) AND EXISTS (SELECT 1 FROM last_price) - """, {'code': stock.stock_code, 'start': start_dt, 'end': end_dt}) - - if price_data and price_data[0] and price_data[0][0] is not None: - first_price = float(price_data[0][0]) - current_price = float(price_data[0][1]) - change_pct = float(price_data[0][2]) - change_amount = current_price - first_price -except Exception as e: - print(f"计算事件涨跌幅失败 {stock.stock_code}: {e}") - -# 如果ClickHouse没有数据,fallback到原来的逻辑 -if change_pct is None: - if latest_trade and prev_trade: - if prev_trade.F007N and prev_trade.F007N != 0: - change_amount = float(latest_trade.F007N) - float(prev_trade.F007N) - change_pct = (change_amount / float(prev_trade.F007N)) * 100 - elif latest_trade and latest_trade.F010N: - change_pct = float(latest_trade.F010N) - change_amount = float(latest_trade.F009N) if latest_trade.F009N else None - -# 如果还没有当前价格,使用latest_trade -if current_price is None and latest_trade and latest_trade.F007N: - current_price = float(latest_trade.F007N) diff --git a/clickhouse_optimization_guide.py b/clickhouse_optimization_guide.py deleted file mode 100644 index 37a9a1bc..00000000 --- a/clickhouse_optimization_guide.py +++ /dev/null @@ -1,540 +0,0 @@ -""" -ClickHouse 查询优化方案 - 针对 /api/event//related-stocks-detail - -问题分析: -1. N+1 查询问题:每只股票执行 3 次独立查询(共 30+ 次) -2. 重复扫描:first_price 和 last_price 需要扫描表两次 -3. 缺少批量查询优化 - -优化方案对比: -┌─────────────┬──────────────┬──────────────┬────────────┐ -│ 方案 │ 查询次数 │ 性能提升 │ 实现难度 │ -├─────────────┼──────────────┼──────────────┼────────────┤ -│ 当前代码 │ N * 3 │ 基准 │ - │ -│ 方案1 批量 │ 1 │ 80-90% │ 中等 │ -│ 方案2 并行 │ N * 3 (并行)│ 40-60% │ 简单 │ -│ 方案3 缓存 │ 减少重复 │ 20-40% │ 简单 │ -└─────────────┴──────────────┴──────────────┴────────────┘ -""" - -# ============================================================================ -# 方案 1: 批量查询(推荐)- 将所有股票的查询合并为一次 -# ============================================================================ - -def get_batch_stock_prices_optimized(client, stock_codes, start_datetime, end_datetime): - """ - 批量获取多只股票的价格数据(一次查询) - - 性能对比: - - 旧方案:10 只股票 = 20 次查询(first + last) - - 新方案:10 只股票 = 1 次查询 - - 性能提升:约 20 倍 - - Args: - client: ClickHouse 客户端 - stock_codes: 股票代码列表 ['600519.SH', '601088.SH', ...] - start_datetime: 开始时间 - end_datetime: 结束时间 - - Returns: - dict: { - '600519.SH': { - 'first_price': 1850.0, - 'last_price': 1860.0, - 'change_pct': 0.54, - 'open': 1850.0, - 'high': 1865.0, - 'low': 1848.0, - 'volume': 1234567, - 'amount': 2345678900.0 - }, - ... - } - """ - if not stock_codes: - return {} - - # 构建批量查询 SQL(使用 IN 子句) - query = """ - SELECT - code, - -- 第一个价格(事件发生时) - anyIf(close, rownum_asc = 1) as first_price, - -- 最后一个价格(当前时间) - anyIf(close, rownum_desc = 1) as last_price, - -- 涨跌幅 - (last_price - first_price) / first_price * 100 as change_pct, - -- 涨跌额 - last_price - first_price as change_amount, - -- 其他价格信息(取最后一条记录) - anyIf(open, rownum_desc = 1) as open_price, - anyIf(high, rownum_desc = 1) as high_price, - anyIf(low, rownum_desc = 1) as low_price, - anyIf(volume, rownum_desc = 1) as volume, - anyIf(amt, rownum_desc = 1) as amount - FROM ( - SELECT - code, - timestamp, - close, - open, - high, - low, - volume, - amt, - -- 正序排名(用于获取第一个价格) - ROW_NUMBER() OVER (PARTITION BY code ORDER BY timestamp ASC) as rownum_asc, - -- 倒序排名(用于获取最后一个价格) - ROW_NUMBER() OVER (PARTITION BY code ORDER BY timestamp DESC) as rownum_desc - FROM stock_minute - WHERE code IN %(codes)s - AND timestamp >= %(start)s - AND timestamp <= %(end)s - ) - GROUP BY code - """ - - try: - # 执行批量查询 - data = client.execute(query, { - 'codes': tuple(stock_codes), # ClickHouse IN 需要 tuple - 'start': start_datetime, - 'end': end_datetime - }) - - # 格式化结果为字典 - result = {} - for row in data: - code = row[0] - result[code] = { - 'first_price': float(row[1]) if row[1] else None, - 'last_price': float(row[2]) if row[2] else None, - 'change_pct': float(row[3]) if row[3] else None, - 'change_amount': float(row[4]) if row[4] else None, - 'open_price': float(row[5]) if row[5] else None, - 'high_price': float(row[6]) if row[6] else None, - 'low_price': float(row[7]) if row[7] else None, - 'volume': int(row[8]) if row[8] else None, - 'amount': float(row[9]) if row[9] else None, - } - - print(f"批量查询完成,获取了 {len(result)} 只股票的数据") - return result - - except Exception as e: - print(f"批量查询失败: {e}") - return {} - - -def get_batch_minute_chart_data(client, stock_codes, start_datetime, end_datetime): - """ - 批量获取多只股票的分时图数据 - - Args: - client: ClickHouse 客户端 - stock_codes: 股票代码列表 - start_datetime: 开始时间 - end_datetime: 结束时间 - - Returns: - dict: { - '600519.SH': [ - {'time': '09:30', 'close': 1850.0, 'volume': 12345, ...}, - {'time': '09:31', 'close': 1851.0, 'volume': 12346, ...}, - ... - ], - ... - } - """ - if not stock_codes: - return {} - - query = """ - SELECT - code, - timestamp, - open, - high, - low, - close, - volume, - amt - FROM stock_minute - WHERE code IN %(codes)s - AND timestamp >= %(start)s - AND timestamp <= %(end)s - ORDER BY code, timestamp - """ - - try: - data = client.execute(query, { - 'codes': tuple(stock_codes), - 'start': start_datetime, - 'end': end_datetime - }) - - # 按股票代码分组 - result = {} - for row in data: - code = row[0] - if code not in result: - result[code] = [] - - result[code].append({ - 'time': row[1].strftime('%H:%M'), - 'open': float(row[2]) if row[2] else None, - 'high': float(row[3]) if row[3] else None, - 'low': float(row[4]) if row[4] else None, - 'close': float(row[5]) if row[5] else None, - 'volume': float(row[6]) if row[6] else None, - 'amount': float(row[7]) if row[7] else None - }) - - print(f"批量获取分时数据完成,获取了 {len(result)} 只股票的数据") - return result - - except Exception as e: - print(f"批量获取分时数据失败: {e}") - return {} - - -# ============================================================================ -# 使用示例:替换原来的 for 循环 -# ============================================================================ - -def api_event_related_stocks_optimized(event_id): - """优化后的端点实现""" - try: - from datetime import datetime - - event = Event.query.get_or_404(event_id) - related_stocks = event.related_stocks.order_by(RelatedStock.correlation.desc()).all() - - if not related_stocks: - return jsonify({'code': 200, 'data': {'related_stocks': []}}) - - # 获取 ClickHouse 客户端 - client = get_clickhouse_client() - - # 计算时间范围(省略交易日计算逻辑,与原代码相同) - event_time = event.start_time if event.start_time else event.created_at - trading_day, start_time, end_time = get_trading_day_and_times(event_time) - start_datetime = datetime.combine(trading_day, start_time) - end_datetime = datetime.combine(trading_day, end_time) - - # ✅ 批量查询所有股票的价格数据(只查询 1 次) - stock_codes = [stock.stock_code for stock in related_stocks] - prices_data = get_batch_stock_prices_optimized( - client, stock_codes, start_datetime, end_datetime - ) - - # ✅ 批量查询所有股票的分时图数据(只查询 1 次) - minute_data = get_batch_minute_chart_data( - client, stock_codes, start_datetime, end_datetime - ) - - # 组装返回数据 - stocks_data = [] - for stock in related_stocks: - # 从批量查询结果中获取数据(无需再次查询) - price_info = prices_data.get(stock.stock_code, {}) - chart_data = minute_data.get(stock.stock_code, []) - - # 获取股票基本信息(这里可以考虑也批量查询) - stock_info = StockBasicInfo.query.filter_by(SECCODE=stock.stock_code).first() - if not stock_info: - base_code = stock.stock_code.split('.')[0] - stock_info = StockBasicInfo.query.filter_by(SECCODE=base_code).first() - - stock_data = { - 'id': stock.id, - 'stock_code': stock.stock_code, - 'stock_name': stock.stock_name, - 'sector': stock.sector, - 'relation_desc': stock.relation_desc, - 'correlation': stock.correlation, - 'momentum': stock.momentum, - 'listing_date': stock_info.F006D.isoformat() if stock_info and stock_info.F006D else None, - 'market': stock_info.F005V if stock_info else None, - - # 交易数据(从批量查询结果获取) - 'trade_data': { - 'latest_price': price_info.get('last_price'), - 'first_price': price_info.get('first_price'), - 'open_price': price_info.get('open_price'), - 'high_price': price_info.get('high_price'), - 'low_price': price_info.get('low_price'), - 'change_amount': round(price_info['change_amount'], 2) if price_info.get('change_amount') else None, - 'change_pct': round(price_info['change_pct'], 2) if price_info.get('change_pct') else None, - 'volume': price_info.get('volume'), - 'amount': price_info.get('amount'), - 'trade_date': trading_day.isoformat(), - }, - - # 分时图数据 - 'minute_chart': chart_data - } - - stocks_data.append(stock_data) - - return jsonify({ - 'code': 200, - 'message': 'success', - 'data': { - 'event_id': event_id, - 'event_title': event.title, - 'related_stocks': stocks_data, - 'total_count': len(stocks_data) - } - }) - - except Exception as e: - print(f"Error in api_event_related_stocks_optimized: {e}") - return jsonify({'code': 500, 'message': str(e)}), 500 - - -# ============================================================================ -# 方案 2: 异步并行查询(适用于无法批量查询的场景) -# ============================================================================ - -import asyncio -from concurrent.futures import ThreadPoolExecutor - -def get_stock_price_async(client, stock_code, start_datetime, end_datetime): - """单个股票的查询函数(线程安全)""" - # 与原代码相同的查询逻辑 - try: - data = client.execute(""" - WITH first_price AS ( - SELECT close FROM stock_minute WHERE code = %(code)s ... - ) - ... - """, {'code': stock_code, 'start': start_datetime, 'end': end_datetime}) - return stock_code, data - except Exception as e: - return stock_code, None - - -def get_all_stocks_parallel(client, stock_codes, start_datetime, end_datetime): - """ - 并行查询多只股票(使用线程池) - - 性能对比: - - 串行:10 只股票 * 0.1 秒 = 1 秒 - - 并行:max(0.1 秒) = 0.1 秒(10 倍提速) - """ - with ThreadPoolExecutor(max_workers=10) as executor: - # 提交所有查询任务 - futures = [ - executor.submit(get_stock_price_async, client, code, start_datetime, end_datetime) - for code in stock_codes - ] - - # 等待所有任务完成 - results = {} - for future in futures: - stock_code, data = future.result() - results[stock_code] = data - - return results - - -# ============================================================================ -# 方案 3: 添加缓存层(Redis) -# ============================================================================ - -import redis -import json - -redis_client = redis.Redis(host='localhost', port=6379, db=0) - -def get_stock_price_with_cache(client, stock_code, start_datetime, end_datetime): - """ - 带缓存的查询(适用于历史数据) - - 缓存策略: - - 历史数据(非当日):缓存 24 小时 - - 当日数据:缓存 1 分钟 - """ - from datetime import datetime - - # 生成缓存键 - cache_key = f"stock_price:{stock_code}:{start_datetime.date()}:{end_datetime.date()}" - - # 尝试从缓存获取 - cached_data = redis_client.get(cache_key) - if cached_data: - print(f"从缓存获取 {stock_code} 数据") - return json.loads(cached_data) - - # 缓存未命中,查询数据库 - print(f"从 ClickHouse 查询 {stock_code} 数据") - data = client.execute("""...""", { - 'code': stock_code, - 'start': start_datetime, - 'end': end_datetime - }) - - # 格式化数据 - result = { - 'first_price': float(data[0][2]) if data else None, - 'last_price': float(data[0][0]) if data else None, - # ... - } - - # 写入缓存 - is_today = start_datetime.date() == datetime.now().date() - ttl = 60 if is_today else 86400 # 当日数据缓存 1 分钟,历史数据缓存 24 小时 - redis_client.setex(cache_key, ttl, json.dumps(result)) - - return result - - -# ============================================================================ -# 方案 4: ClickHouse 查询优化(索引提示) -# ============================================================================ - -def get_stock_price_with_hints(client, stock_code, start_datetime, end_datetime): - """ - 使用 ClickHouse 特性优化查询 - - 优化点: - 1. PREWHERE 子句(提前过滤,减少数据扫描) - 2. FINAL 修饰符(如果使用了 ReplacingMergeTree) - 3. 分区裁剪(如果表按日期分区) - """ - query = """ - SELECT - code, - anyLast(close) as last_price, - any(close) as first_price, - (last_price - first_price) / first_price * 100 as change_pct - FROM stock_minute - PREWHERE code = %(code)s -- 使用 PREWHERE 提前过滤(比 WHERE 快) - WHERE timestamp >= %(start)s - AND timestamp <= %(end)s - GROUP BY code - SETTINGS max_threads = 2 -- 限制线程数(避免资源竞争) - """ - - data = client.execute(query, { - 'code': stock_code, - 'start': start_datetime, - 'end': end_datetime - }) - - return data - - -# ============================================================================ -# 数据库层面优化建议 -# ============================================================================ - -""" -1. 确保 stock_minute 表有以下索引: - - PRIMARY KEY (code, timestamp) -- 主键索引 - - INDEX idx_timestamp timestamp TYPE minmax GRANULARITY 3 -- 时间索引 - -2. 表分区策略(如果数据量大): - CREATE TABLE stock_minute ( - code String, - timestamp DateTime, - ... - ) ENGINE = MergeTree() - PARTITION BY toYYYYMM(timestamp) -- 按月分区 - ORDER BY (code, timestamp) - SETTINGS index_granularity = 8192; - -3. 使用物化视图预计算(适用于固定查询模式): - CREATE MATERIALIZED VIEW stock_minute_summary - ENGINE = AggregatingMergeTree() - PARTITION BY toYYYYMMDD(timestamp) - ORDER BY (code, timestamp) - AS SELECT - code, - toStartOfMinute(timestamp) as minute, - anyLast(close) as last_close, - any(close) as first_close, - ... - FROM stock_minute - GROUP BY code, minute; - -4. 检查表统计信息: - SELECT - table, - partition, - rows, - bytes_on_disk - FROM system.parts - WHERE table = 'stock_minute'; -""" - - -# ============================================================================ -# 性能对比测试 -# ============================================================================ - -def benchmark_query_methods(): - """ - 性能对比测试 - - 测试场景:查询 10 只股票的价格数据 - - 预期结果: - - 原方案(串行 N+1):~1000ms - - 方案 1(批量查询):~50ms(20 倍提速) - - 方案 2(并行查询):~200ms(5 倍提速) - - 方案 3(带缓存):~10ms(100 倍提速,第二次请求) - """ - import time - - stock_codes = ['600519.SH', '601088.SH', '600276.SH', '000001.SZ', ...] - - # 测试方案 1:批量查询 - start = time.time() - result1 = get_batch_stock_prices_optimized(client, stock_codes, start_dt, end_dt) - print(f"批量查询耗时: {(time.time() - start) * 1000:.2f}ms") - - # 测试方案 2:并行查询 - start = time.time() - result2 = get_all_stocks_parallel(client, stock_codes, start_dt, end_dt) - print(f"并行查询耗时: {(time.time() - start) * 1000:.2f}ms") - - # 测试原方案(串行) - start = time.time() - result3 = {} - for code in stock_codes: - result3[code] = get_stock_price_original(client, code, start_dt, end_dt) - print(f"串行查询耗时: {(time.time() - start) * 1000:.2f}ms") - - -# ============================================================================ -# 总结与建议 -# ============================================================================ - -""" -推荐实施顺序: - -第一步(立即实施):方案 1 - 批量查询 -- 实现难度:中等 -- 性能提升:80-90% -- 风险:低 -- 时间:1-2 小时 - -第二步(可选):方案 3 - 添加缓存 -- 实现难度:简单 -- 性能提升:额外 20-40% -- 风险:低 -- 时间:30 分钟 - -第三步(长期):方案 4 - 数据库优化 -- 实现难度:中等 -- 性能提升:20-30% -- 风险:中(需要测试) -- 时间:2-4 小时 - -监控指标: -- 查询时间:目标 < 200ms(当前 > 1000ms) -- ClickHouse 查询次数:目标 1-2 次(当前 30+ 次) -- 缓存命中率:目标 > 80%(如果使用缓存) -""" diff --git a/fix_related_stocks_performance.py b/fix_related_stocks_performance.py deleted file mode 100644 index d91eb328..00000000 --- a/fix_related_stocks_performance.py +++ /dev/null @@ -1,465 +0,0 @@ -""" -性能优化补丁 - 修复 /api/event//related-stocks-detail 的 N+1 查询问题 - -使用方法: -1. 将下面的两个函数复制到 app_vx.py 中 -2. 替换原来的 api_event_related_stocks 函数 - -预期效果: -- 查询时间:从 1000-3000ms 降低到 100-300ms -- ClickHouse 查询次数:从 30+ 次降低到 2 次 -- 性能提升:约 80-90% -""" - -def get_batch_stock_prices(client, stock_codes, start_datetime, end_datetime): - """ - 批量获取多只股票的价格数据(只查询一次 ClickHouse) - - Args: - client: ClickHouse 客户端 - stock_codes: 股票代码列表 ['600519.SH', '601088.SH', ...] - start_datetime: 开始时间 - end_datetime: 结束时间 - - Returns: - dict: { - '600519.SH': { - 'first_price': 1850.0, - 'last_price': 1860.0, - 'change_pct': 0.54, - 'change_amount': 10.0, - 'open': 1850.0, - 'high': 1865.0, - 'low': 1848.0, - 'volume': 1234567, - 'amount': 2345678900.0 - }, - ... - } - """ - if not stock_codes: - return {} - - try: - # 批量查询 SQL - 使用窗口函数一次性获取所有股票的数据 - query = """ - SELECT - code, - first_price, - last_price, - (last_price - first_price) / nullIf(first_price, 0) * 100 as change_pct, - last_price - first_price as change_amount, - open_price, - high_price, - low_price, - volume, - amount - FROM ( - SELECT - code, - -- 使用 anyIf 获取第一个和最后一个价格 - anyIf(close, rn_asc = 1) as first_price, - anyIf(close, rn_desc = 1) as last_price, - anyIf(open, rn_desc = 1) as open_price, - -- 使用 max 获取最高价 - max(high) as high_price, - -- 使用 min 获取最低价 - min(low) as low_price, - anyIf(volume, rn_desc = 1) as volume, - anyIf(amt, rn_desc = 1) as amount - FROM ( - SELECT - code, - timestamp, - close, - open, - high, - low, - volume, - amt, - -- 正序行号(用于获取第一个价格) - row_number() OVER (PARTITION BY code ORDER BY timestamp ASC) as rn_asc, - -- 倒序行号(用于获取最后一个价格) - row_number() OVER (PARTITION BY code ORDER BY timestamp DESC) as rn_desc - FROM stock_minute - WHERE code IN %(codes)s - AND timestamp >= %(start)s - AND timestamp <= %(end)s - ) - GROUP BY code - ) - """ - - # 执行查询 - data = client.execute(query, { - 'codes': tuple(stock_codes), # ClickHouse IN 需要 tuple - 'start': start_datetime, - 'end': end_datetime - }) - - # 格式化结果 - result = {} - for row in data: - code = row[0] - result[code] = { - 'first_price': float(row[1]) if row[1] is not None else None, - 'last_price': float(row[2]) if row[2] is not None else None, - 'change_pct': float(row[3]) if row[3] is not None else None, - 'change_amount': float(row[4]) if row[4] is not None else None, - 'open_price': float(row[5]) if row[5] is not None else None, - 'high_price': float(row[6]) if row[6] is not None else None, - 'low_price': float(row[7]) if row[7] is not None else None, - 'volume': int(row[8]) if row[8] is not None else None, - 'amount': float(row[9]) if row[9] is not None else None, - } - - print(f"✅ 批量查询完成,获取了 {len(result)}/{len(stock_codes)} 只股票的数据") - return result - - except Exception as e: - print(f"❌ 批量查询失败: {e}") - import traceback - traceback.print_exc() - return {} - - -def get_batch_minute_chart_data(client, stock_codes, start_datetime, end_datetime): - """ - 批量获取多只股票的分时图数据 - - Args: - client: ClickHouse 客户端 - stock_codes: 股票代码列表 - start_datetime: 开始时间 - end_datetime: 结束时间 - - Returns: - dict: { - '600519.SH': [ - {'time': '09:30', 'open': 1850.0, 'close': 1851.0, 'volume': 12345, ...}, - {'time': '09:31', 'open': 1851.0, 'close': 1852.0, 'volume': 12346, ...}, - ... - ], - ... - } - """ - if not stock_codes: - return {} - - try: - query = """ - SELECT - code, - timestamp, - open, - high, - low, - close, - volume, - amt - FROM stock_minute - WHERE code IN %(codes)s - AND timestamp >= %(start)s - AND timestamp <= %(end)s - ORDER BY code, timestamp - """ - - data = client.execute(query, { - 'codes': tuple(stock_codes), - 'start': start_datetime, - 'end': end_datetime - }) - - # 按股票代码分组 - result = {} - for row in data: - code = row[0] - if code not in result: - result[code] = [] - - result[code].append({ - 'time': row[1].strftime('%H:%M'), - 'open': float(row[2]) if row[2] is not None else None, - 'high': float(row[3]) if row[3] is not None else None, - 'low': float(row[4]) if row[4] is not None else None, - 'close': float(row[5]) if row[5] is not None else None, - 'volume': float(row[6]) if row[6] is not None else None, - 'amount': float(row[7]) if row[7] is not None else None - }) - - print(f"✅ 批量获取分时数据完成,获取了 {len(result)}/{len(stock_codes)} 只股票的数据") - return result - - except Exception as e: - print(f"❌ 批量获取分时数据失败: {e}") - import traceback - traceback.print_exc() - return {} - - -# ============================================================================ -# 优化后的端点函数(替换原来的 api_event_related_stocks) -# ============================================================================ - -@app.route('/api/event//related-stocks-detail', methods=['GET']) -def api_event_related_stocks(event_id): - """事件相关标的详情接口 - 仅限 Pro/Max 会员(已优化性能)""" - try: - from datetime import datetime, timedelta, time as dt_time - from sqlalchemy import text - import time as time_module - - # 记录开始时间 - start_time = time_module.time() - - event = Event.query.get_or_404(event_id) - related_stocks = event.related_stocks.order_by(RelatedStock.correlation.desc()).all() - - if not related_stocks: - return jsonify({ - 'code': 200, - 'message': 'success', - 'data': { - 'event_id': event_id, - 'event_title': event.title, - 'related_stocks': [], - 'total_count': 0 - } - }) - - # 获取ClickHouse客户端 - client = get_clickhouse_client() - - # 获取事件时间和交易日(与原代码逻辑相同) - event_time = event.start_time if event.start_time else event.created_at - current_time = datetime.now() - - # 定义交易日和时间范围计算函数(与原代码完全一致) - def get_trading_day_and_times(event_datetime): - event_date = event_datetime.date() - event_time_only = event_datetime.time() - - market_open = dt_time(9, 30) - market_close = dt_time(15, 0) - - with engine.connect() as conn: - is_trading_day = conn.execute(text(""" - SELECT 1 - FROM trading_days - WHERE EXCHANGE_DATE = :date - """), {"date": event_date}).fetchone() is not None - - if is_trading_day: - if event_time_only < market_open: - return event_date, market_open, market_close - elif event_time_only > market_close: - next_trading_day = conn.execute(text(""" - SELECT EXCHANGE_DATE - FROM trading_days - WHERE EXCHANGE_DATE > :date - ORDER BY EXCHANGE_DATE LIMIT 1 - """), {"date": event_date}).fetchone() - return (next_trading_day[0].date() if next_trading_day else None, - market_open, market_close) - else: - return event_date, event_time_only, market_close - else: - next_trading_day = conn.execute(text(""" - SELECT EXCHANGE_DATE - FROM trading_days - WHERE EXCHANGE_DATE > :date - ORDER BY EXCHANGE_DATE LIMIT 1 - """), {"date": event_date}).fetchone() - return (next_trading_day[0].date() if next_trading_day else None, - market_open, market_close) - - trading_day, start_time_val, end_time_val = get_trading_day_and_times(event_time) - - if not trading_day: - return jsonify({ - 'code': 200, - 'message': 'success', - 'data': { - 'event_id': event_id, - 'event_title': event.title, - 'event_desc': event.description, - 'event_type': event.event_type, - 'event_importance': event.importance, - 'event_status': event.status, - 'event_created_at': event.created_at.strftime("%Y-%m-%d %H:%M:%S"), - 'event_start_time': event.start_time.isoformat() if event.start_time else None, - 'event_end_time': event.end_time.isoformat() if event.end_time else None, - 'keywords': event.keywords, - 'view_count': event.view_count, - 'post_count': event.post_count, - 'follower_count': event.follower_count, - 'related_stocks': [], - 'total_count': 0 - } - }) - - start_datetime = datetime.combine(trading_day, start_time_val) - end_datetime = datetime.combine(trading_day, end_time_val) - - print(f"📊 事件时间: {event_time}, 交易日: {trading_day}, 时间范围: {start_datetime} - {end_datetime}") - - # ✅ 批量查询所有股票的价格数据(关键优化点 1) - stock_codes = [stock.stock_code for stock in related_stocks] - print(f"📈 开始批量查询 {len(stock_codes)} 只股票的价格数据...") - - query_start = time_module.time() - prices_data = get_batch_stock_prices(client, stock_codes, start_datetime, end_datetime) - query_time = (time_module.time() - query_start) * 1000 - print(f"⏱️ 价格查询耗时: {query_time:.2f}ms") - - # ✅ 批量查询所有股票的分时图数据(关键优化点 2) - print(f"📈 开始批量查询 {len(stock_codes)} 只股票的分时数据...") - - chart_start = time_module.time() - minute_data = get_batch_minute_chart_data(client, stock_codes, start_datetime, end_datetime) - chart_time = (time_module.time() - chart_start) * 1000 - print(f"⏱️ 分时数据查询耗时: {chart_time:.2f}ms") - - # 组装返回数据(不再需要循环查询) - stocks_data = [] - for stock in related_stocks: - # 从批量查询结果中获取数据(O(1) 查找) - price_info = prices_data.get(stock.stock_code, {}) - chart_data = minute_data.get(stock.stock_code, []) - - # 获取股票基本信息 - stock_info = StockBasicInfo.query.filter_by(SECCODE=stock.stock_code).first() - if not stock_info: - base_code = stock.stock_code.split('.')[0] - stock_info = StockBasicInfo.query.filter_by(SECCODE=base_code).first() - - # 如果批量查询没有返回数据,尝试使用 TradeData 作为降级方案 - if not price_info or price_info.get('last_price') is None: - try: - latest_trade = None - search_codes = [stock.stock_code, stock.stock_code.split('.')[0]] - - for code in search_codes: - latest_trade = TradeData.query.filter_by(SECCODE=code) \ - .order_by(TradeData.TRADEDATE.desc()).first() - if latest_trade: - break - - if latest_trade: - price_info = { - 'last_price': float(latest_trade.F007N) if latest_trade.F007N else None, - 'first_price': float(latest_trade.F002N) if latest_trade.F002N else None, - 'open_price': float(latest_trade.F003N) if latest_trade.F003N else None, - 'high_price': float(latest_trade.F005N) if latest_trade.F005N else None, - 'low_price': float(latest_trade.F006N) if latest_trade.F006N else None, - 'volume': float(latest_trade.F004N) if latest_trade.F004N else None, - 'amount': float(latest_trade.F011N) if latest_trade.F011N else None, - 'change_pct': float(latest_trade.F010N) if latest_trade.F010N else None, - 'change_amount': float(latest_trade.F009N) if latest_trade.F009N else None, - } - except Exception as fallback_error: - print(f"⚠️ 降级查询失败 {stock.stock_code}: {fallback_error}") - - stock_data = { - 'id': stock.id, - 'stock_code': stock.stock_code, - 'stock_name': stock.stock_name, - 'sector': stock.sector, - 'relation_desc': stock.relation_desc, - 'correlation': stock.correlation, - 'momentum': stock.momentum, - 'listing_date': stock_info.F006D.isoformat() if stock_info and stock_info.F006D else None, - 'market': stock_info.F005V if stock_info else None, - - # 交易数据(从批量查询结果获取) - 'trade_data': { - 'latest_price': price_info.get('last_price'), - 'first_price': price_info.get('first_price'), - 'open_price': price_info.get('open_price'), - 'high_price': price_info.get('high_price'), - 'low_price': price_info.get('low_price'), - 'change_amount': round(price_info['change_amount'], 2) if price_info.get('change_amount') is not None else None, - 'change_pct': round(price_info['change_pct'], 2) if price_info.get('change_pct') is not None else None, - 'volume': price_info.get('volume'), - 'amount': price_info.get('amount'), - 'trade_date': trading_day.isoformat() if trading_day else None, - }, - - # 分时图数据 - 'minute_chart': chart_data - } - - stocks_data.append(stock_data) - - # 计算总耗时 - total_time = (time_module.time() - start_time) * 1000 - print(f"✅ 请求完成,总耗时: {total_time:.2f}ms (价格: {query_time:.2f}ms, 分时: {chart_time:.2f}ms)") - - return jsonify({ - 'code': 200, - 'message': 'success', - 'data': { - 'event_id': event_id, - 'event_title': event.title, - 'event_desc': event.description, - 'event_type': event.event_type, - 'event_importance': event.importance, - 'event_status': event.status, - 'event_created_at': event.created_at.strftime("%Y-%m-%d %H:%M:%S"), - 'event_start_time': event.start_time.isoformat() if event.start_time else None, - 'event_end_time': event.end_time.isoformat() if event.end_time else None, - 'keywords': event.keywords, - 'view_count': event.view_count, - 'post_count': event.post_count, - 'follower_count': event.follower_count, - 'related_stocks': stocks_data, - 'total_count': len(stocks_data), - - # 性能指标(可选,调试用) - 'performance': { - 'total_time_ms': round(total_time, 2), - 'price_query_ms': round(query_time, 2), - 'chart_query_ms': round(chart_time, 2) - } - } - }) - - except Exception as e: - print(f"❌ Error in api_event_related_stocks: {e}") - import traceback - traceback.print_exc() - return jsonify({'code': 500, 'message': str(e)}), 500 - - -# ============================================================================ -# 使用说明 -# ============================================================================ - -""" -1. 将上面的 3 个函数复制到 app_vx.py 中: - - get_batch_stock_prices() - - get_batch_minute_chart_data() - - api_event_related_stocks()(替换原函数) - -2. 重启 Flask 应用: - python app_vx.py - -3. 测试端点: - curl http://localhost:5001/api/event/18058/related-stocks-detail - -4. 观察日志输出: - ✅ 批量查询完成,获取了 10/10 只股票的数据 - ⏱️ 价格查询耗时: 45.23ms - ⏱️ 分时数据查询耗时: 78.56ms - ✅ 请求完成,总耗时: 234.67ms - -5. 性能对比(10 只股票): - - 优化前:1000-3000ms(30+ 次查询) - - 优化后:100-300ms(2 次查询) - - 提升:80-90% - -6. 如果还是慢,检查: - - ClickHouse 表是否有索引:SHOW CREATE TABLE stock_minute; - - 数据量是否过大:SELECT count() FROM stock_minute WHERE code = '600519.SH'; - - 网络延迟:ping ClickHouse 服务器 -"""