import base64 import csv import io import os import time import urllib import uuid from functools import wraps import qrcode from flask_mail import Mail, Message from flask_socketio import SocketIO, emit, join_room, leave_room import pytz import requests from celery import Celery from flask_compress import Compress from pathlib import Path import json from sqlalchemy import Column, Integer, String, Boolean, DateTime, create_engine, text, func, or_ from flask import Flask, render_template, request, jsonify, redirect, url_for, flash, session, render_template_string, \ current_app, make_response from flask_sqlalchemy import SQLAlchemy from flask_login import LoginManager, UserMixin, login_user, logout_user, login_required, current_user import random from werkzeug.security import generate_password_hash, check_password_hash import re import string from datetime import datetime, timedelta, time as dt_time, date from clickhouse_driver import Client as Cclient from flask_cors import CORS from collections import defaultdict from functools import lru_cache import jieba import jieba.analyse from flask_cors import cross_origin from tencentcloud.common import credential from tencentcloud.common.profile.client_profile import ClientProfile from tencentcloud.common.profile.http_profile import HttpProfile from tencentcloud.sms.v20210111 import sms_client, models from tencentcloud.common.exception.tencent_cloud_sdk_exception import TencentCloudSDKException from sqlalchemy import text, desc, and_ import pandas as pd from decimal import Decimal from apscheduler.schedulers.background import BackgroundScheduler # 交易日数据缓存 trading_days = [] trading_days_set = set() def load_trading_days(): """加载交易日数据""" global trading_days, trading_days_set try: with open('tdays.csv', 'r') as f: reader = csv.DictReader(f) for row in reader: date_str = row['DateTime'] # 解析日期 (格式: 2010/1/4) date = datetime.strptime(date_str, '%Y/%m/%d').date() trading_days.append(date) trading_days_set.add(date) # 排序交易日 trading_days.sort() print(f"成功加载 {len(trading_days)} 个交易日数据") except Exception as e: print(f"加载交易日数据失败: {e}") def row_to_dict(row): """ 将 SQLAlchemy Row 对象转换为字典 兼容 SQLAlchemy 1.4+ 版本 """ if row is None: return None # 使用 _mapping 属性来访问列数据 return dict(row._mapping) def get_trading_day_near_date(target_date): """ 获取距离目标日期最近的交易日 如果目标日期是交易日,返回该日期 如果不是,返回下一个交易日 """ if not trading_days: load_trading_days() if not trading_days: return None # 如果目标日期是datetime,转换为date if isinstance(target_date, datetime): target_date = target_date.date() # 检查目标日期是否是交易日 if target_date in trading_days_set: return target_date # 查找下一个交易日 for trading_day in trading_days: if trading_day >= target_date: return trading_day # 如果没有找到,返回最后一个交易日 return trading_days[-1] if trading_days else None # 应用启动时加载交易日数据 load_trading_days() engine = create_engine( "mysql+pymysql://root:Zzl5588161!@222.128.1.157:33060/stock?charset=utf8mb4", echo=False, pool_size=10, pool_recycle=3600, pool_pre_ping=True, pool_timeout=30, max_overflow=20 ) engine_med = create_engine( "mysql+pymysql://root:Zzl5588161!@222.128.1.157:33060/med?charset=utf8mb4", echo=False, pool_size=5, pool_recycle=3600, pool_pre_ping=True, pool_timeout=30, max_overflow=10 ) engine_2 = create_engine( "mysql+pymysql://root:Zzl5588161!@222.128.1.157:33060/valuefrontier?charset=utf8mb4", echo=False, pool_size=5, pool_recycle=3600, pool_pre_ping=True, pool_timeout=30, max_overflow=10 ) app = Flask(__name__) # 存储验证码的临时字典(生产环境应使用Redis) verification_codes = {} wechat_qr_sessions = {} # 腾讯云短信配置 SMS_SECRET_ID = 'AKID2we9TacdTAhCjCSYTErHVimeJo9Yr00s' SMS_SECRET_KEY = 'pMlBWijlkgT9fz5ziEXdWEnAPTJzRfkf' SMS_SDK_APP_ID = "1400972398" SMS_SIGN_NAME = "价值前沿科技" SMS_TEMPLATE_REGISTER = "2386557" # 注册模板 SMS_TEMPLATE_LOGIN = "2386540" # 登录模板 # 微信开放平台配置 WECHAT_APPID = 'wxa8d74c47041b5f87' WECHAT_APPSECRET = 'eedef95b11787fd7ca7f1acc6c9061bc' WECHAT_REDIRECT_URI = 'http://valuefrontier.cn/api/auth/wechat/callback' # 邮件服务配置(QQ企业邮箱) MAIL_SERVER = 'smtp.exmail.qq.com' MAIL_PORT = 465 MAIL_USE_SSL = True MAIL_USE_TLS = False MAIL_USERNAME = 'admin@valuefrontier.cn' MAIL_PASSWORD = 'QYncRu6WUdASvTg4' MAIL_DEFAULT_SENDER = 'admin@valuefrontier.cn' # Session和安全配置 app.config['SECRET_KEY'] = ''.join(random.choices(string.ascii_letters + string.digits, k=32)) app.config['SESSION_COOKIE_SECURE'] = False # 如果生产环境使用HTTPS,应设为True app.config['SESSION_COOKIE_HTTPONLY'] = True # 生产环境应设为True,防止XSS攻击 app.config['SESSION_COOKIE_SAMESITE'] = 'Lax' # 使用'Lax'以平衡安全性和功能性 app.config['SESSION_COOKIE_DOMAIN'] = None # 不限制域名 app.config['SESSION_COOKIE_PATH'] = '/' # 设置cookie路径 app.config['PERMANENT_SESSION_LIFETIME'] = timedelta(days=7) # session持续7天 app.config['REMEMBER_COOKIE_DURATION'] = timedelta(days=30) # 记住登录30天 app.config['REMEMBER_COOKIE_SECURE'] = False # 记住登录cookie不要求HTTPS app.config['REMEMBER_COOKIE_HTTPONLY'] = False # 允许JavaScript访问 # 配置邮件 app.config['MAIL_SERVER'] = MAIL_SERVER app.config['MAIL_PORT'] = MAIL_PORT app.config['MAIL_USE_SSL'] = MAIL_USE_SSL app.config['MAIL_USE_TLS'] = MAIL_USE_TLS app.config['MAIL_USERNAME'] = MAIL_USERNAME app.config['MAIL_PASSWORD'] = MAIL_PASSWORD app.config['MAIL_DEFAULT_SENDER'] = MAIL_DEFAULT_SENDER # 允许前端跨域访问 - 修复CORS配置 try: CORS(app, origins=["http://localhost:3000", "http://127.0.0.1:3000", "http://localhost:5173", "https://valuefrontier.cn", "http://valuefrontier.cn"], # 明确指定允许的源 methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], allow_headers=["Content-Type", "Authorization", "X-Requested-With"], supports_credentials=True, # 允许携带凭据 expose_headers=["Content-Type", "Authorization"]) except ImportError: pass # 如果未安装flask_cors则跳过 # 初始化 Flask-Login login_manager = LoginManager() login_manager.init_app(app) login_manager.login_view = 'login' login_manager.login_message = '请先登录访问此页面' login_manager.remember_cookie_duration = timedelta(days=30) # 记住登录持续时间 Compress(app) MAX_CONTENT_LENGTH = 16 * 1024 * 1024 # 16MB max file size # Configure Flask-Compress app.config['COMPRESS_ALGORITHM'] = ['gzip', 'br'] app.config['COMPRESS_MIMETYPES'] = [ 'text/html', 'text/css', 'text/xml', 'application/json', 'application/javascript', 'application/x-javascript' ] app.config['SQLALCHEMY_DATABASE_URI'] = 'mysql+pymysql://root:Zzl5588161!@222.128.1.157:33060/stock?charset=utf8mb4' app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False app.config['SQLALCHEMY_ENGINE_OPTIONS'] = { 'pool_size': 10, 'pool_recycle': 3600, 'pool_pre_ping': True, 'pool_timeout': 30, 'max_overflow': 20 } # Cache directory setup CACHE_DIR = Path('cache') CACHE_DIR.mkdir(exist_ok=True) def beijing_now(): # 使用 pytz 处理时区,但返回 naive datetime(适合数据库存储) beijing_tz = pytz.timezone('Asia/Shanghai') return datetime.now(beijing_tz).replace(tzinfo=None) # 检查用户是否登录的装饰器 def login_required(f): @wraps(f) def decorated_function(*args, **kwargs): if 'user_id' not in session: return jsonify({'success': False, 'error': '未登录'}), 401 return f(*args, **kwargs) return decorated_function # Memory management constants MAX_MEMORY_PERCENT = 75 MEMORY_CHECK_INTERVAL = 300 MAX_CACHE_ITEMS = 50 db = SQLAlchemy(app) # 初始化邮件服务 mail = Mail(app) # 初始化 Flask-SocketIO(用于实时事件推送) socketio = SocketIO( app, cors_allowed_origins=["http://localhost:3000", "http://127.0.0.1:3000", "http://localhost:5173", "https://valuefrontier.cn", "http://valuefrontier.cn"], async_mode='gevent', logger=True, engineio_logger=False, ping_timeout=120, # 心跳超时时间(秒),客户端120秒内无响应才断开 ping_interval=25 # 心跳检测间隔(秒),每25秒发送一次ping ) @login_manager.user_loader def load_user(user_id): """Flask-Login 用户加载回调""" try: return User.query.get(int(user_id)) except Exception as e: app.logger.error(f"用户加载错误: {e}") return None # 全局错误处理器 - 确保API接口始终返回JSON @app.errorhandler(404) def not_found_error(error): """404错误处理""" if request.path.startswith('/api/'): return jsonify({'success': False, 'error': '接口不存在'}), 404 return error @app.errorhandler(500) def internal_error(error): """500错误处理""" db.session.rollback() if request.path.startswith('/api/'): return jsonify({'success': False, 'error': '服务器内部错误'}), 500 return error @app.errorhandler(405) def method_not_allowed_error(error): """405错误处理""" if request.path.startswith('/api/'): return jsonify({'success': False, 'error': '请求方法不被允许'}), 405 return error class Post(db.Model): """帖子模型""" id = db.Column(db.Integer, primary_key=True) event_id = db.Column(db.Integer, db.ForeignKey('event.id'), nullable=False) user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False) # 内容 title = db.Column(db.String(200)) # 标题(可选) content = db.Column(db.Text, nullable=False) # 内容 content_type = db.Column(db.String(20), default='text') # 内容类型:text/rich_text/link # 时间 created_at = db.Column(db.DateTime, default=beijing_now) updated_at = db.Column(db.DateTime, default=beijing_now, onupdate=beijing_now) # 统计 likes_count = db.Column(db.Integer, default=0) comments_count = db.Column(db.Integer, default=0) view_count = db.Column(db.Integer, default=0) # 状态 status = db.Column(db.String(20), default='active') # active/hidden/deleted is_top = db.Column(db.Boolean, default=False) # 是否置顶 # 关系 user = db.relationship('User', backref='posts') likes = db.relationship('PostLike', backref='post', lazy='dynamic') comments = db.relationship('Comment', backref='post', lazy='dynamic') class Comment(db.Model): """帖子评论模型""" id = db.Column(db.Integer, primary_key=True) post_id = db.Column(db.Integer, db.ForeignKey('post.id'), nullable=False) user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False) # 内容 content = db.Column(db.Text, nullable=False) parent_id = db.Column(db.Integer, db.ForeignKey('comment.id')) # 时间 created_at = db.Column(db.DateTime, default=beijing_now) updated_at = db.Column(db.DateTime, default=beijing_now, onupdate=beijing_now) # 统计 likes_count = db.Column(db.Integer, default=0) # 状态 status = db.Column(db.String(20), default='active') # active/hidden/deleted # 关系 user = db.relationship('User', backref='comments') replies = db.relationship('Comment', backref=db.backref('parent', remote_side=[id])) class User(UserMixin, db.Model): """用户模型 - 完全匹配现有数据库表结构""" __tablename__ = 'user' # 主键 id = db.Column(db.Integer, primary_key=True, autoincrement=True) # 基础账号信息 username = db.Column(db.String(80), unique=True, nullable=False) email = db.Column(db.String(120), unique=True, nullable=True) password_hash = db.Column(db.String(255), nullable=True) email_confirmed = db.Column(db.Boolean, nullable=True, default=True) # 时间字段 created_at = db.Column(db.DateTime, nullable=True, default=beijing_now) last_seen = db.Column(db.DateTime, nullable=True, default=beijing_now) # 账号状态 status = db.Column(db.String(20), nullable=True, default='active') # 个人资料信息 nickname = db.Column(db.String(30), nullable=True) avatar_url = db.Column(db.String(200), nullable=True) banner_url = db.Column(db.String(200), nullable=True) bio = db.Column(db.String(200), nullable=True) gender = db.Column(db.String(10), nullable=True) birth_date = db.Column(db.Date, nullable=True) location = db.Column(db.String(100), nullable=True) # 联系方式 phone = db.Column(db.String(20), nullable=True) wechat_id = db.Column(db.String(80), nullable=True) # 微信号 # 实名认证 real_name = db.Column(db.String(30), nullable=True) id_number = db.Column(db.String(18), nullable=True) is_verified = db.Column(db.Boolean, nullable=True, default=False) verify_time = db.Column(db.DateTime, nullable=True) # 投资偏好 trading_experience = db.Column(db.String(200), nullable=True) investment_style = db.Column(db.String(50), nullable=True) risk_preference = db.Column(db.String(20), nullable=True) investment_amount = db.Column(db.String(20), nullable=True) preferred_markets = db.Column(db.String(200), nullable=True) # 社区数据 user_level = db.Column(db.Integer, nullable=True, default=1) reputation_score = db.Column(db.Integer, nullable=True, default=0) contribution_point = db.Column(db.Integer, nullable=True, default=0) post_count = db.Column(db.Integer, nullable=True, default=0) comment_count = db.Column(db.Integer, nullable=True, default=0) follower_count = db.Column(db.Integer, nullable=True, default=0) following_count = db.Column(db.Integer, nullable=True, default=0) # 创作者相关 is_creator = db.Column(db.Boolean, nullable=True, default=False) creator_type = db.Column(db.String(20), nullable=True) creator_tags = db.Column(db.String(200), nullable=True) # 通知设置 email_notifications = db.Column(db.Boolean, nullable=True, default=True) sms_notifications = db.Column(db.Boolean, nullable=True, default=False) wechat_notifications = db.Column(db.Boolean, nullable=True, default=False) notification_preferences = db.Column(db.String(500), nullable=True) # 隐私和界面设置 privacy_level = db.Column(db.String(20), nullable=True, default='public') theme_preference = db.Column(db.String(20), nullable=True, default='light') blocked_keywords = db.Column(db.String(500), nullable=True) # 手机验证相关 phone_confirmed = db.Column(db.Boolean, nullable=True, default=False) # 注意:原表中是blob,这里改为Boolean更合理 phone_confirm_time = db.Column(db.DateTime, nullable=True) # 微信登录相关字段 wechat_union_id = db.Column(db.String(100), nullable=True) # 微信UnionID wechat_open_id = db.Column(db.String(100), nullable=True) # 微信OpenID def __init__(self, username, email=None, password=None, phone=None): """初始化用户""" self.username = username if email: self.email = email if phone: self.phone = phone if password: self.set_password(password) self.nickname = username # 默认昵称为用户名 self.created_at = beijing_now() self.last_seen = beijing_now() def set_password(self, password): """设置密码""" if password: self.password_hash = generate_password_hash(password) def check_password(self, password): """验证密码""" if not password or not self.password_hash: return False return check_password_hash(self.password_hash, password) def update_last_seen(self): """更新最后活跃时间""" self.last_seen = beijing_now() db.session.commit() def confirm_email(self): """确认邮箱""" self.email_confirmed = True db.session.commit() def confirm_phone(self): """确认手机号""" self.phone_confirmed = True self.phone_confirm_time = beijing_now() db.session.commit() def bind_wechat(self, open_id, union_id=None, wechat_info=None): """绑定微信账号""" self.wechat_open_id = open_id if union_id: self.wechat_union_id = union_id # 如果提供了微信用户信息,更新头像和昵称 if wechat_info: if not self.avatar_url and wechat_info.get('headimgurl'): self.avatar_url = wechat_info['headimgurl'] if not self.nickname and wechat_info.get('nickname'): # 确保昵称编码正确且长度合理 nickname = self._sanitize_nickname(wechat_info['nickname']) self.nickname = nickname db.session.commit() def _sanitize_nickname(self, nickname): """清理和验证昵称""" if not nickname: return '微信用户' try: # 确保是正确的UTF-8字符串 sanitized = str(nickname).strip() # 移除可能的控制字符 import re sanitized = re.sub(r'[\x00-\x1f\x7f-\x9f]', '', sanitized) # 限制长度(避免过长的昵称) if len(sanitized) > 50: sanitized = sanitized[:47] + '...' # 如果清理后为空,使用默认值 if not sanitized: sanitized = '微信用户' return sanitized except Exception as e: return '微信用户' def unbind_wechat(self): """解绑微信账号""" self.wechat_open_id = None self.wechat_union_id = None db.session.commit() def increment_post_count(self): """增加发帖数""" self.post_count = (self.post_count or 0) + 1 db.session.commit() def increment_comment_count(self): """增加评论数""" self.comment_count = (self.comment_count or 0) + 1 db.session.commit() def add_reputation(self, points): """增加声誉分数""" self.reputation_score = (self.reputation_score or 0) + points db.session.commit() def to_dict(self, include_sensitive=False): """转换为字典""" data = { 'id': self.id, 'username': self.username, 'nickname': self.nickname or self.username, 'avatar_url': self.avatar_url, 'banner_url': self.banner_url, 'bio': self.bio, 'gender': self.gender, 'location': self.location, 'user_level': self.user_level or 1, 'reputation_score': self.reputation_score or 0, 'contribution_point': self.contribution_point or 0, 'post_count': self.post_count or 0, 'comment_count': self.comment_count or 0, 'follower_count': self.follower_count or 0, 'following_count': self.following_count or 0, 'is_creator': self.is_creator or False, 'creator_type': self.creator_type, 'creator_tags': self.creator_tags, 'is_verified': self.is_verified or False, 'created_at': self.created_at.isoformat() if self.created_at else None, 'last_seen': self.last_seen.isoformat() if self.last_seen else None, 'status': self.status, 'has_wechat': bool(self.wechat_open_id), 'is_authenticated': True } # 获取用户订阅信息(从 user_subscriptions 表) subscription = UserSubscription.query.filter_by(user_id=self.id).first() if subscription: data.update({ 'subscription_type': subscription.subscription_type, 'subscription_status': subscription.subscription_status, 'billing_cycle': subscription.billing_cycle, 'start_date': subscription.start_date.isoformat() if subscription.start_date else None, 'end_date': subscription.end_date.isoformat() if subscription.end_date else None, 'auto_renewal': subscription.auto_renewal }) else: # 无订阅时使用默认值 data.update({ 'subscription_type': 'free', 'subscription_status': 'inactive', 'billing_cycle': None, 'start_date': None, 'end_date': None, 'auto_renewal': False }) # 敏感信息只在需要时包含 if include_sensitive: data.update({ 'email': self.email, 'phone': self.phone, 'email_confirmed': self.email_confirmed, 'phone_confirmed': self.phone_confirmed, 'real_name': self.real_name, 'birth_date': self.birth_date.isoformat() if self.birth_date else None, 'trading_experience': self.trading_experience, 'investment_style': self.investment_style, 'risk_preference': self.risk_preference, 'investment_amount': self.investment_amount, 'preferred_markets': self.preferred_markets, 'email_notifications': self.email_notifications, 'sms_notifications': self.sms_notifications, 'wechat_notifications': self.wechat_notifications, 'privacy_level': self.privacy_level, 'theme_preference': self.theme_preference }) return data def to_public_dict(self): """公开信息字典(用于显示给其他用户)""" return { 'id': self.id, 'username': self.username, 'nickname': self.nickname or self.username, 'avatar_url': self.avatar_url, 'bio': self.bio, 'user_level': self.user_level or 1, 'reputation_score': self.reputation_score or 0, 'post_count': self.post_count or 0, 'follower_count': self.follower_count or 0, 'is_creator': self.is_creator or False, 'creator_type': self.creator_type, 'is_verified': self.is_verified or False, 'created_at': self.created_at.isoformat() if self.created_at else None } @staticmethod def find_by_login_info(login_info): """根据登录信息查找用户(支持用户名、邮箱、手机号)""" return User.query.filter( db.or_( User.username == login_info, User.email == login_info, User.phone == login_info ) ).first() @staticmethod def find_by_wechat_openid(open_id): """根据微信OpenID查找用户""" return User.query.filter_by(wechat_open_id=open_id).first() @staticmethod def find_by_wechat_unionid(union_id): """根据微信UnionID查找用户""" return User.query.filter_by(wechat_union_id=union_id).first() @staticmethod def is_username_taken(username): """检查用户名是否已被使用""" return User.query.filter_by(username=username).first() is not None @staticmethod def is_email_taken(email): """检查邮箱是否已被使用""" return User.query.filter_by(email=email).first() is not None @staticmethod def is_phone_taken(phone): """检查手机号是否已被使用""" return User.query.filter_by(phone=phone).first() is not None def __repr__(self): return f'' # ============================================ # 订阅功能模块(安全版本 - 独立表) # ============================================ class UserSubscription(db.Model): """用户订阅表 - 独立于现有User表""" __tablename__ = 'user_subscriptions' id = db.Column(db.Integer, primary_key=True, autoincrement=True) user_id = db.Column(db.Integer, nullable=False, unique=True, index=True) subscription_type = db.Column(db.String(10), nullable=False, default='free') subscription_status = db.Column(db.String(20), nullable=False, default='active') start_date = db.Column(db.DateTime, nullable=True) end_date = db.Column(db.DateTime, nullable=True) billing_cycle = db.Column(db.String(10), nullable=True) auto_renewal = db.Column(db.Boolean, nullable=False, default=False) created_at = db.Column(db.DateTime, default=beijing_now) updated_at = db.Column(db.DateTime, default=beijing_now, onupdate=beijing_now) def is_active(self): if self.subscription_status != 'active': return False if self.subscription_type == 'free': return True if self.end_date: try: now = beijing_now() if self.end_date < now: return False except Exception as e: return False return True def days_left(self): if self.subscription_type == 'free' or not self.end_date: return 999 try: now = beijing_now() delta = self.end_date - now return max(0, delta.days) except Exception as e: return 0 def to_dict(self): return { 'type': self.subscription_type, 'status': self.subscription_status, 'is_active': self.is_active(), 'days_left': self.days_left(), 'start_date': self.start_date.isoformat() if self.start_date else None, 'end_date': self.end_date.isoformat() if self.end_date else None, 'billing_cycle': self.billing_cycle, 'auto_renewal': self.auto_renewal } class SubscriptionPlan(db.Model): """订阅套餐表""" __tablename__ = 'subscription_plans' id = db.Column(db.Integer, primary_key=True, autoincrement=True) name = db.Column(db.String(50), nullable=False, unique=True) display_name = db.Column(db.String(100), nullable=False) description = db.Column(db.Text, nullable=True) monthly_price = db.Column(db.Numeric(10, 2), nullable=False) yearly_price = db.Column(db.Numeric(10, 2), nullable=False) features = db.Column(db.Text, nullable=True) pricing_options = db.Column(db.Text, nullable=True) # JSON格式:[{"months": 1, "price": 99}, {"months": 12, "price": 999}] is_active = db.Column(db.Boolean, default=True) sort_order = db.Column(db.Integer, default=0) created_at = db.Column(db.DateTime, default=beijing_now) def to_dict(self): # 解析pricing_options(如果存在) pricing_opts = None if self.pricing_options: try: pricing_opts = json.loads(self.pricing_options) except: pricing_opts = None # 如果没有pricing_options,则从monthly_price和yearly_price生成默认选项 if not pricing_opts: pricing_opts = [ { 'months': 1, 'price': float(self.monthly_price) if self.monthly_price else 0, 'label': '月付', 'cycle_key': 'monthly' }, { 'months': 12, 'price': float(self.yearly_price) if self.yearly_price else 0, 'label': '年付', 'cycle_key': 'yearly', 'discount_percent': 20 # 年付默认20%折扣 } ] return { 'id': self.id, 'name': self.name, 'display_name': self.display_name, 'description': self.description, 'monthly_price': float(self.monthly_price) if self.monthly_price else 0, 'yearly_price': float(self.yearly_price) if self.yearly_price else 0, 'pricing_options': pricing_opts, # 新增:灵活计费周期选项 'features': json.loads(self.features) if self.features else [], 'is_active': self.is_active, 'sort_order': self.sort_order } class PaymentOrder(db.Model): """支付订单表""" __tablename__ = 'payment_orders' id = db.Column(db.Integer, primary_key=True, autoincrement=True) order_no = db.Column(db.String(32), unique=True, nullable=False) user_id = db.Column(db.Integer, nullable=False) plan_name = db.Column(db.String(20), nullable=False) billing_cycle = db.Column(db.String(10), nullable=False) amount = db.Column(db.Numeric(10, 2), nullable=False) wechat_order_id = db.Column(db.String(64), nullable=True) prepay_id = db.Column(db.String(64), nullable=True) qr_code_url = db.Column(db.String(200), nullable=True) status = db.Column(db.String(20), default='pending') created_at = db.Column(db.DateTime, default=beijing_now) paid_at = db.Column(db.DateTime, nullable=True) expired_at = db.Column(db.DateTime, nullable=True) remark = db.Column(db.String(200), nullable=True) def __init__(self, user_id, plan_name, billing_cycle, amount): self.user_id = user_id self.plan_name = plan_name self.billing_cycle = billing_cycle self.amount = amount import random timestamp = int(beijing_now().timestamp() * 1000000) random_suffix = random.randint(1000, 9999) self.order_no = f"{timestamp}{user_id:04d}{random_suffix}" self.expired_at = beijing_now() + timedelta(minutes=30) def is_expired(self): if not self.expired_at: return False try: now = beijing_now() return now > self.expired_at except Exception as e: return False def mark_as_paid(self, wechat_order_id, transaction_id=None): self.status = 'paid' self.paid_at = beijing_now() self.wechat_order_id = wechat_order_id def to_dict(self): return { 'id': self.id, 'order_no': self.order_no, 'user_id': self.user_id, 'plan_name': self.plan_name, 'billing_cycle': self.billing_cycle, 'amount': float(self.amount) if self.amount else 0, 'original_amount': float(self.original_amount) if hasattr(self, 'original_amount') and self.original_amount else None, 'discount_amount': float(self.discount_amount) if hasattr(self, 'discount_amount') and self.discount_amount else 0, 'promo_code': self.promo_code.code if hasattr(self, 'promo_code') and self.promo_code else None, 'is_upgrade': self.is_upgrade if hasattr(self, 'is_upgrade') else False, 'qr_code_url': self.qr_code_url, 'status': self.status, 'is_expired': self.is_expired(), 'created_at': self.created_at.isoformat() if self.created_at else None, 'paid_at': self.paid_at.isoformat() if self.paid_at else None, 'expired_at': self.expired_at.isoformat() if self.expired_at else None, 'remark': self.remark } class PromoCode(db.Model): """优惠码表""" __tablename__ = 'promo_codes' id = db.Column(db.Integer, primary_key=True, autoincrement=True) code = db.Column(db.String(50), unique=True, nullable=False, index=True) description = db.Column(db.String(200), nullable=True) # 折扣类型和值 discount_type = db.Column(db.String(20), nullable=False) # 'percentage' 或 'fixed_amount' discount_value = db.Column(db.Numeric(10, 2), nullable=False) # 适用范围 applicable_plans = db.Column(db.String(200), nullable=True) # JSON格式 applicable_cycles = db.Column(db.String(50), nullable=True) # JSON格式 min_amount = db.Column(db.Numeric(10, 2), nullable=True) # 使用限制 max_uses = db.Column(db.Integer, nullable=True) max_uses_per_user = db.Column(db.Integer, default=1) current_uses = db.Column(db.Integer, default=0) # 有效期 valid_from = db.Column(db.DateTime, nullable=False) valid_until = db.Column(db.DateTime, nullable=False) # 状态 is_active = db.Column(db.Boolean, default=True) created_by = db.Column(db.Integer, nullable=True) created_at = db.Column(db.DateTime, default=beijing_now) updated_at = db.Column(db.DateTime, default=beijing_now, onupdate=beijing_now) def to_dict(self): return { 'id': self.id, 'code': self.code, 'description': self.description, 'discount_type': self.discount_type, 'discount_value': float(self.discount_value) if self.discount_value else 0, 'applicable_plans': json.loads(self.applicable_plans) if self.applicable_plans else None, 'applicable_cycles': json.loads(self.applicable_cycles) if self.applicable_cycles else None, 'min_amount': float(self.min_amount) if self.min_amount else None, 'max_uses': self.max_uses, 'max_uses_per_user': self.max_uses_per_user, 'current_uses': self.current_uses, 'valid_from': self.valid_from.isoformat() if self.valid_from else None, 'valid_until': self.valid_until.isoformat() if self.valid_until else None, 'is_active': self.is_active } class PromoCodeUsage(db.Model): """优惠码使用记录表""" __tablename__ = 'promo_code_usage' id = db.Column(db.Integer, primary_key=True, autoincrement=True) promo_code_id = db.Column(db.Integer, db.ForeignKey('promo_codes.id'), nullable=False) user_id = db.Column(db.Integer, nullable=False, index=True) order_id = db.Column(db.Integer, db.ForeignKey('payment_orders.id'), nullable=False) original_amount = db.Column(db.Numeric(10, 2), nullable=False) discount_amount = db.Column(db.Numeric(10, 2), nullable=False) final_amount = db.Column(db.Numeric(10, 2), nullable=False) used_at = db.Column(db.DateTime, default=beijing_now) # 关系 promo_code = db.relationship('PromoCode', backref='usages') order = db.relationship('PaymentOrder', backref='promo_usage') class SubscriptionUpgrade(db.Model): """订阅升级/降级记录表""" __tablename__ = 'subscription_upgrades' id = db.Column(db.Integer, primary_key=True, autoincrement=True) user_id = db.Column(db.Integer, nullable=False, index=True) order_id = db.Column(db.Integer, db.ForeignKey('payment_orders.id'), nullable=False) # 原订阅信息 from_plan = db.Column(db.String(20), nullable=False) from_cycle = db.Column(db.String(10), nullable=False) from_end_date = db.Column(db.DateTime, nullable=True) # 新订阅信息 to_plan = db.Column(db.String(20), nullable=False) to_cycle = db.Column(db.String(10), nullable=False) to_end_date = db.Column(db.DateTime, nullable=False) # 价格计算 remaining_value = db.Column(db.Numeric(10, 2), nullable=False) upgrade_amount = db.Column(db.Numeric(10, 2), nullable=False) actual_amount = db.Column(db.Numeric(10, 2), nullable=False) upgrade_type = db.Column(db.String(20), nullable=False) # 'plan_upgrade', 'cycle_change', 'both' created_at = db.Column(db.DateTime, default=beijing_now) # 关系 order = db.relationship('PaymentOrder', backref='upgrade_record') # ============================================ # 模拟盘相关模型 # ============================================ class SimulationAccount(db.Model): """模拟账户""" __tablename__ = 'simulation_accounts' id = db.Column(db.Integer, primary_key=True) user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False, unique=True) account_name = db.Column(db.String(100), default='我的模拟账户') initial_capital = db.Column(db.Numeric(15, 2), default=1000000.00) # 初始资金 available_cash = db.Column(db.Numeric(15, 2), default=1000000.00) # 可用资金 frozen_cash = db.Column(db.Numeric(15, 2), default=0.00) # 冻结资金 position_value = db.Column(db.Numeric(15, 2), default=0.00) # 持仓市值 total_assets = db.Column(db.Numeric(15, 2), default=1000000.00) # 总资产 total_profit = db.Column(db.Numeric(15, 2), default=0.00) # 总盈亏 total_profit_rate = db.Column(db.Numeric(10, 4), default=0.00) # 总收益率 daily_profit = db.Column(db.Numeric(15, 2), default=0.00) # 日盈亏 daily_profit_rate = db.Column(db.Numeric(10, 4), default=0.00) # 日收益率 created_at = db.Column(db.DateTime, default=beijing_now) updated_at = db.Column(db.DateTime, default=beijing_now, onupdate=beijing_now) last_settlement_date = db.Column(db.Date) # 最后结算日期 # 关系 user = db.relationship('User', backref='simulation_account') positions = db.relationship('SimulationPosition', backref='account', lazy='dynamic') orders = db.relationship('SimulationOrder', backref='account', lazy='dynamic') transactions = db.relationship('SimulationTransaction', backref='account', lazy='dynamic') def calculate_total_assets(self): """计算总资产""" self.total_assets = self.available_cash + self.frozen_cash + self.position_value self.total_profit = self.total_assets - self.initial_capital self.total_profit_rate = (self.total_profit / self.initial_capital) * 100 if self.initial_capital > 0 else 0 return self.total_assets class SimulationPosition(db.Model): """模拟持仓""" __tablename__ = 'simulation_positions' id = db.Column(db.Integer, primary_key=True) account_id = db.Column(db.Integer, db.ForeignKey('simulation_accounts.id'), nullable=False) stock_code = db.Column(db.String(20), nullable=False) stock_name = db.Column(db.String(100)) position_qty = db.Column(db.Integer, default=0) # 持仓数量 available_qty = db.Column(db.Integer, default=0) # 可用数量(T+1) frozen_qty = db.Column(db.Integer, default=0) # 冻结数量 avg_cost = db.Column(db.Numeric(10, 3), default=0.00) # 平均成本 current_price = db.Column(db.Numeric(10, 3), default=0.00) # 当前价格 market_value = db.Column(db.Numeric(15, 2), default=0.00) # 市值 profit = db.Column(db.Numeric(15, 2), default=0.00) # 盈亏 profit_rate = db.Column(db.Numeric(10, 4), default=0.00) # 盈亏比例 today_profit = db.Column(db.Numeric(15, 2), default=0.00) # 今日盈亏 today_profit_rate = db.Column(db.Numeric(10, 4), default=0.00) # 今日盈亏比例 created_at = db.Column(db.DateTime, default=beijing_now) updated_at = db.Column(db.DateTime, default=beijing_now, onupdate=beijing_now) __table_args__ = ( db.UniqueConstraint('account_id', 'stock_code', name='unique_account_stock'), ) def update_market_value(self, current_price): """更新市值和盈亏""" self.current_price = current_price self.market_value = self.position_qty * current_price total_cost = self.position_qty * self.avg_cost self.profit = self.market_value - total_cost self.profit_rate = (self.profit / total_cost * 100) if total_cost > 0 else 0 return self.market_value class SimulationOrder(db.Model): """模拟订单""" __tablename__ = 'simulation_orders' id = db.Column(db.Integer, primary_key=True) account_id = db.Column(db.Integer, db.ForeignKey('simulation_accounts.id'), nullable=False) order_no = db.Column(db.String(32), unique=True, nullable=False) stock_code = db.Column(db.String(20), nullable=False) stock_name = db.Column(db.String(100)) order_type = db.Column(db.String(10), nullable=False) # BUY/SELL price_type = db.Column(db.String(10), default='MARKET') # MARKET/LIMIT order_price = db.Column(db.Numeric(10, 3)) # 委托价格 order_qty = db.Column(db.Integer, nullable=False) # 委托数量 filled_qty = db.Column(db.Integer, default=0) # 成交数量 filled_price = db.Column(db.Numeric(10, 3)) # 成交价格 filled_amount = db.Column(db.Numeric(15, 2)) # 成交金额 commission = db.Column(db.Numeric(10, 2), default=0.00) # 手续费 stamp_tax = db.Column(db.Numeric(10, 2), default=0.00) # 印花税 transfer_fee = db.Column(db.Numeric(10, 2), default=0.00) # 过户费 total_fee = db.Column(db.Numeric(10, 2), default=0.00) # 总费用 status = db.Column(db.String(20), default='PENDING') # PENDING/PARTIAL/FILLED/CANCELLED/REJECTED reject_reason = db.Column(db.String(200)) order_time = db.Column(db.DateTime, default=beijing_now) filled_time = db.Column(db.DateTime) cancel_time = db.Column(db.DateTime) def calculate_fees(self): """计算交易费用""" if not self.filled_amount: return 0 # 佣金(万分之2.5,最低5元) self.commission = max(float(self.filled_amount) * 0.00025, 5.0) # 印花税(卖出时收取千分之1) if self.order_type == 'SELL': self.stamp_tax = float(self.filled_amount) * 0.001 else: self.stamp_tax = 0 # 过户费(双向收取,万分之0.2) self.transfer_fee = float(self.filled_amount) * 0.00002 # 总费用 self.total_fee = self.commission + self.stamp_tax + self.transfer_fee return self.total_fee class SimulationTransaction(db.Model): """模拟成交记录""" __tablename__ = 'simulation_transactions' id = db.Column(db.Integer, primary_key=True) account_id = db.Column(db.Integer, db.ForeignKey('simulation_accounts.id'), nullable=False) order_id = db.Column(db.Integer, db.ForeignKey('simulation_orders.id'), nullable=False) transaction_no = db.Column(db.String(32), unique=True, nullable=False) stock_code = db.Column(db.String(20), nullable=False) stock_name = db.Column(db.String(100)) transaction_type = db.Column(db.String(10), nullable=False) # BUY/SELL transaction_price = db.Column(db.Numeric(10, 3), nullable=False) transaction_qty = db.Column(db.Integer, nullable=False) transaction_amount = db.Column(db.Numeric(15, 2), nullable=False) commission = db.Column(db.Numeric(10, 2), default=0.00) stamp_tax = db.Column(db.Numeric(10, 2), default=0.00) transfer_fee = db.Column(db.Numeric(10, 2), default=0.00) total_fee = db.Column(db.Numeric(10, 2), default=0.00) transaction_time = db.Column(db.DateTime, default=beijing_now) settlement_date = db.Column(db.Date) # T+1结算日期 # 关系 order = db.relationship('SimulationOrder', backref='transactions') class SimulationDailyStats(db.Model): """模拟账户日统计""" __tablename__ = 'simulation_daily_stats' id = db.Column(db.Integer, primary_key=True) account_id = db.Column(db.Integer, db.ForeignKey('simulation_accounts.id'), nullable=False) stat_date = db.Column(db.Date, nullable=False) opening_assets = db.Column(db.Numeric(15, 2)) # 期初资产 closing_assets = db.Column(db.Numeric(15, 2)) # 期末资产 daily_profit = db.Column(db.Numeric(15, 2)) # 日盈亏 daily_profit_rate = db.Column(db.Numeric(10, 4)) # 日收益率 total_profit = db.Column(db.Numeric(15, 2)) # 累计盈亏 total_profit_rate = db.Column(db.Numeric(10, 4)) # 累计收益率 trade_count = db.Column(db.Integer, default=0) # 交易次数 win_count = db.Column(db.Integer, default=0) # 盈利次数 loss_count = db.Column(db.Integer, default=0) # 亏损次数 max_profit = db.Column(db.Numeric(15, 2)) # 最大盈利 max_loss = db.Column(db.Numeric(15, 2)) # 最大亏损 created_at = db.Column(db.DateTime, default=beijing_now) __table_args__ = ( db.UniqueConstraint('account_id', 'stat_date', name='unique_account_date'), ) def get_user_subscription_safe(user_id): """安全地获取用户订阅信息""" try: subscription = UserSubscription.query.filter_by(user_id=user_id).first() if not subscription: subscription = UserSubscription(user_id=user_id) db.session.add(subscription) db.session.commit() return subscription except Exception as e: # 返回默认免费版本对象 class DefaultSub: def to_dict(self): return { 'type': 'free', 'status': 'active', 'is_active': True, 'days_left': 999, 'billing_cycle': None, 'auto_renewal': False } return DefaultSub() def activate_user_subscription(user_id, plan_type, billing_cycle, extend_from_now=False): """ 激活用户订阅(新版:续费时从当前订阅结束时间开始延长) Args: user_id: 用户ID plan_type: 套餐类型 (pro/max) billing_cycle: 计费周期 (monthly/quarterly/semiannual/yearly) extend_from_now: 废弃参数,保留以兼容(现在自动判断) Returns: UserSubscription 对象 或 None """ try: subscription = UserSubscription.query.filter_by(user_id=user_id).first() if not subscription: # 新用户,创建订阅记录 subscription = UserSubscription(user_id=user_id) db.session.add(subscription) # 更新订阅类型和状态 subscription.subscription_type = plan_type subscription.subscription_status = 'active' subscription.billing_cycle = billing_cycle # 计算订阅周期天数 cycle_days_map = { 'monthly': 30, 'quarterly': 90, # 3个月 'semiannual': 180, # 6个月 'yearly': 365 } days = cycle_days_map.get(billing_cycle, 30) now = beijing_now() # 判断是新购还是续费 if subscription.end_date and subscription.end_date > now: # 续费:从当前订阅结束时间开始延长 start_date = subscription.end_date end_date = start_date + timedelta(days=days) else: # 新购或过期后重新购买:从当前时间开始 start_date = now end_date = now + timedelta(days=days) subscription.start_date = start_date subscription.end_date = end_date subscription.updated_at = now db.session.commit() return subscription except Exception as e: print(f"激活订阅失败: {e}") db.session.rollback() return None def validate_promo_code(code, plan_name, billing_cycle, amount, user_id): """验证优惠码 Returns: tuple: (promo_code_obj, error_message) """ try: promo = PromoCode.query.filter_by(code=code.upper(), is_active=True).first() if not promo: return None, "优惠码不存在或已失效" # 检查有效期 now = beijing_now() if now < promo.valid_from: return None, "优惠码尚未生效" if now > promo.valid_until: return None, "优惠码已过期" # 检查使用次数 if promo.max_uses and promo.current_uses >= promo.max_uses: return None, "优惠码已被使用完" # 检查每用户使用次数 if promo.max_uses_per_user: user_usage_count = PromoCodeUsage.query.filter_by( promo_code_id=promo.id, user_id=user_id ).count() if user_usage_count >= promo.max_uses_per_user: return None, f"您已使用过此优惠码(限用{promo.max_uses_per_user}次)" # 检查适用套餐 if promo.applicable_plans: try: applicable = json.loads(promo.applicable_plans) if plan_name not in applicable: return None, "该优惠码不适用于此套餐" except: pass # 检查适用周期 if promo.applicable_cycles: try: applicable = json.loads(promo.applicable_cycles) if billing_cycle not in applicable: return None, "该优惠码不适用于此计费周期" except: pass # 检查最低消费 if promo.min_amount and amount < float(promo.min_amount): return None, f"需满{float(promo.min_amount):.2f}元才可使用此优惠码" return promo, None except Exception as e: return None, f"验证优惠码时出错: {str(e)}" def calculate_discount(promo_code, amount): """计算优惠金额""" try: if promo_code.discount_type == 'percentage': discount = amount * (float(promo_code.discount_value) / 100) else: # fixed_amount discount = float(promo_code.discount_value) # 确保折扣不超过总金额 return min(discount, amount) except: return 0 def calculate_subscription_price_simple(user_id, to_plan_name, to_cycle, promo_code=None): """ 简化版价格计算:续费用户和新用户价格完全一致,不计算剩余价值 Args: user_id: 用户ID to_plan_name: 目标套餐名称 (pro/max) to_cycle: 计费周期 (monthly/quarterly/semiannual/yearly) promo_code: 优惠码(可选) Returns: dict: { 'is_renewal': False/True, # 是否为续费 'subscription_type': 'new'/'renew', # 订阅类型 'current_plan': 'pro', # 当前套餐(如果有) 'current_cycle': 'yearly', # 当前周期(如果有) 'new_plan_price': 2699.00, # 新套餐价格 'original_amount': 2699.00, # 原价 'discount_amount': 0, # 优惠金额 'final_amount': 2699.00, # 实付金额 'promo_code': None, # 使用的优惠码 'promo_error': None # 优惠码错误信息 } """ try: # 1. 获取当前订阅 current_sub = UserSubscription.query.filter_by(user_id=user_id).first() # 2. 获取目标套餐 to_plan = SubscriptionPlan.query.filter_by(name=to_plan_name, is_active=True).first() if not to_plan: return {'error': '目标套餐不存在'} # 3. 根据计费周期获取价格 # 优先从 pricing_options 获取价格 price = None if to_plan.pricing_options: try: pricing_opts = json.loads(to_plan.pricing_options) # 查找匹配的周期 for opt in pricing_opts: cycle_key = opt.get('cycle_key', '') months = opt.get('months', 0) # 匹配逻辑 if (cycle_key == to_cycle or (to_cycle == 'monthly' and months == 1) or (to_cycle == 'quarterly' and months == 3) or (to_cycle == 'semiannual' and months == 6) or (to_cycle == 'yearly' and months == 12)): price = float(opt.get('price', 0)) break except: pass # 如果 pricing_options 中没有找到,使用旧的 monthly_price/yearly_price if price is None: if to_cycle == 'yearly': price = float(to_plan.yearly_price) if to_plan.yearly_price else 0 else: # 默认月付 price = float(to_plan.monthly_price) if to_plan.monthly_price else 0 if price <= 0: return {'error': f'{to_cycle} 周期价格未配置'} # 4. 判断订阅类型和计算价格 is_renewal = False is_upgrade = False is_downgrade = False subscription_type = 'new' current_plan = None current_cycle = None remaining_value = 0 final_price = price if current_sub and current_sub.subscription_type in ['pro', 'max']: current_plan = current_sub.subscription_type current_cycle = current_sub.billing_cycle if current_plan == to_plan_name: # 同级续费:延长时长,全价购买 is_renewal = True subscription_type = 'renew' elif current_plan == 'pro' and to_plan_name == 'max': # 升级:Pro → Max,需要计算差价 is_upgrade = True subscription_type = 'upgrade' # 计算当前订阅的剩余价值 if current_sub.end_date and current_sub.end_date > datetime.utcnow(): # 获取当前套餐的原始价格 current_plan_obj = SubscriptionPlan.query.filter_by(name=current_plan, is_active=True).first() if current_plan_obj: current_price = None # 优先从 pricing_options 获取价格 if current_plan_obj.pricing_options: try: pricing_opts = json.loads(current_plan_obj.pricing_options) # 如果 current_cycle 为空或无效,根据剩余天数推断计费周期 if not current_cycle or current_cycle.strip() == '': remaining_days_total = (current_sub.end_date - current_sub.start_date).days if current_sub.start_date else 365 # 根据总天数推断计费周期 if remaining_days_total <= 35: inferred_cycle = 'monthly' elif remaining_days_total <= 100: inferred_cycle = 'quarterly' elif remaining_days_total <= 200: inferred_cycle = 'semiannual' else: inferred_cycle = 'yearly' else: inferred_cycle = current_cycle for opt in pricing_opts: if opt.get('cycle_key') == inferred_cycle: current_price = float(opt.get('price', 0)) current_cycle = inferred_cycle # 更新周期信息 break except: pass # 如果 pricing_options 中没找到,使用 yearly_price 作为默认 if current_price is None or current_price <= 0: current_price = float(current_plan_obj.yearly_price) if current_plan_obj.yearly_price else 0 current_cycle = 'yearly' if current_price and current_price > 0: # 计算剩余天数 remaining_days = (current_sub.end_date - datetime.utcnow()).days # 计算总天数 cycle_days_map = { 'monthly': 30, 'quarterly': 90, 'semiannual': 180, 'yearly': 365 } total_days = cycle_days_map.get(current_cycle, 365) # 计算剩余价值 if total_days > 0 and remaining_days > 0: remaining_value = current_price * (remaining_days / total_days) # 实付金额 = 新套餐价格 - 剩余价值 final_price = max(0, price - remaining_value) # 如果剩余价值 >= 新套餐价格,标记为免费升级 if remaining_value >= price: final_price = 0 elif current_plan == 'max' and to_plan_name == 'pro': # 降级:Max → Pro,到期后切换,全价购买 is_downgrade = True subscription_type = 'downgrade' else: # 其他情况视为新购 subscription_type = 'new' # 5. 构建结果 result = { 'is_renewal': is_renewal, 'is_upgrade': is_upgrade, 'is_downgrade': is_downgrade, 'subscription_type': subscription_type, 'current_plan': current_plan, 'current_cycle': current_cycle, 'new_plan_price': price, 'original_price': price, # 新套餐原价 'remaining_value': remaining_value, # 当前订阅剩余价值(仅升级时有效) 'original_amount': price, 'discount_amount': 0, 'final_amount': final_price, 'promo_code': None, 'promo_error': None } # 6. 应用优惠码(基于差价后的金额) if promo_code and promo_code.strip(): # 优惠码作用于差价后的金额 promo, error = validate_promo_code(promo_code, to_plan_name, to_cycle, final_price, user_id) if promo: discount = calculate_discount(promo, final_price) result['discount_amount'] = float(discount) result['final_amount'] = final_price - float(discount) result['promo_code'] = promo.code elif error: result['promo_error'] = error return result except Exception as e: return {'error': f'价格计算失败: {str(e)}'} # 保留旧函数以兼容(标记为废弃) def calculate_upgrade_price(user_id, to_plan_name, to_cycle, promo_code=None): """ 【已废弃】旧版升级价格计算函数,保留以兼容旧代码 新代码请使用 calculate_subscription_price_simple """ # 直接调用新函数 return calculate_subscription_price_simple(user_id, to_plan_name, to_cycle, promo_code) def initialize_subscription_plans_safe(): """安全地初始化订阅套餐""" try: if SubscriptionPlan.query.first(): return pro_plan = SubscriptionPlan( name='pro', display_name='Pro版本', description='适合个人投资者的基础功能套餐', monthly_price=0.01, yearly_price=0.08, features=json.dumps([ "基础股票分析工具", "历史数据查询", "基础财务报表", "简单投资计划记录", "标准客服支持" ]), sort_order=1 ) max_plan = SubscriptionPlan( name='max', display_name='Max版本', description='适合专业投资者的全功能套餐', monthly_price=0.1, yearly_price=0.8, features=json.dumps([ "全部Pro版本功能", "高级分析工具", "实时数据推送", "专业财务分析报告", "AI投资建议", "无限投资计划存储", "优先客服支持", "独家研报访问" ]), sort_order=2 ) db.session.add(pro_plan) db.session.add(max_plan) db.session.commit() except Exception as e: pass # -------------------------------------------- # 订阅等级工具函数 # -------------------------------------------- def _get_current_subscription_info(): """获取当前登录用户订阅信息的字典形式,未登录或异常时视为免费用户。""" try: user_id = session.get('user_id') if not user_id: return { 'type': 'free', 'status': 'active', 'is_active': True } sub = get_user_subscription_safe(user_id) data = sub.to_dict() # 标准化字段名 return { 'type': data.get('type') or data.get('subscription_type') or 'free', 'status': data.get('status') or data.get('subscription_status') or 'active', 'is_active': data.get('is_active', True) } except Exception: return { 'type': 'free', 'status': 'active', 'is_active': True } def _subscription_level(sub_type): """将订阅类型映射到等级数值,free=0, pro=1, max=2。""" mapping = {'free': 0, 'pro': 1, 'max': 2} return mapping.get((sub_type or 'free').lower(), 0) def _has_required_level(required: str) -> bool: """判断当前用户是否达到所需订阅级别。""" info = _get_current_subscription_info() if not info.get('is_active', True): return False return _subscription_level(info.get('type')) >= _subscription_level(required) # ============================================ # 订阅相关API接口 # ============================================ @app.route('/api/subscription/plans', methods=['GET']) def get_subscription_plans(): """获取订阅套餐列表""" try: plans = SubscriptionPlan.query.filter_by(is_active=True).order_by(SubscriptionPlan.sort_order).all() return jsonify({ 'success': True, 'data': [plan.to_dict() for plan in plans] }) except Exception as e: # 返回默认套餐(包含pricing_options以兼容新前端) default_plans = [ { 'id': 1, 'name': 'pro', 'display_name': 'Pro版本', 'description': '适合个人投资者的基础功能套餐', 'monthly_price': 198, 'yearly_price': 2000, 'pricing_options': [ {'months': 1, 'price': 198, 'label': '月付', 'cycle_key': 'monthly'}, {'months': 3, 'price': 534, 'label': '3个月', 'cycle_key': '3months', 'discount_percent': 10}, {'months': 6, 'price': 950, 'label': '半年', 'cycle_key': '6months', 'discount_percent': 20}, {'months': 12, 'price': 2000, 'label': '1年', 'cycle_key': 'yearly', 'discount_percent': 16}, {'months': 24, 'price': 3600, 'label': '2年', 'cycle_key': '2years', 'discount_percent': 24}, {'months': 36, 'price': 5040, 'label': '3年', 'cycle_key': '3years', 'discount_percent': 29} ], 'features': ['基础股票分析工具', '历史数据查询', '基础财务报表', '简单投资计划记录', '标准客服支持'], 'is_active': True, 'sort_order': 1 }, { 'id': 2, 'name': 'max', 'display_name': 'Max版本', 'description': '适合专业投资者的全功能套餐', 'monthly_price': 998, 'yearly_price': 10000, 'pricing_options': [ {'months': 1, 'price': 998, 'label': '月付', 'cycle_key': 'monthly'}, {'months': 3, 'price': 2695, 'label': '3个月', 'cycle_key': '3months', 'discount_percent': 10}, {'months': 6, 'price': 4790, 'label': '半年', 'cycle_key': '6months', 'discount_percent': 20}, {'months': 12, 'price': 10000, 'label': '1年', 'cycle_key': 'yearly', 'discount_percent': 17}, {'months': 24, 'price': 18000, 'label': '2年', 'cycle_key': '2years', 'discount_percent': 25}, {'months': 36, 'price': 25200, 'label': '3年', 'cycle_key': '3years', 'discount_percent': 30} ], 'features': ['全部Pro版本功能', '高级分析工具', '实时数据推送', 'API访问', '优先客服支持'], 'is_active': True, 'sort_order': 2 } ] return jsonify({ 'success': True, 'data': default_plans }) @app.route('/api/subscription/current', methods=['GET']) def get_current_subscription(): """获取当前用户的订阅信息""" try: if 'user_id' not in session: return jsonify({'success': False, 'error': '未登录'}), 401 subscription = get_user_subscription_safe(session['user_id']) return jsonify({ 'success': True, 'data': subscription.to_dict() }) except Exception as e: return jsonify({ 'success': True, 'data': { 'type': 'free', 'status': 'active', 'is_active': True, 'days_left': 999 } }) @app.route('/api/subscription/info', methods=['GET']) def get_subscription_info(): """获取当前用户的订阅信息 - 前端专用接口""" try: info = _get_current_subscription_info() return jsonify({ 'success': True, 'data': info }) except Exception as e: print(f"获取订阅信息错误: {e}") return jsonify({ 'success': True, 'data': { 'type': 'free', 'status': 'active', 'is_active': True, 'days_left': 999 } }) @app.route('/api/promo-code/validate', methods=['POST']) def validate_promo_code_api(): """验证优惠码""" try: if 'user_id' not in session: return jsonify({'success': False, 'error': '未登录'}), 401 data = request.get_json() code = data.get('code', '').strip() plan_name = data.get('plan_name') billing_cycle = data.get('billing_cycle') amount = data.get('amount', 0) if not code or not plan_name or not billing_cycle: return jsonify({'success': False, 'error': '参数不完整'}), 400 # 验证优惠码 promo, error = validate_promo_code(code, plan_name, billing_cycle, amount, session['user_id']) if error: return jsonify({ 'success': False, 'valid': False, 'error': error }) # 计算折扣 discount_amount = calculate_discount(promo, amount) final_amount = amount - discount_amount return jsonify({ 'success': True, 'valid': True, 'promo_code': promo.to_dict(), 'discount_amount': discount_amount, 'final_amount': final_amount }) except Exception as e: return jsonify({ 'success': False, 'error': f'验证失败: {str(e)}' }), 500 @app.route('/api/subscription/calculate-price', methods=['POST']) def calculate_subscription_price(): """ 计算订阅价格(新版:续费和新购价格一致) Request Body: { "to_plan": "pro", "to_cycle": "yearly", "promo_code": "WELCOME2025" // 可选 } Response: { "success": true, "data": { "is_renewal": true, // 是否为续费 "subscription_type": "renew", // new 或 renew "current_plan": "pro", // 当前套餐(如果有) "current_cycle": "monthly", // 当前周期(如果有) "new_plan_price": 2699.00, "original_amount": 2699.00, "discount_amount": 0, "final_amount": 2699.00, "promo_code": null, "promo_error": null } } """ try: if 'user_id' not in session: return jsonify({'success': False, 'error': '未登录'}), 401 data = request.get_json() to_plan = data.get('to_plan') to_cycle = data.get('to_cycle') promo_code = (data.get('promo_code') or '').strip() or None if not to_plan or not to_cycle: return jsonify({'success': False, 'error': '参数不完整'}), 400 # 使用新的简化价格计算函数 result = calculate_subscription_price_simple(session['user_id'], to_plan, to_cycle, promo_code) if 'error' in result: return jsonify({ 'success': False, 'error': result['error'] }), 400 return jsonify({ 'success': True, 'data': result }) except Exception as e: return jsonify({ 'success': False, 'error': f'计算失败: {str(e)}' }), 500 @app.route('/api/subscription/free-upgrade', methods=['POST']) @login_required def free_upgrade_subscription(): """ 免费升级订阅(当剩余价值 >= 新套餐价格时) Request Body: { "plan_name": "max", "billing_cycle": "yearly" } """ try: data = request.get_json() plan_name = data.get('plan_name') billing_cycle = data.get('billing_cycle') if not plan_name or not billing_cycle: return jsonify({'success': False, 'error': '参数不完整'}), 400 user_id = current_user.id # 计算价格,验证是否可以免费升级 price_result = calculate_subscription_price_simple(user_id, plan_name, billing_cycle, None) if 'error' in price_result: return jsonify({'success': False, 'error': price_result['error']}), 400 # 检查是否为升级且实付金额为0 if not price_result.get('is_upgrade') or price_result.get('final_amount', 1) > 0: return jsonify({'success': False, 'error': '当前情况不符合免费升级条件'}), 400 # 获取当前订阅 subscription = UserSubscription.query.filter_by(user_id=user_id).first() if not subscription: return jsonify({'success': False, 'error': '未找到订阅记录'}), 404 # 计算新的到期时间(按剩余价值折算) remaining_value = price_result.get('remaining_value', 0) new_plan_price = price_result.get('new_plan_price', 0) if new_plan_price > 0: # 计算可以兑换的新套餐天数 value_ratio = remaining_value / new_plan_price cycle_days_map = { 'monthly': 30, 'quarterly': 90, 'semiannual': 180, 'yearly': 365 } new_cycle_days = cycle_days_map.get(billing_cycle, 365) # 新的到期天数 = 周期天数 × 价值比例 new_days = int(new_cycle_days * value_ratio) # 更新订阅信息 subscription.subscription_type = plan_name subscription.billing_cycle = billing_cycle subscription.start_date = datetime.utcnow() subscription.end_date = datetime.utcnow() + timedelta(days=new_days) subscription.subscription_status = 'active' subscription.updated_at = datetime.utcnow() db.session.commit() return jsonify({ 'success': True, 'message': f'升级成功!您的{plan_name.upper()}版本将持续{new_days}天', 'data': { 'subscription_type': plan_name, 'end_date': subscription.end_date.isoformat(), 'days': new_days } }) else: return jsonify({'success': False, 'error': '价格计算异常'}), 500 except Exception as e: db.session.rollback() return jsonify({'success': False, 'error': f'升级失败: {str(e)}'}), 500 @app.route('/api/payment/create-order', methods=['POST']) def create_payment_order(): """ 创建支付订单(新版:简化逻辑,不再记录升级) Request Body: { "plan_name": "pro", "billing_cycle": "yearly", "promo_code": "WELCOME2025" // 可选 } """ try: if 'user_id' not in session: return jsonify({'success': False, 'error': '未登录'}), 401 data = request.get_json() plan_name = data.get('plan_name') billing_cycle = data.get('billing_cycle') promo_code = (data.get('promo_code') or '').strip() or None if not plan_name or not billing_cycle: return jsonify({'success': False, 'error': '参数不完整'}), 400 # 使用新的简化价格计算 price_result = calculate_subscription_price_simple(session['user_id'], plan_name, billing_cycle, promo_code) if 'error' in price_result: return jsonify({'success': False, 'error': price_result['error']}), 400 amount = price_result['final_amount'] subscription_type = price_result.get('subscription_type', 'new') # new 或 renew # 检查是否为免费升级(金额为0) if amount <= 0 and price_result.get('is_upgrade'): return jsonify({ 'success': False, 'error': '当前剩余价值可直接免费升级,请使用免费升级功能', 'should_free_upgrade': True, 'price_info': price_result }), 400 # 创建订单 try: # 获取原价和折扣金额 original_amount = price_result.get('original_amount', amount) discount_amount = price_result.get('discount_amount', 0) order = PaymentOrder( user_id=session['user_id'], plan_name=plan_name, billing_cycle=billing_cycle, amount=amount, original_amount=original_amount, discount_amount=discount_amount ) # 添加订阅类型标记(用于前端展示) order.remark = f"{subscription_type}订阅" if subscription_type == 'renew' else "新购订阅" # 如果使用了优惠码,关联优惠码 if promo_code and price_result.get('promo_code'): promo_obj = PromoCode.query.filter_by(code=promo_code.upper()).first() if promo_obj: order.promo_code_id = promo_obj.id print(f"📦 订单关联优惠码: {promo_obj.code} (ID: {promo_obj.id})") db.session.add(order) db.session.commit() except Exception as e: db.session.rollback() return jsonify({'success': False, 'error': f'订单创建失败: {str(e)}'}), 500 # 尝试调用真实的微信支付API try: from wechat_pay import create_wechat_pay_instance, check_wechat_pay_ready # 检查微信支付是否就绪 is_ready, ready_msg = check_wechat_pay_ready() if not is_ready: # 使用模拟二维码 order.qr_code_url = f"https://api.qrserver.com/v1/create-qr-code/?size=200x200&data=wxpay://order/{order.order_no}" order.remark = f"演示模式 - {ready_msg}" else: wechat_pay = create_wechat_pay_instance() # 创建微信支付订单 plan_display_name = f"{plan_name.upper()}版本-{billing_cycle}" wechat_result = wechat_pay.create_native_order( order_no=order.order_no, total_fee=float(amount), body=f"VFr-{plan_display_name}", product_id=f"{plan_name}_{billing_cycle}" ) if wechat_result['success']: # 获取微信返回的原始code_url wechat_code_url = wechat_result['code_url'] # 将微信协议URL转换为二维码图片URL import urllib.parse encoded_url = urllib.parse.quote(wechat_code_url, safe='') qr_image_url = f"https://api.qrserver.com/v1/create-qr-code/?size=200x200&data={encoded_url}" order.qr_code_url = qr_image_url order.prepay_id = wechat_result.get('prepay_id') order.remark = f"微信支付 - {wechat_code_url}" else: order.qr_code_url = f"https://api.qrserver.com/v1/create-qr-code/?size=200x200&data=wxpay://order/{order.order_no}" order.remark = f"微信支付失败: {wechat_result.get('error')}" except ImportError as e: order.qr_code_url = f"https://api.qrserver.com/v1/create-qr-code/?size=200x200&data=wxpay://order/{order.order_no}" order.remark = "微信支付模块未配置" except Exception as e: order.qr_code_url = f"https://api.qrserver.com/v1/create-qr-code/?size=200x200&data=wxpay://order/{order.order_no}" order.remark = f"支付异常: {str(e)}" db.session.commit() return jsonify({ 'success': True, 'data': order.to_dict(), 'message': '订单创建成功' }) except Exception as e: db.session.rollback() return jsonify({'success': False, 'error': '创建订单失败'}), 500 @app.route('/api/payment/order//status', methods=['GET']) def check_order_status(order_id): """查询订单支付状态""" try: if 'user_id' not in session: return jsonify({'success': False, 'error': '未登录'}), 401 # 查找订单 order = PaymentOrder.query.filter_by( id=order_id, user_id=session['user_id'] ).first() if not order: return jsonify({'success': False, 'error': '订单不存在'}), 404 # 如果订单已经是已支付状态,直接返回 if order.status == 'paid': return jsonify({ 'success': True, 'data': order.to_dict(), 'message': '订单已支付', 'payment_success': True }) # 如果订单过期,标记为过期 if order.is_expired(): order.status = 'expired' db.session.commit() return jsonify({ 'success': True, 'data': order.to_dict(), 'message': '订单已过期' }) # 调用微信支付API查询真实状态 try: from wechat_pay import create_wechat_pay_instance wechat_pay = create_wechat_pay_instance() query_result = wechat_pay.query_order(order_no=order.order_no) if query_result['success']: trade_state = query_result.get('trade_state') transaction_id = query_result.get('transaction_id') if trade_state == 'SUCCESS': # 支付成功,更新订单状态 order.mark_as_paid(transaction_id) # 激活用户订阅 activate_user_subscription(order.user_id, order.plan_name, order.billing_cycle) # 记录优惠码使用情况 if order.promo_code_id: try: existing_usage = PromoCodeUsage.query.filter_by(order_id=order.id).first() if not existing_usage: usage = PromoCodeUsage( promo_code_id=order.promo_code_id, user_id=order.user_id, order_id=order.id, original_amount=order.original_amount or order.amount, discount_amount=order.discount_amount or 0, final_amount=order.amount ) db.session.add(usage) promo = PromoCode.query.get(order.promo_code_id) if promo: promo.current_uses = (promo.current_uses or 0) + 1 print(f"🎫 优惠码使用记录已创建: {promo.code}") except Exception as e: print(f"⚠️ 记录优惠码使用失败: {e}") db.session.commit() return jsonify({ 'success': True, 'data': order.to_dict(), 'message': '支付成功!订阅已激活', 'payment_success': True }) elif trade_state in ['NOTPAY', 'USERPAYING']: # 未支付或支付中 return jsonify({ 'success': True, 'data': order.to_dict(), 'message': '等待支付...', 'payment_success': False }) else: # 支付失败或取消 order.status = 'cancelled' db.session.commit() return jsonify({ 'success': True, 'data': order.to_dict(), 'message': '支付已取消', 'payment_success': False }) else: # 微信查询失败,返回当前状态 return jsonify({ 'success': True, 'data': order.to_dict(), 'message': f"查询失败: {query_result.get('error')}", 'payment_success': False }) except Exception as e: # 查询失败,返回当前订单状态 return jsonify({ 'success': True, 'data': order.to_dict(), 'message': '无法查询支付状态,请稍后重试', 'payment_success': False }) except Exception as e: return jsonify({'success': False, 'error': '查询失败'}), 500 @app.route('/api/payment/order//force-update', methods=['POST']) def force_update_order_status(order_id): """强制更新订单支付状态(调试用)""" try: if 'user_id' not in session: return jsonify({'success': False, 'error': '未登录'}), 401 # 查找订单 order = PaymentOrder.query.filter_by( id=order_id, user_id=session['user_id'] ).first() if not order: return jsonify({'success': False, 'error': '订单不存在'}), 404 # 检查微信支付状态 try: from wechat_pay import create_wechat_pay_instance wechat_pay = create_wechat_pay_instance() query_result = wechat_pay.query_order(order_no=order.order_no) if query_result['success'] and query_result.get('trade_state') == 'SUCCESS': # 强制更新为已支付 old_status = order.status order.mark_as_paid(query_result.get('transaction_id')) # 激活用户订阅 activate_user_subscription(order.user_id, order.plan_name, order.billing_cycle) # 记录优惠码使用(如果使用了优惠码) if order.promo_code_id: try: # 检查是否已经记录过(防止重复) existing_usage = PromoCodeUsage.query.filter_by(order_id=order.id).first() if not existing_usage: promo_usage = PromoCodeUsage( promo_code_id=order.promo_code_id, user_id=order.user_id, order_id=order.id, original_amount=order.original_amount or order.amount, discount_amount=order.discount_amount or 0, final_amount=order.amount ) db.session.add(promo_usage) # 更新优惠码使用次数 promo = PromoCode.query.get(order.promo_code_id) if promo: promo.current_uses = (promo.current_uses or 0) + 1 print(f"🎫 优惠码使用记录已创建: {promo.code}") else: print(f"ℹ️ 优惠码使用记录已存在,跳过") except Exception as e: print(f"⚠️ 记录优惠码使用失败: {e}") db.session.commit() print(f"✅ 订单状态强制更新成功: {old_status} -> paid") return jsonify({ 'success': True, 'message': f'订单状态已从 {old_status} 更新为 paid', 'data': order.to_dict(), 'payment_success': True }) else: return jsonify({ 'success': False, 'error': '微信支付状态不是成功状态,无法强制更新' }) except Exception as e: print(f"❌ 强制更新失败: {e}") return jsonify({ 'success': False, 'error': f'强制更新失败: {str(e)}' }) except Exception as e: print(f"强制更新订单状态失败: {str(e)}") return jsonify({'success': False, 'error': '操作失败'}), 500 @app.route('/api/payment/wechat/callback', methods=['POST']) def wechat_payment_callback(): """微信支付回调处理""" try: # 获取原始XML数据 raw_data = request.get_data() print(f"📥 收到微信支付回调: {raw_data}") # 验证回调数据 try: from wechat_pay import create_wechat_pay_instance wechat_pay = create_wechat_pay_instance() verify_result = wechat_pay.verify_callback(raw_data.decode('utf-8')) if not verify_result['success']: print(f"❌ 微信支付回调验证失败: {verify_result['error']}") return '' callback_data = verify_result['data'] except Exception as e: print(f"❌ 微信支付回调处理异常: {e}") # 简单解析XML(fallback) callback_data = _parse_xml_callback(raw_data.decode('utf-8')) if not callback_data: return '' # 获取关键字段 return_code = callback_data.get('return_code') result_code = callback_data.get('result_code') order_no = callback_data.get('out_trade_no') transaction_id = callback_data.get('transaction_id') print(f"📦 回调数据解析:") print(f" 返回码: {return_code}") print(f" 结果码: {result_code}") print(f" 订单号: {order_no}") print(f" 交易号: {transaction_id}") if not order_no: return '' # 查找订单 order = PaymentOrder.query.filter_by(order_no=order_no).first() if not order: print(f"❌ 订单不存在: {order_no}") return '' # 处理支付成功 if return_code == 'SUCCESS' and result_code == 'SUCCESS': print(f"🎉 支付回调成功: 订单 {order_no}") # 检查订单是否已经处理过 if order.status == 'paid': print(f"ℹ️ 订单已处理过: {order_no}") db.session.commit() return '' # 更新订单状态(无论之前是什么状态) old_status = order.status order.mark_as_paid(transaction_id) print(f"📝 订单状态已更新: {old_status} -> paid") # 激活用户订阅 subscription = activate_user_subscription(order.user_id, order.plan_name, order.billing_cycle) if subscription: print(f"✅ 用户订阅已激活: 用户{order.user_id}, 套餐{order.plan_name}") else: print(f"⚠️ 订阅激活失败,但订单已标记为已支付") # 记录优惠码使用情况 if order.promo_code_id: try: # 检查是否已经记录过(防止重复) existing_usage = PromoCodeUsage.query.filter_by( order_id=order.id ).first() if not existing_usage: # 创建优惠码使用记录 usage = PromoCodeUsage( promo_code_id=order.promo_code_id, user_id=order.user_id, order_id=order.id, original_amount=order.original_amount or order.amount, discount_amount=order.discount_amount or 0, final_amount=order.amount ) db.session.add(usage) # 更新优惠码使用次数 promo = PromoCode.query.get(order.promo_code_id) if promo: promo.current_uses = (promo.current_uses or 0) + 1 print(f"🎫 优惠码使用记录已创建: {promo.code}, 当前使用次数: {promo.current_uses}") else: print(f"ℹ️ 优惠码使用记录已存在,跳过") except Exception as e: print(f"⚠️ 记录优惠码使用失败: {e}") # 不影响主流程,继续执行 db.session.commit() # 返回成功响应给微信 return '' except Exception as e: db.session.rollback() print(f"❌ 微信支付回调处理失败: {e}") import traceback app.logger.error(f"回调处理错误: {e}", exc_info=True) return '' def _parse_xml_callback(xml_data): """简单的XML回调数据解析""" try: import xml.etree.ElementTree as ET root = ET.fromstring(xml_data) result = {} for child in root: result[child.tag] = child.text return result except Exception as e: print(f"XML解析失败: {e}") return None @app.route('/api/auth/session', methods=['GET']) def get_session_info(): """获取当前登录用户信息""" if 'user_id' in session: user = User.query.get(session['user_id']) if user: # 获取用户订阅信息 subscription_info = get_user_subscription_safe(user.id).to_dict() return jsonify({ 'success': True, 'isAuthenticated': True, 'user': { 'id': user.id, 'username': user.username, 'nickname': user.nickname or user.username, 'email': user.email, 'phone': user.phone, 'phone_confirmed': bool(user.phone_confirmed), 'email_confirmed': bool(user.email_confirmed) if hasattr(user, 'email_confirmed') else None, 'avatar_url': user.avatar_url, 'has_wechat': bool(user.wechat_open_id), 'created_at': user.created_at.isoformat() if user.created_at else None, 'last_seen': user.last_seen.isoformat() if user.last_seen else None, # 将订阅字段映射到前端期望的字段名 'subscription_type': subscription_info['type'], 'subscription_status': subscription_info['status'], 'subscription_end_date': subscription_info['end_date'], 'is_subscription_active': subscription_info['is_active'], 'subscription_days_left': subscription_info['days_left'] } }) return jsonify({ 'success': True, 'isAuthenticated': False, 'user': None }) def generate_verification_code(): """生成6位数字验证码""" return ''.join(random.choices(string.digits, k=6)) @app.route('/api/auth/login', methods=['POST']) def login(): """传统登录 - 使用Session""" try: username = request.form.get('username') email = request.form.get('email') phone = request.form.get('phone') password = request.form.get('password') # 验证必要参数 if not password: return jsonify({'success': False, 'error': '密码不能为空'}), 400 # 根据提供的信息查找用户 user = None if username: # 检查username是否为手机号格式 if re.match(r'^1[3-9]\d{9}$', username): # 如果username是手机号格式,先按手机号查找 user = User.query.filter_by(phone=username).first() if not user: # 如果没找到,再按用户名查找 user = User.find_by_login_info(username) else: # 不是手机号格式,按用户名查找 user = User.find_by_login_info(username) elif email: user = User.query.filter_by(email=email).first() elif phone: user = User.query.filter_by(phone=phone).first() else: return jsonify({'success': False, 'error': '请提供用户名、邮箱或手机号'}), 400 if not user: return jsonify({'success': False, 'error': '用户不存在'}), 404 # 尝试密码验证 password_valid = user.check_password(password) if not password_valid: # 还可以尝试直接验证 if user.password_hash: from werkzeug.security import check_password_hash direct_check = check_password_hash(user.password_hash, password) return jsonify({'success': False, 'error': '密码错误'}), 401 # 设置session session.permanent = True # 使用永久session session['user_id'] = user.id session['username'] = user.username session['logged_in'] = True # Flask-Login 登录 login_user(user, remember=True) # 更新最后登录时间 user.update_last_seen() return jsonify({ 'success': True, 'message': '登录成功', 'user': { 'id': user.id, 'username': user.username, 'nickname': user.nickname or user.username, 'email': user.email, 'phone': user.phone, 'avatar_url': user.avatar_url, 'has_wechat': bool(user.wechat_open_id) } }) except Exception as e: import traceback app.logger.error(f"回调处理错误: {e}", exc_info=True) return jsonify({'success': False, 'error': '登录处理失败,请重试'}), 500 # 添加OPTIONS请求处理 @app.before_request def handle_preflight(): if request.method == "OPTIONS": response = make_response() response.headers.add("Access-Control-Allow-Origin", "*") response.headers.add('Access-Control-Allow-Headers', "*") response.headers.add('Access-Control-Allow-Methods', "*") return response # 修改密码API @app.route('/api/account/change-password', methods=['POST']) @login_required def change_password(): """修改当前用户密码""" try: data = request.get_json() or request.form current_password = data.get('currentPassword') or data.get('current_password') new_password = data.get('newPassword') or data.get('new_password') is_first_set = data.get('isFirstSet', False) # 是否为首次设置密码 if not new_password: return jsonify({'success': False, 'error': '新密码不能为空'}), 400 if len(new_password) < 6: return jsonify({'success': False, 'error': '新密码至少需要6个字符'}), 400 # 获取当前用户 user = current_user if not user: return jsonify({'success': False, 'error': '用户未登录'}), 401 # 检查是否为微信用户且首次设置密码 is_wechat_user = bool(user.wechat_open_id) # 如果是微信用户首次设置密码,或者明确标记为首次设置,则跳过当前密码验证 if is_first_set or (is_wechat_user and not current_password): pass # 跳过当前密码验证 else: # 普通用户或非首次设置,需要验证当前密码 if not current_password: return jsonify({'success': False, 'error': '请输入当前密码'}), 400 if not user.check_password(current_password): return jsonify({'success': False, 'error': '当前密码错误'}), 400 # 设置新密码 user.set_password(new_password) db.session.commit() return jsonify({ 'success': True, 'message': '密码设置成功' if (is_first_set or is_wechat_user) else '密码修改成功' }) except Exception as e: return jsonify({'success': False, 'error': str(e)}), 500 # 检查用户密码状态API @app.route('/api/account/password-status', methods=['GET']) @login_required def get_password_status(): """获取当前用户的密码状态信息""" try: user = current_user if not user: return jsonify({'success': False, 'error': '用户未登录'}), 401 is_wechat_user = bool(user.wechat_open_id) return jsonify({ 'success': True, 'data': { 'isWechatUser': is_wechat_user, 'hasPassword': bool(user.password_hash), 'needsFirstTimeSetup': is_wechat_user # 微信用户需要首次设置 } }) except Exception as e: return jsonify({'success': False, 'error': str(e)}), 500 # 检查用户信息完整性API @app.route('/api/account/profile-completeness', methods=['GET']) @login_required def get_profile_completeness(): try: user = current_user if not user: return jsonify({'success': False, 'error': '用户未登录'}), 401 is_wechat_user = bool(user.wechat_open_id) # 检查各项信息 completeness = { 'hasPassword': bool(user.password_hash), 'hasPhone': bool(user.phone), 'hasEmail': bool(user.email and '@' in user.email and not user.email.endswith('@valuefrontier.temp')), 'isWechatUser': is_wechat_user } # 计算完整度 total_items = 3 completed_items = sum([completeness['hasPassword'], completeness['hasPhone'], completeness['hasEmail']]) completeness_percentage = int((completed_items / total_items) * 100) # 智能判断是否需要提醒 needs_attention = False missing_items = [] # 只在用户首次登录或最近登录时提醒 if is_wechat_user: # 检查用户是否是新用户(注册7天内) is_new_user = (datetime.now() - user.created_at).days < 7 # 检查是否最近没有提醒过(使用session记录) last_reminder = session.get('last_completeness_reminder') should_remind = False if not last_reminder: should_remind = True else: # 每7天最多提醒一次 days_since_reminder = (datetime.now() - datetime.fromisoformat(last_reminder)).days should_remind = days_since_reminder >= 7 # 只对新用户或长时间未完善的用户提醒 if (is_new_user or completeness_percentage < 50) and should_remind: needs_attention = True if not completeness['hasPassword']: missing_items.append('登录密码') if not completeness['hasPhone']: missing_items.append('手机号') if not completeness['hasEmail']: missing_items.append('邮箱') # 记录本次提醒时间 session['last_completeness_reminder'] = datetime.now().isoformat() return jsonify({ 'success': True, 'data': { 'completeness': completeness, 'completenessPercentage': completeness_percentage, 'needsAttention': needs_attention, 'missingItems': missing_items, 'isComplete': completed_items == total_items, 'showReminder': needs_attention # 前端使用这个字段决定是否显示提醒 } }) except Exception as e: print(f"获取资料完整性错误: {e}") return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/auth/logout', methods=['POST']) def logout(): """登出 - 清除Session""" logout_user() # Flask-Login 登出 session.clear() return jsonify({'success': True, 'message': '已登出'}) @app.route('/api/auth/send-verification-code', methods=['POST']) def send_verification_code(): """发送验证码(支持手机号和邮箱)""" try: data = request.get_json() credential = data.get('credential') # 手机号或邮箱 code_type = data.get('type') # 'phone' 或 'email' purpose = data.get('purpose', 'login') # 'login' 或 'register' if not credential or not code_type: return jsonify({'success': False, 'error': '缺少必要参数'}), 400 # 清理格式字符(空格、横线、括号等) if code_type == 'phone': # 移除手机号中的空格、横线、括号、加号等格式字符 credential = re.sub(r'[\s\-\(\)\+]', '', credential) print(f"📱 清理后的手机号: {credential}") elif code_type == 'email': # 邮箱只移除空格 credential = credential.strip() # 生成验证码 verification_code = generate_verification_code() # 存储验证码到session(实际生产环境建议使用Redis) session_key = f'verification_code_{code_type}_{credential}_{purpose}' session[session_key] = { 'code': verification_code, 'timestamp': time.time(), 'attempts': 0 } if code_type == 'phone': # 手机号验证码发送 if not re.match(r'^1[3-9]\d{9}$', credential): return jsonify({'success': False, 'error': '手机号格式不正确'}), 400 # 发送真实短信验证码 if send_sms_code(credential, verification_code, SMS_TEMPLATE_LOGIN): print(f"[短信已发送] 验证码到 {credential}: {verification_code}") else: return jsonify({'success': False, 'error': '短信发送失败,请稍后重试'}), 500 elif code_type == 'email': # 邮箱验证码发送 if not re.match(r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$', credential): return jsonify({'success': False, 'error': '邮箱格式不正确'}), 400 # 发送真实邮件验证码 if send_email_code(credential, verification_code): print(f"[邮件已发送] 验证码到 {credential}: {verification_code}") else: return jsonify({'success': False, 'error': '邮件发送失败,请稍后重试'}), 500 else: return jsonify({'success': False, 'error': '不支持的验证码类型'}), 400 return jsonify({ 'success': True, 'message': f'验证码已发送到您的{code_type}' }) except Exception as e: print(f"发送验证码错误: {e}") return jsonify({'success': False, 'error': '发送验证码失败'}), 500 @app.route('/api/auth/login-with-code', methods=['POST']) def login_with_verification_code(): """使用验证码登录/注册(自动注册)""" try: data = request.get_json() credential = data.get('credential') # 手机号或邮箱 verification_code = data.get('verification_code') login_type = data.get('login_type') # 'phone' 或 'email' if not credential or not verification_code or not login_type: return jsonify({'success': False, 'error': '缺少必要参数'}), 400 # 清理格式字符(空格、横线、括号等) if login_type == 'phone': # 移除手机号中的空格、横线、括号、加号等格式字符 original_credential = credential credential = re.sub(r'[\s\-\(\)\+]', '', credential) if original_credential != credential: print(f"📱 登录时清理手机号: {original_credential} -> {credential}") elif login_type == 'email': # 邮箱只移除前后空格 credential = credential.strip() # 检查验证码 session_key = f'verification_code_{login_type}_{credential}_login' stored_code_info = session.get(session_key) if not stored_code_info: return jsonify({'success': False, 'error': '验证码已过期或不存在'}), 400 # 检查验证码是否过期(5分钟) if time.time() - stored_code_info['timestamp'] > 300: session.pop(session_key, None) return jsonify({'success': False, 'error': '验证码已过期'}), 400 # 检查尝试次数 if stored_code_info['attempts'] >= 3: session.pop(session_key, None) return jsonify({'success': False, 'error': '验证码错误次数过多'}), 400 # 验证码错误 if stored_code_info['code'] != verification_code: stored_code_info['attempts'] += 1 session[session_key] = stored_code_info return jsonify({'success': False, 'error': '验证码错误'}), 400 # 验证码正确,查找用户 user = None is_new_user = False if login_type == 'phone': user = User.query.filter_by(phone=credential).first() if not user: # 自动注册新用户 is_new_user = True # 生成唯一用户名 base_username = f"user_{credential}" username = base_username counter = 1 while User.query.filter_by(username=username).first(): username = f"{base_username}_{counter}" counter += 1 # 创建新用户 user = User(username=username, phone=credential) user.phone_confirmed = True user.email = f"{username}@valuefrontier.temp" # 临时邮箱 db.session.add(user) db.session.commit() elif login_type == 'email': user = User.query.filter_by(email=credential).first() if not user: # 自动注册新用户 is_new_user = True # 从邮箱生成用户名 email_prefix = credential.split('@')[0] base_username = f"user_{email_prefix}" username = base_username counter = 1 while User.query.filter_by(username=username).first(): username = f"{base_username}_{counter}" counter += 1 # 如果用户不存在,自动创建新用户 if not user: try: # 生成用户名 if login_type == 'phone': # 使用手机号生成用户名 base_username = f"用户{credential[-4:]}" elif login_type == 'email': # 使用邮箱前缀生成用户名 base_username = credential.split('@')[0] else: base_username = "新用户" # 确保用户名唯一 username = base_username counter = 1 while User.is_username_taken(username): username = f"{base_username}_{counter}" counter += 1 # 创建新用户 user = User(username=username) # 设置手机号或邮箱 if login_type == 'phone': user.phone = credential elif login_type == 'email': user.email = credential # 设置默认密码(使用随机密码,用户后续可以修改) user.set_password(uuid.uuid4().hex) user.status = 'active' user.nickname = username db.session.add(user) db.session.commit() is_new_user = True print(f"✅ 自动创建新用户: {username}, {login_type}: {credential}") except Exception as e: print(f"❌ 创建用户失败: {e}") db.session.rollback() return jsonify({'success': False, 'error': '创建用户失败'}), 500 # 清除验证码 session.pop(session_key, None) # 设置session session.permanent = True session['user_id'] = user.id session['username'] = user.username session['logged_in'] = True # Flask-Login 登录 login_user(user, remember=True) # 更新最后登录时间 user.update_last_seen() # 根据是否为新用户返回不同的消息 message = '注册成功,欢迎加入!' if is_new_user else '登录成功' return jsonify({ 'success': True, 'message': message, 'is_new_user': is_new_user, 'user': { 'id': user.id, 'username': user.username, 'nickname': user.nickname or user.username, 'email': user.email, 'phone': user.phone, 'avatar_url': user.avatar_url, 'has_wechat': bool(user.wechat_open_id) } }) except Exception as e: print(f"验证码登录错误: {e}") db.session.rollback() return jsonify({'success': False, 'error': '登录失败'}), 500 @app.route('/api/auth/register', methods=['POST']) def register(): """用户注册 - 使用Session""" username = request.form.get('username') email = request.form.get('email') password = request.form.get('password') # 验证输入 if not all([username, email, password]): return jsonify({'success': False, 'error': '所有字段都是必填的'}), 400 # 检查用户名和邮箱是否已存在 if User.is_username_taken(username): return jsonify({'success': False, 'error': '用户名已存在'}), 400 if User.is_email_taken(email): return jsonify({'success': False, 'error': '邮箱已被使用'}), 400 try: # 创建新用户 user = User(username=username, email=email) user.set_password(password) user.email_confirmed = True # 暂时默认已确认 db.session.add(user) db.session.flush() # 获取 user.id # 自动创建积分账户,初始10000积分 credit_account = UserCreditAccount( user_id=user.id, balance=10000, frozen=0 ) db.session.add(credit_account) db.session.commit() # 自动登录 session.permanent = True session['user_id'] = user.id session['username'] = user.username session['logged_in'] = True # Flask-Login 登录 login_user(user, remember=True) return jsonify({ 'success': True, 'message': '注册成功', 'user': { 'id': user.id, 'username': user.username, 'nickname': user.nickname or user.username, 'email': user.email } }), 201 except Exception as e: db.session.rollback() print(f"验证码登录/注册错误: {e}") return jsonify({'success': False, 'error': '登录失败'}), 500 def send_sms_code(phone, code, template_id): """发送短信验证码""" try: cred = credential.Credential(SMS_SECRET_ID, SMS_SECRET_KEY) httpProfile = HttpProfile() httpProfile.endpoint = "sms.tencentcloudapi.com" clientProfile = ClientProfile() clientProfile.httpProfile = httpProfile client = sms_client.SmsClient(cred, "ap-beijing", clientProfile) req = models.SendSmsRequest() params = { "PhoneNumberSet": [phone], "SmsSdkAppId": SMS_SDK_APP_ID, "TemplateId": template_id, "SignName": SMS_SIGN_NAME, "TemplateParamSet": [code, "5"] if template_id == SMS_TEMPLATE_LOGIN else [code] } req.from_json_string(json.dumps(params)) resp = client.SendSms(req) return True except TencentCloudSDKException as err: print(f"SMS Error: {err}") return False def send_email_code(email, code): """发送邮件验证码""" try: print(f"[邮件发送] 准备发送验证码到: {email}") print(f"[邮件配置] 服务器: {MAIL_SERVER}, 端口: {MAIL_PORT}, SSL: {MAIL_USE_SSL}") msg = Message( subject='价值前沿 - 验证码', recipients=[email], body=f'您的验证码是:{code},有效期5分钟。如非本人操作,请忽略此邮件。' ) mail.send(msg) print(f"[邮件发送] 验证码邮件发送成功到: {email}") return True except Exception as e: print(f"[邮件发送错误] 发送到 {email} 失败: {str(e)}") print(f"[邮件发送错误] 错误类型: {type(e).__name__}") return False @app.route('/api/auth/send-sms-code', methods=['POST']) def send_sms_verification(): """发送手机验证码""" data = request.get_json() phone = data.get('phone') if not phone: return jsonify({'error': '手机号不能为空'}), 400 # 注册时验证是否已注册;若用于绑定手机,需要另外接口 # 这里保留原逻辑,新增绑定接口处理不同规则 if User.query.filter_by(phone=phone).first(): return jsonify({'error': '该手机号已注册'}), 400 # 生成验证码 code = generate_verification_code() # 发送短信 if send_sms_code(phone, code, SMS_TEMPLATE_REGISTER): # 存储验证码(5分钟有效) verification_codes[f'phone_{phone}'] = { 'code': code, 'expires': time.time() + 300 } return jsonify({'message': '验证码已发送'}), 200 else: return jsonify({'error': '验证码发送失败'}), 500 @app.route('/api/auth/send-email-code', methods=['POST']) def send_email_verification(): """发送邮箱验证码""" data = request.get_json() email = data.get('email') if not email: return jsonify({'error': '邮箱不能为空'}), 400 if User.query.filter_by(email=email).first(): return jsonify({'error': '该邮箱已注册'}), 400 # 生成验证码 code = generate_verification_code() # 发送邮件 if send_email_code(email, code): # 存储验证码(5分钟有效) verification_codes[f'email_{email}'] = { 'code': code, 'expires': time.time() + 300 } return jsonify({'message': '验证码已发送'}), 200 else: return jsonify({'error': '验证码发送失败'}), 500 @app.route('/api/auth/register/phone', methods=['POST']) def register_with_phone(): """手机号注册 - 使用Session""" data = request.get_json() phone = data.get('phone') code = data.get('code') password = data.get('password') username = data.get('username') if not all([phone, code, password, username]): return jsonify({'success': False, 'error': '所有字段都是必填的'}), 400 # 验证验证码 stored_code = verification_codes.get(f'phone_{phone}') if not stored_code or stored_code['expires'] < time.time(): return jsonify({'success': False, 'error': '验证码已过期'}), 400 if stored_code['code'] != code: return jsonify({'success': False, 'error': '验证码错误'}), 400 if User.query.filter_by(username=username).first(): return jsonify({'success': False, 'error': '用户名已存在'}), 400 try: # 创建用户 user = User(username=username, phone=phone) user.email = f"{username}@valuefrontier.temp" user.set_password(password) user.phone_confirmed = True db.session.add(user) db.session.flush() # 获取 user.id # 自动创建积分账户,初始10000积分 credit_account = UserCreditAccount( user_id=user.id, balance=10000, frozen=0 ) db.session.add(credit_account) db.session.commit() # 清除验证码 del verification_codes[f'phone_{phone}'] # 自动登录 session.permanent = True session['user_id'] = user.id session['username'] = user.username session['logged_in'] = True # Flask-Login 登录 login_user(user, remember=True) return jsonify({ 'success': True, 'message': '注册成功', 'user': { 'id': user.id, 'username': user.username, 'phone': user.phone } }), 201 except Exception as e: db.session.rollback() return jsonify({'success': False, 'error': '注册失败,请重试'}), 500 @app.route('/api/account/phone/send-code', methods=['POST']) def send_sms_bind_code(): """发送绑定手机验证码(需已登录)""" if not session.get('logged_in'): return jsonify({'error': '未登录'}), 401 data = request.get_json() phone = data.get('phone') if not phone: return jsonify({'error': '手机号不能为空'}), 400 # 绑定时要求手机号未被占用 if User.query.filter_by(phone=phone).first(): return jsonify({'error': '该手机号已被其他账号使用'}), 400 code = generate_verification_code() if send_sms_code(phone, code, SMS_TEMPLATE_REGISTER): verification_codes[f'bind_{phone}'] = { 'code': code, 'expires': time.time() + 300 } return jsonify({'message': '验证码已发送'}), 200 else: return jsonify({'error': '验证码发送失败'}), 500 @app.route('/api/account/phone/bind', methods=['POST']) def bind_phone(): """当前登录用户绑定手机号""" if not session.get('logged_in'): return jsonify({'error': '未登录'}), 401 data = request.get_json() phone = data.get('phone') code = data.get('code') if not phone or not code: return jsonify({'error': '手机号和验证码不能为空'}), 400 stored = verification_codes.get(f'bind_{phone}') if not stored or stored['expires'] < time.time(): return jsonify({'error': '验证码已过期'}), 400 if stored['code'] != code: return jsonify({'error': '验证码错误'}), 400 if User.query.filter_by(phone=phone).first(): return jsonify({'error': '该手机号已被其他账号使用'}), 400 try: user = User.query.get(session.get('user_id')) if not user: return jsonify({'error': '用户不存在'}), 404 user.phone = phone user.confirm_phone() # 清除验证码 del verification_codes[f'bind_{phone}'] return jsonify({'message': '绑定成功', 'success': True}), 200 except Exception as e: print(f"Bind phone error: {e}") db.session.rollback() return jsonify({'error': '绑定失败,请重试'}), 500 @app.route('/api/account/phone/unbind', methods=['POST']) def unbind_phone(): """解绑手机号(需已登录)""" if not session.get('logged_in'): return jsonify({'error': '未登录'}), 401 try: user = User.query.get(session.get('user_id')) if not user: return jsonify({'error': '用户不存在'}), 404 user.phone = None user.phone_confirmed = False user.phone_confirm_time = None db.session.commit() return jsonify({'message': '解绑成功', 'success': True}), 200 except Exception as e: print(f"Unbind phone error: {e}") db.session.rollback() return jsonify({'error': '解绑失败,请重试'}), 500 @app.route('/api/account/email/send-bind-code', methods=['POST']) def send_email_bind_code(): """发送绑定邮箱验证码(需已登录)""" if not session.get('logged_in'): return jsonify({'error': '未登录'}), 401 data = request.get_json() email = data.get('email') if not email: return jsonify({'error': '邮箱不能为空'}), 400 # 邮箱格式验证 if not re.match(r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$', email): return jsonify({'error': '邮箱格式不正确'}), 400 # 检查邮箱是否已被其他账号使用 if User.query.filter_by(email=email).first(): return jsonify({'error': '该邮箱已被其他账号使用'}), 400 # 生成验证码 code = ''.join(random.choices(string.digits, k=6)) if send_email_code(email, code): # 存储验证码(5分钟有效) verification_codes[f'bind_{email}'] = { 'code': code, 'expires': time.time() + 300 } return jsonify({'message': '验证码已发送'}), 200 else: return jsonify({'error': '验证码发送失败'}), 500 @app.route('/api/account/email/bind', methods=['POST']) def bind_email(): """当前登录用户绑定邮箱""" if not session.get('logged_in'): return jsonify({'error': '未登录'}), 401 data = request.get_json() email = data.get('email') code = data.get('code') if not email or not code: return jsonify({'error': '邮箱和验证码不能为空'}), 400 stored = verification_codes.get(f'bind_{email}') if not stored or stored['expires'] < time.time(): return jsonify({'error': '验证码已过期'}), 400 if stored['code'] != code: return jsonify({'error': '验证码错误'}), 400 if User.query.filter_by(email=email).first(): return jsonify({'error': '该邮箱已被其他账号使用'}), 400 try: user = User.query.get(session.get('user_id')) if not user: return jsonify({'error': '用户不存在'}), 404 user.email = email user.confirm_email() db.session.commit() # 清除验证码 del verification_codes[f'bind_{email}'] return jsonify({ 'message': '邮箱绑定成功', 'success': True, 'user': { 'email': user.email, 'email_confirmed': user.email_confirmed } }), 200 except Exception as e: print(f"Bind email error: {e}") db.session.rollback() return jsonify({'error': '绑定失败,请重试'}), 500 @app.route('/api/account/email/unbind', methods=['POST']) def unbind_email(): """解绑邮箱(需已登录)""" if not session.get('logged_in'): return jsonify({'error': '未登录'}), 401 try: user = User.query.get(session.get('user_id')) if not user: return jsonify({'error': '用户不存在'}), 404 user.email = None user.email_confirmed = False db.session.commit() return jsonify({'message': '解绑成功', 'success': True}), 200 except Exception as e: print(f"Unbind email error: {e}") db.session.rollback() return jsonify({'error': '解绑失败,请重试'}), 500 @app.route('/api/auth/register/email', methods=['POST']) def register_with_email(): """邮箱注册 - 使用Session""" data = request.get_json() email = data.get('email') code = data.get('code') password = data.get('password') username = data.get('username') if not all([email, code, password, username]): return jsonify({'success': False, 'error': '所有字段都是必填的'}), 400 # 验证验证码 stored_code = verification_codes.get(f'email_{email}') if not stored_code or stored_code['expires'] < time.time(): return jsonify({'success': False, 'error': '验证码已过期'}), 400 if stored_code['code'] != code: return jsonify({'success': False, 'error': '验证码错误'}), 400 if User.query.filter_by(username=username).first(): return jsonify({'success': False, 'error': '用户名已存在'}), 400 try: # 创建用户 user = User(username=username, email=email) user.set_password(password) user.email_confirmed = True db.session.add(user) db.session.flush() # 获取 user.id # 自动创建积分账户,初始10000积分 credit_account = UserCreditAccount( user_id=user.id, balance=10000, frozen=0 ) db.session.add(credit_account) db.session.commit() # 清除验证码 del verification_codes[f'email_{email}'] # 自动登录 session.permanent = True session['user_id'] = user.id session['username'] = user.username session['logged_in'] = True # Flask-Login 登录 login_user(user, remember=True) return jsonify({ 'success': True, 'message': '注册成功', 'user': { 'id': user.id, 'username': user.username, 'email': user.email } }), 201 except Exception as e: db.session.rollback() return jsonify({'success': False, 'error': '注册失败,请重试'}), 500 def get_wechat_access_token(code): """通过code获取微信access_token""" url = "https://api.weixin.qq.com/sns/oauth2/access_token" params = { 'appid': WECHAT_APPID, 'secret': WECHAT_APPSECRET, 'code': code, 'grant_type': 'authorization_code' } try: response = requests.get(url, params=params, timeout=10) data = response.json() if 'errcode' in data: print(f"WeChat access token error: {data}") return None return data except Exception as e: print(f"WeChat access token request error: {e}") return None def get_wechat_userinfo(access_token, openid): """获取微信用户信息(包含UnionID)""" url = "https://api.weixin.qq.com/sns/userinfo" params = { 'access_token': access_token, 'openid': openid, 'lang': 'zh_CN' } try: response = requests.get(url, params=params, timeout=10) response.encoding = 'utf-8' # 明确设置编码为UTF-8 data = response.json() if 'errcode' in data: print(f"WeChat userinfo error: {data}") return None # 确保nickname字段的编码正确 if 'nickname' in data and data['nickname']: # 确保昵称是正确的UTF-8编码 try: # 检查是否已经是正确的UTF-8字符串 data['nickname'] = data['nickname'].encode('utf-8').decode('utf-8') except (UnicodeEncodeError, UnicodeDecodeError) as e: print(f"Nickname encoding error: {e}, using default") data['nickname'] = '微信用户' return data except Exception as e: print(f"WeChat userinfo request error: {e}") return None @app.route('/api/auth/wechat/qrcode', methods=['GET']) def get_wechat_qrcode(): """返回微信授权URL,前端使用iframe展示""" # 生成唯一state参数 state = uuid.uuid4().hex # URL编码回调地址 redirect_uri = urllib.parse.quote_plus(WECHAT_REDIRECT_URI) # 构建微信授权URL wechat_auth_url = ( f"https://open.weixin.qq.com/connect/qrconnect?" f"appid={WECHAT_APPID}&redirect_uri={redirect_uri}" f"&response_type=code&scope=snsapi_login&state={state}" "#wechat_redirect" ) # 存储session信息 wechat_qr_sessions[state] = { 'status': 'waiting', 'expires': time.time() + 300, # 5分钟过期 'user_info': None, 'wechat_openid': None, 'wechat_unionid': None } return jsonify({"code":0, "data": { 'auth_url': wechat_auth_url, 'session_id': state, 'expires_in': 300 }}), 200 @app.route('/api/account/wechat/qrcode', methods=['GET']) def get_wechat_bind_qrcode(): """发起微信绑定二维码,会话标记为绑定模式""" if not session.get('logged_in'): return jsonify({'error': '未登录'}), 401 # 生成唯一state参数 state = uuid.uuid4().hex # URL编码回调地址 redirect_uri = urllib.parse.quote_plus(WECHAT_REDIRECT_URI) # 构建微信授权URL wechat_auth_url = ( f"https://open.weixin.qq.com/connect/qrconnect?" f"appid={WECHAT_APPID}&redirect_uri={redirect_uri}" f"&response_type=code&scope=snsapi_login&state={state}" "#wechat_redirect" ) # 存储session信息,标记为绑定模式并记录目标用户 wechat_qr_sessions[state] = { 'status': 'waiting', 'expires': time.time() + 300, 'mode': 'bind', 'bind_user_id': session.get('user_id'), 'user_info': None, 'wechat_openid': None, 'wechat_unionid': None } return jsonify({ 'auth_url': wechat_auth_url, 'session_id': state, 'expires_in': 300 }), 200 @app.route('/api/auth/wechat/check', methods=['POST']) def check_wechat_scan(): """检查微信扫码状态""" data = request.get_json() session_id = data.get('session_id') if not session_id or session_id not in wechat_qr_sessions: return jsonify({'status': 'invalid', 'error': '无效的session'}), 400 session = wechat_qr_sessions[session_id] # 检查是否过期 if time.time() > session['expires']: del wechat_qr_sessions[session_id] return jsonify({'status': 'expired'}), 200 return jsonify({ 'status': session['status'], 'user_info': session.get('user_info'), 'expires_in': int(session['expires'] - time.time()) }), 200 @app.route('/api/account/wechat/check', methods=['POST']) def check_wechat_bind_scan(): """检查微信扫码绑定状态""" data = request.get_json() session_id = data.get('session_id') if not session_id or session_id not in wechat_qr_sessions: return jsonify({'status': 'invalid', 'error': '无效的session'}), 400 sess = wechat_qr_sessions[session_id] # 绑定模式限制 if sess.get('mode') != 'bind': return jsonify({'status': 'invalid', 'error': '会话模式错误'}), 400 # 过期处理 if time.time() > sess['expires']: del wechat_qr_sessions[session_id] return jsonify({'status': 'expired'}), 200 return jsonify({ 'status': sess['status'], 'user_info': sess.get('user_info'), 'expires_in': int(sess['expires'] - time.time()) }), 200 @app.route('/api/auth/wechat/callback', methods=['GET']) def wechat_callback(): """微信授权回调处理 - 使用Session""" code = request.args.get('code') state = request.args.get('state') error = request.args.get('error') # 错误处理:用户拒绝授权 if error: if state in wechat_qr_sessions: wechat_qr_sessions[state]['status'] = 'auth_denied' wechat_qr_sessions[state]['error'] = '用户拒绝授权' print(f"❌ 用户拒绝授权: state={state}") return redirect('/auth/signin?error=wechat_auth_denied') # 参数验证 if not code or not state: if state in wechat_qr_sessions: wechat_qr_sessions[state]['status'] = 'auth_failed' wechat_qr_sessions[state]['error'] = '授权参数缺失' return redirect('/auth/signin?error=wechat_auth_failed') # 验证state if state not in wechat_qr_sessions: return redirect('/auth/signin?error=session_expired') session_data = wechat_qr_sessions[state] # 检查过期 if time.time() > session_data['expires']: del wechat_qr_sessions[state] return redirect('/auth/signin?error=session_expired') try: # 步骤1: 用户已扫码并授权(微信回调过来说明用户已完成扫码+授权) session_data['status'] = 'scanned' print(f"✅ 微信扫码回调: state={state}, code={code[:10]}...") # 步骤2: 获取access_token token_data = get_wechat_access_token(code) if not token_data: session_data['status'] = 'auth_failed' session_data['error'] = '获取访问令牌失败' print(f"❌ 获取微信access_token失败: state={state}") return redirect('/auth/signin?error=token_failed') # 步骤3: Token获取成功,标记为已授权 session_data['status'] = 'authorized' print(f"✅ 微信授权成功: openid={token_data['openid']}") # 步骤4: 获取用户信息 user_info = get_wechat_userinfo(token_data['access_token'], token_data['openid']) if not user_info: session_data['status'] = 'auth_failed' session_data['error'] = '获取用户信息失败' print(f"❌ 获取微信用户信息失败: openid={token_data['openid']}") return redirect('/auth/signin?error=userinfo_failed') # 查找或创建用户 / 或处理绑定 openid = token_data['openid'] unionid = user_info.get('unionid') or token_data.get('unionid') # 如果是绑定流程 session_item = wechat_qr_sessions.get(state) if session_item and session_item.get('mode') == 'bind': try: target_user_id = session.get('user_id') or session_item.get('bind_user_id') if not target_user_id: return redirect('/auth/signin?error=bind_no_user') target_user = User.query.get(target_user_id) if not target_user: return redirect('/auth/signin?error=bind_user_missing') # 检查该微信是否已被其他账户绑定 existing = None if unionid: existing = User.query.filter_by(wechat_union_id=unionid).first() if not existing: existing = User.query.filter_by(wechat_open_id=openid).first() if existing and existing.id != target_user.id: session_item['status'] = 'bind_conflict' return redirect('/home?bind=conflict') # 执行绑定 target_user.bind_wechat(openid, unionid, wechat_info=user_info) # 标记绑定完成,供前端轮询 session_item['status'] = 'bind_ready' session_item['user_info'] = {'user_id': target_user.id} return redirect('/home?bind=success') except Exception as e: print(f"❌ 微信绑定失败: {e}") db.session.rollback() session_item['status'] = 'bind_failed' return redirect('/home?bind=failed') user = None is_new_user = False if unionid: user = User.query.filter_by(wechat_union_id=unionid).first() if not user: user = User.query.filter_by(wechat_open_id=openid).first() if not user: # 创建新用户 # 先清理微信昵称 raw_nickname = user_info.get('nickname', '微信用户') # 创建临时用户实例以使用清理方法 temp_user = User.__new__(User) sanitized_nickname = temp_user._sanitize_nickname(raw_nickname) username = sanitized_nickname counter = 1 while User.is_username_taken(username): username = f"{sanitized_nickname}_{counter}" counter += 1 user = User(username=username) user.nickname = sanitized_nickname user.avatar_url = user_info.get('headimgurl') user.wechat_open_id = openid user.wechat_union_id = unionid user.set_password(uuid.uuid4().hex) user.status = 'active' db.session.add(user) db.session.commit() is_new_user = True print(f"✅ 微信扫码自动创建新用户: {username}, openid: {openid}") # 更新最后登录时间 user.update_last_seen() # 设置session session.permanent = True session['user_id'] = user.id session['username'] = user.username session['logged_in'] = True session['wechat_login'] = True # 标记是微信登录 # Flask-Login 登录 login_user(user, remember=True) # 更新微信session状态,供前端轮询检测 if state in wechat_qr_sessions: session_item = wechat_qr_sessions[state] # 仅处理登录/注册流程,不处理绑定流程 if not session_item.get('mode'): # 更新状态和用户信息 session_item['status'] = 'register_ready' if is_new_user else 'login_ready' session_item['user_info'] = {'user_id': user.id} print(f"✅ 微信扫码状态已更新: {session_item['status']}, user_id: {user.id}") # 直接跳转到首页 return redirect('/home') except Exception as e: print(f"❌ 微信登录失败: {e}") import traceback traceback.print_exc() db.session.rollback() # 更新session状态为失败 if state in wechat_qr_sessions: wechat_qr_sessions[state]['status'] = 'auth_failed' wechat_qr_sessions[state]['error'] = str(e) return redirect('/auth/signin?error=login_failed') @app.route('/api/auth/login/wechat', methods=['POST']) def login_with_wechat(): """微信登录 - 修复版本""" data = request.get_json() session_id = data.get('session_id') if not session_id: return jsonify({'success': False, 'error': 'session_id不能为空'}), 400 # 验证session session = wechat_qr_sessions.get(session_id) if not session: return jsonify({'success': False, 'error': '会话不存在或已过期'}), 400 # 检查session状态 if session['status'] not in ['login_ready', 'register_ready']: return jsonify({'success': False, 'error': '会话状态无效'}), 400 # 检查是否有用户信息 user_info = session.get('user_info') if not user_info or not user_info.get('user_id'): return jsonify({'success': False, 'error': '用户信息不完整'}), 400 try: user = User.query.get(user_info['user_id']) if not user: return jsonify({'success': False, 'error': '用户不存在'}), 404 # 更新最后登录时间 user.update_last_seen() # ✅ 修复:不立即删除session,而是标记为已完成,避免轮询报错 # 原因:前端可能还在轮询检查状态,立即删除会导致 "无效的session" 错误 # 保留原状态(login_ready/register_ready),前端会正确处理 # wechat_qr_sessions[session_id]['status'] 保持不变 # 设置延迟删除(10秒后自动清理,给前端足够时间完成轮询) import threading def delayed_cleanup(): import time time.sleep(10) if session_id in wechat_qr_sessions: del wechat_qr_sessions[session_id] print(f"✅ 延迟清理微信登录session: {session_id[:8]}...") threading.Thread(target=delayed_cleanup, daemon=True).start() # 生成登录响应 response_data = { 'success': True, 'message': '登录成功' if session['status'] == 'login_ready' else '注册并登录成功', 'user': { 'id': user.id, 'username': user.username, 'nickname': user.nickname or user.username, 'email': user.email, 'avatar_url': user.avatar_url, 'has_wechat': True, 'wechat_open_id': user.wechat_open_id, 'wechat_union_id': user.wechat_union_id, 'created_at': user.created_at.isoformat() if user.created_at else None, 'last_seen': user.last_seen.isoformat() if user.last_seen else None }, 'isNewUser': session['status'] == 'register_ready' # 标记是否为新用户 } # 如果需要token认证,可以在这里生成 # response_data['token'] = generate_token(user.id) return jsonify(response_data), 200 except Exception as e: print(f"❌ 微信登录错误: {e}") import traceback app.logger.error(f"回调处理错误: {e}", exc_info=True) return jsonify({ 'success': False, 'error': '登录失败,请重试' }), 500 @app.route('/api/account/wechat/unbind', methods=['POST']) def unbind_wechat_account(): """解绑当前登录用户的微信""" if not session.get('logged_in'): return jsonify({'error': '未登录'}), 401 try: user = User.query.get(session.get('user_id')) if not user: return jsonify({'error': '用户不存在'}), 404 user.unbind_wechat() return jsonify({'message': '解绑成功', 'success': True}), 200 except Exception as e: print(f"Unbind wechat error: {e}") db.session.rollback() return jsonify({'error': '解绑失败,请重试'}), 500 # 评论模型 class EventComment(db.Model): """事件评论""" __tablename__ = 'event_comment' id = db.Column(db.Integer, primary_key=True) event_id = db.Column(db.Integer, nullable=False) user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=True) author = db.Column(db.String(50), default='匿名用户') content = db.Column(db.Text, nullable=False) parent_id = db.Column(db.Integer, db.ForeignKey('event_comment.id')) likes = db.Column(db.Integer, default=0) created_at = db.Column(db.DateTime, default=beijing_now) status = db.Column(db.String(20), default='active') user = db.relationship('User', backref='event_comments') replies = db.relationship('EventComment', backref=db.backref('parent', remote_side=[id])) def to_dict(self, user_session_id=None, current_user_id=None): # 检查当前用户是否已点赞 user_liked = False if user_session_id: like_record = CommentLike.query.filter_by( comment_id=self.id, session_id=user_session_id ).first() user_liked = like_record is not None # 检查当前用户是否可以删除此评论 can_delete = current_user_id is not None and self.user_id == current_user_id return { 'id': self.id, 'event_id': self.event_id, 'author': self.author, 'content': self.content, 'parent_id': self.parent_id, 'likes': self.likes, 'created_at': self.created_at.isoformat() if self.created_at else None, 'user_liked': user_liked, 'can_delete': can_delete, 'user_id': self.user_id, 'replies': [reply.to_dict(user_session_id, current_user_id) for reply in self.replies if reply.status == 'active'] } class CommentLike(db.Model): """评论点赞记录""" __tablename__ = 'comment_like' id = db.Column(db.Integer, primary_key=True) comment_id = db.Column(db.Integer, db.ForeignKey('event_comment.id'), nullable=False) session_id = db.Column(db.String(100), nullable=False) created_at = db.Column(db.DateTime, default=beijing_now) __table_args__ = (db.UniqueConstraint('comment_id', 'session_id'),) @app.after_request def after_request(response): """处理所有响应,添加CORS头部和安全头部""" origin = request.headers.get('Origin') allowed_origins = ['http://localhost:3000', 'http://127.0.0.1:3000', 'http://localhost:5173', 'https://valuefrontier.cn', 'http://valuefrontier.cn'] if origin in allowed_origins: response.headers['Access-Control-Allow-Origin'] = origin response.headers['Access-Control-Allow-Credentials'] = 'true' response.headers['Access-Control-Allow-Headers'] = 'Content-Type,Authorization,X-Requested-With' response.headers['Access-Control-Allow-Methods'] = 'GET,PUT,POST,DELETE,OPTIONS' response.headers['Access-Control-Expose-Headers'] = 'Content-Type,Authorization' # 处理预检请求 if request.method == 'OPTIONS': response.status_code = 200 return response def add_cors_headers(response): """添加CORS头(保留原有函数以兼容)""" origin = request.headers.get('Origin') allowed_origins = ['http://localhost:3000', 'http://127.0.0.1:3000', 'http://localhost:5173', 'https://valuefrontier.cn', 'http://valuefrontier.cn'] if origin in allowed_origins: response.headers['Access-Control-Allow-Origin'] = origin else: response.headers['Access-Control-Allow-Origin'] = 'http://localhost:3000' response.headers['Access-Control-Allow-Headers'] = 'Content-Type,Authorization,X-Requested-With' response.headers['Access-Control-Allow-Methods'] = 'GET,PUT,POST,DELETE,OPTIONS' response.headers['Access-Control-Allow-Credentials'] = 'true' return response class EventFollow(db.Model): """事件关注""" id = db.Column(db.Integer, primary_key=True) user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False) event_id = db.Column(db.Integer, db.ForeignKey('event.id'), nullable=False) created_at = db.Column(db.DateTime, default=beijing_now) user = db.relationship('User', backref='event_follows') __table_args__ = (db.UniqueConstraint('user_id', 'event_id'),) class FutureEventFollow(db.Model): """未来事件关注""" __tablename__ = 'future_event_follow' id = db.Column(db.Integer, primary_key=True) user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False) future_event_id = db.Column(db.Integer, nullable=False) # future_events表的id created_at = db.Column(db.DateTime, default=beijing_now) user = db.relationship('User', backref='future_event_follows') __table_args__ = (db.UniqueConstraint('user_id', 'future_event_id'),) # —— 自选股输入统一化与名称补全工具 —— def _normalize_stock_input(raw_input: str): """解析用户输入为标准6位股票代码与可选名称。 支持: - 6位代码: "600519",或带后缀 "600519.SH"/"600519.SZ" - 名称(代码): "贵州茅台(600519)" 或 "贵州茅台(600519)" 返回 (code6, name_or_none) """ if not raw_input: return None, None s = str(raw_input).strip() # 名称(600519) 或 名称(600519) m = re.match(r"^(.+?)[\((]\s*(\d{6})\s*[\))]\s*$", s) if m: name = m.group(1).strip() code = m.group(2) return code, (name if name else None) # 600519 或 600519.SH / 600519.SZ m2 = re.match(r"^(\d{6})(?:\.(?:SH|SZ))?$", s, re.IGNORECASE) if m2: return m2.group(1), None # SH600519 / SZ000001 m3 = re.match(r"^(SH|SZ)(\d{6})$", s, re.IGNORECASE) if m3: return m3.group(2), None return None, None def _query_stock_name_by_code(code6: str): """根据6位代码查询股票名称,查不到返回None。""" try: with engine.connect() as conn: q = text(""" SELECT SECNAME FROM ea_baseinfo WHERE SECCODE = :c LIMIT 1 """) row = conn.execute(q, {'c': code6}).fetchone() if row: return row[0] except Exception: pass return None class Watchlist(db.Model): """用户自选股""" __tablename__ = 'watchlist' id = db.Column(db.Integer, primary_key=True) user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False) stock_code = db.Column(db.String(20), nullable=False) stock_name = db.Column(db.String(100), nullable=True) created_at = db.Column(db.DateTime, default=beijing_now) user = db.relationship('User', backref='watchlist') __table_args__ = (db.UniqueConstraint('user_id', 'stock_code'),) @app.route('/api/account/watchlist', methods=['GET']) def get_my_watchlist(): """获取当前用户的自选股列表""" try: if 'user_id' not in session: return jsonify({'success': False, 'error': '未登录'}), 401 items = Watchlist.query.filter_by(user_id=session['user_id']).order_by(Watchlist.created_at.desc()).all() # 懒更新:统一代码为6位、补全缺失的名称,并去重(同一代码保留一个记录) from collections import defaultdict groups = defaultdict(list) for i in items: code6, _ = _normalize_stock_input(i.stock_code) normalized_code = code6 or (i.stock_code.strip().upper() if isinstance(i.stock_code, str) else i.stock_code) groups[normalized_code].append(i) dirty = False to_delete = [] for code6, group in groups.items(): # 选择保留记录:优先有名称的,其次创建时间早的 def sort_key(x): return (x.stock_name is None, x.created_at or datetime.min) group_sorted = sorted(group, key=sort_key) keep = group_sorted[0] # 规范保留项 if keep.stock_code != code6: keep.stock_code = code6 dirty = True if not keep.stock_name and code6: nm = _query_stock_name_by_code(code6) if nm: keep.stock_name = nm dirty = True # 其余删除 for g in group_sorted[1:]: to_delete.append(g) if to_delete: for g in to_delete: db.session.delete(g) dirty = True if dirty: db.session.commit() return jsonify({'success': True, 'data': [ { 'id': i.id, 'stock_code': i.stock_code, 'stock_name': i.stock_name, 'created_at': i.created_at.isoformat() if i.created_at else None } for i in items ]}) except Exception as e: print(f"Error in get_my_watchlist: {str(e)}") return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/account/watchlist', methods=['POST']) def add_to_watchlist(): """添加到自选股""" if 'user_id' not in session: return jsonify({'success': False, 'error': '未登录'}), 401 data = request.get_json() or {} raw_code = data.get('stock_code') raw_name = data.get('stock_name') code6, name_from_input = _normalize_stock_input(raw_code) if not code6: return jsonify({'success': False, 'error': '无效的股票标识'}), 400 # 优先使用传入名称,其次从输入解析中获得,最后查库补全 final_name = raw_name or name_from_input or _query_stock_name_by_code(code6) # 查找已存在记录,兼容历史:6位/带后缀 candidates = [code6, f"{code6}.SH", f"{code6}.SZ"] existing = Watchlist.query.filter( Watchlist.user_id == session['user_id'], Watchlist.stock_code.in_(candidates) ).first() if existing: # 统一为6位,补全名称 updated = False if existing.stock_code != code6: existing.stock_code = code6 updated = True if (not existing.stock_name) and final_name: existing.stock_name = final_name updated = True if updated: db.session.commit() return jsonify({'success': True, 'data': {'id': existing.id}}) item = Watchlist(user_id=session['user_id'], stock_code=code6, stock_name=final_name) db.session.add(item) db.session.commit() return jsonify({'success': True, 'data': {'id': item.id}}) @app.route('/api/account/watchlist/', methods=['DELETE']) def remove_from_watchlist(stock_code): """从自选股移除""" if 'user_id' not in session: return jsonify({'success': False, 'error': '未登录'}), 401 code6, _ = _normalize_stock_input(stock_code) candidates = [] if code6: candidates = [code6, f"{code6}.SH", f"{code6}.SZ"] # 包含原始传入(以兼容历史) if stock_code not in candidates: candidates.append(stock_code) item = Watchlist.query.filter( Watchlist.user_id == session['user_id'], Watchlist.stock_code.in_(candidates) ).first() if not item: return jsonify({'success': False, 'error': '未找到自选项'}), 404 db.session.delete(item) db.session.commit() return jsonify({'success': True}) @app.route('/api/account/watchlist/realtime', methods=['GET']) def get_watchlist_realtime(): """获取自选股实时行情数据(基于分钟线)""" try: if 'user_id' not in session: return jsonify({'success': False, 'error': '未登录'}), 401 # 获取用户自选股列表 watchlist = Watchlist.query.filter_by(user_id=session['user_id']).all() if not watchlist: return jsonify({'success': True, 'data': []}) # 获取股票代码列表 stock_codes = [] for item in watchlist: code6, _ = _normalize_stock_input(item.stock_code) # 统一内部查询代码 normalized = code6 or str(item.stock_code).strip().upper() stock_codes.append(normalized) # 使用现有的分钟线接口获取最新行情 client = get_clickhouse_client() quotes_data = {} # 获取最新交易日 today = datetime.now().date() # 获取每只股票的最新价格 for code in stock_codes: raw_code = str(code).strip().upper() if '.' in raw_code: stock_code_full = raw_code else: stock_code_full = f"{raw_code}.SH" if raw_code.startswith('6') else f"{raw_code}.SZ" # 获取最新分钟线数据(先查近7天,若无数据再兜底倒序取最近一条) query = """ SELECT close, timestamp, high, low, volume, amt FROM stock_minute WHERE code = %(code)s AND timestamp >= %(start)s ORDER BY timestamp DESC LIMIT 1 \ """ # 获取最近7天的分钟数据 start_date = today - timedelta(days=7) result = client.execute(query, { 'code': stock_code_full, 'start': datetime.combine(start_date, dt_time(9, 30)) }) # 若近7天无数据,兜底直接取最近一条 if not result: fallback_query = """ SELECT close, timestamp, high, low, volume, amt FROM stock_minute WHERE code = %(code)s ORDER BY timestamp DESC LIMIT 1 \ """ result = client.execute(fallback_query, {'code': stock_code_full}) if result: latest_data = result[0] latest_ts = latest_data[1] # 获取该bar所属交易日前一个交易日的收盘价 prev_close_query = """ SELECT close FROM stock_minute WHERE code = %(code)s AND timestamp \ < %(start)s ORDER BY timestamp DESC LIMIT 1 \ """ prev_result = client.execute(prev_close_query, { 'code': stock_code_full, 'start': datetime.combine(latest_ts.date(), dt_time(9, 30)) }) prev_close = float(prev_result[0][0]) if prev_result else float(latest_data[0]) # 计算涨跌幅 change = float(latest_data[0]) - prev_close change_percent = (change / prev_close * 100) if prev_close > 0 else 0.0 quotes_data[code] = { 'price': float(latest_data[0]), 'prev_close': float(prev_close), 'change': float(change), 'change_percent': float(change_percent), 'high': float(latest_data[2]), 'low': float(latest_data[3]), 'volume': int(latest_data[4]), 'amount': float(latest_data[5]), 'update_time': latest_ts.strftime('%H:%M:%S') } # 构建响应数据 response_data = [] for item in watchlist: code6, _ = _normalize_stock_input(item.stock_code) quote = quotes_data.get(code6 or item.stock_code, {}) response_data.append({ 'stock_code': code6 or item.stock_code, 'stock_name': item.stock_name or (code6 and _query_stock_name_by_code(code6)) or None, 'current_price': quote.get('price', 0), 'prev_close': quote.get('prev_close', 0), 'change': quote.get('change', 0), 'change_percent': quote.get('change_percent', 0), 'high': quote.get('high', 0), 'low': quote.get('low', 0), 'volume': quote.get('volume', 0), 'amount': quote.get('amount', 0), 'update_time': quote.get('update_time', ''), # industry 字段在 Watchlist 模型中不存在,先不返回该字段 }) return jsonify({ 'success': True, 'data': response_data }) except Exception as e: print(f"获取实时行情失败: {str(e)}") return jsonify({'success': False, 'error': '获取实时行情失败'}), 500 # 投资计划和复盘相关的模型 class InvestmentPlan(db.Model): __tablename__ = 'investment_plans' id = db.Column(db.Integer, primary_key=True, autoincrement=True) user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False) date = db.Column(db.Date, nullable=False) title = db.Column(db.String(200), nullable=False) content = db.Column(db.Text) type = db.Column(db.String(20)) # 'plan' or 'review' stocks = db.Column(db.Text) # JSON array of stock codes tags = db.Column(db.String(500)) # JSON array of tags status = db.Column(db.String(20), default='active') # active, completed, cancelled created_at = db.Column(db.DateTime, default=datetime.utcnow) updated_at = db.Column(db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) def to_dict(self): return { 'id': self.id, 'date': self.date.isoformat() if self.date else None, 'title': self.title, 'content': self.content, 'type': self.type, 'stocks': json.loads(self.stocks) if self.stocks else [], 'tags': json.loads(self.tags) if self.tags else [], 'status': self.status, 'created_at': self.created_at.isoformat() if self.created_at else None, 'updated_at': self.updated_at.isoformat() if self.updated_at else None } @app.route('/api/account/investment-plans', methods=['GET']) def get_investment_plans(): """获取投资计划和复盘记录""" try: if 'user_id' not in session: return jsonify({'success': False, 'error': '未登录'}), 401 plan_type = request.args.get('type') # 'plan', 'review', or None for all start_date = request.args.get('start_date') end_date = request.args.get('end_date') query = InvestmentPlan.query.filter_by(user_id=session['user_id']) if plan_type: query = query.filter_by(type=plan_type) if start_date: query = query.filter(InvestmentPlan.date >= datetime.fromisoformat(start_date).date()) if end_date: query = query.filter(InvestmentPlan.date <= datetime.fromisoformat(end_date).date()) plans = query.order_by(InvestmentPlan.date.desc()).all() return jsonify({ 'success': True, 'data': [plan.to_dict() for plan in plans] }) except Exception as e: print(f"获取投资计划失败: {str(e)}") return jsonify({'success': False, 'error': '获取数据失败'}), 500 @app.route('/api/account/investment-plans', methods=['POST']) def create_investment_plan(): """创建投资计划或复盘记录""" try: if 'user_id' not in session: return jsonify({'success': False, 'error': '未登录'}), 401 data = request.get_json() # 验证必要字段 if not data.get('date') or not data.get('title') or not data.get('type'): return jsonify({'success': False, 'error': '缺少必要字段'}), 400 plan = InvestmentPlan( user_id=session['user_id'], date=datetime.fromisoformat(data['date']).date(), title=data['title'], content=data.get('content', ''), type=data['type'], stocks=json.dumps(data.get('stocks', [])), tags=json.dumps(data.get('tags', [])), status=data.get('status', 'active') ) db.session.add(plan) db.session.commit() return jsonify({ 'success': True, 'data': plan.to_dict() }) except Exception as e: db.session.rollback() print(f"创建投资计划失败: {str(e)}") return jsonify({'success': False, 'error': '创建失败'}), 500 @app.route('/api/account/investment-plans/', methods=['PUT']) def update_investment_plan(plan_id): """更新投资计划或复盘记录""" try: if 'user_id' not in session: return jsonify({'success': False, 'error': '未登录'}), 401 plan = InvestmentPlan.query.filter_by(id=plan_id, user_id=session['user_id']).first() if not plan: return jsonify({'success': False, 'error': '未找到该记录'}), 404 data = request.get_json() if 'date' in data: plan.date = datetime.fromisoformat(data['date']).date() if 'title' in data: plan.title = data['title'] if 'content' in data: plan.content = data['content'] if 'stocks' in data: plan.stocks = json.dumps(data['stocks']) if 'tags' in data: plan.tags = json.dumps(data['tags']) if 'status' in data: plan.status = data['status'] plan.updated_at = datetime.utcnow() db.session.commit() return jsonify({ 'success': True, 'data': plan.to_dict() }) except Exception as e: db.session.rollback() print(f"更新投资计划失败: {str(e)}") return jsonify({'success': False, 'error': '更新失败'}), 500 @app.route('/api/account/investment-plans/', methods=['DELETE']) def delete_investment_plan(plan_id): """删除投资计划或复盘记录""" try: if 'user_id' not in session: return jsonify({'success': False, 'error': '未登录'}), 401 plan = InvestmentPlan.query.filter_by(id=plan_id, user_id=session['user_id']).first() if not plan: return jsonify({'success': False, 'error': '未找到该记录'}), 404 db.session.delete(plan) db.session.commit() return jsonify({'success': True}) except Exception as e: db.session.rollback() print(f"删除投资计划失败: {str(e)}") return jsonify({'success': False, 'error': '删除失败'}), 500 @app.route('/api/account/events/following', methods=['GET']) def get_my_following_events(): """获取我关注的事件列表""" if 'user_id' not in session: return jsonify({'success': False, 'error': '未登录'}), 401 follows = EventFollow.query.filter_by(user_id=session['user_id']).order_by(EventFollow.created_at.desc()).all() event_ids = [f.event_id for f in follows] if not event_ids: return jsonify({'success': True, 'data': []}) events = Event.query.filter(Event.id.in_(event_ids)).all() data = [] for ev in events: data.append({ 'id': ev.id, 'title': ev.title, 'event_type': ev.event_type, 'start_time': ev.start_time.isoformat() if ev.start_time else None, 'hot_score': ev.hot_score, 'follower_count': ev.follower_count, }) return jsonify({'success': True, 'data': data}) @app.route('/api/account/events/comments', methods=['GET']) def get_my_event_comments(): """获取我在事件上的评论(EventComment)""" if 'user_id' not in session: return jsonify({'success': False, 'error': '未登录'}), 401 comments = EventComment.query.filter_by(user_id=session['user_id']).order_by(EventComment.created_at.desc()).limit( 100).all() return jsonify({'success': True, 'data': [c.to_dict() for c in comments]}) @app.route('/api/account/events/posts', methods=['GET']) def get_my_event_posts(): """获取我在事件上的帖子(Post)- 用于个人中心显示""" if 'user_id' not in session: return jsonify({'success': False, 'error': '未登录'}), 401 try: # 查询当前用户的所有 Post(按创建时间倒序) posts = Post.query.filter_by( user_id=session['user_id'], status='active' ).order_by(Post.created_at.desc()).limit(100).all() posts_data = [] for post in posts: # 获取关联的事件信息 event = Event.query.get(post.event_id) event_title = event.title if event else '未知事件' # 获取用户信息 user = User.query.get(post.user_id) author = user.username if user else '匿名用户' # ⚡ 返回格式兼容旧 EventComment.to_dict() posts_data.append({ 'id': post.id, 'event_id': post.event_id, 'event_title': event_title, # ⚡ 新增字段(旧 API 没有) 'user_id': post.user_id, 'author': author, # ⚡ 兼容旧格式(字符串类型) 'content': post.content, 'title': post.title, # Post 独有字段(可选) 'content_type': post.content_type, # Post 独有字段 'likes': post.likes_count, # ⚡ 兼容旧字段名 'created_at': post.created_at.isoformat(), 'updated_at': post.updated_at.isoformat(), 'status': post.status, }) return jsonify({'success': True, 'data': posts_data}) except Exception as e: print(f"获取用户帖子失败: {e}") return jsonify({'success': False, 'error': '获取帖子失败'}), 500 @app.route('/api/account/future-events/following', methods=['GET']) def get_my_following_future_events(): """获取当前用户关注的未来事件""" if 'user_id' not in session: return jsonify({'success': False, 'error': '未登录'}), 401 try: # 获取用户关注的未来事件ID列表 follows = FutureEventFollow.query.filter_by(user_id=session['user_id']).all() future_event_ids = [f.future_event_id for f in follows] if not future_event_ids: return jsonify({'success': True, 'data': []}) # 查询未来事件详情 sql = """ SELECT * FROM future_events WHERE data_id IN :event_ids ORDER BY calendar_time \ """ result = db.session.execute( text(sql), {'event_ids': tuple(future_event_ids)} ) events = [] for row in result: event_data = { 'id': row.data_id, 'title': row.title, 'type': row.type, 'calendar_time': row.calendar_time.isoformat(), 'star': row.star, 'former': row.former, 'forecast': row.forecast, 'fact': row.fact, 'is_following': True, # 这些都是已关注的 'related_stocks': parse_json_field(row.related_stocks), 'concepts': parse_json_field(row.concepts) } events.append(event_data) return jsonify({'success': True, 'data': events}) except Exception as e: return jsonify({'success': False, 'error': str(e)}), 500 class PostLike(db.Model): """帖子点赞""" id = db.Column(db.Integer, primary_key=True) user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False) post_id = db.Column(db.Integer, db.ForeignKey('post.id'), nullable=False) created_at = db.Column(db.DateTime, default=beijing_now) user = db.relationship('User', backref='post_likes') __table_args__ = (db.UniqueConstraint('user_id', 'post_id'),) # =========================== # 预测市场系统模型 # =========================== class UserCreditAccount(db.Model): """用户积分账户""" __tablename__ = 'user_credit_account' id = db.Column(db.Integer, primary_key=True) user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False, unique=True) # 积分余额 balance = db.Column(db.Float, default=10000.0, nullable=False) # 初始10000积分 frozen_balance = db.Column(db.Float, default=0.0, nullable=False) # 冻结积分 total_earned = db.Column(db.Float, default=0.0, nullable=False) # 累计获得 total_spent = db.Column(db.Float, default=0.0, nullable=False) # 累计消费 # 时间 created_at = db.Column(db.DateTime, default=beijing_now, nullable=False) updated_at = db.Column(db.DateTime, default=beijing_now, onupdate=beijing_now) last_daily_bonus_at = db.Column(db.DateTime) # 最后一次领取每日奖励时间 # 关系 user = db.relationship('User', backref=db.backref('credit_account', uselist=False)) def __repr__(self): return f'' class PredictionTopic(db.Model): """预测话题""" __tablename__ = 'prediction_topic' id = db.Column(db.Integer, primary_key=True) creator_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False) # 基本信息 title = db.Column(db.String(200), nullable=False) description = db.Column(db.Text) category = db.Column(db.String(50), default='stock') # stock/event/market # 市场数据 yes_total_shares = db.Column(db.Integer, default=0, nullable=False) # YES方总份额 no_total_shares = db.Column(db.Integer, default=0, nullable=False) # NO方总份额 yes_price = db.Column(db.Float, default=500.0, nullable=False) # YES方价格(0-1000) no_price = db.Column(db.Float, default=500.0, nullable=False) # NO方价格(0-1000) # 奖池 total_pool = db.Column(db.Float, default=0.0, nullable=False) # 总奖池(2%交易税累积) # 领主信息 yes_lord_id = db.Column(db.Integer, db.ForeignKey('user.id')) # YES方领主 no_lord_id = db.Column(db.Integer, db.ForeignKey('user.id')) # NO方领主 # 状态 status = db.Column(db.String(20), default='active', nullable=False) # active/settled/cancelled result = db.Column(db.String(10)) # yes/no/draw(结算结果) # 时间 deadline = db.Column(db.DateTime, nullable=False) # 截止时间 settled_at = db.Column(db.DateTime) # 结算时间 created_at = db.Column(db.DateTime, default=beijing_now, nullable=False) updated_at = db.Column(db.DateTime, default=beijing_now, onupdate=beijing_now) # 统计 views_count = db.Column(db.Integer, default=0) comments_count = db.Column(db.Integer, default=0) participants_count = db.Column(db.Integer, default=0) # 关系 creator = db.relationship('User', foreign_keys=[creator_id], backref='created_topics') yes_lord = db.relationship('User', foreign_keys=[yes_lord_id], backref='yes_lord_topics') no_lord = db.relationship('User', foreign_keys=[no_lord_id], backref='no_lord_topics') positions = db.relationship('PredictionPosition', backref='topic', lazy='dynamic') transactions = db.relationship('PredictionTransaction', backref='topic', lazy='dynamic') comments = db.relationship('TopicComment', backref='topic', lazy='dynamic') def __repr__(self): return f'' class PredictionPosition(db.Model): """用户持仓""" __tablename__ = 'prediction_position' id = db.Column(db.Integer, primary_key=True) user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False) topic_id = db.Column(db.Integer, db.ForeignKey('prediction_topic.id'), nullable=False) # 持仓信息 direction = db.Column(db.String(3), nullable=False) # yes/no shares = db.Column(db.Integer, default=0, nullable=False) # 持有份额 avg_cost = db.Column(db.Float, default=0.0, nullable=False) # 平均成本 total_invested = db.Column(db.Float, default=0.0, nullable=False) # 总投入 # 时间 created_at = db.Column(db.DateTime, default=beijing_now, nullable=False) updated_at = db.Column(db.DateTime, default=beijing_now, onupdate=beijing_now) # 关系 user = db.relationship('User', backref='prediction_positions') # 唯一约束:每个用户在每个话题的每个方向只能有一个持仓 __table_args__ = (db.UniqueConstraint('user_id', 'topic_id', 'direction'),) def __repr__(self): return f'' class PredictionTransaction(db.Model): """预测交易记录""" __tablename__ = 'prediction_transaction' id = db.Column(db.Integer, primary_key=True) user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False) topic_id = db.Column(db.Integer, db.ForeignKey('prediction_topic.id'), nullable=False) # 交易信息 trade_type = db.Column(db.String(10), nullable=False) # buy/sell direction = db.Column(db.String(3), nullable=False) # yes/no shares = db.Column(db.Integer, nullable=False) # 份额数量 price = db.Column(db.Float, nullable=False) # 成交价格 # 费用 amount = db.Column(db.Float, nullable=False) # 交易金额 tax = db.Column(db.Float, default=0.0, nullable=False) # 手续费(2%) total_cost = db.Column(db.Float, nullable=False) # 总成本(amount + tax) # 时间 created_at = db.Column(db.DateTime, default=beijing_now, nullable=False) # 关系 user = db.relationship('User', backref='prediction_transactions') def __repr__(self): return f'' class CreditTransaction(db.Model): """积分交易记录""" __tablename__ = 'credit_transaction' id = db.Column(db.Integer, primary_key=True) user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False) # 交易信息 transaction_type = db.Column(db.String(30), nullable=False) # prediction_buy/prediction_sell/daily_bonus/create_topic/settle_win amount = db.Column(db.Float, nullable=False) # 金额(正数=增加,负数=减少) balance_after = db.Column(db.Float, nullable=False) # 交易后余额 # 关联 related_topic_id = db.Column(db.Integer, db.ForeignKey('prediction_topic.id')) # 相关话题 related_transaction_id = db.Column(db.Integer, db.ForeignKey('prediction_transaction.id')) # 相关预测交易 # 描述 description = db.Column(db.String(200)) # 交易描述 # 时间 created_at = db.Column(db.DateTime, default=beijing_now, nullable=False) # 关系 user = db.relationship('User', backref='credit_transactions') related_topic = db.relationship('PredictionTopic', backref='credit_transactions') def __repr__(self): return f'' class TopicComment(db.Model): """话题评论""" __tablename__ = 'topic_comment' id = db.Column(db.Integer, primary_key=True) topic_id = db.Column(db.Integer, db.ForeignKey('prediction_topic.id'), nullable=False) user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False) # 内容 content = db.Column(db.Text, nullable=False) parent_id = db.Column(db.Integer, db.ForeignKey('topic_comment.id')) # 父评论ID(回复功能) # 状态 is_pinned = db.Column(db.Boolean, default=False, nullable=False) # 是否置顶(领主特权) status = db.Column(db.String(20), default='active') # active/hidden/deleted # 统计 likes_count = db.Column(db.Integer, default=0, nullable=False) # 观点IPO 相关 total_investment = db.Column(db.Integer, default=0, nullable=False) # 总投资额 investor_count = db.Column(db.Integer, default=0, nullable=False) # 投资人数 is_verified = db.Column(db.Boolean, default=False, nullable=False) # 是否已验证 verification_result = db.Column(db.String(20)) # 验证结果:correct/incorrect/null position_rank = db.Column(db.Integer) # 评论位置排名(用于首发权拍卖) # 时间 created_at = db.Column(db.DateTime, default=beijing_now, nullable=False) updated_at = db.Column(db.DateTime, default=beijing_now, onupdate=beijing_now) # 关系 user = db.relationship('User', backref='topic_comments') replies = db.relationship('TopicComment', backref=db.backref('parent', remote_side=[id]), lazy='dynamic') likes = db.relationship('TopicCommentLike', backref='comment', lazy='dynamic') def __repr__(self): return f'' class TopicCommentLike(db.Model): """话题评论点赞""" __tablename__ = 'topic_comment_like' id = db.Column(db.Integer, primary_key=True) comment_id = db.Column(db.Integer, db.ForeignKey('topic_comment.id'), nullable=False) user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False) created_at = db.Column(db.DateTime, default=beijing_now, nullable=False) # 关系 user = db.relationship('User', backref='topic_comment_likes') # 唯一约束 __table_args__ = (db.UniqueConstraint('comment_id', 'user_id'),) def __repr__(self): return f'' class CommentInvestment(db.Model): """评论投资记录(观点IPO)""" __tablename__ = 'comment_investment' id = db.Column(db.Integer, primary_key=True) comment_id = db.Column(db.Integer, db.ForeignKey('topic_comment.id'), nullable=False) user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False) # 投资数据 shares = db.Column(db.Integer, nullable=False) # 投资份额 amount = db.Column(db.Integer, nullable=False) # 投资金额 avg_price = db.Column(db.Float, nullable=False) # 平均价格 # 状态 status = db.Column(db.String(20), default='active', nullable=False) # active/settled # 时间 created_at = db.Column(db.DateTime, default=beijing_now, nullable=False) # 关系 user = db.relationship('User', backref='comment_investments') comment = db.relationship('TopicComment', backref='investments') def __repr__(self): return f'' class CommentPositionBid(db.Model): """评论位置竞拍记录(首发权拍卖)""" __tablename__ = 'comment_position_bid' id = db.Column(db.Integer, primary_key=True) topic_id = db.Column(db.Integer, db.ForeignKey('prediction_topic.id'), nullable=False) user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False) # 竞拍数据 position = db.Column(db.Integer, nullable=False) # 位置:1/2/3 bid_amount = db.Column(db.Integer, nullable=False) # 出价金额 status = db.Column(db.String(20), default='pending', nullable=False) # pending/won/lost # 时间 created_at = db.Column(db.DateTime, default=beijing_now, nullable=False) expires_at = db.Column(db.DateTime, nullable=False) # 竞拍截止时间 # 关系 user = db.relationship('User', backref='comment_position_bids') topic = db.relationship('PredictionTopic', backref='position_bids') def __repr__(self): return f'' class TimeCapsuleTopic(db.Model): """时间胶囊话题(长期预测)""" __tablename__ = 'time_capsule_topic' id = db.Column(db.Integer, primary_key=True) user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False) # 话题内容 title = db.Column(db.String(200), nullable=False) description = db.Column(db.Text) encrypted_content = db.Column(db.Text) # 加密的预测内容 encryption_key = db.Column(db.String(500)) # 加密密钥(后端存储) # 时间范围 start_year = db.Column(db.Integer, nullable=False) # 起始年份 end_year = db.Column(db.Integer, nullable=False) # 结束年份 # 状态 status = db.Column(db.String(20), default='active', nullable=False) # active/settled is_decrypted = db.Column(db.Boolean, default=False, nullable=False) # 是否已解密 actual_happened_year = db.Column(db.Integer) # 实际发生年份 # 统计 total_pool = db.Column(db.Integer, default=0, nullable=False) # 总奖池 # 时间 created_at = db.Column(db.DateTime, default=beijing_now, nullable=False) updated_at = db.Column(db.DateTime, default=beijing_now, onupdate=beijing_now) # 关系 user = db.relationship('User', backref='time_capsule_topics') time_slots = db.relationship('TimeCapsuleTimeSlot', backref='topic', lazy='dynamic') def __repr__(self): return f'' class TimeCapsuleTimeSlot(db.Model): """时间胶囊时间段""" __tablename__ = 'time_capsule_time_slot' id = db.Column(db.Integer, primary_key=True) topic_id = db.Column(db.Integer, db.ForeignKey('time_capsule_topic.id'), nullable=False) # 时间段 year_start = db.Column(db.Integer, nullable=False) year_end = db.Column(db.Integer, nullable=False) # 竞拍数据 current_holder_id = db.Column(db.Integer, db.ForeignKey('user.id')) # 当前持有者 current_price = db.Column(db.Integer, default=100, nullable=False) # 当前价格 total_bids = db.Column(db.Integer, default=0, nullable=False) # 总竞拍次数 # 状态 status = db.Column(db.String(20), default='active', nullable=False) # active/won/expired # 时间 created_at = db.Column(db.DateTime, default=beijing_now, nullable=False) updated_at = db.Column(db.DateTime, default=beijing_now, onupdate=beijing_now) # 关系 current_holder = db.relationship('User', foreign_keys=[current_holder_id]) bids = db.relationship('TimeSlotBid', backref='time_slot', lazy='dynamic') def __repr__(self): return f'' class TimeSlotBid(db.Model): """时间段竞拍记录""" __tablename__ = 'time_slot_bid' id = db.Column(db.Integer, primary_key=True) slot_id = db.Column(db.Integer, db.ForeignKey('time_capsule_time_slot.id'), nullable=False) user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False) # 竞拍数据 bid_amount = db.Column(db.Integer, nullable=False) status = db.Column(db.String(20), default='outbid', nullable=False) # outbid/holding/won # 时间 created_at = db.Column(db.DateTime, default=beijing_now, nullable=False) # 关系 user = db.relationship('User', backref='time_slot_bids') def __repr__(self): return f'' class Event(db.Model): """事件模型""" id = db.Column(db.Integer, primary_key=True) title = db.Column(db.String(200), nullable=False) description = db.Column(db.Text) # 事件类型与状态 event_type = db.Column(db.String(50)) status = db.Column(db.String(20), default='active') # 时间相关 start_time = db.Column(db.DateTime, default=beijing_now) end_time = db.Column(db.DateTime) created_at = db.Column(db.DateTime, default=beijing_now) updated_at = db.Column(db.DateTime, default=beijing_now) # 热度与统计 hot_score = db.Column(db.Float, default=0) view_count = db.Column(db.Integer, default=0) trending_score = db.Column(db.Float, default=0) post_count = db.Column(db.Integer, default=0) follower_count = db.Column(db.Integer, default=0) # 关联信息 related_industries = db.Column(db.JSON) keywords = db.Column(db.JSON) files = db.Column(db.JSON) importance = db.Column(db.String(20)) related_avg_chg = db.Column(db.Float, default=0) related_max_chg = db.Column(db.Float, default=0) related_week_chg = db.Column(db.Float, default=0) # 新增字段 invest_score = db.Column(db.Integer) # 超预期得分 expectation_surprise_score = db.Column(db.Integer) # 创建者信息 creator_id = db.Column(db.Integer, db.ForeignKey('user.id')) creator = db.relationship('User', backref='created_events') # 关系 posts = db.relationship('Post', backref='event', lazy='dynamic') followers = db.relationship('EventFollow', backref='event', lazy='dynamic') related_stocks = db.relationship('RelatedStock', backref='event', lazy='dynamic') historical_events = db.relationship('HistoricalEvent', backref='event', lazy='dynamic') related_data = db.relationship('RelatedData', backref='event', lazy='dynamic') related_concepts = db.relationship('RelatedConcepts', backref='event', lazy='dynamic') @property def keywords_list(self): """返回解析后的关键词列表""" if not self.keywords: return [] if isinstance(self.keywords, list): return self.keywords try: # 如果是字符串,尝试解析JSON if isinstance(self.keywords, str): decoded = json.loads(self.keywords) # 处理Unicode编码的情况 if isinstance(decoded, list): return [ keyword.encode('utf-8').decode('unicode_escape') if isinstance(keyword, str) and '\\u' in keyword else keyword for keyword in decoded ] return [] # 如果已经是字典或其他格式,尝试转换为列表 return list(self.keywords) except (json.JSONDecodeError, AttributeError, TypeError): return [] def set_keywords(self, keywords): """设置关键词列表""" if isinstance(keywords, list): self.keywords = json.dumps(keywords, ensure_ascii=False) elif isinstance(keywords, str): try: # 尝试解析JSON字符串 parsed = json.loads(keywords) if isinstance(parsed, list): self.keywords = json.dumps(parsed, ensure_ascii=False) else: self.keywords = json.dumps([keywords], ensure_ascii=False) except json.JSONDecodeError: # 如果不是有效的JSON,将其作为单个关键词 self.keywords = json.dumps([keywords], ensure_ascii=False) class RelatedStock(db.Model): """相关标的模型""" id = db.Column(db.Integer, primary_key=True) event_id = db.Column(db.Integer, db.ForeignKey('event.id')) stock_code = db.Column(db.String(20)) # 股票代码 stock_name = db.Column(db.String(100)) # 股票名称 sector = db.Column(db.String(100)) # 关联类型 relation_desc = db.Column(db.String(1024)) # 关联原因描述 created_at = db.Column(db.DateTime, default=beijing_now) updated_at = db.Column(db.DateTime, default=beijing_now, onupdate=beijing_now) correlation = db.Column(db.Float()) momentum = db.Column(db.String(1024)) # 动量 retrieved_sources = db.Column(db.JSON) # 动量 class RelatedData(db.Model): """关联数据模型""" id = db.Column(db.Integer, primary_key=True) event_id = db.Column(db.Integer, db.ForeignKey('event.id')) title = db.Column(db.String(200)) # 数据标题 data_type = db.Column(db.String(50)) # 数据类型 data_content = db.Column(db.JSON) # 数据内容(JSON格式) description = db.Column(db.Text) # 数据描述 created_at = db.Column(db.DateTime, default=beijing_now) class RelatedConcepts(db.Model): """关联数据模型""" id = db.Column(db.Integer, primary_key=True) event_id = db.Column(db.Integer, db.ForeignKey('event.id')) concept_code = db.Column(db.String(20)) # 数据标题 concept = db.Column(db.String(100)) # 数据类型 reason = db.Column(db.Text) # 数据描述 image_paths = db.Column(db.JSON) # 数据内容(JSON格式) created_at = db.Column(db.DateTime, default=beijing_now) @property def image_paths_list(self): """返回解析后的图片路径列表""" if not self.image_paths: return [] try: # 如果是字符串,先解析成JSON if isinstance(self.image_paths, str): paths = json.loads(self.image_paths) else: paths = self.image_paths # 确保paths是列表 if not isinstance(paths, list): paths = [paths] # 从每个对象中提取path字段 return [item['path'] if isinstance(item, dict) and 'path' in item else item for item in paths] except Exception as e: print(f"Error processing image paths: {e}") return [] def get_first_image_path(self): """获取第一张图片的完整路径""" paths = self.image_paths_list if not paths: return None # 获取第一个路径 first_path = paths[0] # 返回完整路径 return first_path class EventHotHistory(db.Model): """事件热度历史记录""" id = db.Column(db.Integer, primary_key=True) event_id = db.Column(db.Integer, db.ForeignKey('event.id')) score = db.Column(db.Float) # 总分 interaction_score = db.Column(db.Float) # 互动分数 follow_score = db.Column(db.Float) # 关注度分数 view_score = db.Column(db.Float) # 浏览量分数 recent_activity_score = db.Column(db.Float) # 最近活跃度分数 time_decay = db.Column(db.Float) # 时间衰减因子 created_at = db.Column(db.DateTime, default=beijing_now) event = db.relationship('Event', backref='hot_history') class EventTransmissionNode(db.Model): """事件传导节点模型""" __tablename__ = 'event_transmission_nodes' id = db.Column(db.Integer, primary_key=True) event_id = db.Column(db.Integer, db.ForeignKey('event.id'), nullable=False) node_type = db.Column(db.Enum('company', 'industry', 'policy', 'technology', 'market', 'event', 'other'), nullable=False) node_name = db.Column(db.String(200), nullable=False) node_description = db.Column(db.Text) importance_score = db.Column(db.Integer, default=50) stock_code = db.Column(db.String(20)) is_main_event = db.Column(db.Boolean, default=False) created_at = db.Column(db.DateTime, default=beijing_now) updated_at = db.Column(db.DateTime, default=beijing_now, onupdate=beijing_now) # Relationships event = db.relationship('Event', backref='transmission_nodes') outgoing_edges = db.relationship('EventTransmissionEdge', foreign_keys='EventTransmissionEdge.from_node_id', backref='from_node', cascade='all, delete-orphan') incoming_edges = db.relationship('EventTransmissionEdge', foreign_keys='EventTransmissionEdge.to_node_id', backref='to_node', cascade='all, delete-orphan') __table_args__ = ( db.Index('idx_event_id', 'event_id'), db.Index('idx_node_type', 'node_type'), db.Index('idx_main_event', 'is_main_event'), ) class EventTransmissionEdge(db.Model): """事件传导边模型""" __tablename__ = 'event_transmission_edges' id = db.Column(db.Integer, primary_key=True) event_id = db.Column(db.Integer, db.ForeignKey('event.id'), nullable=False) from_node_id = db.Column(db.Integer, db.ForeignKey('event_transmission_nodes.id'), nullable=False) to_node_id = db.Column(db.Integer, db.ForeignKey('event_transmission_nodes.id'), nullable=False) transmission_type = db.Column(db.Enum('supply_chain', 'competition', 'policy', 'technology', 'capital_flow', 'expectation', 'cyclic_effect', 'other'), nullable=False) transmission_mechanism = db.Column(db.Text) direction = db.Column(db.Enum('positive', 'negative', 'neutral', 'mixed'), default='neutral') strength = db.Column(db.Integer, default=50) impact = db.Column(db.Text) is_circular = db.Column(db.Boolean, default=False) created_at = db.Column(db.DateTime, default=beijing_now) updated_at = db.Column(db.DateTime, default=beijing_now, onupdate=beijing_now) # Relationship event = db.relationship('Event', backref='transmission_edges') __table_args__ = ( db.Index('idx_event_id', 'event_id'), db.Index('idx_strength', 'strength'), db.Index('idx_from_to', 'from_node_id', 'to_node_id'), db.Index('idx_circular', 'is_circular'), ) # 在 paste-2.txt 的模型定义部分添加 class EventSankeyFlow(db.Model): """事件桑基流模型""" __tablename__ = 'event_sankey_flows' id = db.Column(db.Integer, primary_key=True) event_id = db.Column(db.Integer, db.ForeignKey('event.id'), nullable=False) # 流的基本信息 source_node = db.Column(db.String(200), nullable=False) source_type = db.Column(db.Enum('event', 'policy', 'technology', 'industry', 'company', 'product'), nullable=False) source_level = db.Column(db.Integer, nullable=False, default=0) target_node = db.Column(db.String(200), nullable=False) target_type = db.Column(db.Enum('policy', 'technology', 'industry', 'company', 'product'), nullable=False) target_level = db.Column(db.Integer, nullable=False, default=1) # 流量信息 flow_value = db.Column(db.Numeric(10, 2), nullable=False) flow_ratio = db.Column(db.Numeric(5, 4), nullable=False) # 传导机制 transmission_path = db.Column(db.String(500)) impact_description = db.Column(db.Text) evidence_strength = db.Column(db.Integer, default=50) # 时间戳 created_at = db.Column(db.DateTime, default=beijing_now) updated_at = db.Column(db.DateTime, default=beijing_now, onupdate=beijing_now) # 关系 event = db.relationship('Event', backref='sankey_flows') __table_args__ = ( db.Index('idx_event_id', 'event_id'), db.Index('idx_source_target', 'source_node', 'target_node'), db.Index('idx_levels', 'source_level', 'target_level'), db.Index('idx_flow_value', 'flow_value'), ) class HistoricalEvent(db.Model): """历史事件模型""" id = db.Column(db.Integer, primary_key=True) event_id = db.Column(db.Integer, db.ForeignKey('event.id')) title = db.Column(db.String(200)) content = db.Column(db.Text) event_date = db.Column(db.DateTime) relevance = db.Column(db.Integer) # 相关性 importance = db.Column(db.Integer) # 重要程度 related_stock = db.Column(db.JSON) # 保留JSON字段 created_at = db.Column(db.DateTime, default=beijing_now) # 新增关系 stocks = db.relationship('HistoricalEventStock', backref='historical_event', lazy='dynamic', cascade='all, delete-orphan') class HistoricalEventStock(db.Model): """历史事件相关股票模型""" __tablename__ = 'historical_event_stocks' id = db.Column(db.Integer, primary_key=True) historical_event_id = db.Column(db.Integer, db.ForeignKey('historical_event.id'), nullable=False) stock_code = db.Column(db.String(20), nullable=False) stock_name = db.Column(db.String(50)) relation_desc = db.Column(db.Text) correlation = db.Column(db.Float, default=0.5) sector = db.Column(db.String(100)) created_at = db.Column(db.DateTime, default=beijing_now) __table_args__ = ( db.UniqueConstraint('historical_event_id', 'stock_code', name='unique_event_stock'), ) # === 股票盈利预测(自有表) === class StockForecastData(db.Model): """股票盈利预测数据 源于本地表 stock_forecast_data,由独立离线程序写入。 字段与表结构保持一致,仅用于读取聚合后输出前端报表所需的结构。 """ __tablename__ = 'stock_forecast_data' id = db.Column(db.Integer, primary_key=True) stock_code = db.Column(db.String(6), nullable=False) indicator_name = db.Column(db.String(50), nullable=False) year_2022a = db.Column(db.Numeric(15, 2)) year_2023a = db.Column(db.Numeric(15, 2)) year_2024a = db.Column(db.Numeric(15, 2)) year_2025e = db.Column(db.Numeric(15, 2)) year_2026e = db.Column(db.Numeric(15, 2)) year_2027e = db.Column(db.Numeric(15, 2)) process_time = db.Column(db.DateTime, nullable=False) __table_args__ = ( db.UniqueConstraint('stock_code', 'indicator_name', name='unique_stock_indicator'), ) def values_by_year(self): years = ['2022A', '2023A', '2024A', '2025E', '2026E', '2027E'] vals = [self.year_2022a, self.year_2023a, self.year_2024a, self.year_2025e, self.year_2026e, self.year_2027e] def _to_float(x): try: return float(x) if x is not None else None except Exception: return None return years, [_to_float(v) for v in vals] @app.route('/api/events/', methods=['GET']) def get_event_detail(event_id): """获取事件详情""" try: event = Event.query.get_or_404(event_id) # 增加浏览计数 event.view_count += 1 db.session.commit() return jsonify({ 'success': True, 'data': { 'id': event.id, 'title': event.title, 'description': event.description, 'event_type': event.event_type, 'status': event.status, 'start_time': event.start_time.isoformat() if event.start_time else None, 'end_time': event.end_time.isoformat() if event.end_time else None, 'created_at': event.created_at.isoformat() if event.created_at else None, 'hot_score': event.hot_score, 'view_count': event.view_count, 'trending_score': event.trending_score, 'post_count': event.post_count, 'follower_count': event.follower_count, 'related_industries': event.related_industries, 'keywords': event.keywords_list, 'importance': event.importance, 'related_avg_chg': event.related_avg_chg, 'related_max_chg': event.related_max_chg, 'related_week_chg': event.related_week_chg, 'invest_score': event.invest_score, 'expectation_surprise_score': event.expectation_surprise_score, 'creator_id': event.creator_id, 'has_chain_analysis': ( EventTransmissionNode.query.filter_by(event_id=event_id).first() is not None or EventSankeyFlow.query.filter_by(event_id=event_id).first() is not None ), 'is_following': False, # 需要根据当前用户状态判断 } }) except Exception as e: return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/events//stocks', methods=['GET']) def get_related_stocks(event_id): """获取相关股票列表""" try: # 订阅控制:相关标的需要 Pro 及以上 if not _has_required_level('pro'): return jsonify({'success': False, 'error': '需要Pro订阅', 'required_level': 'pro'}), 403 event = Event.query.get_or_404(event_id) stocks = event.related_stocks.order_by(RelatedStock.correlation.desc()).all() stocks_data = [] for stock in stocks: if stock.retrieved_sources is not None: stocks_data.append({ 'id': stock.id, 'stock_code': stock.stock_code, 'stock_name': stock.stock_name, 'sector': stock.sector, 'relation_desc': {"data":stock.retrieved_sources}, 'retrieved_sources': stock.retrieved_sources, 'correlation': stock.correlation, 'momentum': stock.momentum, 'created_at': stock.created_at.isoformat() if stock.created_at else None, 'updated_at': stock.updated_at.isoformat() if stock.updated_at else None }) else: stocks_data.append({ 'id': stock.id, 'stock_code': stock.stock_code, 'stock_name': stock.stock_name, 'sector': stock.sector, 'relation_desc': stock.relation_desc, 'correlation': stock.correlation, 'momentum': stock.momentum, 'created_at': stock.created_at.isoformat() if stock.created_at else None, 'updated_at': stock.updated_at.isoformat() if stock.updated_at else None }) return jsonify({ 'success': True, 'data': stocks_data }) except Exception as e: return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/events//stocks', methods=['POST']) def add_related_stock(event_id): """添加相关股票""" try: event = Event.query.get_or_404(event_id) data = request.get_json() # 验证必要字段 if not data.get('stock_code') or not data.get('relation_desc'): return jsonify({'success': False, 'error': '缺少必要字段'}), 400 # 检查是否已存在 existing = RelatedStock.query.filter_by( event_id=event_id, stock_code=data['stock_code'] ).first() if existing: return jsonify({'success': False, 'error': '该股票已存在'}), 400 # 创建新的相关股票记录 new_stock = RelatedStock( event_id=event_id, stock_code=data['stock_code'], stock_name=data.get('stock_name', ''), sector=data.get('sector', ''), relation_desc=data['relation_desc'], correlation=data.get('correlation', 0.5), momentum=data.get('momentum', '') ) db.session.add(new_stock) db.session.commit() return jsonify({ 'success': True, 'data': { 'id': new_stock.id, 'stock_code': new_stock.stock_code, 'relation_desc': new_stock.relation_desc } }) except Exception as e: db.session.rollback() return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/stocks/', methods=['DELETE']) def delete_related_stock(stock_id): """删除相关股票""" try: stock = RelatedStock.query.get_or_404(stock_id) db.session.delete(stock) db.session.commit() return jsonify({'success': True, 'message': '删除成功'}) except Exception as e: db.session.rollback() return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/events//concepts', methods=['GET']) def get_related_concepts(event_id): """获取相关概念列表""" try: # 订阅控制:相关概念需要 Pro 及以上 if not _has_required_level('pro'): return jsonify({'success': False, 'error': '需要Pro订阅', 'required_level': 'pro'}), 403 event = Event.query.get_or_404(event_id) concepts = event.related_concepts.all() concepts_data = [] for concept in concepts: concepts_data.append({ 'id': concept.id, 'concept_code': concept.concept_code, 'concept': concept.concept, 'reason': concept.reason, 'image_paths': concept.image_paths_list, 'first_image_path': concept.get_first_image_path(), 'created_at': concept.created_at.isoformat() if concept.created_at else None }) return jsonify({ 'success': True, 'data': concepts_data }) except Exception as e: return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/events//historical', methods=['GET']) def get_historical_events(event_id): """获取历史事件对比""" try: event = Event.query.get_or_404(event_id) historical_events = event.historical_events.order_by(HistoricalEvent.event_date.desc()).all() events_data = [] for hist_event in historical_events: events_data.append({ 'id': hist_event.id, 'title': hist_event.title, 'content': hist_event.content, 'event_date': hist_event.event_date.isoformat() if hist_event.event_date else None, 'importance': hist_event.importance, 'relevance': hist_event.relevance, 'created_at': hist_event.created_at.isoformat() if hist_event.created_at else None }) # 订阅控制:免费用户仅返回前2条;Pro/Max返回全部 info = _get_current_subscription_info() sub_type = (info.get('type') or 'free').lower() if sub_type == 'free': return jsonify({ 'success': True, 'data': events_data[:2], 'truncated': len(events_data) > 2, 'required_level': 'pro' }) return jsonify({'success': True, 'data': events_data}) except Exception as e: return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/historical-events//stocks', methods=['GET']) def get_historical_event_stocks(event_id): """获取历史事件相关股票列表""" try: # 直接查询历史事件,不需要通过主事件 hist_event = HistoricalEvent.query.get_or_404(event_id) stocks = hist_event.stocks.order_by(HistoricalEventStock.correlation.desc()).all() # 获取事件对应的交易日 event_trading_date = None if hist_event.event_date: event_trading_date = get_trading_day_near_date(hist_event.event_date) stocks_data = [] for stock in stocks: stock_data = { 'id': stock.id, 'stock_code': stock.stock_code, 'stock_name': stock.stock_name, 'sector': stock.sector, 'relation_desc': stock.relation_desc, 'correlation': stock.correlation, 'created_at': stock.created_at.isoformat() if stock.created_at else None } # 添加涨幅数据 if event_trading_date: try: # 查询股票在事件对应交易日的数据 with engine.connect() as conn: query = text(""" SELECT close_price, change_pct FROM ea_dailyline WHERE seccode = :stock_code AND date = :trading_date ORDER BY date DESC LIMIT 1 """) result = conn.execute(query, { 'stock_code': stock.stock_code, 'trading_date': event_trading_date }).fetchone() if result: stock_data['event_day_close'] = float(result[0]) if result[0] else None stock_data['event_day_change_pct'] = float(result[1]) if result[1] else None else: stock_data['event_day_close'] = None stock_data['event_day_change_pct'] = None except Exception as e: print(f"查询股票{stock.stock_code}在{event_trading_date}的数据失败: {e}") stock_data['event_day_close'] = None stock_data['event_day_change_pct'] = None else: stock_data['event_day_close'] = None stock_data['event_day_change_pct'] = None stocks_data.append(stock_data) return jsonify({ 'success': True, 'data': stocks_data, 'event_trading_date': event_trading_date.isoformat() if event_trading_date else None }) except Exception as e: return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/events//expectation-score', methods=['GET']) def get_expectation_score(event_id): """获取超预期得分""" try: event = Event.query.get_or_404(event_id) # 如果事件有超预期得分,直接返回 if event.expectation_surprise_score is not None: score = event.expectation_surprise_score else: # 如果没有,根据历史事件计算一个模拟得分 historical_events = event.historical_events.all() if historical_events: # 基于历史事件数量和重要性计算得分 total_importance = sum(ev.importance or 0 for ev in historical_events) avg_importance = total_importance / len(historical_events) if historical_events else 0 score = min(100, max(0, int(avg_importance * 20 + len(historical_events) * 5))) else: # 默认得分 score = 65 return jsonify({ 'success': True, 'data': { 'score': score, 'description': '基于历史事件判断当前事件的超预期情况,满分100分' } }) except Exception as e: return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/events//follow', methods=['POST']) def toggle_event_follow(event_id): """切换事件关注状态(需登录)""" if 'user_id' not in session: return jsonify({'success': False, 'error': '未登录'}), 401 try: event = Event.query.get_or_404(event_id) user_id = session['user_id'] existing = EventFollow.query.filter_by(user_id=user_id, event_id=event_id).first() if existing: # 取消关注 db.session.delete(existing) event.follower_count = max(0, (event.follower_count or 0) - 1) db.session.commit() return jsonify({'success': True, 'data': {'is_following': False, 'follower_count': event.follower_count}}) else: # 关注 follow = EventFollow(user_id=user_id, event_id=event_id) db.session.add(follow) event.follower_count = (event.follower_count or 0) + 1 db.session.commit() return jsonify({'success': True, 'data': {'is_following': True, 'follower_count': event.follower_count}}) except Exception as e: db.session.rollback() return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/events//transmission', methods=['GET']) def get_transmission_chain(event_id): try: # 订阅控制:传导链分析需要 Max 及以上 if not _has_required_level('max'): return jsonify({'success': False, 'error': '需要Max订阅', 'required_level': 'max'}), 403 # 确保数据库连接是活跃的 db.session.execute(text('SELECT 1')) event = Event.query.get_or_404(event_id) nodes = EventTransmissionNode.query.filter_by(event_id=event_id).all() edges = EventTransmissionEdge.query.filter_by(event_id=event_id).all() # 过滤孤立节点 connected_node_ids = set() for edge in edges: connected_node_ids.add(edge.from_node_id) connected_node_ids.add(edge.to_node_id) # 只保留有连接的节点 connected_nodes = [node for node in nodes if node.id in connected_node_ids] # 如果没有主事件节点,也保留主事件节点 main_event_node = next((node for node in nodes if node.is_main_event), None) if main_event_node and main_event_node not in connected_nodes: connected_nodes.append(main_event_node) if not connected_nodes: return jsonify({'success': False, 'message': '暂无传导链分析数据'}) # 节点类型到中文类别的映射 categories = { 'event': "事件", 'industry': "行业", 'company': "公司", 'policy': "政策", 'technology': "技术", 'market': "市场", 'other': "其他" } nodes_data = [] for node in connected_nodes: node_category = categories.get(node.node_type, "其他") nodes_data.append({ 'id': str(node.id), # 转换为字符串以保持一致性 'name': node.node_name, 'category': node_category, 'value': node.importance_score or 20, 'extra': { 'node_type': node.node_type, 'description': node.node_description, 'importance_score': node.importance_score, 'stock_code': node.stock_code, 'is_main_event': node.is_main_event } }) edges_data = [] for edge in edges: # 确保边的两端节点都在连接节点列表中 if edge.from_node_id in connected_node_ids and edge.to_node_id in connected_node_ids: edges_data.append({ 'source': str(edge.from_node_id), # 转换为字符串以保持一致性 'target': str(edge.to_node_id), # 转换为字符串以保持一致性 'value': edge.strength or 50, 'extra': { 'transmission_type': edge.transmission_type, 'transmission_mechanism': edge.transmission_mechanism, 'direction': edge.direction, 'strength': edge.strength, 'impact': edge.impact, 'is_circular': edge.is_circular, } }) return jsonify({ 'success': True, 'data': { 'nodes': nodes_data, 'edges': edges_data } }) except Exception as e: return jsonify({'success': False, 'error': str(e)}), 500 # 修复股票报价API - 支持GET和POST方法 @app.route('/api/stock/quotes', methods=['GET', 'POST']) def get_stock_quotes(): try: if request.method == 'GET': # GET 请求从查询参数获取数据 codes_str = request.args.get('codes', '') codes = [code.strip() for code in codes_str.split(',') if code.strip()] event_time_str = request.args.get('event_time') else: # POST 请求从 JSON 获取数据 codes = request.json.get('codes', []) event_time_str = request.json.get('event_time') if not codes: return jsonify({'success': False, 'error': '请提供股票代码'}), 400 # 处理事件时间 if event_time_str: try: event_time = datetime.fromisoformat(event_time_str.replace('Z', '+00:00')) except: event_time = datetime.now() else: event_time = datetime.now() current_time = datetime.now() client = get_clickhouse_client() # Get stock names from MySQL stock_names = {} with engine.connect() as conn: for code in codes: codez = code.split('.')[0] result = conn.execute(text( "SELECT SECNAME FROM ea_stocklist WHERE SECCODE = :code" ), {"code": codez}).fetchone() if result: stock_names[code] = result[0] else: stock_names[code] = f"股票{codez}" def get_trading_day_and_times(event_datetime): event_date = event_datetime.date() event_time = event_datetime.time() # Trading hours market_open = dt_time(9, 30) market_close = dt_time(15, 0) with engine.connect() as conn: # First check if the event date itself is a trading day is_trading_day = conn.execute(text(""" SELECT 1 FROM trading_days WHERE EXCHANGE_DATE = :date """), {"date": event_date}).fetchone() is not None if is_trading_day: # If it's a trading day, determine time period based on event time if event_time < market_open: # Before market opens - use full trading day return event_date, market_open, market_close elif event_time > market_close: # After market closes - get next trading day next_trading_day = conn.execute(text(""" SELECT EXCHANGE_DATE FROM trading_days WHERE EXCHANGE_DATE > :date ORDER BY EXCHANGE_DATE LIMIT 1 """), {"date": event_date}).fetchone() # Convert to date object if we found a next trading day return (next_trading_day[0].date() if next_trading_day else None, market_open, market_close) else: # During trading hours return event_date, event_time, market_close else: # If not a trading day, get next trading day next_trading_day = conn.execute(text(""" SELECT EXCHANGE_DATE FROM trading_days WHERE EXCHANGE_DATE > :date ORDER BY EXCHANGE_DATE LIMIT 1 """), {"date": event_date}).fetchone() # Convert to date object if we found a next trading day return (next_trading_day[0].date() if next_trading_day else None, market_open, market_close) trading_day, start_time, end_time = get_trading_day_and_times(event_time) if not trading_day: return jsonify({ 'success': True, 'data': {code: {'name': name, 'price': None, 'change': None} for code, name in stock_names.items()} }) # For historical dates, ensure we're using actual data start_datetime = datetime.combine(trading_day, start_time) end_datetime = datetime.combine(trading_day, end_time) # If the trading day is in the future relative to current time, # return only names without data if trading_day > current_time.date(): return jsonify({ 'success': True, 'data': {code: {'name': name, 'price': None, 'change': None} for code, name in stock_names.items()} }) results = {} print(f"处理股票代码: {codes}, 交易日: {trading_day}, 时间范围: {start_datetime} - {end_datetime}") for code in codes: try: print(f"正在查询股票 {code} 的价格数据...") # Get the first price and last price for the trading period data = client.execute(""" WITH first_price AS (SELECT close FROM stock_minute WHERE code = %(code)s AND timestamp >= %(start)s AND timestamp <= %(end)s ORDER BY timestamp LIMIT 1 ), last_price AS ( SELECT close FROM stock_minute WHERE code = %(code)s AND timestamp >= %(start)s AND timestamp <= %(end)s ORDER BY timestamp DESC LIMIT 1 ) SELECT last_price.close as last_price, (last_price.close - first_price.close) / first_price.close * 100 as change FROM last_price CROSS JOIN first_price WHERE EXISTS (SELECT 1 FROM first_price) AND EXISTS (SELECT 1 FROM last_price) """, { 'code': code, 'start': start_datetime, 'end': end_datetime }) print(f"股票 {code} 查询结果: {data}") if data and data[0] and data[0][0] is not None: price = float(data[0][0]) if data[0][0] is not None else None change = float(data[0][1]) if data[0][1] is not None else None results[code] = { 'price': price, 'change': change, 'name': stock_names.get(code, f'股票{code.split(".")[0]}') } else: results[code] = { 'price': None, 'change': None, 'name': stock_names.get(code, f'股票{code.split(".")[0]}') } except Exception as e: print(f"Error processing stock {code}: {e}") results[code] = { 'price': None, 'change': None, 'name': stock_names.get(code, f'股票{code.split(".")[0]}') } # 返回标准格式 return jsonify({'success': True, 'data': results}) except Exception as e: print(f"Stock quotes API error: {e}") return jsonify({'success': False, 'error': str(e)}), 500 def get_clickhouse_client(): return Cclient( host='222.128.1.157', port=18000, user='default', password='Zzl33818!', database='stock' ) @app.route('/api/account/calendar/events', methods=['GET', 'POST']) def account_calendar_events(): """返回当前用户的投资计划与关注的未来事件(合并)。 GET: 可按日期范围/月份过滤;POST: 新增投资计划(写入 InvestmentPlan)。 """ try: if 'user_id' not in session: return jsonify({'success': False, 'error': '未登录'}), 401 if request.method == 'POST': data = request.get_json() or {} title = data.get('title') event_date_str = data.get('event_date') or data.get('date') plan_type = data.get('type') or 'plan' description = data.get('description') or data.get('content') or '' stocks = data.get('stocks') or [] if not title or not event_date_str: return jsonify({'success': False, 'error': '缺少必填字段'}), 400 try: event_date = datetime.fromisoformat(event_date_str).date() except Exception: return jsonify({'success': False, 'error': '日期格式错误'}), 400 plan = InvestmentPlan( user_id=session['user_id'], date=event_date, title=title, content=description, type=plan_type, stocks=json.dumps(stocks), tags=json.dumps(data.get('tags', [])), status=data.get('status', 'active') ) db.session.add(plan) db.session.commit() return jsonify({'success': True, 'data': { 'id': plan.id, 'title': plan.title, 'event_date': plan.date.isoformat(), 'type': plan.type, 'description': plan.content, 'stocks': json.loads(plan.stocks) if plan.stocks else [], 'source': 'plan' }}) # GET # 解析过滤参数:date 或 (year, month) 或 (start_date, end_date) date_str = request.args.get('date') year = request.args.get('year', type=int) month = request.args.get('month', type=int) start_date_str = request.args.get('start_date') end_date_str = request.args.get('end_date') start_date = None end_date = None if date_str: try: d = datetime.fromisoformat(date_str).date() start_date = d end_date = d except Exception: pass elif year and month: # 月份范围 start_date = datetime(year, month, 1).date() if month == 12: end_date = datetime(year + 1, 1, 1).date() - timedelta(days=1) else: end_date = datetime(year, month + 1, 1).date() - timedelta(days=1) elif start_date_str and end_date_str: try: start_date = datetime.fromisoformat(start_date_str).date() end_date = datetime.fromisoformat(end_date_str).date() except Exception: start_date = None end_date = None # 查询投资计划 plans_query = InvestmentPlan.query.filter_by(user_id=session['user_id']) if start_date and end_date: plans_query = plans_query.filter(InvestmentPlan.date >= start_date, InvestmentPlan.date <= end_date) elif start_date: plans_query = plans_query.filter(InvestmentPlan.date == start_date) plans = plans_query.order_by(InvestmentPlan.date.asc()).all() plan_events = [{ 'id': p.id, 'title': p.title, 'event_date': p.date.isoformat(), 'type': p.type or 'plan', 'description': p.content, 'importance': 3, 'stocks': json.loads(p.stocks) if p.stocks else [], 'source': 'plan' } for p in plans] # 查询关注的未来事件 follows = FutureEventFollow.query.filter_by(user_id=session['user_id']).all() future_event_ids = [f.future_event_id for f in follows] future_events = [] if future_event_ids: base_sql = """ SELECT data_id, \ title, \ type, \ calendar_time, \ star, \ former, \ forecast, \ fact, \ related_stocks, \ concepts FROM future_events WHERE data_id IN :event_ids \ """ params = {'event_ids': tuple(future_event_ids)} # 日期过滤(按 calendar_time 的日期) if start_date and end_date: base_sql += " AND DATE(calendar_time) BETWEEN :start_date AND :end_date" params.update({'start_date': start_date, 'end_date': end_date}) elif start_date: base_sql += " AND DATE(calendar_time) = :start_date" params.update({'start_date': start_date}) base_sql += " ORDER BY calendar_time" result = db.session.execute(text(base_sql), params) for row in result: # related_stocks 形如 [[code,name,reason,score], ...] rs = parse_json_field(row.related_stocks) stock_tags = [] try: for it in rs: if isinstance(it, (list, tuple)) and len(it) >= 2: stock_tags.append(f"{it[0]} {it[1]}") elif isinstance(it, str): stock_tags.append(it) except Exception: pass future_events.append({ 'id': row.data_id, 'title': row.title, 'event_date': (row.calendar_time.date().isoformat() if row.calendar_time else None), 'type': 'future_event', 'importance': int(row.star) if getattr(row, 'star', None) is not None else 3, 'description': row.former or '', 'stocks': stock_tags, 'is_following': True, 'source': 'future' }) return jsonify({'success': True, 'data': plan_events + future_events}) except Exception as e: return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/account/calendar/events/', methods=['DELETE']) def delete_account_calendar_event(event_id): """删除用户创建的投资计划事件(不影响关注的未来事件)。""" try: if 'user_id' not in session: return jsonify({'success': False, 'error': '未登录'}), 401 plan = InvestmentPlan.query.filter_by(id=event_id, user_id=session['user_id']).first() if not plan: return jsonify({'success': False, 'error': '未找到该记录'}), 404 db.session.delete(plan) db.session.commit() return jsonify({'success': True}) except Exception as e: db.session.rollback() return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/stock//kline') def get_stock_kline(stock_code): chart_type = request.args.get('type', 'minute') event_time = request.args.get('event_time') try: event_datetime = datetime.fromisoformat(event_time) if event_time else datetime.now() except ValueError: return jsonify({'error': 'Invalid event_time format'}), 400 # 获取股票名称 with engine.connect() as conn: result = conn.execute(text( "SELECT SECNAME FROM ea_stocklist WHERE SECCODE = :code" ), {"code": stock_code.split('.')[0]}).fetchone() stock_name = result[0] if result else 'Unknown' if chart_type == 'daily': return get_daily_kline(stock_code, event_datetime, stock_name) elif chart_type == 'minute': return get_minute_kline(stock_code, event_datetime, stock_name) elif chart_type == 'timeline': return get_timeline_data(stock_code, event_datetime, stock_name) else: # 对于未知的类型,返回错误 return jsonify({'error': f'Unsupported chart type: {chart_type}'}), 400 @app.route('/api/stock//latest-minute', methods=['GET']) def get_latest_minute_data(stock_code): """获取最新交易日的分钟频数据""" client = get_clickhouse_client() # 确保股票代码包含后缀 if '.' not in stock_code: stock_code = f"{stock_code}.SH" if stock_code.startswith('6') else f"{stock_code}.SZ" # 获取股票名称 with engine.connect() as conn: result = conn.execute(text( "SELECT SECNAME FROM ea_stocklist WHERE SECCODE = :code" ), {"code": stock_code.split('.')[0]}).fetchone() stock_name = result[0] if result else 'Unknown' # 查找最近30天内有数据的最新交易日 target_date = None current_date = datetime.now().date() for i in range(30): check_date = current_date - timedelta(days=i) trading_day = get_trading_day_near_date(check_date) if trading_day and trading_day <= current_date: # 检查这个交易日是否有分钟数据 test_data = client.execute(""" SELECT COUNT(*) FROM stock_minute WHERE code = %(code)s AND timestamp BETWEEN %(start)s AND %(end)s LIMIT 1 """, { 'code': stock_code, 'start': datetime.combine(trading_day, dt_time(9, 30)), 'end': datetime.combine(trading_day, dt_time(15, 0)) }) if test_data and test_data[0][0] > 0: target_date = trading_day break if not target_date: return jsonify({ 'error': 'No data available', 'code': stock_code, 'name': stock_name, 'data': [], 'trade_date': current_date.strftime('%Y-%m-%d'), 'type': 'minute' }) # 获取目标日期的完整交易时段数据 data = client.execute(""" SELECT timestamp, open, high, low, close, volume, amt FROM stock_minute WHERE code = %(code)s AND timestamp BETWEEN %(start)s AND %(end)s ORDER BY timestamp """, { 'code': stock_code, 'start': datetime.combine(target_date, dt_time(9, 30)), 'end': datetime.combine(target_date, dt_time(15, 0)) }) kline_data = [{ 'time': row[0].strftime('%H:%M'), 'open': float(row[1]), 'high': float(row[2]), 'low': float(row[3]), 'close': float(row[4]), 'volume': float(row[5]), 'amount': float(row[6]) } for row in data] return jsonify({ 'code': stock_code, 'name': stock_name, 'data': kline_data, 'trade_date': target_date.strftime('%Y-%m-%d'), 'type': 'minute', 'is_latest': True }) @app.route('/api/stock//forecast-report', methods=['GET']) def get_stock_forecast_report(stock_code): """基于 stock_forecast_data 输出报表所需数据结构 返回: - income_profit_trend: 营业收入/归母净利润趋势 - growth_bars: 增长率柱状图数据(基于营业收入同比) - eps_trend: EPS 折线 - pe_peg_axes: PE/PEG 双轴 - detail_table: 详细数据表格(与附件结构一致) """ try: # 读取该股票所有指标 rows = StockForecastData.query.filter_by(stock_code=stock_code).all() if not rows: return jsonify({'success': False, 'error': 'no_data'}), 404 # 将指标映射为字典 indicators = {} for r in rows: years, vals = r.values_by_year() indicators[r.indicator_name] = dict(zip(years, vals)) def safe(x): return x if x is not None else None years = ['2022A', '2023A', '2024A', '2025E', '2026E', '2027E'] # 营业收入与净利润趋势 income = indicators.get('营业总收入(百万元)', {}) profit = indicators.get('归母净利润(百万元)', {}) income_profit_trend = { 'years': years, 'income': [safe(income.get(y)) for y in years], 'profit': [safe(profit.get(y)) for y in years] } # 增长率柱状(若表内已有"增长率(%)",直接使用;否则按营业收入同比计算) growth = indicators.get('增长率(%)') if growth is None: # 计算同比: (curr - prev)/prev*100 growth_vals = [] prev = None for y in years: curr = income.get(y) if prev is not None and prev not in (None, 0) and curr is not None: growth_vals.append(round((float(curr) - float(prev)) / float(prev) * 100, 2)) else: growth_vals.append(None) prev = curr else: growth_vals = [safe(growth.get(y)) for y in years] growth_bars = { 'years': years, 'revenue_growth_pct': growth_vals, 'net_profit_growth_pct': None # 如后续需要可扩展 } # EPS 趋势 eps = indicators.get('EPS(稀释)') or indicators.get('EPS(元/股)') or {} eps_trend = { 'years': years, 'eps': [safe(eps.get(y)) for y in years] } # PE / PEG 双轴 pe = indicators.get('PE') or {} peg = indicators.get('PEG') or {} pe_peg_axes = { 'years': years, 'pe': [safe(pe.get(y)) for y in years], 'peg': [safe(peg.get(y)) for y in years] } # 详细数据表格(列顺序固定) def fmt(val): try: return None if val is None else round(float(val), 2) except Exception: return None detail_rows = [ { '指标': '营业总收入(百万元)', **{y: fmt(income.get(y)) for y in years}, }, { '指标': '增长率(%)', **{y: fmt(v) for y, v in zip(years, growth_vals)}, }, { '指标': '归母净利润(百万元)', **{y: fmt(profit.get(y)) for y in years}, }, { '指标': 'EPS(稀释)', **{y: fmt(eps.get(y)) for y in years}, }, { '指标': 'PE', **{y: fmt(pe.get(y)) for y in years}, }, { '指标': 'PEG', **{y: fmt(peg.get(y)) for y in years}, }, ] return jsonify({ 'success': True, 'data': { 'income_profit_trend': income_profit_trend, 'growth_bars': growth_bars, 'eps_trend': eps_trend, 'pe_peg_axes': pe_peg_axes, 'detail_table': { 'years': years, 'rows': detail_rows } } }) except Exception as e: app.logger.error(f"forecast report error: {e}", exc_info=True) return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/stock//basic-info', methods=['GET']) def get_stock_basic_info(stock_code): """获取股票基本信息(来自ea_baseinfo表)""" try: with engine.connect() as conn: query = text(""" SELECT SECCODE, SECNAME, ORGNAME, F001V as en_name, F002V as en_short_name, F003V as legal_representative, F004V as reg_address, F005V as office_address, F006V as post_code, F007N as reg_capital, F009V as currency, F010D as establish_date, F011V as website, F012V as email, F013V as tel, F014V as fax, F015V as main_business, F016V as business_scope, F017V as company_intro, F018V as secretary, F019V as secretary_tel, F020V as secretary_fax, F021V as secretary_email, F024V as listing_status, F026V as province, F028V as city, F030V as industry_l1, F032V as industry_l2, F034V as sw_industry_l1, F036V as sw_industry_l2, F038V as sw_industry_l3, F039V as accounting_firm, F040V as law_firm, F041V as chairman, F042V as general_manager, F043V as independent_directors, F050V as credit_code, F054V as company_size, UPDATE_DATE FROM ea_baseinfo WHERE SECCODE = :stock_code LIMIT 1 """) result = conn.execute(query, {'stock_code': stock_code}).fetchone() if not result: return jsonify({ 'success': False, 'error': f'未找到股票代码 {stock_code} 的基本信息' }), 404 # 转换为字典 basic_info = {} result_dict = row_to_dict(result) for key, value in result_dict.items(): if isinstance(value, datetime): basic_info[key] = value.strftime('%Y-%m-%d') elif isinstance(value, Decimal): basic_info[key] = float(value) else: basic_info[key] = value return jsonify({ 'success': True, 'data': basic_info }) except Exception as e: app.logger.error(f"Error getting stock basic info: {e}", exc_info=True) return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/stock//announcements', methods=['GET']) def get_stock_announcements(stock_code): """获取股票公告列表""" try: limit = request.args.get('limit', 50, type=int) with engine.connect() as conn: query = text(""" SELECT F001D as announce_date, F002V as title, F003V as url, F004V as format, F005N as file_size, F006V as info_type, UPDATE_DATE FROM ea_baseinfolist WHERE SECCODE = :stock_code ORDER BY F001D DESC LIMIT :limit """) result = conn.execute(query, {'stock_code': stock_code, 'limit': limit}).fetchall() announcements = [] for row in result: announcement = {} for key, value in row_to_dict(row).items(): if value is None: announcement[key] = None elif isinstance(value, datetime): announcement[key] = value.strftime('%Y-%m-%d %H:%M:%S') elif isinstance(value, date): announcement[key] = value.strftime('%Y-%m-%d') elif isinstance(value, Decimal): announcement[key] = float(value) else: announcement[key] = value announcements.append(announcement) return jsonify({ 'success': True, 'data': announcements, 'total': len(announcements) }) except Exception as e: app.logger.error(f"Error getting stock announcements: {e}", exc_info=True) return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/stock//disclosure-schedule', methods=['GET']) def get_stock_disclosure_schedule(stock_code): """获取股票财报预披露时间表""" try: with engine.connect() as conn: query = text(""" SELECT distinct F001D as report_period, F002D as scheduled_date, F003D as change_date1, F004D as change_date2, F005D as change_date3, F006D as actual_date, F007D as change_date4, F008D as change_date5, MODTIME as mod_time FROM ea_pretime WHERE SECCODE = :stock_code ORDER BY F001D DESC LIMIT 20 """) result = conn.execute(query, {'stock_code': stock_code}).fetchall() schedules = [] for row in result: schedule = {} for key, value in row_to_dict(row).items(): if value is None: schedule[key] = None elif isinstance(value, datetime): schedule[key] = value.strftime('%Y-%m-%d %H:%M:%S') elif isinstance(value, date): schedule[key] = value.strftime('%Y-%m-%d') elif isinstance(value, Decimal): schedule[key] = float(value) else: schedule[key] = value # 计算最新的预约日期 latest_scheduled = schedule.get('scheduled_date') for change_field in ['change_date5', 'change_date4', 'change_date3', 'change_date2', 'change_date1']: if schedule.get(change_field): latest_scheduled = schedule[change_field] break schedule['latest_scheduled_date'] = latest_scheduled schedule['is_disclosed'] = bool(schedule.get('actual_date')) # 格式化报告期名称 if schedule.get('report_period'): period_date = schedule['report_period'] if period_date.endswith('-03-31'): schedule['report_name'] = f"{period_date[:4]}年一季报" elif period_date.endswith('-06-30'): schedule['report_name'] = f"{period_date[:4]}年中报" elif period_date.endswith('-09-30'): schedule['report_name'] = f"{period_date[:4]}年三季报" elif period_date.endswith('-12-31'): schedule['report_name'] = f"{period_date[:4]}年年报" else: schedule['report_name'] = period_date schedules.append(schedule) return jsonify({ 'success': True, 'data': schedules, 'total': len(schedules) }) except Exception as e: app.logger.error(f"Error getting disclosure schedule: {e}", exc_info=True) return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/stock//actual-control', methods=['GET']) def get_stock_actual_control(stock_code): """获取股票实际控制人信息""" try: with engine.connect() as conn: query = text(""" SELECT DECLAREDATE as declare_date, ENDDATE as end_date, F001V as direct_holder_id, F002V as direct_holder_name, F003V as actual_controller_id, F004V as actual_controller_name, F005N as holding_shares, F006N as holding_ratio, F007V as control_type_code, F008V as control_type, F012V as direct_controller_id, F013V as direct_controller_name, F014V as controller_type, ORGNAME as org_name, SECCODE as sec_code, SECNAME as sec_name FROM ea_actualcon WHERE SECCODE = :stock_code ORDER BY ENDDATE DESC, DECLAREDATE DESC LIMIT 20 """) result = conn.execute(query, {'stock_code': stock_code}).fetchall() control_info = [] for row in result: control_record = {} for key, value in row_to_dict(row).items(): if value is None: control_record[key] = None elif isinstance(value, datetime): control_record[key] = value.strftime('%Y-%m-%d %H:%M:%S') elif isinstance(value, date): control_record[key] = value.strftime('%Y-%m-%d') elif isinstance(value, Decimal): control_record[key] = float(value) else: control_record[key] = value control_info.append(control_record) return jsonify({ 'success': True, 'data': control_info, 'total': len(control_info) }) except Exception as e: app.logger.error(f"Error getting actual control info: {e}", exc_info=True) return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/stock//concentration', methods=['GET']) def get_stock_concentration(stock_code): """获取股票股权集中度信息""" try: with engine.connect() as conn: query = text(""" SELECT ENDDATE as end_date, F001V as stat_item, F002N as holding_shares, F003N as holding_ratio, F004N as ratio_change, ORGNAME as org_name, SECCODE as sec_code, SECNAME as sec_name FROM ea_concentration WHERE SECCODE = :stock_code ORDER BY ENDDATE DESC LIMIT 20 """) result = conn.execute(query, {'stock_code': stock_code}).fetchall() concentration_info = [] for row in result: concentration_record = {} for key, value in row_to_dict(row).items(): if value is None: concentration_record[key] = None elif isinstance(value, datetime): concentration_record[key] = value.strftime('%Y-%m-%d %H:%M:%S') elif isinstance(value, date): concentration_record[key] = value.strftime('%Y-%m-%d') elif isinstance(value, Decimal): concentration_record[key] = float(value) else: concentration_record[key] = value concentration_info.append(concentration_record) return jsonify({ 'success': True, 'data': concentration_info, 'total': len(concentration_info) }) except Exception as e: app.logger.error(f"Error getting concentration info: {e}", exc_info=True) return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/stock//management', methods=['GET']) def get_stock_management(stock_code): """获取股票管理层信息""" try: # 获取是否只显示在职人员参数 active_only = request.args.get('active_only', 'true').lower() == 'true' with engine.connect() as conn: base_query = """ SELECT DECLAREDATE as declare_date, \ F001V as person_id, \ F002V as name, \ F007D as start_date, \ F008D as end_date, \ F009V as position_name, \ F010V as gender, \ F011V as education, \ F012V as birth_year, \ F013V as nationality, \ F014V as position_category_code, \ F015V as position_category, \ F016V as position_code, \ F017V as highest_degree, \ F019V as resume, \ F020C as is_active, \ ORGNAME as org_name, \ SECCODE as sec_code, \ SECNAME as sec_name FROM ea_management WHERE SECCODE = :stock_code \ """ if active_only: base_query += " AND F020C = '1'" base_query += " ORDER BY DECLAREDATE DESC, F007D DESC" query = text(base_query) result = conn.execute(query, {'stock_code': stock_code}).fetchall() management_info = [] for row in result: management_record = {} for key, value in row_to_dict(row).items(): if value is None: management_record[key] = None elif isinstance(value, datetime): management_record[key] = value.strftime('%Y-%m-%d %H:%M:%S') elif isinstance(value, date): management_record[key] = value.strftime('%Y-%m-%d') elif isinstance(value, Decimal): management_record[key] = float(value) else: management_record[key] = value management_info.append(management_record) return jsonify({ 'success': True, 'data': management_info, 'total': len(management_info) }) except Exception as e: app.logger.error(f"Error getting management info: {e}", exc_info=True) return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/stock//top-circulation-shareholders', methods=['GET']) def get_stock_top_circulation_shareholders(stock_code): """获取股票十大流通股东信息""" try: limit = request.args.get('limit', 10, type=int) with engine.connect() as conn: query = text(""" SELECT DECLAREDATE as declare_date, ENDDATE as end_date, F001N as shareholder_rank, F002V as shareholder_id, F003V as shareholder_name, F004V as shareholder_type, F005N as holding_shares, F006N as total_share_ratio, F007N as circulation_share_ratio, F011V as share_nature, F012N as b_shares, F013N as h_shares, F014N as other_shares, ORGNAME as org_name, SECCODE as sec_code, SECNAME as sec_name FROM ea_tencirculation WHERE SECCODE = :stock_code ORDER BY ENDDATE DESC, F001N ASC LIMIT :limit """) result = conn.execute(query, {'stock_code': stock_code, 'limit': limit}).fetchall() shareholders_info = [] for row in result: shareholder_record = {} for key, value in row_to_dict(row).items(): if value is None: shareholder_record[key] = None elif isinstance(value, datetime): shareholder_record[key] = value.strftime('%Y-%m-%d %H:%M:%S') elif isinstance(value, date): shareholder_record[key] = value.strftime('%Y-%m-%d') elif isinstance(value, Decimal): shareholder_record[key] = float(value) else: shareholder_record[key] = value shareholders_info.append(shareholder_record) return jsonify({ 'success': True, 'data': shareholders_info, 'total': len(shareholders_info) }) except Exception as e: app.logger.error(f"Error getting top circulation shareholders: {e}", exc_info=True) return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/stock//top-shareholders', methods=['GET']) def get_stock_top_shareholders(stock_code): """获取股票十大股东信息""" try: limit = request.args.get('limit', 10, type=int) with engine.connect() as conn: query = text(""" SELECT DECLAREDATE as declare_date, ENDDATE as end_date, F001N as shareholder_rank, F002V as shareholder_name, F003V as shareholder_id, F004V as shareholder_type, F005N as holding_shares, F006N as total_share_ratio, F007N as circulation_share_ratio, F011V as share_nature, F016N as restricted_shares, F017V as concert_party_group, F018N as circulation_shares, ORGNAME as org_name, SECCODE as sec_code, SECNAME as sec_name FROM ea_tenshareholder WHERE SECCODE = :stock_code ORDER BY ENDDATE DESC, F001N ASC LIMIT :limit """) result = conn.execute(query, {'stock_code': stock_code, 'limit': limit}).fetchall() shareholders_info = [] for row in result: shareholder_record = {} for key, value in row_to_dict(row).items(): if value is None: shareholder_record[key] = None elif isinstance(value, datetime): shareholder_record[key] = value.strftime('%Y-%m-%d %H:%M:%S') elif isinstance(value, date): shareholder_record[key] = value.strftime('%Y-%m-%d') elif isinstance(value, Decimal): shareholder_record[key] = float(value) else: shareholder_record[key] = value shareholders_info.append(shareholder_record) return jsonify({ 'success': True, 'data': shareholders_info, 'total': len(shareholders_info) }) except Exception as e: app.logger.error(f"Error getting top shareholders: {e}", exc_info=True) return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/stock//branches', methods=['GET']) def get_stock_branches(stock_code): """获取股票分支机构信息""" try: with engine.connect() as conn: query = text(""" SELECT CRECODE as cre_code, F001V as branch_name, F002V as register_capital, F003V as business_status, F004D as register_date, F005N as related_company_count, F006V as legal_person, ORGNAME as org_name, SECCODE as sec_code, SECNAME as sec_name FROM ea_branch WHERE SECCODE = :stock_code ORDER BY F004D DESC """) result = conn.execute(query, {'stock_code': stock_code}).fetchall() branches_info = [] for row in result: branch_record = {} for key, value in row_to_dict(row).items(): if value is None: branch_record[key] = None elif isinstance(value, datetime): branch_record[key] = value.strftime('%Y-%m-%d %H:%M:%S') elif isinstance(value, date): branch_record[key] = value.strftime('%Y-%m-%d') elif isinstance(value, Decimal): branch_record[key] = float(value) else: branch_record[key] = value branches_info.append(branch_record) return jsonify({ 'success': True, 'data': branches_info, 'total': len(branches_info) }) except Exception as e: app.logger.error(f"Error getting branches info: {e}", exc_info=True) return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/stock//patents', methods=['GET']) def get_stock_patents(stock_code): """获取股票专利信息""" try: limit = request.args.get('limit', 50, type=int) patent_type = request.args.get('type', None) # 专利类型筛选 with engine.connect() as conn: base_query = """ SELECT CRECODE as cre_code, \ F001V as patent_name, \ F002V as application_number, \ F003V as publication_number, \ F004V as classification_number, \ F005D as publication_date, \ F006D as application_date, \ F007V as patent_type, \ F008V as applicant, \ F009V as inventor, \ ID as id, \ ORGNAME as org_name, \ SECCODE as sec_code, \ SECNAME as sec_name FROM ea_patent WHERE SECCODE = :stock_code \ """ params = {'stock_code': stock_code, 'limit': limit} if patent_type: base_query += " AND F007V = :patent_type" params['patent_type'] = patent_type base_query += " ORDER BY F006D DESC, F005D DESC LIMIT :limit" query = text(base_query) result = conn.execute(query, params).fetchall() patents_info = [] for row in result: patent_record = {} for key, value in row_to_dict(row).items(): if value is None: patent_record[key] = None elif isinstance(value, datetime): patent_record[key] = value.strftime('%Y-%m-%d %H:%M:%S') elif isinstance(value, date): patent_record[key] = value.strftime('%Y-%m-%d') elif isinstance(value, Decimal): patent_record[key] = float(value) else: patent_record[key] = value patents_info.append(patent_record) return jsonify({ 'success': True, 'data': patents_info, 'total': len(patents_info) }) except Exception as e: app.logger.error(f"Error getting patents info: {e}", exc_info=True) return jsonify({'success': False, 'error': str(e)}), 500 def get_daily_kline(stock_code, event_datetime, stock_name): """处理日K线数据""" stock_code = stock_code.split('.')[0] with engine.connect() as conn: # 获取事件日期前后的数据(前365天/1年,后30天) kline_sql = """ WITH date_range AS (SELECT TRADEDATE \ FROM ea_trade \ WHERE SECCODE = :stock_code \ AND TRADEDATE BETWEEN DATE_SUB(:trade_date, INTERVAL 365 DAY) \ AND DATE_ADD(:trade_date, INTERVAL 30 DAY) \ GROUP BY TRADEDATE \ ORDER BY TRADEDATE) SELECT t.TRADEDATE, CAST(t.F003N AS FLOAT) as open, CAST(t.F007N AS FLOAT) as close, CAST(t.F005N AS FLOAT) as high, CAST(t.F006N AS FLOAT) as low, CAST(t.F004N AS FLOAT) as volume FROM ea_trade t JOIN date_range d \ ON t.TRADEDATE = d.TRADEDATE WHERE t.SECCODE = :stock_code ORDER BY t.TRADEDATE \ """ result = conn.execute(text(kline_sql), { "stock_code": stock_code, "trade_date": event_datetime.date() }).fetchall() if not result: return jsonify({ 'error': 'No data available', 'code': stock_code, 'name': stock_name, 'data': [], 'trade_date': event_datetime.date().strftime('%Y-%m-%d'), 'type': 'daily' }) kline_data = [{ 'time': row.TRADEDATE.strftime('%Y-%m-%d'), 'open': float(row.open), 'high': float(row.high), 'low': float(row.low), 'close': float(row.close), 'volume': float(row.volume) } for row in result] return jsonify({ 'code': stock_code, 'name': stock_name, 'data': kline_data, 'trade_date': event_datetime.date().strftime('%Y-%m-%d'), 'type': 'daily', 'is_history': True }) def get_minute_kline(stock_code, event_datetime, stock_name): """处理分钟K线数据""" client = get_clickhouse_client() target_date = get_trading_day_near_date(event_datetime.date()) is_after_market = event_datetime.time() > dt_time(15, 0) # 核心逻辑改动:先判断当前日期是否是交易日,以及是否已收盘 if target_date and is_after_market: # 如果是交易日且已收盘,查找下一个交易日 next_trade_date = get_trading_day_near_date(target_date + timedelta(days=1)) if next_trade_date: target_date = next_trade_date if not target_date: return jsonify({ 'error': 'No data available', 'code': stock_code, 'name': stock_name, 'data': [], 'trade_date': event_datetime.date().strftime('%Y-%m-%d'), 'type': 'minute' }) # 获取目标日期的完整交易时段数据 data = client.execute(""" SELECT timestamp, open, high, low, close, volume, amt FROM stock_minute WHERE code = %(code)s AND timestamp BETWEEN %(start)s AND %(end)s ORDER BY timestamp """, { 'code': stock_code, 'start': datetime.combine(target_date, dt_time(9, 30)), 'end': datetime.combine(target_date, dt_time(15, 0)) }) kline_data = [{ 'time': row[0].strftime('%H:%M'), 'open': float(row[1]), 'high': float(row[2]), 'low': float(row[3]), 'close': float(row[4]), 'volume': float(row[5]), 'amount': float(row[6]) } for row in data] return jsonify({ 'code': stock_code, 'name': stock_name, 'data': kline_data, 'trade_date': target_date.strftime('%Y-%m-%d'), 'type': 'minute', 'is_history': target_date < event_datetime.date() }) def get_timeline_data(stock_code, event_datetime, stock_name): """处理分时均价线数据(timeline)。 规则: - 若事件时间在交易日的15:00之后,则展示下一个交易日的分时数据; - 若事件日非交易日,优先展示下一个交易日;如无,则回退到最近一个交易日; - 数据区间固定为 09:30-15:00。 """ client = get_clickhouse_client() target_date = get_trading_day_near_date(event_datetime.date()) is_after_market = event_datetime.time() > dt_time(15, 0) # 与分钟K逻辑保持一致的日期选择规则 if target_date and is_after_market: next_trade_date = get_trading_day_near_date(target_date + timedelta(days=1)) if next_trade_date: target_date = next_trade_date if not target_date: return jsonify({ 'error': 'No data available', 'code': stock_code, 'name': stock_name, 'data': [], 'trade_date': event_datetime.date().strftime('%Y-%m-%d'), 'type': 'timeline' }) # 获取昨收盘价 prev_close_query = """ SELECT close FROM stock_minute WHERE code = %(code)s AND timestamp \ < %(start)s ORDER BY timestamp DESC LIMIT 1 \ """ prev_close_result = client.execute(prev_close_query, { 'code': stock_code, 'start': datetime.combine(target_date, dt_time(9, 30)) }) prev_close = float(prev_close_result[0][0]) if prev_close_result else None data = client.execute( """ SELECT timestamp, close, volume FROM stock_minute WHERE code = %(code)s AND timestamp BETWEEN %(start)s AND %(end)s ORDER BY timestamp """, { 'code': stock_code, 'start': datetime.combine(target_date, dt_time(9, 30)), 'end': datetime.combine(target_date, dt_time(15, 0)), } ) timeline_data = [] total_amount = 0 total_volume = 0 for row in data: price = float(row[1]) volume = float(row[2]) total_amount += price * volume total_volume += volume avg_price = total_amount / total_volume if total_volume > 0 else price # 计算涨跌幅 change_percent = ((price - prev_close) / prev_close * 100) if prev_close else 0 timeline_data.append({ 'time': row[0].strftime('%H:%M'), 'price': price, 'avg_price': avg_price, 'volume': volume, 'change_percent': change_percent, }) return jsonify({ 'code': stock_code, 'name': stock_name, 'data': timeline_data, 'trade_date': target_date.strftime('%Y-%m-%d'), 'type': 'timeline', 'is_history': target_date < event_datetime.date(), 'prev_close': prev_close, }) # ==================== 指数行情API(与股票逻辑一致,数据表为 index_minute) ==================== @app.route('/api/index//kline') def get_index_kline(index_code): chart_type = request.args.get('type', 'minute') event_time = request.args.get('event_time') try: event_datetime = datetime.fromisoformat(event_time) if event_time else datetime.now() except ValueError: return jsonify({'error': 'Invalid event_time format'}), 400 # 指数名称(暂无索引表,先返回代码本身) index_name = index_code if chart_type == 'minute': return get_index_minute_kline(index_code, event_datetime, index_name) elif chart_type == 'timeline': return get_index_timeline_data(index_code, event_datetime, index_name) elif chart_type == 'daily': return get_index_daily_kline(index_code, event_datetime, index_name) else: return jsonify({'error': f'Unsupported chart type: {chart_type}'}), 400 def get_index_minute_kline(index_code, event_datetime, index_name): client = get_clickhouse_client() target_date = get_trading_day_near_date(event_datetime.date()) if not target_date: return jsonify({ 'error': 'No data available', 'code': index_code, 'name': index_name, 'data': [], 'trade_date': event_datetime.date().strftime('%Y-%m-%d'), 'type': 'minute' }) data = client.execute( """ SELECT timestamp, open, high, low, close, volume, amt FROM index_minute WHERE code = %(code)s AND timestamp BETWEEN %(start)s AND %(end)s ORDER BY timestamp """, { 'code': index_code, 'start': datetime.combine(target_date, dt_time(9, 30)), 'end': datetime.combine(target_date, dt_time(15, 0)), } ) kline_data = [{ 'time': row[0].strftime('%H:%M'), 'open': float(row[1]), 'high': float(row[2]), 'low': float(row[3]), 'close': float(row[4]), 'volume': float(row[5]), 'amount': float(row[6]), } for row in data] return jsonify({ 'code': index_code, 'name': index_name, 'data': kline_data, 'trade_date': target_date.strftime('%Y-%m-%d'), 'type': 'minute', 'is_history': target_date < event_datetime.date(), }) def get_index_timeline_data(index_code, event_datetime, index_name): client = get_clickhouse_client() target_date = get_trading_day_near_date(event_datetime.date()) if not target_date: return jsonify({ 'error': 'No data available', 'code': index_code, 'name': index_name, 'data': [], 'trade_date': event_datetime.date().strftime('%Y-%m-%d'), 'type': 'timeline' }) data = client.execute( """ SELECT timestamp, close, volume FROM index_minute WHERE code = %(code)s AND timestamp BETWEEN %(start)s AND %(end)s ORDER BY timestamp """, { 'code': index_code, 'start': datetime.combine(target_date, dt_time(9, 30)), 'end': datetime.combine(target_date, dt_time(15, 0)), } ) timeline = [] total_amount = 0 total_volume = 0 for row in data: price = float(row[1]) volume = float(row[2]) total_amount += price * volume total_volume += volume avg_price = total_amount / total_volume if total_volume > 0 else price timeline.append({ 'time': row[0].strftime('%H:%M'), 'price': price, 'avg_price': avg_price, 'volume': volume, }) return jsonify({ 'code': index_code, 'name': index_name, 'data': timeline, 'trade_date': target_date.strftime('%Y-%m-%d'), 'type': 'timeline', 'is_history': target_date < event_datetime.date(), }) def get_index_daily_kline(index_code, event_datetime, index_name): """从 MySQL 的 stock.ea_exchangetrade 获取指数日线 注意:表中 INDEXCODE 无后缀,例如 000001.SH -> 000001 字段: F003N 开市指数 -> open F004N 最高指数 -> high F005N 最低指数 -> low F006N 最近指数 -> close(作为当日收盘或最近价使用) F007N 昨日收市指数 -> prev_close """ # 去掉后缀 code_no_suffix = index_code.split('.')[0] # 选择展示的最后交易日 target_date = get_trading_day_near_date(event_datetime.date()) if not target_date: return jsonify({ 'error': 'No data available', 'code': index_code, 'name': index_name, 'data': [], 'trade_date': event_datetime.date().strftime('%Y-%m-%d'), 'type': 'daily' }) # 取最近一段时间的日线(倒序再反转为升序) with engine.connect() as conn: rows = conn.execute(text( """ SELECT TRADEDATE, F003N, F004N, F005N, F006N, F007N FROM ea_exchangetrade WHERE INDEXCODE = :code AND TRADEDATE <= :end_dt ORDER BY TRADEDATE DESC LIMIT 180 """ ), { 'code': code_no_suffix, 'end_dt': datetime.combine(target_date, dt_time(23, 59, 59)) }).fetchall() # 反转为时间升序 rows = list(reversed(rows)) daily = [] for i, r in enumerate(rows): trade_dt = r[0] open_v = r[1] high_v = r[2] low_v = r[3] last_v = r[4] prev_close_v = r[5] # 正确的前收盘价逻辑:使用前一个交易日的F006N(收盘价) calculated_prev_close = None if i > 0 and rows[i - 1][4] is not None: # 使用前一个交易日的收盘价作为前收盘价 calculated_prev_close = float(rows[i - 1][4]) else: # 第一条记录,尝试使用F007N字段作为备选 if prev_close_v is not None and prev_close_v > 0: calculated_prev_close = float(prev_close_v) daily.append({ 'time': trade_dt.strftime('%Y-%m-%d') if hasattr(trade_dt, 'strftime') else str(trade_dt), 'open': float(open_v) if open_v is not None else None, 'high': float(high_v) if high_v is not None else None, 'low': float(low_v) if low_v is not None else None, 'close': float(last_v) if last_v is not None else None, 'prev_close': calculated_prev_close, }) return jsonify({ 'code': index_code, 'name': index_name, 'data': daily, 'trade_date': target_date.strftime('%Y-%m-%d'), 'type': 'daily', 'is_history': target_date < event_datetime.date(), }) # ==================== 日历API ==================== @app.route('/api/v1/calendar/event-counts', methods=['GET']) def get_event_counts(): """获取日历事件数量统计""" try: # 获取月份参数 year = request.args.get('year', datetime.now().year, type=int) month = request.args.get('month', datetime.now().month, type=int) # 计算月份的开始和结束日期 start_date = datetime(year, month, 1) if month == 12: end_date = datetime(year + 1, 1, 1) else: end_date = datetime(year, month + 1, 1) # 查询事件数量 query = """ SELECT DATE(calendar_time) as date, COUNT(*) as count FROM future_events WHERE calendar_time BETWEEN :start_date AND :end_date AND type = 'event' GROUP BY DATE(calendar_time) """ result = db.session.execute(text(query), { 'start_date': start_date, 'end_date': end_date }) # 格式化结果 events = [] for day in result: events.append({ 'date': day.date.isoformat(), 'count': day.count, 'className': get_event_class(day.count) }) return jsonify({ 'success': True, 'data': events }) except Exception as e: return jsonify({ 'success': False, 'error': str(e) }), 500 @app.route('/api/v1/calendar/events', methods=['GET']) def get_calendar_events(): """获取指定日期的事件列表""" date_str = request.args.get('date') event_type = request.args.get('type', 'all') if not date_str: return jsonify({ 'success': False, 'error': 'Date parameter required' }), 400 try: date = datetime.strptime(date_str, '%Y-%m-%d') except ValueError: return jsonify({ 'success': False, 'error': 'Invalid date format' }), 400 # 修复SQL语法:去掉函数名后的空格,去掉参数前的空格 query = """ SELECT * FROM future_events WHERE DATE(calendar_time) = :date """ params = {'date': date} if event_type != 'all': query += " AND type = :type" params['type'] = event_type query += " ORDER BY calendar_time" result = db.session.execute(text(query), params) events = [] user_following_ids = set() if 'user_id' in session: follows = FutureEventFollow.query.filter_by(user_id=session['user_id']).all() user_following_ids = {f.future_event_id for f in follows} for row in result: event_data = { 'id': row.data_id, 'title': row.title, 'type': row.type, 'calendar_time': row.calendar_time.isoformat(), 'star': row.star, 'former': row.former, 'forecast': row.forecast, 'fact': row.fact, 'is_following': row.data_id in user_following_ids } # 解析相关股票和概念 if row.related_stocks: try: if isinstance(row.related_stocks, str): if row.related_stocks.startswith('['): event_data['related_stocks'] = json.loads(row.related_stocks) else: event_data['related_stocks'] = row.related_stocks.split(',') else: event_data['related_stocks'] = row.related_stocks except: event_data['related_stocks'] = [] else: event_data['related_stocks'] = [] if row.concepts: try: if isinstance(row.concepts, str): if row.concepts.startswith('['): event_data['concepts'] = json.loads(row.concepts) else: event_data['concepts'] = row.concepts.split(',') else: event_data['concepts'] = row.concepts except: event_data['concepts'] = [] else: event_data['concepts'] = [] events.append(event_data) return jsonify({ 'success': True, 'data': events }) @app.route('/api/v1/calendar/events/', methods=['GET']) def get_calendar_event_detail(event_id): """获取日历事件详情""" try: sql = """ SELECT * FROM future_events WHERE data_id = :event_id \ """ result = db.session.execute(text(sql), {'event_id': event_id}).first() if not result: return jsonify({ 'success': False, 'error': 'Event not found' }), 404 event_data = { 'id': result.data_id, 'title': result.title, 'type': result.type, 'calendar_time': result.calendar_time.isoformat(), 'star': result.star, 'former': result.former, 'forecast': result.forecast, 'fact': result.fact, 'related_stocks': parse_json_field(result.related_stocks), 'concepts': parse_json_field(result.concepts) } # 检查当前用户是否关注了该未来事件 if 'user_id' in session: is_following = FutureEventFollow.query.filter_by( user_id=session['user_id'], future_event_id=event_id ).first() is not None event_data['is_following'] = is_following else: event_data['is_following'] = False return jsonify({ 'success': True, 'data': event_data }) except Exception as e: return jsonify({ 'success': False, 'error': str(e) }), 500 @app.route('/api/v1/calendar/events//follow', methods=['POST']) def toggle_future_event_follow(event_id): """切换未来事件关注状态(需登录)""" if 'user_id' not in session: return jsonify({'success': False, 'error': '未登录'}), 401 try: # 检查未来事件是否存在 sql = """ SELECT data_id \ FROM future_events \ WHERE data_id = :event_id \ """ result = db.session.execute(text(sql), {'event_id': event_id}).first() if not result: return jsonify({'success': False, 'error': '未来事件不存在'}), 404 user_id = session['user_id'] # 检查是否已关注 existing = FutureEventFollow.query.filter_by( user_id=user_id, future_event_id=event_id ).first() if existing: # 取消关注 db.session.delete(existing) db.session.commit() return jsonify({ 'success': True, 'data': {'is_following': False} }) else: # 关注 follow = FutureEventFollow( user_id=user_id, future_event_id=event_id ) db.session.add(follow) db.session.commit() return jsonify({ 'success': True, 'data': {'is_following': True} }) except Exception as e: db.session.rollback() return jsonify({'success': False, 'error': str(e)}), 500 def get_event_class(count): """根据事件数量返回CSS类名""" if count >= 10: return 'event-high' elif count >= 5: return 'event-medium' elif count > 0: return 'event-low' return '' def parse_json_field(field_value): """解析JSON字段""" if not field_value: return [] try: if isinstance(field_value, str): if field_value.startswith('['): return json.loads(field_value) else: return field_value.split(',') else: return field_value except: return [] # ==================== 行业API ==================== @app.route('/api/classifications', methods=['GET']) def get_classifications(): """获取申银万国行业分类树形结构""" try: # 查询申银万国行业分类的所有数据 sql = """ SELECT f003v as code, f004v as level1, f005v as level2, f006v as level3,f007v as level4 FROM ea_sector WHERE f002v = '申银万国行业分类' AND f003v IS NOT NULL AND f004v IS NOT NULL ORDER BY f003v """ result = db.session.execute(text(sql)).all() # 构建树形结构 tree_dict = {} for row in result: code = row.code level1 = row.level1 level2 = row.level2 level3 = row.level3 # 跳过空数据 if not level1: continue # 第一层 if level1 not in tree_dict: # 获取第一层的code(取前3位或前缀) level1_code = code[:3] if len(code) >= 3 else code tree_dict[level1] = { 'value': level1_code, 'label': level1, 'children_dict': {} } # 第二层 if level2: if level2 not in tree_dict[level1]['children_dict']: # 获取第二层的code(取前6位) level2_code = code[:6] if len(code) >= 6 else code tree_dict[level1]['children_dict'][level2] = { 'value': level2_code, 'label': level2, 'children_dict': {} } # 第三层 if level3: if level3 not in tree_dict[level1]['children_dict'][level2]['children_dict']: tree_dict[level1]['children_dict'][level2]['children_dict'][level3] = { 'value': code, 'label': level3 } # 转换为最终格式 result_list = [] for level1_name, level1_data in tree_dict.items(): level1_node = { 'value': level1_data['value'], 'label': level1_data['label'] } # 处理第二层 if level1_data['children_dict']: level1_children = [] for level2_name, level2_data in level1_data['children_dict'].items(): level2_node = { 'value': level2_data['value'], 'label': level2_data['label'] } # 处理第三层 if level2_data['children_dict']: level2_children = [] for level3_name, level3_data in level2_data['children_dict'].items(): level2_children.append({ 'value': level3_data['value'], 'label': level3_data['label'] }) if level2_children: level2_node['children'] = level2_children level1_children.append(level2_node) if level1_children: level1_node['children'] = level1_children result_list.append(level1_node) return jsonify({ 'success': True, 'data': result_list }) except Exception as e: return jsonify({ 'success': False, 'error': str(e) }), 500 @app.route('/api/stocklist', methods=['GET']) def get_stock_list(): """获取股票列表""" try: sql = """ SELECT DISTINCT SECCODE as code, SECNAME as name FROM ea_stocklist ORDER BY SECCODE """ result = db.session.execute(text(sql)).all() stocks = [{'code': row.code, 'name': row.name} for row in result] return jsonify(stocks) except Exception as e: return jsonify({ 'success': False, 'error': str(e) }), 500 @app.route('/api/events', methods=['GET'], strict_slashes=False) def api_get_events(): """ 获取事件列表API - 支持筛选、排序、分页,兼容前端调用 """ try: # 分页参数 page = max(1, request.args.get('page', 1, type=int)) per_page = min(100, max(1, request.args.get('per_page', 10, type=int))) # 基础筛选参数 event_type = request.args.get('type', 'all') event_status = request.args.get('status', 'active') importance = request.args.get('importance', 'all') # 日期筛选参数 start_date = request.args.get('start_date') end_date = request.args.get('end_date') date_range = request.args.get('date_range') recent_days = request.args.get('recent_days', type=int) # 行业筛选参数(只支持申银万国行业分类) industry_code = request.args.get('industry_code') # 申万行业代码,如 "S370502" # 概念/标签筛选参数 tag = request.args.get('tag') tags = request.args.get('tags') keywords = request.args.get('keywords') # 搜索参数 search_query = request.args.get('q') search_type = request.args.get('search_type', 'topic') search_fields = request.args.get('search_fields', 'title,description').split(',') # 排序参数 sort_by = request.args.get('sort', 'new') return_type = request.args.get('return_type', 'avg') order = request.args.get('order', 'desc') # 收益率筛选参数 min_avg_return = request.args.get('min_avg_return', type=float) max_avg_return = request.args.get('max_avg_return', type=float) min_max_return = request.args.get('min_max_return', type=float) max_max_return = request.args.get('max_max_return', type=float) min_week_return = request.args.get('min_week_return', type=float) max_week_return = request.args.get('max_week_return', type=float) # 其他筛选参数 min_hot_score = request.args.get('min_hot_score', type=float) max_hot_score = request.args.get('max_hot_score', type=float) min_view_count = request.args.get('min_view_count', type=int) creator_id = request.args.get('creator_id', type=int) # 返回格式参数 include_creator = request.args.get('include_creator', 'true').lower() == 'true' include_stats = request.args.get('include_stats', 'true').lower() == 'true' include_related_data = request.args.get('include_related_data', 'false').lower() == 'true' # ==================== 构建查询 ==================== query = Event.query if event_status != 'all': query = query.filter_by(status=event_status) if event_type != 'all': query = query.filter_by(event_type=event_type) # 支持多个重要性级别筛选,用逗号分隔(如 importance=S,A) if importance != 'all': if ',' in importance: # 多个重要性级别 importance_list = [imp.strip() for imp in importance.split(',') if imp.strip()] query = query.filter(Event.importance.in_(importance_list)) else: # 单个重要性级别 query = query.filter_by(importance=importance) if creator_id: query = query.filter_by(creator_id=creator_id) # 新增:行业代码过滤(申银万国行业分类) if industry_code: # related_industries 格式: [{"申银万国行业分类": "S370502"}, ...] # 支持多个行业代码,用逗号分隔 json_path = '$[*]."申银万国行业分类"' # 如果包含逗号,说明是多个行业代码 if ',' in industry_code: codes = [code.strip() for code in industry_code.split(',') if code.strip()] # 使用 OR 条件匹配任意一个行业代码 conditions = [] for code in codes: conditions.append( text("JSON_CONTAINS(JSON_EXTRACT(related_industries, :json_path), :code)") .bindparams(json_path=json_path, code=json.dumps(code)) ) query = query.filter(db.or_(*conditions)) else: # 单个行业代码 query = query.filter( text("JSON_CONTAINS(JSON_EXTRACT(related_industries, :json_path), :industry_code)") ).params(json_path=json_path, industry_code=json.dumps(industry_code)) # 新增:关键词/全文搜索过滤(MySQL JSON) if search_query: like_pattern = f"%{search_query}%" # 子查询:查找关联股票中匹配的事件ID stock_subquery = db.session.query(RelatedStock.event_id).filter( db.or_( RelatedStock.stock_name.ilike(like_pattern), RelatedStock.relation_desc.ilike(like_pattern) ) ).distinct() # 主查询:搜索事件标题、描述、关键词或关联股票 query = query.filter( db.or_( Event.title.ilike(like_pattern), Event.description.ilike(like_pattern), text(f"JSON_SEARCH(keywords, 'one', '%{search_query}%') IS NOT NULL"), Event.id.in_(stock_subquery) ) ) if recent_days: from datetime import datetime, timedelta cutoff_date = datetime.now() - timedelta(days=recent_days) query = query.filter(Event.created_at >= cutoff_date) else: if date_range and ' 至 ' in date_range: try: start_date_str, end_date_str = date_range.split(' 至 ') start_date = start_date_str.strip() end_date = end_date_str.strip() except ValueError: pass if start_date: from datetime import datetime try: if len(start_date) == 10: start_datetime = datetime.strptime(start_date, '%Y-%m-%d') else: start_datetime = datetime.strptime(start_date, '%Y-%m-%d %H:%M:%S') query = query.filter(Event.created_at >= start_datetime) except ValueError: pass if end_date: from datetime import datetime try: if len(end_date) == 10: end_datetime = datetime.strptime(end_date, '%Y-%m-%d') end_datetime = end_datetime.replace(hour=23, minute=59, second=59) else: end_datetime = datetime.strptime(end_date, '%Y-%m-%d %H:%M:%S') query = query.filter(Event.created_at <= end_datetime) except ValueError: pass if min_view_count is not None: query = query.filter(Event.view_count >= min_view_count) # 排序 from sqlalchemy import desc, asc, case order_func = desc if order.lower() == 'desc' else asc if sort_by == 'hot': query = query.order_by(order_func(Event.hot_score)) elif sort_by == 'new': query = query.order_by(order_func(Event.created_at)) elif sort_by == 'returns': if return_type == 'avg': query = query.order_by(order_func(Event.related_avg_chg)) elif return_type == 'max': query = query.order_by(order_func(Event.related_max_chg)) elif return_type == 'week': query = query.order_by(order_func(Event.related_week_chg)) elif sort_by == 'importance': importance_order = case( (Event.importance == 'S', 1), (Event.importance == 'A', 2), (Event.importance == 'B', 3), (Event.importance == 'C', 4), else_=5 ) if order.lower() == 'desc': query = query.order_by(importance_order) else: query = query.order_by(desc(importance_order)) elif sort_by == 'view_count': query = query.order_by(order_func(Event.view_count)) # 分页 paginated = query.paginate(page=page, per_page=per_page, error_out=False) events_data = [] for event in paginated.items: event_dict = { 'id': event.id, 'title': event.title, 'description': event.description, 'event_type': event.event_type, 'importance': event.importance, 'status': event.status, 'created_at': event.created_at.isoformat() if event.created_at else None, 'updated_at': event.updated_at.isoformat() if event.updated_at else None, 'start_time': event.start_time.isoformat() if event.start_time else None, 'end_time': event.end_time.isoformat() if event.end_time else None, } if include_stats: event_dict.update({ 'hot_score': event.hot_score, 'view_count': event.view_count, 'post_count': event.post_count, 'follower_count': event.follower_count, 'related_avg_chg': event.related_avg_chg, 'related_max_chg': event.related_max_chg, 'related_week_chg': event.related_week_chg, 'invest_score': event.invest_score, 'trending_score': event.trending_score, }) if include_creator: event_dict['creator'] = { 'id': event.creator.id if event.creator else None, 'username': event.creator.username if event.creator else 'Anonymous' } event_dict['keywords'] = event.keywords_list if hasattr(event, 'keywords_list') else event.keywords event_dict['related_industries'] = event.related_industries if include_related_data: pass events_data.append(event_dict) applied_filters = {} if event_type != 'all': applied_filters['type'] = event_type if importance != 'all': applied_filters['importance'] = importance if start_date: applied_filters['start_date'] = start_date if end_date: applied_filters['end_date'] = end_date if industry_code: applied_filters['industry_code'] = industry_code if tag: applied_filters['tag'] = tag if tags: applied_filters['tags'] = tags if search_query: applied_filters['search_query'] = search_query applied_filters['search_type'] = search_type return jsonify({ 'success': True, 'data': { 'events': events_data, 'pagination': { 'page': paginated.page, 'per_page': paginated.per_page, 'total': paginated.total, 'pages': paginated.pages, 'has_prev': paginated.has_prev, 'has_next': paginated.has_next }, 'filters': { 'applied_filters': applied_filters, 'total_count': paginated.total } } }) except Exception as e: app.logger.error(f"获取事件列表出错: {str(e)}", exc_info=True) return jsonify({ 'success': False, 'error': str(e), 'error_type': type(e).__name__ }), 500 @app.route('/api/events/hot', methods=['GET']) def get_hot_events(): """获取热点事件""" try: from datetime import datetime, timedelta days = request.args.get('days', 3, type=int) limit = request.args.get('limit', 4, type=int) since_date = datetime.now() - timedelta(days=days) hot_events = Event.query.filter( Event.status == 'active', Event.created_at >= since_date, Event.related_avg_chg != None, Event.related_avg_chg > 0 ).order_by(Event.related_avg_chg.desc()).limit(limit).all() if len(hot_events) < limit: additional_events = Event.query.filter( Event.status == 'active', Event.created_at >= since_date, ~Event.id.in_([event.id for event in hot_events]) ).order_by(Event.hot_score.desc()).limit(limit - len(hot_events)).all() hot_events.extend(additional_events) events_data = [] for event in hot_events: events_data.append({ 'id': event.id, 'title': event.title, 'description': event.description, 'importance': event.importance, 'created_at': event.created_at.isoformat() if event.created_at else None, 'related_avg_chg': event.related_avg_chg, 'creator': { 'username': event.creator.username if event.creator else 'Anonymous' } }) return jsonify({'success': True, 'data': events_data}) except Exception as e: return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/events/keywords/popular', methods=['GET']) def get_popular_keywords(): """获取热门关键词""" try: limit = request.args.get('limit', 20, type=int) sql = ''' WITH RECURSIVE \ numbers AS (SELECT 0 as n \ UNION ALL \ SELECT n + 1 \ FROM numbers \ WHERE n < 100), \ json_array AS (SELECT JSON_UNQUOTE(JSON_EXTRACT(e.keywords, CONCAT('$[', n.n, ']'))) as keyword, \ COUNT(*) as count FROM event e CROSS JOIN numbers n WHERE e.status = 'active' AND JSON_EXTRACT(e.keywords \ , CONCAT('$[' \ , n.n \ , ']')) IS NOT NULL GROUP BY JSON_UNQUOTE(JSON_EXTRACT(e.keywords, CONCAT('$[', n.n, ']'))) HAVING keyword IS NOT NULL ) SELECT keyword, count FROM json_array ORDER BY count DESC, keyword LIMIT :limit \ ''' result = db.session.execute(text(sql), {'limit': limit}).all() keywords_data = [{'keyword': row.keyword, 'count': row.count} for row in result] return jsonify({'success': True, 'data': keywords_data}) except Exception as e: return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/events//sankey-data') def get_event_sankey_data(event_id): """ 获取事件桑基图数据 (最终优化版) - 处理重名节点 - 检测并打破循环依赖 """ flows = EventSankeyFlow.query.filter_by(event_id=event_id).order_by( EventSankeyFlow.source_level, EventSankeyFlow.target_level ).all() if not flows: return jsonify({'success': False, 'message': '暂无桑基图数据'}) nodes_map = {} links = [] type_colors = { 'event': '#ff4757', 'policy': '#10ac84', 'technology': '#ee5a6f', 'industry': '#00d2d3', 'company': '#54a0ff', 'product': '#ffd93d' } # --- 1. 识别并处理重名节点 (与上一版相同) --- all_node_keys = set() name_counts = {} for flow in flows: source_key = f"{flow.source_node}|{flow.source_level}" target_key = f"{flow.target_node}|{flow.target_level}" all_node_keys.add(source_key) all_node_keys.add(target_key) name_counts.setdefault(flow.source_node, set()).add(flow.source_level) name_counts.setdefault(flow.target_node, set()).add(flow.target_level) duplicate_names = {name for name, levels in name_counts.items() if len(levels) > 1} for flow in flows: source_key = f"{flow.source_node}|{flow.source_level}" if source_key not in nodes_map: display_name = f"{flow.source_node} (L{flow.source_level})" if flow.source_node in duplicate_names else flow.source_node nodes_map[source_key] = {'name': display_name, 'type': flow.source_type, 'level': flow.source_level, 'color': type_colors.get(flow.source_type)} target_key = f"{flow.target_node}|{flow.target_level}" if target_key not in nodes_map: display_name = f"{flow.target_node} (L{flow.target_level})" if flow.target_node in duplicate_names else flow.target_node nodes_map[target_key] = {'name': display_name, 'type': flow.target_type, 'level': flow.target_level, 'color': type_colors.get(flow.target_type)} links.append({ 'source_key': source_key, 'target_key': target_key, 'value': float(flow.flow_value), 'ratio': float(flow.flow_ratio), 'transmission_path': flow.transmission_path, 'impact_description': flow.impact_description, 'evidence_strength': flow.evidence_strength }) # --- 2. 循环检测与处理 --- # 构建邻接表 adj = defaultdict(list) for link in links: adj[link['source_key']].append(link['target_key']) # 深度优先搜索(DFS)来检测循环 path = set() # 记录当前递归路径上的节点 visited = set() # 记录所有访问过的节点 back_edges = set() # 记录导致循环的"回流边" def detect_cycle_util(node): path.add(node) visited.add(node) for neighbour in adj.get(node, []): if neighbour in path: # 发现了循环,记录这条回流边 (target, source) back_edges.add((neighbour, node)) elif neighbour not in visited: detect_cycle_util(neighbour) path.remove(node) # 从所有节点开始检测 for node_key in list(adj.keys()): if node_key not in visited: detect_cycle_util(node_key) # 过滤掉导致循环的边 if back_edges: print(f"检测到并移除了 {len(back_edges)} 条循环边: {back_edges}") valid_links_no_cycle = [] for link in links: if (link['source_key'], link['target_key']) not in back_edges and \ (link['target_key'], link['source_key']) not in back_edges: # 移除非严格意义上的双向边 valid_links_no_cycle.append(link) # --- 3. 构建最终的 JSON 响应 (与上一版相似) --- node_list = [] node_index_map = {} sorted_node_keys = sorted(nodes_map.keys(), key=lambda k: (nodes_map[k]['level'], nodes_map[k]['name'])) for i, key in enumerate(sorted_node_keys): node_list.append(nodes_map[key]) node_index_map[key] = i final_links = [] for link in valid_links_no_cycle: source_idx = node_index_map.get(link['source_key']) target_idx = node_index_map.get(link['target_key']) if source_idx is not None and target_idx is not None: # 移除临时的 key,只保留 ECharts 需要的字段 link.pop('source_key', None) link.pop('target_key', None) link['source'] = source_idx link['target'] = target_idx final_links.append(link) # ... (统计信息计算部分保持不变) ... stats = { 'total_nodes': len(node_list), 'total_flows': len(final_links), 'total_flow_value': sum(link['value'] for link in final_links), 'max_level': max((node['level'] for node in node_list), default=0), 'node_type_counts': {ntype: sum(1 for n in node_list if n['type'] == ntype) for ntype in type_colors} } return jsonify({ 'success': True, 'data': {'nodes': node_list, 'links': final_links, 'stats': stats} }) # 优化后的传导链分析 API @app.route('/api/events//chain-analysis') def get_event_chain_analysis(event_id): """获取事件传导链分析数据""" nodes = EventTransmissionNode.query.filter_by(event_id=event_id).all() if not nodes: return jsonify({'success': False, 'message': '暂无传导链分析数据'}) edges = EventTransmissionEdge.query.filter_by(event_id=event_id).all() # 过滤孤立节点 connected_node_ids = set() for edge in edges: connected_node_ids.add(edge.from_node_id) connected_node_ids.add(edge.to_node_id) # 只保留有连接的节点 connected_nodes = [node for node in nodes if node.id in connected_node_ids] if not connected_nodes: return jsonify({'success': False, 'message': '所有节点都是孤立的,暂无传导关系'}) # 节点分类,用于力导向图的图例 categories = { 'event': "事件", 'industry': "行业", 'company': "公司", 'policy': "政策", 'technology': "技术", 'market': "市场", 'other': "其他" } # 计算每个节点的连接数 node_connection_count = {} for node in connected_nodes: count = sum(1 for edge in edges if edge.from_node_id == node.id or edge.to_node_id == node.id) node_connection_count[node.id] = count nodes_data = [] for node in connected_nodes: connection_count = node_connection_count[node.id] nodes_data.append({ 'id': str(node.id), 'name': node.node_name, 'value': node.importance_score, # 用于控制节点大小的基础值 'category': categories.get(node.node_type), 'extra': { 'node_type': node.node_type, 'description': node.node_description, 'importance_score': node.importance_score, 'stock_code': node.stock_code, 'is_main_event': node.is_main_event, 'connection_count': connection_count, # 添加连接数信息 } }) edges_data = [] for edge in edges: # 确保边的两端节点都在连接节点列表中 if edge.from_node_id in connected_node_ids and edge.to_node_id in connected_node_ids: edges_data.append({ 'source': str(edge.from_node_id), 'target': str(edge.to_node_id), 'value': edge.strength, # 用于控制边的宽度 'extra': { 'transmission_type': edge.transmission_type, 'transmission_mechanism': edge.transmission_mechanism, 'direction': edge.direction, 'strength': edge.strength, 'impact': edge.impact, 'is_circular': edge.is_circular, } }) # 重新计算统计信息(基于连接的节点和边) stats = { 'total_nodes': len(connected_nodes), 'total_edges': len(edges_data), 'node_types': {cat: sum(1 for n in connected_nodes if n.node_type == node_type) for node_type, cat in categories.items()}, 'edge_types': {edge.transmission_type: sum(1 for e in edges_data if e['extra']['transmission_type'] == edge.transmission_type) for edge in edges}, 'avg_importance': sum(node.importance_score for node in connected_nodes) / len( connected_nodes) if connected_nodes else 0, 'avg_strength': sum(edge.strength for edge in edges) / len(edges) if edges else 0 } return jsonify({ 'success': True, 'data': { 'nodes': nodes_data, 'edges': edges_data, 'categories': list(categories.values()), 'stats': stats } }) @app.route('/api/events//chain-node/', methods=['GET']) @cross_origin() def get_chain_node_detail(event_id, node_id): """获取传导链节点及其直接关联节点的详细信息""" node = db.session.get(EventTransmissionNode, node_id) if not node or node.event_id != event_id: return jsonify({'success': False, 'message': '节点不存在'}) # 验证节点是否为孤立节点 total_connections = (EventTransmissionEdge.query.filter_by(from_node_id=node_id).count() + EventTransmissionEdge.query.filter_by(to_node_id=node_id).count()) if total_connections == 0 and not node.is_main_event: return jsonify({'success': False, 'message': '该节点为孤立节点,无连接关系'}) # 找出影响当前节点的父节点 parents_info = [] incoming_edges = EventTransmissionEdge.query.filter_by(to_node_id=node_id).all() for edge in incoming_edges: parent = db.session.get(EventTransmissionNode, edge.from_node_id) if parent: parents_info.append({ 'id': parent.id, 'name': parent.node_name, 'type': parent.node_type, 'direction': edge.direction, 'strength': edge.strength, 'transmission_type': edge.transmission_type, 'transmission_mechanism': edge.transmission_mechanism, # 修复字段名 'is_circular': edge.is_circular, 'impact': edge.impact }) # 找出被当前节点影响的子节点 children_info = [] outgoing_edges = EventTransmissionEdge.query.filter_by(from_node_id=node_id).all() for edge in outgoing_edges: child = db.session.get(EventTransmissionNode, edge.to_node_id) if child: children_info.append({ 'id': child.id, 'name': child.node_name, 'type': child.node_type, 'direction': edge.direction, 'strength': edge.strength, 'transmission_type': edge.transmission_type, 'transmission_mechanism': edge.transmission_mechanism, # 修复字段名 'is_circular': edge.is_circular, 'impact': edge.impact }) node_data = { 'id': node.id, 'name': node.node_name, 'type': node.node_type, 'description': node.node_description, 'importance_score': node.importance_score, 'stock_code': node.stock_code, 'is_main_event': node.is_main_event, 'total_connections': total_connections, 'incoming_connections': len(incoming_edges), 'outgoing_connections': len(outgoing_edges) } return jsonify({ 'success': True, 'data': { 'node': node_data, 'parents': parents_info, 'children': children_info } }) @app.route('/api/events//posts', methods=['GET']) def get_event_posts(event_id): """获取事件下的帖子""" try: sort_type = request.args.get('sort', 'latest') page = request.args.get('page', 1, type=int) per_page = request.args.get('per_page', 20, type=int) # 查询事件下的帖子 query = Post.query.filter_by(event_id=event_id, status='active') if sort_type == 'hot': query = query.order_by(Post.likes_count.desc(), Post.created_at.desc()) else: # latest query = query.order_by(Post.created_at.desc()) # 分页 pagination = query.paginate(page=page, per_page=per_page, error_out=False) posts = pagination.items posts_data = [] for post in posts: post_dict = { 'id': post.id, 'event_id': post.event_id, 'user_id': post.user_id, 'title': post.title, 'content': post.content, 'content_type': post.content_type, 'created_at': post.created_at.isoformat(), 'updated_at': post.updated_at.isoformat(), 'likes_count': post.likes_count, 'comments_count': post.comments_count, 'view_count': post.view_count, 'is_top': post.is_top, 'user': { 'id': post.user.id, 'username': post.user.username, 'avatar_url': post.user.avatar_url } if post.user else None, 'liked': False # 后续可以根据当前用户判断 } posts_data.append(post_dict) return jsonify({ 'success': True, 'data': posts_data, 'pagination': { 'page': page, 'per_page': per_page, 'total': pagination.total, 'pages': pagination.pages } }) except Exception as e: print(f"获取帖子失败: {e}") return jsonify({ 'success': False, 'error': str(e) }), 500 @app.route('/api/events//posts', methods=['POST']) @login_required def create_event_post(event_id): """在事件下创建帖子""" try: data = request.get_json() content = data.get('content', '').strip() title = data.get('title', '').strip() content_type = data.get('content_type', 'text') if not content: return jsonify({ 'success': False, 'message': '帖子内容不能为空' }), 400 # 创建新帖子 post = Post( event_id=event_id, user_id=current_user.id, title=title, content=content, content_type=content_type ) db.session.add(post) # 更新事件的帖子数 event = Event.query.get(event_id) if event: event.post_count = Post.query.filter_by(event_id=event_id, status='active').count() # 更新用户发帖数 current_user.post_count = (current_user.post_count or 0) + 1 db.session.commit() return jsonify({ 'success': True, 'data': { 'id': post.id, 'event_id': post.event_id, 'user_id': post.user_id, 'title': post.title, 'content': post.content, 'content_type': post.content_type, 'created_at': post.created_at.isoformat(), 'user': { 'id': current_user.id, 'username': current_user.username, 'avatar_url': current_user.avatar_url } }, 'message': '帖子发布成功' }) except Exception as e: db.session.rollback() print(f"创建帖子失败: {e}") return jsonify({ 'success': False, 'message': str(e) }), 500 @app.route('/api/posts//comments', methods=['GET']) def get_post_comments(post_id): """获取帖子的评论""" try: sort_type = request.args.get('sort', 'latest') # 查询帖子的顶级评论(非回复) query = Comment.query.filter_by(post_id=post_id, parent_id=None, status='active') if sort_type == 'hot': comments = query.order_by(Comment.likes_count.desc(), Comment.created_at.desc()).all() else: # latest comments = query.order_by(Comment.created_at.desc()).all() comments_data = [] for comment in comments: comment_dict = { 'id': comment.id, 'post_id': comment.post_id, 'user_id': comment.user_id, 'content': comment.content, 'created_at': comment.created_at.isoformat(), 'updated_at': comment.updated_at.isoformat(), 'likes_count': comment.likes_count, 'user': { 'id': comment.user.id, 'username': comment.user.username, 'avatar_url': comment.user.avatar_url } if comment.user else None, 'replies': [] # 加载回复 } # 加载回复 replies = Comment.query.filter_by(parent_id=comment.id, status='active').order_by(Comment.created_at).all() for reply in replies: reply_dict = { 'id': reply.id, 'post_id': reply.post_id, 'user_id': reply.user_id, 'content': reply.content, 'parent_id': reply.parent_id, 'created_at': reply.created_at.isoformat(), 'likes_count': reply.likes_count, 'user': { 'id': reply.user.id, 'username': reply.user.username, 'avatar_url': reply.user.avatar_url } if reply.user else None } comment_dict['replies'].append(reply_dict) comments_data.append(comment_dict) return jsonify({ 'success': True, 'data': comments_data }) except Exception as e: print(f"获取评论失败: {e}") return jsonify({ 'success': False, 'error': str(e) }), 500 @app.route('/api/posts//comments', methods=['POST']) @login_required def create_post_comment(post_id): """在帖子下创建评论""" try: data = request.get_json() content = data.get('content', '').strip() parent_id = data.get('parent_id') if not content: return jsonify({ 'success': False, 'message': '评论内容不能为空' }), 400 # 创建新评论 comment = Comment( post_id=post_id, user_id=current_user.id, content=content, parent_id=parent_id ) db.session.add(comment) # 更新帖子评论数 post = Post.query.get(post_id) if post: post.comments_count = Comment.query.filter_by(post_id=post_id, status='active').count() # 更新用户评论数 current_user.comment_count = (current_user.comment_count or 0) + 1 db.session.commit() return jsonify({ 'success': True, 'data': { 'id': comment.id, 'post_id': comment.post_id, 'user_id': comment.user_id, 'content': comment.content, 'parent_id': comment.parent_id, 'created_at': comment.created_at.isoformat(), 'user': { 'id': current_user.id, 'username': current_user.username, 'avatar_url': current_user.avatar_url } }, 'message': '评论发布成功' }) except Exception as e: db.session.rollback() print(f"创建评论失败: {e}") return jsonify({ 'success': False, 'message': str(e) }), 500 # 兼容旧的评论接口,转换为帖子模式 @app.route('/api/events//comments', methods=['GET']) def get_event_comments(event_id): """获取事件评论(兼容旧接口)""" # 将事件评论转换为获取事件下所有帖子的评论 return get_event_posts(event_id) @app.route('/api/events//comments', methods=['POST']) @login_required def add_event_comment(event_id): """添加事件评论(兼容旧接口)""" try: data = request.get_json() content = data.get('content', '').strip() parent_id = data.get('parent_id') if not content: return jsonify({ 'success': False, 'message': '评论内容不能为空' }), 400 # 如果有 parent_id,说明是回复,需要找到对应的帖子 if parent_id: # 这是一个回复,需要将其转换为对应帖子的评论 # 首先需要找到 parent_id 对应的帖子 # 这里假设旧的 parent_id 是之前的 EventComment id # 需要在数据迁移时处理这个映射关系 return jsonify({ 'success': False, 'message': '回复功能正在升级中,请稍后再试' }), 503 # 如果没有 parent_id,说明是顶级评论,创建为新帖子 post = Post( event_id=event_id, user_id=current_user.id, content=content, content_type='text' ) db.session.add(post) # 更新事件的帖子数 event = Event.query.get(event_id) if event: event.post_count = Post.query.filter_by(event_id=event_id, status='active').count() # 更新用户发帖数 current_user.post_count = (current_user.post_count or 0) + 1 db.session.commit() # 返回兼容旧接口的数据格式 return jsonify({ 'success': True, 'data': { 'id': post.id, 'event_id': post.event_id, 'user_id': post.user_id, 'author': current_user.username, 'content': post.content, 'parent_id': None, 'likes': 0, 'created_at': post.created_at.isoformat(), 'status': 'active', 'user': { 'id': current_user.id, 'username': current_user.username, 'avatar_url': current_user.avatar_url }, 'replies': [] }, 'message': '评论发布成功' }) except Exception as e: db.session.rollback() print(f"添加事件评论失败: {e}") return jsonify({ 'success': False, 'message': str(e) }), 500 # ==================== WebSocket 事件处理器(实时事件推送) ==================== @socketio.on('connect') def handle_connect(): """客户端连接事件""" print(f'\n[WebSocket DEBUG] ========== 客户端连接 ==========') print(f'[WebSocket DEBUG] Socket ID: {request.sid}') print(f'[WebSocket DEBUG] Remote Address: {request.remote_addr if hasattr(request, "remote_addr") else "N/A"}') print(f'[WebSocket] 客户端已连接: {request.sid}') emit('connection_response', { 'status': 'connected', 'sid': request.sid, 'message': '已连接到事件推送服务' }) print(f'[WebSocket DEBUG] ✓ 已发送 connection_response') print(f'[WebSocket DEBUG] ========== 连接完成 ==========\n') @socketio.on('subscribe_events') def handle_subscribe(data): """ 客户端订阅事件推送 data: { 'event_type': 'all' | 'policy' | 'market' | 'tech' | ..., 'importance': 'all' | 'S' | 'A' | 'B' | 'C', 'filters': {...} # 可选的其他筛选条件 } """ try: print(f'\n[WebSocket DEBUG] ========== 收到订阅请求 ==========') print(f'[WebSocket DEBUG] Socket ID: {request.sid}') print(f'[WebSocket DEBUG] 订阅数据: {data}') event_type = data.get('event_type', 'all') importance = data.get('importance', 'all') print(f'[WebSocket DEBUG] 事件类型: {event_type}') print(f'[WebSocket DEBUG] 重要性: {importance}') # 加入对应的房间 room_name = f"events_{event_type}" print(f'[WebSocket DEBUG] 准备加入房间: {room_name}') join_room(room_name) print(f'[WebSocket DEBUG] ✓ 已加入房间: {room_name}') print(f'[WebSocket] 客户端 {request.sid} 订阅了房间: {room_name}') response_data = { 'success': True, 'room': room_name, 'event_type': event_type, 'importance': importance, 'message': f'已订阅 {event_type} 类型的事件推送' } print(f'[WebSocket DEBUG] 准备发送 subscription_confirmed: {response_data}') emit('subscription_confirmed', response_data) print(f'[WebSocket DEBUG] ✓ 已发送 subscription_confirmed') print(f'[WebSocket DEBUG] ========== 订阅完成 ==========\n') except Exception as e: print(f'[WebSocket ERROR] 订阅失败: {e}') import traceback traceback.print_exc() emit('subscription_error', { 'success': False, 'error': str(e) }) @socketio.on('unsubscribe_events') def handle_unsubscribe(data): """取消订阅事件推送""" try: print(f'\n[WebSocket DEBUG] ========== 收到取消订阅请求 ==========') print(f'[WebSocket DEBUG] Socket ID: {request.sid}') print(f'[WebSocket DEBUG] 数据: {data}') event_type = data.get('event_type', 'all') room_name = f"events_{event_type}" print(f'[WebSocket DEBUG] 准备离开房间: {room_name}') leave_room(room_name) print(f'[WebSocket DEBUG] ✓ 已离开房间: {room_name}') print(f'[WebSocket] 客户端 {request.sid} 取消订阅房间: {room_name}') emit('unsubscription_confirmed', { 'success': True, 'room': room_name, 'message': f'已取消订阅 {event_type} 类型的事件推送' }) print(f'[WebSocket DEBUG] ========== 取消订阅完成 ==========\n') except Exception as e: print(f'[WebSocket ERROR] 取消订阅失败: {e}') import traceback traceback.print_exc() emit('unsubscription_error', { 'success': False, 'error': str(e) }) @socketio.on('disconnect') def handle_disconnect(): """客户端断开连接事件""" print(f'\n[WebSocket DEBUG] ========== 客户端断开 ==========') print(f'[WebSocket DEBUG] Socket ID: {request.sid}') print(f'[WebSocket] 客户端已断开: {request.sid}') print(f'[WebSocket DEBUG] ========== 断开完成 ==========\n') # ==================== WebSocket 辅助函数 ==================== def broadcast_new_event(event): """ 广播新事件到所有订阅的客户端 在创建新事件时调用此函数 Args: event: Event 模型实例 """ try: print(f'\n[WebSocket DEBUG] ========== 广播新事件 ==========') print(f'[WebSocket DEBUG] 事件ID: {event.id}') print(f'[WebSocket DEBUG] 事件标题: {event.title}') print(f'[WebSocket DEBUG] 事件类型: {event.event_type}') print(f'[WebSocket DEBUG] 重要性: {event.importance}') event_data = { 'id': event.id, 'title': event.title, 'description': event.description, 'event_type': event.event_type, 'importance': event.importance, 'status': event.status, 'created_at': event.created_at.isoformat() if event.created_at else None, 'hot_score': event.hot_score, 'view_count': event.view_count, 'related_avg_chg': event.related_avg_chg, 'related_max_chg': event.related_max_chg, 'keywords': event.keywords_list if hasattr(event, 'keywords_list') else event.keywords, } print(f'[WebSocket DEBUG] 准备发送的数据: {event_data}') # 发送到所有订阅者(all 房间) print(f'[WebSocket DEBUG] 正在发送到房间: events_all') socketio.emit('new_event', event_data, room='events_all', namespace='/') print(f'[WebSocket DEBUG] ✓ 已发送到 events_all') # 发送到特定类型订阅者 if event.event_type: room_name = f"events_{event.event_type}" print(f'[WebSocket DEBUG] 正在发送到房间: {room_name}') socketio.emit('new_event', event_data, room=room_name, namespace='/') print(f'[WebSocket DEBUG] ✓ 已发送到 {room_name}') print(f'[WebSocket] 已推送新事件到房间: events_all, {room_name}') else: print(f'[WebSocket] 已推送新事件到房间: events_all') print(f'[WebSocket DEBUG] ========== 广播完成 ==========\n') except Exception as e: print(f'[WebSocket ERROR] 推送新事件失败: {e}') import traceback traceback.print_exc() # ==================== WebSocket 轮询机制(检测新事件) ==================== # 内存变量:记录近24小时内已知的事件ID集合和最大ID known_event_ids_in_24h = set() # 近24小时内已知的所有事件ID last_max_event_id = 0 # 已知的最大事件ID def poll_new_events(): """ 定期轮询数据库,检查是否有新事件 每 30 秒执行一次 新的设计思路(修复 created_at 不是入库时间的问题): 1. 查询近24小时内的所有活跃事件(按 created_at,因为这是事件发生时间) 2. 通过对比事件ID(自增ID)来判断是否为新插入的事件 3. 推送 ID > last_max_event_id 的事件 4. 更新已知事件ID集合和最大ID """ global known_event_ids_in_24h, last_max_event_id try: with app.app_context(): from datetime import datetime, timedelta current_time = datetime.now() print(f'\n[轮询 DEBUG] ========== 开始轮询 ==========') print(f'[轮询 DEBUG] 当前时间: {current_time.strftime("%Y-%m-%d %H:%M:%S")}') print(f'[轮询 DEBUG] 已知事件ID数量: {len(known_event_ids_in_24h)}') print(f'[轮询 DEBUG] 当前最大事件ID: {last_max_event_id}') # 查询近24小时内的所有活跃事件(按事件发生时间 created_at) time_24h_ago = current_time - timedelta(hours=24) print(f'[轮询 DEBUG] 查询时间范围: 近24小时({time_24h_ago.strftime("%Y-%m-%d %H:%M:%S")} ~ 现在)') # 查询所有近24小时内的活跃事件 events_in_24h = Event.query.filter( Event.created_at >= time_24h_ago, Event.status == 'active' ).order_by(Event.id.asc()).all() print(f'[轮询 DEBUG] 数据库查询结果: 找到 {len(events_in_24h)} 个近24小时内的事件') # 找出新插入的事件(ID > last_max_event_id) new_events = [ event for event in events_in_24h if event.id > last_max_event_id ] print(f'[轮询 DEBUG] 新事件数量(ID > {last_max_event_id}): {len(new_events)} 个') if new_events: print(f'[轮询] 发现 {len(new_events)} 个新事件') for event in new_events: print(f'[轮询 DEBUG] 新事件详情:') print(f'[轮询 DEBUG] - ID: {event.id}') print(f'[轮询 DEBUG] - 标题: {event.title}') print(f'[轮询 DEBUG] - 事件发生时间(created_at): {event.created_at}') print(f'[轮询 DEBUG] - 事件类型: {event.event_type}') # 推送新事件 print(f'[轮询 DEBUG] 准备推送事件 ID={event.id}') broadcast_new_event(event) print(f'[轮询] ✓ 已推送事件 ID={event.id}, 标题={event.title}') # 更新已知事件ID集合(所有近24小时内的事件ID) known_event_ids_in_24h = set(event.id for event in events_in_24h) # 更新最大事件ID new_max_id = max(event.id for event in events_in_24h) print(f'[轮询 DEBUG] 更新最大事件ID: {last_max_event_id} -> {new_max_id}') last_max_event_id = new_max_id print(f'[轮询 DEBUG] 更新后已知事件ID数量: {len(known_event_ids_in_24h)}') else: print(f'[轮询 DEBUG] 没有新事件需要推送') # 即使没有新事件,也要更新已知事件集合(清理超过24小时的) if events_in_24h: known_event_ids_in_24h = set(event.id for event in events_in_24h) current_max_id = max(event.id for event in events_in_24h) if current_max_id != last_max_event_id: print(f'[轮询 DEBUG] 更新最大事件ID: {last_max_event_id} -> {current_max_id}') last_max_event_id = current_max_id print(f'[轮询 DEBUG] ========== 轮询结束 ==========\n') except Exception as e: print(f'[轮询 ERROR] 检查新事件时出错: {e}') import traceback traceback.print_exc() def initialize_event_polling(): """ 初始化事件轮询机制 在应用启动时调用 """ global known_event_ids_in_24h, last_max_event_id try: from datetime import datetime, timedelta with app.app_context(): current_time = datetime.now() time_24h_ago = current_time - timedelta(hours=24) print(f'\n[轮询] ========== 初始化事件轮询 ==========') print(f'[轮询] 当前时间: {current_time.strftime("%Y-%m-%d %H:%M:%S")}') # 查询近24小时内的所有活跃事件 events_in_24h = Event.query.filter( Event.created_at >= time_24h_ago, Event.status == 'active' ).order_by(Event.id.asc()).all() # 初始化已知事件ID集合 known_event_ids_in_24h = set(event.id for event in events_in_24h) # 初始化最大事件ID if events_in_24h: last_max_event_id = max(event.id for event in events_in_24h) print(f'[轮询] 近24小时内共有 {len(events_in_24h)} 个活跃事件') print(f'[轮询] 初始最大事件ID: {last_max_event_id}') print(f'[轮询] 事件ID范围: {min(event.id for event in events_in_24h)} ~ {last_max_event_id}') else: last_max_event_id = 0 print(f'[轮询] 近24小时内没有活跃事件') print(f'[轮询] 初始最大事件ID: 0') # 统计数据库中的事件总数 total_events = Event.query.filter_by(status='active').count() print(f'[轮询] 数据库中共有 {total_events} 个活跃事件(所有时间)') print(f'[轮询] 只会推送 ID > {last_max_event_id} 的新事件') print(f'[轮询] ========== 初始化完成 ==========\n') # 创建后台调度器 scheduler = BackgroundScheduler() # 每 30 秒执行一次轮询 scheduler.add_job( func=poll_new_events, trigger='interval', seconds=30, id='poll_new_events', name='检查新事件并推送', replace_existing=True ) scheduler.start() print('[轮询] 调度器已启动,每 30 秒检查一次新事件') except Exception as e: print(f'[轮询] 初始化失败: {e}') # ==================== 结束 WebSocket 部分 ==================== @app.route('/api/posts//like', methods=['POST']) @login_required def like_post(post_id): """点赞/取消点赞帖子""" try: post = Post.query.get_or_404(post_id) # 检查是否已经点赞 existing_like = PostLike.query.filter_by( post_id=post_id, user_id=current_user.id ).first() if existing_like: # 取消点赞 db.session.delete(existing_like) post.likes_count = max(0, post.likes_count - 1) message = '取消点赞成功' liked = False else: # 添加点赞 new_like = PostLike(post_id=post_id, user_id=current_user.id) db.session.add(new_like) post.likes_count += 1 message = '点赞成功' liked = True db.session.commit() return jsonify({ 'success': True, 'message': message, 'likes_count': post.likes_count, 'liked': liked }) except Exception as e: db.session.rollback() print(f"点赞失败: {e}") return jsonify({ 'success': False, 'message': str(e) }), 500 @app.route('/api/comments//like', methods=['POST']) @login_required def like_comment(comment_id): """点赞/取消点赞评论""" try: comment = Comment.query.get_or_404(comment_id) # 检查是否已经点赞(需要创建 CommentLike 关联到新的 Comment 模型) # 暂时使用简单的计数器 comment.likes_count += 1 db.session.commit() return jsonify({ 'success': True, 'message': '点赞成功', 'likes_count': comment.likes_count }) except Exception as e: db.session.rollback() print(f"点赞失败: {e}") return jsonify({ 'success': False, 'message': str(e) }), 500 @app.route('/api/posts/', methods=['DELETE']) @login_required def delete_post(post_id): """删除帖子""" try: post = Post.query.get_or_404(post_id) # 检查权限:只能删除自己的帖子 if post.user_id != current_user.id: return jsonify({ 'success': False, 'message': '您只能删除自己的帖子' }), 403 # 软删除 post.status = 'deleted' # 更新事件的帖子数 event = Event.query.get(post.event_id) if event: event.post_count = Post.query.filter_by(event_id=post.event_id, status='active').count() # 更新用户发帖数 if current_user.post_count > 0: current_user.post_count -= 1 db.session.commit() return jsonify({ 'success': True, 'message': '帖子删除成功' }) except Exception as e: db.session.rollback() print(f"删除帖子失败: {e}") return jsonify({ 'success': False, 'message': str(e) }), 500 @app.route('/api/comments/', methods=['DELETE']) @login_required def delete_comment(comment_id): """删除评论""" try: comment = Comment.query.get_or_404(comment_id) # 检查权限:只能删除自己的评论 if comment.user_id != current_user.id: return jsonify({ 'success': False, 'message': '您只能删除自己的评论' }), 403 # 软删除 comment.status = 'deleted' comment.content = '[该评论已被删除]' # 更新帖子评论数 post = Post.query.get(comment.post_id) if post: post.comments_count = Comment.query.filter_by(post_id=comment.post_id, status='active').count() # 更新用户评论数 if current_user.comment_count > 0: current_user.comment_count -= 1 db.session.commit() return jsonify({ 'success': True, 'message': '评论删除成功' }) except Exception as e: db.session.rollback() print(f"删除评论失败: {e}") return jsonify({ 'success': False, 'message': str(e) }), 500 def format_decimal(value): """格式化decimal类型数据""" if value is None: return None if isinstance(value, Decimal): return float(value) return float(value) def format_date(date_obj): """格式化日期""" if date_obj is None: return None if isinstance(date_obj, datetime): return date_obj.strftime('%Y-%m-%d') return str(date_obj) def remove_cycles_from_sankey_flows(flows_data): """ 移除Sankey图数据中的循环边,确保数据是DAG(有向无环图) 使用拓扑排序算法检测循环,优先保留flow_ratio高的边 Args: flows_data: list of flow objects with 'source', 'target', 'flow_metrics' keys Returns: list of flows without cycles """ if not flows_data: return flows_data # 按flow_ratio降序排序,优先保留重要的边 sorted_flows = sorted( flows_data, key=lambda x: x.get('flow_metrics', {}).get('flow_ratio', 0) or 0, reverse=True ) # 构建图的邻接表和入度表 def build_graph(flows): graph = {} # node -> list of successors in_degree = {} # node -> in-degree count all_nodes = set() for flow in flows: source = flow['source']['node_name'] target = flow['target']['node_name'] all_nodes.add(source) all_nodes.add(target) if source not in graph: graph[source] = [] graph[source].append(target) if target not in in_degree: in_degree[target] = 0 in_degree[target] += 1 if source not in in_degree: in_degree[source] = 0 return graph, in_degree, all_nodes # 使用Kahn算法检测是否有环 def has_cycle(graph, in_degree, all_nodes): # 找到所有入度为0的节点 queue = [node for node in all_nodes if in_degree.get(node, 0) == 0] visited_count = 0 while queue: node = queue.pop(0) visited_count += 1 # 访问所有邻居 for neighbor in graph.get(node, []): in_degree[neighbor] -= 1 if in_degree[neighbor] == 0: queue.append(neighbor) # 如果访问的节点数等于总节点数,说明没有环 return visited_count < len(all_nodes) # 逐个添加边,如果添加后产生环则跳过 result_flows = [] for flow in sorted_flows: # 尝试添加这条边 temp_flows = result_flows + [flow] # 检查是否产生环 graph, in_degree, all_nodes = build_graph(temp_flows) # 复制in_degree用于检测(因为检测过程会修改它) in_degree_copy = in_degree.copy() if not has_cycle(graph, in_degree_copy, all_nodes): # 没有产生环,可以添加 result_flows.append(flow) else: # 产生环,跳过这条边 print(f"Skipping edge that creates cycle: {flow['source']['node_name']} -> {flow['target']['node_name']}") removed_count = len(flows_data) - len(result_flows) if removed_count > 0: print(f"Removed {removed_count} edges to eliminate cycles in Sankey diagram") return result_flows def get_report_type(date_str): """获取报告期类型""" if not date_str: return '' if isinstance(date_str, str): date = datetime.strptime(date_str, '%Y-%m-%d') else: date = date_str month = date.month year = date.year if month == 3: return f"{year}年一季报" elif month == 6: return f"{year}年中报" elif month == 9: return f"{year}年三季报" elif month == 12: return f"{year}年年报" else: return str(date_str) @app.route('/api/financial/stock-info/', methods=['GET']) def get_stock_info(seccode): """获取股票基本信息和最新财务摘要""" try: # 获取最新的财务数据 query = text(""" SELECT distinct a.SECCODE, a.SECNAME, a.ENDDATE, a.F003N as eps, a.F004N as basic_eps, a.F005N as diluted_eps, a.F006N as deducted_eps, a.F007N as undistributed_profit_ps, a.F008N as bvps, a.F010N as capital_reserve_ps, a.F014N as roe, a.F067N as roe_weighted, a.F016N as roa, a.F078N as gross_margin, a.F017N as net_margin, a.F089N as revenue, a.F101N as net_profit, a.F102N as parent_net_profit, a.F118N as total_assets, a.F121N as total_liabilities, a.F128N as total_equity, a.F052N as revenue_growth, a.F053N as profit_growth, a.F054N as equity_growth, a.F056N as asset_growth, a.F122N as share_capital FROM ea_financialindex a WHERE a.SECCODE = :seccode ORDER BY a.ENDDATE DESC LIMIT 1 """) with engine.connect() as conn: result = conn.execute(query, {'seccode': seccode}).fetchone() if not result: return jsonify({ 'success': False, 'message': f'未找到股票代码 {seccode} 的财务数据' }), 404 # 获取最近的业绩预告 forecast_query = text(""" SELECT distinct F001D as report_date, F003V as forecast_type, F004V as content, F007N as profit_lower, F008N as profit_upper, F009N as change_lower, F010N as change_upper FROM ea_forecast WHERE SECCODE = :seccode AND F006C = 'T' ORDER BY F001D DESC LIMIT 1 """) with engine.connect() as conn: forecast_result = conn.execute(forecast_query, {'seccode': seccode}).fetchone() data = { 'stock_code': result.SECCODE, 'stock_name': result.SECNAME, 'latest_period': format_date(result.ENDDATE), 'report_type': get_report_type(result.ENDDATE), 'key_metrics': { 'eps': format_decimal(result.eps), 'basic_eps': format_decimal(result.basic_eps), 'diluted_eps': format_decimal(result.diluted_eps), 'deducted_eps': format_decimal(result.deducted_eps), 'bvps': format_decimal(result.bvps), 'roe': format_decimal(result.roe), 'roe_weighted': format_decimal(result.roe_weighted), 'roa': format_decimal(result.roa), 'gross_margin': format_decimal(result.gross_margin), 'net_margin': format_decimal(result.net_margin), }, 'financial_summary': { 'revenue': format_decimal(result.revenue), 'net_profit': format_decimal(result.net_profit), 'parent_net_profit': format_decimal(result.parent_net_profit), 'total_assets': format_decimal(result.total_assets), 'total_liabilities': format_decimal(result.total_liabilities), 'total_equity': format_decimal(result.total_equity), 'share_capital': format_decimal(result.share_capital), }, 'growth_rates': { 'revenue_growth': format_decimal(result.revenue_growth), 'profit_growth': format_decimal(result.profit_growth), 'equity_growth': format_decimal(result.equity_growth), 'asset_growth': format_decimal(result.asset_growth), } } # 添加业绩预告信息 if forecast_result: data['latest_forecast'] = { 'report_date': format_date(forecast_result.report_date), 'forecast_type': forecast_result.forecast_type, 'content': forecast_result.content, 'profit_range': { 'lower': format_decimal(forecast_result.profit_lower), 'upper': format_decimal(forecast_result.profit_upper), }, 'change_range': { 'lower': format_decimal(forecast_result.change_lower), 'upper': format_decimal(forecast_result.change_upper), } } return jsonify({ 'success': True, 'data': data }) except Exception as e: return jsonify({ 'success': False, 'error': str(e) }), 500 @app.route('/api/financial/balance-sheet/', methods=['GET']) def get_balance_sheet(seccode): """获取完整的资产负债表数据""" try: limit = request.args.get('limit', 12, type=int) query = text(""" SELECT distinct ENDDATE, DECLAREDATE, -- 流动资产 F006N as cash, -- 货币资金 F007N as trading_financial_assets, -- 交易性金融资产 F008N as notes_receivable, -- 应收票据 F009N as accounts_receivable, -- 应收账款 F010N as prepayments, -- 预付款项 F011N as other_receivables, -- 其他应收款 F013N as interest_receivable, -- 应收利息 F014N as dividends_receivable, -- 应收股利 F015N as inventory, -- 存货 F016N as consumable_biological_assets, -- 消耗性生物资产 F017N as non_current_assets_due_within_one_year, -- 一年内到期的非流动资产 F018N as other_current_assets, -- 其他流动资产 F019N as total_current_assets, -- 流动资产合计 -- 非流动资产 F020N as available_for_sale_financial_assets, -- 可供出售金融资产 F021N as held_to_maturity_investments, -- 持有至到期投资 F022N as long_term_receivables, -- 长期应收款 F023N as long_term_equity_investments, -- 长期股权投资 F024N as investment_property, -- 投资性房地产 F025N as fixed_assets, -- 固定资产 F026N as construction_in_progress, -- 在建工程 F027N as engineering_materials, -- 工程物资 F029N as productive_biological_assets, -- 生产性生物资产 F030N as oil_and_gas_assets, -- 油气资产 F031N as intangible_assets, -- 无形资产 F032N as development_expenditure, -- 开发支出 F033N as goodwill, -- 商誉 F034N as long_term_deferred_expenses, -- 长期待摊费用 F035N as deferred_tax_assets, -- 递延所得税资产 F036N as other_non_current_assets, -- 其他非流动资产 F037N as total_non_current_assets, -- 非流动资产合计 F038N as total_assets, -- 资产总计 -- 流动负债 F039N as short_term_borrowings, -- 短期借款 F040N as trading_financial_liabilities, -- 交易性金融负债 F041N as notes_payable, -- 应付票据 F042N as accounts_payable, -- 应付账款 F043N as advance_receipts, -- 预收款项 F044N as employee_compensation_payable, -- 应付职工薪酬 F045N as taxes_payable, -- 应交税费 F046N as interest_payable, -- 应付利息 F047N as dividends_payable, -- 应付股利 F048N as other_payables, -- 其他应付款 F050N as non_current_liabilities_due_within_one_year, -- 一年内到期的非流动负债 F051N as other_current_liabilities, -- 其他流动负债 F052N as total_current_liabilities, -- 流动负债合计 -- 非流动负债 F053N as long_term_borrowings, -- 长期借款 F054N as bonds_payable, -- 应付债券 F055N as long_term_payables, -- 长期应付款 F056N as special_payables, -- 专项应付款 F057N as estimated_liabilities, -- 预计负债 F058N as deferred_tax_liabilities, -- 递延所得税负债 F059N as other_non_current_liabilities, -- 其他非流动负债 F060N as total_non_current_liabilities, -- 非流动负债合计 F061N as total_liabilities, -- 负债合计 -- 所有者权益 F062N as share_capital, -- 股本 F063N as capital_reserve, -- 资本公积 F064N as surplus_reserve, -- 盈余公积 F065N as undistributed_profit, -- 未分配利润 F066N as treasury_stock, -- 库存股 F067N as minority_interests, -- 少数股东权益 F070N as total_equity, -- 所有者权益合计 F071N as total_liabilities_and_equity, -- 负债和所有者权益合计 F073N as parent_company_equity, -- 归属于母公司所有者权益 F074N as other_comprehensive_income, -- 其他综合收益 -- 新会计准则科目 F110N as other_debt_investments, -- 其他债权投资 F111N as other_equity_investments, -- 其他权益工具投资 F112N as other_non_current_financial_assets, -- 其他非流动金融资产 F115N as contract_liabilities, -- 合同负债 F119N as contract_assets, -- 合同资产 F120N as receivables_financing, -- 应收款项融资 F121N as right_of_use_assets, -- 使用权资产 F122N as lease_liabilities -- 租赁负债 FROM ea_asset WHERE SECCODE = :seccode and F002V = '071001' ORDER BY ENDDATE DESC LIMIT :limit """) with engine.connect() as conn: result = conn.execute(query, {'seccode': seccode, 'limit': limit}) data = [] for row in result: # 安全计算关键比率,避免 Decimal 与 None 运算错误 def to_float(v): try: return float(v) if v is not None else None except Exception: return None ta = to_float(row.total_assets) tl = to_float(row.total_liabilities) tca = to_float(row.total_current_assets) tcl = to_float(row.total_current_liabilities) inv = to_float(row.inventory) or 0.0 asset_liability_ratio_val = None if ta is not None and ta != 0 and tl is not None: asset_liability_ratio_val = (tl / ta) * 100 current_ratio_val = None if tcl is not None and tcl != 0 and tca is not None: current_ratio_val = tca / tcl quick_ratio_val = None if tcl is not None and tcl != 0 and tca is not None: quick_ratio_val = (tca - inv) / tcl period_data = { 'period': format_date(row.ENDDATE), 'declare_date': format_date(row.DECLAREDATE), 'report_type': get_report_type(row.ENDDATE), # 资产部分 'assets': { 'current_assets': { 'cash': format_decimal(row.cash), 'trading_financial_assets': format_decimal(row.trading_financial_assets), 'notes_receivable': format_decimal(row.notes_receivable), 'accounts_receivable': format_decimal(row.accounts_receivable), 'prepayments': format_decimal(row.prepayments), 'other_receivables': format_decimal(row.other_receivables), 'inventory': format_decimal(row.inventory), 'contract_assets': format_decimal(row.contract_assets), 'other_current_assets': format_decimal(row.other_current_assets), 'total': format_decimal(row.total_current_assets), }, 'non_current_assets': { 'long_term_equity_investments': format_decimal(row.long_term_equity_investments), 'investment_property': format_decimal(row.investment_property), 'fixed_assets': format_decimal(row.fixed_assets), 'construction_in_progress': format_decimal(row.construction_in_progress), 'intangible_assets': format_decimal(row.intangible_assets), 'goodwill': format_decimal(row.goodwill), 'right_of_use_assets': format_decimal(row.right_of_use_assets), 'deferred_tax_assets': format_decimal(row.deferred_tax_assets), 'other_non_current_assets': format_decimal(row.other_non_current_assets), 'total': format_decimal(row.total_non_current_assets), }, 'total': format_decimal(row.total_assets), }, # 负债部分 'liabilities': { 'current_liabilities': { 'short_term_borrowings': format_decimal(row.short_term_borrowings), 'notes_payable': format_decimal(row.notes_payable), 'accounts_payable': format_decimal(row.accounts_payable), 'advance_receipts': format_decimal(row.advance_receipts), 'contract_liabilities': format_decimal(row.contract_liabilities), 'employee_compensation_payable': format_decimal(row.employee_compensation_payable), 'taxes_payable': format_decimal(row.taxes_payable), 'other_payables': format_decimal(row.other_payables), 'non_current_liabilities_due_within_one_year': format_decimal( row.non_current_liabilities_due_within_one_year), 'total': format_decimal(row.total_current_liabilities), }, 'non_current_liabilities': { 'long_term_borrowings': format_decimal(row.long_term_borrowings), 'bonds_payable': format_decimal(row.bonds_payable), 'lease_liabilities': format_decimal(row.lease_liabilities), 'deferred_tax_liabilities': format_decimal(row.deferred_tax_liabilities), 'other_non_current_liabilities': format_decimal(row.other_non_current_liabilities), 'total': format_decimal(row.total_non_current_liabilities), }, 'total': format_decimal(row.total_liabilities), }, # 股东权益部分 'equity': { 'share_capital': format_decimal(row.share_capital), 'capital_reserve': format_decimal(row.capital_reserve), 'surplus_reserve': format_decimal(row.surplus_reserve), 'undistributed_profit': format_decimal(row.undistributed_profit), 'treasury_stock': format_decimal(row.treasury_stock), 'other_comprehensive_income': format_decimal(row.other_comprehensive_income), 'parent_company_equity': format_decimal(row.parent_company_equity), 'minority_interests': format_decimal(row.minority_interests), 'total': format_decimal(row.total_equity), }, # 关键比率 'key_ratios': { 'asset_liability_ratio': format_decimal(asset_liability_ratio_val), 'current_ratio': format_decimal(current_ratio_val), 'quick_ratio': format_decimal(quick_ratio_val), } } data.append(period_data) return jsonify({ 'success': True, 'data': data }) except Exception as e: return jsonify({ 'success': False, 'error': str(e) }), 500 @app.route('/api/financial/income-statement/', methods=['GET']) def get_income_statement(seccode): """获取完整的利润表数据""" try: limit = request.args.get('limit', 12, type=int) query = text(""" SELECT distinct ENDDATE, STARTDATE, DECLAREDATE, -- 营业收入部分 F006N as revenue, -- 营业收入 F035N as total_operating_revenue, -- 营业总收入 F051N as other_income, -- 其他收入 -- 营业成本部分 F007N as cost, -- 营业成本 F008N as taxes_and_surcharges, -- 税金及附加 F009N as selling_expenses, -- 销售费用 F010N as admin_expenses, -- 管理费用 F056N as rd_expenses, -- 研发费用 F012N as financial_expenses, -- 财务费用 F062N as interest_expense, -- 利息费用 F063N as interest_income, -- 利息收入 F013N as asset_impairment_loss, -- 资产减值损失(营业总成本) F057N as credit_impairment_loss, -- 信用减值损失(营业总成本) F036N as total_operating_cost, -- 营业总成本 -- 其他收益 F014N as fair_value_change_income, -- 公允价值变动净收益 F015N as investment_income, -- 投资收益 F016N as investment_income_from_associates, -- 对联营企业和合营企业的投资收益 F037N as exchange_income, -- 汇兑收益 F058N as net_exposure_hedging_income, -- 净敞口套期收益 F059N as asset_disposal_income, -- 资产处置收益 -- 利润部分 F018N as operating_profit, -- 营业利润 F019N as subsidy_income, -- 补贴收入 F020N as non_operating_income, -- 营业外收入 F021N as non_operating_expenses, -- 营业外支出 F022N as non_current_asset_disposal_loss, -- 非流动资产处置损失 F024N as total_profit, -- 利润总额 F025N as income_tax_expense, -- 所得税 F027N as net_profit, -- 净利润 F028N as parent_net_profit, -- 归属于母公司所有者的净利润 F029N as minority_profit, -- 少数股东损益 -- 持续经营 F060N as continuing_operations_net_profit, -- 持续经营净利润 F061N as discontinued_operations_net_profit, -- 终止经营净利润 -- 每股收益 F031N as basic_eps, -- 基本每股收益 F032N as diluted_eps, -- 稀释每股收益 -- 综合收益 F038N as other_comprehensive_income_after_tax, -- 其他综合收益的税后净额 F039N as total_comprehensive_income, -- 综合收益总额 F040N as parent_company_comprehensive_income, -- 归属于母公司的综合收益 F041N as minority_comprehensive_income -- 归属于少数股东的综合收益 FROM ea_profit WHERE SECCODE = :seccode and F002V = '071001' ORDER BY ENDDATE DESC LIMIT :limit """) with engine.connect() as conn: result = conn.execute(query, {'seccode': seccode, 'limit': limit}) data = [] for row in result: # 计算一些衍生指标 gross_profit = (row.revenue - row.cost) if row.revenue and row.cost else None gross_margin = (gross_profit / row.revenue * 100) if row.revenue and gross_profit else None operating_margin = ( row.operating_profit / row.revenue * 100) if row.revenue and row.operating_profit else None net_margin = (row.net_profit / row.revenue * 100) if row.revenue and row.net_profit else None # 三费合计 three_expenses = 0 if row.selling_expenses: three_expenses += row.selling_expenses if row.admin_expenses: three_expenses += row.admin_expenses if row.financial_expenses: three_expenses += row.financial_expenses # 四费合计(加研发) four_expenses = three_expenses if row.rd_expenses: four_expenses += row.rd_expenses period_data = { 'period': format_date(row.ENDDATE), 'start_date': format_date(row.STARTDATE), 'declare_date': format_date(row.DECLAREDATE), 'report_type': get_report_type(row.ENDDATE), # 收入部分 'revenue': { 'operating_revenue': format_decimal(row.revenue), 'total_operating_revenue': format_decimal(row.total_operating_revenue), 'other_income': format_decimal(row.other_income), }, # 成本费用部分 'costs': { 'operating_cost': format_decimal(row.cost), 'taxes_and_surcharges': format_decimal(row.taxes_and_surcharges), 'selling_expenses': format_decimal(row.selling_expenses), 'admin_expenses': format_decimal(row.admin_expenses), 'rd_expenses': format_decimal(row.rd_expenses), 'financial_expenses': format_decimal(row.financial_expenses), 'interest_expense': format_decimal(row.interest_expense), 'interest_income': format_decimal(row.interest_income), 'asset_impairment_loss': format_decimal(row.asset_impairment_loss), 'credit_impairment_loss': format_decimal(row.credit_impairment_loss), 'total_operating_cost': format_decimal(row.total_operating_cost), 'three_expenses_total': format_decimal(three_expenses), 'four_expenses_total': format_decimal(four_expenses), }, # 其他收益 'other_gains': { 'fair_value_change': format_decimal(row.fair_value_change_income), 'investment_income': format_decimal(row.investment_income), 'investment_income_from_associates': format_decimal(row.investment_income_from_associates), 'exchange_income': format_decimal(row.exchange_income), 'asset_disposal_income': format_decimal(row.asset_disposal_income), }, # 利润 'profit': { 'gross_profit': format_decimal(gross_profit), 'operating_profit': format_decimal(row.operating_profit), 'total_profit': format_decimal(row.total_profit), 'net_profit': format_decimal(row.net_profit), 'parent_net_profit': format_decimal(row.parent_net_profit), 'minority_profit': format_decimal(row.minority_profit), 'continuing_operations_net_profit': format_decimal(row.continuing_operations_net_profit), 'discontinued_operations_net_profit': format_decimal(row.discontinued_operations_net_profit), }, # 非经营项目 'non_operating': { 'subsidy_income': format_decimal(row.subsidy_income), 'non_operating_income': format_decimal(row.non_operating_income), 'non_operating_expenses': format_decimal(row.non_operating_expenses), }, # 每股收益 'per_share': { 'basic_eps': format_decimal(row.basic_eps), 'diluted_eps': format_decimal(row.diluted_eps), }, # 综合收益 'comprehensive_income': { 'other_comprehensive_income': format_decimal(row.other_comprehensive_income_after_tax), 'total_comprehensive_income': format_decimal(row.total_comprehensive_income), 'parent_comprehensive_income': format_decimal(row.parent_company_comprehensive_income), 'minority_comprehensive_income': format_decimal(row.minority_comprehensive_income), }, # 关键比率 'margins': { 'gross_margin': format_decimal(gross_margin), 'operating_margin': format_decimal(operating_margin), 'net_margin': format_decimal(net_margin), 'expense_ratio': format_decimal(four_expenses / row.revenue * 100) if row.revenue else None, 'rd_ratio': format_decimal( row.rd_expenses / row.revenue * 100) if row.revenue and row.rd_expenses else None, } } data.append(period_data) return jsonify({ 'success': True, 'data': data }) except Exception as e: return jsonify({ 'success': False, 'error': str(e) }), 500 @app.route('/api/financial/cashflow/', methods=['GET']) def get_cashflow(seccode): """获取完整的现金流量表数据""" try: limit = request.args.get('limit', 12, type=int) query = text(""" SELECT distinct ENDDATE, STARTDATE, DECLAREDATE, -- 经营活动现金流 F006N as cash_from_sales, -- 销售商品、提供劳务收到的现金 F007N as tax_refunds, -- 收到的税费返还 F008N as other_operating_cash_received, -- 收到其他与经营活动有关的现金 F009N as total_operating_cash_inflow, -- 经营活动现金流入小计 F010N as cash_paid_for_goods, -- 购买商品、接受劳务支付的现金 F011N as cash_paid_to_employees, -- 支付给职工以及为职工支付的现金 F012N as taxes_paid, -- 支付的各项税费 F013N as other_operating_cash_paid, -- 支付其他与经营活动有关的现金 F014N as total_operating_cash_outflow, -- 经营活动现金流出小计 F015N as net_operating_cash_flow, -- 经营活动产生的现金流量净额 -- 投资活动现金流 F016N as cash_from_investment_recovery, -- 收回投资收到的现金 F017N as cash_from_investment_income, -- 取得投资收益收到的现金 F018N as cash_from_asset_disposal, -- 处置固定资产、无形资产和其他长期资产收回的现金净额 F019N as cash_from_subsidiary_disposal, -- 处置子公司及其他营业单位收到的现金净额 F020N as other_investment_cash_received, -- 收到其他与投资活动有关的现金 F021N as total_investment_cash_inflow, -- 投资活动现金流入小计 F022N as cash_paid_for_assets, -- 购建固定资产、无形资产和其他长期资产支付的现金 F023N as cash_paid_for_investments, -- 投资支付的现金 F024N as cash_paid_for_subsidiaries, -- 取得子公司及其他营业单位支付的现金净额 F025N as other_investment_cash_paid, -- 支付其他与投资活动有关的现金 F026N as total_investment_cash_outflow, -- 投资活动现金流出小计 F027N as net_investment_cash_flow, -- 投资活动产生的现金流量净额 -- 筹资活动现金流 F028N as cash_from_capital, -- 吸收投资收到的现金 F029N as cash_from_borrowings, -- 取得借款收到的现金 F030N as other_financing_cash_received, -- 收到其他与筹资活动有关的现金 F031N as total_financing_cash_inflow, -- 筹资活动现金流入小计 F032N as cash_paid_for_debt, -- 偿还债务支付的现金 F033N as cash_paid_for_distribution, -- 分配股利、利润或偿付利息支付的现金 F034N as other_financing_cash_paid, -- 支付其他与筹资活动有关的现金 F035N as total_financing_cash_outflow, -- 筹资活动现金流出小计 F036N as net_financing_cash_flow, -- 筹资活动产生的现金流量净额 -- 汇率变动影响 F037N as exchange_rate_effect, -- 汇率变动对现金及现金等价物的影响 F038N as other_cash_effect, -- 其他原因对现金的影响 -- 现金净增加额 F039N as net_cash_increase, -- 现金及现金等价物净增加额 F040N as beginning_cash_balance, -- 期初现金及现金等价物余额 F041N as ending_cash_balance, -- 期末现金及现金等价物余额 -- 补充资料部分 F044N as net_profit, -- 净利润 F045N as asset_impairment, -- 资产减值准备 F096N as credit_impairment, -- 信用减值损失 F046N as depreciation, -- 固定资产折旧、油气资产折耗、生产性生物资产折旧 F097N as right_of_use_asset_depreciation, -- 使用权资产折旧/摊销 F047N as intangible_amortization, -- 无形资产摊销 F048N as long_term_expense_amortization, -- 长期待摊费用摊销 F049N as loss_on_disposal, -- 处置固定资产、无形资产和其他长期资产的损失 F050N as fixed_asset_scrap_loss, -- 固定资产报废损失 F051N as fair_value_change_loss, -- 公允价值变动损失 F052N as financial_expenses, -- 财务费用 F053N as investment_loss, -- 投资损失 F054N as deferred_tax_asset_decrease, -- 递延所得税资产减少 F055N as deferred_tax_liability_increase, -- 递延所得税负债增加 F056N as inventory_decrease, -- 存货的减少 F057N as operating_receivables_decrease, -- 经营性应收项目的减少 F058N as operating_payables_increase, -- 经营性应付项目的增加 F059N as other, -- 其他 F060N as net_operating_cash_flow_indirect, -- 经营活动产生的现金流量净额(间接法) -- 特殊行业科目(金融) F072N as customer_deposit_increase, -- 客户存款和同业存放款项净增加额 F073N as central_bank_borrowing_increase, -- 向中央银行借款净增加额 F081N as interest_and_commission_received, -- 收取利息、手续费及佣金的现金 F087N as interest_and_commission_paid -- 支付利息、手续费及佣金的现金 FROM ea_cashflow WHERE SECCODE = :seccode and F002V = '071001' ORDER BY ENDDATE DESC LIMIT :limit """) with engine.connect() as conn: result = conn.execute(query, {'seccode': seccode, 'limit': limit}) data = [] for row in result: # 计算一些衍生指标 free_cash_flow = None if row.net_operating_cash_flow and row.cash_paid_for_assets: free_cash_flow = row.net_operating_cash_flow - row.cash_paid_for_assets period_data = { 'period': format_date(row.ENDDATE), 'start_date': format_date(row.STARTDATE), 'declare_date': format_date(row.DECLAREDATE), 'report_type': get_report_type(row.ENDDATE), # 经营活动现金流 'operating_activities': { 'inflow': { 'cash_from_sales': format_decimal(row.cash_from_sales), 'tax_refunds': format_decimal(row.tax_refunds), 'other': format_decimal(row.other_operating_cash_received), 'total': format_decimal(row.total_operating_cash_inflow), }, 'outflow': { 'cash_for_goods': format_decimal(row.cash_paid_for_goods), 'cash_for_employees': format_decimal(row.cash_paid_to_employees), 'taxes_paid': format_decimal(row.taxes_paid), 'other': format_decimal(row.other_operating_cash_paid), 'total': format_decimal(row.total_operating_cash_outflow), }, 'net_flow': format_decimal(row.net_operating_cash_flow), }, # 投资活动现金流 'investment_activities': { 'inflow': { 'investment_recovery': format_decimal(row.cash_from_investment_recovery), 'investment_income': format_decimal(row.cash_from_investment_income), 'asset_disposal': format_decimal(row.cash_from_asset_disposal), 'subsidiary_disposal': format_decimal(row.cash_from_subsidiary_disposal), 'other': format_decimal(row.other_investment_cash_received), 'total': format_decimal(row.total_investment_cash_inflow), }, 'outflow': { 'asset_purchase': format_decimal(row.cash_paid_for_assets), 'investments': format_decimal(row.cash_paid_for_investments), 'subsidiaries': format_decimal(row.cash_paid_for_subsidiaries), 'other': format_decimal(row.other_investment_cash_paid), 'total': format_decimal(row.total_investment_cash_outflow), }, 'net_flow': format_decimal(row.net_investment_cash_flow), }, # 筹资活动现金流 'financing_activities': { 'inflow': { 'capital': format_decimal(row.cash_from_capital), 'borrowings': format_decimal(row.cash_from_borrowings), 'other': format_decimal(row.other_financing_cash_received), 'total': format_decimal(row.total_financing_cash_inflow), }, 'outflow': { 'debt_repayment': format_decimal(row.cash_paid_for_debt), 'distribution': format_decimal(row.cash_paid_for_distribution), 'other': format_decimal(row.other_financing_cash_paid), 'total': format_decimal(row.total_financing_cash_outflow), }, 'net_flow': format_decimal(row.net_financing_cash_flow), }, # 现金变动 'cash_changes': { 'exchange_rate_effect': format_decimal(row.exchange_rate_effect), 'other_effect': format_decimal(row.other_cash_effect), 'net_increase': format_decimal(row.net_cash_increase), 'beginning_balance': format_decimal(row.beginning_cash_balance), 'ending_balance': format_decimal(row.ending_cash_balance), }, # 补充资料(间接法) 'indirect_method': { 'net_profit': format_decimal(row.net_profit), 'adjustments': { 'asset_impairment': format_decimal(row.asset_impairment), 'credit_impairment': format_decimal(row.credit_impairment), 'depreciation': format_decimal(row.depreciation), 'intangible_amortization': format_decimal(row.intangible_amortization), 'financial_expenses': format_decimal(row.financial_expenses), 'investment_loss': format_decimal(row.investment_loss), 'inventory_decrease': format_decimal(row.inventory_decrease), 'receivables_decrease': format_decimal(row.operating_receivables_decrease), 'payables_increase': format_decimal(row.operating_payables_increase), }, 'net_operating_cash_flow': format_decimal(row.net_operating_cash_flow_indirect), }, # 关键指标 'key_metrics': { 'free_cash_flow': format_decimal(free_cash_flow), 'cash_flow_to_profit_ratio': format_decimal( row.net_operating_cash_flow / row.net_profit) if row.net_profit and row.net_operating_cash_flow else None, 'capex': format_decimal(row.cash_paid_for_assets), } } data.append(period_data) return jsonify({ 'success': True, 'data': data }) except Exception as e: return jsonify({ 'success': False, 'error': str(e) }), 500 @app.route('/api/financial/financial-metrics/', methods=['GET']) def get_financial_metrics(seccode): """获取完整的财务指标数据""" try: limit = request.args.get('limit', 12, type=int) query = text(""" SELECT distinct ENDDATE, STARTDATE, -- 每股指标 F003N as eps, -- 每股收益 F004N as basic_eps, -- 基本每股收益 F005N as diluted_eps, -- 稀释每股收益 F006N as deducted_eps, -- 扣除非经常性损益每股收益 F007N as undistributed_profit_ps, -- 每股未分配利润 F008N as bvps, -- 每股净资产 F009N as adjusted_bvps, -- 调整后每股净资产 F010N as capital_reserve_ps, -- 每股资本公积金 F059N as cash_flow_ps, -- 每股现金流量 F060N as operating_cash_flow_ps, -- 每股经营现金流量 -- 盈利能力指标 F011N as operating_profit_margin, -- 营业利润率 F012N as tax_rate, -- 营业税金率 F013N as cost_ratio, -- 营业成本率 F014N as roe, -- 净资产收益率 F066N as roe_deducted, -- 净资产收益率(扣除非经常性损益) F067N as roe_weighted, -- 净资产收益率-加权 F068N as roe_weighted_deducted, -- 净资产收益率-加权(扣除非经常性损益) F015N as investment_return, -- 投资收益率 F016N as roa, -- 总资产报酬率 F017N as net_profit_margin, -- 净利润率 F078N as gross_margin, -- 毛利率 F020N as cost_profit_ratio, -- 成本费用利润率 -- 费用率指标 F018N as admin_expense_ratio, -- 管理费用率 F019N as financial_expense_ratio, -- 财务费用率 F021N as three_expense_ratio, -- 三费比重 F091N as selling_expense, -- 销售费用 F092N as admin_expense, -- 管理费用 F093N as financial_expense, -- 财务费用 F094N as three_expense_total, -- 三费合计 F130N as rd_expense, -- 研发费用 F131N as rd_expense_ratio, -- 研发费用率 F132N as selling_expense_ratio, -- 销售费用率 F133N as four_expense_ratio, -- 四费费用率 -- 运营能力指标 F022N as receivable_turnover, -- 应收账款周转率 F023N as inventory_turnover, -- 存货周转率 F024N as working_capital_turnover, -- 运营资金周转率 F025N as total_asset_turnover, -- 总资产周转率 F026N as fixed_asset_turnover, -- 固定资产周转率 F027N as receivable_days, -- 应收账款周转天数 F028N as inventory_days, -- 存货周转天数 F029N as current_asset_turnover, -- 流动资产周转率 F030N as current_asset_days, -- 流动资产周转天数 F031N as total_asset_days, -- 总资产周转天数 F032N as equity_turnover, -- 股东权益周转率 -- 偿债能力指标 F041N as asset_liability_ratio, -- 资产负债率 F042N as current_ratio, -- 流动比率 F043N as quick_ratio, -- 速动比率 F044N as cash_ratio, -- 现金比率 F045N as interest_coverage, -- 利息保障倍数 F049N as conservative_quick_ratio, -- 保守速动比率 F050N as cash_to_maturity_debt_ratio, -- 现金到期债务比率 F051N as tangible_asset_debt_ratio, -- 有形资产净值债务率 -- 成长能力指标 F052N as revenue_growth, -- 营业收入增长率 F053N as net_profit_growth, -- 净利润增长率 F054N as equity_growth, -- 净资产增长率 F055N as fixed_asset_growth, -- 固定资产增长率 F056N as total_asset_growth, -- 总资产增长率 F057N as investment_income_growth, -- 投资收益增长率 F058N as operating_profit_growth, -- 营业利润增长率 F141N as deducted_profit_growth, -- 扣除非经常性损益后的净利润同比变化率 F142N as parent_profit_growth, -- 归属于母公司所有者的净利润同比变化率 F143N as operating_cash_flow_growth, -- 经营活动产生的现金流净额同比变化率 -- 现金流量指标 F061N as operating_cash_to_short_debt, -- 经营净现金比率(短期债务) F062N as operating_cash_to_total_debt, -- 经营净现金比率(全部债务) F063N as operating_cash_to_profit_ratio, -- 经营活动现金净流量与净利润比率 F064N as cash_revenue_ratio, -- 营业收入现金含量 F065N as cash_recovery_rate, -- 全部资产现金回收率 F082N as cash_to_profit_ratio, -- 净利含金量 -- 财务结构指标 F033N as current_asset_ratio, -- 流动资产比率 F034N as cash_ratio_structure, -- 货币资金比率 F036N as inventory_ratio, -- 存货比率 F037N as fixed_asset_ratio, -- 固定资产比率 F038N as liability_structure_ratio, -- 负债结构比 F039N as equity_ratio, -- 产权比率 F040N as net_asset_ratio, -- 净资产比率 F046N as working_capital, -- 营运资金 F047N as non_current_liability_ratio, -- 非流动负债比率 F048N as current_liability_ratio, -- 流动负债比率 -- 非经常性损益 F076N as deducted_net_profit, -- 扣除非经常性损益后的净利润 F077N as non_recurring_items, -- 非经常性损益合计 F083N as non_recurring_ratio, -- 非经常性损益占比 -- 综合指标 F085N as ebit, -- 基本获利能力(EBIT) F086N as receivable_to_asset_ratio, -- 应收账款占比 F087N as inventory_to_asset_ratio -- 存货占比 FROM ea_financialindex WHERE SECCODE = :seccode ORDER BY ENDDATE DESC LIMIT :limit """) with engine.connect() as conn: result = conn.execute(query, {'seccode': seccode, 'limit': limit}) data = [] for row in result: period_data = { 'period': format_date(row.ENDDATE), 'start_date': format_date(row.STARTDATE), 'report_type': get_report_type(row.ENDDATE), # 每股指标 'per_share_metrics': { 'eps': format_decimal(row.eps), 'basic_eps': format_decimal(row.basic_eps), 'diluted_eps': format_decimal(row.diluted_eps), 'deducted_eps': format_decimal(row.deducted_eps), 'bvps': format_decimal(row.bvps), 'adjusted_bvps': format_decimal(row.adjusted_bvps), 'undistributed_profit_ps': format_decimal(row.undistributed_profit_ps), 'capital_reserve_ps': format_decimal(row.capital_reserve_ps), 'cash_flow_ps': format_decimal(row.cash_flow_ps), 'operating_cash_flow_ps': format_decimal(row.operating_cash_flow_ps), }, # 盈利能力 'profitability': { 'roe': format_decimal(row.roe), 'roe_deducted': format_decimal(row.roe_deducted), 'roe_weighted': format_decimal(row.roe_weighted), 'roa': format_decimal(row.roa), 'gross_margin': format_decimal(row.gross_margin), 'net_profit_margin': format_decimal(row.net_profit_margin), 'operating_profit_margin': format_decimal(row.operating_profit_margin), 'cost_profit_ratio': format_decimal(row.cost_profit_ratio), 'ebit': format_decimal(row.ebit), }, # 费用率 'expense_ratios': { 'selling_expense_ratio': format_decimal(row.selling_expense_ratio), 'admin_expense_ratio': format_decimal(row.admin_expense_ratio), 'financial_expense_ratio': format_decimal(row.financial_expense_ratio), 'rd_expense_ratio': format_decimal(row.rd_expense_ratio), 'three_expense_ratio': format_decimal(row.three_expense_ratio), 'four_expense_ratio': format_decimal(row.four_expense_ratio), }, # 运营能力 'operational_efficiency': { 'receivable_turnover': format_decimal(row.receivable_turnover), 'receivable_days': format_decimal(row.receivable_days), 'inventory_turnover': format_decimal(row.inventory_turnover), 'inventory_days': format_decimal(row.inventory_days), 'total_asset_turnover': format_decimal(row.total_asset_turnover), 'total_asset_days': format_decimal(row.total_asset_days), 'fixed_asset_turnover': format_decimal(row.fixed_asset_turnover), 'current_asset_turnover': format_decimal(row.current_asset_turnover), 'working_capital_turnover': format_decimal(row.working_capital_turnover), }, # 偿债能力 'solvency': { 'current_ratio': format_decimal(row.current_ratio), 'quick_ratio': format_decimal(row.quick_ratio), 'cash_ratio': format_decimal(row.cash_ratio), 'conservative_quick_ratio': format_decimal(row.conservative_quick_ratio), 'asset_liability_ratio': format_decimal(row.asset_liability_ratio), 'interest_coverage': format_decimal(row.interest_coverage), 'cash_to_maturity_debt_ratio': format_decimal(row.cash_to_maturity_debt_ratio), 'tangible_asset_debt_ratio': format_decimal(row.tangible_asset_debt_ratio), }, # 成长能力 'growth': { 'revenue_growth': format_decimal(row.revenue_growth), 'net_profit_growth': format_decimal(row.net_profit_growth), 'deducted_profit_growth': format_decimal(row.deducted_profit_growth), 'parent_profit_growth': format_decimal(row.parent_profit_growth), 'equity_growth': format_decimal(row.equity_growth), 'total_asset_growth': format_decimal(row.total_asset_growth), 'fixed_asset_growth': format_decimal(row.fixed_asset_growth), 'operating_profit_growth': format_decimal(row.operating_profit_growth), 'operating_cash_flow_growth': format_decimal(row.operating_cash_flow_growth), }, # 现金流量 'cash_flow_quality': { 'operating_cash_to_profit_ratio': format_decimal(row.operating_cash_to_profit_ratio), 'cash_to_profit_ratio': format_decimal(row.cash_to_profit_ratio), 'cash_revenue_ratio': format_decimal(row.cash_revenue_ratio), 'cash_recovery_rate': format_decimal(row.cash_recovery_rate), 'operating_cash_to_short_debt': format_decimal(row.operating_cash_to_short_debt), 'operating_cash_to_total_debt': format_decimal(row.operating_cash_to_total_debt), }, # 财务结构 'financial_structure': { 'current_asset_ratio': format_decimal(row.current_asset_ratio), 'fixed_asset_ratio': format_decimal(row.fixed_asset_ratio), 'inventory_ratio': format_decimal(row.inventory_ratio), 'receivable_to_asset_ratio': format_decimal(row.receivable_to_asset_ratio), 'current_liability_ratio': format_decimal(row.current_liability_ratio), 'non_current_liability_ratio': format_decimal(row.non_current_liability_ratio), 'equity_ratio': format_decimal(row.equity_ratio), }, # 非经常性损益 'non_recurring': { 'deducted_net_profit': format_decimal(row.deducted_net_profit), 'non_recurring_items': format_decimal(row.non_recurring_items), 'non_recurring_ratio': format_decimal(row.non_recurring_ratio), } } data.append(period_data) return jsonify({ 'success': True, 'data': data }) except Exception as e: return jsonify({ 'success': False, 'error': str(e) }), 500 @app.route('/api/financial/main-business/', methods=['GET']) def get_main_business(seccode): """获取主营业务构成数据(包括产品和行业分类)""" try: limit = request.args.get('periods', 4, type=int) # 获取最近几期的数据 # 获取最近的报告期 period_query = text(""" SELECT DISTINCT ENDDATE FROM ea_mainproduct WHERE SECCODE = :seccode ORDER BY ENDDATE DESC LIMIT :limit """) with engine.connect() as conn: periods = conn.execute(period_query, {'seccode': seccode, 'limit': limit}).fetchall() # 产品分类数据 product_data = [] for period in periods: query = text(""" SELECT distinct ENDDATE, F002V as category, F003V as content, F005N as revenue, F006N as cost, F007N as profit FROM ea_mainproduct WHERE SECCODE = :seccode AND ENDDATE = :enddate ORDER BY F005N DESC """) with engine.connect() as conn: result = conn.execute(query, {'seccode': seccode, 'enddate': period[0]}) # Convert result to list to allow multiple iterations rows = list(result) period_products = [] total_revenue = 0 for row in rows: if row.revenue: total_revenue += row.revenue for row in rows: product = { 'category': row.category, 'content': row.content, 'revenue': format_decimal(row.revenue), 'cost': format_decimal(row.cost), 'profit': format_decimal(row.profit), 'profit_margin': format_decimal( (row.profit / row.revenue * 100) if row.revenue and row.profit else None), 'revenue_ratio': format_decimal( (row.revenue / total_revenue * 100) if total_revenue and row.revenue else None) } period_products.append(product) if period_products: product_data.append({ 'period': format_date(period[0]), 'report_type': get_report_type(period[0]), 'total_revenue': format_decimal(total_revenue), 'products': period_products }) # 行业分类数据(从ea_mainind表) industry_data = [] for period in periods: query = text(""" SELECT distinct ENDDATE, F002V as business_content, F007N as main_revenue, F008N as main_cost, F009N as main_profit, F010N as gross_margin, F012N as revenue_ratio FROM ea_mainind WHERE SECCODE = :seccode AND ENDDATE = :enddate ORDER BY F007N DESC """) with engine.connect() as conn: result = conn.execute(query, {'seccode': seccode, 'enddate': period[0]}) # Convert result to list to allow multiple iterations rows = list(result) period_industries = [] for row in rows: industry = { 'content': row.business_content, 'revenue': format_decimal(row.main_revenue), 'cost': format_decimal(row.main_cost), 'profit': format_decimal(row.main_profit), 'gross_margin': format_decimal(row.gross_margin), 'revenue_ratio': format_decimal(row.revenue_ratio) } period_industries.append(industry) if period_industries: industry_data.append({ 'period': format_date(period[0]), 'report_type': get_report_type(period[0]), 'industries': period_industries }) return jsonify({ 'success': True, 'data': { 'product_classification': product_data, 'industry_classification': industry_data } }) except Exception as e: return jsonify({ 'success': False, 'error': str(e) }), 500 @app.route('/api/financial/forecast/', methods=['GET']) def get_forecast(seccode): """获取业绩预告和预披露时间""" try: # 获取业绩预告 forecast_query = text(""" SELECT distinct DECLAREDATE, F001D as report_date, F002V as forecast_type_code, F003V as forecast_type, F004V as content, F005V as reason, F006C as latest_flag, F007N as profit_lower, F008N as profit_upper, F009N as change_lower, F010N as change_upper, UPDATE_DATE FROM ea_forecast WHERE SECCODE = :seccode ORDER BY F001D DESC, UPDATE_DATE DESC LIMIT 10 """) with engine.connect() as conn: forecast_result = conn.execute(forecast_query, {'seccode': seccode}) forecast_data = [] for row in forecast_result: forecast = { 'declare_date': format_date(row.DECLAREDATE), 'report_date': format_date(row.report_date), 'report_type': get_report_type(row.report_date), 'forecast_type': row.forecast_type, 'forecast_type_code': row.forecast_type_code, 'content': row.content, 'reason': row.reason, 'is_latest': row.latest_flag == 'T', 'profit_range': { 'lower': format_decimal(row.profit_lower), 'upper': format_decimal(row.profit_upper), }, 'change_range': { 'lower': format_decimal(row.change_lower), 'upper': format_decimal(row.change_upper), }, 'update_date': format_date(row.UPDATE_DATE) } forecast_data.append(forecast) # 获取预披露时间 pretime_query = text(""" SELECT distinct F001D as report_period, F002D as scheduled_date, F003D as change_date_1, F004D as change_date_2, F005D as change_date_3, F006D as actual_date, F007D as change_date_4, F008D as change_date_5, UPDATE_DATE FROM ea_pretime WHERE SECCODE = :seccode ORDER BY F001D DESC LIMIT 8 """) with engine.connect() as conn: pretime_result = conn.execute(pretime_query, {'seccode': seccode}) pretime_data = [] for row in pretime_result: # 收集所有变更日期 change_dates = [] for date in [row.change_date_1, row.change_date_2, row.change_date_3, row.change_date_4, row.change_date_5]: if date: change_dates.append(format_date(date)) pretime = { 'report_period': format_date(row.report_period), 'report_type': get_report_type(row.report_period), 'scheduled_date': format_date(row.scheduled_date), 'actual_date': format_date(row.actual_date), 'change_dates': change_dates, 'update_date': format_date(row.UPDATE_DATE), 'status': 'completed' if row.actual_date else 'pending' } pretime_data.append(pretime) return jsonify({ 'success': True, 'data': { 'forecasts': forecast_data, 'disclosure_schedule': pretime_data } }) except Exception as e: return jsonify({ 'success': False, 'error': str(e) }), 500 @app.route('/api/financial/industry-rank/', methods=['GET']) def get_industry_rank(seccode): """获取行业排名数据""" try: limit = request.args.get('limit', 4, type=int) query = text(""" SELECT distinct F001V as industry_level, F002V as level_description, F003D as report_date, INDNAME as industry_name, -- 每股收益 F004N as eps, F005N as eps_industry_avg, F006N as eps_rank, -- 扣除后每股收益 F007N as deducted_eps, F008N as deducted_eps_industry_avg, F009N as deducted_eps_rank, -- 每股净资产 F010N as bvps, F011N as bvps_industry_avg, F012N as bvps_rank, -- 净资产收益率 F013N as roe, F014N as roe_industry_avg, F015N as roe_rank, -- 每股未分配利润 F016N as undistributed_profit_ps, F017N as undistributed_profit_ps_industry_avg, F018N as undistributed_profit_ps_rank, -- 每股经营现金流量 F019N as operating_cash_flow_ps, F020N as operating_cash_flow_ps_industry_avg, F021N as operating_cash_flow_ps_rank, -- 营业收入增长率 F022N as revenue_growth, F023N as revenue_growth_industry_avg, F024N as revenue_growth_rank, -- 净利润增长率 F025N as profit_growth, F026N as profit_growth_industry_avg, F027N as profit_growth_rank, -- 营业利润率 F028N as operating_margin, F029N as operating_margin_industry_avg, F030N as operating_margin_rank, -- 资产负债率 F031N as debt_ratio, F032N as debt_ratio_industry_avg, F033N as debt_ratio_rank, -- 应收账款周转率 F034N as receivable_turnover, F035N as receivable_turnover_industry_avg, F036N as receivable_turnover_rank, UPDATE_DATE FROM ea_finindexrank WHERE SECCODE = :seccode ORDER BY F003D DESC, F001V ASC LIMIT :limit_total """) # 获取多个报告期的数据 with engine.connect() as conn: result = conn.execute(query, {'seccode': seccode, 'limit_total': limit * 4}) # 按报告期和行业级别组织数据 data_by_period = {} for row in result: period = format_date(row.report_date) if period not in data_by_period: data_by_period[period] = [] rank_data = { 'industry_level': row.industry_level, 'level_description': row.level_description, 'industry_name': row.industry_name, 'metrics': { 'eps': { 'value': format_decimal(row.eps), 'industry_avg': format_decimal(row.eps_industry_avg), 'rank': int(row.eps_rank) if row.eps_rank else None }, 'deducted_eps': { 'value': format_decimal(row.deducted_eps), 'industry_avg': format_decimal(row.deducted_eps_industry_avg), 'rank': int(row.deducted_eps_rank) if row.deducted_eps_rank else None }, 'bvps': { 'value': format_decimal(row.bvps), 'industry_avg': format_decimal(row.bvps_industry_avg), 'rank': int(row.bvps_rank) if row.bvps_rank else None }, 'roe': { 'value': format_decimal(row.roe), 'industry_avg': format_decimal(row.roe_industry_avg), 'rank': int(row.roe_rank) if row.roe_rank else None }, 'operating_cash_flow_ps': { 'value': format_decimal(row.operating_cash_flow_ps), 'industry_avg': format_decimal(row.operating_cash_flow_ps_industry_avg), 'rank': int(row.operating_cash_flow_ps_rank) if row.operating_cash_flow_ps_rank else None }, 'revenue_growth': { 'value': format_decimal(row.revenue_growth), 'industry_avg': format_decimal(row.revenue_growth_industry_avg), 'rank': int(row.revenue_growth_rank) if row.revenue_growth_rank else None }, 'profit_growth': { 'value': format_decimal(row.profit_growth), 'industry_avg': format_decimal(row.profit_growth_industry_avg), 'rank': int(row.profit_growth_rank) if row.profit_growth_rank else None }, 'operating_margin': { 'value': format_decimal(row.operating_margin), 'industry_avg': format_decimal(row.operating_margin_industry_avg), 'rank': int(row.operating_margin_rank) if row.operating_margin_rank else None }, 'debt_ratio': { 'value': format_decimal(row.debt_ratio), 'industry_avg': format_decimal(row.debt_ratio_industry_avg), 'rank': int(row.debt_ratio_rank) if row.debt_ratio_rank else None }, 'receivable_turnover': { 'value': format_decimal(row.receivable_turnover), 'industry_avg': format_decimal(row.receivable_turnover_industry_avg), 'rank': int(row.receivable_turnover_rank) if row.receivable_turnover_rank else None } } } data_by_period[period].append(rank_data) # 转换为列表格式 data = [] for period, ranks in data_by_period.items(): data.append({ 'period': period, 'report_type': get_report_type(period), 'rankings': ranks }) return jsonify({ 'success': True, 'data': data }) except Exception as e: return jsonify({ 'success': False, 'error': str(e) }), 500 @app.route('/api/financial/comparison/', methods=['GET']) def get_period_comparison(seccode): """获取不同报告期的对比数据""" try: periods = request.args.get('periods', 8, type=int) # 获取多期财务数据进行对比 query = text(""" SELECT distinct fi.ENDDATE, fi.F089N as revenue, fi.F101N as net_profit, fi.F102N as parent_net_profit, fi.F078N as gross_margin, fi.F017N as net_margin, fi.F014N as roe, fi.F016N as roa, fi.F052N as revenue_growth, fi.F053N as profit_growth, fi.F003N as eps, fi.F060N as operating_cash_flow_ps, fi.F042N as current_ratio, fi.F041N as debt_ratio, fi.F105N as operating_cash_flow, fi.F118N as total_assets, fi.F121N as total_liabilities, fi.F128N as total_equity FROM ea_financialindex fi WHERE fi.SECCODE = :seccode ORDER BY fi.ENDDATE DESC LIMIT :periods """) with engine.connect() as conn: result = conn.execute(query, {'seccode': seccode, 'periods': periods}) data = [] for row in result: period_data = { 'period': format_date(row.ENDDATE), 'report_type': get_report_type(row.ENDDATE), 'performance': { 'revenue': format_decimal(row.revenue), 'net_profit': format_decimal(row.net_profit), 'parent_net_profit': format_decimal(row.parent_net_profit), 'operating_cash_flow': format_decimal(row.operating_cash_flow), }, 'profitability': { 'gross_margin': format_decimal(row.gross_margin), 'net_margin': format_decimal(row.net_margin), 'roe': format_decimal(row.roe), 'roa': format_decimal(row.roa), }, 'growth': { 'revenue_growth': format_decimal(row.revenue_growth), 'profit_growth': format_decimal(row.profit_growth), }, 'per_share': { 'eps': format_decimal(row.eps), 'operating_cash_flow_ps': format_decimal(row.operating_cash_flow_ps), }, 'financial_health': { 'current_ratio': format_decimal(row.current_ratio), 'debt_ratio': format_decimal(row.debt_ratio), 'total_assets': format_decimal(row.total_assets), 'total_liabilities': format_decimal(row.total_liabilities), 'total_equity': format_decimal(row.total_equity), } } data.append(period_data) # 计算同比和环比变化 for i in range(len(data)): if i > 0: # 环比 data[i]['qoq_changes'] = { 'revenue': calculate_change(data[i]['performance']['revenue'], data[i - 1]['performance']['revenue']), 'net_profit': calculate_change(data[i]['performance']['net_profit'], data[i - 1]['performance']['net_profit']), } # 同比(找到去年同期) current_period = data[i]['period'] yoy_period = get_yoy_period(current_period) for j in range(len(data)): if data[j]['period'] == yoy_period: data[i]['yoy_changes'] = { 'revenue': calculate_change(data[i]['performance']['revenue'], data[j]['performance']['revenue']), 'net_profit': calculate_change(data[i]['performance']['net_profit'], data[j]['performance']['net_profit']), } break return jsonify({ 'success': True, 'data': data }) except Exception as e: return jsonify({ 'success': False, 'error': str(e) }), 500 # 辅助函数 def calculate_change(current, previous): """计算变化率""" if previous and current: return format_decimal((current - previous) / abs(previous) * 100) return None def get_yoy_period(date_str): """获取去年同期""" if not date_str: return None try: date = datetime.strptime(date_str, '%Y-%m-%d') yoy_date = date.replace(year=date.year - 1) return yoy_date.strftime('%Y-%m-%d') except: return None @app.route('/api/market/trade/', methods=['GET']) def get_trade_data(seccode): """获取股票交易数据(日K线)""" try: days = request.args.get('days', 60, type=int) end_date = request.args.get('end_date', datetime.now().strftime('%Y-%m-%d')) query = text(""" SELECT TRADEDATE, SECNAME, F002N as pre_close, F003N as open, F004N as volume, F005N as high, F006N as low, F007N as close, F008N as trades_count, F009N as change_amount, F010N as change_percent, F011N as amount, F012N as turnover_rate, F013N as amplitude, F020N as total_shares, F021N as float_shares, F026N as pe_ratio FROM ea_trade WHERE SECCODE = :seccode AND TRADEDATE <= :end_date ORDER BY TRADEDATE DESC LIMIT :days """) with engine.connect() as conn: result = conn.execute(query, {'seccode': seccode, 'end_date': end_date, 'days': days}) data = [] for row in result: data.append({ 'date': format_date(row.TRADEDATE), 'stock_name': row.SECNAME, 'open': format_decimal(row.open), 'high': format_decimal(row.high), 'low': format_decimal(row.low), 'close': format_decimal(row.close), 'pre_close': format_decimal(row.pre_close), 'volume': format_decimal(row.volume), 'amount': format_decimal(row.amount), 'change_amount': format_decimal(row.change_amount), 'change_percent': format_decimal(row.change_percent), 'turnover_rate': format_decimal(row.turnover_rate), 'amplitude': format_decimal(row.amplitude), 'trades_count': format_decimal(row.trades_count), 'pe_ratio': format_decimal(row.pe_ratio), 'total_shares': format_decimal(row.total_shares), 'float_shares': format_decimal(row.float_shares), }) # 倒序,让最早的日期在前 data.reverse() # 计算统计数据 if data: prices = [d['close'] for d in data if d['close']] stats = { 'highest': max(prices) if prices else None, 'lowest': min(prices) if prices else None, 'average': sum(prices) / len(prices) if prices else None, 'latest_price': data[-1]['close'] if data else None, 'total_volume': sum([d['volume'] for d in data if d['volume']]) if data else None, 'total_amount': sum([d['amount'] for d in data if d['amount']]) if data else None, } else: stats = {} return jsonify({ 'success': True, 'data': data, 'stats': stats }) except Exception as e: return jsonify({ 'success': False, 'error': str(e) }), 500 @app.route('/api/market/funding/', methods=['GET']) def get_funding_data(seccode): """获取融资融券数据""" try: days = request.args.get('days', 30, type=int) query = text(""" SELECT TRADEDATE, SECNAME, F001N as financing_balance, F002N as financing_buy, F003N as financing_repay, F004N as securities_balance, F006N as securities_sell, F007N as securities_repay, F008N as securities_balance_amount, F009N as total_balance FROM ea_funding WHERE SECCODE = :seccode ORDER BY TRADEDATE DESC LIMIT :days """) with engine.connect() as conn: result = conn.execute(query, {'seccode': seccode, 'days': days}) data = [] for row in result: data.append({ 'date': format_date(row.TRADEDATE), 'stock_name': row.SECNAME, 'financing': { 'balance': format_decimal(row.financing_balance), 'buy': format_decimal(row.financing_buy), 'repay': format_decimal(row.financing_repay), 'net': format_decimal( row.financing_buy - row.financing_repay) if row.financing_buy and row.financing_repay else None }, 'securities': { 'balance': format_decimal(row.securities_balance), 'sell': format_decimal(row.securities_sell), 'repay': format_decimal(row.securities_repay), 'balance_amount': format_decimal(row.securities_balance_amount) }, 'total_balance': format_decimal(row.total_balance) }) data.reverse() return jsonify({ 'success': True, 'data': data }) except Exception as e: return jsonify({ 'success': False, 'error': str(e) }), 500 @app.route('/api/market/bigdeal/', methods=['GET']) def get_bigdeal_data(seccode): """获取大宗交易数据""" try: days = request.args.get('days', 30, type=int) query = text(""" SELECT TRADEDATE, SECNAME, F001V as exchange, F002V as buyer_dept, F003V as seller_dept, F004N as price, F005N as volume, F006N as amount, F007N as seq_no FROM ea_bigdeal WHERE SECCODE = :seccode ORDER BY TRADEDATE DESC, F007N LIMIT :days """) with engine.connect() as conn: result = conn.execute(query, {'seccode': seccode, 'days': days}) data = [] for row in result: data.append({ 'date': format_date(row.TRADEDATE), 'stock_name': row.SECNAME, 'exchange': row.exchange, 'buyer_dept': row.buyer_dept, 'seller_dept': row.seller_dept, 'price': format_decimal(row.price), 'volume': format_decimal(row.volume), 'amount': format_decimal(row.amount), 'seq_no': int(row.seq_no) if row.seq_no else None }) # 按日期分组统计 daily_stats = {} for item in data: date = item['date'] if date not in daily_stats: daily_stats[date] = { 'date': date, 'count': 0, 'total_volume': 0, 'total_amount': 0, 'avg_price': 0, 'deals': [] } daily_stats[date]['count'] += 1 daily_stats[date]['total_volume'] += item['volume'] or 0 daily_stats[date]['total_amount'] += item['amount'] or 0 daily_stats[date]['deals'].append(item) # 计算平均价格 for date in daily_stats: if daily_stats[date]['total_volume'] > 0: daily_stats[date]['avg_price'] = daily_stats[date]['total_amount'] / daily_stats[date]['total_volume'] return jsonify({ 'success': True, 'data': data, 'daily_stats': list(daily_stats.values()) }) except Exception as e: return jsonify({ 'success': False, 'error': str(e) }), 500 @app.route('/api/market/unusual/', methods=['GET']) def get_unusual_data(seccode): """获取龙虎榜数据""" try: days = request.args.get('days', 30, type=int) query = text(""" SELECT TRADEDATE, SECNAME, F001V as info_type_code, F002V as info_type, F003C as trade_type, F004N as rank_no, F005V as dept_name, F006N as buy_amount, F007N as sell_amount, F008N as net_amount FROM ea_unusual WHERE SECCODE = :seccode ORDER BY TRADEDATE DESC, F004N LIMIT 100 """) with engine.connect() as conn: result = conn.execute(query, {'seccode': seccode}) data = [] for row in result: data.append({ 'date': format_date(row.TRADEDATE), 'stock_name': row.SECNAME, 'info_type': row.info_type, 'info_type_code': row.info_type_code, 'trade_type': 'buy' if row.trade_type == 'B' else 'sell' if row.trade_type == 'S' else 'unknown', 'rank': int(row.rank_no) if row.rank_no else None, 'dept_name': row.dept_name, 'buy_amount': format_decimal(row.buy_amount), 'sell_amount': format_decimal(row.sell_amount), 'net_amount': format_decimal(row.net_amount) }) # 按日期分组 grouped_data = {} for item in data: date = item['date'] if date not in grouped_data: grouped_data[date] = { 'date': date, 'info_types': set(), 'buyers': [], 'sellers': [], 'total_buy': 0, 'total_sell': 0, 'net_amount': 0 } grouped_data[date]['info_types'].add(item['info_type']) if item['trade_type'] == 'buy': grouped_data[date]['buyers'].append(item) grouped_data[date]['total_buy'] += item['buy_amount'] or 0 elif item['trade_type'] == 'sell': grouped_data[date]['sellers'].append(item) grouped_data[date]['total_sell'] += item['sell_amount'] or 0 grouped_data[date]['net_amount'] = grouped_data[date]['total_buy'] - grouped_data[date]['total_sell'] # 转换set为list for date in grouped_data: grouped_data[date]['info_types'] = list(grouped_data[date]['info_types']) return jsonify({ 'success': True, 'data': data, 'grouped_data': list(grouped_data.values()) }) except Exception as e: return jsonify({ 'success': False, 'error': str(e) }), 500 @app.route('/api/market/pledge/', methods=['GET']) def get_pledge_data(seccode): """获取股权质押数据""" try: query = text(""" SELECT ENDDATE, STARTDATE, SECNAME, F001N as unrestricted_pledge, F002N as restricted_pledge, F003N as total_shares_a, F004N as pledge_count, F005N as pledge_ratio FROM ea_pledgeratio WHERE SECCODE = :seccode ORDER BY ENDDATE DESC LIMIT 12 """) with engine.connect() as conn: result = conn.execute(query, {'seccode': seccode}) data = [] for row in result: total_pledge = (row.unrestricted_pledge or 0) + (row.restricted_pledge or 0) data.append({ 'end_date': format_date(row.ENDDATE), 'start_date': format_date(row.STARTDATE), 'stock_name': row.SECNAME, 'unrestricted_pledge': format_decimal(row.unrestricted_pledge), 'restricted_pledge': format_decimal(row.restricted_pledge), 'total_pledge': format_decimal(total_pledge), 'total_shares': format_decimal(row.total_shares_a), 'pledge_count': int(row.pledge_count) if row.pledge_count else None, 'pledge_ratio': format_decimal(row.pledge_ratio) }) return jsonify({ 'success': True, 'data': data }) except Exception as e: return jsonify({ 'success': False, 'error': str(e) }), 500 @app.route('/api/market/summary/', methods=['GET']) def get_market_summary(seccode): """获取市场数据汇总""" try: # 获取最新交易数据 trade_query = text(""" SELECT * FROM ea_trade WHERE SECCODE = :seccode ORDER BY TRADEDATE DESC LIMIT 1 """) # 获取最新融资融券数据 funding_query = text(""" SELECT * FROM ea_funding WHERE SECCODE = :seccode ORDER BY TRADEDATE DESC LIMIT 1 """) # 获取最新质押数据 pledge_query = text(""" SELECT * FROM ea_pledgeratio WHERE SECCODE = :seccode ORDER BY ENDDATE DESC LIMIT 1 """) with engine.connect() as conn: trade_result = conn.execute(trade_query, {'seccode': seccode}).fetchone() with engine.connect() as conn: funding_result = conn.execute(funding_query, {'seccode': seccode}).fetchone() with engine.connect() as conn: pledge_result = conn.execute(pledge_query, {'seccode': seccode}).fetchone() summary = { 'stock_code': seccode, 'stock_name': trade_result.SECNAME if trade_result else None, 'latest_trade': { 'date': format_date(trade_result.TRADEDATE) if trade_result else None, 'close': format_decimal(trade_result.F007N) if trade_result else None, 'change_percent': format_decimal(trade_result.F010N) if trade_result else None, 'volume': format_decimal(trade_result.F004N) if trade_result else None, 'amount': format_decimal(trade_result.F011N) if trade_result else None, 'pe_ratio': format_decimal(trade_result.F026N) if trade_result else None, 'turnover_rate': format_decimal(trade_result.F012N) if trade_result else None, } if trade_result else None, 'latest_funding': { 'date': format_date(funding_result.TRADEDATE) if funding_result else None, 'financing_balance': format_decimal(funding_result.F001N) if funding_result else None, 'securities_balance': format_decimal(funding_result.F004N) if funding_result else None, 'total_balance': format_decimal(funding_result.F009N) if funding_result else None, } if funding_result else None, 'latest_pledge': { 'date': format_date(pledge_result.ENDDATE) if pledge_result else None, 'pledge_ratio': format_decimal(pledge_result.F005N) if pledge_result else None, 'pledge_count': int(pledge_result.F004N) if pledge_result and pledge_result.F004N else None, } if pledge_result else None } return jsonify({ 'success': True, 'data': summary }) except Exception as e: return jsonify({ 'success': False, 'error': str(e) }), 500 @app.route('/api/stocks/search', methods=['GET']) def search_stocks(): """搜索股票(支持股票代码、股票简称、拼音首字母)""" try: query = request.args.get('q', '').strip() limit = request.args.get('limit', 20, type=int) if not query: return jsonify({ 'success': False, 'error': '请输入搜索关键词' }), 400 with engine.connect() as conn: test_sql = text(""" SELECT SECCODE, SECNAME, F001V, F003V, F010V, F011V FROM ea_stocklist WHERE SECCODE = '300750' OR F001V LIKE '%ndsd%' LIMIT 5 """) test_result = conn.execute(test_sql).fetchall() # 构建搜索SQL - 支持股票代码、股票简称、拼音简称搜索 search_sql = text(""" SELECT DISTINCT SECCODE as stock_code, SECNAME as stock_name, F001V as pinyin_abbr, F003V as security_type, F005V as exchange, F011V as listing_status FROM ea_stocklist WHERE ( UPPER(SECCODE) LIKE UPPER(:query_pattern) OR UPPER(SECNAME) LIKE UPPER(:query_pattern) OR UPPER(F001V) LIKE UPPER(:query_pattern) ) -- 基本过滤条件:只搜索正常的A股和B股 AND (F011V = '正常上市' OR F010V = '013001') -- 正常上市状态 AND F003V IN ('A股', 'B股') -- 只搜索A股和B股 ORDER BY CASE WHEN UPPER(SECCODE) = UPPER(:exact_query) THEN 1 WHEN UPPER(SECNAME) = UPPER(:exact_query) THEN 2 WHEN UPPER(F001V) = UPPER(:exact_query) THEN 3 WHEN UPPER(SECCODE) LIKE UPPER(:prefix_pattern) THEN 4 WHEN UPPER(SECNAME) LIKE UPPER(:prefix_pattern) THEN 5 WHEN UPPER(F001V) LIKE UPPER(:prefix_pattern) THEN 6 ELSE 7 END, SECCODE LIMIT :limit """) result = conn.execute(search_sql, { 'query_pattern': f'%{query}%', 'exact_query': query, 'prefix_pattern': f'{query}%', 'limit': limit }).fetchall() stocks = [] for row in result: # 获取当前价格 current_price, _ = get_latest_price_from_clickhouse(row.stock_code) stocks.append({ 'stock_code': row.stock_code, 'stock_name': row.stock_name, 'current_price': current_price or 0, # 添加当前价格 'pinyin_abbr': row.pinyin_abbr, 'security_type': row.security_type, 'exchange': row.exchange, 'listing_status': row.listing_status }) return jsonify({ 'success': True, 'data': stocks, 'count': len(stocks) }) except Exception as e: return jsonify({ 'success': False, 'error': str(e) }), 500 @app.route('/api/market/heatmap', methods=['GET']) def get_market_heatmap(): """获取市场热力图数据(基于市值和涨跌幅)""" try: # 获取交易日期参数 trade_date = request.args.get('date') # 前端显示用的limit,但统计数据会基于全部股票 display_limit = request.args.get('limit', 500, type=int) with engine.connect() as conn: # 如果没有指定日期,获取最新交易日 if not trade_date: latest_date_result = conn.execute(text(""" SELECT MAX(TRADEDATE) as latest_date FROM ea_trade """)).fetchone() trade_date = latest_date_result.latest_date if latest_date_result else None if not trade_date: return jsonify({ 'success': False, 'error': '无法获取交易数据' }), 404 # 获取全部股票数据用于统计 all_stocks_sql = text(""" SELECT t.SECCODE as stock_code, t.SECNAME as stock_name, t.F010N as change_percent, -- 涨跌幅 t.F007N as close_price, -- 收盘价 t.F021N * t.F007N / 100000000 as market_cap, -- 市值(亿元) t.F011N / 100000000 as amount, -- 成交额(亿元) t.F012N as turnover_rate, -- 换手率 b.F034V as industry, -- 申万行业分类一级名称 b.F026V as province -- 所属省份 FROM ea_trade t LEFT JOIN ea_baseinfo b ON t.SECCODE = b.SECCODE WHERE t.TRADEDATE = :trade_date AND t.F010N IS NOT NULL -- 仅统计当日有涨跌幅数据的股票 ORDER BY market_cap DESC """) all_result = conn.execute(all_stocks_sql, { 'trade_date': trade_date }).fetchall() # 计算统计数据(基于全部股票) total_market_cap = 0 total_amount = 0 rising_count = 0 falling_count = 0 flat_count = 0 all_data = [] for row in all_result: # F010N 已在 SQL 中确保非空 change_percent = float(row.change_percent) market_cap = float(row.market_cap) if row.market_cap else 0 amount = float(row.amount) if row.amount else 0 total_market_cap += market_cap total_amount += amount if change_percent > 0: rising_count += 1 elif change_percent < 0: falling_count += 1 else: flat_count += 1 all_data.append({ 'stock_code': row.stock_code, 'stock_name': row.stock_name, 'change_percent': change_percent, 'close_price': float(row.close_price) if row.close_price else 0, 'market_cap': market_cap, 'amount': amount, 'turnover_rate': float(row.turnover_rate) if row.turnover_rate else 0, 'industry': row.industry, 'province': row.province }) # 只返回前display_limit条用于热力图显示 heatmap_data = all_data[:display_limit] return jsonify({ 'success': True, 'data': heatmap_data, 'trade_date': trade_date.strftime('%Y-%m-%d') if hasattr(trade_date, 'strftime') else str(trade_date), 'count': len(all_data), # 全部股票数量 'display_count': len(heatmap_data), # 显示的股票数量 'statistics': { 'total_market_cap': round(total_market_cap, 2), # 总市值(亿元) 'total_amount': round(total_amount, 2), # 总成交额(亿元) 'rising_count': rising_count, # 上涨家数 'falling_count': falling_count, # 下跌家数 'flat_count': flat_count # 平盘家数 } }) except Exception as e: return jsonify({ 'success': False, 'error': str(e) }), 500 @app.route('/api/market/statistics', methods=['GET']) def get_market_statistics(): """获取市场统计数据(从ea_blocktrading表)""" try: # 获取交易日期参数 trade_date = request.args.get('date') with engine.connect() as conn: # 如果没有指定日期,获取最新交易日 if not trade_date: latest_date_result = conn.execute(text(""" SELECT MAX(TRADEDATE) as latest_date FROM ea_blocktrading """)).fetchone() trade_date = latest_date_result.latest_date if latest_date_result else None if not trade_date: return jsonify({ 'success': False, 'error': '无法获取统计数据' }), 404 # 获取沪深两市的统计数据 stats_sql = text(""" SELECT EXCHANGECODE, EXCHANGENAME, F001V as indicator_code, F002V as indicator_name, F003N as indicator_value, F004V as unit, TRADEDATE FROM ea_blocktrading WHERE TRADEDATE = :trade_date AND EXCHANGECODE IN ('012001', '012002') -- 只获取上交所和深交所的数据 AND F001V IN ( '250006', '250014', -- 深交所股票总市值、上交所市价总值 '250007', '250015', -- 深交所股票流通市值、上交所流通市值 '250008', -- 深交所股票成交金额 '250010', '250019', -- 深交所股票平均市盈率、上交所平均市盈率 '250050', '250001' -- 上交所上市公司家数、深交所上市公司数 ) """) result = conn.execute(stats_sql, { 'trade_date': trade_date }).fetchall() # 整理数据 statistics = {} for row in result: key = f"{row.EXCHANGECODE}_{row.indicator_code}" statistics[key] = { 'exchange_code': row.EXCHANGECODE, 'exchange_name': row.EXCHANGENAME, 'indicator_code': row.indicator_code, 'indicator_name': row.indicator_name, 'value': float(row.indicator_value) if row.indicator_value else 0, 'unit': row.unit } # 汇总数据 summary = { 'total_market_cap': 0, # 总市值 'total_float_cap': 0, # 流通市值 'total_amount': 0, # 成交额 'sh_pe_ratio': 0, # 上交所市盈率 'sz_pe_ratio': 0, # 深交所市盈率 'sh_companies': 0, # 上交所上市公司数 'sz_companies': 0 # 深交所上市公司数 } # 计算汇总值 if '012001_250014' in statistics: # 上交所市价总值 summary['total_market_cap'] += statistics['012001_250014']['value'] if '012002_250006' in statistics: # 深交所股票总市值 summary['total_market_cap'] += statistics['012002_250006']['value'] if '012001_250015' in statistics: # 上交所流通市值 summary['total_float_cap'] += statistics['012001_250015']['value'] if '012002_250007' in statistics: # 深交所股票流通市值 summary['total_float_cap'] += statistics['012002_250007']['value'] # 成交额需要获取上交所的数据 # 获取上交所成交金额 sh_amount_result = conn.execute(text(""" SELECT F003N FROM ea_blocktrading WHERE TRADEDATE = :trade_date AND EXCHANGECODE = '012001' AND F002V LIKE '%成交金额%' LIMIT 1 """), {'trade_date': trade_date}).fetchone() sh_amount = float(sh_amount_result.F003N) if sh_amount_result and sh_amount_result.F003N else 0 sz_amount = statistics['012002_250008']['value'] if '012002_250008' in statistics else 0 summary['total_amount'] = sh_amount + sz_amount if '012001_250019' in statistics: # 上交所平均市盈率 summary['sh_pe_ratio'] = statistics['012001_250019']['value'] if '012002_250010' in statistics: # 深交所股票平均市盈率 summary['sz_pe_ratio'] = statistics['012002_250010']['value'] if '012001_250050' in statistics: # 上交所上市公司家数 summary['sh_companies'] = int(statistics['012001_250050']['value']) if '012002_250001' in statistics: # 深交所上市公司数 summary['sz_companies'] = int(statistics['012002_250001']['value']) # 获取可用的交易日期列表 available_dates_result = conn.execute(text(""" SELECT DISTINCT TRADEDATE FROM ea_blocktrading WHERE EXCHANGECODE IN ('012001', '012002') ORDER BY TRADEDATE DESC LIMIT 30 """)).fetchall() available_dates = [str(row.TRADEDATE) for row in available_dates_result] return jsonify({ 'success': True, 'trade_date': str(trade_date), 'summary': summary, 'details': list(statistics.values()), 'available_dates': available_dates }) except Exception as e: return jsonify({ 'success': False, 'error': str(e) }), 500 @app.route('/api/concepts/daily-top', methods=['GET']) def get_daily_top_concepts(): """获取每日涨幅靠前的概念板块""" try: # 获取交易日期参数 trade_date = request.args.get('date') limit = request.args.get('limit', 6, type=int) # 构建概念中心API的URL concept_api_url = 'http://222.128.1.157:16801/search' # 准备请求数据 request_data = { 'query': '', 'size': limit, 'page': 1, 'sort_by': 'change_pct' } if trade_date: request_data['trade_date'] = trade_date # 调用概念中心API response = requests.post(concept_api_url, json=request_data, timeout=10) if response.status_code == 200: data = response.json() top_concepts = [] for concept in data.get('results', []): top_concepts.append({ 'concept_id': concept.get('concept_id'), 'concept_name': concept.get('concept'), 'description': concept.get('description'), 'change_percent': concept.get('price_info', {}).get('avg_change_pct', 0), 'stock_count': concept.get('stock_count', 0), 'stocks': concept.get('stocks', [])[:5] # 只返回前5只股票 }) return jsonify({ 'success': True, 'data': top_concepts, 'trade_date': data.get('price_date'), 'count': len(top_concepts) }) else: return jsonify({ 'success': False, 'error': '获取概念数据失败' }), 500 except Exception as e: return jsonify({ 'success': False, 'error': str(e) }), 500 @app.route('/api/market/rise-analysis/', methods=['GET']) def get_rise_analysis(seccode): """获取股票涨幅分析数据""" try: # 获取日期范围参数 start_date = request.args.get('start_date') end_date = request.args.get('end_date') query = text(""" SELECT stock_code, stock_name, trade_date, rise_rate, close_price, volume, amount, main_business, rise_reason_brief, rise_reason_detail, news_summary, announcements, guba_sentiment, analysis_time FROM stock_rise_analysis WHERE stock_code = :stock_code """) params = {'stock_code': seccode} # 添加日期筛选 if start_date and end_date: query = text(""" SELECT stock_code, stock_name, trade_date, rise_rate, close_price, volume, amount, main_business, rise_reason_brief, rise_reason_detail, news_summary, announcements, guba_sentiment, analysis_time FROM stock_rise_analysis WHERE stock_code = :stock_code AND trade_date BETWEEN :start_date AND :end_date ORDER BY trade_date DESC """) params['start_date'] = start_date params['end_date'] = end_date else: query = text(""" SELECT stock_code, stock_name, trade_date, rise_rate, close_price, volume, amount, main_business, rise_reason_brief, rise_reason_detail, news_summary, announcements, guba_sentiment, analysis_time FROM stock_rise_analysis WHERE stock_code = :stock_code ORDER BY trade_date DESC LIMIT 100 """) with engine.connect() as conn: result = conn.execute(query, params).fetchall() # 格式化数据 rise_analysis_data = [] for row in result: rise_analysis_data.append({ 'stock_code': row.stock_code, 'stock_name': row.stock_name, 'trade_date': format_date(row.trade_date), 'rise_rate': format_decimal(row.rise_rate), 'close_price': format_decimal(row.close_price), 'volume': format_decimal(row.volume), 'amount': format_decimal(row.amount), 'main_business': row.main_business, 'rise_reason_brief': row.rise_reason_brief, 'rise_reason_detail': row.rise_reason_detail, 'news_summary': row.news_summary, 'announcements': row.announcements, 'guba_sentiment': row.guba_sentiment, 'analysis_time': row.analysis_time.strftime('%Y-%m-%d %H:%M:%S') if row.analysis_time else None }) return jsonify({ 'success': True, 'data': rise_analysis_data, 'count': len(rise_analysis_data) }) except Exception as e: return jsonify({ 'success': False, 'error': str(e) }), 500 # ============================================ # 公司分析相关接口 # ============================================ @app.route('/api/company/comprehensive-analysis/', methods=['GET']) def get_comprehensive_analysis(company_code): """获取公司综合分析数据""" try: # 获取公司定性分析 qualitative_query = text(""" SELECT one_line_intro, investment_highlights, business_model_desc, company_story, positioning_analysis, unique_value_proposition, business_logic_explanation, revenue_driver_analysis, customer_value_analysis, strategy_description, strategic_initiatives, created_at, updated_at FROM company_analysis WHERE company_code = :company_code """) with engine.connect() as conn: qualitative_result = conn.execute(qualitative_query, {'company_code': company_code}).fetchone() # 获取业务板块分析 segments_query = text(""" SELECT segment_name, segment_description, competitive_position, future_potential, key_customers, value_chain_position, created_at, updated_at FROM business_segment_analysis WHERE company_code = :company_code ORDER BY created_at DESC """) with engine.connect() as conn: segments_result = conn.execute(segments_query, {'company_code': company_code}).fetchall() # 获取竞争地位数据 - 最新一期 competitive_query = text(""" SELECT market_position_score, technology_score, brand_score, operation_score, finance_score, innovation_score, risk_score, growth_score, industry_avg_comparison, main_competitors, competitive_advantages, competitive_disadvantages, industry_rank, total_companies, report_period, updated_at FROM company_competitive_position WHERE company_code = :company_code ORDER BY report_period DESC LIMIT 1 """) with engine.connect() as conn: competitive_result = conn.execute(competitive_query, {'company_code': company_code}).fetchone() # 获取业务结构数据 - 最新一期 business_structure_query = text(""" SELECT business_name, parent_business, business_level, revenue, revenue_unit, revenue_ratio, profit, profit_unit, profit_ratio, revenue_growth, profit_growth, gross_margin, customer_count, market_share, report_period FROM company_business_structure WHERE company_code = :company_code AND report_period = (SELECT MAX(report_period) FROM company_business_structure WHERE company_code = :company_code) ORDER BY revenue_ratio DESC """) with engine.connect() as conn: business_structure_result = conn.execute(business_structure_query, {'company_code': company_code}).fetchall() # 构建返回数据 response_data = { 'company_code': company_code, 'qualitative_analysis': None, 'business_segments': [], 'competitive_position': None, 'business_structure': [] } # 处理定性分析数据 if qualitative_result: response_data['qualitative_analysis'] = { 'core_positioning': { 'one_line_intro': qualitative_result.one_line_intro, 'investment_highlights': qualitative_result.investment_highlights, 'business_model_desc': qualitative_result.business_model_desc, 'company_story': qualitative_result.company_story }, 'business_understanding': { 'positioning_analysis': qualitative_result.positioning_analysis, 'unique_value_proposition': qualitative_result.unique_value_proposition, 'business_logic_explanation': qualitative_result.business_logic_explanation, 'revenue_driver_analysis': qualitative_result.revenue_driver_analysis, 'customer_value_analysis': qualitative_result.customer_value_analysis }, 'strategy': { 'strategy_description': qualitative_result.strategy_description, 'strategic_initiatives': qualitative_result.strategic_initiatives }, 'updated_at': qualitative_result.updated_at.strftime( '%Y-%m-%d %H:%M:%S') if qualitative_result.updated_at else None } # 处理业务板块数据 for segment in segments_result: response_data['business_segments'].append({ 'segment_name': segment.segment_name, 'segment_description': segment.segment_description, 'competitive_position': segment.competitive_position, 'future_potential': segment.future_potential, 'key_customers': segment.key_customers, 'value_chain_position': segment.value_chain_position, 'updated_at': segment.updated_at.strftime('%Y-%m-%d %H:%M:%S') if segment.updated_at else None }) # 处理竞争地位数据 if competitive_result: response_data['competitive_position'] = { 'scores': { 'market_position': competitive_result.market_position_score, 'technology': competitive_result.technology_score, 'brand': competitive_result.brand_score, 'operation': competitive_result.operation_score, 'finance': competitive_result.finance_score, 'innovation': competitive_result.innovation_score, 'risk': competitive_result.risk_score, 'growth': competitive_result.growth_score }, 'analysis': { 'industry_avg_comparison': competitive_result.industry_avg_comparison, 'main_competitors': competitive_result.main_competitors, 'competitive_advantages': competitive_result.competitive_advantages, 'competitive_disadvantages': competitive_result.competitive_disadvantages }, 'ranking': { 'industry_rank': competitive_result.industry_rank, 'total_companies': competitive_result.total_companies, 'rank_percentage': round( (competitive_result.industry_rank / competitive_result.total_companies * 100), 2) if competitive_result.industry_rank and competitive_result.total_companies else None }, 'report_period': competitive_result.report_period, 'updated_at': competitive_result.updated_at.strftime( '%Y-%m-%d %H:%M:%S') if competitive_result.updated_at else None } # 处理业务结构数据 for business in business_structure_result: response_data['business_structure'].append({ 'business_name': business.business_name, 'parent_business': business.parent_business, 'business_level': business.business_level, 'revenue': format_decimal(business.revenue), 'revenue_unit': business.revenue_unit, 'profit': format_decimal(business.profit), 'profit_unit': business.profit_unit, 'financial_metrics': { 'revenue': format_decimal(business.revenue), 'revenue_ratio': format_decimal(business.revenue_ratio), 'profit': format_decimal(business.profit), 'profit_ratio': format_decimal(business.profit_ratio), 'gross_margin': format_decimal(business.gross_margin) }, 'growth_metrics': { 'revenue_growth': format_decimal(business.revenue_growth), 'profit_growth': format_decimal(business.profit_growth) }, 'market_metrics': { 'customer_count': business.customer_count, 'market_share': format_decimal(business.market_share) }, 'report_period': business.report_period }) return jsonify({ 'success': True, 'data': response_data }) except Exception as e: return jsonify({ 'success': False, 'error': str(e) }), 500 @app.route('/api/company/value-chain-analysis/', methods=['GET']) def get_value_chain_analysis(company_code): """获取公司产业链分析数据""" try: # 获取产业链节点数据 nodes_query = text(""" SELECT node_name, node_type, node_level, node_description, importance_score, market_share, dependency_degree, created_at FROM company_value_chain_nodes WHERE company_code = :company_code ORDER BY node_level ASC, importance_score DESC """) with engine.connect() as conn: nodes_result = conn.execute(nodes_query, {'company_code': company_code}).fetchall() # 获取产业链流向数据 flows_query = text(""" SELECT source_node, source_type, source_level, target_node, target_type, target_level, flow_value, flow_ratio, flow_type, relationship_desc, transaction_volume FROM company_value_chain_flows WHERE company_code = :company_code ORDER BY flow_ratio DESC """) with engine.connect() as conn: flows_result = conn.execute(flows_query, {'company_code': company_code}).fetchall() # 构建节点数据结构 nodes_by_level = {} all_nodes = [] for node in nodes_result: node_data = { 'node_name': node.node_name, 'node_type': node.node_type, 'node_level': node.node_level, 'node_description': node.node_description, 'importance_score': node.importance_score, 'market_share': format_decimal(node.market_share), 'dependency_degree': format_decimal(node.dependency_degree), 'created_at': node.created_at.strftime('%Y-%m-%d %H:%M:%S') if node.created_at else None } all_nodes.append(node_data) # 按层级分组 level_key = f"level_{node.node_level}" if level_key not in nodes_by_level: nodes_by_level[level_key] = [] nodes_by_level[level_key].append(node_data) # 构建流向数据 flows_data = [] for flow in flows_result: flows_data.append({ 'source': { 'node_name': flow.source_node, 'node_type': flow.source_type, 'node_level': flow.source_level }, 'target': { 'node_name': flow.target_node, 'node_type': flow.target_type, 'node_level': flow.target_level }, 'flow_metrics': { 'flow_value': format_decimal(flow.flow_value), 'flow_ratio': format_decimal(flow.flow_ratio), 'flow_type': flow.flow_type }, 'relationship_info': { 'relationship_desc': flow.relationship_desc, 'transaction_volume': flow.transaction_volume } }) # 移除循环边,确保Sankey图数据是DAG(有向无环图) flows_data = remove_cycles_from_sankey_flows(flows_data) # 统计各层级节点数量 level_stats = {} for level_key, nodes in nodes_by_level.items(): level_stats[level_key] = { 'count': len(nodes), 'avg_importance': round(sum(node['importance_score'] or 0 for node in nodes) / len(nodes), 2) if nodes else 0 } response_data = { 'company_code': company_code, 'value_chain_structure': { 'nodes_by_level': nodes_by_level, 'level_statistics': level_stats, 'total_nodes': len(all_nodes) }, 'value_chain_flows': flows_data, 'analysis_summary': { 'total_flows': len(flows_data), 'upstream_nodes': len([n for n in all_nodes if n['node_level'] < 0]), 'company_nodes': len([n for n in all_nodes if n['node_level'] == 0]), 'downstream_nodes': len([n for n in all_nodes if n['node_level'] > 0]) } } return jsonify({ 'success': True, 'data': response_data }) except Exception as e: return jsonify({ 'success': False, 'error': str(e) }), 500 @app.route('/api/company/value-chain/related-companies', methods=['GET']) def get_related_companies_by_node(): """ 根据产业链节点名称查询相关公司(结合nodes和flows表) 参数: node_name - 节点名称(如 "中芯国际"、"EDA/IP"等) 返回: 包含该节点的所有公司列表,附带节点层级、类型、关系描述等信息 """ try: node_name = request.args.get('node_name') if not node_name: return jsonify({ 'success': False, 'error': '缺少必需参数 node_name' }), 400 # 查询包含该节点的所有公司及其节点信息 query = text(""" SELECT DISTINCT n.company_code as stock_code, s.SECNAME as stock_name, s.ORGNAME as company_name, n.node_level, n.node_type, n.node_description, n.importance_score, n.market_share, n.dependency_degree FROM company_value_chain_nodes n LEFT JOIN ea_stocklist s ON n.company_code = s.SECCODE WHERE n.node_name = :node_name ORDER BY n.importance_score DESC, n.company_code """) with engine.connect() as conn: nodes_result = conn.execute(query, {'node_name': node_name}).fetchall() # 构建返回数据 companies = [] for row in nodes_result: company_data = { 'stock_code': row.stock_code, 'stock_name': row.stock_name or row.stock_code, 'company_name': row.company_name, 'node_info': { 'node_level': row.node_level, 'node_type': row.node_type, 'node_description': row.node_description, 'importance_score': row.importance_score, 'market_share': format_decimal(row.market_share), 'dependency_degree': format_decimal(row.dependency_degree) }, 'relationships': [] } # 查询该节点在该公司产业链中的流向关系 flows_query = text(""" SELECT source_node, source_type, source_level, target_node, target_type, target_level, flow_type, relationship_desc, flow_value, flow_ratio FROM company_value_chain_flows WHERE company_code = :company_code AND (source_node = :node_name OR target_node = :node_name) ORDER BY flow_ratio DESC LIMIT 5 """) with engine.connect() as conn: flows_result = conn.execute(flows_query, { 'company_code': row.stock_code, 'node_name': node_name }).fetchall() # 添加流向关系信息 for flow in flows_result: # 判断节点在流向中的角色 is_source = (flow.source_node == node_name) relationship = { 'role': 'source' if is_source else 'target', 'connected_node': flow.target_node if is_source else flow.source_node, 'connected_type': flow.target_type if is_source else flow.source_type, 'connected_level': flow.target_level if is_source else flow.source_level, 'flow_type': flow.flow_type, 'relationship_desc': flow.relationship_desc, 'flow_ratio': format_decimal(flow.flow_ratio) } company_data['relationships'].append(relationship) companies.append(company_data) return jsonify({ 'success': True, 'data': companies, 'total': len(companies), 'node_name': node_name }) except Exception as e: return jsonify({ 'success': False, 'error': str(e) }), 500 @app.route('/api/company/key-factors-timeline/', methods=['GET']) def get_key_factors_timeline(company_code): """获取公司关键因素和时间线数据""" try: # 获取请求参数 report_period = request.args.get('report_period') # 可选的报告期筛选 event_limit = request.args.get('event_limit', 50, type=int) # 时间线事件数量限制 # 获取关键因素类别 categories_query = text(""" SELECT id, category_name, category_desc, display_order FROM company_key_factor_categories WHERE company_code = :company_code ORDER BY display_order ASC, created_at ASC """) with engine.connect() as conn: categories_result = conn.execute(categories_query, {'company_code': company_code}).fetchall() # 获取关键因素详情 factors_query = text(""" SELECT kf.category_id, kf.factor_name, kf.factor_type, kf.factor_value, kf.factor_unit, kf.factor_desc, kf.impact_direction, kf.impact_weight, kf.report_period, kf.year_on_year, kf.data_source, kf.created_at, kf.updated_at FROM company_key_factors kf WHERE kf.company_code = :company_code """) params = {'company_code': company_code} # 如果指定了报告期,添加筛选条件 if report_period: factors_query = text(""" SELECT kf.category_id, kf.factor_name, kf.factor_type, kf.factor_value, kf.factor_unit, kf.factor_desc, kf.impact_direction, kf.impact_weight, kf.report_period, kf.year_on_year, kf.data_source, kf.created_at, kf.updated_at FROM company_key_factors kf WHERE kf.company_code = :company_code AND kf.report_period = :report_period ORDER BY kf.impact_weight DESC, kf.updated_at DESC """) params['report_period'] = report_period else: factors_query = text(""" SELECT kf.category_id, kf.factor_name, kf.factor_type, kf.factor_value, kf.factor_unit, kf.factor_desc, kf.impact_direction, kf.impact_weight, kf.report_period, kf.year_on_year, kf.data_source, kf.created_at, kf.updated_at FROM company_key_factors kf WHERE kf.company_code = :company_code ORDER BY kf.report_period DESC, kf.impact_weight DESC, kf.updated_at DESC """) with engine.connect() as conn: factors_result = conn.execute(factors_query, params).fetchall() # 获取发展时间线事件 timeline_query = text(""" SELECT event_date, event_type, event_title, event_desc, impact_score, is_positive, related_products, related_partners, financial_impact, created_at FROM company_timeline_events WHERE company_code = :company_code ORDER BY event_date DESC LIMIT :limit """) with engine.connect() as conn: timeline_result = conn.execute(timeline_query, {'company_code': company_code, 'limit': event_limit}).fetchall() # 构建关键因素数据结构 key_factors_data = {} factors_by_category = {} # 先建立类别索引 categories_map = {} for category in categories_result: categories_map[category.id] = { 'category_name': category.category_name, 'category_desc': category.category_desc, 'display_order': category.display_order, 'factors': [] } # 将因素分组到类别中 for factor in factors_result: factor_data = { 'factor_name': factor.factor_name, 'factor_type': factor.factor_type, 'factor_value': factor.factor_value, 'factor_unit': factor.factor_unit, 'factor_desc': factor.factor_desc, 'impact_direction': factor.impact_direction, 'impact_weight': factor.impact_weight, 'report_period': factor.report_period, 'year_on_year': format_decimal(factor.year_on_year), 'data_source': factor.data_source, 'updated_at': factor.updated_at.strftime('%Y-%m-%d %H:%M:%S') if factor.updated_at else None } category_id = factor.category_id if category_id and category_id in categories_map: categories_map[category_id]['factors'].append(factor_data) # 构建时间线数据 timeline_data = [] for event in timeline_result: timeline_data.append({ 'event_date': event.event_date.strftime('%Y-%m-%d') if event.event_date else None, 'event_type': event.event_type, 'event_title': event.event_title, 'event_desc': event.event_desc, 'impact_metrics': { 'impact_score': event.impact_score, 'is_positive': event.is_positive }, 'related_info': { 'related_products': event.related_products, 'related_partners': event.related_partners, 'financial_impact': event.financial_impact }, 'created_at': event.created_at.strftime('%Y-%m-%d %H:%M:%S') if event.created_at else None }) # 统计信息 total_factors = len(factors_result) positive_events = len([e for e in timeline_result if e.is_positive]) negative_events = len(timeline_result) - positive_events response_data = { 'company_code': company_code, 'key_factors': { 'categories': list(categories_map.values()), 'total_factors': total_factors, 'report_period': report_period }, 'development_timeline': { 'events': timeline_data, 'statistics': { 'total_events': len(timeline_data), 'positive_events': positive_events, 'negative_events': negative_events, 'event_types': list(set(event.event_type for event in timeline_result if event.event_type)) } } } return jsonify({ 'success': True, 'data': response_data }) except Exception as e: return jsonify({ 'success': False, 'error': str(e) }), 500 # ============================================ # 模拟盘服务函数 # ============================================ def get_or_create_simulation_account(user_id): """获取或创建模拟账户""" account = SimulationAccount.query.filter_by(user_id=user_id).first() if not account: account = SimulationAccount( user_id=user_id, account_name=f'模拟账户_{user_id}', initial_capital=1000000.00, available_cash=1000000.00 ) db.session.add(account) db.session.commit() return account def is_trading_time(): """判断是否为交易时间""" now = beijing_now() # 检查是否为工作日 if now.weekday() >= 5: # 周六日 return False # 检查是否为交易时间 current_time = now.time() morning_start = dt_time(9, 30) morning_end = dt_time(11, 30) afternoon_start = dt_time(13, 0) afternoon_end = dt_time(15, 0) if (morning_start <= current_time <= morning_end) or \ (afternoon_start <= current_time <= afternoon_end): return True return False def get_latest_price_from_clickhouse(stock_code): """从ClickHouse获取最新价格(优先分钟数据,备选日线数据)""" try: client = get_clickhouse_client() # 确保stock_code包含后缀 if '.' not in stock_code: stock_code = f"{stock_code}.SH" if stock_code.startswith('6') else f"{stock_code}.SZ" # 1. 首先尝试获取最新的分钟数据(近30天) minute_query = """ SELECT close, timestamp FROM stock_minute WHERE code = %(code)s AND timestamp >= today() - 30 ORDER BY timestamp DESC LIMIT 1 \ """ result = client.execute(minute_query, {'code': stock_code}) if result: return float(result[0][0]), result[0][1] # 2. 如果没有分钟数据,获取最新的日线收盘价 daily_query = """ SELECT close, date FROM stock_daily WHERE code = %(code)s AND date >= today() - 90 ORDER BY date DESC LIMIT 1 \ """ daily_result = client.execute(daily_query, {'code': stock_code}) if daily_result: return float(daily_result[0][0]), daily_result[0][1] # 3. 如果还是没有,尝试从其他表获取(如果有的话) fallback_query = """ SELECT close_price, trade_date FROM stock_minute_kline WHERE stock_code = %(code6)s AND trade_date >= today() - 30 ORDER BY trade_date DESC, trade_time DESC LIMIT 1 \ """ # 提取6位代码 code6 = stock_code.split('.')[0] fallback_result = client.execute(fallback_query, {'code6': code6}) if fallback_result: return float(fallback_result[0][0]), fallback_result[0][1] print(f"警告: 无法获取股票 {stock_code} 的价格数据") return None, None except Exception as e: print(f"获取最新价格失败 {stock_code}: {e}") return None, None def get_next_minute_price(stock_code, order_time): """获取下单后一分钟内的收盘价作为成交价""" try: client = get_clickhouse_client() # 确保stock_code包含后缀 if '.' not in stock_code: stock_code = f"{stock_code}.SH" if stock_code.startswith('6') else f"{stock_code}.SZ" # 获取下单后一分钟内的数据 query = """ SELECT close, timestamp FROM stock_minute WHERE code = %(code)s AND timestamp \ > %(order_time)s AND timestamp <= %(end_time)s ORDER BY timestamp ASC LIMIT 1 \ """ end_time = order_time + timedelta(minutes=1) result = client.execute(query, { 'code': stock_code, 'order_time': order_time, 'end_time': end_time }) if result: return float(result[0][0]), result[0][1] # 如果一分钟内没有数据,获取最近的数据 query = """ SELECT close, timestamp FROM stock_minute WHERE code = %(code)s AND timestamp \ > %(order_time)s ORDER BY timestamp ASC LIMIT 1 \ """ result = client.execute(query, { 'code': stock_code, 'order_time': order_time }) if result: return float(result[0][0]), result[0][1] # 如果没有后续分钟数据,使用最新可用价格 print(f"没有找到下单后的分钟数据,使用最新价格: {stock_code}") return get_latest_price_from_clickhouse(stock_code) except Exception as e: print(f"获取成交价格失败: {e}") # 出错时也尝试获取最新价格 return get_latest_price_from_clickhouse(stock_code) def validate_and_get_stock_info(stock_input): """验证股票输入并获取标准代码和名称 支持输入格式: - 股票代码:600519 或 600519.SH - 股票名称:贵州茅台 - 拼音首字母:gzmt - 名称(代码):贵州茅台(600519) 返回: (stock_code_with_suffix, stock_code_6digit, stock_name) 或 (None, None, None) """ # 先尝试标准化输入 code6, name_from_input = _normalize_stock_input(stock_input) if code6: # 如果能解析出6位代码,查询股票名称 stock_name = name_from_input or _query_stock_name_by_code(code6) stock_code_full = f"{code6}.SH" if code6.startswith('6') else f"{code6}.SZ" return stock_code_full, code6, stock_name # 如果不是标准代码格式,尝试搜索 with engine.connect() as conn: search_sql = text(""" SELECT DISTINCT SECCODE as stock_code, SECNAME as stock_name FROM ea_stocklist WHERE ( UPPER(SECCODE) = UPPER(:exact_match) OR UPPER(SECNAME) = UPPER(:exact_match) OR UPPER(F001V) = UPPER(:exact_match) ) AND F011V = '正常上市' AND F003V IN ('A股', 'B股') LIMIT 1 """) result = conn.execute(search_sql, { 'exact_match': stock_input.upper() }).fetchone() if result: code6 = result.stock_code stock_name = result.stock_name stock_code_full = f"{code6}.SH" if code6.startswith('6') else f"{code6}.SZ" return stock_code_full, code6, stock_name return None, None, None def execute_simulation_order(order): """执行模拟订单(优化版)""" try: # 标准化股票代码 stock_code_full, code6, stock_name = validate_and_get_stock_info(order.stock_code) if not stock_code_full: order.status = 'REJECTED' order.reject_reason = '无效的股票代码' db.session.commit() return False # 更新订单的股票信息 order.stock_code = stock_code_full order.stock_name = stock_name # 获取成交价格(下单后一分钟的收盘价) filled_price, filled_time = get_next_minute_price(stock_code_full, order.order_time) if not filled_price: # 如果无法获取价格,订单保持PENDING状态,等待后台处理 order.status = 'PENDING' db.session.commit() return True # 返回True表示下单成功,但未成交 # 更新订单信息 order.filled_qty = order.order_qty order.filled_price = filled_price order.filled_amount = filled_price * order.order_qty order.filled_time = filled_time or beijing_now() # 计算费用 order.calculate_fees() # 获取账户 account = SimulationAccount.query.get(order.account_id) if order.order_type == 'BUY': # 买入操作 total_cost = float(order.filled_amount) + float(order.total_fee) # 检查资金是否充足 if float(account.available_cash) < total_cost: order.status = 'REJECTED' order.reject_reason = '可用资金不足' db.session.commit() return False # 扣除资金 account.available_cash -= Decimal(str(total_cost)) # 更新或创建持仓 position = SimulationPosition.query.filter_by( account_id=account.id, stock_code=order.stock_code ).first() if position: # 更新持仓 total_cost_before = float(position.avg_cost) * position.position_qty total_cost_after = total_cost_before + float(order.filled_amount) total_qty_after = position.position_qty + order.filled_qty position.avg_cost = Decimal(str(total_cost_after / total_qty_after)) position.position_qty = total_qty_after # 今日买入,T+1才可用 position.frozen_qty += order.filled_qty else: # 创建新持仓 position = SimulationPosition( account_id=account.id, stock_code=order.stock_code, stock_name=order.stock_name, position_qty=order.filled_qty, available_qty=0, # T+1 frozen_qty=order.filled_qty, # 今日买入冻结 avg_cost=order.filled_price, current_price=order.filled_price ) db.session.add(position) # 更新持仓市值 position.update_market_value(order.filled_price) else: # SELL # 卖出操作 print(f"🔍 调试:查找持仓,账户ID: {account.id}, 股票代码: {order.stock_code}") # 先尝试用完整格式查找 position = SimulationPosition.query.filter_by( account_id=account.id, stock_code=order.stock_code ).first() # 如果没找到,尝试用6位数字格式查找 if not position and '.' in order.stock_code: code6 = order.stock_code.split('.')[0] print(f"🔍 调试:尝试用6位格式查找: {code6}") position = SimulationPosition.query.filter_by( account_id=account.id, stock_code=code6 ).first() print(f"🔍 调试:找到持仓: {position}") if position: print( f"🔍 调试:持仓详情 - 股票代码: {position.stock_code}, 持仓数量: {position.position_qty}, 可用数量: {position.available_qty}") # 检查持仓是否存在 if not position: order.status = 'REJECTED' order.reject_reason = '持仓不存在' db.session.commit() return False # 检查总持仓数量是否足够(包括冻结的) total_holdings = position.position_qty if total_holdings < order.order_qty: order.status = 'REJECTED' order.reject_reason = f'持仓数量不足,当前持仓: {total_holdings} 股,需要: {order.order_qty} 股' db.session.commit() return False # 如果可用数量不足,但总持仓足够,则从冻结数量中解冻 if position.available_qty < order.order_qty: # 计算需要解冻的数量 need_to_unfreeze = order.order_qty - position.available_qty if position.frozen_qty >= need_to_unfreeze: # 解冻部分冻结数量 position.frozen_qty -= need_to_unfreeze position.available_qty += need_to_unfreeze print(f"解冻 {need_to_unfreeze} 股用于卖出") else: order.status = 'REJECTED' order.reject_reason = f'可用数量不足,可用: {position.available_qty} 股,冻结: {position.frozen_qty} 股,需要: {order.order_qty} 股' db.session.commit() return False # 更新持仓 position.position_qty -= order.filled_qty position.available_qty -= order.filled_qty # 增加资金 account.available_cash += Decimal(str(float(order.filled_amount) - float(order.total_fee))) # 如果全部卖出,删除持仓记录 if position.position_qty == 0: db.session.delete(position) # 创建成交记录 transaction = SimulationTransaction( account_id=account.id, order_id=order.id, transaction_no=f"T{int(beijing_now().timestamp() * 1000000)}", stock_code=order.stock_code, stock_name=order.stock_name, transaction_type=order.order_type, transaction_price=order.filled_price, transaction_qty=order.filled_qty, transaction_amount=order.filled_amount, commission=order.commission, stamp_tax=order.stamp_tax, transfer_fee=order.transfer_fee, total_fee=order.total_fee, transaction_time=order.filled_time, settlement_date=(order.filled_time + timedelta(days=1)).date() ) db.session.add(transaction) # 更新订单状态 order.status = 'FILLED' # 更新账户总资产 update_account_assets(account) db.session.commit() return True except Exception as e: print(f"执行订单失败: {e}") db.session.rollback() return False def update_account_assets(account): """更新账户资产(轻量级版本,不实时获取价格)""" try: # 只计算已有的持仓市值,不实时获取价格 # 价格更新由后台脚本负责 positions = SimulationPosition.query.filter_by(account_id=account.id).all() total_market_value = sum(position.market_value or Decimal('0') for position in positions) account.position_value = total_market_value account.calculate_total_assets() db.session.commit() except Exception as e: print(f"更新账户资产失败: {e}") db.session.rollback() def update_all_positions_price(): """更新所有持仓的最新价格(定时任务调用)""" try: positions = SimulationPosition.query.all() for position in positions: latest_price, _ = get_latest_price_from_clickhouse(position.stock_code) if latest_price: # 记录昨日收盘价(用于计算今日盈亏) yesterday_close = position.current_price # 更新市值 position.update_market_value(latest_price) # 计算今日盈亏 position.today_profit = (Decimal(str(latest_price)) - yesterday_close) * position.position_qty position.today_profit_rate = ((Decimal( str(latest_price)) - yesterday_close) / yesterday_close * 100) if yesterday_close > 0 else 0 db.session.commit() except Exception as e: print(f"更新持仓价格失败: {e}") db.session.rollback() def process_t1_settlement(): """处理T+1结算(每日收盘后运行)""" try: # 获取所有需要结算的持仓 positions = SimulationPosition.query.filter(SimulationPosition.frozen_qty > 0).all() for position in positions: # 将冻结数量转为可用数量 position.available_qty += position.frozen_qty position.frozen_qty = 0 db.session.commit() except Exception as e: print(f"T+1结算失败: {e}") db.session.rollback() # ============================================ # 模拟盘API接口 # ============================================ @app.route('/api/simulation/account', methods=['GET']) @login_required def get_simulation_account(): """获取模拟账户信息""" try: account = get_or_create_simulation_account(current_user.id) # 更新账户资产 update_account_assets(account) return jsonify({ 'success': True, 'data': { 'account_id': account.id, 'account_name': account.account_name, 'initial_capital': float(account.initial_capital), 'available_cash': float(account.available_cash), 'frozen_cash': float(account.frozen_cash), 'position_value': float(account.position_value), 'total_assets': float(account.total_assets), 'total_profit': float(account.total_profit), 'total_profit_rate': float(account.total_profit_rate), 'daily_profit': float(account.daily_profit), 'daily_profit_rate': float(account.daily_profit_rate), 'created_at': account.created_at.isoformat(), 'updated_at': account.updated_at.isoformat() } }) except Exception as e: return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/simulation/positions', methods=['GET']) @login_required def get_simulation_positions(): """获取模拟持仓列表(优化版本,使用缓存的价格数据)""" try: account = get_or_create_simulation_account(current_user.id) # 直接获取持仓数据,不实时更新价格(由后台脚本负责) positions = SimulationPosition.query.filter_by(account_id=account.id).all() positions_data = [] for position in positions: positions_data.append({ 'id': position.id, 'stock_code': position.stock_code, 'stock_name': position.stock_name, 'position_qty': position.position_qty, 'available_qty': position.available_qty, 'frozen_qty': position.frozen_qty, 'avg_cost': float(position.avg_cost), 'current_price': float(position.current_price or 0), 'market_value': float(position.market_value or 0), 'profit': float(position.profit or 0), 'profit_rate': float(position.profit_rate or 0), 'today_profit': float(position.today_profit or 0), 'today_profit_rate': float(position.today_profit_rate or 0), 'updated_at': position.updated_at.isoformat() }) return jsonify({ 'success': True, 'data': positions_data }) except Exception as e: return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/simulation/orders', methods=['GET']) @login_required def get_simulation_orders(): """获取模拟订单列表""" try: account = get_or_create_simulation_account(current_user.id) # 获取查询参数 status = request.args.get('status') # 订单状态筛选 date_str = request.args.get('date') # 日期筛选 limit = request.args.get('limit', 50, type=int) query = SimulationOrder.query.filter_by(account_id=account.id) if status: query = query.filter_by(status=status) if date_str: try: date = datetime.strptime(date_str, '%Y-%m-%d').date() start_time = datetime.combine(date, dt_time(0, 0, 0)) end_time = datetime.combine(date, dt_time(23, 59, 59)) query = query.filter(SimulationOrder.order_time.between(start_time, end_time)) except ValueError: pass orders = query.order_by(SimulationOrder.order_time.desc()).limit(limit).all() orders_data = [] for order in orders: orders_data.append({ 'id': order.id, 'order_no': order.order_no, 'stock_code': order.stock_code, 'stock_name': order.stock_name, 'order_type': order.order_type, 'price_type': order.price_type, 'order_price': float(order.order_price) if order.order_price else None, 'order_qty': order.order_qty, 'filled_qty': order.filled_qty, 'filled_price': float(order.filled_price) if order.filled_price else None, 'filled_amount': float(order.filled_amount) if order.filled_amount else None, 'commission': float(order.commission), 'stamp_tax': float(order.stamp_tax), 'transfer_fee': float(order.transfer_fee), 'total_fee': float(order.total_fee), 'status': order.status, 'reject_reason': order.reject_reason, 'order_time': order.order_time.isoformat(), 'filled_time': order.filled_time.isoformat() if order.filled_time else None }) return jsonify({ 'success': True, 'data': orders_data }) except Exception as e: return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/simulation/place-order', methods=['POST']) @login_required def place_simulation_order(): """下单""" try: # 移除交易时间检查,允许7x24小时下单 # 非交易时间下的单子会保持PENDING状态,等待行情数据 data = request.get_json() stock_code = data.get('stock_code') order_type = data.get('order_type') # BUY/SELL order_qty = data.get('order_qty') price_type = data.get('price_type', 'MARKET') # 目前只支持市价单 # 标准化股票代码格式 if stock_code and '.' not in stock_code: # 如果没有后缀,根据股票代码添加后缀 if stock_code.startswith('6'): stock_code = f"{stock_code}.SH" elif stock_code.startswith('0') or stock_code.startswith('3'): stock_code = f"{stock_code}.SZ" # 参数验证 if not all([stock_code, order_type, order_qty]): return jsonify({'success': False, 'error': '缺少必要参数'}), 400 if order_type not in ['BUY', 'SELL']: return jsonify({'success': False, 'error': '订单类型错误'}), 400 order_qty = int(order_qty) if order_qty <= 0 or order_qty % 100 != 0: return jsonify({'success': False, 'error': '下单数量必须为100的整数倍'}), 400 # 获取账户 account = get_or_create_simulation_account(current_user.id) # 获取股票信息 stock_name = None with engine.connect() as conn: result = conn.execute(text( "SELECT SECNAME FROM ea_stocklist WHERE SECCODE = :code" ), {"code": stock_code.split('.')[0]}).fetchone() if result: stock_name = result[0] # 创建订单 order = SimulationOrder( account_id=account.id, order_no=f"O{int(beijing_now().timestamp() * 1000000)}", stock_code=stock_code, stock_name=stock_name, order_type=order_type, price_type=price_type, order_qty=order_qty, status='PENDING' ) db.session.add(order) db.session.commit() # 执行订单 print(f"🔍 调试:开始执行订单,股票代码: {order.stock_code}, 订单类型: {order.order_type}") success = execute_simulation_order(order) print(f"🔍 调试:订单执行结果: {success}, 订单状态: {order.status}") if success: # 重新查询订单状态,因为可能在execute_simulation_order中被修改 db.session.refresh(order) if order.status == 'FILLED': return jsonify({ 'success': True, 'message': '订单执行成功,已成交', 'data': { 'order_no': order.order_no, 'status': 'FILLED', 'filled_price': float(order.filled_price) if order.filled_price else None, 'filled_qty': order.filled_qty, 'filled_amount': float(order.filled_amount) if order.filled_amount else None, 'total_fee': float(order.total_fee) } }) elif order.status == 'PENDING': return jsonify({ 'success': True, 'message': '订单提交成功,等待行情数据成交', 'data': { 'order_no': order.order_no, 'status': 'PENDING', 'order_qty': order.order_qty, 'order_price': float(order.order_price) if order.order_price else None } }) else: return jsonify({ 'success': False, 'error': order.reject_reason or '订单状态异常' }), 400 else: return jsonify({ 'success': False, 'error': order.reject_reason or '订单执行失败' }), 400 except Exception as e: db.session.rollback() return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/simulation/cancel-order/', methods=['POST']) @login_required def cancel_simulation_order(order_id): """撤销订单""" try: account = get_or_create_simulation_account(current_user.id) order = SimulationOrder.query.filter_by( id=order_id, account_id=account.id, status='PENDING' ).first() if not order: return jsonify({'success': False, 'error': '订单不存在或无法撤销'}), 404 order.status = 'CANCELLED' order.cancel_time = beijing_now() db.session.commit() return jsonify({ 'success': True, 'message': '订单已撤销' }) except Exception as e: db.session.rollback() return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/simulation/transactions', methods=['GET']) @login_required def get_simulation_transactions(): """获取成交记录""" try: account = get_or_create_simulation_account(current_user.id) # 获取查询参数 date_str = request.args.get('date') limit = request.args.get('limit', 100, type=int) query = SimulationTransaction.query.filter_by(account_id=account.id) if date_str: try: date = datetime.strptime(date_str, '%Y-%m-%d').date() start_time = datetime.combine(date, dt_time(0, 0, 0)) end_time = datetime.combine(date, dt_time(23, 59, 59)) query = query.filter(SimulationTransaction.transaction_time.between(start_time, end_time)) except ValueError: pass transactions = query.order_by(SimulationTransaction.transaction_time.desc()).limit(limit).all() transactions_data = [] for trans in transactions: transactions_data.append({ 'id': trans.id, 'transaction_no': trans.transaction_no, 'stock_code': trans.stock_code, 'stock_name': trans.stock_name, 'transaction_type': trans.transaction_type, 'transaction_price': float(trans.transaction_price), 'transaction_qty': trans.transaction_qty, 'transaction_amount': float(trans.transaction_amount), 'commission': float(trans.commission), 'stamp_tax': float(trans.stamp_tax), 'transfer_fee': float(trans.transfer_fee), 'total_fee': float(trans.total_fee), 'transaction_time': trans.transaction_time.isoformat(), 'settlement_date': trans.settlement_date.isoformat() if trans.settlement_date else None }) return jsonify({ 'success': True, 'data': transactions_data }) except Exception as e: return jsonify({'success': False, 'error': str(e)}), 500 def get_simulation_statistics(): """获取模拟交易统计""" try: account = get_or_create_simulation_account(current_user.id) # 获取统计时间范围 days = request.args.get('days', 30, type=int) end_date = beijing_now().date() start_date = end_date - timedelta(days=days) # 查询日统计数据 daily_stats = SimulationDailyStats.query.filter( SimulationDailyStats.account_id == account.id, SimulationDailyStats.stat_date >= start_date, SimulationDailyStats.stat_date <= end_date ).order_by(SimulationDailyStats.stat_date).all() # 查询总体统计 total_transactions = SimulationTransaction.query.filter_by(account_id=account.id).count() win_transactions = SimulationTransaction.query.filter( SimulationTransaction.account_id == account.id, SimulationTransaction.transaction_type == 'SELL' ).all() win_count = 0 total_profit = Decimal('0') for trans in win_transactions: # 查找对应的买入记录计算盈亏 position = SimulationPosition.query.filter_by( account_id=account.id, stock_code=trans.stock_code ).first() if position and trans.transaction_price > position.avg_cost: win_count += 1 profit = (trans.transaction_price - position.avg_cost) * trans.transaction_qty if position else 0 total_profit += profit # 构建日收益曲线 daily_returns = [] for stat in daily_stats: daily_returns.append({ 'date': stat.stat_date.isoformat(), 'daily_profit': float(stat.daily_profit), 'daily_profit_rate': float(stat.daily_profit_rate), 'total_profit': float(stat.total_profit), 'total_profit_rate': float(stat.total_profit_rate), 'closing_assets': float(stat.closing_assets) }) return jsonify({ 'success': True, 'data': { 'summary': { 'total_transactions': total_transactions, 'win_count': win_count, 'win_rate': (win_count / len(win_transactions) * 100) if win_transactions else 0, 'total_profit': float(total_profit), 'average_profit_per_trade': float(total_profit / len(win_transactions)) if win_transactions else 0 }, 'daily_returns': daily_returns } }) except Exception as e: return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/simulation/t1-settlement', methods=['POST']) @login_required def trigger_t1_settlement(): """手动触发T+1结算""" try: # 导入后台处理器的函数 from simulation_background_processor import process_t1_settlement # 执行T+1结算 process_t1_settlement() return jsonify({ 'success': True, 'message': 'T+1结算执行成功' }) except Exception as e: return jsonify({ 'success': False, 'error': str(e) }), 500 @app.route('/api/simulation/debug-positions', methods=['GET']) @login_required def debug_positions(): """调试接口:查看持仓数据""" try: account = get_or_create_simulation_account(current_user.id) positions = SimulationPosition.query.filter_by(account_id=account.id).all() positions_data = [] for position in positions: positions_data.append({ 'stock_code': position.stock_code, 'stock_name': position.stock_name, 'position_qty': position.position_qty, 'available_qty': position.available_qty, 'frozen_qty': position.frozen_qty, 'avg_cost': float(position.avg_cost), 'current_price': float(position.current_price or 0) }) return jsonify({ 'success': True, 'data': positions_data }) except Exception as e: return jsonify({ 'success': False, 'error': str(e) }), 500 @app.route('/api/simulation/debug-transactions', methods=['GET']) @login_required def debug_transactions(): """调试接口:查看成交记录数据""" try: account = get_or_create_simulation_account(current_user.id) transactions = SimulationTransaction.query.filter_by(account_id=account.id).all() transactions_data = [] for trans in transactions: transactions_data.append({ 'id': trans.id, 'transaction_no': trans.transaction_no, 'stock_code': trans.stock_code, 'stock_name': trans.stock_name, 'transaction_type': trans.transaction_type, 'transaction_price': float(trans.transaction_price), 'transaction_qty': trans.transaction_qty, 'transaction_amount': float(trans.transaction_amount), 'commission': float(trans.commission), 'stamp_tax': float(trans.stamp_tax), 'transfer_fee': float(trans.transfer_fee), 'total_fee': float(trans.total_fee), 'transaction_time': trans.transaction_time.isoformat(), 'settlement_date': trans.settlement_date.isoformat() if trans.settlement_date else None }) return jsonify({ 'success': True, 'data': transactions_data, 'count': len(transactions_data) }) except Exception as e: return jsonify({ 'success': False, 'error': str(e) }), 500 @app.route('/api/simulation/daily-settlement', methods=['POST']) @login_required def trigger_daily_settlement(): """手动触发日结算""" try: # 导入后台处理器的函数 from simulation_background_processor import generate_daily_stats # 执行日结算 generate_daily_stats() return jsonify({ 'success': True, 'message': '日结算执行成功' }) except Exception as e: return jsonify({ 'success': False, 'error': str(e) }), 500 @app.route('/api/simulation/reset', methods=['POST']) @login_required def reset_simulation_account(): """重置模拟账户""" try: account = SimulationAccount.query.filter_by(user_id=current_user.id).first() if account: # 删除所有相关数据 SimulationPosition.query.filter_by(account_id=account.id).delete() SimulationOrder.query.filter_by(account_id=account.id).delete() SimulationTransaction.query.filter_by(account_id=account.id).delete() SimulationDailyStats.query.filter_by(account_id=account.id).delete() # 重置账户数据 account.available_cash = account.initial_capital account.frozen_cash = Decimal('0') account.position_value = Decimal('0') account.total_assets = account.initial_capital account.total_profit = Decimal('0') account.total_profit_rate = Decimal('0') account.daily_profit = Decimal('0') account.daily_profit_rate = Decimal('0') account.updated_at = beijing_now() db.session.commit() return jsonify({ 'success': True, 'message': '模拟账户已重置' }) except Exception as e: db.session.rollback() return jsonify({'success': False, 'error': str(e)}), 500 # =========================== # 预测市场 API 路由 # 请将此文件内容插入到 app.py 的 `if __name__ == '__main__':` 之前 # =========================== # --- 积分系统 API --- @app.route('/api/prediction/credit/account', methods=['GET']) @login_required def get_credit_account(): """获取用户积分账户""" try: account = UserCreditAccount.query.filter_by(user_id=current_user.id).first() # 如果账户不存在,自动创建 if not account: account = UserCreditAccount(user_id=current_user.id) db.session.add(account) db.session.commit() return jsonify({ 'success': True, 'data': { 'balance': float(account.balance), 'frozen_balance': float(account.frozen_balance), 'available_balance': float(account.balance - account.frozen_balance), 'total_earned': float(account.total_earned), 'total_spent': float(account.total_spent), 'last_daily_bonus_at': account.last_daily_bonus_at.isoformat() if account.last_daily_bonus_at else None } }) except Exception as e: return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/prediction/credit/daily-bonus', methods=['POST']) @login_required def claim_daily_bonus(): """领取每日奖励(100积分)""" try: account = UserCreditAccount.query.filter_by(user_id=current_user.id).first() if not account: account = UserCreditAccount(user_id=current_user.id) db.session.add(account) # 检查是否已领取今日奖励 today = beijing_now().date() if account.last_daily_bonus_at and account.last_daily_bonus_at.date() == today: return jsonify({ 'success': False, 'error': '今日奖励已领取' }), 400 # 发放奖励 bonus_amount = 100.0 account.balance += bonus_amount account.total_earned += bonus_amount account.last_daily_bonus_at = beijing_now() # 记录交易 transaction = CreditTransaction( user_id=current_user.id, transaction_type='daily_bonus', amount=bonus_amount, balance_after=account.balance, description='每日登录奖励' ) db.session.add(transaction) db.session.commit() return jsonify({ 'success': True, 'message': f'领取成功,获得 {bonus_amount} 积分', 'data': { 'bonus_amount': bonus_amount, 'new_balance': float(account.balance) } }) except Exception as e: db.session.rollback() return jsonify({'success': False, 'error': str(e)}), 500 # --- 预测话题 API --- @app.route('/api/prediction/topics', methods=['POST']) @login_required def create_prediction_topic(): """创建预测话题(消耗100积分)""" try: data = request.get_json() title = data.get('title', '').strip() description = data.get('description', '').strip() category = data.get('category', 'stock') deadline_str = data.get('deadline') # 验证参数 if not title or len(title) < 5: return jsonify({'success': False, 'error': '标题至少5个字符'}), 400 if not deadline_str: return jsonify({'success': False, 'error': '请设置截止时间'}), 400 # 解析截止时间(移除时区信息以匹配数据库格式) deadline = datetime.fromisoformat(deadline_str.replace('Z', '+00:00')) # 移除时区信息,转换为naive datetime if deadline.tzinfo is not None: deadline = deadline.replace(tzinfo=None) if deadline <= beijing_now(): return jsonify({'success': False, 'error': '截止时间必须在未来'}), 400 # 检查积分账户 account = UserCreditAccount.query.filter_by(user_id=current_user.id).first() if not account or account.balance < 100: return jsonify({'success': False, 'error': '积分不足(需要100积分)'}), 400 # 扣除创建费用 create_cost = 100.0 account.balance -= create_cost account.total_spent += create_cost # 创建话题 topic = PredictionTopic( creator_id=current_user.id, title=title, description=description, category=category, deadline=deadline ) db.session.add(topic) # 记录积分交易 transaction = CreditTransaction( user_id=current_user.id, transaction_type='create_topic', amount=-create_cost, balance_after=account.balance, description=f'创建预测话题:{title}' ) db.session.add(transaction) db.session.commit() return jsonify({ 'success': True, 'message': '话题创建成功', 'data': { 'topic_id': topic.id, 'title': topic.title, 'new_balance': float(account.balance) } }) except Exception as e: db.session.rollback() return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/prediction/topics', methods=['GET']) def get_prediction_topics(): """获取预测话题列表""" try: status = request.args.get('status', 'active') category = request.args.get('category') sort_by = request.args.get('sort_by', 'created_at') page = request.args.get('page', 1, type=int) per_page = request.args.get('per_page', 20, type=int) # 构建查询 query = PredictionTopic.query if status: query = query.filter_by(status=status) if category: query = query.filter_by(category=category) # 排序 if sort_by == 'hot': query = query.order_by(desc(PredictionTopic.views_count)) elif sort_by == 'participants': query = query.order_by(desc(PredictionTopic.participants_count)) else: query = query.order_by(desc(PredictionTopic.created_at)) # 分页 pagination = query.paginate(page=page, per_page=per_page, error_out=False) topics = pagination.items # 格式化返回数据 topics_data = [] for topic in topics: # 计算市场倾向 total_shares = topic.yes_total_shares + topic.no_total_shares yes_prob = (topic.yes_total_shares / total_shares * 100) if total_shares > 0 else 50.0 # 处理datetime,确保移除时区信息 deadline = topic.deadline if hasattr(deadline, 'replace') and deadline.tzinfo is not None: deadline = deadline.replace(tzinfo=None) created_at = topic.created_at if hasattr(created_at, 'replace') and created_at.tzinfo is not None: created_at = created_at.replace(tzinfo=None) topics_data.append({ 'id': topic.id, 'title': topic.title, 'description': topic.description, 'category': topic.category, 'status': topic.status, 'yes_price': float(topic.yes_price), 'no_price': float(topic.no_price), 'yes_probability': round(yes_prob, 1), 'total_pool': float(topic.total_pool), 'yes_lord': { 'id': topic.yes_lord.id, 'username': topic.yes_lord.username, 'nickname': topic.yes_lord.nickname or topic.yes_lord.username, 'avatar_url': topic.yes_lord.avatar_url } if topic.yes_lord else None, 'no_lord': { 'id': topic.no_lord.id, 'username': topic.no_lord.username, 'nickname': topic.no_lord.nickname or topic.no_lord.username, 'avatar_url': topic.no_lord.avatar_url } if topic.no_lord else None, 'deadline': deadline.isoformat() if deadline else None, 'created_at': created_at.isoformat() if created_at else None, 'views_count': topic.views_count, 'comments_count': topic.comments_count, 'participants_count': topic.participants_count, 'creator': { 'id': topic.creator.id, 'username': topic.creator.username, 'nickname': topic.creator.nickname or topic.creator.username } }) return jsonify({ 'success': True, 'data': topics_data, 'pagination': { 'page': page, 'per_page': per_page, 'total': pagination.total, 'pages': pagination.pages, 'has_next': pagination.has_next, 'has_prev': pagination.has_prev } }) except Exception as e: import traceback print(f"[ERROR] 获取话题列表失败: {str(e)}") print(traceback.format_exc()) return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/prediction/topics/', methods=['GET']) def get_prediction_topic_detail(topic_id): """获取预测话题详情""" try: # 刷新会话,确保获取最新数据 db.session.expire_all() topic = PredictionTopic.query.get_or_404(topic_id) # 增加浏览量 topic.views_count += 1 db.session.commit() # 计算市场倾向 total_shares = topic.yes_total_shares + topic.no_total_shares yes_prob = (topic.yes_total_shares / total_shares * 100) if total_shares > 0 else 50.0 # 获取 TOP 5 持仓(YES 和 NO 各5个) yes_top_positions = PredictionPosition.query.filter_by( topic_id=topic_id, direction='yes' ).order_by(desc(PredictionPosition.shares)).limit(5).all() no_top_positions = PredictionPosition.query.filter_by( topic_id=topic_id, direction='no' ).order_by(desc(PredictionPosition.shares)).limit(5).all() def format_position(position): return { 'user': { 'id': position.user.id, 'username': position.user.username, 'nickname': position.user.nickname or position.user.username, 'avatar_url': position.user.avatar_url }, 'shares': position.shares, 'avg_cost': float(position.avg_cost), 'total_invested': float(position.total_invested), 'is_lord': (topic.yes_lord_id == position.user_id and position.direction == 'yes') or (topic.no_lord_id == position.user_id and position.direction == 'no') } return jsonify({ 'success': True, 'data': { 'id': topic.id, 'title': topic.title, 'description': topic.description, 'category': topic.category, 'status': topic.status, 'result': topic.result, 'yes_price': float(topic.yes_price), 'no_price': float(topic.no_price), 'yes_total_shares': topic.yes_total_shares, 'no_total_shares': topic.no_total_shares, 'yes_probability': round(yes_prob, 1), 'no_probability': round(100 - yes_prob, 1), 'total_pool': float(topic.total_pool), 'yes_lord': { 'id': topic.yes_lord.id, 'username': topic.yes_lord.username, 'nickname': topic.yes_lord.nickname or topic.yes_lord.username, 'avatar_url': topic.yes_lord.avatar_url } if topic.yes_lord else None, 'no_lord': { 'id': topic.no_lord.id, 'username': topic.no_lord.username, 'nickname': topic.no_lord.nickname or topic.no_lord.username, 'avatar_url': topic.no_lord.avatar_url } if topic.no_lord else None, 'yes_top_positions': [format_position(p) for p in yes_top_positions], 'no_top_positions': [format_position(p) for p in no_top_positions], 'deadline': topic.deadline.isoformat(), 'settled_at': topic.settled_at.isoformat() if topic.settled_at else None, 'created_at': topic.created_at.isoformat(), 'views_count': topic.views_count, 'comments_count': topic.comments_count, 'participants_count': topic.participants_count, 'creator': { 'id': topic.creator.id, 'username': topic.creator.username, 'nickname': topic.creator.nickname or topic.creator.username, 'avatar_url': topic.creator.avatar_url } } }) except Exception as e: return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/prediction/topics//settle', methods=['POST']) @login_required def settle_prediction_topic(topic_id): """结算预测话题(仅创建者可操作)""" try: topic = PredictionTopic.query.get_or_404(topic_id) # 验证权限 if topic.creator_id != current_user.id: return jsonify({'success': False, 'error': '只有创建者可以结算'}), 403 # 验证状态 if topic.status != 'active': return jsonify({'success': False, 'error': '话题已结算或已取消'}), 400 # 验证截止时间 if beijing_now() < topic.deadline: return jsonify({'success': False, 'error': '未到截止时间'}), 400 # 获取结算结果 data = request.get_json() result = data.get('result') # 'yes', 'no', 'draw' if result not in ['yes', 'no', 'draw']: return jsonify({'success': False, 'error': '无效的结算结果'}), 400 # 更新话题状态 topic.status = 'settled' topic.result = result topic.settled_at = beijing_now() # 获取获胜方的所有持仓 if result == 'draw': # 平局:所有人按投入比例分配奖池 all_positions = PredictionPosition.query.filter_by(topic_id=topic_id).all() total_invested = sum(p.total_invested for p in all_positions) for position in all_positions: if total_invested > 0: share_ratio = position.total_invested / total_invested prize = topic.total_pool * share_ratio # 发放奖金 account = UserCreditAccount.query.filter_by(user_id=position.user_id).first() if account: account.balance += prize account.total_earned += prize # 记录交易 transaction = CreditTransaction( user_id=position.user_id, transaction_type='settle_win', amount=prize, balance_after=account.balance, related_topic_id=topic_id, description=f'预测平局,获得奖池分红:{topic.title}' ) db.session.add(transaction) else: # YES 或 NO 获胜 winning_direction = result winning_positions = PredictionPosition.query.filter_by( topic_id=topic_id, direction=winning_direction ).all() if winning_positions: total_winning_shares = sum(p.shares for p in winning_positions) for position in winning_positions: # 按份额比例分配奖池 share_ratio = position.shares / total_winning_shares prize = topic.total_pool * share_ratio # 发放奖金 account = UserCreditAccount.query.filter_by(user_id=position.user_id).first() if account: account.balance += prize account.total_earned += prize # 记录交易 transaction = CreditTransaction( user_id=position.user_id, transaction_type='settle_win', amount=prize, balance_after=account.balance, related_topic_id=topic_id, description=f'预测正确,获得奖金:{topic.title}' ) db.session.add(transaction) db.session.commit() return jsonify({ 'success': True, 'message': f'话题已结算,结果为:{result}', 'data': { 'topic_id': topic.id, 'result': result, 'total_pool': float(topic.total_pool), 'settled_at': topic.settled_at.isoformat() } }) except Exception as e: db.session.rollback() return jsonify({'success': False, 'error': str(e)}), 500 # --- 交易 API --- @app.route('/api/prediction/trade/buy', methods=['POST']) @login_required def buy_prediction_shares(): """买入预测份额""" try: data = request.get_json() topic_id = data.get('topic_id') direction = data.get('direction') # 'yes' or 'no' shares = data.get('shares', 0) # 验证参数 if not topic_id or direction not in ['yes', 'no'] or shares <= 0: return jsonify({'success': False, 'error': '参数错误'}), 400 if shares > 1000: return jsonify({'success': False, 'error': '单次最多买入1000份额'}), 400 # 获取话题 topic = PredictionTopic.query.get_or_404(topic_id) if topic.status != 'active': return jsonify({'success': False, 'error': '话题已结算或已取消'}), 400 if beijing_now() >= topic.deadline: return jsonify({'success': False, 'error': '话题已截止'}), 400 # 获取积分账户 account = UserCreditAccount.query.filter_by(user_id=current_user.id).first() if not account: account = UserCreditAccount(user_id=current_user.id) db.session.add(account) db.session.flush() # 计算价格 current_price = topic.yes_price if direction == 'yes' else topic.no_price # 简化的AMM定价:price = (对应方份额 / 总份额) * 1000 total_shares = topic.yes_total_shares + topic.no_total_shares if total_shares > 0: if direction == 'yes': current_price = (topic.yes_total_shares / total_shares) * 1000 else: current_price = (topic.no_total_shares / total_shares) * 1000 else: current_price = 500.0 # 初始价格 # 买入后价格会上涨,使用平均价格 after_total = total_shares + shares if direction == 'yes': after_yes_shares = topic.yes_total_shares + shares after_price = (after_yes_shares / after_total) * 1000 else: after_no_shares = topic.no_total_shares + shares after_price = (after_no_shares / after_total) * 1000 avg_price = (current_price + after_price) / 2 # 计算费用 amount = avg_price * shares tax = amount * 0.02 # 2% 手续费 total_cost = amount + tax # 检查余额 if account.balance < total_cost: return jsonify({'success': False, 'error': '积分不足'}), 400 # 扣除费用 account.balance -= total_cost account.total_spent += total_cost # 更新话题数据 if direction == 'yes': topic.yes_total_shares += shares topic.yes_price = after_price else: topic.no_total_shares += shares topic.no_price = after_price topic.total_pool += tax # 手续费进入奖池 # 更新或创建持仓 position = PredictionPosition.query.filter_by( user_id=current_user.id, topic_id=topic_id, direction=direction ).first() if position: # 更新平均成本 old_cost = position.avg_cost * position.shares new_cost = avg_price * shares position.shares += shares position.avg_cost = (old_cost + new_cost) / position.shares position.total_invested += total_cost else: position = PredictionPosition( user_id=current_user.id, topic_id=topic_id, direction=direction, shares=shares, avg_cost=avg_price, total_invested=total_cost ) db.session.add(position) topic.participants_count += 1 # 更新领主 if direction == 'yes': # 找到YES方持仓最多的用户 top_yes = db.session.query(PredictionPosition).filter_by( topic_id=topic_id, direction='yes' ).order_by(desc(PredictionPosition.shares)).first() if top_yes: topic.yes_lord_id = top_yes.user_id else: # 找到NO方持仓最多的用户 top_no = db.session.query(PredictionPosition).filter_by( topic_id=topic_id, direction='no' ).order_by(desc(PredictionPosition.shares)).first() if top_no: topic.no_lord_id = top_no.user_id # 记录交易 transaction = PredictionTransaction( user_id=current_user.id, topic_id=topic_id, trade_type='buy', direction=direction, shares=shares, price=avg_price, amount=amount, tax=tax, total_cost=total_cost ) db.session.add(transaction) # 记录积分交易 credit_transaction = CreditTransaction( user_id=current_user.id, transaction_type='prediction_buy', amount=-total_cost, balance_after=account.balance, related_topic_id=topic_id, related_transaction_id=transaction.id, description=f'买入 {direction.upper()} 份额:{topic.title}' ) db.session.add(credit_transaction) db.session.commit() return jsonify({ 'success': True, 'message': '买入成功', 'data': { 'transaction_id': transaction.id, 'shares': shares, 'price': round(avg_price, 2), 'total_cost': round(total_cost, 2), 'tax': round(tax, 2), 'new_balance': float(account.balance), 'new_position': { 'shares': position.shares, 'avg_cost': float(position.avg_cost) } } }) except Exception as e: db.session.rollback() return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/prediction/positions', methods=['GET']) @login_required def get_user_positions(): """获取用户的所有持仓""" try: positions = PredictionPosition.query.filter_by(user_id=current_user.id).all() positions_data = [] for position in positions: topic = position.topic # 计算当前市值(如果话题还在进行中) current_value = 0 profit = 0 profit_rate = 0 if topic.status == 'active': current_price = topic.yes_price if position.direction == 'yes' else topic.no_price current_value = current_price * position.shares profit = current_value - position.total_invested profit_rate = (profit / position.total_invested * 100) if position.total_invested > 0 else 0 positions_data.append({ 'id': position.id, 'topic': { 'id': topic.id, 'title': topic.title, 'status': topic.status, 'result': topic.result, 'deadline': topic.deadline.isoformat() }, 'direction': position.direction, 'shares': position.shares, 'avg_cost': float(position.avg_cost), 'total_invested': float(position.total_invested), 'current_value': round(current_value, 2), 'profit': round(profit, 2), 'profit_rate': round(profit_rate, 2), 'created_at': position.created_at.isoformat(), 'is_lord': (topic.yes_lord_id == current_user.id and position.direction == 'yes') or (topic.no_lord_id == current_user.id and position.direction == 'no') }) return jsonify({ 'success': True, 'data': positions_data, 'count': len(positions_data) }) except Exception as e: return jsonify({'success': False, 'error': str(e)}), 500 # --- 评论 API --- @app.route('/api/prediction/topics//comments', methods=['POST']) @login_required def create_topic_comment(topic_id): """发表话题评论""" try: topic = PredictionTopic.query.get_or_404(topic_id) data = request.get_json() content = data.get('content', '').strip() parent_id = data.get('parent_id') if not content or len(content) < 2: return jsonify({'success': False, 'error': '评论内容至少2个字符'}), 400 # 创建评论 comment = TopicComment( topic_id=topic_id, user_id=current_user.id, content=content, parent_id=parent_id ) # 如果是领主评论,自动置顶 is_lord = (topic.yes_lord_id == current_user.id) or (topic.no_lord_id == current_user.id) if is_lord: comment.is_pinned = True db.session.add(comment) # 更新话题评论数 topic.comments_count += 1 db.session.commit() return jsonify({ 'success': True, 'message': '评论成功', 'data': { 'comment_id': comment.id, 'content': comment.content, 'is_pinned': comment.is_pinned, 'created_at': comment.created_at.isoformat() } }) except Exception as e: db.session.rollback() return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/prediction/topics//comments', methods=['GET']) def get_topic_comments(topic_id): """获取话题评论列表""" try: topic = PredictionTopic.query.get_or_404(topic_id) page = request.args.get('page', 1, type=int) per_page = request.args.get('per_page', 20, type=int) # 置顶评论在前,然后按时间倒序 query = TopicComment.query.filter_by( topic_id=topic_id, status='active', parent_id=None # 只获取顶级评论 ).order_by( desc(TopicComment.is_pinned), desc(TopicComment.created_at) ) pagination = query.paginate(page=page, per_page=per_page, error_out=False) comments = pagination.items def format_comment(comment): # 获取回复 replies = TopicComment.query.filter_by( parent_id=comment.id, status='active' ).order_by(TopicComment.created_at).limit(5).all() return { 'id': comment.id, 'content': comment.content, 'is_pinned': comment.is_pinned, 'likes_count': comment.likes_count, 'created_at': comment.created_at.isoformat(), 'user': { 'id': comment.user.id, 'username': comment.user.username, 'nickname': comment.user.nickname or comment.user.username, 'avatar_url': comment.user.avatar_url }, 'is_lord': (topic.yes_lord_id == comment.user_id) or (topic.no_lord_id == comment.user_id), 'replies': [{ 'id': reply.id, 'content': reply.content, 'created_at': reply.created_at.isoformat(), 'user': { 'id': reply.user.id, 'username': reply.user.username, 'nickname': reply.user.nickname or reply.user.username, 'avatar_url': reply.user.avatar_url } } for reply in replies] } comments_data = [format_comment(comment) for comment in comments] return jsonify({ 'success': True, 'data': comments_data, 'pagination': { 'page': page, 'per_page': per_page, 'total': pagination.total, 'pages': pagination.pages } }) except Exception as e: return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/prediction/comments//like', methods=['POST']) @login_required def like_topic_comment(comment_id): """点赞/取消点赞评论""" try: comment = TopicComment.query.get_or_404(comment_id) # 检查是否已点赞 existing_like = TopicCommentLike.query.filter_by( comment_id=comment_id, user_id=current_user.id ).first() if existing_like: # 取消点赞 db.session.delete(existing_like) comment.likes_count = max(0, comment.likes_count - 1) action = 'unliked' else: # 点赞 like = TopicCommentLike( comment_id=comment_id, user_id=current_user.id ) db.session.add(like) comment.likes_count += 1 action = 'liked' db.session.commit() return jsonify({ 'success': True, 'action': action, 'likes_count': comment.likes_count }) except Exception as e: db.session.rollback() return jsonify({'success': False, 'error': str(e)}), 500 # ==================== 观点IPO API ==================== @app.route('/api/prediction/comments//invest', methods=['POST']) @login_required def invest_comment(comment_id): """投资评论(观点IPO)""" try: data = request.json shares = data.get('shares', 1) # 获取评论 comment = TopicComment.query.get_or_404(comment_id) # 检查评论是否已结算 if comment.is_verified: return jsonify({'success': False, 'error': '该评论已结算,无法继续投资'}), 400 # 检查是否是自己的评论 if comment.user_id == current_user.id: return jsonify({'success': False, 'error': '不能投资自己的评论'}), 400 # 计算投资金额(简化:每份100积分基础价格 + 已有投资额/10) base_price = 100 price_increase = comment.total_investment / 10 if comment.total_investment > 0 else 0 price_per_share = base_price + price_increase amount = int(price_per_share * shares) # 获取用户积分账户 account = UserCreditAccount.query.filter_by(user_id=current_user.id).first() if not account: return jsonify({'success': False, 'error': '账户不存在'}), 404 # 检查余额 if account.balance < amount: return jsonify({'success': False, 'error': '积分不足'}), 400 # 扣减积分 account.balance -= amount # 检查是否已有投资记录 existing_investment = CommentInvestment.query.filter_by( comment_id=comment_id, user_id=current_user.id, status='active' ).first() if existing_investment: # 更新投资记录 total_shares = existing_investment.shares + shares total_amount = existing_investment.amount + amount existing_investment.shares = total_shares existing_investment.amount = total_amount existing_investment.avg_price = total_amount / total_shares else: # 创建新投资记录 investment = CommentInvestment( comment_id=comment_id, user_id=current_user.id, shares=shares, amount=amount, avg_price=price_per_share ) db.session.add(investment) comment.investor_count += 1 # 更新评论统计 comment.total_investment += amount # 记录积分交易 transaction = CreditTransaction( user_id=current_user.id, type='comment_investment', amount=-amount, balance_after=account.balance, description=f'投资评论 #{comment_id}' ) db.session.add(transaction) db.session.commit() return jsonify({ 'success': True, 'data': { 'shares': shares, 'amount': amount, 'price_per_share': price_per_share, 'total_investment': comment.total_investment, 'investor_count': comment.investor_count, 'new_balance': account.balance } }) except Exception as e: db.session.rollback() return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/prediction/comments//investments', methods=['GET']) def get_comment_investments(comment_id): """获取评论的投资列表""" try: investments = CommentInvestment.query.filter_by( comment_id=comment_id, status='active' ).all() result = [] for inv in investments: user = User.query.get(inv.user_id) result.append({ 'id': inv.id, 'user_id': inv.user_id, 'user_name': user.username if user else '未知用户', 'user_avatar': user.avatar if user else None, 'shares': inv.shares, 'amount': inv.amount, 'avg_price': inv.avg_price, 'created_at': inv.created_at.strftime('%Y-%m-%d %H:%M:%S') }) return jsonify({ 'success': True, 'data': result }) except Exception as e: return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/prediction/comments//verify', methods=['POST']) @login_required def verify_comment(comment_id): """管理员验证评论预测结果""" try: # 检查管理员权限(简化版:假设 user_id=1 是管理员) if current_user.id != 1: return jsonify({'success': False, 'error': '无权限操作'}), 403 data = request.json result = data.get('result') # 'correct' or 'incorrect' if result not in ['correct', 'incorrect']: return jsonify({'success': False, 'error': '无效的验证结果'}), 400 comment = TopicComment.query.get_or_404(comment_id) # 检查是否已验证 if comment.is_verified: return jsonify({'success': False, 'error': '该评论已验证'}), 400 # 更新验证状态 comment.is_verified = True comment.verification_result = result # 如果预测正确,进行收益分配 if result == 'correct' and comment.total_investment > 0: # 获取所有投资记录 investments = CommentInvestment.query.filter_by( comment_id=comment_id, status='active' ).all() # 计算总收益(总投资额的1.5倍) total_reward = int(comment.total_investment * 1.5) # 按份额比例分配收益 total_shares = sum([inv.shares for inv in investments]) for inv in investments: # 计算该投资者的收益 investor_reward = int((inv.shares / total_shares) * total_reward) # 获取投资者账户 account = UserCreditAccount.query.filter_by(user_id=inv.user_id).first() if account: account.balance += investor_reward # 记录积分交易 transaction = CreditTransaction( user_id=inv.user_id, type='comment_investment_profit', amount=investor_reward, balance_after=account.balance, description=f'评论投资收益 #{comment_id}' ) db.session.add(transaction) # 更新投资状态 inv.status = 'settled' # 评论作者也获得奖励(总投资额的20%) author_reward = int(comment.total_investment * 0.2) author_account = UserCreditAccount.query.filter_by(user_id=comment.user_id).first() if author_account: author_account.balance += author_reward transaction = CreditTransaction( user_id=comment.user_id, type='comment_author_bonus', amount=author_reward, balance_after=author_account.balance, description=f'评论作者奖励 #{comment_id}' ) db.session.add(transaction) db.session.commit() return jsonify({ 'success': True, 'data': { 'comment_id': comment_id, 'verification_result': result, 'total_investment': comment.total_investment } }) except Exception as e: db.session.rollback() return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/prediction/topics//bid-position', methods=['POST']) @login_required def bid_comment_position(topic_id): """竞拍评论位置(首发权拍卖)""" try: data = request.json position = data.get('position') # 1/2/3 bid_amount = data.get('bid_amount') if position not in [1, 2, 3]: return jsonify({'success': False, 'error': '无效的位置'}), 400 if bid_amount < 500: return jsonify({'success': False, 'error': '最低出价500积分'}), 400 # 获取用户积分账户 account = UserCreditAccount.query.filter_by(user_id=current_user.id).first() if not account or account.balance < bid_amount: return jsonify({'success': False, 'error': '积分不足'}), 400 # 检查该位置的当前最高出价 current_highest = CommentPositionBid.query.filter_by( topic_id=topic_id, position=position, status='pending' ).order_by(CommentPositionBid.bid_amount.desc()).first() if current_highest and bid_amount <= current_highest.bid_amount: return jsonify({ 'success': False, 'error': f'出价必须高于当前最高价 {current_highest.bid_amount}' }), 400 # 扣减积分(冻结) account.balance -= bid_amount account.frozen += bid_amount # 如果有之前的出价,退还积分 user_previous_bid = CommentPositionBid.query.filter_by( topic_id=topic_id, position=position, user_id=current_user.id, status='pending' ).first() if user_previous_bid: account.frozen -= user_previous_bid.bid_amount account.balance += user_previous_bid.bid_amount user_previous_bid.status = 'lost' # 创建竞拍记录 topic = PredictionTopic.query.get_or_404(topic_id) bid = CommentPositionBid( topic_id=topic_id, user_id=current_user.id, position=position, bid_amount=bid_amount, expires_at=topic.deadline # 竞拍截止时间与话题截止时间相同 ) db.session.add(bid) # 记录积分交易 transaction = CreditTransaction( user_id=current_user.id, type='position_bid', amount=-bid_amount, balance_after=account.balance, description=f'竞拍评论位置 #{position} (话题#{topic_id})' ) db.session.add(transaction) db.session.commit() return jsonify({ 'success': True, 'data': { 'bid_id': bid.id, 'position': position, 'bid_amount': bid_amount, 'new_balance': account.balance, 'frozen': account.frozen } }) except Exception as e: db.session.rollback() return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/prediction/topics//position-bids', methods=['GET']) def get_position_bids(topic_id): """获取话题的位置竞拍列表""" try: result = {} for position in [1, 2, 3]: bids = CommentPositionBid.query.filter_by( topic_id=topic_id, position=position, status='pending' ).order_by(CommentPositionBid.bid_amount.desc()).limit(5).all() position_bids = [] for bid in bids: user = User.query.get(bid.user_id) position_bids.append({ 'id': bid.id, 'user_id': bid.user_id, 'user_name': user.username if user else '未知用户', 'user_avatar': user.avatar if user else None, 'bid_amount': bid.bid_amount, 'created_at': bid.created_at.strftime('%Y-%m-%d %H:%M:%S') }) result[f'position_{position}'] = position_bids return jsonify({ 'success': True, 'data': result }) except Exception as e: return jsonify({'success': False, 'error': str(e)}), 500 # ==================== 时间胶囊 API ==================== @app.route('/api/time-capsule/topics', methods=['POST']) @login_required def create_time_capsule_topic(): """创建时间胶囊话题""" try: data = request.json title = data.get('title') description = data.get('description', '') encrypted_content = data.get('encrypted_content') encryption_key = data.get('encryption_key') start_year = data.get('start_year') end_year = data.get('end_year') # 验证 if not title or not encrypted_content or not encryption_key: return jsonify({'success': False, 'error': '缺少必要参数'}), 400 if not start_year or not end_year or end_year <= start_year: return jsonify({'success': False, 'error': '无效的时间范围'}), 400 # 获取用户积分账户 account = UserCreditAccount.query.filter_by(user_id=current_user.id).first() if not account or account.balance < 100: return jsonify({'success': False, 'error': '积分不足,需要100积分'}), 400 # 扣减积分 account.balance -= 100 # 创建话题 topic = TimeCapsuleTopic( user_id=current_user.id, title=title, description=description, encrypted_content=encrypted_content, encryption_key=encryption_key, start_year=start_year, end_year=end_year, total_pool=100 # 创建费用进入奖池 ) db.session.add(topic) db.session.flush() # 获取 topic.id # 自动创建时间段(每年一个时间段) for year in range(start_year, end_year + 1): slot = TimeCapsuleTimeSlot( topic_id=topic.id, year_start=year, year_end=year ) db.session.add(slot) # 记录积分交易 transaction = CreditTransaction( user_id=current_user.id, type='time_capsule_create', amount=-100, balance_after=account.balance, description=f'创建时间胶囊话题 #{topic.id}' ) db.session.add(transaction) db.session.commit() return jsonify({ 'success': True, 'data': { 'topic_id': topic.id, 'title': topic.title, 'new_balance': account.balance } }) except Exception as e: db.session.rollback() return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/time-capsule/topics', methods=['GET']) def get_time_capsule_topics(): """获取时间胶囊话题列表""" try: status = request.args.get('status', 'active') query = TimeCapsuleTopic.query.filter_by(status=status) topics = query.order_by(TimeCapsuleTopic.created_at.desc()).all() result = [] for topic in topics: # 获取用户信息 user = User.query.get(topic.user_id) # 获取时间段统计 slots = TimeCapsuleTimeSlot.query.filter_by(topic_id=topic.id).all() total_slots = len(slots) active_slots = len([s for s in slots if s.status == 'active']) result.append({ 'id': topic.id, 'title': topic.title, 'description': topic.description, 'start_year': topic.start_year, 'end_year': topic.end_year, 'total_pool': topic.total_pool, 'total_slots': total_slots, 'active_slots': active_slots, 'is_decrypted': topic.is_decrypted, 'status': topic.status, 'author_id': topic.user_id, 'author_name': user.username if user else '未知用户', 'author_avatar': user.avatar if user else None, 'created_at': topic.created_at.strftime('%Y-%m-%d %H:%M:%S') }) return jsonify({ 'success': True, 'data': result }) except Exception as e: return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/time-capsule/topics/', methods=['GET']) def get_time_capsule_topic(topic_id): """获取时间胶囊话题详情""" try: topic = TimeCapsuleTopic.query.get_or_404(topic_id) user = User.query.get(topic.user_id) # 获取所有时间段 slots = TimeCapsuleTimeSlot.query.filter_by(topic_id=topic_id).order_by(TimeCapsuleTimeSlot.year_start).all() slots_data = [] for slot in slots: holder = User.query.get(slot.current_holder_id) if slot.current_holder_id else None slots_data.append({ 'id': slot.id, 'year_start': slot.year_start, 'year_end': slot.year_end, 'current_price': slot.current_price, 'total_bids': slot.total_bids, 'status': slot.status, 'current_holder_id': slot.current_holder_id, 'current_holder_name': holder.username if holder else None, 'current_holder_avatar': holder.avatar if holder else None }) result = { 'id': topic.id, 'title': topic.title, 'description': topic.description, 'start_year': topic.start_year, 'end_year': topic.end_year, 'total_pool': topic.total_pool, 'is_decrypted': topic.is_decrypted, 'decrypted_content': topic.encrypted_content if topic.is_decrypted else None, 'actual_happened_year': topic.actual_happened_year, 'status': topic.status, 'author_id': topic.user_id, 'author_name': user.username if user else '未知用户', 'author_avatar': user.avatar if user else None, 'time_slots': slots_data, 'created_at': topic.created_at.strftime('%Y-%m-%d %H:%M:%S') } return jsonify({ 'success': True, 'data': result }) except Exception as e: return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/time-capsule/slots//bid', methods=['POST']) @login_required def bid_time_slot(slot_id): """竞拍时间段""" try: data = request.json bid_amount = data.get('bid_amount') slot = TimeCapsuleTimeSlot.query.get_or_404(slot_id) # 检查时间段是否还在竞拍 if slot.status != 'active': return jsonify({'success': False, 'error': '该时间段已结束竞拍'}), 400 # 检查出价是否高于当前价格 min_bid = slot.current_price + 50 # 至少比当前价格高50积分 if bid_amount < min_bid: return jsonify({ 'success': False, 'error': f'出价必须至少为 {min_bid} 积分' }), 400 # 获取用户积分账户 account = UserCreditAccount.query.filter_by(user_id=current_user.id).first() if not account or account.balance < bid_amount: return jsonify({'success': False, 'error': '积分不足'}), 400 # 扣减积分 account.balance -= bid_amount # 如果有前任持有者,退还积分 if slot.current_holder_id: prev_holder_account = UserCreditAccount.query.filter_by(user_id=slot.current_holder_id).first() if prev_holder_account: prev_holder_account.balance += slot.current_price # 更新前任的竞拍记录状态 prev_bid = TimeSlotBid.query.filter_by( slot_id=slot_id, user_id=slot.current_holder_id, status='holding' ).first() if prev_bid: prev_bid.status = 'outbid' # 创建竞拍记录 bid = TimeSlotBid( slot_id=slot_id, user_id=current_user.id, bid_amount=bid_amount, status='holding' ) db.session.add(bid) # 更新时间段 slot.current_holder_id = current_user.id slot.current_price = bid_amount slot.total_bids += 1 # 更新话题奖池 topic = TimeCapsuleTopic.query.get(slot.topic_id) price_increase = bid_amount - (slot.current_price if slot.current_holder_id else 100) topic.total_pool += price_increase # 记录积分交易 transaction = CreditTransaction( user_id=current_user.id, type='time_slot_bid', amount=-bid_amount, balance_after=account.balance, description=f'竞拍时间段 {slot.year_start}-{slot.year_end}' ) db.session.add(transaction) db.session.commit() return jsonify({ 'success': True, 'data': { 'slot_id': slot_id, 'bid_amount': bid_amount, 'new_balance': account.balance, 'total_pool': topic.total_pool } }) except Exception as e: db.session.rollback() return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/time-capsule/topics//decrypt', methods=['POST']) @login_required def decrypt_time_capsule(topic_id): """解密时间胶囊(管理员或作者)""" try: topic = TimeCapsuleTopic.query.get_or_404(topic_id) # 检查权限(管理员或作者) if current_user.id != 1 and current_user.id != topic.user_id: return jsonify({'success': False, 'error': '无权限操作'}), 403 # 检查是否已解密 if topic.is_decrypted: return jsonify({'success': False, 'error': '该话题已解密'}), 400 # 解密(前端会用密钥解密内容) topic.is_decrypted = True db.session.commit() return jsonify({ 'success': True, 'data': { 'encrypted_content': topic.encrypted_content, 'encryption_key': topic.encryption_key } }) except Exception as e: db.session.rollback() return jsonify({'success': False, 'error': str(e)}), 500 @app.route('/api/time-capsule/topics//settle', methods=['POST']) @login_required def settle_time_capsule(topic_id): """结算时间胶囊话题""" try: # 检查管理员权限 if current_user.id != 1: return jsonify({'success': False, 'error': '无权限操作'}), 403 data = request.json happened_year = data.get('happened_year') topic = TimeCapsuleTopic.query.get_or_404(topic_id) # 检查是否已结算 if topic.status == 'settled': return jsonify({'success': False, 'error': '该话题已结算'}), 400 # 更新话题状态 topic.status = 'settled' topic.actual_happened_year = happened_year # 找到中奖的时间段 winning_slot = TimeCapsuleTimeSlot.query.filter_by( topic_id=topic_id, year_start=happened_year ).first() if winning_slot and winning_slot.current_holder_id: # 中奖者获得全部奖池 winner_account = UserCreditAccount.query.filter_by(user_id=winning_slot.current_holder_id).first() if winner_account: winner_account.balance += topic.total_pool # 记录积分交易 transaction = CreditTransaction( user_id=winning_slot.current_holder_id, type='time_capsule_win', amount=topic.total_pool, balance_after=winner_account.balance, description=f'时间胶囊中奖 #{topic_id}' ) db.session.add(transaction) # 更新竞拍记录 winning_bid = TimeSlotBid.query.filter_by( slot_id=winning_slot.id, user_id=winning_slot.current_holder_id, status='holding' ).first() if winning_bid: winning_bid.status = 'won' # 更新时间段状态 winning_slot.status = 'won' # 其他时间段设为过期 other_slots = TimeCapsuleTimeSlot.query.filter( TimeCapsuleTimeSlot.topic_id == topic_id, TimeCapsuleTimeSlot.id != (winning_slot.id if winning_slot else None) ).all() for slot in other_slots: slot.status = 'expired' db.session.commit() return jsonify({ 'success': True, 'data': { 'topic_id': topic_id, 'happened_year': happened_year, 'winner_id': winning_slot.current_holder_id if winning_slot else None, 'prize': topic.total_pool } }) except Exception as e: db.session.rollback() return jsonify({'success': False, 'error': str(e)}), 500 if __name__ == '__main__': # 创建数据库表 with app.app_context(): try: db.create_all() # 安全地初始化订阅套餐 initialize_subscription_plans_safe() except Exception as e: app.logger.error(f"数据库初始化失败: {e}") # 初始化事件轮询机制(WebSocket 推送) initialize_event_polling() # 使用 socketio.run 替代 app.run 以支持 WebSocket socketio.run(app, host='0.0.0.0', port=5001, debug=False, allow_unsafe_werkzeug=True)