diff --git a/__pycache__/app.cpython-310.pyc b/__pycache__/app.cpython-310.pyc deleted file mode 100644 index 31faf388..00000000 Binary files a/__pycache__/app.cpython-310.pyc and /dev/null differ diff --git a/__pycache__/config.cpython-311.pyc b/__pycache__/config.cpython-311.pyc deleted file mode 100755 index c1e8c3f8..00000000 Binary files a/__pycache__/config.cpython-311.pyc and /dev/null differ diff --git a/__pycache__/mcp_server.cpython-310.pyc b/__pycache__/mcp_server.cpython-310.pyc deleted file mode 100644 index ef714330..00000000 Binary files a/__pycache__/mcp_server.cpython-310.pyc and /dev/null differ diff --git a/__pycache__/wechat_pay.cpython-310.pyc b/__pycache__/wechat_pay.cpython-310.pyc deleted file mode 100644 index 2a421f05..00000000 Binary files a/__pycache__/wechat_pay.cpython-310.pyc and /dev/null differ diff --git a/__pycache__/wechat_pay_config.cpython-310.pyc b/__pycache__/wechat_pay_config.cpython-310.pyc deleted file mode 100644 index cabda701..00000000 Binary files a/__pycache__/wechat_pay_config.cpython-310.pyc and /dev/null differ diff --git a/app.py b/app.py index b5922c1a..e681df2f 100755 --- a/app.py +++ b/app.py @@ -68,6 +68,17 @@ def load_trading_days(): print(f"加载交易日数据失败: {e}") +def row_to_dict(row): + """ + 将 SQLAlchemy Row 对象转换为字典 + 兼容 SQLAlchemy 1.4+ 版本 + """ + if row is None: + return None + # 使用 _mapping 属性来访问列数据 + return dict(row._mapping) + + def get_trading_day_near_date(target_date): """ 获取距离目标日期最近的交易日 @@ -5642,7 +5653,8 @@ def get_stock_basic_info(stock_code): # 转换为字典 basic_info = {} - for key, value in zip(result.keys(), result): + result_dict = row_to_dict(result) + for key, value in result_dict.items(): if isinstance(value, datetime): basic_info[key] = value.strftime('%Y-%m-%d') elif isinstance(value, Decimal): @@ -5685,7 +5697,7 @@ def get_stock_announcements(stock_code): announcements = [] for row in result: announcement = {} - for key, value in zip(row.keys(), row): + for key, value in row_to_dict(row).items(): if value is None: announcement[key] = None elif isinstance(value, datetime): @@ -5734,7 +5746,7 @@ def get_stock_disclosure_schedule(stock_code): schedules = [] for row in result: schedule = {} - for key, value in zip(row.keys(), row): + for key, value in row_to_dict(row).items(): if value is None: schedule[key] = None elif isinstance(value, datetime): @@ -5815,7 +5827,7 @@ def get_stock_actual_control(stock_code): control_info = [] for row in result: control_record = {} - for key, value in zip(row.keys(), row): + for key, value in row_to_dict(row).items(): if value is None: control_record[key] = None elif isinstance(value, datetime): @@ -5864,7 +5876,7 @@ def get_stock_concentration(stock_code): concentration_info = [] for row in result: concentration_record = {} - for key, value in zip(row.keys(), row): + for key, value in row_to_dict(row).items(): if value is None: concentration_record[key] = None elif isinstance(value, datetime): @@ -5933,7 +5945,7 @@ def get_stock_management(stock_code): management_info = [] for row in result: management_record = {} - for key, value in zip(row.keys(), row): + for key, value in row_to_dict(row).items(): if value is None: management_record[key] = None elif isinstance(value, datetime): @@ -5992,7 +6004,7 @@ def get_stock_top_circulation_shareholders(stock_code): shareholders_info = [] for row in result: shareholder_record = {} - for key, value in zip(row.keys(), row): + for key, value in row_to_dict(row).items(): if value is None: shareholder_record[key] = None elif isinstance(value, datetime): @@ -6051,7 +6063,7 @@ def get_stock_top_shareholders(stock_code): shareholders_info = [] for row in result: shareholder_record = {} - for key, value in zip(row.keys(), row): + for key, value in row_to_dict(row).items(): if value is None: shareholder_record[key] = None elif isinstance(value, datetime): @@ -6102,7 +6114,7 @@ def get_stock_branches(stock_code): branches_info = [] for row in result: branch_record = {} - for key, value in zip(row.keys(), row): + for key, value in row_to_dict(row).items(): if value is None: branch_record[key] = None elif isinstance(value, datetime): @@ -6169,7 +6181,7 @@ def get_stock_patents(stock_code): patents_info = [] for row in result: patent_record = {} - for key, value in zip(row.keys(), row): + for key, value in row_to_dict(row).items(): if value is None: patent_record[key] = None elif isinstance(value, datetime): @@ -8644,7 +8656,8 @@ def get_stock_info(seccode): ORDER BY a.ENDDATE DESC LIMIT 1 """) - result = engine.execute(query, seccode=seccode).fetchone() + with engine.connect() as conn: + result = conn.execute(query, {'seccode': seccode}).fetchone() if not result: return jsonify({ @@ -8667,7 +8680,8 @@ def get_stock_info(seccode): ORDER BY F001D DESC LIMIT 1 """) - forecast_result = engine.execute(forecast_query, seccode=seccode).fetchone() + with engine.connect() as conn: + forecast_result = conn.execute(forecast_query, {'seccode': seccode}).fetchone() data = { 'stock_code': result.SECCODE, @@ -8828,7 +8842,8 @@ def get_balance_sheet(seccode): ORDER BY ENDDATE DESC LIMIT :limit """) - result = engine.execute(query, seccode=seccode, limit=limit) + with engine.connect() as conn: + result = conn.execute(query, {'seccode': seccode, 'limit': limit}) data = [] for row in result: @@ -9018,7 +9033,8 @@ def get_income_statement(seccode): ORDER BY ENDDATE DESC LIMIT :limit """) - result = engine.execute(query, seccode=seccode, limit=limit) + with engine.connect() as conn: + result = conn.execute(query, {'seccode': seccode, 'limit': limit}) data = [] for row in result: @@ -9227,7 +9243,8 @@ def get_cashflow(seccode): ORDER BY ENDDATE DESC LIMIT :limit """) - result = engine.execute(query, seccode=seccode, limit=limit) + with engine.connect() as conn: + result = conn.execute(query, {'seccode': seccode, 'limit': limit}) data = [] for row in result: @@ -9462,7 +9479,8 @@ def get_financial_metrics(seccode): ORDER BY ENDDATE DESC LIMIT :limit """) - result = engine.execute(query, seccode=seccode, limit=limit) + with engine.connect() as conn: + result = conn.execute(query, {'seccode': seccode, 'limit': limit}) data = [] for row in result: @@ -9602,7 +9620,8 @@ def get_main_business(seccode): ORDER BY ENDDATE DESC LIMIT :limit """) - periods = engine.execute(period_query, seccode=seccode, limit=limit).fetchall() + with engine.connect() as conn: + periods = conn.execute(period_query, {'seccode': seccode, 'limit': limit}).fetchall() # 产品分类数据 product_data = [] @@ -9620,7 +9639,8 @@ def get_main_business(seccode): ORDER BY F005N DESC """) - result = engine.execute(query, seccode=seccode, enddate=period[0]) + with engine.connect() as conn: + result = conn.execute(query, {'seccode': seccode, 'enddate': period[0]}) # Convert result to list to allow multiple iterations rows = list(result) @@ -9669,7 +9689,8 @@ def get_main_business(seccode): ORDER BY F007N DESC """) - result = engine.execute(query, seccode=seccode, enddate=period[0]) + with engine.connect() as conn: + result = conn.execute(query, {'seccode': seccode, 'enddate': period[0]}) # Convert result to list to allow multiple iterations rows = list(result) @@ -9730,7 +9751,8 @@ def get_forecast(seccode): ORDER BY F001D DESC, UPDATE_DATE DESC LIMIT 10 """) - forecast_result = engine.execute(forecast_query, seccode=seccode) + with engine.connect() as conn: + forecast_result = conn.execute(forecast_query, {'seccode': seccode}) forecast_data = [] for row in forecast_result: @@ -9771,7 +9793,8 @@ def get_forecast(seccode): ORDER BY F001D DESC LIMIT 8 """) - pretime_result = engine.execute(pretime_query, seccode=seccode) + with engine.connect() as conn: + pretime_result = conn.execute(pretime_query, {'seccode': seccode}) pretime_data = [] for row in pretime_result: @@ -9870,7 +9893,8 @@ def get_industry_rank(seccode): """) # 获取多个报告期的数据 - result = engine.execute(query, seccode=seccode, limit_total=limit * 4) + with engine.connect() as conn: + result = conn.execute(query, {'seccode': seccode, 'limit_total': limit * 4}) # 按报告期和行业级别组织数据 data_by_period = {} @@ -9990,7 +10014,8 @@ def get_period_comparison(seccode): ORDER BY fi.ENDDATE DESC LIMIT :periods """) - result = engine.execute(query, seccode=seccode, periods=periods) + with engine.connect() as conn: + result = conn.execute(query, {'seccode': seccode, 'periods': periods}) data = [] for row in result: @@ -10114,7 +10139,8 @@ def get_trade_data(seccode): LIMIT :days """) - result = engine.execute(query, seccode=seccode, end_date=end_date, days=days) + with engine.connect() as conn: + result = conn.execute(query, {'seccode': seccode, 'end_date': end_date, 'days': days}) data = [] for row in result: @@ -10190,7 +10216,8 @@ def get_funding_data(seccode): ORDER BY TRADEDATE DESC LIMIT :days """) - result = engine.execute(query, seccode=seccode, days=days) + with engine.connect() as conn: + result = conn.execute(query, {'seccode': seccode, 'days': days}) data = [] for row in result: @@ -10248,7 +10275,8 @@ def get_bigdeal_data(seccode): ORDER BY TRADEDATE DESC, F007N LIMIT :days """) - result = engine.execute(query, seccode=seccode, days=days) + with engine.connect() as conn: + result = conn.execute(query, {'seccode': seccode, 'days': days}) data = [] for row in result: @@ -10322,7 +10350,8 @@ def get_unusual_data(seccode): ORDER BY TRADEDATE DESC, F004N LIMIT 100 """) - result = engine.execute(query, seccode=seccode) + with engine.connect() as conn: + result = conn.execute(query, {'seccode': seccode}) data = [] for row in result: @@ -10400,7 +10429,8 @@ def get_pledge_data(seccode): ORDER BY ENDDATE DESC LIMIT 12 """) - result = engine.execute(query, seccode=seccode) + with engine.connect() as conn: + result = conn.execute(query, {'seccode': seccode}) data = [] for row in result: @@ -10457,9 +10487,12 @@ def get_market_summary(seccode): ORDER BY ENDDATE DESC LIMIT 1 """) - trade_result = engine.execute(trade_query, seccode=seccode).fetchone() - funding_result = engine.execute(funding_query, seccode=seccode).fetchone() - pledge_result = engine.execute(pledge_query, seccode=seccode).fetchone() + with engine.connect() as conn: + trade_result = conn.execute(trade_query, {'seccode': seccode}).fetchone() + with engine.connect() as conn: + funding_result = conn.execute(funding_query, {'seccode': seccode}).fetchone() + with engine.connect() as conn: + pledge_result = conn.execute(pledge_query, {'seccode': seccode}).fetchone() summary = { 'stock_code': seccode, @@ -10954,7 +10987,8 @@ def get_rise_analysis(seccode): ORDER BY trade_date DESC LIMIT 100 """) - result = engine.execute(query, **params).fetchall() + with engine.connect() as conn: + result = conn.execute(query, params).fetchall() # 格式化数据 rise_analysis_data = [] @@ -11016,7 +11050,8 @@ def get_comprehensive_analysis(company_code): WHERE company_code = :company_code """) - qualitative_result = engine.execute(qualitative_query, company_code=company_code).fetchone() + with engine.connect() as conn: + qualitative_result = conn.execute(qualitative_query, {'company_code': company_code}).fetchone() # 获取业务板块分析 segments_query = text(""" @@ -11033,7 +11068,8 @@ def get_comprehensive_analysis(company_code): ORDER BY created_at DESC """) - segments_result = engine.execute(segments_query, company_code=company_code).fetchall() + with engine.connect() as conn: + segments_result = conn.execute(segments_query, {'company_code': company_code}).fetchall() # 获取竞争地位数据 - 最新一期 competitive_query = text(""" @@ -11058,7 +11094,8 @@ def get_comprehensive_analysis(company_code): ORDER BY report_period DESC LIMIT 1 """) - competitive_result = engine.execute(competitive_query, company_code=company_code).fetchone() + with engine.connect() as conn: + competitive_result = conn.execute(competitive_query, {'company_code': company_code}).fetchone() # 获取业务结构数据 - 最新一期 business_structure_query = text(""" @@ -11085,7 +11122,8 @@ def get_comprehensive_analysis(company_code): ORDER BY revenue_ratio DESC """) - business_structure_result = engine.execute(business_structure_query, company_code=company_code).fetchall() + with engine.connect() as conn: + business_structure_result = conn.execute(business_structure_query, {'company_code': company_code}).fetchall() # 构建返回数据 response_data = { @@ -11222,7 +11260,8 @@ def get_value_chain_analysis(company_code): ORDER BY node_level ASC, importance_score DESC """) - nodes_result = engine.execute(nodes_query, company_code=company_code).fetchall() + with engine.connect() as conn: + nodes_result = conn.execute(nodes_query, {'company_code': company_code}).fetchall() # 获取产业链流向数据 flows_query = text(""" @@ -11242,7 +11281,8 @@ def get_value_chain_analysis(company_code): ORDER BY flow_ratio DESC """) - flows_result = engine.execute(flows_query, company_code=company_code).fetchall() + with engine.connect() as conn: + flows_result = conn.execute(flows_query, {'company_code': company_code}).fetchall() # 构建节点数据结构 nodes_by_level = {} @@ -11352,7 +11392,8 @@ def get_key_factors_timeline(company_code): ORDER BY display_order ASC, created_at ASC """) - categories_result = engine.execute(categories_query, company_code=company_code).fetchall() + with engine.connect() as conn: + categories_result = conn.execute(categories_query, {'company_code': company_code}).fetchall() # 获取关键因素详情 factors_query = text(""" @@ -11417,7 +11458,8 @@ def get_key_factors_timeline(company_code): ORDER BY kf.report_period DESC, kf.impact_weight DESC, kf.updated_at DESC """) - factors_result = engine.execute(factors_query, **params).fetchall() + with engine.connect() as conn: + factors_result = conn.execute(factors_query, params).fetchall() # 获取发展时间线事件 timeline_query = text(""" @@ -11436,9 +11478,8 @@ def get_key_factors_timeline(company_code): ORDER BY event_date DESC LIMIT :limit """) - timeline_result = engine.execute(timeline_query, - company_code=company_code, - limit=event_limit).fetchall() + with engine.connect() as conn: + timeline_result = conn.execute(timeline_query, {'company_code': company_code, 'limit': event_limit}).fetchall() # 构建关键因素数据结构 key_factors_data = {} diff --git a/app.py.backup b/app.py.backup new file mode 100644 index 00000000..b5922c1a --- /dev/null +++ b/app.py.backup @@ -0,0 +1,12556 @@ +import base64 +import csv +import io +import os +import time +import urllib +import uuid +from functools import wraps +import qrcode +from flask_mail import Mail, Message +from flask_socketio import SocketIO, emit, join_room, leave_room +import pytz +import requests +from celery import Celery +from flask_compress import Compress +from pathlib import Path +import json +from sqlalchemy import Column, Integer, String, Boolean, DateTime, create_engine, text, func, or_ +from flask import Flask, render_template, request, jsonify, redirect, url_for, flash, session, render_template_string, \ + current_app, make_response +from flask_sqlalchemy import SQLAlchemy +from flask_login import LoginManager, UserMixin, login_user, logout_user, login_required, current_user +import random +from werkzeug.security import generate_password_hash, check_password_hash +import re +import string +from datetime import datetime, timedelta, time as dt_time, date +from clickhouse_driver import Client as Cclient +from flask_cors import CORS + +from collections import defaultdict +from functools import lru_cache +import jieba +import jieba.analyse +from flask_cors import cross_origin +from tencentcloud.common import credential +from tencentcloud.common.profile.client_profile import ClientProfile +from tencentcloud.common.profile.http_profile import HttpProfile +from tencentcloud.sms.v20210111 import sms_client, models +from tencentcloud.common.exception.tencent_cloud_sdk_exception import TencentCloudSDKException +from sqlalchemy import text, desc, and_ +import pandas as pd +from decimal import Decimal +from apscheduler.schedulers.background import BackgroundScheduler + +# 交易日数据缓存 +trading_days = [] +trading_days_set = set() + + +def load_trading_days(): + """加载交易日数据""" + global trading_days, trading_days_set + try: + with open('tdays.csv', 'r') as f: + reader = csv.DictReader(f) + for row in reader: + date_str = row['DateTime'] + # 解析日期 (格式: 2010/1/4) + date = datetime.strptime(date_str, '%Y/%m/%d').date() + trading_days.append(date) + trading_days_set.add(date) + + # 排序交易日 + trading_days.sort() + print(f"成功加载 {len(trading_days)} 个交易日数据") + except Exception as e: + print(f"加载交易日数据失败: {e}") + + +def get_trading_day_near_date(target_date): + """ + 获取距离目标日期最近的交易日 + 如果目标日期是交易日,返回该日期 + 如果不是,返回下一个交易日 + """ + if not trading_days: + load_trading_days() + + if not trading_days: + return None + + # 如果目标日期是datetime,转换为date + if isinstance(target_date, datetime): + target_date = target_date.date() + + # 检查目标日期是否是交易日 + if target_date in trading_days_set: + return target_date + + # 查找下一个交易日 + for trading_day in trading_days: + if trading_day >= target_date: + return trading_day + + # 如果没有找到,返回最后一个交易日 + return trading_days[-1] if trading_days else None + + +# 应用启动时加载交易日数据 +load_trading_days() + +engine = create_engine( + "mysql+pymysql://root:Zzl5588161!@222.128.1.157:33060/stock?charset=utf8mb4", + echo=False, + pool_size=10, + pool_recycle=3600, + pool_pre_ping=True, + pool_timeout=30, + max_overflow=20 +) +engine_med = create_engine( + "mysql+pymysql://root:Zzl5588161!@222.128.1.157:33060/med?charset=utf8mb4", + echo=False, + pool_size=5, + pool_recycle=3600, + pool_pre_ping=True, + pool_timeout=30, + max_overflow=10 +) +engine_2 = create_engine( + "mysql+pymysql://root:Zzl5588161!@222.128.1.157:33060/valuefrontier?charset=utf8mb4", + echo=False, + pool_size=5, + pool_recycle=3600, + pool_pre_ping=True, + pool_timeout=30, + max_overflow=10 +) +app = Flask(__name__) +# 存储验证码的临时字典(生产环境应使用Redis) +verification_codes = {} +wechat_qr_sessions = {} +# 腾讯云短信配置 +SMS_SECRET_ID = 'AKID2we9TacdTAhCjCSYTErHVimeJo9Yr00s' +SMS_SECRET_KEY = 'pMlBWijlkgT9fz5ziEXdWEnAPTJzRfkf' +SMS_SDK_APP_ID = "1400972398" +SMS_SIGN_NAME = "价值前沿科技" +SMS_TEMPLATE_REGISTER = "2386557" # 注册模板 +SMS_TEMPLATE_LOGIN = "2386540" # 登录模板 + +# 微信开放平台配置 +WECHAT_APPID = 'wxa8d74c47041b5f87' +WECHAT_APPSECRET = 'eedef95b11787fd7ca7f1acc6c9061bc' +WECHAT_REDIRECT_URI = 'http://valuefrontier.cn/api/auth/wechat/callback' + +# 邮件服务配置(QQ企业邮箱) +MAIL_SERVER = 'smtp.exmail.qq.com' +MAIL_PORT = 465 +MAIL_USE_SSL = True +MAIL_USE_TLS = False +MAIL_USERNAME = 'admin@valuefrontier.cn' +MAIL_PASSWORD = 'QYncRu6WUdASvTg4' +MAIL_DEFAULT_SENDER = 'admin@valuefrontier.cn' + +# Session和安全配置 +app.config['SECRET_KEY'] = ''.join(random.choices(string.ascii_letters + string.digits, k=32)) +app.config['SESSION_COOKIE_SECURE'] = False # 如果生产环境使用HTTPS,应设为True +app.config['SESSION_COOKIE_HTTPONLY'] = True # 生产环境应设为True,防止XSS攻击 +app.config['SESSION_COOKIE_SAMESITE'] = 'Lax' # 使用'Lax'以平衡安全性和功能性 +app.config['SESSION_COOKIE_DOMAIN'] = None # 不限制域名 +app.config['SESSION_COOKIE_PATH'] = '/' # 设置cookie路径 +app.config['PERMANENT_SESSION_LIFETIME'] = timedelta(days=7) # session持续7天 +app.config['REMEMBER_COOKIE_DURATION'] = timedelta(days=30) # 记住登录30天 +app.config['REMEMBER_COOKIE_SECURE'] = False # 记住登录cookie不要求HTTPS +app.config['REMEMBER_COOKIE_HTTPONLY'] = False # 允许JavaScript访问 + +# 配置邮件 +app.config['MAIL_SERVER'] = MAIL_SERVER +app.config['MAIL_PORT'] = MAIL_PORT +app.config['MAIL_USE_SSL'] = MAIL_USE_SSL +app.config['MAIL_USE_TLS'] = MAIL_USE_TLS +app.config['MAIL_USERNAME'] = MAIL_USERNAME +app.config['MAIL_PASSWORD'] = MAIL_PASSWORD +app.config['MAIL_DEFAULT_SENDER'] = MAIL_DEFAULT_SENDER + +# 允许前端跨域访问 - 修复CORS配置 +try: + CORS(app, + origins=["http://localhost:3000", "http://127.0.0.1:3000", "http://localhost:5173", "https://valuefrontier.cn", + "http://valuefrontier.cn"], # 明确指定允许的源 + methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], + allow_headers=["Content-Type", "Authorization", "X-Requested-With"], + supports_credentials=True, # 允许携带凭据 + expose_headers=["Content-Type", "Authorization"]) +except ImportError: + pass # 如果未安装flask_cors则跳过 + +# 初始化 Flask-Login +login_manager = LoginManager() +login_manager.init_app(app) +login_manager.login_view = 'login' +login_manager.login_message = '请先登录访问此页面' +login_manager.remember_cookie_duration = timedelta(days=30) # 记住登录持续时间 +Compress(app) +MAX_CONTENT_LENGTH = 16 * 1024 * 1024 # 16MB max file size +# Configure Flask-Compress +app.config['COMPRESS_ALGORITHM'] = ['gzip', 'br'] +app.config['COMPRESS_MIMETYPES'] = [ + 'text/html', + 'text/css', + 'text/xml', + 'application/json', + 'application/javascript', + 'application/x-javascript' +] +app.config['SQLALCHEMY_DATABASE_URI'] = 'mysql+pymysql://root:Zzl5588161!@222.128.1.157:33060/stock?charset=utf8mb4' +app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False +app.config['SQLALCHEMY_ENGINE_OPTIONS'] = { + 'pool_size': 10, + 'pool_recycle': 3600, + 'pool_pre_ping': True, + 'pool_timeout': 30, + 'max_overflow': 20 +} +# Cache directory setup +CACHE_DIR = Path('cache') +CACHE_DIR.mkdir(exist_ok=True) + + +def beijing_now(): + # 使用 pytz 处理时区,但返回 naive datetime(适合数据库存储) + beijing_tz = pytz.timezone('Asia/Shanghai') + return datetime.now(beijing_tz).replace(tzinfo=None) + + +# 检查用户是否登录的装饰器 +def login_required(f): + @wraps(f) + def decorated_function(*args, **kwargs): + if 'user_id' not in session: + return jsonify({'success': False, 'error': '未登录'}), 401 + return f(*args, **kwargs) + + return decorated_function + + +# Memory management constants +MAX_MEMORY_PERCENT = 75 +MEMORY_CHECK_INTERVAL = 300 +MAX_CACHE_ITEMS = 50 +db = SQLAlchemy(app) + +# 初始化邮件服务 +mail = Mail(app) + +# 初始化 Flask-SocketIO(用于实时事件推送) +socketio = SocketIO( + app, + cors_allowed_origins=["http://localhost:3000", "http://127.0.0.1:3000", "http://localhost:5173", + "https://valuefrontier.cn", "http://valuefrontier.cn"], + async_mode='gevent', + logger=True, + engineio_logger=False, + ping_timeout=120, # 心跳超时时间(秒),客户端120秒内无响应才断开 + ping_interval=25 # 心跳检测间隔(秒),每25秒发送一次ping +) + + +@login_manager.user_loader +def load_user(user_id): + """Flask-Login 用户加载回调""" + try: + return User.query.get(int(user_id)) + except Exception as e: + app.logger.error(f"用户加载错误: {e}") + return None + + +# 全局错误处理器 - 确保API接口始终返回JSON +@app.errorhandler(404) +def not_found_error(error): + """404错误处理""" + if request.path.startswith('/api/'): + return jsonify({'success': False, 'error': '接口不存在'}), 404 + return error + + +@app.errorhandler(500) +def internal_error(error): + """500错误处理""" + db.session.rollback() + if request.path.startswith('/api/'): + return jsonify({'success': False, 'error': '服务器内部错误'}), 500 + return error + + +@app.errorhandler(405) +def method_not_allowed_error(error): + """405错误处理""" + if request.path.startswith('/api/'): + return jsonify({'success': False, 'error': '请求方法不被允许'}), 405 + return error + + +class Post(db.Model): + """帖子模型""" + id = db.Column(db.Integer, primary_key=True) + event_id = db.Column(db.Integer, db.ForeignKey('event.id'), nullable=False) + user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False) + + # 内容 + title = db.Column(db.String(200)) # 标题(可选) + content = db.Column(db.Text, nullable=False) # 内容 + content_type = db.Column(db.String(20), default='text') # 内容类型:text/rich_text/link + + # 时间 + created_at = db.Column(db.DateTime, default=beijing_now) + updated_at = db.Column(db.DateTime, default=beijing_now, onupdate=beijing_now) + + # 统计 + likes_count = db.Column(db.Integer, default=0) + comments_count = db.Column(db.Integer, default=0) + view_count = db.Column(db.Integer, default=0) + + # 状态 + status = db.Column(db.String(20), default='active') # active/hidden/deleted + is_top = db.Column(db.Boolean, default=False) # 是否置顶 + + # 关系 + user = db.relationship('User', backref='posts') + likes = db.relationship('PostLike', backref='post', lazy='dynamic') + comments = db.relationship('Comment', backref='post', lazy='dynamic') + + +class Comment(db.Model): + """帖子评论模型""" + id = db.Column(db.Integer, primary_key=True) + post_id = db.Column(db.Integer, db.ForeignKey('post.id'), nullable=False) + user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False) + + # 内容 + content = db.Column(db.Text, nullable=False) + parent_id = db.Column(db.Integer, db.ForeignKey('comment.id')) + + # 时间 + created_at = db.Column(db.DateTime, default=beijing_now) + updated_at = db.Column(db.DateTime, default=beijing_now, onupdate=beijing_now) + + # 统计 + likes_count = db.Column(db.Integer, default=0) + + # 状态 + status = db.Column(db.String(20), default='active') # active/hidden/deleted + + # 关系 + user = db.relationship('User', backref='comments') + replies = db.relationship('Comment', backref=db.backref('parent', remote_side=[id])) + + +class User(UserMixin, db.Model): + """用户模型 - 完全匹配现有数据库表结构""" + __tablename__ = 'user' + + # 主键 + id = db.Column(db.Integer, primary_key=True, autoincrement=True) + + # 基础账号信息 + username = db.Column(db.String(80), unique=True, nullable=False) + email = db.Column(db.String(120), unique=True, nullable=True) + password_hash = db.Column(db.String(255), nullable=True) + email_confirmed = db.Column(db.Boolean, nullable=True, default=True) + + # 时间字段 + created_at = db.Column(db.DateTime, nullable=True, default=beijing_now) + last_seen = db.Column(db.DateTime, nullable=True, default=beijing_now) + + # 账号状态 + status = db.Column(db.String(20), nullable=True, default='active') + + # 个人资料信息 + nickname = db.Column(db.String(30), nullable=True) + avatar_url = db.Column(db.String(200), nullable=True) + banner_url = db.Column(db.String(200), nullable=True) + bio = db.Column(db.String(200), nullable=True) + gender = db.Column(db.String(10), nullable=True) + birth_date = db.Column(db.Date, nullable=True) + location = db.Column(db.String(100), nullable=True) + + # 联系方式 + phone = db.Column(db.String(20), nullable=True) + wechat_id = db.Column(db.String(80), nullable=True) # 微信号 + + # 实名认证 + real_name = db.Column(db.String(30), nullable=True) + id_number = db.Column(db.String(18), nullable=True) + is_verified = db.Column(db.Boolean, nullable=True, default=False) + verify_time = db.Column(db.DateTime, nullable=True) + + # 投资偏好 + trading_experience = db.Column(db.String(200), nullable=True) + investment_style = db.Column(db.String(50), nullable=True) + risk_preference = db.Column(db.String(20), nullable=True) + investment_amount = db.Column(db.String(20), nullable=True) + preferred_markets = db.Column(db.String(200), nullable=True) + + # 社区数据 + user_level = db.Column(db.Integer, nullable=True, default=1) + reputation_score = db.Column(db.Integer, nullable=True, default=0) + contribution_point = db.Column(db.Integer, nullable=True, default=0) + post_count = db.Column(db.Integer, nullable=True, default=0) + comment_count = db.Column(db.Integer, nullable=True, default=0) + follower_count = db.Column(db.Integer, nullable=True, default=0) + following_count = db.Column(db.Integer, nullable=True, default=0) + + # 创作者相关 + is_creator = db.Column(db.Boolean, nullable=True, default=False) + creator_type = db.Column(db.String(20), nullable=True) + creator_tags = db.Column(db.String(200), nullable=True) + + # 通知设置 + email_notifications = db.Column(db.Boolean, nullable=True, default=True) + sms_notifications = db.Column(db.Boolean, nullable=True, default=False) + wechat_notifications = db.Column(db.Boolean, nullable=True, default=False) + notification_preferences = db.Column(db.String(500), nullable=True) + + # 隐私和界面设置 + privacy_level = db.Column(db.String(20), nullable=True, default='public') + theme_preference = db.Column(db.String(20), nullable=True, default='light') + blocked_keywords = db.Column(db.String(500), nullable=True) + + # 手机验证相关 + phone_confirmed = db.Column(db.Boolean, nullable=True, default=False) # 注意:原表中是blob,这里改为Boolean更合理 + phone_confirm_time = db.Column(db.DateTime, nullable=True) + + # 微信登录相关字段 + wechat_union_id = db.Column(db.String(100), nullable=True) # 微信UnionID + wechat_open_id = db.Column(db.String(100), nullable=True) # 微信OpenID + + def __init__(self, username, email=None, password=None, phone=None): + """初始化用户""" + self.username = username + if email: + self.email = email + if phone: + self.phone = phone + if password: + self.set_password(password) + self.nickname = username # 默认昵称为用户名 + self.created_at = beijing_now() + self.last_seen = beijing_now() + + def set_password(self, password): + """设置密码""" + if password: + self.password_hash = generate_password_hash(password) + + def check_password(self, password): + """验证密码""" + if not password or not self.password_hash: + return False + return check_password_hash(self.password_hash, password) + + def update_last_seen(self): + """更新最后活跃时间""" + self.last_seen = beijing_now() + db.session.commit() + + def confirm_email(self): + """确认邮箱""" + self.email_confirmed = True + db.session.commit() + + def confirm_phone(self): + """确认手机号""" + self.phone_confirmed = True + self.phone_confirm_time = beijing_now() + db.session.commit() + + def bind_wechat(self, open_id, union_id=None, wechat_info=None): + """绑定微信账号""" + self.wechat_open_id = open_id + if union_id: + self.wechat_union_id = union_id + + # 如果提供了微信用户信息,更新头像和昵称 + if wechat_info: + if not self.avatar_url and wechat_info.get('headimgurl'): + self.avatar_url = wechat_info['headimgurl'] + if not self.nickname and wechat_info.get('nickname'): + # 确保昵称编码正确且长度合理 + nickname = self._sanitize_nickname(wechat_info['nickname']) + self.nickname = nickname + + db.session.commit() + + def _sanitize_nickname(self, nickname): + """清理和验证昵称""" + if not nickname: + return '微信用户' + + try: + # 确保是正确的UTF-8字符串 + sanitized = str(nickname).strip() + + # 移除可能的控制字符 + import re + sanitized = re.sub(r'[\x00-\x1f\x7f-\x9f]', '', sanitized) + + # 限制长度(避免过长的昵称) + if len(sanitized) > 50: + sanitized = sanitized[:47] + '...' + + # 如果清理后为空,使用默认值 + if not sanitized: + sanitized = '微信用户' + + return sanitized + except Exception as e: + return '微信用户' + + def unbind_wechat(self): + """解绑微信账号""" + self.wechat_open_id = None + self.wechat_union_id = None + db.session.commit() + + def increment_post_count(self): + """增加发帖数""" + self.post_count = (self.post_count or 0) + 1 + db.session.commit() + + def increment_comment_count(self): + """增加评论数""" + self.comment_count = (self.comment_count or 0) + 1 + db.session.commit() + + def add_reputation(self, points): + """增加声誉分数""" + self.reputation_score = (self.reputation_score or 0) + points + db.session.commit() + + def to_dict(self, include_sensitive=False): + """转换为字典""" + data = { + 'id': self.id, + 'username': self.username, + 'nickname': self.nickname or self.username, + 'avatar_url': self.avatar_url, + 'banner_url': self.banner_url, + 'bio': self.bio, + 'gender': self.gender, + 'location': self.location, + 'user_level': self.user_level or 1, + 'reputation_score': self.reputation_score or 0, + 'contribution_point': self.contribution_point or 0, + 'post_count': self.post_count or 0, + 'comment_count': self.comment_count or 0, + 'follower_count': self.follower_count or 0, + 'following_count': self.following_count or 0, + 'is_creator': self.is_creator or False, + 'creator_type': self.creator_type, + 'creator_tags': self.creator_tags, + 'is_verified': self.is_verified or False, + 'created_at': self.created_at.isoformat() if self.created_at else None, + 'last_seen': self.last_seen.isoformat() if self.last_seen else None, + 'status': self.status, + 'has_wechat': bool(self.wechat_open_id), + 'is_authenticated': True + } + + # 敏感信息只在需要时包含 + if include_sensitive: + data.update({ + 'email': self.email, + 'phone': self.phone, + 'email_confirmed': self.email_confirmed, + 'phone_confirmed': self.phone_confirmed, + 'real_name': self.real_name, + 'birth_date': self.birth_date.isoformat() if self.birth_date else None, + 'trading_experience': self.trading_experience, + 'investment_style': self.investment_style, + 'risk_preference': self.risk_preference, + 'investment_amount': self.investment_amount, + 'preferred_markets': self.preferred_markets, + 'email_notifications': self.email_notifications, + 'sms_notifications': self.sms_notifications, + 'wechat_notifications': self.wechat_notifications, + 'privacy_level': self.privacy_level, + 'theme_preference': self.theme_preference + }) + + return data + + def to_public_dict(self): + """公开信息字典(用于显示给其他用户)""" + return { + 'id': self.id, + 'username': self.username, + 'nickname': self.nickname or self.username, + 'avatar_url': self.avatar_url, + 'bio': self.bio, + 'user_level': self.user_level or 1, + 'reputation_score': self.reputation_score or 0, + 'post_count': self.post_count or 0, + 'follower_count': self.follower_count or 0, + 'is_creator': self.is_creator or False, + 'creator_type': self.creator_type, + 'is_verified': self.is_verified or False, + 'created_at': self.created_at.isoformat() if self.created_at else None + } + + @staticmethod + def find_by_login_info(login_info): + """根据登录信息查找用户(支持用户名、邮箱、手机号)""" + return User.query.filter( + db.or_( + User.username == login_info, + User.email == login_info, + User.phone == login_info + ) + ).first() + + @staticmethod + def find_by_wechat_openid(open_id): + """根据微信OpenID查找用户""" + return User.query.filter_by(wechat_open_id=open_id).first() + + @staticmethod + def find_by_wechat_unionid(union_id): + """根据微信UnionID查找用户""" + return User.query.filter_by(wechat_union_id=union_id).first() + + @staticmethod + def is_username_taken(username): + """检查用户名是否已被使用""" + return User.query.filter_by(username=username).first() is not None + + @staticmethod + def is_email_taken(email): + """检查邮箱是否已被使用""" + return User.query.filter_by(email=email).first() is not None + + @staticmethod + def is_phone_taken(phone): + """检查手机号是否已被使用""" + return User.query.filter_by(phone=phone).first() is not None + + def __repr__(self): + return f'' + + +# ============================================ +# 订阅功能模块(安全版本 - 独立表) +# ============================================ +class UserSubscription(db.Model): + """用户订阅表 - 独立于现有User表""" + __tablename__ = 'user_subscriptions' + + id = db.Column(db.Integer, primary_key=True, autoincrement=True) + user_id = db.Column(db.Integer, nullable=False, unique=True, index=True) + subscription_type = db.Column(db.String(10), nullable=False, default='free') + subscription_status = db.Column(db.String(20), nullable=False, default='active') + start_date = db.Column(db.DateTime, nullable=True) + end_date = db.Column(db.DateTime, nullable=True) + billing_cycle = db.Column(db.String(10), nullable=True) + auto_renewal = db.Column(db.Boolean, nullable=False, default=False) + created_at = db.Column(db.DateTime, default=beijing_now) + updated_at = db.Column(db.DateTime, default=beijing_now, onupdate=beijing_now) + + def is_active(self): + if self.subscription_status != 'active': + return False + if self.subscription_type == 'free': + return True + if self.end_date: + try: + now = beijing_now() + if self.end_date < now: + return False + except Exception as e: + return False + return True + + def days_left(self): + if self.subscription_type == 'free' or not self.end_date: + return 999 + try: + now = beijing_now() + delta = self.end_date - now + return max(0, delta.days) + except Exception as e: + return 0 + + def to_dict(self): + return { + 'type': self.subscription_type, + 'status': self.subscription_status, + 'is_active': self.is_active(), + 'days_left': self.days_left(), + 'start_date': self.start_date.isoformat() if self.start_date else None, + 'end_date': self.end_date.isoformat() if self.end_date else None, + 'billing_cycle': self.billing_cycle, + 'auto_renewal': self.auto_renewal + } + + +class SubscriptionPlan(db.Model): + """订阅套餐表""" + __tablename__ = 'subscription_plans' + + id = db.Column(db.Integer, primary_key=True, autoincrement=True) + name = db.Column(db.String(50), nullable=False, unique=True) + display_name = db.Column(db.String(100), nullable=False) + description = db.Column(db.Text, nullable=True) + monthly_price = db.Column(db.Numeric(10, 2), nullable=False) + yearly_price = db.Column(db.Numeric(10, 2), nullable=False) + features = db.Column(db.Text, nullable=True) + pricing_options = db.Column(db.Text, nullable=True) # JSON格式:[{"months": 1, "price": 99}, {"months": 12, "price": 999}] + is_active = db.Column(db.Boolean, default=True) + sort_order = db.Column(db.Integer, default=0) + created_at = db.Column(db.DateTime, default=beijing_now) + + def to_dict(self): + # 解析pricing_options(如果存在) + pricing_opts = None + if self.pricing_options: + try: + pricing_opts = json.loads(self.pricing_options) + except: + pricing_opts = None + + # 如果没有pricing_options,则从monthly_price和yearly_price生成默认选项 + if not pricing_opts: + pricing_opts = [ + { + 'months': 1, + 'price': float(self.monthly_price) if self.monthly_price else 0, + 'label': '月付', + 'cycle_key': 'monthly' + }, + { + 'months': 12, + 'price': float(self.yearly_price) if self.yearly_price else 0, + 'label': '年付', + 'cycle_key': 'yearly', + 'discount_percent': 20 # 年付默认20%折扣 + } + ] + + return { + 'id': self.id, + 'name': self.name, + 'display_name': self.display_name, + 'description': self.description, + 'monthly_price': float(self.monthly_price) if self.monthly_price else 0, + 'yearly_price': float(self.yearly_price) if self.yearly_price else 0, + 'pricing_options': pricing_opts, # 新增:灵活计费周期选项 + 'features': json.loads(self.features) if self.features else [], + 'is_active': self.is_active, + 'sort_order': self.sort_order + } + + +class PaymentOrder(db.Model): + """支付订单表""" + __tablename__ = 'payment_orders' + + id = db.Column(db.Integer, primary_key=True, autoincrement=True) + order_no = db.Column(db.String(32), unique=True, nullable=False) + user_id = db.Column(db.Integer, nullable=False) + plan_name = db.Column(db.String(20), nullable=False) + billing_cycle = db.Column(db.String(10), nullable=False) + amount = db.Column(db.Numeric(10, 2), nullable=False) + wechat_order_id = db.Column(db.String(64), nullable=True) + prepay_id = db.Column(db.String(64), nullable=True) + qr_code_url = db.Column(db.String(200), nullable=True) + status = db.Column(db.String(20), default='pending') + created_at = db.Column(db.DateTime, default=beijing_now) + paid_at = db.Column(db.DateTime, nullable=True) + expired_at = db.Column(db.DateTime, nullable=True) + remark = db.Column(db.String(200), nullable=True) + + def __init__(self, user_id, plan_name, billing_cycle, amount): + self.user_id = user_id + self.plan_name = plan_name + self.billing_cycle = billing_cycle + self.amount = amount + import random + timestamp = int(beijing_now().timestamp() * 1000000) + random_suffix = random.randint(1000, 9999) + self.order_no = f"{timestamp}{user_id:04d}{random_suffix}" + self.expired_at = beijing_now() + timedelta(minutes=30) + + def is_expired(self): + if not self.expired_at: + return False + try: + now = beijing_now() + return now > self.expired_at + except Exception as e: + return False + + def mark_as_paid(self, wechat_order_id, transaction_id=None): + self.status = 'paid' + self.paid_at = beijing_now() + self.wechat_order_id = wechat_order_id + + def to_dict(self): + return { + 'id': self.id, + 'order_no': self.order_no, + 'user_id': self.user_id, + 'plan_name': self.plan_name, + 'billing_cycle': self.billing_cycle, + 'amount': float(self.amount) if self.amount else 0, + 'original_amount': float(self.original_amount) if hasattr(self, 'original_amount') and self.original_amount else None, + 'discount_amount': float(self.discount_amount) if hasattr(self, 'discount_amount') and self.discount_amount else 0, + 'promo_code': self.promo_code.code if hasattr(self, 'promo_code') and self.promo_code else None, + 'is_upgrade': self.is_upgrade if hasattr(self, 'is_upgrade') else False, + 'qr_code_url': self.qr_code_url, + 'status': self.status, + 'is_expired': self.is_expired(), + 'created_at': self.created_at.isoformat() if self.created_at else None, + 'paid_at': self.paid_at.isoformat() if self.paid_at else None, + 'expired_at': self.expired_at.isoformat() if self.expired_at else None, + 'remark': self.remark + } + + +class PromoCode(db.Model): + """优惠码表""" + __tablename__ = 'promo_codes' + + id = db.Column(db.Integer, primary_key=True, autoincrement=True) + code = db.Column(db.String(50), unique=True, nullable=False, index=True) + description = db.Column(db.String(200), nullable=True) + + # 折扣类型和值 + discount_type = db.Column(db.String(20), nullable=False) # 'percentage' 或 'fixed_amount' + discount_value = db.Column(db.Numeric(10, 2), nullable=False) + + # 适用范围 + applicable_plans = db.Column(db.String(200), nullable=True) # JSON格式 + applicable_cycles = db.Column(db.String(50), nullable=True) # JSON格式 + min_amount = db.Column(db.Numeric(10, 2), nullable=True) + + # 使用限制 + max_uses = db.Column(db.Integer, nullable=True) + max_uses_per_user = db.Column(db.Integer, default=1) + current_uses = db.Column(db.Integer, default=0) + + # 有效期 + valid_from = db.Column(db.DateTime, nullable=False) + valid_until = db.Column(db.DateTime, nullable=False) + + # 状态 + is_active = db.Column(db.Boolean, default=True) + created_by = db.Column(db.Integer, nullable=True) + created_at = db.Column(db.DateTime, default=beijing_now) + updated_at = db.Column(db.DateTime, default=beijing_now, onupdate=beijing_now) + + def to_dict(self): + return { + 'id': self.id, + 'code': self.code, + 'description': self.description, + 'discount_type': self.discount_type, + 'discount_value': float(self.discount_value) if self.discount_value else 0, + 'applicable_plans': json.loads(self.applicable_plans) if self.applicable_plans else None, + 'applicable_cycles': json.loads(self.applicable_cycles) if self.applicable_cycles else None, + 'min_amount': float(self.min_amount) if self.min_amount else None, + 'max_uses': self.max_uses, + 'max_uses_per_user': self.max_uses_per_user, + 'current_uses': self.current_uses, + 'valid_from': self.valid_from.isoformat() if self.valid_from else None, + 'valid_until': self.valid_until.isoformat() if self.valid_until else None, + 'is_active': self.is_active + } + + +class PromoCodeUsage(db.Model): + """优惠码使用记录表""" + __tablename__ = 'promo_code_usage' + + id = db.Column(db.Integer, primary_key=True, autoincrement=True) + promo_code_id = db.Column(db.Integer, db.ForeignKey('promo_codes.id'), nullable=False) + user_id = db.Column(db.Integer, nullable=False, index=True) + order_id = db.Column(db.Integer, db.ForeignKey('payment_orders.id'), nullable=False) + + original_amount = db.Column(db.Numeric(10, 2), nullable=False) + discount_amount = db.Column(db.Numeric(10, 2), nullable=False) + final_amount = db.Column(db.Numeric(10, 2), nullable=False) + + used_at = db.Column(db.DateTime, default=beijing_now) + + # 关系 + promo_code = db.relationship('PromoCode', backref='usages') + order = db.relationship('PaymentOrder', backref='promo_usage') + + +class SubscriptionUpgrade(db.Model): + """订阅升级/降级记录表""" + __tablename__ = 'subscription_upgrades' + + id = db.Column(db.Integer, primary_key=True, autoincrement=True) + user_id = db.Column(db.Integer, nullable=False, index=True) + order_id = db.Column(db.Integer, db.ForeignKey('payment_orders.id'), nullable=False) + + # 原订阅信息 + from_plan = db.Column(db.String(20), nullable=False) + from_cycle = db.Column(db.String(10), nullable=False) + from_end_date = db.Column(db.DateTime, nullable=True) + + # 新订阅信息 + to_plan = db.Column(db.String(20), nullable=False) + to_cycle = db.Column(db.String(10), nullable=False) + to_end_date = db.Column(db.DateTime, nullable=False) + + # 价格计算 + remaining_value = db.Column(db.Numeric(10, 2), nullable=False) + upgrade_amount = db.Column(db.Numeric(10, 2), nullable=False) + actual_amount = db.Column(db.Numeric(10, 2), nullable=False) + + upgrade_type = db.Column(db.String(20), nullable=False) # 'plan_upgrade', 'cycle_change', 'both' + created_at = db.Column(db.DateTime, default=beijing_now) + + # 关系 + order = db.relationship('PaymentOrder', backref='upgrade_record') + + +# ============================================ +# 模拟盘相关模型 +# ============================================ +class SimulationAccount(db.Model): + """模拟账户""" + __tablename__ = 'simulation_accounts' + + id = db.Column(db.Integer, primary_key=True) + user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False, unique=True) + account_name = db.Column(db.String(100), default='我的模拟账户') + initial_capital = db.Column(db.Numeric(15, 2), default=1000000.00) # 初始资金 + available_cash = db.Column(db.Numeric(15, 2), default=1000000.00) # 可用资金 + frozen_cash = db.Column(db.Numeric(15, 2), default=0.00) # 冻结资金 + position_value = db.Column(db.Numeric(15, 2), default=0.00) # 持仓市值 + total_assets = db.Column(db.Numeric(15, 2), default=1000000.00) # 总资产 + total_profit = db.Column(db.Numeric(15, 2), default=0.00) # 总盈亏 + total_profit_rate = db.Column(db.Numeric(10, 4), default=0.00) # 总收益率 + daily_profit = db.Column(db.Numeric(15, 2), default=0.00) # 日盈亏 + daily_profit_rate = db.Column(db.Numeric(10, 4), default=0.00) # 日收益率 + created_at = db.Column(db.DateTime, default=beijing_now) + updated_at = db.Column(db.DateTime, default=beijing_now, onupdate=beijing_now) + last_settlement_date = db.Column(db.Date) # 最后结算日期 + + # 关系 + user = db.relationship('User', backref='simulation_account') + positions = db.relationship('SimulationPosition', backref='account', lazy='dynamic') + orders = db.relationship('SimulationOrder', backref='account', lazy='dynamic') + transactions = db.relationship('SimulationTransaction', backref='account', lazy='dynamic') + + def calculate_total_assets(self): + """计算总资产""" + self.total_assets = self.available_cash + self.frozen_cash + self.position_value + self.total_profit = self.total_assets - self.initial_capital + self.total_profit_rate = (self.total_profit / self.initial_capital) * 100 if self.initial_capital > 0 else 0 + return self.total_assets + + +class SimulationPosition(db.Model): + """模拟持仓""" + __tablename__ = 'simulation_positions' + + id = db.Column(db.Integer, primary_key=True) + account_id = db.Column(db.Integer, db.ForeignKey('simulation_accounts.id'), nullable=False) + stock_code = db.Column(db.String(20), nullable=False) + stock_name = db.Column(db.String(100)) + position_qty = db.Column(db.Integer, default=0) # 持仓数量 + available_qty = db.Column(db.Integer, default=0) # 可用数量(T+1) + frozen_qty = db.Column(db.Integer, default=0) # 冻结数量 + avg_cost = db.Column(db.Numeric(10, 3), default=0.00) # 平均成本 + current_price = db.Column(db.Numeric(10, 3), default=0.00) # 当前价格 + market_value = db.Column(db.Numeric(15, 2), default=0.00) # 市值 + profit = db.Column(db.Numeric(15, 2), default=0.00) # 盈亏 + profit_rate = db.Column(db.Numeric(10, 4), default=0.00) # 盈亏比例 + today_profit = db.Column(db.Numeric(15, 2), default=0.00) # 今日盈亏 + today_profit_rate = db.Column(db.Numeric(10, 4), default=0.00) # 今日盈亏比例 + created_at = db.Column(db.DateTime, default=beijing_now) + updated_at = db.Column(db.DateTime, default=beijing_now, onupdate=beijing_now) + + __table_args__ = ( + db.UniqueConstraint('account_id', 'stock_code', name='unique_account_stock'), + ) + + def update_market_value(self, current_price): + """更新市值和盈亏""" + self.current_price = current_price + self.market_value = self.position_qty * current_price + total_cost = self.position_qty * self.avg_cost + self.profit = self.market_value - total_cost + self.profit_rate = (self.profit / total_cost * 100) if total_cost > 0 else 0 + return self.market_value + + +class SimulationOrder(db.Model): + """模拟订单""" + __tablename__ = 'simulation_orders' + + id = db.Column(db.Integer, primary_key=True) + account_id = db.Column(db.Integer, db.ForeignKey('simulation_accounts.id'), nullable=False) + order_no = db.Column(db.String(32), unique=True, nullable=False) + stock_code = db.Column(db.String(20), nullable=False) + stock_name = db.Column(db.String(100)) + order_type = db.Column(db.String(10), nullable=False) # BUY/SELL + price_type = db.Column(db.String(10), default='MARKET') # MARKET/LIMIT + order_price = db.Column(db.Numeric(10, 3)) # 委托价格 + order_qty = db.Column(db.Integer, nullable=False) # 委托数量 + filled_qty = db.Column(db.Integer, default=0) # 成交数量 + filled_price = db.Column(db.Numeric(10, 3)) # 成交价格 + filled_amount = db.Column(db.Numeric(15, 2)) # 成交金额 + commission = db.Column(db.Numeric(10, 2), default=0.00) # 手续费 + stamp_tax = db.Column(db.Numeric(10, 2), default=0.00) # 印花税 + transfer_fee = db.Column(db.Numeric(10, 2), default=0.00) # 过户费 + total_fee = db.Column(db.Numeric(10, 2), default=0.00) # 总费用 + status = db.Column(db.String(20), default='PENDING') # PENDING/PARTIAL/FILLED/CANCELLED/REJECTED + reject_reason = db.Column(db.String(200)) + order_time = db.Column(db.DateTime, default=beijing_now) + filled_time = db.Column(db.DateTime) + cancel_time = db.Column(db.DateTime) + + def calculate_fees(self): + """计算交易费用""" + if not self.filled_amount: + return 0 + + # 佣金(万分之2.5,最低5元) + self.commission = max(float(self.filled_amount) * 0.00025, 5.0) + + # 印花税(卖出时收取千分之1) + if self.order_type == 'SELL': + self.stamp_tax = float(self.filled_amount) * 0.001 + else: + self.stamp_tax = 0 + + # 过户费(双向收取,万分之0.2) + self.transfer_fee = float(self.filled_amount) * 0.00002 + + # 总费用 + self.total_fee = self.commission + self.stamp_tax + self.transfer_fee + + return self.total_fee + + +class SimulationTransaction(db.Model): + """模拟成交记录""" + __tablename__ = 'simulation_transactions' + + id = db.Column(db.Integer, primary_key=True) + account_id = db.Column(db.Integer, db.ForeignKey('simulation_accounts.id'), nullable=False) + order_id = db.Column(db.Integer, db.ForeignKey('simulation_orders.id'), nullable=False) + transaction_no = db.Column(db.String(32), unique=True, nullable=False) + stock_code = db.Column(db.String(20), nullable=False) + stock_name = db.Column(db.String(100)) + transaction_type = db.Column(db.String(10), nullable=False) # BUY/SELL + transaction_price = db.Column(db.Numeric(10, 3), nullable=False) + transaction_qty = db.Column(db.Integer, nullable=False) + transaction_amount = db.Column(db.Numeric(15, 2), nullable=False) + commission = db.Column(db.Numeric(10, 2), default=0.00) + stamp_tax = db.Column(db.Numeric(10, 2), default=0.00) + transfer_fee = db.Column(db.Numeric(10, 2), default=0.00) + total_fee = db.Column(db.Numeric(10, 2), default=0.00) + transaction_time = db.Column(db.DateTime, default=beijing_now) + settlement_date = db.Column(db.Date) # T+1结算日期 + + # 关系 + order = db.relationship('SimulationOrder', backref='transactions') + + +class SimulationDailyStats(db.Model): + """模拟账户日统计""" + __tablename__ = 'simulation_daily_stats' + + id = db.Column(db.Integer, primary_key=True) + account_id = db.Column(db.Integer, db.ForeignKey('simulation_accounts.id'), nullable=False) + stat_date = db.Column(db.Date, nullable=False) + opening_assets = db.Column(db.Numeric(15, 2)) # 期初资产 + closing_assets = db.Column(db.Numeric(15, 2)) # 期末资产 + daily_profit = db.Column(db.Numeric(15, 2)) # 日盈亏 + daily_profit_rate = db.Column(db.Numeric(10, 4)) # 日收益率 + total_profit = db.Column(db.Numeric(15, 2)) # 累计盈亏 + total_profit_rate = db.Column(db.Numeric(10, 4)) # 累计收益率 + trade_count = db.Column(db.Integer, default=0) # 交易次数 + win_count = db.Column(db.Integer, default=0) # 盈利次数 + loss_count = db.Column(db.Integer, default=0) # 亏损次数 + max_profit = db.Column(db.Numeric(15, 2)) # 最大盈利 + max_loss = db.Column(db.Numeric(15, 2)) # 最大亏损 + created_at = db.Column(db.DateTime, default=beijing_now) + + __table_args__ = ( + db.UniqueConstraint('account_id', 'stat_date', name='unique_account_date'), + ) + + +def get_user_subscription_safe(user_id): + """安全地获取用户订阅信息""" + try: + subscription = UserSubscription.query.filter_by(user_id=user_id).first() + if not subscription: + subscription = UserSubscription(user_id=user_id) + db.session.add(subscription) + db.session.commit() + return subscription + except Exception as e: + # 返回默认免费版本对象 + class DefaultSub: + def to_dict(self): + return { + 'type': 'free', + 'status': 'active', + 'is_active': True, + 'days_left': 999, + 'billing_cycle': None, + 'auto_renewal': False + } + + return DefaultSub() + + +def activate_user_subscription(user_id, plan_type, billing_cycle, extend_from_now=False): + """激活用户订阅 + + Args: + user_id: 用户ID + plan_type: 套餐类型 + billing_cycle: 计费周期 + extend_from_now: 是否从当前时间开始延长(用于升级场景) + """ + try: + subscription = UserSubscription.query.filter_by(user_id=user_id).first() + if not subscription: + subscription = UserSubscription(user_id=user_id) + db.session.add(subscription) + + subscription.subscription_type = plan_type + subscription.subscription_status = 'active' + subscription.billing_cycle = billing_cycle + + if not extend_from_now or not subscription.start_date: + subscription.start_date = beijing_now() + + if billing_cycle == 'monthly': + subscription.end_date = beijing_now() + timedelta(days=30) + else: # yearly + subscription.end_date = beijing_now() + timedelta(days=365) + + subscription.updated_at = beijing_now() + db.session.commit() + return subscription + except Exception as e: + return None + + +def validate_promo_code(code, plan_name, billing_cycle, amount, user_id): + """验证优惠码 + + Returns: + tuple: (promo_code_obj, error_message) + """ + try: + promo = PromoCode.query.filter_by(code=code.upper(), is_active=True).first() + + if not promo: + return None, "优惠码不存在或已失效" + + # 检查有效期 + now = beijing_now() + if now < promo.valid_from: + return None, "优惠码尚未生效" + if now > promo.valid_until: + return None, "优惠码已过期" + + # 检查使用次数 + if promo.max_uses and promo.current_uses >= promo.max_uses: + return None, "优惠码已被使用完" + + # 检查每用户使用次数 + if promo.max_uses_per_user: + user_usage_count = PromoCodeUsage.query.filter_by( + promo_code_id=promo.id, + user_id=user_id + ).count() + if user_usage_count >= promo.max_uses_per_user: + return None, f"您已使用过此优惠码(限用{promo.max_uses_per_user}次)" + + # 检查适用套餐 + if promo.applicable_plans: + try: + applicable = json.loads(promo.applicable_plans) + if plan_name not in applicable: + return None, "该优惠码不适用于此套餐" + except: + pass + + # 检查适用周期 + if promo.applicable_cycles: + try: + applicable = json.loads(promo.applicable_cycles) + if billing_cycle not in applicable: + return None, "该优惠码不适用于此计费周期" + except: + pass + + # 检查最低消费 + if promo.min_amount and amount < float(promo.min_amount): + return None, f"需满{float(promo.min_amount):.2f}元才可使用此优惠码" + + return promo, None + except Exception as e: + return None, f"验证优惠码时出错: {str(e)}" + + +def calculate_discount(promo_code, amount): + """计算优惠金额""" + try: + if promo_code.discount_type == 'percentage': + discount = amount * (float(promo_code.discount_value) / 100) + else: # fixed_amount + discount = float(promo_code.discount_value) + + # 确保折扣不超过总金额 + return min(discount, amount) + except: + return 0 + + +def calculate_remaining_value(subscription, current_plan): + """计算当前订阅的剩余价值""" + try: + if not subscription or not subscription.end_date: + return 0 + + now = beijing_now() + if subscription.end_date <= now: + return 0 + + days_left = (subscription.end_date - now).days + + if subscription.billing_cycle == 'monthly': + daily_value = float(current_plan.monthly_price) / 30 + else: # yearly + daily_value = float(current_plan.yearly_price) / 365 + + return daily_value * days_left + except: + return 0 + + +def calculate_upgrade_price(user_id, to_plan_name, to_cycle, promo_code=None): + """计算升级所需价格 + + Returns: + dict: 包含价格计算结果的字典 + """ + try: + # 1. 获取当前订阅 + current_sub = UserSubscription.query.filter_by(user_id=user_id).first() + + # 2. 获取目标套餐 + to_plan = SubscriptionPlan.query.filter_by(name=to_plan_name, is_active=True).first() + if not to_plan: + return {'error': '目标套餐不存在'} + + # 3. 计算目标套餐价格 + new_price = float(to_plan.yearly_price if to_cycle == 'yearly' else to_plan.monthly_price) + + # 4. 如果是新订阅(非升级) + if not current_sub or current_sub.subscription_type == 'free': + result = { + 'is_upgrade': False, + 'new_plan_price': new_price, + 'remaining_value': 0, + 'upgrade_amount': new_price, + 'original_amount': new_price, + 'discount_amount': 0, + 'final_amount': new_price, + 'promo_code': None + } + + # 应用优惠码 + if promo_code: + promo, error = validate_promo_code(promo_code, to_plan_name, to_cycle, new_price, user_id) + if promo: + discount = calculate_discount(promo, new_price) + result['discount_amount'] = discount + result['final_amount'] = new_price - discount + result['promo_code'] = promo.code + elif error: + result['promo_error'] = error + + return result + + # 5. 升级场景:计算剩余价值 + current_plan = SubscriptionPlan.query.filter_by(name=current_sub.subscription_type, is_active=True).first() + if not current_plan: + return {'error': '当前套餐信息不存在'} + + remaining_value = calculate_remaining_value(current_sub, current_plan) + + # 6. 计算升级差价 + upgrade_amount = max(0, new_price - remaining_value) + + # 7. 判断升级类型 + upgrade_type = 'new' + if current_sub.subscription_type != to_plan_name and current_sub.billing_cycle != to_cycle: + upgrade_type = 'both' + elif current_sub.subscription_type != to_plan_name: + upgrade_type = 'plan_upgrade' + elif current_sub.billing_cycle != to_cycle: + upgrade_type = 'cycle_change' + + result = { + 'is_upgrade': True, + 'upgrade_type': upgrade_type, + 'current_plan': current_sub.subscription_type, + 'current_cycle': current_sub.billing_cycle, + 'current_end_date': current_sub.end_date.isoformat() if current_sub.end_date else None, + 'new_plan_price': new_price, + 'remaining_value': remaining_value, + 'upgrade_amount': upgrade_amount, + 'original_amount': upgrade_amount, + 'discount_amount': 0, + 'final_amount': upgrade_amount, + 'promo_code': None + } + + # 8. 应用优惠码 + if promo_code and upgrade_amount > 0: + promo, error = validate_promo_code(promo_code, to_plan_name, to_cycle, upgrade_amount, user_id) + if promo: + discount = calculate_discount(promo, upgrade_amount) + result['discount_amount'] = discount + result['final_amount'] = upgrade_amount - discount + result['promo_code'] = promo.code + elif error: + result['promo_error'] = error + + return result + except Exception as e: + return {'error': str(e)} + + +def initialize_subscription_plans_safe(): + """安全地初始化订阅套餐""" + try: + if SubscriptionPlan.query.first(): + return + + pro_plan = SubscriptionPlan( + name='pro', + display_name='Pro版本', + description='适合个人投资者的基础功能套餐', + monthly_price=0.01, + yearly_price=0.08, + features=json.dumps([ + "基础股票分析工具", + "历史数据查询", + "基础财务报表", + "简单投资计划记录", + "标准客服支持" + ]), + sort_order=1 + ) + + max_plan = SubscriptionPlan( + name='max', + display_name='Max版本', + description='适合专业投资者的全功能套餐', + monthly_price=0.1, + yearly_price=0.8, + features=json.dumps([ + "全部Pro版本功能", + "高级分析工具", + "实时数据推送", + "专业财务分析报告", + "AI投资建议", + "无限投资计划存储", + "优先客服支持", + "独家研报访问" + ]), + sort_order=2 + ) + + db.session.add(pro_plan) + db.session.add(max_plan) + db.session.commit() + except Exception as e: + pass + + +# -------------------------------------------- +# 订阅等级工具函数 +# -------------------------------------------- +def _get_current_subscription_info(): + """获取当前登录用户订阅信息的字典形式,未登录或异常时视为免费用户。""" + try: + user_id = session.get('user_id') + if not user_id: + return { + 'type': 'free', + 'status': 'active', + 'is_active': True + } + sub = get_user_subscription_safe(user_id) + data = sub.to_dict() + # 标准化字段名 + return { + 'type': data.get('type') or data.get('subscription_type') or 'free', + 'status': data.get('status') or data.get('subscription_status') or 'active', + 'is_active': data.get('is_active', True) + } + except Exception: + return { + 'type': 'free', + 'status': 'active', + 'is_active': True + } + + +def _subscription_level(sub_type): + """将订阅类型映射到等级数值,free=0, pro=1, max=2。""" + mapping = {'free': 0, 'pro': 1, 'max': 2} + return mapping.get((sub_type or 'free').lower(), 0) + + +def _has_required_level(required: str) -> bool: + """判断当前用户是否达到所需订阅级别。""" + info = _get_current_subscription_info() + if not info.get('is_active', True): + return False + return _subscription_level(info.get('type')) >= _subscription_level(required) + + +# ============================================ +# 订阅相关API接口 +# ============================================ + +@app.route('/api/subscription/plans', methods=['GET']) +def get_subscription_plans(): + """获取订阅套餐列表""" + try: + plans = SubscriptionPlan.query.filter_by(is_active=True).order_by(SubscriptionPlan.sort_order).all() + return jsonify({ + 'success': True, + 'data': [plan.to_dict() for plan in plans] + }) + except Exception as e: + # 返回默认套餐(包含pricing_options以兼容新前端) + default_plans = [ + { + 'id': 1, + 'name': 'pro', + 'display_name': 'Pro版本', + 'description': '适合个人投资者的基础功能套餐', + 'monthly_price': 198, + 'yearly_price': 2000, + 'pricing_options': [ + {'months': 1, 'price': 198, 'label': '月付', 'cycle_key': 'monthly'}, + {'months': 3, 'price': 534, 'label': '3个月', 'cycle_key': '3months', 'discount_percent': 10}, + {'months': 6, 'price': 950, 'label': '半年', 'cycle_key': '6months', 'discount_percent': 20}, + {'months': 12, 'price': 2000, 'label': '1年', 'cycle_key': 'yearly', 'discount_percent': 16}, + {'months': 24, 'price': 3600, 'label': '2年', 'cycle_key': '2years', 'discount_percent': 24}, + {'months': 36, 'price': 5040, 'label': '3年', 'cycle_key': '3years', 'discount_percent': 29} + ], + 'features': ['基础股票分析工具', '历史数据查询', '基础财务报表', '简单投资计划记录', '标准客服支持'], + 'is_active': True, + 'sort_order': 1 + }, + { + 'id': 2, + 'name': 'max', + 'display_name': 'Max版本', + 'description': '适合专业投资者的全功能套餐', + 'monthly_price': 998, + 'yearly_price': 10000, + 'pricing_options': [ + {'months': 1, 'price': 998, 'label': '月付', 'cycle_key': 'monthly'}, + {'months': 3, 'price': 2695, 'label': '3个月', 'cycle_key': '3months', 'discount_percent': 10}, + {'months': 6, 'price': 4790, 'label': '半年', 'cycle_key': '6months', 'discount_percent': 20}, + {'months': 12, 'price': 10000, 'label': '1年', 'cycle_key': 'yearly', 'discount_percent': 17}, + {'months': 24, 'price': 18000, 'label': '2年', 'cycle_key': '2years', 'discount_percent': 25}, + {'months': 36, 'price': 25200, 'label': '3年', 'cycle_key': '3years', 'discount_percent': 30} + ], + 'features': ['全部Pro版本功能', '高级分析工具', '实时数据推送', 'API访问', '优先客服支持'], + 'is_active': True, + 'sort_order': 2 + } + ] + return jsonify({ + 'success': True, + 'data': default_plans + }) + + +@app.route('/api/subscription/current', methods=['GET']) +def get_current_subscription(): + """获取当前用户的订阅信息""" + try: + if 'user_id' not in session: + return jsonify({'success': False, 'error': '未登录'}), 401 + + subscription = get_user_subscription_safe(session['user_id']) + return jsonify({ + 'success': True, + 'data': subscription.to_dict() + }) + except Exception as e: + return jsonify({ + 'success': True, + 'data': { + 'type': 'free', + 'status': 'active', + 'is_active': True, + 'days_left': 999 + } + }) + + +@app.route('/api/subscription/info', methods=['GET']) +def get_subscription_info(): + """获取当前用户的订阅信息 - 前端专用接口""" + try: + info = _get_current_subscription_info() + return jsonify({ + 'success': True, + 'data': info + }) + except Exception as e: + print(f"获取订阅信息错误: {e}") + return jsonify({ + 'success': True, + 'data': { + 'type': 'free', + 'status': 'active', + 'is_active': True, + 'days_left': 999 + } + }) + + +@app.route('/api/promo-code/validate', methods=['POST']) +def validate_promo_code_api(): + """验证优惠码""" + try: + if 'user_id' not in session: + return jsonify({'success': False, 'error': '未登录'}), 401 + + data = request.get_json() + code = data.get('code', '').strip() + plan_name = data.get('plan_name') + billing_cycle = data.get('billing_cycle') + amount = data.get('amount', 0) + + if not code or not plan_name or not billing_cycle: + return jsonify({'success': False, 'error': '参数不完整'}), 400 + + # 验证优惠码 + promo, error = validate_promo_code(code, plan_name, billing_cycle, amount, session['user_id']) + + if error: + return jsonify({ + 'success': False, + 'valid': False, + 'error': error + }) + + # 计算折扣 + discount_amount = calculate_discount(promo, amount) + final_amount = amount - discount_amount + + return jsonify({ + 'success': True, + 'valid': True, + 'promo_code': promo.to_dict(), + 'discount_amount': discount_amount, + 'final_amount': final_amount + }) + + except Exception as e: + return jsonify({ + 'success': False, + 'error': f'验证失败: {str(e)}' + }), 500 + + +@app.route('/api/subscription/calculate-price', methods=['POST']) +def calculate_subscription_price(): + """计算订阅价格(支持升级和优惠码)""" + try: + if 'user_id' not in session: + return jsonify({'success': False, 'error': '未登录'}), 401 + + data = request.get_json() + to_plan = data.get('to_plan') + to_cycle = data.get('to_cycle') + promo_code = data.get('promo_code', '').strip() or None + + if not to_plan or not to_cycle: + return jsonify({'success': False, 'error': '参数不完整'}), 400 + + # 计算价格 + result = calculate_upgrade_price(session['user_id'], to_plan, to_cycle, promo_code) + + if 'error' in result: + return jsonify({ + 'success': False, + 'error': result['error'] + }), 400 + + return jsonify({ + 'success': True, + 'data': result + }) + + except Exception as e: + return jsonify({ + 'success': False, + 'error': f'计算失败: {str(e)}' + }), 500 + + +@app.route('/api/payment/create-order', methods=['POST']) +def create_payment_order(): + """创建支付订单(支持升级和优惠码)""" + try: + if 'user_id' not in session: + return jsonify({'success': False, 'error': '未登录'}), 401 + + data = request.get_json() + plan_name = data.get('plan_name') + billing_cycle = data.get('billing_cycle') + promo_code = data.get('promo_code', '').strip() or None + + if not plan_name or not billing_cycle: + return jsonify({'success': False, 'error': '参数不完整'}), 400 + + # 计算价格(包括升级和优惠码) + price_result = calculate_upgrade_price(session['user_id'], plan_name, billing_cycle, promo_code) + + if 'error' in price_result: + return jsonify({'success': False, 'error': price_result['error']}), 400 + + amount = price_result['final_amount'] + original_amount = price_result['original_amount'] + discount_amount = price_result['discount_amount'] + is_upgrade = price_result.get('is_upgrade', False) + + # 创建订单 + try: + order = PaymentOrder( + user_id=session['user_id'], + plan_name=plan_name, + billing_cycle=billing_cycle, + amount=amount + ) + + # 添加扩展字段(使用动态属性) + if hasattr(order, 'original_amount') or True: # 兼容性检查 + order.original_amount = original_amount + order.discount_amount = discount_amount + order.is_upgrade = is_upgrade + + # 如果使用了优惠码,关联优惠码 + if promo_code and price_result.get('promo_code'): + promo_obj = PromoCode.query.filter_by(code=promo_code.upper()).first() + if promo_obj: + order.promo_code_id = promo_obj.id + + # 如果是升级,记录原套餐信息 + if is_upgrade: + order.upgrade_from_plan = price_result.get('current_plan') + + db.session.add(order) + db.session.commit() + + # 如果是升级订单,创建升级记录 + if is_upgrade and price_result.get('upgrade_type'): + try: + upgrade_record = SubscriptionUpgrade( + user_id=session['user_id'], + order_id=order.id, + from_plan=price_result['current_plan'], + from_cycle=price_result['current_cycle'], + from_end_date=datetime.fromisoformat(price_result['current_end_date']) if price_result.get('current_end_date') else None, + to_plan=plan_name, + to_cycle=billing_cycle, + to_end_date=beijing_now() + timedelta(days=365 if billing_cycle == 'yearly' else 30), + remaining_value=price_result['remaining_value'], + upgrade_amount=price_result['upgrade_amount'], + actual_amount=amount, + upgrade_type=price_result['upgrade_type'] + ) + db.session.add(upgrade_record) + db.session.commit() + except Exception as e: + print(f"创建升级记录失败: {e}") + # 不影响主流程 + + except Exception as e: + db.session.rollback() + return jsonify({'success': False, 'error': f'订单创建失败: {str(e)}'}), 500 + + # 尝试调用真实的微信支付API + try: + from wechat_pay import create_wechat_pay_instance, check_wechat_pay_ready + + # 检查微信支付是否就绪 + is_ready, ready_msg = check_wechat_pay_ready() + if not is_ready: + # 使用模拟二维码 + order.qr_code_url = f"https://api.qrserver.com/v1/create-qr-code/?size=200x200&data=wxpay://order/{order.order_no}" + order.remark = f"演示模式 - {ready_msg}" + else: + wechat_pay = create_wechat_pay_instance() + + # 创建微信支付订单 + plan_display_name = f"{plan_name.upper()}版本-{billing_cycle}" + wechat_result = wechat_pay.create_native_order( + order_no=order.order_no, + total_fee=float(amount), + body=f"VFr-{plan_display_name}", + product_id=f"{plan_name}_{billing_cycle}" + ) + + if wechat_result['success']: + + # 获取微信返回的原始code_url + wechat_code_url = wechat_result['code_url'] + + # 将微信协议URL转换为二维码图片URL + import urllib.parse + encoded_url = urllib.parse.quote(wechat_code_url, safe='') + qr_image_url = f"https://api.qrserver.com/v1/create-qr-code/?size=200x200&data={encoded_url}" + + order.qr_code_url = qr_image_url + order.prepay_id = wechat_result.get('prepay_id') + order.remark = f"微信支付 - {wechat_code_url}" + + else: + order.qr_code_url = f"https://api.qrserver.com/v1/create-qr-code/?size=200x200&data=wxpay://order/{order.order_no}" + order.remark = f"微信支付失败: {wechat_result.get('error')}" + + except ImportError as e: + order.qr_code_url = f"https://api.qrserver.com/v1/create-qr-code/?size=200x200&data=wxpay://order/{order.order_no}" + order.remark = "微信支付模块未配置" + except Exception as e: + order.qr_code_url = f"https://api.qrserver.com/v1/create-qr-code/?size=200x200&data=wxpay://order/{order.order_no}" + order.remark = f"支付异常: {str(e)}" + + db.session.commit() + + return jsonify({ + 'success': True, + 'data': order.to_dict(), + 'message': '订单创建成功' + }) + + except Exception as e: + db.session.rollback() + return jsonify({'success': False, 'error': '创建订单失败'}), 500 + + +@app.route('/api/payment/order//status', methods=['GET']) +def check_order_status(order_id): + """查询订单支付状态""" + try: + if 'user_id' not in session: + return jsonify({'success': False, 'error': '未登录'}), 401 + + # 查找订单 + order = PaymentOrder.query.filter_by( + id=order_id, + user_id=session['user_id'] + ).first() + + if not order: + return jsonify({'success': False, 'error': '订单不存在'}), 404 + + # 如果订单已经是已支付状态,直接返回 + if order.status == 'paid': + return jsonify({ + 'success': True, + 'data': order.to_dict(), + 'message': '订单已支付', + 'payment_success': True + }) + + # 如果订单过期,标记为过期 + if order.is_expired(): + order.status = 'expired' + db.session.commit() + return jsonify({ + 'success': True, + 'data': order.to_dict(), + 'message': '订单已过期' + }) + + # 调用微信支付API查询真实状态 + try: + from wechat_pay import create_wechat_pay_instance + wechat_pay = create_wechat_pay_instance() + + query_result = wechat_pay.query_order(order_no=order.order_no) + + if query_result['success']: + trade_state = query_result.get('trade_state') + transaction_id = query_result.get('transaction_id') + + if trade_state == 'SUCCESS': + # 支付成功,更新订单状态 + order.mark_as_paid(transaction_id) + + # 激活用户订阅 + activate_user_subscription(order.user_id, order.plan_name, order.billing_cycle) + + return jsonify({ + 'success': True, + 'data': order.to_dict(), + 'message': '支付成功!订阅已激活', + 'payment_success': True + }) + elif trade_state in ['NOTPAY', 'USERPAYING']: + # 未支付或支付中 + return jsonify({ + 'success': True, + 'data': order.to_dict(), + 'message': '等待支付...', + 'payment_success': False + }) + else: + # 支付失败或取消 + order.status = 'cancelled' + db.session.commit() + return jsonify({ + 'success': True, + 'data': order.to_dict(), + 'message': '支付已取消', + 'payment_success': False + }) + else: + # 微信查询失败,返回当前状态 + return jsonify({ + 'success': True, + 'data': order.to_dict(), + 'message': f"查询失败: {query_result.get('error')}", + 'payment_success': False + }) + + except Exception as e: + # 查询失败,返回当前订单状态 + return jsonify({ + 'success': True, + 'data': order.to_dict(), + 'message': '无法查询支付状态,请稍后重试', + 'payment_success': False + }) + + except Exception as e: + return jsonify({'success': False, 'error': '查询失败'}), 500 + + +@app.route('/api/payment/order//force-update', methods=['POST']) +def force_update_order_status(order_id): + """强制更新订单支付状态(调试用)""" + try: + if 'user_id' not in session: + return jsonify({'success': False, 'error': '未登录'}), 401 + + # 查找订单 + order = PaymentOrder.query.filter_by( + id=order_id, + user_id=session['user_id'] + ).first() + + if not order: + return jsonify({'success': False, 'error': '订单不存在'}), 404 + + # 检查微信支付状态 + try: + from wechat_pay import create_wechat_pay_instance + wechat_pay = create_wechat_pay_instance() + + query_result = wechat_pay.query_order(order_no=order.order_no) + + if query_result['success'] and query_result.get('trade_state') == 'SUCCESS': + # 强制更新为已支付 + old_status = order.status + order.mark_as_paid(query_result.get('transaction_id')) + + # 激活用户订阅 + activate_user_subscription(order.user_id, order.plan_name, order.billing_cycle) + + # 记录优惠码使用(如果使用了优惠码) + if hasattr(order, 'promo_code_id') and order.promo_code_id: + try: + promo_usage = PromoCodeUsage( + promo_code_id=order.promo_code_id, + user_id=order.user_id, + order_id=order.id, + original_amount=order.original_amount if hasattr(order, 'original_amount') else order.amount, + discount_amount=order.discount_amount if hasattr(order, 'discount_amount') else 0, + final_amount=order.amount + ) + db.session.add(promo_usage) + + # 更新优惠码使用次数 + promo = PromoCode.query.get(order.promo_code_id) + if promo: + promo.current_uses = (promo.current_uses or 0) + 1 + except Exception as e: + print(f"记录优惠码使用失败: {e}") + + db.session.commit() + + print(f"✅ 订单状态强制更新成功: {old_status} -> paid") + + return jsonify({ + 'success': True, + 'message': f'订单状态已从 {old_status} 更新为 paid', + 'data': order.to_dict(), + 'payment_success': True + }) + else: + return jsonify({ + 'success': False, + 'error': '微信支付状态不是成功状态,无法强制更新' + }) + + except Exception as e: + print(f"❌ 强制更新失败: {e}") + return jsonify({ + 'success': False, + 'error': f'强制更新失败: {str(e)}' + }) + + except Exception as e: + print(f"强制更新订单状态失败: {str(e)}") + return jsonify({'success': False, 'error': '操作失败'}), 500 + + +@app.route('/api/payment/wechat/callback', methods=['POST']) +def wechat_payment_callback(): + """微信支付回调处理""" + try: + # 获取原始XML数据 + raw_data = request.get_data() + print(f"📥 收到微信支付回调: {raw_data}") + + # 验证回调数据 + try: + from wechat_pay import create_wechat_pay_instance + wechat_pay = create_wechat_pay_instance() + verify_result = wechat_pay.verify_callback(raw_data.decode('utf-8')) + + if not verify_result['success']: + print(f"❌ 微信支付回调验证失败: {verify_result['error']}") + return '' + + callback_data = verify_result['data'] + + except Exception as e: + print(f"❌ 微信支付回调处理异常: {e}") + # 简单解析XML(fallback) + callback_data = _parse_xml_callback(raw_data.decode('utf-8')) + if not callback_data: + return '' + + # 获取关键字段 + return_code = callback_data.get('return_code') + result_code = callback_data.get('result_code') + order_no = callback_data.get('out_trade_no') + transaction_id = callback_data.get('transaction_id') + + print(f"📦 回调数据解析:") + print(f" 返回码: {return_code}") + print(f" 结果码: {result_code}") + print(f" 订单号: {order_no}") + print(f" 交易号: {transaction_id}") + + if not order_no: + return '' + + # 查找订单 + order = PaymentOrder.query.filter_by(order_no=order_no).first() + if not order: + print(f"❌ 订单不存在: {order_no}") + return '' + + # 处理支付成功 + if return_code == 'SUCCESS' and result_code == 'SUCCESS': + print(f"🎉 支付回调成功: 订单 {order_no}") + + # 检查订单是否已经处理过 + if order.status == 'paid': + print(f"ℹ️ 订单已处理过: {order_no}") + db.session.commit() + return '' + + # 更新订单状态(无论之前是什么状态) + old_status = order.status + order.mark_as_paid(transaction_id) + print(f"📝 订单状态已更新: {old_status} -> paid") + + # 激活用户订阅 + subscription = activate_user_subscription(order.user_id, order.plan_name, order.billing_cycle) + + if subscription: + print(f"✅ 用户订阅已激活: 用户{order.user_id}, 套餐{order.plan_name}") + else: + print(f"⚠️ 订阅激活失败,但订单已标记为已支付") + + db.session.commit() + + # 返回成功响应给微信 + return '' + + except Exception as e: + db.session.rollback() + print(f"❌ 微信支付回调处理失败: {e}") + import traceback + app.logger.error(f"回调处理错误: {e}", exc_info=True) + return '' + + +def _parse_xml_callback(xml_data): + """简单的XML回调数据解析""" + try: + import xml.etree.ElementTree as ET + root = ET.fromstring(xml_data) + result = {} + for child in root: + result[child.tag] = child.text + return result + except Exception as e: + print(f"XML解析失败: {e}") + return None + + +@app.route('/api/auth/session', methods=['GET']) +def get_session_info(): + """获取当前登录用户信息""" + if 'user_id' in session: + user = User.query.get(session['user_id']) + if user: + # 获取用户订阅信息 + subscription_info = get_user_subscription_safe(user.id).to_dict() + + return jsonify({ + 'success': True, + 'isAuthenticated': True, + 'user': { + 'id': user.id, + 'username': user.username, + 'nickname': user.nickname or user.username, + 'email': user.email, + 'phone': user.phone, + 'phone_confirmed': bool(user.phone_confirmed), + 'email_confirmed': bool(user.email_confirmed) if hasattr(user, 'email_confirmed') else None, + 'avatar_url': user.avatar_url, + 'has_wechat': bool(user.wechat_open_id), + 'created_at': user.created_at.isoformat() if user.created_at else None, + 'last_seen': user.last_seen.isoformat() if user.last_seen else None, + # 将订阅字段映射到前端期望的字段名 + 'subscription_type': subscription_info['type'], + 'subscription_status': subscription_info['status'], + 'subscription_end_date': subscription_info['end_date'], + 'is_subscription_active': subscription_info['is_active'], + 'subscription_days_left': subscription_info['days_left'] + } + }) + + return jsonify({ + 'success': True, + 'isAuthenticated': False, + 'user': None + }) + + +def generate_verification_code(): + """生成6位数字验证码""" + return ''.join(random.choices(string.digits, k=6)) + + +@app.route('/api/auth/login', methods=['POST']) +def login(): + """传统登录 - 使用Session""" + try: + + username = request.form.get('username') + email = request.form.get('email') + phone = request.form.get('phone') + password = request.form.get('password') + + # 验证必要参数 + if not password: + return jsonify({'success': False, 'error': '密码不能为空'}), 400 + + # 根据提供的信息查找用户 + user = None + if username: + # 检查username是否为手机号格式 + if re.match(r'^1[3-9]\d{9}$', username): + # 如果username是手机号格式,先按手机号查找 + user = User.query.filter_by(phone=username).first() + if not user: + # 如果没找到,再按用户名查找 + user = User.find_by_login_info(username) + else: + # 不是手机号格式,按用户名查找 + user = User.find_by_login_info(username) + elif email: + user = User.query.filter_by(email=email).first() + elif phone: + user = User.query.filter_by(phone=phone).first() + else: + return jsonify({'success': False, 'error': '请提供用户名、邮箱或手机号'}), 400 + + if not user: + return jsonify({'success': False, 'error': '用户不存在'}), 404 + + # 尝试密码验证 + password_valid = user.check_password(password) + + if not password_valid: + # 还可以尝试直接验证 + if user.password_hash: + from werkzeug.security import check_password_hash + direct_check = check_password_hash(user.password_hash, password) + return jsonify({'success': False, 'error': '密码错误'}), 401 + + # 设置session + session.permanent = True # 使用永久session + session['user_id'] = user.id + session['username'] = user.username + session['logged_in'] = True + + # Flask-Login 登录 + login_user(user, remember=True) + + # 更新最后登录时间 + user.update_last_seen() + + return jsonify({ + 'success': True, + 'message': '登录成功', + 'user': { + 'id': user.id, + 'username': user.username, + 'nickname': user.nickname or user.username, + 'email': user.email, + 'phone': user.phone, + 'avatar_url': user.avatar_url, + 'has_wechat': bool(user.wechat_open_id) + } + }) + + except Exception as e: + import traceback + app.logger.error(f"回调处理错误: {e}", exc_info=True) + return jsonify({'success': False, 'error': '登录处理失败,请重试'}), 500 + + +# 添加OPTIONS请求处理 +@app.before_request +def handle_preflight(): + if request.method == "OPTIONS": + response = make_response() + response.headers.add("Access-Control-Allow-Origin", "*") + response.headers.add('Access-Control-Allow-Headers', "*") + response.headers.add('Access-Control-Allow-Methods', "*") + return response + + +# 修改密码API +@app.route('/api/account/change-password', methods=['POST']) +@login_required +def change_password(): + """修改当前用户密码""" + try: + data = request.get_json() or request.form + current_password = data.get('currentPassword') or data.get('current_password') + new_password = data.get('newPassword') or data.get('new_password') + is_first_set = data.get('isFirstSet', False) # 是否为首次设置密码 + + if not new_password: + return jsonify({'success': False, 'error': '新密码不能为空'}), 400 + + if len(new_password) < 6: + return jsonify({'success': False, 'error': '新密码至少需要6个字符'}), 400 + + # 获取当前用户 + user = current_user + if not user: + return jsonify({'success': False, 'error': '用户未登录'}), 401 + + # 检查是否为微信用户且首次设置密码 + is_wechat_user = bool(user.wechat_open_id) + + # 如果是微信用户首次设置密码,或者明确标记为首次设置,则跳过当前密码验证 + if is_first_set or (is_wechat_user and not current_password): + pass # 跳过当前密码验证 + else: + # 普通用户或非首次设置,需要验证当前密码 + if not current_password: + return jsonify({'success': False, 'error': '请输入当前密码'}), 400 + + if not user.check_password(current_password): + return jsonify({'success': False, 'error': '当前密码错误'}), 400 + + # 设置新密码 + user.set_password(new_password) + db.session.commit() + + return jsonify({ + 'success': True, + 'message': '密码设置成功' if (is_first_set or is_wechat_user) else '密码修改成功' + }) + + except Exception as e: + return jsonify({'success': False, 'error': str(e)}), 500 + + +# 检查用户密码状态API +@app.route('/api/account/password-status', methods=['GET']) +@login_required +def get_password_status(): + """获取当前用户的密码状态信息""" + try: + user = current_user + if not user: + return jsonify({'success': False, 'error': '用户未登录'}), 401 + + is_wechat_user = bool(user.wechat_open_id) + + return jsonify({ + 'success': True, + 'data': { + 'isWechatUser': is_wechat_user, + 'hasPassword': bool(user.password_hash), + 'needsFirstTimeSetup': is_wechat_user # 微信用户需要首次设置 + } + }) + + except Exception as e: + return jsonify({'success': False, 'error': str(e)}), 500 + + +# 检查用户信息完整性API +@app.route('/api/account/profile-completeness', methods=['GET']) +@login_required +def get_profile_completeness(): + try: + user = current_user + if not user: + return jsonify({'success': False, 'error': '用户未登录'}), 401 + + is_wechat_user = bool(user.wechat_open_id) + + # 检查各项信息 + completeness = { + 'hasPassword': bool(user.password_hash), + 'hasPhone': bool(user.phone), + 'hasEmail': bool(user.email and '@' in user.email and not user.email.endswith('@valuefrontier.temp')), + 'isWechatUser': is_wechat_user + } + + # 计算完整度 + total_items = 3 + completed_items = sum([completeness['hasPassword'], completeness['hasPhone'], completeness['hasEmail']]) + completeness_percentage = int((completed_items / total_items) * 100) + + # 智能判断是否需要提醒 + needs_attention = False + missing_items = [] + + # 只在用户首次登录或最近登录时提醒 + if is_wechat_user: + # 检查用户是否是新用户(注册7天内) + is_new_user = (datetime.now() - user.created_at).days < 7 + + # 检查是否最近没有提醒过(使用session记录) + last_reminder = session.get('last_completeness_reminder') + should_remind = False + + if not last_reminder: + should_remind = True + else: + # 每7天最多提醒一次 + days_since_reminder = (datetime.now() - datetime.fromisoformat(last_reminder)).days + should_remind = days_since_reminder >= 7 + + # 只对新用户或长时间未完善的用户提醒 + if (is_new_user or completeness_percentage < 50) and should_remind: + needs_attention = True + if not completeness['hasPassword']: + missing_items.append('登录密码') + if not completeness['hasPhone']: + missing_items.append('手机号') + if not completeness['hasEmail']: + missing_items.append('邮箱') + + # 记录本次提醒时间 + session['last_completeness_reminder'] = datetime.now().isoformat() + + return jsonify({ + 'success': True, + 'data': { + 'completeness': completeness, + 'completenessPercentage': completeness_percentage, + 'needsAttention': needs_attention, + 'missingItems': missing_items, + 'isComplete': completed_items == total_items, + 'showReminder': needs_attention # 前端使用这个字段决定是否显示提醒 + } + }) + + except Exception as e: + print(f"获取资料完整性错误: {e}") + return jsonify({'success': False, 'error': str(e)}), 500 + + +@app.route('/api/auth/logout', methods=['POST']) +def logout(): + """登出 - 清除Session""" + logout_user() # Flask-Login 登出 + session.clear() + return jsonify({'success': True, 'message': '已登出'}) + + +@app.route('/api/auth/send-verification-code', methods=['POST']) +def send_verification_code(): + """发送验证码(支持手机号和邮箱)""" + try: + data = request.get_json() + credential = data.get('credential') # 手机号或邮箱 + code_type = data.get('type') # 'phone' 或 'email' + purpose = data.get('purpose', 'login') # 'login' 或 'register' + + if not credential or not code_type: + return jsonify({'success': False, 'error': '缺少必要参数'}), 400 + + # 清理格式字符(空格、横线、括号等) + if code_type == 'phone': + # 移除手机号中的空格、横线、括号、加号等格式字符 + credential = re.sub(r'[\s\-\(\)\+]', '', credential) + print(f"📱 清理后的手机号: {credential}") + elif code_type == 'email': + # 邮箱只移除空格 + credential = credential.strip() + + # 生成验证码 + verification_code = generate_verification_code() + + # 存储验证码到session(实际生产环境建议使用Redis) + session_key = f'verification_code_{code_type}_{credential}_{purpose}' + session[session_key] = { + 'code': verification_code, + 'timestamp': time.time(), + 'attempts': 0 + } + + if code_type == 'phone': + # 手机号验证码发送 + if not re.match(r'^1[3-9]\d{9}$', credential): + return jsonify({'success': False, 'error': '手机号格式不正确'}), 400 + + # 发送真实短信验证码 + if send_sms_code(credential, verification_code, SMS_TEMPLATE_LOGIN): + print(f"[短信已发送] 验证码到 {credential}: {verification_code}") + else: + return jsonify({'success': False, 'error': '短信发送失败,请稍后重试'}), 500 + + elif code_type == 'email': + # 邮箱验证码发送 + if not re.match(r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$', credential): + return jsonify({'success': False, 'error': '邮箱格式不正确'}), 400 + + # 发送真实邮件验证码 + if send_email_code(credential, verification_code): + print(f"[邮件已发送] 验证码到 {credential}: {verification_code}") + else: + return jsonify({'success': False, 'error': '邮件发送失败,请稍后重试'}), 500 + + else: + return jsonify({'success': False, 'error': '不支持的验证码类型'}), 400 + + return jsonify({ + 'success': True, + 'message': f'验证码已发送到您的{code_type}' + }) + + except Exception as e: + print(f"发送验证码错误: {e}") + return jsonify({'success': False, 'error': '发送验证码失败'}), 500 + + +@app.route('/api/auth/login-with-code', methods=['POST']) +def login_with_verification_code(): + """使用验证码登录/注册(自动注册)""" + try: + data = request.get_json() + credential = data.get('credential') # 手机号或邮箱 + verification_code = data.get('verification_code') + login_type = data.get('login_type') # 'phone' 或 'email' + + if not credential or not verification_code or not login_type: + return jsonify({'success': False, 'error': '缺少必要参数'}), 400 + + # 清理格式字符(空格、横线、括号等) + if login_type == 'phone': + # 移除手机号中的空格、横线、括号、加号等格式字符 + original_credential = credential + credential = re.sub(r'[\s\-\(\)\+]', '', credential) + if original_credential != credential: + print(f"📱 登录时清理手机号: {original_credential} -> {credential}") + elif login_type == 'email': + # 邮箱只移除前后空格 + credential = credential.strip() + + # 检查验证码 + session_key = f'verification_code_{login_type}_{credential}_login' + stored_code_info = session.get(session_key) + + if not stored_code_info: + return jsonify({'success': False, 'error': '验证码已过期或不存在'}), 400 + + # 检查验证码是否过期(5分钟) + if time.time() - stored_code_info['timestamp'] > 300: + session.pop(session_key, None) + return jsonify({'success': False, 'error': '验证码已过期'}), 400 + + # 检查尝试次数 + if stored_code_info['attempts'] >= 3: + session.pop(session_key, None) + return jsonify({'success': False, 'error': '验证码错误次数过多'}), 400 + + # 验证码错误 + if stored_code_info['code'] != verification_code: + stored_code_info['attempts'] += 1 + session[session_key] = stored_code_info + return jsonify({'success': False, 'error': '验证码错误'}), 400 + + # 验证码正确,查找用户 + user = None + is_new_user = False + + if login_type == 'phone': + user = User.query.filter_by(phone=credential).first() + if not user: + # 自动注册新用户 + is_new_user = True + # 生成唯一用户名 + base_username = f"user_{credential}" + username = base_username + counter = 1 + while User.query.filter_by(username=username).first(): + username = f"{base_username}_{counter}" + counter += 1 + + # 创建新用户 + user = User(username=username, phone=credential) + user.phone_confirmed = True + user.email = f"{username}@valuefrontier.temp" # 临时邮箱 + db.session.add(user) + db.session.commit() + + elif login_type == 'email': + user = User.query.filter_by(email=credential).first() + if not user: + # 自动注册新用户 + is_new_user = True + # 从邮箱生成用户名 + email_prefix = credential.split('@')[0] + base_username = f"user_{email_prefix}" + username = base_username + counter = 1 + while User.query.filter_by(username=username).first(): + username = f"{base_username}_{counter}" + counter += 1 + + # 如果用户不存在,自动创建新用户 + if not user: + try: + # 生成用户名 + if login_type == 'phone': + # 使用手机号生成用户名 + base_username = f"用户{credential[-4:]}" + elif login_type == 'email': + # 使用邮箱前缀生成用户名 + base_username = credential.split('@')[0] + else: + base_username = "新用户" + + # 确保用户名唯一 + username = base_username + counter = 1 + while User.is_username_taken(username): + username = f"{base_username}_{counter}" + counter += 1 + + # 创建新用户 + user = User(username=username) + + # 设置手机号或邮箱 + if login_type == 'phone': + user.phone = credential + elif login_type == 'email': + user.email = credential + + # 设置默认密码(使用随机密码,用户后续可以修改) + user.set_password(uuid.uuid4().hex) + user.status = 'active' + user.nickname = username + + db.session.add(user) + db.session.commit() + + is_new_user = True + print(f"✅ 自动创建新用户: {username}, {login_type}: {credential}") + + except Exception as e: + print(f"❌ 创建用户失败: {e}") + db.session.rollback() + return jsonify({'success': False, 'error': '创建用户失败'}), 500 + + # 清除验证码 + session.pop(session_key, None) + + # 设置session + session.permanent = True + session['user_id'] = user.id + session['username'] = user.username + session['logged_in'] = True + + # Flask-Login 登录 + login_user(user, remember=True) + + # 更新最后登录时间 + user.update_last_seen() + + # 根据是否为新用户返回不同的消息 + message = '注册成功,欢迎加入!' if is_new_user else '登录成功' + + return jsonify({ + 'success': True, + 'message': message, + 'is_new_user': is_new_user, + 'user': { + 'id': user.id, + 'username': user.username, + 'nickname': user.nickname or user.username, + 'email': user.email, + 'phone': user.phone, + 'avatar_url': user.avatar_url, + 'has_wechat': bool(user.wechat_open_id) + } + }) + + except Exception as e: + print(f"验证码登录错误: {e}") + db.session.rollback() + return jsonify({'success': False, 'error': '登录失败'}), 500 + + +@app.route('/api/auth/register', methods=['POST']) +def register(): + """用户注册 - 使用Session""" + username = request.form.get('username') + email = request.form.get('email') + password = request.form.get('password') + + # 验证输入 + if not all([username, email, password]): + return jsonify({'success': False, 'error': '所有字段都是必填的'}), 400 + + # 检查用户名和邮箱是否已存在 + if User.is_username_taken(username): + return jsonify({'success': False, 'error': '用户名已存在'}), 400 + + if User.is_email_taken(email): + return jsonify({'success': False, 'error': '邮箱已被使用'}), 400 + + try: + # 创建新用户 + user = User(username=username, email=email) + user.set_password(password) + user.email_confirmed = True # 暂时默认已确认 + + db.session.add(user) + db.session.commit() + + # 自动登录 + session.permanent = True + session['user_id'] = user.id + session['username'] = user.username + session['logged_in'] = True + + # Flask-Login 登录 + login_user(user, remember=True) + + return jsonify({ + 'success': True, + 'message': '注册成功', + 'user': { + 'id': user.id, + 'username': user.username, + 'nickname': user.nickname or user.username, + 'email': user.email + } + }), 201 + + except Exception as e: + db.session.rollback() + print(f"验证码登录/注册错误: {e}") + return jsonify({'success': False, 'error': '登录失败'}), 500 + + +def send_sms_code(phone, code, template_id): + """发送短信验证码""" + try: + cred = credential.Credential(SMS_SECRET_ID, SMS_SECRET_KEY) + httpProfile = HttpProfile() + httpProfile.endpoint = "sms.tencentcloudapi.com" + + clientProfile = ClientProfile() + clientProfile.httpProfile = httpProfile + client = sms_client.SmsClient(cred, "ap-beijing", clientProfile) + + req = models.SendSmsRequest() + params = { + "PhoneNumberSet": [phone], + "SmsSdkAppId": SMS_SDK_APP_ID, + "TemplateId": template_id, + "SignName": SMS_SIGN_NAME, + "TemplateParamSet": [code, "5"] if template_id == SMS_TEMPLATE_LOGIN else [code] + } + req.from_json_string(json.dumps(params)) + + resp = client.SendSms(req) + return True + except TencentCloudSDKException as err: + print(f"SMS Error: {err}") + return False + + +def send_email_code(email, code): + """发送邮件验证码""" + try: + print(f"[邮件发送] 准备发送验证码到: {email}") + print(f"[邮件配置] 服务器: {MAIL_SERVER}, 端口: {MAIL_PORT}, SSL: {MAIL_USE_SSL}") + + msg = Message( + subject='价值前沿 - 验证码', + recipients=[email], + body=f'您的验证码是:{code},有效期5分钟。如非本人操作,请忽略此邮件。' + ) + mail.send(msg) + print(f"[邮件发送] 验证码邮件发送成功到: {email}") + return True + except Exception as e: + print(f"[邮件发送错误] 发送到 {email} 失败: {str(e)}") + print(f"[邮件发送错误] 错误类型: {type(e).__name__}") + return False + + +@app.route('/api/auth/send-sms-code', methods=['POST']) +def send_sms_verification(): + """发送手机验证码""" + data = request.get_json() + phone = data.get('phone') + + if not phone: + return jsonify({'error': '手机号不能为空'}), 400 + + # 注册时验证是否已注册;若用于绑定手机,需要另外接口 + # 这里保留原逻辑,新增绑定接口处理不同规则 + if User.query.filter_by(phone=phone).first(): + return jsonify({'error': '该手机号已注册'}), 400 + + # 生成验证码 + code = generate_verification_code() + + # 发送短信 + if send_sms_code(phone, code, SMS_TEMPLATE_REGISTER): + # 存储验证码(5分钟有效) + verification_codes[f'phone_{phone}'] = { + 'code': code, + 'expires': time.time() + 300 + } + return jsonify({'message': '验证码已发送'}), 200 + else: + return jsonify({'error': '验证码发送失败'}), 500 + + +@app.route('/api/auth/send-email-code', methods=['POST']) +def send_email_verification(): + """发送邮箱验证码""" + data = request.get_json() + email = data.get('email') + + if not email: + return jsonify({'error': '邮箱不能为空'}), 400 + + if User.query.filter_by(email=email).first(): + return jsonify({'error': '该邮箱已注册'}), 400 + + # 生成验证码 + code = generate_verification_code() + + # 发送邮件 + if send_email_code(email, code): + # 存储验证码(5分钟有效) + verification_codes[f'email_{email}'] = { + 'code': code, + 'expires': time.time() + 300 + } + return jsonify({'message': '验证码已发送'}), 200 + else: + return jsonify({'error': '验证码发送失败'}), 500 + + +@app.route('/api/auth/register/phone', methods=['POST']) +def register_with_phone(): + """手机号注册 - 使用Session""" + data = request.get_json() + phone = data.get('phone') + code = data.get('code') + password = data.get('password') + username = data.get('username') + + if not all([phone, code, password, username]): + return jsonify({'success': False, 'error': '所有字段都是必填的'}), 400 + + # 验证验证码 + stored_code = verification_codes.get(f'phone_{phone}') + if not stored_code or stored_code['expires'] < time.time(): + return jsonify({'success': False, 'error': '验证码已过期'}), 400 + + if stored_code['code'] != code: + return jsonify({'success': False, 'error': '验证码错误'}), 400 + + if User.query.filter_by(username=username).first(): + return jsonify({'success': False, 'error': '用户名已存在'}), 400 + + try: + # 创建用户 + user = User(username=username, phone=phone) + user.email = f"{username}@valuefrontier.temp" + user.set_password(password) + user.phone_confirmed = True + + db.session.add(user) + db.session.commit() + + # 清除验证码 + del verification_codes[f'phone_{phone}'] + + # 自动登录 + session.permanent = True + session['user_id'] = user.id + session['username'] = user.username + session['logged_in'] = True + + # Flask-Login 登录 + login_user(user, remember=True) + + return jsonify({ + 'success': True, + 'message': '注册成功', + 'user': { + 'id': user.id, + 'username': user.username, + 'phone': user.phone + } + }), 201 + + except Exception as e: + db.session.rollback() + return jsonify({'success': False, 'error': '注册失败,请重试'}), 500 + + +@app.route('/api/account/phone/send-code', methods=['POST']) +def send_sms_bind_code(): + """发送绑定手机验证码(需已登录)""" + if not session.get('logged_in'): + return jsonify({'error': '未登录'}), 401 + + data = request.get_json() + phone = data.get('phone') + if not phone: + return jsonify({'error': '手机号不能为空'}), 400 + + # 绑定时要求手机号未被占用 + if User.query.filter_by(phone=phone).first(): + return jsonify({'error': '该手机号已被其他账号使用'}), 400 + + code = generate_verification_code() + if send_sms_code(phone, code, SMS_TEMPLATE_REGISTER): + verification_codes[f'bind_{phone}'] = { + 'code': code, + 'expires': time.time() + 300 + } + return jsonify({'message': '验证码已发送'}), 200 + else: + return jsonify({'error': '验证码发送失败'}), 500 + + +@app.route('/api/account/phone/bind', methods=['POST']) +def bind_phone(): + """当前登录用户绑定手机号""" + if not session.get('logged_in'): + return jsonify({'error': '未登录'}), 401 + + data = request.get_json() + phone = data.get('phone') + code = data.get('code') + + if not phone or not code: + return jsonify({'error': '手机号和验证码不能为空'}), 400 + + stored = verification_codes.get(f'bind_{phone}') + if not stored or stored['expires'] < time.time(): + return jsonify({'error': '验证码已过期'}), 400 + if stored['code'] != code: + return jsonify({'error': '验证码错误'}), 400 + + if User.query.filter_by(phone=phone).first(): + return jsonify({'error': '该手机号已被其他账号使用'}), 400 + + try: + user = User.query.get(session.get('user_id')) + if not user: + return jsonify({'error': '用户不存在'}), 404 + + user.phone = phone + user.confirm_phone() + # 清除验证码 + del verification_codes[f'bind_{phone}'] + + return jsonify({'message': '绑定成功', 'success': True}), 200 + except Exception as e: + print(f"Bind phone error: {e}") + db.session.rollback() + return jsonify({'error': '绑定失败,请重试'}), 500 + + +@app.route('/api/account/phone/unbind', methods=['POST']) +def unbind_phone(): + """解绑手机号(需已登录)""" + if not session.get('logged_in'): + return jsonify({'error': '未登录'}), 401 + + try: + user = User.query.get(session.get('user_id')) + if not user: + return jsonify({'error': '用户不存在'}), 404 + + user.phone = None + user.phone_confirmed = False + user.phone_confirm_time = None + db.session.commit() + return jsonify({'message': '解绑成功', 'success': True}), 200 + except Exception as e: + print(f"Unbind phone error: {e}") + db.session.rollback() + return jsonify({'error': '解绑失败,请重试'}), 500 + + +@app.route('/api/account/email/send-bind-code', methods=['POST']) +def send_email_bind_code(): + """发送绑定邮箱验证码(需已登录)""" + if not session.get('logged_in'): + return jsonify({'error': '未登录'}), 401 + + data = request.get_json() + email = data.get('email') + + if not email: + return jsonify({'error': '邮箱不能为空'}), 400 + + # 邮箱格式验证 + if not re.match(r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$', email): + return jsonify({'error': '邮箱格式不正确'}), 400 + + # 检查邮箱是否已被其他账号使用 + if User.query.filter_by(email=email).first(): + return jsonify({'error': '该邮箱已被其他账号使用'}), 400 + + # 生成验证码 + code = ''.join(random.choices(string.digits, k=6)) + + if send_email_code(email, code): + # 存储验证码(5分钟有效) + verification_codes[f'bind_{email}'] = { + 'code': code, + 'expires': time.time() + 300 + } + return jsonify({'message': '验证码已发送'}), 200 + else: + return jsonify({'error': '验证码发送失败'}), 500 + + +@app.route('/api/account/email/bind', methods=['POST']) +def bind_email(): + """当前登录用户绑定邮箱""" + if not session.get('logged_in'): + return jsonify({'error': '未登录'}), 401 + + data = request.get_json() + email = data.get('email') + code = data.get('code') + + if not email or not code: + return jsonify({'error': '邮箱和验证码不能为空'}), 400 + + stored = verification_codes.get(f'bind_{email}') + if not stored or stored['expires'] < time.time(): + return jsonify({'error': '验证码已过期'}), 400 + if stored['code'] != code: + return jsonify({'error': '验证码错误'}), 400 + + if User.query.filter_by(email=email).first(): + return jsonify({'error': '该邮箱已被其他账号使用'}), 400 + + try: + user = User.query.get(session.get('user_id')) + if not user: + return jsonify({'error': '用户不存在'}), 404 + + user.email = email + user.confirm_email() + db.session.commit() + + # 清除验证码 + del verification_codes[f'bind_{email}'] + + return jsonify({ + 'message': '邮箱绑定成功', + 'success': True, + 'user': { + 'email': user.email, + 'email_confirmed': user.email_confirmed + } + }), 200 + except Exception as e: + print(f"Bind email error: {e}") + db.session.rollback() + return jsonify({'error': '绑定失败,请重试'}), 500 + + +@app.route('/api/account/email/unbind', methods=['POST']) +def unbind_email(): + """解绑邮箱(需已登录)""" + if not session.get('logged_in'): + return jsonify({'error': '未登录'}), 401 + + try: + user = User.query.get(session.get('user_id')) + if not user: + return jsonify({'error': '用户不存在'}), 404 + + user.email = None + user.email_confirmed = False + db.session.commit() + return jsonify({'message': '解绑成功', 'success': True}), 200 + except Exception as e: + print(f"Unbind email error: {e}") + db.session.rollback() + return jsonify({'error': '解绑失败,请重试'}), 500 + + +@app.route('/api/auth/register/email', methods=['POST']) +def register_with_email(): + """邮箱注册 - 使用Session""" + data = request.get_json() + email = data.get('email') + code = data.get('code') + password = data.get('password') + username = data.get('username') + + if not all([email, code, password, username]): + return jsonify({'success': False, 'error': '所有字段都是必填的'}), 400 + + # 验证验证码 + stored_code = verification_codes.get(f'email_{email}') + if not stored_code or stored_code['expires'] < time.time(): + return jsonify({'success': False, 'error': '验证码已过期'}), 400 + + if stored_code['code'] != code: + return jsonify({'success': False, 'error': '验证码错误'}), 400 + + if User.query.filter_by(username=username).first(): + return jsonify({'success': False, 'error': '用户名已存在'}), 400 + + try: + # 创建用户 + user = User(username=username, email=email) + user.set_password(password) + user.email_confirmed = True + + db.session.add(user) + db.session.commit() + + # 清除验证码 + del verification_codes[f'email_{email}'] + + # 自动登录 + session.permanent = True + session['user_id'] = user.id + session['username'] = user.username + session['logged_in'] = True + + # Flask-Login 登录 + login_user(user, remember=True) + + return jsonify({ + 'success': True, + 'message': '注册成功', + 'user': { + 'id': user.id, + 'username': user.username, + 'email': user.email + } + }), 201 + + except Exception as e: + db.session.rollback() + return jsonify({'success': False, 'error': '注册失败,请重试'}), 500 + + +def get_wechat_access_token(code): + """通过code获取微信access_token""" + url = "https://api.weixin.qq.com/sns/oauth2/access_token" + params = { + 'appid': WECHAT_APPID, + 'secret': WECHAT_APPSECRET, + 'code': code, + 'grant_type': 'authorization_code' + } + + try: + response = requests.get(url, params=params, timeout=10) + data = response.json() + + if 'errcode' in data: + print(f"WeChat access token error: {data}") + return None + + return data + except Exception as e: + print(f"WeChat access token request error: {e}") + return None + + +def get_wechat_userinfo(access_token, openid): + """获取微信用户信息(包含UnionID)""" + url = "https://api.weixin.qq.com/sns/userinfo" + params = { + 'access_token': access_token, + 'openid': openid, + 'lang': 'zh_CN' + } + + try: + response = requests.get(url, params=params, timeout=10) + response.encoding = 'utf-8' # 明确设置编码为UTF-8 + data = response.json() + + if 'errcode' in data: + print(f"WeChat userinfo error: {data}") + return None + + # 确保nickname字段的编码正确 + if 'nickname' in data and data['nickname']: + # 确保昵称是正确的UTF-8编码 + try: + # 检查是否已经是正确的UTF-8字符串 + data['nickname'] = data['nickname'].encode('utf-8').decode('utf-8') + except (UnicodeEncodeError, UnicodeDecodeError) as e: + print(f"Nickname encoding error: {e}, using default") + data['nickname'] = '微信用户' + + return data + except Exception as e: + print(f"WeChat userinfo request error: {e}") + return None + + +@app.route('/api/auth/wechat/qrcode', methods=['GET']) +def get_wechat_qrcode(): + """返回微信授权URL,前端使用iframe展示""" + # 生成唯一state参数 + state = uuid.uuid4().hex + + # URL编码回调地址 + redirect_uri = urllib.parse.quote_plus(WECHAT_REDIRECT_URI) + + # 构建微信授权URL + wechat_auth_url = ( + f"https://open.weixin.qq.com/connect/qrconnect?" + f"appid={WECHAT_APPID}&redirect_uri={redirect_uri}" + f"&response_type=code&scope=snsapi_login&state={state}" + "#wechat_redirect" + ) + + # 存储session信息 + wechat_qr_sessions[state] = { + 'status': 'waiting', + 'expires': time.time() + 300, # 5分钟过期 + 'user_info': None, + 'wechat_openid': None, + 'wechat_unionid': None + } + + return jsonify({"code":0, + "data": + { + 'auth_url': wechat_auth_url, + 'session_id': state, + 'expires_in': 300 + }}), 200 + + +@app.route('/api/account/wechat/qrcode', methods=['GET']) +def get_wechat_bind_qrcode(): + """发起微信绑定二维码,会话标记为绑定模式""" + if not session.get('logged_in'): + return jsonify({'error': '未登录'}), 401 + + # 生成唯一state参数 + state = uuid.uuid4().hex + + # URL编码回调地址 + redirect_uri = urllib.parse.quote_plus(WECHAT_REDIRECT_URI) + + # 构建微信授权URL + wechat_auth_url = ( + f"https://open.weixin.qq.com/connect/qrconnect?" + f"appid={WECHAT_APPID}&redirect_uri={redirect_uri}" + f"&response_type=code&scope=snsapi_login&state={state}" + "#wechat_redirect" + ) + + # 存储session信息,标记为绑定模式并记录目标用户 + wechat_qr_sessions[state] = { + 'status': 'waiting', + 'expires': time.time() + 300, + 'mode': 'bind', + 'bind_user_id': session.get('user_id'), + 'user_info': None, + 'wechat_openid': None, + 'wechat_unionid': None + } + + return jsonify({ + 'auth_url': wechat_auth_url, + 'session_id': state, + 'expires_in': 300 + }), 200 + + +@app.route('/api/auth/wechat/check', methods=['POST']) +def check_wechat_scan(): + """检查微信扫码状态""" + data = request.get_json() + session_id = data.get('session_id') + + if not session_id or session_id not in wechat_qr_sessions: + return jsonify({'status': 'invalid', 'error': '无效的session'}), 400 + + session = wechat_qr_sessions[session_id] + + # 检查是否过期 + if time.time() > session['expires']: + del wechat_qr_sessions[session_id] + return jsonify({'status': 'expired'}), 200 + + return jsonify({ + 'status': session['status'], + 'user_info': session.get('user_info'), + 'expires_in': int(session['expires'] - time.time()) + }), 200 + + +@app.route('/api/account/wechat/check', methods=['POST']) +def check_wechat_bind_scan(): + """检查微信扫码绑定状态""" + data = request.get_json() + session_id = data.get('session_id') + + if not session_id or session_id not in wechat_qr_sessions: + return jsonify({'status': 'invalid', 'error': '无效的session'}), 400 + + sess = wechat_qr_sessions[session_id] + + # 绑定模式限制 + if sess.get('mode') != 'bind': + return jsonify({'status': 'invalid', 'error': '会话模式错误'}), 400 + + # 过期处理 + if time.time() > sess['expires']: + del wechat_qr_sessions[session_id] + return jsonify({'status': 'expired'}), 200 + + return jsonify({ + 'status': sess['status'], + 'user_info': sess.get('user_info'), + 'expires_in': int(sess['expires'] - time.time()) + }), 200 + + +@app.route('/api/auth/wechat/callback', methods=['GET']) +def wechat_callback(): + """微信授权回调处理 - 使用Session""" + code = request.args.get('code') + state = request.args.get('state') + error = request.args.get('error') + + # 错误处理:用户拒绝授权 + if error: + if state in wechat_qr_sessions: + wechat_qr_sessions[state]['status'] = 'auth_denied' + wechat_qr_sessions[state]['error'] = '用户拒绝授权' + print(f"❌ 用户拒绝授权: state={state}") + return redirect('/auth/signin?error=wechat_auth_denied') + + # 参数验证 + if not code or not state: + if state in wechat_qr_sessions: + wechat_qr_sessions[state]['status'] = 'auth_failed' + wechat_qr_sessions[state]['error'] = '授权参数缺失' + return redirect('/auth/signin?error=wechat_auth_failed') + + # 验证state + if state not in wechat_qr_sessions: + return redirect('/auth/signin?error=session_expired') + + session_data = wechat_qr_sessions[state] + + # 检查过期 + if time.time() > session_data['expires']: + del wechat_qr_sessions[state] + return redirect('/auth/signin?error=session_expired') + + try: + # 步骤1: 用户已扫码并授权(微信回调过来说明用户已完成扫码+授权) + session_data['status'] = 'scanned' + print(f"✅ 微信扫码回调: state={state}, code={code[:10]}...") + + # 步骤2: 获取access_token + token_data = get_wechat_access_token(code) + if not token_data: + session_data['status'] = 'auth_failed' + session_data['error'] = '获取访问令牌失败' + print(f"❌ 获取微信access_token失败: state={state}") + return redirect('/auth/signin?error=token_failed') + + # 步骤3: Token获取成功,标记为已授权 + session_data['status'] = 'authorized' + print(f"✅ 微信授权成功: openid={token_data['openid']}") + + # 步骤4: 获取用户信息 + user_info = get_wechat_userinfo(token_data['access_token'], token_data['openid']) + if not user_info: + session_data['status'] = 'auth_failed' + session_data['error'] = '获取用户信息失败' + print(f"❌ 获取微信用户信息失败: openid={token_data['openid']}") + return redirect('/auth/signin?error=userinfo_failed') + + # 查找或创建用户 / 或处理绑定 + openid = token_data['openid'] + unionid = user_info.get('unionid') or token_data.get('unionid') + + # 如果是绑定流程 + session_item = wechat_qr_sessions.get(state) + if session_item and session_item.get('mode') == 'bind': + try: + target_user_id = session.get('user_id') or session_item.get('bind_user_id') + if not target_user_id: + return redirect('/auth/signin?error=bind_no_user') + + target_user = User.query.get(target_user_id) + if not target_user: + return redirect('/auth/signin?error=bind_user_missing') + + # 检查该微信是否已被其他账户绑定 + existing = None + if unionid: + existing = User.query.filter_by(wechat_union_id=unionid).first() + if not existing: + existing = User.query.filter_by(wechat_open_id=openid).first() + + if existing and existing.id != target_user.id: + session_item['status'] = 'bind_conflict' + return redirect('/home?bind=conflict') + + # 执行绑定 + target_user.bind_wechat(openid, unionid, wechat_info=user_info) + + # 标记绑定完成,供前端轮询 + session_item['status'] = 'bind_ready' + session_item['user_info'] = {'user_id': target_user.id} + + return redirect('/home?bind=success') + except Exception as e: + print(f"❌ 微信绑定失败: {e}") + db.session.rollback() + session_item['status'] = 'bind_failed' + return redirect('/home?bind=failed') + + user = None + is_new_user = False + + if unionid: + user = User.query.filter_by(wechat_union_id=unionid).first() + if not user: + user = User.query.filter_by(wechat_open_id=openid).first() + + if not user: + # 创建新用户 + # 先清理微信昵称 + raw_nickname = user_info.get('nickname', '微信用户') + # 创建临时用户实例以使用清理方法 + temp_user = User.__new__(User) + sanitized_nickname = temp_user._sanitize_nickname(raw_nickname) + + username = sanitized_nickname + counter = 1 + while User.is_username_taken(username): + username = f"{sanitized_nickname}_{counter}" + counter += 1 + + user = User(username=username) + user.nickname = sanitized_nickname + user.avatar_url = user_info.get('headimgurl') + user.wechat_open_id = openid + user.wechat_union_id = unionid + user.set_password(uuid.uuid4().hex) + user.status = 'active' + + db.session.add(user) + db.session.commit() + + is_new_user = True + print(f"✅ 微信扫码自动创建新用户: {username}, openid: {openid}") + + # 更新最后登录时间 + user.update_last_seen() + + # 设置session + session.permanent = True + session['user_id'] = user.id + session['username'] = user.username + session['logged_in'] = True + session['wechat_login'] = True # 标记是微信登录 + + # Flask-Login 登录 + login_user(user, remember=True) + + # 更新微信session状态,供前端轮询检测 + if state in wechat_qr_sessions: + session_item = wechat_qr_sessions[state] + # 仅处理登录/注册流程,不处理绑定流程 + if not session_item.get('mode'): + # 更新状态和用户信息 + session_item['status'] = 'register_ready' if is_new_user else 'login_ready' + session_item['user_info'] = {'user_id': user.id} + print(f"✅ 微信扫码状态已更新: {session_item['status']}, user_id: {user.id}") + + # 直接跳转到首页 + return redirect('/home') + + except Exception as e: + print(f"❌ 微信登录失败: {e}") + import traceback + traceback.print_exc() + db.session.rollback() + + # 更新session状态为失败 + if state in wechat_qr_sessions: + wechat_qr_sessions[state]['status'] = 'auth_failed' + wechat_qr_sessions[state]['error'] = str(e) + + return redirect('/auth/signin?error=login_failed') + + +@app.route('/api/auth/login/wechat', methods=['POST']) +def login_with_wechat(): + """微信登录 - 修复版本""" + data = request.get_json() + session_id = data.get('session_id') + + if not session_id: + return jsonify({'success': False, 'error': 'session_id不能为空'}), 400 + + # 验证session + session = wechat_qr_sessions.get(session_id) + if not session: + return jsonify({'success': False, 'error': '会话不存在或已过期'}), 400 + + # 检查session状态 + if session['status'] not in ['login_ready', 'register_ready']: + return jsonify({'success': False, 'error': '会话状态无效'}), 400 + + # 检查是否有用户信息 + user_info = session.get('user_info') + if not user_info or not user_info.get('user_id'): + return jsonify({'success': False, 'error': '用户信息不完整'}), 400 + + try: + user = User.query.get(user_info['user_id']) + if not user: + return jsonify({'success': False, 'error': '用户不存在'}), 404 + + # 更新最后登录时间 + user.update_last_seen() + + # 清除session + del wechat_qr_sessions[session_id] + + # 生成登录响应 + response_data = { + 'success': True, + 'message': '登录成功' if session['status'] == 'login_ready' else '注册并登录成功', + 'user': { + 'id': user.id, + 'username': user.username, + 'nickname': user.nickname or user.username, + 'email': user.email, + 'avatar_url': user.avatar_url, + 'has_wechat': True, + 'wechat_open_id': user.wechat_open_id, + 'wechat_union_id': user.wechat_union_id, + 'created_at': user.created_at.isoformat() if user.created_at else None, + 'last_seen': user.last_seen.isoformat() if user.last_seen else None + } + } + + # 如果需要token认证,可以在这里生成 + # response_data['token'] = generate_token(user.id) + + return jsonify(response_data), 200 + + except Exception as e: + print(f"❌ 微信登录错误: {e}") + import traceback + app.logger.error(f"回调处理错误: {e}", exc_info=True) + return jsonify({ + 'success': False, + 'error': '登录失败,请重试' + }), 500 + + +@app.route('/api/account/wechat/unbind', methods=['POST']) +def unbind_wechat_account(): + """解绑当前登录用户的微信""" + if not session.get('logged_in'): + return jsonify({'error': '未登录'}), 401 + + try: + user = User.query.get(session.get('user_id')) + if not user: + return jsonify({'error': '用户不存在'}), 404 + + user.unbind_wechat() + return jsonify({'message': '解绑成功', 'success': True}), 200 + except Exception as e: + print(f"Unbind wechat error: {e}") + db.session.rollback() + return jsonify({'error': '解绑失败,请重试'}), 500 + + +# 评论模型 +class EventComment(db.Model): + """事件评论""" + __tablename__ = 'event_comment' + + id = db.Column(db.Integer, primary_key=True) + event_id = db.Column(db.Integer, nullable=False) + user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=True) + author = db.Column(db.String(50), default='匿名用户') + content = db.Column(db.Text, nullable=False) + parent_id = db.Column(db.Integer, db.ForeignKey('event_comment.id')) + likes = db.Column(db.Integer, default=0) + created_at = db.Column(db.DateTime, default=beijing_now) + status = db.Column(db.String(20), default='active') + + user = db.relationship('User', backref='event_comments') + replies = db.relationship('EventComment', backref=db.backref('parent', remote_side=[id])) + + def to_dict(self, user_session_id=None, current_user_id=None): + # 检查当前用户是否已点赞 + user_liked = False + if user_session_id: + like_record = CommentLike.query.filter_by( + comment_id=self.id, + session_id=user_session_id + ).first() + user_liked = like_record is not None + + # 检查当前用户是否可以删除此评论 + can_delete = current_user_id is not None and self.user_id == current_user_id + + return { + 'id': self.id, + 'event_id': self.event_id, + 'author': self.author, + 'content': self.content, + 'parent_id': self.parent_id, + 'likes': self.likes, + 'created_at': self.created_at.isoformat() if self.created_at else None, + 'user_liked': user_liked, + 'can_delete': can_delete, + 'user_id': self.user_id, + 'replies': [reply.to_dict(user_session_id, current_user_id) for reply in self.replies if + reply.status == 'active'] + } + + +class CommentLike(db.Model): + """评论点赞记录""" + __tablename__ = 'comment_like' + + id = db.Column(db.Integer, primary_key=True) + comment_id = db.Column(db.Integer, db.ForeignKey('event_comment.id'), nullable=False) + session_id = db.Column(db.String(100), nullable=False) + created_at = db.Column(db.DateTime, default=beijing_now) + + __table_args__ = (db.UniqueConstraint('comment_id', 'session_id'),) + + +@app.after_request +def after_request(response): + """处理所有响应,添加CORS头部和安全头部""" + origin = request.headers.get('Origin') + allowed_origins = ['http://localhost:3000', 'http://127.0.0.1:3000', 'http://localhost:5173', + 'https://valuefrontier.cn', 'http://valuefrontier.cn'] + + if origin in allowed_origins: + response.headers['Access-Control-Allow-Origin'] = origin + response.headers['Access-Control-Allow-Credentials'] = 'true' + response.headers['Access-Control-Allow-Headers'] = 'Content-Type,Authorization,X-Requested-With' + response.headers['Access-Control-Allow-Methods'] = 'GET,PUT,POST,DELETE,OPTIONS' + response.headers['Access-Control-Expose-Headers'] = 'Content-Type,Authorization' + + # 处理预检请求 + if request.method == 'OPTIONS': + response.status_code = 200 + + return response + + +def add_cors_headers(response): + """添加CORS头(保留原有函数以兼容)""" + origin = request.headers.get('Origin') + allowed_origins = ['http://localhost:3000', 'http://127.0.0.1:3000', 'http://localhost:5173', + 'https://valuefrontier.cn', 'http://valuefrontier.cn'] + + if origin in allowed_origins: + response.headers['Access-Control-Allow-Origin'] = origin + else: + response.headers['Access-Control-Allow-Origin'] = 'http://localhost:3000' + + response.headers['Access-Control-Allow-Headers'] = 'Content-Type,Authorization,X-Requested-With' + response.headers['Access-Control-Allow-Methods'] = 'GET,PUT,POST,DELETE,OPTIONS' + response.headers['Access-Control-Allow-Credentials'] = 'true' + return response + + +class EventFollow(db.Model): + """事件关注""" + id = db.Column(db.Integer, primary_key=True) + user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False) + event_id = db.Column(db.Integer, db.ForeignKey('event.id'), nullable=False) + created_at = db.Column(db.DateTime, default=beijing_now) + + user = db.relationship('User', backref='event_follows') + + __table_args__ = (db.UniqueConstraint('user_id', 'event_id'),) + + +class FutureEventFollow(db.Model): + """未来事件关注""" + __tablename__ = 'future_event_follow' + + id = db.Column(db.Integer, primary_key=True) + user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False) + future_event_id = db.Column(db.Integer, nullable=False) # future_events表的id + created_at = db.Column(db.DateTime, default=beijing_now) + + user = db.relationship('User', backref='future_event_follows') + + __table_args__ = (db.UniqueConstraint('user_id', 'future_event_id'),) + + +# —— 自选股输入统一化与名称补全工具 —— +def _normalize_stock_input(raw_input: str): + """解析用户输入为标准6位股票代码与可选名称。 + + 支持: + - 6位代码: "600519",或带后缀 "600519.SH"/"600519.SZ" + - 名称(代码): "贵州茅台(600519)" 或 "贵州茅台(600519)" + 返回 (code6, name_or_none) + """ + if not raw_input: + return None, None + s = str(raw_input).strip() + + # 名称(600519) 或 名称(600519) + m = re.match(r"^(.+?)[\((]\s*(\d{6})\s*[\))]\s*$", s) + if m: + name = m.group(1).strip() + code = m.group(2) + return code, (name if name else None) + + # 600519 或 600519.SH / 600519.SZ + m2 = re.match(r"^(\d{6})(?:\.(?:SH|SZ))?$", s, re.IGNORECASE) + if m2: + return m2.group(1), None + + # SH600519 / SZ000001 + m3 = re.match(r"^(SH|SZ)(\d{6})$", s, re.IGNORECASE) + if m3: + return m3.group(2), None + + return None, None + + +def _query_stock_name_by_code(code6: str): + """根据6位代码查询股票名称,查不到返回None。""" + try: + with engine.connect() as conn: + q = text(""" + SELECT SECNAME + FROM ea_baseinfo + WHERE SECCODE = :c LIMIT 1 + """) + row = conn.execute(q, {'c': code6}).fetchone() + if row: + return row[0] + except Exception: + pass + return None + + +class Watchlist(db.Model): + """用户自选股""" + __tablename__ = 'watchlist' + + id = db.Column(db.Integer, primary_key=True) + user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False) + stock_code = db.Column(db.String(20), nullable=False) + stock_name = db.Column(db.String(100), nullable=True) + created_at = db.Column(db.DateTime, default=beijing_now) + + user = db.relationship('User', backref='watchlist') + + __table_args__ = (db.UniqueConstraint('user_id', 'stock_code'),) + + +@app.route('/api/account/watchlist', methods=['GET']) +def get_my_watchlist(): + """获取当前用户的自选股列表""" + try: + if 'user_id' not in session: + return jsonify({'success': False, 'error': '未登录'}), 401 + + items = Watchlist.query.filter_by(user_id=session['user_id']).order_by(Watchlist.created_at.desc()).all() + + # 懒更新:统一代码为6位、补全缺失的名称,并去重(同一代码保留一个记录) + from collections import defaultdict + groups = defaultdict(list) + for i in items: + code6, _ = _normalize_stock_input(i.stock_code) + normalized_code = code6 or (i.stock_code.strip().upper() if isinstance(i.stock_code, str) else i.stock_code) + groups[normalized_code].append(i) + + dirty = False + to_delete = [] + for code6, group in groups.items(): + # 选择保留记录:优先有名称的,其次创建时间早的 + def sort_key(x): + return (x.stock_name is None, x.created_at or datetime.min) + + group_sorted = sorted(group, key=sort_key) + keep = group_sorted[0] + # 规范保留项 + if keep.stock_code != code6: + keep.stock_code = code6 + dirty = True + if not keep.stock_name and code6: + nm = _query_stock_name_by_code(code6) + if nm: + keep.stock_name = nm + dirty = True + # 其余删除 + for g in group_sorted[1:]: + to_delete.append(g) + + if to_delete: + for g in to_delete: + db.session.delete(g) + dirty = True + + if dirty: + db.session.commit() + + return jsonify({'success': True, 'data': [ + { + 'id': i.id, + 'stock_code': i.stock_code, + 'stock_name': i.stock_name, + 'created_at': i.created_at.isoformat() if i.created_at else None + } for i in items + ]}) + except Exception as e: + print(f"Error in get_my_watchlist: {str(e)}") + return jsonify({'success': False, 'error': str(e)}), 500 + + +@app.route('/api/account/watchlist', methods=['POST']) +def add_to_watchlist(): + """添加到自选股""" + if 'user_id' not in session: + return jsonify({'success': False, 'error': '未登录'}), 401 + + data = request.get_json() or {} + raw_code = data.get('stock_code') + raw_name = data.get('stock_name') + + code6, name_from_input = _normalize_stock_input(raw_code) + if not code6: + return jsonify({'success': False, 'error': '无效的股票标识'}), 400 + + # 优先使用传入名称,其次从输入解析中获得,最后查库补全 + final_name = raw_name or name_from_input or _query_stock_name_by_code(code6) + + # 查找已存在记录,兼容历史:6位/带后缀 + candidates = [code6, f"{code6}.SH", f"{code6}.SZ"] + existing = Watchlist.query.filter( + Watchlist.user_id == session['user_id'], + Watchlist.stock_code.in_(candidates) + ).first() + if existing: + # 统一为6位,补全名称 + updated = False + if existing.stock_code != code6: + existing.stock_code = code6 + updated = True + if (not existing.stock_name) and final_name: + existing.stock_name = final_name + updated = True + if updated: + db.session.commit() + return jsonify({'success': True, 'data': {'id': existing.id}}) + + item = Watchlist(user_id=session['user_id'], stock_code=code6, stock_name=final_name) + db.session.add(item) + db.session.commit() + return jsonify({'success': True, 'data': {'id': item.id}}) + + +@app.route('/api/account/watchlist/', methods=['DELETE']) +def remove_from_watchlist(stock_code): + """从自选股移除""" + if 'user_id' not in session: + return jsonify({'success': False, 'error': '未登录'}), 401 + + code6, _ = _normalize_stock_input(stock_code) + candidates = [] + if code6: + candidates = [code6, f"{code6}.SH", f"{code6}.SZ"] + # 包含原始传入(以兼容历史) + if stock_code not in candidates: + candidates.append(stock_code) + + item = Watchlist.query.filter( + Watchlist.user_id == session['user_id'], + Watchlist.stock_code.in_(candidates) + ).first() + if not item: + return jsonify({'success': False, 'error': '未找到自选项'}), 404 + db.session.delete(item) + db.session.commit() + return jsonify({'success': True}) + + +@app.route('/api/account/watchlist/realtime', methods=['GET']) +def get_watchlist_realtime(): + """获取自选股实时行情数据(基于分钟线)""" + try: + if 'user_id' not in session: + return jsonify({'success': False, 'error': '未登录'}), 401 + + # 获取用户自选股列表 + watchlist = Watchlist.query.filter_by(user_id=session['user_id']).all() + if not watchlist: + return jsonify({'success': True, 'data': []}) + + # 获取股票代码列表 + stock_codes = [] + for item in watchlist: + code6, _ = _normalize_stock_input(item.stock_code) + # 统一内部查询代码 + normalized = code6 or str(item.stock_code).strip().upper() + stock_codes.append(normalized) + + # 使用现有的分钟线接口获取最新行情 + client = get_clickhouse_client() + quotes_data = {} + + # 获取最新交易日 + today = datetime.now().date() + + # 获取每只股票的最新价格 + for code in stock_codes: + raw_code = str(code).strip().upper() + if '.' in raw_code: + stock_code_full = raw_code + else: + stock_code_full = f"{raw_code}.SH" if raw_code.startswith('6') else f"{raw_code}.SZ" + + # 获取最新分钟线数据(先查近7天,若无数据再兜底倒序取最近一条) + query = """ + SELECT + close, timestamp, high, low, volume, amt + FROM stock_minute + WHERE code = %(code)s + AND timestamp >= %(start)s + ORDER BY timestamp DESC + LIMIT 1 \ + """ + + # 获取最近7天的分钟数据 + start_date = today - timedelta(days=7) + + result = client.execute(query, { + 'code': stock_code_full, + 'start': datetime.combine(start_date, dt_time(9, 30)) + }) + + # 若近7天无数据,兜底直接取最近一条 + if not result: + fallback_query = """ + SELECT + close, timestamp, high, low, volume, amt + FROM stock_minute + WHERE code = %(code)s + ORDER BY timestamp DESC + LIMIT 1 \ + """ + result = client.execute(fallback_query, {'code': stock_code_full}) + + if result: + latest_data = result[0] + latest_ts = latest_data[1] + + # 获取该bar所属交易日前一个交易日的收盘价 + prev_close_query = """ + SELECT close + FROM stock_minute + WHERE code = %(code)s + AND timestamp \ + < %(start)s + ORDER BY timestamp DESC + LIMIT 1 \ + """ + + prev_result = client.execute(prev_close_query, { + 'code': stock_code_full, + 'start': datetime.combine(latest_ts.date(), dt_time(9, 30)) + }) + + prev_close = float(prev_result[0][0]) if prev_result else float(latest_data[0]) + + # 计算涨跌幅 + change = float(latest_data[0]) - prev_close + change_percent = (change / prev_close * 100) if prev_close > 0 else 0.0 + + quotes_data[code] = { + 'price': float(latest_data[0]), + 'prev_close': float(prev_close), + 'change': float(change), + 'change_percent': float(change_percent), + 'high': float(latest_data[2]), + 'low': float(latest_data[3]), + 'volume': int(latest_data[4]), + 'amount': float(latest_data[5]), + 'update_time': latest_ts.strftime('%H:%M:%S') + } + + # 构建响应数据 + response_data = [] + for item in watchlist: + code6, _ = _normalize_stock_input(item.stock_code) + quote = quotes_data.get(code6 or item.stock_code, {}) + response_data.append({ + 'stock_code': code6 or item.stock_code, + 'stock_name': item.stock_name or (code6 and _query_stock_name_by_code(code6)) or None, + 'current_price': quote.get('price', 0), + 'prev_close': quote.get('prev_close', 0), + 'change': quote.get('change', 0), + 'change_percent': quote.get('change_percent', 0), + 'high': quote.get('high', 0), + 'low': quote.get('low', 0), + 'volume': quote.get('volume', 0), + 'amount': quote.get('amount', 0), + 'update_time': quote.get('update_time', ''), + # industry 字段在 Watchlist 模型中不存在,先不返回该字段 + }) + + return jsonify({ + 'success': True, + 'data': response_data + }) + + except Exception as e: + print(f"获取实时行情失败: {str(e)}") + return jsonify({'success': False, 'error': '获取实时行情失败'}), 500 + + +# 投资计划和复盘相关的模型 +class InvestmentPlan(db.Model): + __tablename__ = 'investment_plans' + id = db.Column(db.Integer, primary_key=True, autoincrement=True) + user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False) + date = db.Column(db.Date, nullable=False) + title = db.Column(db.String(200), nullable=False) + content = db.Column(db.Text) + type = db.Column(db.String(20)) # 'plan' or 'review' + stocks = db.Column(db.Text) # JSON array of stock codes + tags = db.Column(db.String(500)) # JSON array of tags + status = db.Column(db.String(20), default='active') # active, completed, cancelled + created_at = db.Column(db.DateTime, default=datetime.utcnow) + updated_at = db.Column(db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) + + def to_dict(self): + return { + 'id': self.id, + 'date': self.date.isoformat() if self.date else None, + 'title': self.title, + 'content': self.content, + 'type': self.type, + 'stocks': json.loads(self.stocks) if self.stocks else [], + 'tags': json.loads(self.tags) if self.tags else [], + 'status': self.status, + 'created_at': self.created_at.isoformat() if self.created_at else None, + 'updated_at': self.updated_at.isoformat() if self.updated_at else None + } + + +@app.route('/api/account/investment-plans', methods=['GET']) +def get_investment_plans(): + """获取投资计划和复盘记录""" + try: + if 'user_id' not in session: + return jsonify({'success': False, 'error': '未登录'}), 401 + + plan_type = request.args.get('type') # 'plan', 'review', or None for all + start_date = request.args.get('start_date') + end_date = request.args.get('end_date') + + query = InvestmentPlan.query.filter_by(user_id=session['user_id']) + + if plan_type: + query = query.filter_by(type=plan_type) + + if start_date: + query = query.filter(InvestmentPlan.date >= datetime.fromisoformat(start_date).date()) + + if end_date: + query = query.filter(InvestmentPlan.date <= datetime.fromisoformat(end_date).date()) + + plans = query.order_by(InvestmentPlan.date.desc()).all() + + return jsonify({ + 'success': True, + 'data': [plan.to_dict() for plan in plans] + }) + + except Exception as e: + print(f"获取投资计划失败: {str(e)}") + return jsonify({'success': False, 'error': '获取数据失败'}), 500 + + +@app.route('/api/account/investment-plans', methods=['POST']) +def create_investment_plan(): + """创建投资计划或复盘记录""" + try: + if 'user_id' not in session: + return jsonify({'success': False, 'error': '未登录'}), 401 + + data = request.get_json() + + # 验证必要字段 + if not data.get('date') or not data.get('title') or not data.get('type'): + return jsonify({'success': False, 'error': '缺少必要字段'}), 400 + + plan = InvestmentPlan( + user_id=session['user_id'], + date=datetime.fromisoformat(data['date']).date(), + title=data['title'], + content=data.get('content', ''), + type=data['type'], + stocks=json.dumps(data.get('stocks', [])), + tags=json.dumps(data.get('tags', [])), + status=data.get('status', 'active') + ) + + db.session.add(plan) + db.session.commit() + + return jsonify({ + 'success': True, + 'data': plan.to_dict() + }) + + except Exception as e: + db.session.rollback() + print(f"创建投资计划失败: {str(e)}") + return jsonify({'success': False, 'error': '创建失败'}), 500 + + +@app.route('/api/account/investment-plans/', methods=['PUT']) +def update_investment_plan(plan_id): + """更新投资计划或复盘记录""" + try: + if 'user_id' not in session: + return jsonify({'success': False, 'error': '未登录'}), 401 + + plan = InvestmentPlan.query.filter_by(id=plan_id, user_id=session['user_id']).first() + if not plan: + return jsonify({'success': False, 'error': '未找到该记录'}), 404 + + data = request.get_json() + + if 'date' in data: + plan.date = datetime.fromisoformat(data['date']).date() + if 'title' in data: + plan.title = data['title'] + if 'content' in data: + plan.content = data['content'] + if 'stocks' in data: + plan.stocks = json.dumps(data['stocks']) + if 'tags' in data: + plan.tags = json.dumps(data['tags']) + if 'status' in data: + plan.status = data['status'] + + plan.updated_at = datetime.utcnow() + db.session.commit() + + return jsonify({ + 'success': True, + 'data': plan.to_dict() + }) + + except Exception as e: + db.session.rollback() + print(f"更新投资计划失败: {str(e)}") + return jsonify({'success': False, 'error': '更新失败'}), 500 + + +@app.route('/api/account/investment-plans/', methods=['DELETE']) +def delete_investment_plan(plan_id): + """删除投资计划或复盘记录""" + try: + if 'user_id' not in session: + return jsonify({'success': False, 'error': '未登录'}), 401 + + plan = InvestmentPlan.query.filter_by(id=plan_id, user_id=session['user_id']).first() + if not plan: + return jsonify({'success': False, 'error': '未找到该记录'}), 404 + + db.session.delete(plan) + db.session.commit() + + return jsonify({'success': True}) + + except Exception as e: + db.session.rollback() + print(f"删除投资计划失败: {str(e)}") + return jsonify({'success': False, 'error': '删除失败'}), 500 + + +@app.route('/api/account/events/following', methods=['GET']) +def get_my_following_events(): + """获取我关注的事件列表""" + if 'user_id' not in session: + return jsonify({'success': False, 'error': '未登录'}), 401 + + follows = EventFollow.query.filter_by(user_id=session['user_id']).order_by(EventFollow.created_at.desc()).all() + event_ids = [f.event_id for f in follows] + if not event_ids: + return jsonify({'success': True, 'data': []}) + + events = Event.query.filter(Event.id.in_(event_ids)).all() + data = [] + for ev in events: + data.append({ + 'id': ev.id, + 'title': ev.title, + 'event_type': ev.event_type, + 'start_time': ev.start_time.isoformat() if ev.start_time else None, + 'hot_score': ev.hot_score, + 'follower_count': ev.follower_count, + }) + return jsonify({'success': True, 'data': data}) + + +@app.route('/api/account/events/comments', methods=['GET']) +def get_my_event_comments(): + """获取我在事件上的评论(EventComment)""" + if 'user_id' not in session: + return jsonify({'success': False, 'error': '未登录'}), 401 + + comments = EventComment.query.filter_by(user_id=session['user_id']).order_by(EventComment.created_at.desc()).limit( + 100).all() + return jsonify({'success': True, 'data': [c.to_dict() for c in comments]}) + + +@app.route('/api/account/future-events/following', methods=['GET']) +def get_my_following_future_events(): + """获取当前用户关注的未来事件""" + if 'user_id' not in session: + return jsonify({'success': False, 'error': '未登录'}), 401 + + try: + # 获取用户关注的未来事件ID列表 + follows = FutureEventFollow.query.filter_by(user_id=session['user_id']).all() + future_event_ids = [f.future_event_id for f in follows] + + if not future_event_ids: + return jsonify({'success': True, 'data': []}) + + # 查询未来事件详情 + sql = """ + SELECT * + FROM future_events + WHERE data_id IN :event_ids + ORDER BY calendar_time \ + """ + + result = db.session.execute( + text(sql), + {'event_ids': tuple(future_event_ids)} + ) + + events = [] + for row in result: + event_data = { + 'id': row.data_id, + 'title': row.title, + 'type': row.type, + 'calendar_time': row.calendar_time.isoformat(), + 'star': row.star, + 'former': row.former, + 'forecast': row.forecast, + 'fact': row.fact, + 'is_following': True, # 这些都是已关注的 + 'related_stocks': parse_json_field(row.related_stocks), + 'concepts': parse_json_field(row.concepts) + } + events.append(event_data) + + return jsonify({'success': True, 'data': events}) + + except Exception as e: + return jsonify({'success': False, 'error': str(e)}), 500 + + +class PostLike(db.Model): + """帖子点赞""" + id = db.Column(db.Integer, primary_key=True) + user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False) + post_id = db.Column(db.Integer, db.ForeignKey('post.id'), nullable=False) + created_at = db.Column(db.DateTime, default=beijing_now) + + user = db.relationship('User', backref='post_likes') + + __table_args__ = (db.UniqueConstraint('user_id', 'post_id'),) + + +class Event(db.Model): + """事件模型""" + id = db.Column(db.Integer, primary_key=True) + title = db.Column(db.String(200), nullable=False) + description = db.Column(db.Text) + + # 事件类型与状态 + event_type = db.Column(db.String(50)) + status = db.Column(db.String(20), default='active') + + # 时间相关 + start_time = db.Column(db.DateTime, default=beijing_now) + end_time = db.Column(db.DateTime) + created_at = db.Column(db.DateTime, default=beijing_now) + updated_at = db.Column(db.DateTime, default=beijing_now) + + # 热度与统计 + hot_score = db.Column(db.Float, default=0) + view_count = db.Column(db.Integer, default=0) + trending_score = db.Column(db.Float, default=0) + post_count = db.Column(db.Integer, default=0) + follower_count = db.Column(db.Integer, default=0) + + # 关联信息 + related_industries = db.Column(db.JSON) + keywords = db.Column(db.JSON) + files = db.Column(db.JSON) + importance = db.Column(db.String(20)) + related_avg_chg = db.Column(db.Float, default=0) + related_max_chg = db.Column(db.Float, default=0) + related_week_chg = db.Column(db.Float, default=0) + + # 新增字段 + invest_score = db.Column(db.Integer) # 超预期得分 + expectation_surprise_score = db.Column(db.Integer) + # 创建者信息 + creator_id = db.Column(db.Integer, db.ForeignKey('user.id')) + creator = db.relationship('User', backref='created_events') + + # 关系 + posts = db.relationship('Post', backref='event', lazy='dynamic') + followers = db.relationship('EventFollow', backref='event', lazy='dynamic') + related_stocks = db.relationship('RelatedStock', backref='event', lazy='dynamic') + historical_events = db.relationship('HistoricalEvent', backref='event', lazy='dynamic') + related_data = db.relationship('RelatedData', backref='event', lazy='dynamic') + related_concepts = db.relationship('RelatedConcepts', backref='event', lazy='dynamic') + + @property + def keywords_list(self): + """返回解析后的关键词列表""" + if not self.keywords: + return [] + + if isinstance(self.keywords, list): + return self.keywords + + try: + # 如果是字符串,尝试解析JSON + if isinstance(self.keywords, str): + decoded = json.loads(self.keywords) + # 处理Unicode编码的情况 + if isinstance(decoded, list): + return [ + keyword.encode('utf-8').decode('unicode_escape') + if isinstance(keyword, str) and '\\u' in keyword + else keyword + for keyword in decoded + ] + return [] + + # 如果已经是字典或其他格式,尝试转换为列表 + return list(self.keywords) + except (json.JSONDecodeError, AttributeError, TypeError): + return [] + + def set_keywords(self, keywords): + """设置关键词列表""" + if isinstance(keywords, list): + self.keywords = json.dumps(keywords, ensure_ascii=False) + elif isinstance(keywords, str): + try: + # 尝试解析JSON字符串 + parsed = json.loads(keywords) + if isinstance(parsed, list): + self.keywords = json.dumps(parsed, ensure_ascii=False) + else: + self.keywords = json.dumps([keywords], ensure_ascii=False) + except json.JSONDecodeError: + # 如果不是有效的JSON,将其作为单个关键词 + self.keywords = json.dumps([keywords], ensure_ascii=False) + + +class RelatedStock(db.Model): + """相关标的模型""" + id = db.Column(db.Integer, primary_key=True) + event_id = db.Column(db.Integer, db.ForeignKey('event.id')) + stock_code = db.Column(db.String(20)) # 股票代码 + stock_name = db.Column(db.String(100)) # 股票名称 + sector = db.Column(db.String(100)) # 关联类型 + relation_desc = db.Column(db.String(1024)) # 关联原因描述 + created_at = db.Column(db.DateTime, default=beijing_now) + updated_at = db.Column(db.DateTime, default=beijing_now, onupdate=beijing_now) + correlation = db.Column(db.Float()) + momentum = db.Column(db.String(1024)) # 动量 + retrieved_sources = db.Column(db.JSON) # 动量 + + +class RelatedData(db.Model): + """关联数据模型""" + id = db.Column(db.Integer, primary_key=True) + event_id = db.Column(db.Integer, db.ForeignKey('event.id')) + title = db.Column(db.String(200)) # 数据标题 + data_type = db.Column(db.String(50)) # 数据类型 + data_content = db.Column(db.JSON) # 数据内容(JSON格式) + description = db.Column(db.Text) # 数据描述 + created_at = db.Column(db.DateTime, default=beijing_now) + + +class RelatedConcepts(db.Model): + """关联数据模型""" + id = db.Column(db.Integer, primary_key=True) + event_id = db.Column(db.Integer, db.ForeignKey('event.id')) + concept_code = db.Column(db.String(20)) # 数据标题 + concept = db.Column(db.String(100)) # 数据类型 + reason = db.Column(db.Text) # 数据描述 + image_paths = db.Column(db.JSON) # 数据内容(JSON格式) + created_at = db.Column(db.DateTime, default=beijing_now) + + @property + def image_paths_list(self): + """返回解析后的图片路径列表""" + if not self.image_paths: + return [] + + try: + # 如果是字符串,先解析成JSON + if isinstance(self.image_paths, str): + paths = json.loads(self.image_paths) + else: + paths = self.image_paths + + # 确保paths是列表 + if not isinstance(paths, list): + paths = [paths] + + # 从每个对象中提取path字段 + return [item['path'] if isinstance(item, dict) and 'path' in item + else item for item in paths] + except Exception as e: + print(f"Error processing image paths: {e}") + return [] + + def get_first_image_path(self): + """获取第一张图片的完整路径""" + paths = self.image_paths_list + if not paths: + return None + + # 获取第一个路径 + first_path = paths[0] + # 返回完整路径 + return first_path + + +class EventHotHistory(db.Model): + """事件热度历史记录""" + id = db.Column(db.Integer, primary_key=True) + event_id = db.Column(db.Integer, db.ForeignKey('event.id')) + score = db.Column(db.Float) # 总分 + interaction_score = db.Column(db.Float) # 互动分数 + follow_score = db.Column(db.Float) # 关注度分数 + view_score = db.Column(db.Float) # 浏览量分数 + recent_activity_score = db.Column(db.Float) # 最近活跃度分数 + time_decay = db.Column(db.Float) # 时间衰减因子 + created_at = db.Column(db.DateTime, default=beijing_now) + + event = db.relationship('Event', backref='hot_history') + + +class EventTransmissionNode(db.Model): + """事件传导节点模型""" + __tablename__ = 'event_transmission_nodes' + + id = db.Column(db.Integer, primary_key=True) + event_id = db.Column(db.Integer, db.ForeignKey('event.id'), nullable=False) + node_type = db.Column(db.Enum('company', 'industry', 'policy', 'technology', + 'market', 'event', 'other'), nullable=False) + node_name = db.Column(db.String(200), nullable=False) + node_description = db.Column(db.Text) + importance_score = db.Column(db.Integer, default=50) + stock_code = db.Column(db.String(20)) + is_main_event = db.Column(db.Boolean, default=False) + + created_at = db.Column(db.DateTime, default=beijing_now) + updated_at = db.Column(db.DateTime, default=beijing_now, onupdate=beijing_now) + + # Relationships + event = db.relationship('Event', backref='transmission_nodes') + outgoing_edges = db.relationship('EventTransmissionEdge', + foreign_keys='EventTransmissionEdge.from_node_id', + backref='from_node', cascade='all, delete-orphan') + incoming_edges = db.relationship('EventTransmissionEdge', + foreign_keys='EventTransmissionEdge.to_node_id', + backref='to_node', cascade='all, delete-orphan') + + __table_args__ = ( + db.Index('idx_event_id', 'event_id'), + db.Index('idx_node_type', 'node_type'), + db.Index('idx_main_event', 'is_main_event'), + ) + + +class EventTransmissionEdge(db.Model): + """事件传导边模型""" + __tablename__ = 'event_transmission_edges' + + id = db.Column(db.Integer, primary_key=True) + event_id = db.Column(db.Integer, db.ForeignKey('event.id'), nullable=False) + from_node_id = db.Column(db.Integer, db.ForeignKey('event_transmission_nodes.id'), nullable=False) + to_node_id = db.Column(db.Integer, db.ForeignKey('event_transmission_nodes.id'), nullable=False) + + transmission_type = db.Column(db.Enum('supply_chain', 'competition', 'policy', + 'technology', 'capital_flow', 'expectation', + 'cyclic_effect', 'other'), nullable=False) + transmission_mechanism = db.Column(db.Text) + direction = db.Column(db.Enum('positive', 'negative', 'neutral', 'mixed'), default='neutral') + strength = db.Column(db.Integer, default=50) + impact = db.Column(db.Text) + is_circular = db.Column(db.Boolean, default=False) + + created_at = db.Column(db.DateTime, default=beijing_now) + updated_at = db.Column(db.DateTime, default=beijing_now, onupdate=beijing_now) + + # Relationship + event = db.relationship('Event', backref='transmission_edges') + + __table_args__ = ( + db.Index('idx_event_id', 'event_id'), + db.Index('idx_strength', 'strength'), + db.Index('idx_from_to', 'from_node_id', 'to_node_id'), + db.Index('idx_circular', 'is_circular'), + ) + + +# 在 paste-2.txt 的模型定义部分添加 +class EventSankeyFlow(db.Model): + """事件桑基流模型""" + __tablename__ = 'event_sankey_flows' + + id = db.Column(db.Integer, primary_key=True) + event_id = db.Column(db.Integer, db.ForeignKey('event.id'), nullable=False) + + # 流的基本信息 + source_node = db.Column(db.String(200), nullable=False) + source_type = db.Column(db.Enum('event', 'policy', 'technology', 'industry', + 'company', 'product'), nullable=False) + source_level = db.Column(db.Integer, nullable=False, default=0) + + target_node = db.Column(db.String(200), nullable=False) + target_type = db.Column(db.Enum('policy', 'technology', 'industry', + 'company', 'product'), nullable=False) + target_level = db.Column(db.Integer, nullable=False, default=1) + + # 流量信息 + flow_value = db.Column(db.Numeric(10, 2), nullable=False) + flow_ratio = db.Column(db.Numeric(5, 4), nullable=False) + + # 传导机制 + transmission_path = db.Column(db.String(500)) + impact_description = db.Column(db.Text) + evidence_strength = db.Column(db.Integer, default=50) + + # 时间戳 + created_at = db.Column(db.DateTime, default=beijing_now) + updated_at = db.Column(db.DateTime, default=beijing_now, onupdate=beijing_now) + + # 关系 + event = db.relationship('Event', backref='sankey_flows') + + __table_args__ = ( + db.Index('idx_event_id', 'event_id'), + db.Index('idx_source_target', 'source_node', 'target_node'), + db.Index('idx_levels', 'source_level', 'target_level'), + db.Index('idx_flow_value', 'flow_value'), + ) + + +class HistoricalEvent(db.Model): + """历史事件模型""" + id = db.Column(db.Integer, primary_key=True) + event_id = db.Column(db.Integer, db.ForeignKey('event.id')) + title = db.Column(db.String(200)) + content = db.Column(db.Text) + event_date = db.Column(db.DateTime) + relevance = db.Column(db.Integer) # 相关性 + importance = db.Column(db.Integer) # 重要程度 + related_stock = db.Column(db.JSON) # 保留JSON字段 + created_at = db.Column(db.DateTime, default=beijing_now) + + # 新增关系 + stocks = db.relationship('HistoricalEventStock', backref='historical_event', lazy='dynamic', + cascade='all, delete-orphan') + + +class HistoricalEventStock(db.Model): + """历史事件相关股票模型""" + __tablename__ = 'historical_event_stocks' + + id = db.Column(db.Integer, primary_key=True) + historical_event_id = db.Column(db.Integer, db.ForeignKey('historical_event.id'), nullable=False) + stock_code = db.Column(db.String(20), nullable=False) + stock_name = db.Column(db.String(50)) + relation_desc = db.Column(db.Text) + correlation = db.Column(db.Float, default=0.5) + sector = db.Column(db.String(100)) + created_at = db.Column(db.DateTime, default=beijing_now) + + __table_args__ = ( + db.UniqueConstraint('historical_event_id', 'stock_code', name='unique_event_stock'), + ) + + +# === 股票盈利预测(自有表) === +class StockForecastData(db.Model): + """股票盈利预测数据 + + 源于本地表 stock_forecast_data,由独立离线程序写入。 + 字段与表结构保持一致,仅用于读取聚合后输出前端报表所需的结构。 + """ + __tablename__ = 'stock_forecast_data' + + id = db.Column(db.Integer, primary_key=True) + stock_code = db.Column(db.String(6), nullable=False) + indicator_name = db.Column(db.String(50), nullable=False) + year_2022a = db.Column(db.Numeric(15, 2)) + year_2023a = db.Column(db.Numeric(15, 2)) + year_2024a = db.Column(db.Numeric(15, 2)) + year_2025e = db.Column(db.Numeric(15, 2)) + year_2026e = db.Column(db.Numeric(15, 2)) + year_2027e = db.Column(db.Numeric(15, 2)) + process_time = db.Column(db.DateTime, nullable=False) + + __table_args__ = ( + db.UniqueConstraint('stock_code', 'indicator_name', name='unique_stock_indicator'), + ) + + def values_by_year(self): + years = ['2022A', '2023A', '2024A', '2025E', '2026E', '2027E'] + vals = [self.year_2022a, self.year_2023a, self.year_2024a, self.year_2025e, self.year_2026e, self.year_2027e] + + def _to_float(x): + try: + return float(x) if x is not None else None + except Exception: + return None + + return years, [_to_float(v) for v in vals] + + +@app.route('/api/events/', methods=['GET']) +def get_event_detail(event_id): + """获取事件详情""" + try: + event = Event.query.get_or_404(event_id) + + # 增加浏览计数 + event.view_count += 1 + db.session.commit() + + return jsonify({ + 'success': True, + 'data': { + 'id': event.id, + 'title': event.title, + 'description': event.description, + 'event_type': event.event_type, + 'status': event.status, + 'start_time': event.start_time.isoformat() if event.start_time else None, + 'end_time': event.end_time.isoformat() if event.end_time else None, + 'created_at': event.created_at.isoformat() if event.created_at else None, + 'hot_score': event.hot_score, + 'view_count': event.view_count, + 'trending_score': event.trending_score, + 'post_count': event.post_count, + 'follower_count': event.follower_count, + 'related_industries': event.related_industries, + 'keywords': event.keywords_list, + 'importance': event.importance, + 'related_avg_chg': event.related_avg_chg, + 'related_max_chg': event.related_max_chg, + 'related_week_chg': event.related_week_chg, + 'invest_score': event.invest_score, + 'expectation_surprise_score': event.expectation_surprise_score, + 'creator_id': event.creator_id, + 'has_chain_analysis': ( + EventTransmissionNode.query.filter_by(event_id=event_id).first() is not None or + EventSankeyFlow.query.filter_by(event_id=event_id).first() is not None + ), + 'is_following': False, # 需要根据当前用户状态判断 + } + }) + except Exception as e: + return jsonify({'success': False, 'error': str(e)}), 500 + + +@app.route('/api/events//stocks', methods=['GET']) +def get_related_stocks(event_id): + """获取相关股票列表""" + try: + # 订阅控制:相关标的需要 Pro 及以上 + if not _has_required_level('pro'): + return jsonify({'success': False, 'error': '需要Pro订阅', 'required_level': 'pro'}), 403 + event = Event.query.get_or_404(event_id) + stocks = event.related_stocks.order_by(RelatedStock.correlation.desc()).all() + + stocks_data = [] + for stock in stocks: + if stock.retrieved_sources is not None: + stocks_data.append({ + 'id': stock.id, + 'stock_code': stock.stock_code, + 'stock_name': stock.stock_name, + 'sector': stock.sector, + 'relation_desc': {"data":stock.retrieved_sources}, + 'retrieved_sources': stock.retrieved_sources, + 'correlation': stock.correlation, + 'momentum': stock.momentum, + 'created_at': stock.created_at.isoformat() if stock.created_at else None, + 'updated_at': stock.updated_at.isoformat() if stock.updated_at else None + }) + else: + stocks_data.append({ + 'id': stock.id, + 'stock_code': stock.stock_code, + 'stock_name': stock.stock_name, + 'sector': stock.sector, + 'relation_desc': stock.relation_desc, + 'correlation': stock.correlation, + 'momentum': stock.momentum, + 'created_at': stock.created_at.isoformat() if stock.created_at else None, + 'updated_at': stock.updated_at.isoformat() if stock.updated_at else None + }) + + return jsonify({ + 'success': True, + 'data': stocks_data + }) + except Exception as e: + return jsonify({'success': False, 'error': str(e)}), 500 + + +@app.route('/api/events//stocks', methods=['POST']) +def add_related_stock(event_id): + """添加相关股票""" + try: + event = Event.query.get_or_404(event_id) + data = request.get_json() + + # 验证必要字段 + if not data.get('stock_code') or not data.get('relation_desc'): + return jsonify({'success': False, 'error': '缺少必要字段'}), 400 + + # 检查是否已存在 + existing = RelatedStock.query.filter_by( + event_id=event_id, + stock_code=data['stock_code'] + ).first() + + if existing: + return jsonify({'success': False, 'error': '该股票已存在'}), 400 + + # 创建新的相关股票记录 + new_stock = RelatedStock( + event_id=event_id, + stock_code=data['stock_code'], + stock_name=data.get('stock_name', ''), + sector=data.get('sector', ''), + relation_desc=data['relation_desc'], + correlation=data.get('correlation', 0.5), + momentum=data.get('momentum', '') + ) + + db.session.add(new_stock) + db.session.commit() + + return jsonify({ + 'success': True, + 'data': { + 'id': new_stock.id, + 'stock_code': new_stock.stock_code, + 'relation_desc': new_stock.relation_desc + } + }) + except Exception as e: + db.session.rollback() + return jsonify({'success': False, 'error': str(e)}), 500 + + +@app.route('/api/stocks/', methods=['DELETE']) +def delete_related_stock(stock_id): + """删除相关股票""" + try: + stock = RelatedStock.query.get_or_404(stock_id) + db.session.delete(stock) + db.session.commit() + + return jsonify({'success': True, 'message': '删除成功'}) + except Exception as e: + db.session.rollback() + return jsonify({'success': False, 'error': str(e)}), 500 + + +@app.route('/api/events//concepts', methods=['GET']) +def get_related_concepts(event_id): + """获取相关概念列表""" + try: + # 订阅控制:相关概念需要 Pro 及以上 + if not _has_required_level('pro'): + return jsonify({'success': False, 'error': '需要Pro订阅', 'required_level': 'pro'}), 403 + event = Event.query.get_or_404(event_id) + concepts = event.related_concepts.all() + + concepts_data = [] + for concept in concepts: + concepts_data.append({ + 'id': concept.id, + 'concept_code': concept.concept_code, + 'concept': concept.concept, + 'reason': concept.reason, + 'image_paths': concept.image_paths_list, + 'first_image_path': concept.get_first_image_path(), + 'created_at': concept.created_at.isoformat() if concept.created_at else None + }) + + return jsonify({ + 'success': True, + 'data': concepts_data + }) + except Exception as e: + return jsonify({'success': False, 'error': str(e)}), 500 + + +@app.route('/api/events//historical', methods=['GET']) +def get_historical_events(event_id): + """获取历史事件对比""" + try: + event = Event.query.get_or_404(event_id) + historical_events = event.historical_events.order_by(HistoricalEvent.event_date.desc()).all() + + events_data = [] + for hist_event in historical_events: + events_data.append({ + 'id': hist_event.id, + 'title': hist_event.title, + 'content': hist_event.content, + 'event_date': hist_event.event_date.isoformat() if hist_event.event_date else None, + 'importance': hist_event.importance, + 'relevance': hist_event.relevance, + 'created_at': hist_event.created_at.isoformat() if hist_event.created_at else None + }) + + # 订阅控制:免费用户仅返回前2条;Pro/Max返回全部 + info = _get_current_subscription_info() + sub_type = (info.get('type') or 'free').lower() + if sub_type == 'free': + return jsonify({ + 'success': True, + 'data': events_data[:2], + 'truncated': len(events_data) > 2, + 'required_level': 'pro' + }) + return jsonify({'success': True, 'data': events_data}) + except Exception as e: + return jsonify({'success': False, 'error': str(e)}), 500 + + +@app.route('/api/historical-events//stocks', methods=['GET']) +def get_historical_event_stocks(event_id): + """获取历史事件相关股票列表""" + try: + # 直接查询历史事件,不需要通过主事件 + hist_event = HistoricalEvent.query.get_or_404(event_id) + stocks = hist_event.stocks.order_by(HistoricalEventStock.correlation.desc()).all() + + # 获取事件对应的交易日 + event_trading_date = None + if hist_event.event_date: + event_trading_date = get_trading_day_near_date(hist_event.event_date) + + stocks_data = [] + for stock in stocks: + stock_data = { + 'id': stock.id, + 'stock_code': stock.stock_code, + 'stock_name': stock.stock_name, + 'sector': stock.sector, + 'relation_desc': stock.relation_desc, + 'correlation': stock.correlation, + 'created_at': stock.created_at.isoformat() if stock.created_at else None + } + + # 添加涨幅数据 + if event_trading_date: + try: + # 查询股票在事件对应交易日的数据 + with engine.connect() as conn: + query = text(""" + SELECT close_price, change_pct + FROM ea_dailyline + WHERE seccode = :stock_code + AND date = :trading_date + ORDER BY date DESC + LIMIT 1 + """) + + result = conn.execute(query, { + 'stock_code': stock.stock_code, + 'trading_date': event_trading_date + }).fetchone() + + if result: + stock_data['event_day_close'] = float(result[0]) if result[0] else None + stock_data['event_day_change_pct'] = float(result[1]) if result[1] else None + else: + stock_data['event_day_close'] = None + stock_data['event_day_change_pct'] = None + except Exception as e: + print(f"查询股票{stock.stock_code}在{event_trading_date}的数据失败: {e}") + stock_data['event_day_close'] = None + stock_data['event_day_change_pct'] = None + else: + stock_data['event_day_close'] = None + stock_data['event_day_change_pct'] = None + + stocks_data.append(stock_data) + + return jsonify({ + 'success': True, + 'data': stocks_data, + 'event_trading_date': event_trading_date.isoformat() if event_trading_date else None + }) + except Exception as e: + return jsonify({'success': False, 'error': str(e)}), 500 + + +@app.route('/api/events//expectation-score', methods=['GET']) +def get_expectation_score(event_id): + """获取超预期得分""" + try: + event = Event.query.get_or_404(event_id) + + # 如果事件有超预期得分,直接返回 + if event.expectation_surprise_score is not None: + score = event.expectation_surprise_score + else: + # 如果没有,根据历史事件计算一个模拟得分 + historical_events = event.historical_events.all() + if historical_events: + # 基于历史事件数量和重要性计算得分 + total_importance = sum(ev.importance or 0 for ev in historical_events) + avg_importance = total_importance / len(historical_events) if historical_events else 0 + score = min(100, max(0, int(avg_importance * 20 + len(historical_events) * 5))) + else: + # 默认得分 + score = 65 + + return jsonify({ + 'success': True, + 'data': { + 'score': score, + 'description': '基于历史事件判断当前事件的超预期情况,满分100分' + } + }) + except Exception as e: + return jsonify({'success': False, 'error': str(e)}), 500 + + +@app.route('/api/events//follow', methods=['POST']) +def toggle_event_follow(event_id): + """切换事件关注状态(需登录)""" + if 'user_id' not in session: + return jsonify({'success': False, 'error': '未登录'}), 401 + + try: + event = Event.query.get_or_404(event_id) + user_id = session['user_id'] + + existing = EventFollow.query.filter_by(user_id=user_id, event_id=event_id).first() + if existing: + # 取消关注 + db.session.delete(existing) + event.follower_count = max(0, (event.follower_count or 0) - 1) + db.session.commit() + return jsonify({'success': True, 'data': {'is_following': False, 'follower_count': event.follower_count}}) + else: + # 关注 + follow = EventFollow(user_id=user_id, event_id=event_id) + db.session.add(follow) + event.follower_count = (event.follower_count or 0) + 1 + db.session.commit() + return jsonify({'success': True, 'data': {'is_following': True, 'follower_count': event.follower_count}}) + except Exception as e: + db.session.rollback() + return jsonify({'success': False, 'error': str(e)}), 500 + + +@app.route('/api/events//transmission', methods=['GET']) +def get_transmission_chain(event_id): + try: + # 订阅控制:传导链分析需要 Max 及以上 + if not _has_required_level('max'): + return jsonify({'success': False, 'error': '需要Max订阅', 'required_level': 'max'}), 403 + # 确保数据库连接是活跃的 + db.session.execute(text('SELECT 1')) + + event = Event.query.get_or_404(event_id) + nodes = EventTransmissionNode.query.filter_by(event_id=event_id).all() + edges = EventTransmissionEdge.query.filter_by(event_id=event_id).all() + + # 过滤孤立节点 + connected_node_ids = set() + for edge in edges: + connected_node_ids.add(edge.from_node_id) + connected_node_ids.add(edge.to_node_id) + + # 只保留有连接的节点 + connected_nodes = [node for node in nodes if node.id in connected_node_ids] + + # 如果没有主事件节点,也保留主事件节点 + main_event_node = next((node for node in nodes if node.is_main_event), None) + if main_event_node and main_event_node not in connected_nodes: + connected_nodes.append(main_event_node) + + if not connected_nodes: + return jsonify({'success': False, 'message': '暂无传导链分析数据'}) + + # 节点类型到中文类别的映射 + categories = { + 'event': "事件", 'industry': "行业", 'company': "公司", + 'policy': "政策", 'technology': "技术", 'market': "市场", 'other': "其他" + } + + nodes_data = [] + for node in connected_nodes: + node_category = categories.get(node.node_type, "其他") + nodes_data.append({ + 'id': str(node.id), # 转换为字符串以保持一致性 + 'name': node.node_name, + 'category': node_category, + 'value': node.importance_score or 20, + 'extra': { + 'node_type': node.node_type, + 'description': node.node_description, + 'importance_score': node.importance_score, + 'stock_code': node.stock_code, + 'is_main_event': node.is_main_event + } + }) + + edges_data = [] + for edge in edges: + # 确保边的两端节点都在连接节点列表中 + if edge.from_node_id in connected_node_ids and edge.to_node_id in connected_node_ids: + edges_data.append({ + 'source': str(edge.from_node_id), # 转换为字符串以保持一致性 + 'target': str(edge.to_node_id), # 转换为字符串以保持一致性 + 'value': edge.strength or 50, + 'extra': { + 'transmission_type': edge.transmission_type, + 'transmission_mechanism': edge.transmission_mechanism, + 'direction': edge.direction, + 'strength': edge.strength, + 'impact': edge.impact, + 'is_circular': edge.is_circular, + } + }) + + return jsonify({ + 'success': True, + 'data': { + 'nodes': nodes_data, + 'edges': edges_data + } + }) + except Exception as e: + return jsonify({'success': False, 'error': str(e)}), 500 + + +# 修复股票报价API - 支持GET和POST方法 +@app.route('/api/stock/quotes', methods=['GET', 'POST']) +def get_stock_quotes(): + try: + if request.method == 'GET': + # GET 请求从查询参数获取数据 + codes_str = request.args.get('codes', '') + codes = [code.strip() for code in codes_str.split(',') if code.strip()] + event_time_str = request.args.get('event_time') + else: + # POST 请求从 JSON 获取数据 + codes = request.json.get('codes', []) + event_time_str = request.json.get('event_time') + + if not codes: + return jsonify({'success': False, 'error': '请提供股票代码'}), 400 + + # 处理事件时间 + if event_time_str: + try: + event_time = datetime.fromisoformat(event_time_str.replace('Z', '+00:00')) + except: + event_time = datetime.now() + else: + event_time = datetime.now() + + current_time = datetime.now() + client = get_clickhouse_client() + + # Get stock names from MySQL + stock_names = {} + with engine.connect() as conn: + for code in codes: + codez = code.split('.')[0] + result = conn.execute(text( + "SELECT SECNAME FROM ea_stocklist WHERE SECCODE = :code" + ), {"code": codez}).fetchone() + if result: + stock_names[code] = result[0] + else: + stock_names[code] = f"股票{codez}" + + def get_trading_day_and_times(event_datetime): + event_date = event_datetime.date() + event_time = event_datetime.time() + + # Trading hours + market_open = dt_time(9, 30) + market_close = dt_time(15, 0) + + with engine.connect() as conn: + # First check if the event date itself is a trading day + is_trading_day = conn.execute(text(""" + SELECT 1 + FROM trading_days + WHERE EXCHANGE_DATE = :date + """), {"date": event_date}).fetchone() is not None + + if is_trading_day: + # If it's a trading day, determine time period based on event time + if event_time < market_open: + # Before market opens - use full trading day + return event_date, market_open, market_close + elif event_time > market_close: + # After market closes - get next trading day + next_trading_day = conn.execute(text(""" + SELECT EXCHANGE_DATE + FROM trading_days + WHERE EXCHANGE_DATE > :date + ORDER BY EXCHANGE_DATE LIMIT 1 + """), {"date": event_date}).fetchone() + # Convert to date object if we found a next trading day + return (next_trading_day[0].date() if next_trading_day else None, + market_open, market_close) + else: + # During trading hours + return event_date, event_time, market_close + else: + # If not a trading day, get next trading day + next_trading_day = conn.execute(text(""" + SELECT EXCHANGE_DATE + FROM trading_days + WHERE EXCHANGE_DATE > :date + ORDER BY EXCHANGE_DATE LIMIT 1 + """), {"date": event_date}).fetchone() + # Convert to date object if we found a next trading day + return (next_trading_day[0].date() if next_trading_day else None, + market_open, market_close) + + trading_day, start_time, end_time = get_trading_day_and_times(event_time) + + if not trading_day: + return jsonify({ + 'success': True, + 'data': {code: {'name': name, 'price': None, 'change': None} + for code, name in stock_names.items()} + }) + + # For historical dates, ensure we're using actual data + start_datetime = datetime.combine(trading_day, start_time) + end_datetime = datetime.combine(trading_day, end_time) + + # If the trading day is in the future relative to current time, + # return only names without data + if trading_day > current_time.date(): + return jsonify({ + 'success': True, + 'data': {code: {'name': name, 'price': None, 'change': None} + for code, name in stock_names.items()} + }) + + results = {} + print(f"处理股票代码: {codes}, 交易日: {trading_day}, 时间范围: {start_datetime} - {end_datetime}") + + for code in codes: + try: + print(f"正在查询股票 {code} 的价格数据...") + # Get the first price and last price for the trading period + data = client.execute(""" + WITH first_price AS (SELECT close + FROM stock_minute + WHERE code = %(code)s + AND timestamp >= %(start)s + AND timestamp <= %(end)s + ORDER BY timestamp + LIMIT 1 + ), + last_price AS ( + SELECT close + FROM stock_minute + WHERE code = %(code)s + AND timestamp >= %(start)s + AND timestamp <= %(end)s + ORDER BY timestamp DESC + LIMIT 1 + ) + SELECT last_price.close as last_price, + (last_price.close - first_price.close) / first_price.close * 100 as change + FROM last_price + CROSS JOIN first_price + WHERE EXISTS (SELECT 1 FROM first_price) + AND EXISTS (SELECT 1 FROM last_price) + """, { + 'code': code, + 'start': start_datetime, + 'end': end_datetime + }) + + print(f"股票 {code} 查询结果: {data}") + if data and data[0] and data[0][0] is not None: + price = float(data[0][0]) if data[0][0] is not None else None + change = float(data[0][1]) if data[0][1] is not None else None + + results[code] = { + 'price': price, + 'change': change, + 'name': stock_names.get(code, f'股票{code.split(".")[0]}') + } + else: + results[code] = { + 'price': None, + 'change': None, + 'name': stock_names.get(code, f'股票{code.split(".")[0]}') + } + except Exception as e: + print(f"Error processing stock {code}: {e}") + results[code] = { + 'price': None, + 'change': None, + 'name': stock_names.get(code, f'股票{code.split(".")[0]}') + } + + # 返回标准格式 + return jsonify({'success': True, 'data': results}) + + except Exception as e: + print(f"Stock quotes API error: {e}") + return jsonify({'success': False, 'error': str(e)}), 500 + + +def get_clickhouse_client(): + return Cclient( + host='222.128.1.157', + port=18000, + user='default', + password='Zzl33818!', + database='stock' + ) + + +@app.route('/api/account/calendar/events', methods=['GET', 'POST']) +def account_calendar_events(): + """返回当前用户的投资计划与关注的未来事件(合并)。 + GET: 可按日期范围/月份过滤;POST: 新增投资计划(写入 InvestmentPlan)。 + """ + try: + if 'user_id' not in session: + return jsonify({'success': False, 'error': '未登录'}), 401 + + if request.method == 'POST': + data = request.get_json() or {} + title = data.get('title') + event_date_str = data.get('event_date') or data.get('date') + plan_type = data.get('type') or 'plan' + description = data.get('description') or data.get('content') or '' + stocks = data.get('stocks') or [] + + if not title or not event_date_str: + return jsonify({'success': False, 'error': '缺少必填字段'}), 400 + + try: + event_date = datetime.fromisoformat(event_date_str).date() + except Exception: + return jsonify({'success': False, 'error': '日期格式错误'}), 400 + + plan = InvestmentPlan( + user_id=session['user_id'], + date=event_date, + title=title, + content=description, + type=plan_type, + stocks=json.dumps(stocks), + tags=json.dumps(data.get('tags', [])), + status=data.get('status', 'active') + ) + db.session.add(plan) + db.session.commit() + + return jsonify({'success': True, 'data': { + 'id': plan.id, + 'title': plan.title, + 'event_date': plan.date.isoformat(), + 'type': plan.type, + 'description': plan.content, + 'stocks': json.loads(plan.stocks) if plan.stocks else [], + 'source': 'plan' + }}) + + # GET + # 解析过滤参数:date 或 (year, month) 或 (start_date, end_date) + date_str = request.args.get('date') + year = request.args.get('year', type=int) + month = request.args.get('month', type=int) + start_date_str = request.args.get('start_date') + end_date_str = request.args.get('end_date') + + start_date = None + end_date = None + if date_str: + try: + d = datetime.fromisoformat(date_str).date() + start_date = d + end_date = d + except Exception: + pass + elif year and month: + # 月份范围 + start_date = datetime(year, month, 1).date() + if month == 12: + end_date = datetime(year + 1, 1, 1).date() - timedelta(days=1) + else: + end_date = datetime(year, month + 1, 1).date() - timedelta(days=1) + elif start_date_str and end_date_str: + try: + start_date = datetime.fromisoformat(start_date_str).date() + end_date = datetime.fromisoformat(end_date_str).date() + except Exception: + start_date = None + end_date = None + + # 查询投资计划 + plans_query = InvestmentPlan.query.filter_by(user_id=session['user_id']) + if start_date and end_date: + plans_query = plans_query.filter(InvestmentPlan.date >= start_date, InvestmentPlan.date <= end_date) + elif start_date: + plans_query = plans_query.filter(InvestmentPlan.date == start_date) + plans = plans_query.order_by(InvestmentPlan.date.asc()).all() + + plan_events = [{ + 'id': p.id, + 'title': p.title, + 'event_date': p.date.isoformat(), + 'type': p.type or 'plan', + 'description': p.content, + 'importance': 3, + 'stocks': json.loads(p.stocks) if p.stocks else [], + 'source': 'plan' + } for p in plans] + + # 查询关注的未来事件 + follows = FutureEventFollow.query.filter_by(user_id=session['user_id']).all() + future_event_ids = [f.future_event_id for f in follows] + + future_events = [] + if future_event_ids: + base_sql = """ + SELECT data_id, \ + title, \ + type, \ + calendar_time, \ + star, \ + former, \ + forecast, \ + fact, \ + related_stocks, \ + concepts + FROM future_events + WHERE data_id IN :event_ids \ + """ + + params = {'event_ids': tuple(future_event_ids)} + # 日期过滤(按 calendar_time 的日期) + if start_date and end_date: + base_sql += " AND DATE(calendar_time) BETWEEN :start_date AND :end_date" + params.update({'start_date': start_date, 'end_date': end_date}) + elif start_date: + base_sql += " AND DATE(calendar_time) = :start_date" + params.update({'start_date': start_date}) + + base_sql += " ORDER BY calendar_time" + + result = db.session.execute(text(base_sql), params) + for row in result: + # related_stocks 形如 [[code,name,reason,score], ...] + rs = parse_json_field(row.related_stocks) + stock_tags = [] + try: + for it in rs: + if isinstance(it, (list, tuple)) and len(it) >= 2: + stock_tags.append(f"{it[0]} {it[1]}") + elif isinstance(it, str): + stock_tags.append(it) + except Exception: + pass + + future_events.append({ + 'id': row.data_id, + 'title': row.title, + 'event_date': (row.calendar_time.date().isoformat() if row.calendar_time else None), + 'type': 'future_event', + 'importance': int(row.star) if getattr(row, 'star', None) is not None else 3, + 'description': row.former or '', + 'stocks': stock_tags, + 'is_following': True, + 'source': 'future' + }) + + return jsonify({'success': True, 'data': plan_events + future_events}) + + except Exception as e: + return jsonify({'success': False, 'error': str(e)}), 500 + + +@app.route('/api/account/calendar/events/', methods=['DELETE']) +def delete_account_calendar_event(event_id): + """删除用户创建的投资计划事件(不影响关注的未来事件)。""" + try: + if 'user_id' not in session: + return jsonify({'success': False, 'error': '未登录'}), 401 + plan = InvestmentPlan.query.filter_by(id=event_id, user_id=session['user_id']).first() + if not plan: + return jsonify({'success': False, 'error': '未找到该记录'}), 404 + db.session.delete(plan) + db.session.commit() + return jsonify({'success': True}) + except Exception as e: + db.session.rollback() + return jsonify({'success': False, 'error': str(e)}), 500 + + +@app.route('/api/stock//kline') +def get_stock_kline(stock_code): + chart_type = request.args.get('type', 'minute') + event_time = request.args.get('event_time') + + try: + event_datetime = datetime.fromisoformat(event_time) if event_time else datetime.now() + except ValueError: + return jsonify({'error': 'Invalid event_time format'}), 400 + + # 获取股票名称 + with engine.connect() as conn: + result = conn.execute(text( + "SELECT SECNAME FROM ea_stocklist WHERE SECCODE = :code" + ), {"code": stock_code.split('.')[0]}).fetchone() + stock_name = result[0] if result else 'Unknown' + + if chart_type == 'daily': + return get_daily_kline(stock_code, event_datetime, stock_name) + elif chart_type == 'minute': + return get_minute_kline(stock_code, event_datetime, stock_name) + elif chart_type == 'timeline': + return get_timeline_data(stock_code, event_datetime, stock_name) + else: + # 对于未知的类型,返回错误 + return jsonify({'error': f'Unsupported chart type: {chart_type}'}), 400 + + +@app.route('/api/stock//latest-minute', methods=['GET']) +def get_latest_minute_data(stock_code): + """获取最新交易日的分钟频数据""" + client = get_clickhouse_client() + + # 确保股票代码包含后缀 + if '.' not in stock_code: + stock_code = f"{stock_code}.SH" if stock_code.startswith('6') else f"{stock_code}.SZ" + + # 获取股票名称 + with engine.connect() as conn: + result = conn.execute(text( + "SELECT SECNAME FROM ea_stocklist WHERE SECCODE = :code" + ), {"code": stock_code.split('.')[0]}).fetchone() + stock_name = result[0] if result else 'Unknown' + + # 查找最近30天内有数据的最新交易日 + target_date = None + current_date = datetime.now().date() + + for i in range(30): + check_date = current_date - timedelta(days=i) + trading_day = get_trading_day_near_date(check_date) + + if trading_day and trading_day <= current_date: + # 检查这个交易日是否有分钟数据 + test_data = client.execute(""" + SELECT COUNT(*) + FROM stock_minute + WHERE code = %(code)s + AND timestamp BETWEEN %(start)s AND %(end)s + LIMIT 1 + """, { + 'code': stock_code, + 'start': datetime.combine(trading_day, dt_time(9, 30)), + 'end': datetime.combine(trading_day, dt_time(15, 0)) + }) + + if test_data and test_data[0][0] > 0: + target_date = trading_day + break + + if not target_date: + return jsonify({ + 'error': 'No data available', + 'code': stock_code, + 'name': stock_name, + 'data': [], + 'trade_date': current_date.strftime('%Y-%m-%d'), + 'type': 'minute' + }) + + # 获取目标日期的完整交易时段数据 + data = client.execute(""" + SELECT + timestamp, + open, + high, + low, + close, + volume, + amt + FROM stock_minute + WHERE code = %(code)s + AND timestamp BETWEEN %(start)s AND %(end)s + ORDER BY timestamp + """, { + 'code': stock_code, + 'start': datetime.combine(target_date, dt_time(9, 30)), + 'end': datetime.combine(target_date, dt_time(15, 0)) + }) + + kline_data = [{ + 'time': row[0].strftime('%H:%M'), + 'open': float(row[1]), + 'high': float(row[2]), + 'low': float(row[3]), + 'close': float(row[4]), + 'volume': float(row[5]), + 'amount': float(row[6]) + } for row in data] + + return jsonify({ + 'code': stock_code, + 'name': stock_name, + 'data': kline_data, + 'trade_date': target_date.strftime('%Y-%m-%d'), + 'type': 'minute', + 'is_latest': True + }) + + +@app.route('/api/stock//forecast-report', methods=['GET']) +def get_stock_forecast_report(stock_code): + """基于 stock_forecast_data 输出报表所需数据结构 + + 返回: + - income_profit_trend: 营业收入/归母净利润趋势 + - growth_bars: 增长率柱状图数据(基于营业收入同比) + - eps_trend: EPS 折线 + - pe_peg_axes: PE/PEG 双轴 + - detail_table: 详细数据表格(与附件结构一致) + """ + try: + # 读取该股票所有指标 + rows = StockForecastData.query.filter_by(stock_code=stock_code).all() + if not rows: + return jsonify({'success': False, 'error': 'no_data'}), 404 + + # 将指标映射为字典 + indicators = {} + for r in rows: + years, vals = r.values_by_year() + indicators[r.indicator_name] = dict(zip(years, vals)) + + def safe(x): + return x if x is not None else None + + years = ['2022A', '2023A', '2024A', '2025E', '2026E', '2027E'] + + # 营业收入与净利润趋势 + income = indicators.get('营业总收入(百万元)', {}) + profit = indicators.get('归母净利润(百万元)', {}) + income_profit_trend = { + 'years': years, + 'income': [safe(income.get(y)) for y in years], + 'profit': [safe(profit.get(y)) for y in years] + } + + # 增长率柱状(若表内已有"增长率(%)",直接使用;否则按营业收入同比计算) + growth = indicators.get('增长率(%)') + if growth is None: + # 计算同比: (curr - prev)/prev*100 + growth_vals = [] + prev = None + for y in years: + curr = income.get(y) + if prev is not None and prev not in (None, 0) and curr is not None: + growth_vals.append(round((float(curr) - float(prev)) / float(prev) * 100, 2)) + else: + growth_vals.append(None) + prev = curr + else: + growth_vals = [safe(growth.get(y)) for y in years] + growth_bars = { + 'years': years, + 'revenue_growth_pct': growth_vals, + 'net_profit_growth_pct': None # 如后续需要可扩展 + } + + # EPS 趋势 + eps = indicators.get('EPS(稀释)') or indicators.get('EPS(元/股)') or {} + eps_trend = { + 'years': years, + 'eps': [safe(eps.get(y)) for y in years] + } + + # PE / PEG 双轴 + pe = indicators.get('PE') or {} + peg = indicators.get('PEG') or {} + pe_peg_axes = { + 'years': years, + 'pe': [safe(pe.get(y)) for y in years], + 'peg': [safe(peg.get(y)) for y in years] + } + + # 详细数据表格(列顺序固定) + def fmt(val): + try: + return None if val is None else round(float(val), 2) + except Exception: + return None + + detail_rows = [ + { + '指标': '营业总收入(百万元)', + **{y: fmt(income.get(y)) for y in years}, + }, + { + '指标': '增长率(%)', + **{y: fmt(v) for y, v in zip(years, growth_vals)}, + }, + { + '指标': '归母净利润(百万元)', + **{y: fmt(profit.get(y)) for y in years}, + }, + { + '指标': 'EPS(稀释)', + **{y: fmt(eps.get(y)) for y in years}, + }, + { + '指标': 'PE', + **{y: fmt(pe.get(y)) for y in years}, + }, + { + '指标': 'PEG', + **{y: fmt(peg.get(y)) for y in years}, + }, + ] + + return jsonify({ + 'success': True, + 'data': { + 'income_profit_trend': income_profit_trend, + 'growth_bars': growth_bars, + 'eps_trend': eps_trend, + 'pe_peg_axes': pe_peg_axes, + 'detail_table': { + 'years': years, + 'rows': detail_rows + } + } + }) + except Exception as e: + app.logger.error(f"forecast report error: {e}", exc_info=True) + return jsonify({'success': False, 'error': str(e)}), 500 + + +@app.route('/api/stock//basic-info', methods=['GET']) +def get_stock_basic_info(stock_code): + """获取股票基本信息(来自ea_baseinfo表)""" + try: + with engine.connect() as conn: + query = text(""" + SELECT SECCODE, + SECNAME, + ORGNAME, + F001V as en_name, + F002V as en_short_name, + F003V as legal_representative, + F004V as reg_address, + F005V as office_address, + F006V as post_code, + F007N as reg_capital, + F009V as currency, + F010D as establish_date, + F011V as website, + F012V as email, + F013V as tel, + F014V as fax, + F015V as main_business, + F016V as business_scope, + F017V as company_intro, + F018V as secretary, + F019V as secretary_tel, + F020V as secretary_fax, + F021V as secretary_email, + F024V as listing_status, + F026V as province, + F028V as city, + F030V as industry_l1, + F032V as industry_l2, + F034V as sw_industry_l1, + F036V as sw_industry_l2, + F038V as sw_industry_l3, + F039V as accounting_firm, + F040V as law_firm, + F041V as chairman, + F042V as general_manager, + F043V as independent_directors, + F050V as credit_code, + F054V as company_size, + UPDATE_DATE + FROM ea_baseinfo + WHERE SECCODE = :stock_code LIMIT 1 + """) + + result = conn.execute(query, {'stock_code': stock_code}).fetchone() + + if not result: + return jsonify({ + 'success': False, + 'error': f'未找到股票代码 {stock_code} 的基本信息' + }), 404 + + # 转换为字典 + basic_info = {} + for key, value in zip(result.keys(), result): + if isinstance(value, datetime): + basic_info[key] = value.strftime('%Y-%m-%d') + elif isinstance(value, Decimal): + basic_info[key] = float(value) + else: + basic_info[key] = value + + return jsonify({ + 'success': True, + 'data': basic_info + }) + + except Exception as e: + app.logger.error(f"Error getting stock basic info: {e}", exc_info=True) + return jsonify({'success': False, 'error': str(e)}), 500 + + +@app.route('/api/stock//announcements', methods=['GET']) +def get_stock_announcements(stock_code): + """获取股票公告列表""" + try: + limit = request.args.get('limit', 50, type=int) + + with engine.connect() as conn: + query = text(""" + SELECT F001D as announce_date, + F002V as title, + F003V as url, + F004V as format, + F005N as file_size, + F006V as info_type, + UPDATE_DATE + FROM ea_baseinfolist + WHERE SECCODE = :stock_code + ORDER BY F001D DESC LIMIT :limit + """) + + result = conn.execute(query, {'stock_code': stock_code, 'limit': limit}).fetchall() + + announcements = [] + for row in result: + announcement = {} + for key, value in zip(row.keys(), row): + if value is None: + announcement[key] = None + elif isinstance(value, datetime): + announcement[key] = value.strftime('%Y-%m-%d %H:%M:%S') + elif isinstance(value, date): + announcement[key] = value.strftime('%Y-%m-%d') + elif isinstance(value, Decimal): + announcement[key] = float(value) + else: + announcement[key] = value + announcements.append(announcement) + + return jsonify({ + 'success': True, + 'data': announcements, + 'total': len(announcements) + }) + + except Exception as e: + app.logger.error(f"Error getting stock announcements: {e}", exc_info=True) + return jsonify({'success': False, 'error': str(e)}), 500 + + +@app.route('/api/stock//disclosure-schedule', methods=['GET']) +def get_stock_disclosure_schedule(stock_code): + """获取股票财报预披露时间表""" + try: + with engine.connect() as conn: + query = text(""" + SELECT distinct F001D as report_period, + F002D as scheduled_date, + F003D as change_date1, + F004D as change_date2, + F005D as change_date3, + F006D as actual_date, + F007D as change_date4, + F008D as change_date5, + MODTIME as mod_time + FROM ea_pretime + WHERE SECCODE = :stock_code + ORDER BY F001D DESC LIMIT 20 + """) + + result = conn.execute(query, {'stock_code': stock_code}).fetchall() + + schedules = [] + for row in result: + schedule = {} + for key, value in zip(row.keys(), row): + if value is None: + schedule[key] = None + elif isinstance(value, datetime): + schedule[key] = value.strftime('%Y-%m-%d %H:%M:%S') + elif isinstance(value, date): + schedule[key] = value.strftime('%Y-%m-%d') + elif isinstance(value, Decimal): + schedule[key] = float(value) + else: + schedule[key] = value + + # 计算最新的预约日期 + latest_scheduled = schedule.get('scheduled_date') + for change_field in ['change_date5', 'change_date4', 'change_date3', 'change_date2', 'change_date1']: + if schedule.get(change_field): + latest_scheduled = schedule[change_field] + break + + schedule['latest_scheduled_date'] = latest_scheduled + schedule['is_disclosed'] = bool(schedule.get('actual_date')) + + # 格式化报告期名称 + if schedule.get('report_period'): + period_date = schedule['report_period'] + if period_date.endswith('-03-31'): + schedule['report_name'] = f"{period_date[:4]}年一季报" + elif period_date.endswith('-06-30'): + schedule['report_name'] = f"{period_date[:4]}年中报" + elif period_date.endswith('-09-30'): + schedule['report_name'] = f"{period_date[:4]}年三季报" + elif period_date.endswith('-12-31'): + schedule['report_name'] = f"{period_date[:4]}年年报" + else: + schedule['report_name'] = period_date + + schedules.append(schedule) + + return jsonify({ + 'success': True, + 'data': schedules, + 'total': len(schedules) + }) + + except Exception as e: + app.logger.error(f"Error getting disclosure schedule: {e}", exc_info=True) + return jsonify({'success': False, 'error': str(e)}), 500 + + +@app.route('/api/stock//actual-control', methods=['GET']) +def get_stock_actual_control(stock_code): + """获取股票实际控制人信息""" + try: + with engine.connect() as conn: + query = text(""" + SELECT DECLAREDATE as declare_date, + ENDDATE as end_date, + F001V as direct_holder_id, + F002V as direct_holder_name, + F003V as actual_controller_id, + F004V as actual_controller_name, + F005N as holding_shares, + F006N as holding_ratio, + F007V as control_type_code, + F008V as control_type, + F012V as direct_controller_id, + F013V as direct_controller_name, + F014V as controller_type, + ORGNAME as org_name, + SECCODE as sec_code, + SECNAME as sec_name + FROM ea_actualcon + WHERE SECCODE = :stock_code + ORDER BY ENDDATE DESC, DECLAREDATE DESC LIMIT 20 + """) + + result = conn.execute(query, {'stock_code': stock_code}).fetchall() + + control_info = [] + for row in result: + control_record = {} + for key, value in zip(row.keys(), row): + if value is None: + control_record[key] = None + elif isinstance(value, datetime): + control_record[key] = value.strftime('%Y-%m-%d %H:%M:%S') + elif isinstance(value, date): + control_record[key] = value.strftime('%Y-%m-%d') + elif isinstance(value, Decimal): + control_record[key] = float(value) + else: + control_record[key] = value + + control_info.append(control_record) + + return jsonify({ + 'success': True, + 'data': control_info, + 'total': len(control_info) + }) + + except Exception as e: + app.logger.error(f"Error getting actual control info: {e}", exc_info=True) + return jsonify({'success': False, 'error': str(e)}), 500 + + +@app.route('/api/stock//concentration', methods=['GET']) +def get_stock_concentration(stock_code): + """获取股票股权集中度信息""" + try: + with engine.connect() as conn: + query = text(""" + SELECT ENDDATE as end_date, + F001V as stat_item, + F002N as holding_shares, + F003N as holding_ratio, + F004N as ratio_change, + ORGNAME as org_name, + SECCODE as sec_code, + SECNAME as sec_name + FROM ea_concentration + WHERE SECCODE = :stock_code + ORDER BY ENDDATE DESC LIMIT 20 + """) + + result = conn.execute(query, {'stock_code': stock_code}).fetchall() + + concentration_info = [] + for row in result: + concentration_record = {} + for key, value in zip(row.keys(), row): + if value is None: + concentration_record[key] = None + elif isinstance(value, datetime): + concentration_record[key] = value.strftime('%Y-%m-%d %H:%M:%S') + elif isinstance(value, date): + concentration_record[key] = value.strftime('%Y-%m-%d') + elif isinstance(value, Decimal): + concentration_record[key] = float(value) + else: + concentration_record[key] = value + + concentration_info.append(concentration_record) + + return jsonify({ + 'success': True, + 'data': concentration_info, + 'total': len(concentration_info) + }) + + except Exception as e: + app.logger.error(f"Error getting concentration info: {e}", exc_info=True) + return jsonify({'success': False, 'error': str(e)}), 500 + + +@app.route('/api/stock//management', methods=['GET']) +def get_stock_management(stock_code): + """获取股票管理层信息""" + try: + # 获取是否只显示在职人员参数 + active_only = request.args.get('active_only', 'true').lower() == 'true' + + with engine.connect() as conn: + base_query = """ + SELECT DECLAREDATE as declare_date, \ + F001V as person_id, \ + F002V as name, \ + F007D as start_date, \ + F008D as end_date, \ + F009V as position_name, \ + F010V as gender, \ + F011V as education, \ + F012V as birth_year, \ + F013V as nationality, \ + F014V as position_category_code, \ + F015V as position_category, \ + F016V as position_code, \ + F017V as highest_degree, \ + F019V as resume, \ + F020C as is_active, \ + ORGNAME as org_name, \ + SECCODE as sec_code, \ + SECNAME as sec_name + FROM ea_management + WHERE SECCODE = :stock_code \ + """ + + if active_only: + base_query += " AND F020C = '1'" + + base_query += " ORDER BY DECLAREDATE DESC, F007D DESC" + + query = text(base_query) + + result = conn.execute(query, {'stock_code': stock_code}).fetchall() + + management_info = [] + for row in result: + management_record = {} + for key, value in zip(row.keys(), row): + if value is None: + management_record[key] = None + elif isinstance(value, datetime): + management_record[key] = value.strftime('%Y-%m-%d %H:%M:%S') + elif isinstance(value, date): + management_record[key] = value.strftime('%Y-%m-%d') + elif isinstance(value, Decimal): + management_record[key] = float(value) + else: + management_record[key] = value + + management_info.append(management_record) + + return jsonify({ + 'success': True, + 'data': management_info, + 'total': len(management_info) + }) + + except Exception as e: + app.logger.error(f"Error getting management info: {e}", exc_info=True) + return jsonify({'success': False, 'error': str(e)}), 500 + + +@app.route('/api/stock//top-circulation-shareholders', methods=['GET']) +def get_stock_top_circulation_shareholders(stock_code): + """获取股票十大流通股东信息""" + try: + limit = request.args.get('limit', 10, type=int) + + with engine.connect() as conn: + query = text(""" + SELECT DECLAREDATE as declare_date, + ENDDATE as end_date, + F001N as shareholder_rank, + F002V as shareholder_id, + F003V as shareholder_name, + F004V as shareholder_type, + F005N as holding_shares, + F006N as total_share_ratio, + F007N as circulation_share_ratio, + F011V as share_nature, + F012N as b_shares, + F013N as h_shares, + F014N as other_shares, + ORGNAME as org_name, + SECCODE as sec_code, + SECNAME as sec_name + FROM ea_tencirculation + WHERE SECCODE = :stock_code + ORDER BY ENDDATE DESC, F001N ASC LIMIT :limit + """) + + result = conn.execute(query, {'stock_code': stock_code, 'limit': limit}).fetchall() + + shareholders_info = [] + for row in result: + shareholder_record = {} + for key, value in zip(row.keys(), row): + if value is None: + shareholder_record[key] = None + elif isinstance(value, datetime): + shareholder_record[key] = value.strftime('%Y-%m-%d %H:%M:%S') + elif isinstance(value, date): + shareholder_record[key] = value.strftime('%Y-%m-%d') + elif isinstance(value, Decimal): + shareholder_record[key] = float(value) + else: + shareholder_record[key] = value + + shareholders_info.append(shareholder_record) + + return jsonify({ + 'success': True, + 'data': shareholders_info, + 'total': len(shareholders_info) + }) + + except Exception as e: + app.logger.error(f"Error getting top circulation shareholders: {e}", exc_info=True) + return jsonify({'success': False, 'error': str(e)}), 500 + + +@app.route('/api/stock//top-shareholders', methods=['GET']) +def get_stock_top_shareholders(stock_code): + """获取股票十大股东信息""" + try: + limit = request.args.get('limit', 10, type=int) + + with engine.connect() as conn: + query = text(""" + SELECT DECLAREDATE as declare_date, + ENDDATE as end_date, + F001N as shareholder_rank, + F002V as shareholder_name, + F003V as shareholder_id, + F004V as shareholder_type, + F005N as holding_shares, + F006N as total_share_ratio, + F007N as circulation_share_ratio, + F011V as share_nature, + F016N as restricted_shares, + F017V as concert_party_group, + F018N as circulation_shares, + ORGNAME as org_name, + SECCODE as sec_code, + SECNAME as sec_name + FROM ea_tenshareholder + WHERE SECCODE = :stock_code + ORDER BY ENDDATE DESC, F001N ASC LIMIT :limit + """) + + result = conn.execute(query, {'stock_code': stock_code, 'limit': limit}).fetchall() + + shareholders_info = [] + for row in result: + shareholder_record = {} + for key, value in zip(row.keys(), row): + if value is None: + shareholder_record[key] = None + elif isinstance(value, datetime): + shareholder_record[key] = value.strftime('%Y-%m-%d %H:%M:%S') + elif isinstance(value, date): + shareholder_record[key] = value.strftime('%Y-%m-%d') + elif isinstance(value, Decimal): + shareholder_record[key] = float(value) + else: + shareholder_record[key] = value + + shareholders_info.append(shareholder_record) + + return jsonify({ + 'success': True, + 'data': shareholders_info, + 'total': len(shareholders_info) + }) + + except Exception as e: + app.logger.error(f"Error getting top shareholders: {e}", exc_info=True) + return jsonify({'success': False, 'error': str(e)}), 500 + + +@app.route('/api/stock//branches', methods=['GET']) +def get_stock_branches(stock_code): + """获取股票分支机构信息""" + try: + with engine.connect() as conn: + query = text(""" + SELECT CRECODE as cre_code, + F001V as branch_name, + F002V as register_capital, + F003V as business_status, + F004D as register_date, + F005N as related_company_count, + F006V as legal_person, + ORGNAME as org_name, + SECCODE as sec_code, + SECNAME as sec_name + FROM ea_branch + WHERE SECCODE = :stock_code + ORDER BY F004D DESC + """) + + result = conn.execute(query, {'stock_code': stock_code}).fetchall() + + branches_info = [] + for row in result: + branch_record = {} + for key, value in zip(row.keys(), row): + if value is None: + branch_record[key] = None + elif isinstance(value, datetime): + branch_record[key] = value.strftime('%Y-%m-%d %H:%M:%S') + elif isinstance(value, date): + branch_record[key] = value.strftime('%Y-%m-%d') + elif isinstance(value, Decimal): + branch_record[key] = float(value) + else: + branch_record[key] = value + + branches_info.append(branch_record) + + return jsonify({ + 'success': True, + 'data': branches_info, + 'total': len(branches_info) + }) + + except Exception as e: + app.logger.error(f"Error getting branches info: {e}", exc_info=True) + return jsonify({'success': False, 'error': str(e)}), 500 + + +@app.route('/api/stock//patents', methods=['GET']) +def get_stock_patents(stock_code): + """获取股票专利信息""" + try: + limit = request.args.get('limit', 50, type=int) + patent_type = request.args.get('type', None) # 专利类型筛选 + + with engine.connect() as conn: + base_query = """ + SELECT CRECODE as cre_code, \ + F001V as patent_name, \ + F002V as application_number, \ + F003V as publication_number, \ + F004V as classification_number, \ + F005D as publication_date, \ + F006D as application_date, \ + F007V as patent_type, \ + F008V as applicant, \ + F009V as inventor, \ + ID as id, \ + ORGNAME as org_name, \ + SECCODE as sec_code, \ + SECNAME as sec_name + FROM ea_patent + WHERE SECCODE = :stock_code \ + """ + + params = {'stock_code': stock_code, 'limit': limit} + + if patent_type: + base_query += " AND F007V = :patent_type" + params['patent_type'] = patent_type + + base_query += " ORDER BY F006D DESC, F005D DESC LIMIT :limit" + + query = text(base_query) + + result = conn.execute(query, params).fetchall() + + patents_info = [] + for row in result: + patent_record = {} + for key, value in zip(row.keys(), row): + if value is None: + patent_record[key] = None + elif isinstance(value, datetime): + patent_record[key] = value.strftime('%Y-%m-%d %H:%M:%S') + elif isinstance(value, date): + patent_record[key] = value.strftime('%Y-%m-%d') + elif isinstance(value, Decimal): + patent_record[key] = float(value) + else: + patent_record[key] = value + + patents_info.append(patent_record) + + return jsonify({ + 'success': True, + 'data': patents_info, + 'total': len(patents_info) + }) + + except Exception as e: + app.logger.error(f"Error getting patents info: {e}", exc_info=True) + return jsonify({'success': False, 'error': str(e)}), 500 + + +def get_daily_kline(stock_code, event_datetime, stock_name): + """处理日K线数据""" + stock_code = stock_code.split('.')[0] + + with engine.connect() as conn: + # 获取事件日期前后的数据 + kline_sql = """ + WITH date_range AS (SELECT TRADEDATE \ + FROM ea_trade \ + WHERE SECCODE = :stock_code \ + AND TRADEDATE BETWEEN DATE_SUB(:trade_date, INTERVAL 60 DAY) \ + AND DATE_ADD(:trade_date, INTERVAL 30 DAY) \ + GROUP BY TRADEDATE \ + ORDER BY TRADEDATE) + SELECT t.TRADEDATE, + CAST(t.F003N AS FLOAT) as open, + CAST(t.F007N AS FLOAT) as close, + CAST(t.F005N AS FLOAT) as high, + CAST(t.F006N AS FLOAT) as low, + CAST(t.F004N AS FLOAT) as volume + FROM ea_trade t + JOIN date_range d \ + ON t.TRADEDATE = d.TRADEDATE + WHERE t.SECCODE = :stock_code + ORDER BY t.TRADEDATE \ + """ + + result = conn.execute(text(kline_sql), { + "stock_code": stock_code, + "trade_date": event_datetime.date() + }).fetchall() + + if not result: + return jsonify({ + 'error': 'No data available', + 'code': stock_code, + 'name': stock_name, + 'data': [], + 'trade_date': event_datetime.date().strftime('%Y-%m-%d'), + 'type': 'daily' + }) + + kline_data = [{ + 'time': row.TRADEDATE.strftime('%Y-%m-%d'), + 'open': float(row.open), + 'high': float(row.high), + 'low': float(row.low), + 'close': float(row.close), + 'volume': float(row.volume) + } for row in result] + + return jsonify({ + 'code': stock_code, + 'name': stock_name, + 'data': kline_data, + 'trade_date': event_datetime.date().strftime('%Y-%m-%d'), + 'type': 'daily', + 'is_history': True + }) + + +def get_minute_kline(stock_code, event_datetime, stock_name): + """处理分钟K线数据""" + client = get_clickhouse_client() + + target_date = get_trading_day_near_date(event_datetime.date()) + is_after_market = event_datetime.time() > dt_time(15, 0) + + # 核心逻辑改动:先判断当前日期是否是交易日,以及是否已收盘 + if target_date and is_after_market: + # 如果是交易日且已收盘,查找下一个交易日 + next_trade_date = get_trading_day_near_date(target_date + timedelta(days=1)) + if next_trade_date: + target_date = next_trade_date + + if not target_date: + return jsonify({ + 'error': 'No data available', + 'code': stock_code, + 'name': stock_name, + 'data': [], + 'trade_date': event_datetime.date().strftime('%Y-%m-%d'), + 'type': 'minute' + }) + + # 获取目标日期的完整交易时段数据 + data = client.execute(""" + SELECT + timestamp, open, high, low, close, volume, amt + FROM stock_minute + WHERE code = %(code)s + AND timestamp BETWEEN %(start)s + AND %(end)s + ORDER BY timestamp + """, { + 'code': stock_code, + 'start': datetime.combine(target_date, dt_time(9, 30)), + 'end': datetime.combine(target_date, dt_time(15, 0)) + }) + + kline_data = [{ + 'time': row[0].strftime('%H:%M'), + 'open': float(row[1]), + 'high': float(row[2]), + 'low': float(row[3]), + 'close': float(row[4]), + 'volume': float(row[5]), + 'amount': float(row[6]) + } for row in data] + + return jsonify({ + 'code': stock_code, + 'name': stock_name, + 'data': kline_data, + 'trade_date': target_date.strftime('%Y-%m-%d'), + 'type': 'minute', + 'is_history': target_date < event_datetime.date() + }) + + +def get_timeline_data(stock_code, event_datetime, stock_name): + """处理分时均价线数据(timeline)。 + 规则: + - 若事件时间在交易日的15:00之后,则展示下一个交易日的分时数据; + - 若事件日非交易日,优先展示下一个交易日;如无,则回退到最近一个交易日; + - 数据区间固定为 09:30-15:00。 + """ + client = get_clickhouse_client() + + target_date = get_trading_day_near_date(event_datetime.date()) + is_after_market = event_datetime.time() > dt_time(15, 0) + + # 与分钟K逻辑保持一致的日期选择规则 + if target_date and is_after_market: + next_trade_date = get_trading_day_near_date(target_date + timedelta(days=1)) + if next_trade_date: + target_date = next_trade_date + + if not target_date: + return jsonify({ + 'error': 'No data available', + 'code': stock_code, + 'name': stock_name, + 'data': [], + 'trade_date': event_datetime.date().strftime('%Y-%m-%d'), + 'type': 'timeline' + }) + + # 获取昨收盘价 + prev_close_query = """ + SELECT close + FROM stock_minute + WHERE code = %(code)s + AND timestamp \ + < %(start)s + ORDER BY timestamp DESC + LIMIT 1 \ + """ + + prev_close_result = client.execute(prev_close_query, { + 'code': stock_code, + 'start': datetime.combine(target_date, dt_time(9, 30)) + }) + + prev_close = float(prev_close_result[0][0]) if prev_close_result else None + + data = client.execute( + """ + SELECT + timestamp, close, volume + FROM stock_minute + WHERE code = %(code)s + AND timestamp BETWEEN %(start)s + AND %(end)s + ORDER BY timestamp + """, + { + 'code': stock_code, + 'start': datetime.combine(target_date, dt_time(9, 30)), + 'end': datetime.combine(target_date, dt_time(15, 0)), + } + ) + + timeline_data = [] + total_amount = 0 + total_volume = 0 + for row in data: + price = float(row[1]) + volume = float(row[2]) + total_amount += price * volume + total_volume += volume + avg_price = total_amount / total_volume if total_volume > 0 else price + + # 计算涨跌幅 + change_percent = ((price - prev_close) / prev_close * 100) if prev_close else 0 + + timeline_data.append({ + 'time': row[0].strftime('%H:%M'), + 'price': price, + 'avg_price': avg_price, + 'volume': volume, + 'change_percent': change_percent, + }) + + return jsonify({ + 'code': stock_code, + 'name': stock_name, + 'data': timeline_data, + 'trade_date': target_date.strftime('%Y-%m-%d'), + 'type': 'timeline', + 'is_history': target_date < event_datetime.date(), + 'prev_close': prev_close, + }) + + +# ==================== 指数行情API(与股票逻辑一致,数据表为 index_minute) ==================== +@app.route('/api/index//kline') +def get_index_kline(index_code): + chart_type = request.args.get('type', 'minute') + event_time = request.args.get('event_time') + + try: + event_datetime = datetime.fromisoformat(event_time) if event_time else datetime.now() + except ValueError: + return jsonify({'error': 'Invalid event_time format'}), 400 + + # 指数名称(暂无索引表,先返回代码本身) + index_name = index_code + + if chart_type == 'minute': + return get_index_minute_kline(index_code, event_datetime, index_name) + elif chart_type == 'timeline': + return get_index_timeline_data(index_code, event_datetime, index_name) + elif chart_type == 'daily': + return get_index_daily_kline(index_code, event_datetime, index_name) + else: + return jsonify({'error': f'Unsupported chart type: {chart_type}'}), 400 + + +def get_index_minute_kline(index_code, event_datetime, index_name): + client = get_clickhouse_client() + target_date = get_trading_day_near_date(event_datetime.date()) + + if not target_date: + return jsonify({ + 'error': 'No data available', + 'code': index_code, + 'name': index_name, + 'data': [], + 'trade_date': event_datetime.date().strftime('%Y-%m-%d'), + 'type': 'minute' + }) + + data = client.execute( + """ + SELECT timestamp, open, high, low, close, volume, amt + FROM index_minute + WHERE code = %(code)s + AND timestamp BETWEEN %(start)s + AND %(end)s + ORDER BY timestamp + """, + { + 'code': index_code, + 'start': datetime.combine(target_date, dt_time(9, 30)), + 'end': datetime.combine(target_date, dt_time(15, 0)), + } + ) + + kline_data = [{ + 'time': row[0].strftime('%H:%M'), + 'open': float(row[1]), + 'high': float(row[2]), + 'low': float(row[3]), + 'close': float(row[4]), + 'volume': float(row[5]), + 'amount': float(row[6]), + } for row in data] + + return jsonify({ + 'code': index_code, + 'name': index_name, + 'data': kline_data, + 'trade_date': target_date.strftime('%Y-%m-%d'), + 'type': 'minute', + 'is_history': target_date < event_datetime.date(), + }) + + +def get_index_timeline_data(index_code, event_datetime, index_name): + client = get_clickhouse_client() + target_date = get_trading_day_near_date(event_datetime.date()) + + if not target_date: + return jsonify({ + 'error': 'No data available', + 'code': index_code, + 'name': index_name, + 'data': [], + 'trade_date': event_datetime.date().strftime('%Y-%m-%d'), + 'type': 'timeline' + }) + + data = client.execute( + """ + SELECT timestamp, close, volume + FROM index_minute + WHERE code = %(code)s + AND timestamp BETWEEN %(start)s + AND %(end)s + ORDER BY timestamp + """, + { + 'code': index_code, + 'start': datetime.combine(target_date, dt_time(9, 30)), + 'end': datetime.combine(target_date, dt_time(15, 0)), + } + ) + + timeline = [] + total_amount = 0 + total_volume = 0 + for row in data: + price = float(row[1]) + volume = float(row[2]) + total_amount += price * volume + total_volume += volume + avg_price = total_amount / total_volume if total_volume > 0 else price + timeline.append({ + 'time': row[0].strftime('%H:%M'), + 'price': price, + 'avg_price': avg_price, + 'volume': volume, + }) + + return jsonify({ + 'code': index_code, + 'name': index_name, + 'data': timeline, + 'trade_date': target_date.strftime('%Y-%m-%d'), + 'type': 'timeline', + 'is_history': target_date < event_datetime.date(), + }) + + +def get_index_daily_kline(index_code, event_datetime, index_name): + """从 MySQL 的 stock.ea_exchangetrade 获取指数日线 + 注意:表中 INDEXCODE 无后缀,例如 000001.SH -> 000001 + 字段: + F003N 开市指数 -> open + F004N 最高指数 -> high + F005N 最低指数 -> low + F006N 最近指数 -> close(作为当日收盘或最近价使用) + F007N 昨日收市指数 -> prev_close + """ + # 去掉后缀 + code_no_suffix = index_code.split('.')[0] + + # 选择展示的最后交易日 + target_date = get_trading_day_near_date(event_datetime.date()) + if not target_date: + return jsonify({ + 'error': 'No data available', + 'code': index_code, + 'name': index_name, + 'data': [], + 'trade_date': event_datetime.date().strftime('%Y-%m-%d'), + 'type': 'daily' + }) + + # 取最近一段时间的日线(倒序再反转为升序) + with engine.connect() as conn: + rows = conn.execute(text( + """ + SELECT TRADEDATE, F003N, F004N, F005N, F006N, F007N + FROM ea_exchangetrade + WHERE INDEXCODE = :code + AND TRADEDATE <= :end_dt + ORDER BY TRADEDATE DESC LIMIT 180 + """ + ), { + 'code': code_no_suffix, + 'end_dt': datetime.combine(target_date, dt_time(23, 59, 59)) + }).fetchall() + + # 反转为时间升序 + rows = list(reversed(rows)) + + daily = [] + for i, r in enumerate(rows): + trade_dt = r[0] + open_v = r[1] + high_v = r[2] + low_v = r[3] + last_v = r[4] + prev_close_v = r[5] + + # 正确的前收盘价逻辑:使用前一个交易日的F006N(收盘价) + calculated_prev_close = None + if i > 0 and rows[i - 1][4] is not None: + # 使用前一个交易日的收盘价作为前收盘价 + calculated_prev_close = float(rows[i - 1][4]) + else: + # 第一条记录,尝试使用F007N字段作为备选 + if prev_close_v is not None and prev_close_v > 0: + calculated_prev_close = float(prev_close_v) + + daily.append({ + 'time': trade_dt.strftime('%Y-%m-%d') if hasattr(trade_dt, 'strftime') else str(trade_dt), + 'open': float(open_v) if open_v is not None else None, + 'high': float(high_v) if high_v is not None else None, + 'low': float(low_v) if low_v is not None else None, + 'close': float(last_v) if last_v is not None else None, + 'prev_close': calculated_prev_close, + }) + + return jsonify({ + 'code': index_code, + 'name': index_name, + 'data': daily, + 'trade_date': target_date.strftime('%Y-%m-%d'), + 'type': 'daily', + 'is_history': target_date < event_datetime.date(), + }) + + +# ==================== 日历API ==================== +@app.route('/api/v1/calendar/event-counts', methods=['GET']) +def get_event_counts(): + """获取日历事件数量统计""" + try: + # 获取月份参数 + year = request.args.get('year', datetime.now().year, type=int) + month = request.args.get('month', datetime.now().month, type=int) + + # 计算月份的开始和结束日期 + start_date = datetime(year, month, 1) + if month == 12: + end_date = datetime(year + 1, 1, 1) + else: + end_date = datetime(year, month + 1, 1) + + # 查询事件数量 + query = """ + SELECT DATE(calendar_time) as date, COUNT(*) as count + FROM future_events + WHERE calendar_time BETWEEN :start_date AND :end_date + AND type = 'event' + GROUP BY DATE(calendar_time) +""" + + result = db.session.execute(text(query), { + 'start_date': start_date, + 'end_date': end_date + }) + + # 格式化结果 + events = [] + for day in result: + events.append({ + 'date': day.date.isoformat(), + 'count': day.count, + 'className': get_event_class(day.count) + }) + + return jsonify({ + 'success': True, + 'data': events + }) + + except Exception as e: + return jsonify({ + 'success': False, + 'error': str(e) + }), 500 + + +@app.route('/api/v1/calendar/events', methods=['GET']) +def get_calendar_events(): + """获取指定日期的事件列表""" + date_str = request.args.get('date') + event_type = request.args.get('type', 'all') + + if not date_str: + return jsonify({ + 'success': False, + 'error': 'Date parameter required' + }), 400 + + try: + date = datetime.strptime(date_str, '%Y-%m-%d') + except ValueError: + return jsonify({ + 'success': False, + 'error': 'Invalid date format' + }), 400 + + # 修复SQL语法:去掉函数名后的空格,去掉参数前的空格 + query = """ + SELECT * + FROM future_events + WHERE DATE(calendar_time) = :date + """ + + params = {'date': date} + + if event_type != 'all': + query += " AND type = :type" + params['type'] = event_type + + query += " ORDER BY calendar_time" + + result = db.session.execute(text(query), params) + + events = [] + user_following_ids = set() + if 'user_id' in session: + follows = FutureEventFollow.query.filter_by(user_id=session['user_id']).all() + user_following_ids = {f.future_event_id for f in follows} + + for row in result: + event_data = { + 'id': row.data_id, + 'title': row.title, + 'type': row.type, + 'calendar_time': row.calendar_time.isoformat(), + 'star': row.star, + 'former': row.former, + 'forecast': row.forecast, + 'fact': row.fact, + 'is_following': row.data_id in user_following_ids + } + + # 解析相关股票和概念 + if row.related_stocks: + try: + if isinstance(row.related_stocks, str): + if row.related_stocks.startswith('['): + event_data['related_stocks'] = json.loads(row.related_stocks) + else: + event_data['related_stocks'] = row.related_stocks.split(',') + else: + event_data['related_stocks'] = row.related_stocks + except: + event_data['related_stocks'] = [] + else: + event_data['related_stocks'] = [] + + if row.concepts: + try: + if isinstance(row.concepts, str): + if row.concepts.startswith('['): + event_data['concepts'] = json.loads(row.concepts) + else: + event_data['concepts'] = row.concepts.split(',') + else: + event_data['concepts'] = row.concepts + except: + event_data['concepts'] = [] + else: + event_data['concepts'] = [] + + events.append(event_data) + + return jsonify({ + 'success': True, + 'data': events + }) + +@app.route('/api/v1/calendar/events/', methods=['GET']) +def get_calendar_event_detail(event_id): + """获取日历事件详情""" + try: + sql = """ + SELECT * + FROM future_events + WHERE data_id = :event_id \ + """ + + result = db.session.execute(text(sql), {'event_id': event_id}).first() + + if not result: + return jsonify({ + 'success': False, + 'error': 'Event not found' + }), 404 + + event_data = { + 'id': result.data_id, + 'title': result.title, + 'type': result.type, + 'calendar_time': result.calendar_time.isoformat(), + 'star': result.star, + 'former': result.former, + 'forecast': result.forecast, + 'fact': result.fact, + 'related_stocks': parse_json_field(result.related_stocks), + 'concepts': parse_json_field(result.concepts) + } + + # 检查当前用户是否关注了该未来事件 + if 'user_id' in session: + is_following = FutureEventFollow.query.filter_by( + user_id=session['user_id'], + future_event_id=event_id + ).first() is not None + event_data['is_following'] = is_following + else: + event_data['is_following'] = False + + return jsonify({ + 'success': True, + 'data': event_data + }) + + except Exception as e: + return jsonify({ + 'success': False, + 'error': str(e) + }), 500 + + +@app.route('/api/v1/calendar/events//follow', methods=['POST']) +def toggle_future_event_follow(event_id): + """切换未来事件关注状态(需登录)""" + if 'user_id' not in session: + return jsonify({'success': False, 'error': '未登录'}), 401 + + try: + # 检查未来事件是否存在 + sql = """ + SELECT data_id \ + FROM future_events \ + WHERE data_id = :event_id \ + """ + result = db.session.execute(text(sql), {'event_id': event_id}).first() + + if not result: + return jsonify({'success': False, 'error': '未来事件不存在'}), 404 + + user_id = session['user_id'] + + # 检查是否已关注 + existing = FutureEventFollow.query.filter_by( + user_id=user_id, + future_event_id=event_id + ).first() + + if existing: + # 取消关注 + db.session.delete(existing) + db.session.commit() + return jsonify({ + 'success': True, + 'data': {'is_following': False} + }) + else: + # 关注 + follow = FutureEventFollow( + user_id=user_id, + future_event_id=event_id + ) + db.session.add(follow) + db.session.commit() + return jsonify({ + 'success': True, + 'data': {'is_following': True} + }) + except Exception as e: + db.session.rollback() + return jsonify({'success': False, 'error': str(e)}), 500 + + +def get_event_class(count): + """根据事件数量返回CSS类名""" + if count >= 10: + return 'event-high' + elif count >= 5: + return 'event-medium' + elif count > 0: + return 'event-low' + return '' + + +def parse_json_field(field_value): + """解析JSON字段""" + if not field_value: + return [] + try: + if isinstance(field_value, str): + if field_value.startswith('['): + return json.loads(field_value) + else: + return field_value.split(',') + else: + return field_value + except: + return [] + + +# ==================== 行业API ==================== +@app.route('/api/classifications', methods=['GET']) +def get_classifications(): + """获取申银万国行业分类树形结构""" + try: + # 查询申银万国行业分类的所有数据 + sql = """ + SELECT f003v as code, f004v as level1, f005v as level2, f006v as level3,f007v as level4 + FROM ea_sector + WHERE f002v = '申银万国行业分类' + AND f003v IS NOT NULL + AND f004v IS NOT NULL + ORDER BY f003v + """ + + result = db.session.execute(text(sql)).all() + + # 构建树形结构 + tree_dict = {} + + for row in result: + code = row.code + level1 = row.level1 + level2 = row.level2 + level3 = row.level3 + + # 跳过空数据 + if not level1: + continue + + # 第一层 + if level1 not in tree_dict: + # 获取第一层的code(取前3位或前缀) + level1_code = code[:3] if len(code) >= 3 else code + tree_dict[level1] = { + 'value': level1_code, + 'label': level1, + 'children_dict': {} + } + + # 第二层 + if level2: + if level2 not in tree_dict[level1]['children_dict']: + # 获取第二层的code(取前6位) + level2_code = code[:6] if len(code) >= 6 else code + tree_dict[level1]['children_dict'][level2] = { + 'value': level2_code, + 'label': level2, + 'children_dict': {} + } + + # 第三层 + if level3: + if level3 not in tree_dict[level1]['children_dict'][level2]['children_dict']: + tree_dict[level1]['children_dict'][level2]['children_dict'][level3] = { + 'value': code, + 'label': level3 + } + + # 转换为最终格式 + result_list = [] + for level1_name, level1_data in tree_dict.items(): + level1_node = { + 'value': level1_data['value'], + 'label': level1_data['label'] + } + + # 处理第二层 + if level1_data['children_dict']: + level1_children = [] + for level2_name, level2_data in level1_data['children_dict'].items(): + level2_node = { + 'value': level2_data['value'], + 'label': level2_data['label'] + } + + # 处理第三层 + if level2_data['children_dict']: + level2_children = [] + for level3_name, level3_data in level2_data['children_dict'].items(): + level2_children.append({ + 'value': level3_data['value'], + 'label': level3_data['label'] + }) + if level2_children: + level2_node['children'] = level2_children + + level1_children.append(level2_node) + + if level1_children: + level1_node['children'] = level1_children + + result_list.append(level1_node) + + return jsonify({ + 'success': True, + 'data': result_list + }) + + except Exception as e: + return jsonify({ + 'success': False, + 'error': str(e) + }), 500 + + +@app.route('/api/stocklist', methods=['GET']) +def get_stock_list(): + """获取股票列表""" + try: + sql = """ + SELECT DISTINCT SECCODE as code, SECNAME as name + FROM ea_stocklist + ORDER BY SECCODE + """ + + result = db.session.execute(text(sql)).all() + + stocks = [{'code': row.code, 'name': row.name} for row in result] + + return jsonify(stocks) + + except Exception as e: + return jsonify({ + 'success': False, + 'error': str(e) + }), 500 + + +@app.route('/api/events', methods=['GET'], strict_slashes=False) +def api_get_events(): + """ + 获取事件列表API - 支持筛选、排序、分页,兼容前端调用 + """ + try: + # 分页参数 + page = max(1, request.args.get('page', 1, type=int)) + per_page = min(100, max(1, request.args.get('per_page', 10, type=int))) + + # 基础筛选参数 + event_type = request.args.get('type', 'all') + event_status = request.args.get('status', 'active') + importance = request.args.get('importance', 'all') + + # 日期筛选参数 + start_date = request.args.get('start_date') + end_date = request.args.get('end_date') + date_range = request.args.get('date_range') + recent_days = request.args.get('recent_days', type=int) + + # 行业筛选参数(只支持申银万国行业分类) + industry_code = request.args.get('industry_code') # 申万行业代码,如 "S370502" + + # 概念/标签筛选参数 + tag = request.args.get('tag') + tags = request.args.get('tags') + keywords = request.args.get('keywords') + + # 搜索参数 + search_query = request.args.get('q') + search_type = request.args.get('search_type', 'topic') + search_fields = request.args.get('search_fields', 'title,description').split(',') + + # 排序参数 + sort_by = request.args.get('sort', 'new') + return_type = request.args.get('return_type', 'avg') + order = request.args.get('order', 'desc') + + # 收益率筛选参数 + min_avg_return = request.args.get('min_avg_return', type=float) + max_avg_return = request.args.get('max_avg_return', type=float) + min_max_return = request.args.get('min_max_return', type=float) + max_max_return = request.args.get('max_max_return', type=float) + min_week_return = request.args.get('min_week_return', type=float) + max_week_return = request.args.get('max_week_return', type=float) + + # 其他筛选参数 + min_hot_score = request.args.get('min_hot_score', type=float) + max_hot_score = request.args.get('max_hot_score', type=float) + min_view_count = request.args.get('min_view_count', type=int) + creator_id = request.args.get('creator_id', type=int) + + # 返回格式参数 + include_creator = request.args.get('include_creator', 'true').lower() == 'true' + include_stats = request.args.get('include_stats', 'true').lower() == 'true' + include_related_data = request.args.get('include_related_data', 'false').lower() == 'true' + + # ==================== 构建查询 ==================== + query = Event.query + if event_status != 'all': + query = query.filter_by(status=event_status) + if event_type != 'all': + query = query.filter_by(event_type=event_type) + # 支持多个重要性级别筛选,用逗号分隔(如 importance=S,A) + if importance != 'all': + if ',' in importance: + # 多个重要性级别 + importance_list = [imp.strip() for imp in importance.split(',') if imp.strip()] + query = query.filter(Event.importance.in_(importance_list)) + else: + # 单个重要性级别 + query = query.filter_by(importance=importance) + if creator_id: + query = query.filter_by(creator_id=creator_id) + # 新增:行业代码过滤(申银万国行业分类) + if industry_code: + # related_industries 格式: [{"申银万国行业分类": "S370502"}, ...] + # 支持多个行业代码,用逗号分隔 + json_path = '$[*]."申银万国行业分类"' + + # 如果包含逗号,说明是多个行业代码 + if ',' in industry_code: + codes = [code.strip() for code in industry_code.split(',') if code.strip()] + # 使用 OR 条件匹配任意一个行业代码 + conditions = [] + for code in codes: + conditions.append( + text("JSON_CONTAINS(JSON_EXTRACT(related_industries, :json_path), :code)") + .bindparams(json_path=json_path, code=json.dumps(code)) + ) + query = query.filter(db.or_(*conditions)) + else: + # 单个行业代码 + query = query.filter( + text("JSON_CONTAINS(JSON_EXTRACT(related_industries, :json_path), :industry_code)") + ).params(json_path=json_path, industry_code=json.dumps(industry_code)) + # 新增:关键词/全文搜索过滤(MySQL JSON) + if search_query: + like_pattern = f"%{search_query}%" + query = query.filter( + db.or_( + Event.title.ilike(like_pattern), + Event.description.ilike(like_pattern), + text(f"JSON_SEARCH(keywords, 'one', '%{search_query}%') IS NOT NULL") + ) + ) + if recent_days: + from datetime import datetime, timedelta + cutoff_date = datetime.now() - timedelta(days=recent_days) + query = query.filter(Event.created_at >= cutoff_date) + else: + if date_range and ' 至 ' in date_range: + try: + start_date_str, end_date_str = date_range.split(' 至 ') + start_date = start_date_str.strip() + end_date = end_date_str.strip() + except ValueError: + pass + if start_date: + from datetime import datetime + try: + if len(start_date) == 10: + start_datetime = datetime.strptime(start_date, '%Y-%m-%d') + else: + start_datetime = datetime.strptime(start_date, '%Y-%m-%d %H:%M:%S') + query = query.filter(Event.created_at >= start_datetime) + except ValueError: + pass + if end_date: + from datetime import datetime + try: + if len(end_date) == 10: + end_datetime = datetime.strptime(end_date, '%Y-%m-%d') + end_datetime = end_datetime.replace(hour=23, minute=59, second=59) + else: + end_datetime = datetime.strptime(end_date, '%Y-%m-%d %H:%M:%S') + query = query.filter(Event.created_at <= end_datetime) + except ValueError: + pass + if min_view_count is not None: + query = query.filter(Event.view_count >= min_view_count) + # 排序 + from sqlalchemy import desc, asc, case + order_func = desc if order.lower() == 'desc' else asc + if sort_by == 'hot': + query = query.order_by(order_func(Event.hot_score)) + elif sort_by == 'new': + query = query.order_by(order_func(Event.created_at)) + elif sort_by == 'returns': + if return_type == 'avg': + query = query.order_by(order_func(Event.related_avg_chg)) + elif return_type == 'max': + query = query.order_by(order_func(Event.related_max_chg)) + elif return_type == 'week': + query = query.order_by(order_func(Event.related_week_chg)) + elif sort_by == 'importance': + importance_order = case( + (Event.importance == 'S', 1), + (Event.importance == 'A', 2), + (Event.importance == 'B', 3), + (Event.importance == 'C', 4), + else_=5 + ) + if order.lower() == 'desc': + query = query.order_by(importance_order) + else: + query = query.order_by(desc(importance_order)) + elif sort_by == 'view_count': + query = query.order_by(order_func(Event.view_count)) + # 分页 + paginated = query.paginate(page=page, per_page=per_page, error_out=False) + events_data = [] + for event in paginated.items: + event_dict = { + 'id': event.id, + 'title': event.title, + 'description': event.description, + 'event_type': event.event_type, + 'importance': event.importance, + 'status': event.status, + 'created_at': event.created_at.isoformat() if event.created_at else None, + 'updated_at': event.updated_at.isoformat() if event.updated_at else None, + 'start_time': event.start_time.isoformat() if event.start_time else None, + 'end_time': event.end_time.isoformat() if event.end_time else None, + } + if include_stats: + event_dict.update({ + 'hot_score': event.hot_score, + 'view_count': event.view_count, + 'post_count': event.post_count, + 'follower_count': event.follower_count, + 'related_avg_chg': event.related_avg_chg, + 'related_max_chg': event.related_max_chg, + 'related_week_chg': event.related_week_chg, + 'invest_score': event.invest_score, + 'trending_score': event.trending_score, + }) + if include_creator: + event_dict['creator'] = { + 'id': event.creator.id if event.creator else None, + 'username': event.creator.username if event.creator else 'Anonymous' + } + event_dict['keywords'] = event.keywords_list if hasattr(event, 'keywords_list') else event.keywords + event_dict['related_industries'] = event.related_industries + if include_related_data: + pass + events_data.append(event_dict) + applied_filters = {} + if event_type != 'all': + applied_filters['type'] = event_type + if importance != 'all': + applied_filters['importance'] = importance + if start_date: + applied_filters['start_date'] = start_date + if end_date: + applied_filters['end_date'] = end_date + if industry_code: + applied_filters['industry_code'] = industry_code + if tag: + applied_filters['tag'] = tag + if tags: + applied_filters['tags'] = tags + if search_query: + applied_filters['search_query'] = search_query + applied_filters['search_type'] = search_type + return jsonify({ + 'success': True, + 'data': { + 'events': events_data, + 'pagination': { + 'page': paginated.page, + 'per_page': paginated.per_page, + 'total': paginated.total, + 'pages': paginated.pages, + 'has_prev': paginated.has_prev, + 'has_next': paginated.has_next + }, + 'filters': { + 'applied_filters': applied_filters, + 'total_count': paginated.total + } + } + }) + except Exception as e: + app.logger.error(f"获取事件列表出错: {str(e)}", exc_info=True) + return jsonify({ + 'success': False, + 'error': str(e), + 'error_type': type(e).__name__ + }), 500 + + +@app.route('/api/events/hot', methods=['GET']) +def get_hot_events(): + """获取热点事件""" + try: + from datetime import datetime, timedelta + days = request.args.get('days', 3, type=int) + limit = request.args.get('limit', 4, type=int) + since_date = datetime.now() - timedelta(days=days) + hot_events = Event.query.filter( + Event.status == 'active', + Event.created_at >= since_date, + Event.related_avg_chg != None, + Event.related_avg_chg > 0 + ).order_by(Event.related_avg_chg.desc()).limit(limit).all() + if len(hot_events) < limit: + additional_events = Event.query.filter( + Event.status == 'active', + Event.created_at >= since_date, + ~Event.id.in_([event.id for event in hot_events]) + ).order_by(Event.hot_score.desc()).limit(limit - len(hot_events)).all() + hot_events.extend(additional_events) + events_data = [] + for event in hot_events: + events_data.append({ + 'id': event.id, + 'title': event.title, + 'description': event.description, + 'importance': event.importance, + 'created_at': event.created_at.isoformat() if event.created_at else None, + 'related_avg_chg': event.related_avg_chg, + 'creator': { + 'username': event.creator.username if event.creator else 'Anonymous' + } + }) + return jsonify({'success': True, 'data': events_data}) + except Exception as e: + return jsonify({'success': False, 'error': str(e)}), 500 + + +@app.route('/api/events/keywords/popular', methods=['GET']) +def get_popular_keywords(): + """获取热门关键词""" + try: + limit = request.args.get('limit', 20, type=int) + sql = ''' + WITH RECURSIVE \ + numbers AS (SELECT 0 as n \ + UNION ALL \ + SELECT n + 1 \ + FROM numbers \ + WHERE n < 100), \ + json_array AS (SELECT JSON_UNQUOTE(JSON_EXTRACT(e.keywords, CONCAT('$[', n.n, ']'))) as keyword, \ + COUNT(*) as count + FROM event e + CROSS JOIN numbers n + WHERE + e.status = 'active' + AND JSON_EXTRACT(e.keywords \ + , CONCAT('$[' \ + , n.n \ + , ']')) IS NOT NULL + GROUP BY JSON_UNQUOTE(JSON_EXTRACT(e.keywords, CONCAT('$[', n.n, ']'))) + HAVING keyword IS NOT NULL + ) + SELECT keyword, count + FROM json_array + ORDER BY count DESC, keyword LIMIT :limit \ + ''' + result = db.session.execute(text(sql), {'limit': limit}).all() + keywords_data = [{'keyword': row.keyword, 'count': row.count} for row in result] + return jsonify({'success': True, 'data': keywords_data}) + except Exception as e: + return jsonify({'success': False, 'error': str(e)}), 500 + + +@app.route('/api/events//sankey-data') +def get_event_sankey_data(event_id): + """ + 获取事件桑基图数据 (最终优化版) + - 处理重名节点 + - 检测并打破循环依赖 + """ + flows = EventSankeyFlow.query.filter_by(event_id=event_id).order_by( + EventSankeyFlow.source_level, EventSankeyFlow.target_level + ).all() + + if not flows: + return jsonify({'success': False, 'message': '暂无桑基图数据'}) + + nodes_map = {} + links = [] + type_colors = { + 'event': '#ff4757', 'policy': '#10ac84', 'technology': '#ee5a6f', + 'industry': '#00d2d3', 'company': '#54a0ff', 'product': '#ffd93d' + } + + # --- 1. 识别并处理重名节点 (与上一版相同) --- + all_node_keys = set() + name_counts = {} + for flow in flows: + source_key = f"{flow.source_node}|{flow.source_level}" + target_key = f"{flow.target_node}|{flow.target_level}" + all_node_keys.add(source_key) + all_node_keys.add(target_key) + name_counts.setdefault(flow.source_node, set()).add(flow.source_level) + name_counts.setdefault(flow.target_node, set()).add(flow.target_level) + + duplicate_names = {name for name, levels in name_counts.items() if len(levels) > 1} + + for flow in flows: + source_key = f"{flow.source_node}|{flow.source_level}" + if source_key not in nodes_map: + display_name = f"{flow.source_node} (L{flow.source_level})" if flow.source_node in duplicate_names else flow.source_node + nodes_map[source_key] = {'name': display_name, 'type': flow.source_type, 'level': flow.source_level, + 'color': type_colors.get(flow.source_type)} + + target_key = f"{flow.target_node}|{flow.target_level}" + if target_key not in nodes_map: + display_name = f"{flow.target_node} (L{flow.target_level})" if flow.target_node in duplicate_names else flow.target_node + nodes_map[target_key] = {'name': display_name, 'type': flow.target_type, 'level': flow.target_level, + 'color': type_colors.get(flow.target_type)} + + links.append({ + 'source_key': source_key, 'target_key': target_key, 'value': float(flow.flow_value), + 'ratio': float(flow.flow_ratio), 'transmission_path': flow.transmission_path, + 'impact_description': flow.impact_description, 'evidence_strength': flow.evidence_strength + }) + + # --- 2. 循环检测与处理 --- + # 构建邻接表 + adj = defaultdict(list) + for link in links: + adj[link['source_key']].append(link['target_key']) + + # 深度优先搜索(DFS)来检测循环 + path = set() # 记录当前递归路径上的节点 + visited = set() # 记录所有访问过的节点 + back_edges = set() # 记录导致循环的"回流边" + + def detect_cycle_util(node): + path.add(node) + visited.add(node) + for neighbour in adj.get(node, []): + if neighbour in path: + # 发现了循环,记录这条回流边 (target, source) + back_edges.add((neighbour, node)) + elif neighbour not in visited: + detect_cycle_util(neighbour) + path.remove(node) + + # 从所有节点开始检测 + for node_key in list(adj.keys()): + if node_key not in visited: + detect_cycle_util(node_key) + + # 过滤掉导致循环的边 + if back_edges: + print(f"检测到并移除了 {len(back_edges)} 条循环边: {back_edges}") + + valid_links_no_cycle = [] + for link in links: + if (link['source_key'], link['target_key']) not in back_edges and \ + (link['target_key'], link['source_key']) not in back_edges: # 移除非严格意义上的双向边 + valid_links_no_cycle.append(link) + + # --- 3. 构建最终的 JSON 响应 (与上一版相似) --- + node_list = [] + node_index_map = {} + sorted_node_keys = sorted(nodes_map.keys(), key=lambda k: (nodes_map[k]['level'], nodes_map[k]['name'])) + + for i, key in enumerate(sorted_node_keys): + node_list.append(nodes_map[key]) + node_index_map[key] = i + + final_links = [] + for link in valid_links_no_cycle: + source_idx = node_index_map.get(link['source_key']) + target_idx = node_index_map.get(link['target_key']) + if source_idx is not None and target_idx is not None: + # 移除临时的 key,只保留 ECharts 需要的字段 + link.pop('source_key', None) + link.pop('target_key', None) + link['source'] = source_idx + link['target'] = target_idx + final_links.append(link) + + # ... (统计信息计算部分保持不变) ... + stats = { + 'total_nodes': len(node_list), 'total_flows': len(final_links), + 'total_flow_value': sum(link['value'] for link in final_links), + 'max_level': max((node['level'] for node in node_list), default=0), + 'node_type_counts': {ntype: sum(1 for n in node_list if n['type'] == ntype) for ntype in type_colors} + } + + return jsonify({ + 'success': True, + 'data': {'nodes': node_list, 'links': final_links, 'stats': stats} + }) + + +# 优化后的传导链分析 API +@app.route('/api/events//chain-analysis') +def get_event_chain_analysis(event_id): + """获取事件传导链分析数据""" + nodes = EventTransmissionNode.query.filter_by(event_id=event_id).all() + if not nodes: + return jsonify({'success': False, 'message': '暂无传导链分析数据'}) + + edges = EventTransmissionEdge.query.filter_by(event_id=event_id).all() + + # 过滤孤立节点 + connected_node_ids = set() + for edge in edges: + connected_node_ids.add(edge.from_node_id) + connected_node_ids.add(edge.to_node_id) + + # 只保留有连接的节点 + connected_nodes = [node for node in nodes if node.id in connected_node_ids] + + if not connected_nodes: + return jsonify({'success': False, 'message': '所有节点都是孤立的,暂无传导关系'}) + + # 节点分类,用于力导向图的图例 + categories = { + 'event': "事件", 'industry': "行业", 'company': "公司", + 'policy': "政策", 'technology': "技术", 'market': "市场", 'other': "其他" + } + + # 计算每个节点的连接数 + node_connection_count = {} + for node in connected_nodes: + count = sum(1 for edge in edges + if edge.from_node_id == node.id or edge.to_node_id == node.id) + node_connection_count[node.id] = count + + nodes_data = [] + for node in connected_nodes: + connection_count = node_connection_count[node.id] + + nodes_data.append({ + 'id': str(node.id), + 'name': node.node_name, + 'value': node.importance_score, # 用于控制节点大小的基础值 + 'category': categories.get(node.node_type), + 'extra': { + 'node_type': node.node_type, + 'description': node.node_description, + 'importance_score': node.importance_score, + 'stock_code': node.stock_code, + 'is_main_event': node.is_main_event, + 'connection_count': connection_count, # 添加连接数信息 + } + }) + + edges_data = [] + for edge in edges: + # 确保边的两端节点都在连接节点列表中 + if edge.from_node_id in connected_node_ids and edge.to_node_id in connected_node_ids: + edges_data.append({ + 'source': str(edge.from_node_id), + 'target': str(edge.to_node_id), + 'value': edge.strength, # 用于控制边的宽度 + 'extra': { + 'transmission_type': edge.transmission_type, + 'transmission_mechanism': edge.transmission_mechanism, + 'direction': edge.direction, + 'strength': edge.strength, + 'impact': edge.impact, + 'is_circular': edge.is_circular, + } + }) + + # 重新计算统计信息(基于连接的节点和边) + stats = { + 'total_nodes': len(connected_nodes), + 'total_edges': len(edges_data), + 'node_types': {cat: sum(1 for n in connected_nodes if n.node_type == node_type) + for node_type, cat in categories.items()}, + 'edge_types': {edge.transmission_type: sum(1 for e in edges_data + if e['extra']['transmission_type'] == edge.transmission_type) for + edge in edges}, + 'avg_importance': sum(node.importance_score for node in connected_nodes) / len( + connected_nodes) if connected_nodes else 0, + 'avg_strength': sum(edge.strength for edge in edges) / len(edges) if edges else 0 + } + + return jsonify({ + 'success': True, + 'data': { + 'nodes': nodes_data, + 'edges': edges_data, + 'categories': list(categories.values()), + 'stats': stats + } + }) + + +@app.route('/api/events//chain-node/', methods=['GET']) +@cross_origin() +def get_chain_node_detail(event_id, node_id): + """获取传导链节点及其直接关联节点的详细信息""" + node = db.session.get(EventTransmissionNode, node_id) + if not node or node.event_id != event_id: + return jsonify({'success': False, 'message': '节点不存在'}) + + # 验证节点是否为孤立节点 + total_connections = (EventTransmissionEdge.query.filter_by(from_node_id=node_id).count() + + EventTransmissionEdge.query.filter_by(to_node_id=node_id).count()) + + if total_connections == 0 and not node.is_main_event: + return jsonify({'success': False, 'message': '该节点为孤立节点,无连接关系'}) + + # 找出影响当前节点的父节点 + parents_info = [] + incoming_edges = EventTransmissionEdge.query.filter_by(to_node_id=node_id).all() + for edge in incoming_edges: + parent = db.session.get(EventTransmissionNode, edge.from_node_id) + if parent: + parents_info.append({ + 'id': parent.id, + 'name': parent.node_name, + 'type': parent.node_type, + 'direction': edge.direction, + 'strength': edge.strength, + 'transmission_type': edge.transmission_type, + 'transmission_mechanism': edge.transmission_mechanism, # 修复字段名 + 'is_circular': edge.is_circular, + 'impact': edge.impact + }) + + # 找出被当前节点影响的子节点 + children_info = [] + outgoing_edges = EventTransmissionEdge.query.filter_by(from_node_id=node_id).all() + for edge in outgoing_edges: + child = db.session.get(EventTransmissionNode, edge.to_node_id) + if child: + children_info.append({ + 'id': child.id, + 'name': child.node_name, + 'type': child.node_type, + 'direction': edge.direction, + 'strength': edge.strength, + 'transmission_type': edge.transmission_type, + 'transmission_mechanism': edge.transmission_mechanism, # 修复字段名 + 'is_circular': edge.is_circular, + 'impact': edge.impact + }) + + node_data = { + 'id': node.id, + 'name': node.node_name, + 'type': node.node_type, + 'description': node.node_description, + 'importance_score': node.importance_score, + 'stock_code': node.stock_code, + 'is_main_event': node.is_main_event, + 'total_connections': total_connections, + 'incoming_connections': len(incoming_edges), + 'outgoing_connections': len(outgoing_edges) + } + + return jsonify({ + 'success': True, + 'data': { + 'node': node_data, + 'parents': parents_info, + 'children': children_info + } + }) + + +@app.route('/api/events//posts', methods=['GET']) +def get_event_posts(event_id): + """获取事件下的帖子""" + try: + sort_type = request.args.get('sort', 'latest') + page = request.args.get('page', 1, type=int) + per_page = request.args.get('per_page', 20, type=int) + + # 查询事件下的帖子 + query = Post.query.filter_by(event_id=event_id, status='active') + + if sort_type == 'hot': + query = query.order_by(Post.likes_count.desc(), Post.created_at.desc()) + else: # latest + query = query.order_by(Post.created_at.desc()) + + # 分页 + pagination = query.paginate(page=page, per_page=per_page, error_out=False) + posts = pagination.items + + posts_data = [] + for post in posts: + post_dict = { + 'id': post.id, + 'event_id': post.event_id, + 'user_id': post.user_id, + 'title': post.title, + 'content': post.content, + 'content_type': post.content_type, + 'created_at': post.created_at.isoformat(), + 'updated_at': post.updated_at.isoformat(), + 'likes_count': post.likes_count, + 'comments_count': post.comments_count, + 'view_count': post.view_count, + 'is_top': post.is_top, + 'user': { + 'id': post.user.id, + 'username': post.user.username, + 'avatar_url': post.user.avatar_url + } if post.user else None, + 'liked': False # 后续可以根据当前用户判断 + } + posts_data.append(post_dict) + + return jsonify({ + 'success': True, + 'data': posts_data, + 'pagination': { + 'page': page, + 'per_page': per_page, + 'total': pagination.total, + 'pages': pagination.pages + } + }) + + except Exception as e: + print(f"获取帖子失败: {e}") + return jsonify({ + 'success': False, + 'error': str(e) + }), 500 + + +@app.route('/api/events//posts', methods=['POST']) +@login_required +def create_event_post(event_id): + """在事件下创建帖子""" + try: + data = request.get_json() + content = data.get('content', '').strip() + title = data.get('title', '').strip() + content_type = data.get('content_type', 'text') + + if not content: + return jsonify({ + 'success': False, + 'message': '帖子内容不能为空' + }), 400 + + # 创建新帖子 + post = Post( + event_id=event_id, + user_id=current_user.id, + title=title, + content=content, + content_type=content_type + ) + + db.session.add(post) + + # 更新事件的帖子数 + event = Event.query.get(event_id) + if event: + event.post_count = Post.query.filter_by(event_id=event_id, status='active').count() + + # 更新用户发帖数 + current_user.post_count = (current_user.post_count or 0) + 1 + + db.session.commit() + + return jsonify({ + 'success': True, + 'data': { + 'id': post.id, + 'event_id': post.event_id, + 'user_id': post.user_id, + 'title': post.title, + 'content': post.content, + 'content_type': post.content_type, + 'created_at': post.created_at.isoformat(), + 'user': { + 'id': current_user.id, + 'username': current_user.username, + 'avatar_url': current_user.avatar_url + } + }, + 'message': '帖子发布成功' + }) + + except Exception as e: + db.session.rollback() + print(f"创建帖子失败: {e}") + return jsonify({ + 'success': False, + 'message': str(e) + }), 500 + + +@app.route('/api/posts//comments', methods=['GET']) +def get_post_comments(post_id): + """获取帖子的评论""" + try: + sort_type = request.args.get('sort', 'latest') + + # 查询帖子的顶级评论(非回复) + query = Comment.query.filter_by(post_id=post_id, parent_id=None, status='active') + + if sort_type == 'hot': + comments = query.order_by(Comment.likes_count.desc(), Comment.created_at.desc()).all() + else: # latest + comments = query.order_by(Comment.created_at.desc()).all() + + comments_data = [] + for comment in comments: + comment_dict = { + 'id': comment.id, + 'post_id': comment.post_id, + 'user_id': comment.user_id, + 'content': comment.content, + 'created_at': comment.created_at.isoformat(), + 'updated_at': comment.updated_at.isoformat(), + 'likes_count': comment.likes_count, + 'user': { + 'id': comment.user.id, + 'username': comment.user.username, + 'avatar_url': comment.user.avatar_url + } if comment.user else None, + 'replies': [] # 加载回复 + } + + # 加载回复 + replies = Comment.query.filter_by(parent_id=comment.id, status='active').order_by(Comment.created_at).all() + for reply in replies: + reply_dict = { + 'id': reply.id, + 'post_id': reply.post_id, + 'user_id': reply.user_id, + 'content': reply.content, + 'parent_id': reply.parent_id, + 'created_at': reply.created_at.isoformat(), + 'likes_count': reply.likes_count, + 'user': { + 'id': reply.user.id, + 'username': reply.user.username, + 'avatar_url': reply.user.avatar_url + } if reply.user else None + } + comment_dict['replies'].append(reply_dict) + + comments_data.append(comment_dict) + + return jsonify({ + 'success': True, + 'data': comments_data + }) + + except Exception as e: + print(f"获取评论失败: {e}") + return jsonify({ + 'success': False, + 'error': str(e) + }), 500 + + +@app.route('/api/posts//comments', methods=['POST']) +@login_required +def create_post_comment(post_id): + """在帖子下创建评论""" + try: + data = request.get_json() + content = data.get('content', '').strip() + parent_id = data.get('parent_id') + + if not content: + return jsonify({ + 'success': False, + 'message': '评论内容不能为空' + }), 400 + + # 创建新评论 + comment = Comment( + post_id=post_id, + user_id=current_user.id, + content=content, + parent_id=parent_id + ) + + db.session.add(comment) + + # 更新帖子评论数 + post = Post.query.get(post_id) + if post: + post.comments_count = Comment.query.filter_by(post_id=post_id, status='active').count() + + # 更新用户评论数 + current_user.comment_count = (current_user.comment_count or 0) + 1 + + db.session.commit() + + return jsonify({ + 'success': True, + 'data': { + 'id': comment.id, + 'post_id': comment.post_id, + 'user_id': comment.user_id, + 'content': comment.content, + 'parent_id': comment.parent_id, + 'created_at': comment.created_at.isoformat(), + 'user': { + 'id': current_user.id, + 'username': current_user.username, + 'avatar_url': current_user.avatar_url + } + }, + 'message': '评论发布成功' + }) + + except Exception as e: + db.session.rollback() + print(f"创建评论失败: {e}") + return jsonify({ + 'success': False, + 'message': str(e) + }), 500 + + +# 兼容旧的评论接口,转换为帖子模式 +@app.route('/api/events//comments', methods=['GET']) +def get_event_comments(event_id): + """获取事件评论(兼容旧接口)""" + # 将事件评论转换为获取事件下所有帖子的评论 + return get_event_posts(event_id) + + +@app.route('/api/events//comments', methods=['POST']) +@login_required +def add_event_comment(event_id): + """添加事件评论(兼容旧接口)""" + try: + data = request.get_json() + content = data.get('content', '').strip() + parent_id = data.get('parent_id') + + if not content: + return jsonify({ + 'success': False, + 'message': '评论内容不能为空' + }), 400 + + # 如果有 parent_id,说明是回复,需要找到对应的帖子 + if parent_id: + # 这是一个回复,需要将其转换为对应帖子的评论 + # 首先需要找到 parent_id 对应的帖子 + # 这里假设旧的 parent_id 是之前的 EventComment id + # 需要在数据迁移时处理这个映射关系 + return jsonify({ + 'success': False, + 'message': '回复功能正在升级中,请稍后再试' + }), 503 + + # 如果没有 parent_id,说明是顶级评论,创建为新帖子 + post = Post( + event_id=event_id, + user_id=current_user.id, + content=content, + content_type='text' + ) + + db.session.add(post) + + # 更新事件的帖子数 + event = Event.query.get(event_id) + if event: + event.post_count = Post.query.filter_by(event_id=event_id, status='active').count() + + # 更新用户发帖数 + current_user.post_count = (current_user.post_count or 0) + 1 + + db.session.commit() + + # 返回兼容旧接口的数据格式 + return jsonify({ + 'success': True, + 'data': { + 'id': post.id, + 'event_id': post.event_id, + 'user_id': post.user_id, + 'author': current_user.username, + 'content': post.content, + 'parent_id': None, + 'likes': 0, + 'created_at': post.created_at.isoformat(), + 'status': 'active', + 'user': { + 'id': current_user.id, + 'username': current_user.username, + 'avatar_url': current_user.avatar_url + }, + 'replies': [] + }, + 'message': '评论发布成功' + }) + + except Exception as e: + db.session.rollback() + print(f"添加事件评论失败: {e}") + return jsonify({ + 'success': False, + 'message': str(e) + }), 500 + + +# ==================== WebSocket 事件处理器(实时事件推送) ==================== + +@socketio.on('connect') +def handle_connect(): + """客户端连接事件""" + print(f'\n[WebSocket DEBUG] ========== 客户端连接 ==========') + print(f'[WebSocket DEBUG] Socket ID: {request.sid}') + print(f'[WebSocket DEBUG] Remote Address: {request.remote_addr if hasattr(request, "remote_addr") else "N/A"}') + print(f'[WebSocket] 客户端已连接: {request.sid}') + + emit('connection_response', { + 'status': 'connected', + 'sid': request.sid, + 'message': '已连接到事件推送服务' + }) + print(f'[WebSocket DEBUG] ✓ 已发送 connection_response') + print(f'[WebSocket DEBUG] ========== 连接完成 ==========\n') + + +@socketio.on('subscribe_events') +def handle_subscribe(data): + """ + 客户端订阅事件推送 + data: { + 'event_type': 'all' | 'policy' | 'market' | 'tech' | ..., + 'importance': 'all' | 'S' | 'A' | 'B' | 'C', + 'filters': {...} # 可选的其他筛选条件 + } + """ + try: + print(f'\n[WebSocket DEBUG] ========== 收到订阅请求 ==========') + print(f'[WebSocket DEBUG] Socket ID: {request.sid}') + print(f'[WebSocket DEBUG] 订阅数据: {data}') + + event_type = data.get('event_type', 'all') + importance = data.get('importance', 'all') + + print(f'[WebSocket DEBUG] 事件类型: {event_type}') + print(f'[WebSocket DEBUG] 重要性: {importance}') + + # 加入对应的房间 + room_name = f"events_{event_type}" + print(f'[WebSocket DEBUG] 准备加入房间: {room_name}') + join_room(room_name) + print(f'[WebSocket DEBUG] ✓ 已加入房间: {room_name}') + + print(f'[WebSocket] 客户端 {request.sid} 订阅了房间: {room_name}') + + response_data = { + 'success': True, + 'room': room_name, + 'event_type': event_type, + 'importance': importance, + 'message': f'已订阅 {event_type} 类型的事件推送' + } + print(f'[WebSocket DEBUG] 准备发送 subscription_confirmed: {response_data}') + emit('subscription_confirmed', response_data) + print(f'[WebSocket DEBUG] ✓ 已发送 subscription_confirmed') + print(f'[WebSocket DEBUG] ========== 订阅完成 ==========\n') + + except Exception as e: + print(f'[WebSocket ERROR] 订阅失败: {e}') + import traceback + traceback.print_exc() + emit('subscription_error', { + 'success': False, + 'error': str(e) + }) + + +@socketio.on('unsubscribe_events') +def handle_unsubscribe(data): + """取消订阅事件推送""" + try: + print(f'\n[WebSocket DEBUG] ========== 收到取消订阅请求 ==========') + print(f'[WebSocket DEBUG] Socket ID: {request.sid}') + print(f'[WebSocket DEBUG] 数据: {data}') + + event_type = data.get('event_type', 'all') + room_name = f"events_{event_type}" + + print(f'[WebSocket DEBUG] 准备离开房间: {room_name}') + leave_room(room_name) + print(f'[WebSocket DEBUG] ✓ 已离开房间: {room_name}') + + print(f'[WebSocket] 客户端 {request.sid} 取消订阅房间: {room_name}') + + emit('unsubscription_confirmed', { + 'success': True, + 'room': room_name, + 'message': f'已取消订阅 {event_type} 类型的事件推送' + }) + print(f'[WebSocket DEBUG] ========== 取消订阅完成 ==========\n') + + except Exception as e: + print(f'[WebSocket ERROR] 取消订阅失败: {e}') + import traceback + traceback.print_exc() + emit('unsubscription_error', { + 'success': False, + 'error': str(e) + }) + + +@socketio.on('disconnect') +def handle_disconnect(): + """客户端断开连接事件""" + print(f'\n[WebSocket DEBUG] ========== 客户端断开 ==========') + print(f'[WebSocket DEBUG] Socket ID: {request.sid}') + print(f'[WebSocket] 客户端已断开: {request.sid}') + print(f'[WebSocket DEBUG] ========== 断开完成 ==========\n') + + +# ==================== WebSocket 辅助函数 ==================== + +def broadcast_new_event(event): + """ + 广播新事件到所有订阅的客户端 + 在创建新事件时调用此函数 + + Args: + event: Event 模型实例 + """ + try: + print(f'\n[WebSocket DEBUG] ========== 广播新事件 ==========') + print(f'[WebSocket DEBUG] 事件ID: {event.id}') + print(f'[WebSocket DEBUG] 事件标题: {event.title}') + print(f'[WebSocket DEBUG] 事件类型: {event.event_type}') + print(f'[WebSocket DEBUG] 重要性: {event.importance}') + + event_data = { + 'id': event.id, + 'title': event.title, + 'description': event.description, + 'event_type': event.event_type, + 'importance': event.importance, + 'status': event.status, + 'created_at': event.created_at.isoformat() if event.created_at else None, + 'hot_score': event.hot_score, + 'view_count': event.view_count, + 'related_avg_chg': event.related_avg_chg, + 'related_max_chg': event.related_max_chg, + 'keywords': event.keywords_list if hasattr(event, 'keywords_list') else event.keywords, + } + + print(f'[WebSocket DEBUG] 准备发送的数据: {event_data}') + + # 发送到所有订阅者(all 房间) + print(f'[WebSocket DEBUG] 正在发送到房间: events_all') + socketio.emit('new_event', event_data, room='events_all', namespace='/') + print(f'[WebSocket DEBUG] ✓ 已发送到 events_all') + + # 发送到特定类型订阅者 + if event.event_type: + room_name = f"events_{event.event_type}" + print(f'[WebSocket DEBUG] 正在发送到房间: {room_name}') + socketio.emit('new_event', event_data, room=room_name, namespace='/') + print(f'[WebSocket DEBUG] ✓ 已发送到 {room_name}') + print(f'[WebSocket] 已推送新事件到房间: events_all, {room_name}') + else: + print(f'[WebSocket] 已推送新事件到房间: events_all') + + print(f'[WebSocket DEBUG] ========== 广播完成 ==========\n') + + except Exception as e: + print(f'[WebSocket ERROR] 推送新事件失败: {e}') + import traceback + traceback.print_exc() + + +# ==================== WebSocket 轮询机制(检测新事件) ==================== + +# 内存变量:记录近24小时内已知的事件ID集合和最大ID +known_event_ids_in_24h = set() # 近24小时内已知的所有事件ID +last_max_event_id = 0 # 已知的最大事件ID + +def poll_new_events(): + """ + 定期轮询数据库,检查是否有新事件 + 每 30 秒执行一次 + + 新的设计思路(修复 created_at 不是入库时间的问题): + 1. 查询近24小时内的所有活跃事件(按 created_at,因为这是事件发生时间) + 2. 通过对比事件ID(自增ID)来判断是否为新插入的事件 + 3. 推送 ID > last_max_event_id 的事件 + 4. 更新已知事件ID集合和最大ID + """ + global known_event_ids_in_24h, last_max_event_id + + try: + with app.app_context(): + from datetime import datetime, timedelta + + current_time = datetime.now() + print(f'\n[轮询 DEBUG] ========== 开始轮询 ==========') + print(f'[轮询 DEBUG] 当前时间: {current_time.strftime("%Y-%m-%d %H:%M:%S")}') + print(f'[轮询 DEBUG] 已知事件ID数量: {len(known_event_ids_in_24h)}') + print(f'[轮询 DEBUG] 当前最大事件ID: {last_max_event_id}') + + # 查询近24小时内的所有活跃事件(按事件发生时间 created_at) + time_24h_ago = current_time - timedelta(hours=24) + print(f'[轮询 DEBUG] 查询时间范围: 近24小时({time_24h_ago.strftime("%Y-%m-%d %H:%M:%S")} ~ 现在)') + + # 查询所有近24小时内的活跃事件 + events_in_24h = Event.query.filter( + Event.created_at >= time_24h_ago, + Event.status == 'active' + ).order_by(Event.id.asc()).all() + + print(f'[轮询 DEBUG] 数据库查询结果: 找到 {len(events_in_24h)} 个近24小时内的事件') + + # 找出新插入的事件(ID > last_max_event_id) + new_events = [ + event for event in events_in_24h + if event.id > last_max_event_id + ] + + print(f'[轮询 DEBUG] 新事件数量(ID > {last_max_event_id}): {len(new_events)} 个') + + if new_events: + print(f'[轮询] 发现 {len(new_events)} 个新事件') + + for event in new_events: + print(f'[轮询 DEBUG] 新事件详情:') + print(f'[轮询 DEBUG] - ID: {event.id}') + print(f'[轮询 DEBUG] - 标题: {event.title}') + print(f'[轮询 DEBUG] - 事件发生时间(created_at): {event.created_at}') + print(f'[轮询 DEBUG] - 事件类型: {event.event_type}') + + # 推送新事件 + print(f'[轮询 DEBUG] 准备推送事件 ID={event.id}') + broadcast_new_event(event) + print(f'[轮询] ✓ 已推送事件 ID={event.id}, 标题={event.title}') + + # 更新已知事件ID集合(所有近24小时内的事件ID) + known_event_ids_in_24h = set(event.id for event in events_in_24h) + + # 更新最大事件ID + new_max_id = max(event.id for event in events_in_24h) + print(f'[轮询 DEBUG] 更新最大事件ID: {last_max_event_id} -> {new_max_id}') + last_max_event_id = new_max_id + + print(f'[轮询 DEBUG] 更新后已知事件ID数量: {len(known_event_ids_in_24h)}') + + else: + print(f'[轮询 DEBUG] 没有新事件需要推送') + + # 即使没有新事件,也要更新已知事件集合(清理超过24小时的) + if events_in_24h: + known_event_ids_in_24h = set(event.id for event in events_in_24h) + current_max_id = max(event.id for event in events_in_24h) + if current_max_id != last_max_event_id: + print(f'[轮询 DEBUG] 更新最大事件ID: {last_max_event_id} -> {current_max_id}') + last_max_event_id = current_max_id + + print(f'[轮询 DEBUG] ========== 轮询结束 ==========\n') + + except Exception as e: + print(f'[轮询 ERROR] 检查新事件时出错: {e}') + import traceback + traceback.print_exc() + + +def initialize_event_polling(): + """ + 初始化事件轮询机制 + 在应用启动时调用 + """ + global known_event_ids_in_24h, last_max_event_id + + try: + from datetime import datetime, timedelta + + with app.app_context(): + current_time = datetime.now() + time_24h_ago = current_time - timedelta(hours=24) + + print(f'\n[轮询] ========== 初始化事件轮询 ==========') + print(f'[轮询] 当前时间: {current_time.strftime("%Y-%m-%d %H:%M:%S")}') + + # 查询近24小时内的所有活跃事件 + events_in_24h = Event.query.filter( + Event.created_at >= time_24h_ago, + Event.status == 'active' + ).order_by(Event.id.asc()).all() + + # 初始化已知事件ID集合 + known_event_ids_in_24h = set(event.id for event in events_in_24h) + + # 初始化最大事件ID + if events_in_24h: + last_max_event_id = max(event.id for event in events_in_24h) + print(f'[轮询] 近24小时内共有 {len(events_in_24h)} 个活跃事件') + print(f'[轮询] 初始最大事件ID: {last_max_event_id}') + print(f'[轮询] 事件ID范围: {min(event.id for event in events_in_24h)} ~ {last_max_event_id}') + else: + last_max_event_id = 0 + print(f'[轮询] 近24小时内没有活跃事件') + print(f'[轮询] 初始最大事件ID: 0') + + # 统计数据库中的事件总数 + total_events = Event.query.filter_by(status='active').count() + print(f'[轮询] 数据库中共有 {total_events} 个活跃事件(所有时间)') + print(f'[轮询] 只会推送 ID > {last_max_event_id} 的新事件') + print(f'[轮询] ========== 初始化完成 ==========\n') + + # 创建后台调度器 + scheduler = BackgroundScheduler() + # 每 30 秒执行一次轮询 + scheduler.add_job( + func=poll_new_events, + trigger='interval', + seconds=30, + id='poll_new_events', + name='检查新事件并推送', + replace_existing=True + ) + scheduler.start() + print('[轮询] 调度器已启动,每 30 秒检查一次新事件') + + except Exception as e: + print(f'[轮询] 初始化失败: {e}') + + +# ==================== 结束 WebSocket 部分 ==================== + + +@app.route('/api/posts//like', methods=['POST']) +@login_required +def like_post(post_id): + """点赞/取消点赞帖子""" + try: + post = Post.query.get_or_404(post_id) + + # 检查是否已经点赞 + existing_like = PostLike.query.filter_by( + post_id=post_id, + user_id=current_user.id + ).first() + + if existing_like: + # 取消点赞 + db.session.delete(existing_like) + post.likes_count = max(0, post.likes_count - 1) + message = '取消点赞成功' + liked = False + else: + # 添加点赞 + new_like = PostLike(post_id=post_id, user_id=current_user.id) + db.session.add(new_like) + post.likes_count += 1 + message = '点赞成功' + liked = True + + db.session.commit() + + return jsonify({ + 'success': True, + 'message': message, + 'likes_count': post.likes_count, + 'liked': liked + }) + + except Exception as e: + db.session.rollback() + print(f"点赞失败: {e}") + return jsonify({ + 'success': False, + 'message': str(e) + }), 500 + + +@app.route('/api/comments//like', methods=['POST']) +@login_required +def like_comment(comment_id): + """点赞/取消点赞评论""" + try: + comment = Comment.query.get_or_404(comment_id) + + # 检查是否已经点赞(需要创建 CommentLike 关联到新的 Comment 模型) + # 暂时使用简单的计数器 + comment.likes_count += 1 + db.session.commit() + + return jsonify({ + 'success': True, + 'message': '点赞成功', + 'likes_count': comment.likes_count + }) + + except Exception as e: + db.session.rollback() + print(f"点赞失败: {e}") + return jsonify({ + 'success': False, + 'message': str(e) + }), 500 + + +@app.route('/api/posts/', methods=['DELETE']) +@login_required +def delete_post(post_id): + """删除帖子""" + try: + post = Post.query.get_or_404(post_id) + + # 检查权限:只能删除自己的帖子 + if post.user_id != current_user.id: + return jsonify({ + 'success': False, + 'message': '您只能删除自己的帖子' + }), 403 + + # 软删除 + post.status = 'deleted' + + # 更新事件的帖子数 + event = Event.query.get(post.event_id) + if event: + event.post_count = Post.query.filter_by(event_id=post.event_id, status='active').count() + + # 更新用户发帖数 + if current_user.post_count > 0: + current_user.post_count -= 1 + + db.session.commit() + + return jsonify({ + 'success': True, + 'message': '帖子删除成功' + }) + + except Exception as e: + db.session.rollback() + print(f"删除帖子失败: {e}") + return jsonify({ + 'success': False, + 'message': str(e) + }), 500 + + +@app.route('/api/comments/', methods=['DELETE']) +@login_required +def delete_comment(comment_id): + """删除评论""" + try: + comment = Comment.query.get_or_404(comment_id) + + # 检查权限:只能删除自己的评论 + if comment.user_id != current_user.id: + return jsonify({ + 'success': False, + 'message': '您只能删除自己的评论' + }), 403 + + # 软删除 + comment.status = 'deleted' + comment.content = '[该评论已被删除]' + + # 更新帖子评论数 + post = Post.query.get(comment.post_id) + if post: + post.comments_count = Comment.query.filter_by(post_id=comment.post_id, status='active').count() + + # 更新用户评论数 + if current_user.comment_count > 0: + current_user.comment_count -= 1 + + db.session.commit() + + return jsonify({ + 'success': True, + 'message': '评论删除成功' + }) + + except Exception as e: + db.session.rollback() + print(f"删除评论失败: {e}") + return jsonify({ + 'success': False, + 'message': str(e) + }), 500 + + +def format_decimal(value): + """格式化decimal类型数据""" + if value is None: + return None + if isinstance(value, Decimal): + return float(value) + return float(value) + + +def format_date(date_obj): + """格式化日期""" + if date_obj is None: + return None + if isinstance(date_obj, datetime): + return date_obj.strftime('%Y-%m-%d') + return str(date_obj) + + +def remove_cycles_from_sankey_flows(flows_data): + """ + 移除Sankey图数据中的循环边,确保数据是DAG(有向无环图) + 使用拓扑排序算法检测循环,优先保留flow_ratio高的边 + + Args: + flows_data: list of flow objects with 'source', 'target', 'flow_metrics' keys + + Returns: + list of flows without cycles + """ + if not flows_data: + return flows_data + + # 按flow_ratio降序排序,优先保留重要的边 + sorted_flows = sorted( + flows_data, + key=lambda x: x.get('flow_metrics', {}).get('flow_ratio', 0) or 0, + reverse=True + ) + + # 构建图的邻接表和入度表 + def build_graph(flows): + graph = {} # node -> list of successors + in_degree = {} # node -> in-degree count + all_nodes = set() + + for flow in flows: + source = flow['source']['node_name'] + target = flow['target']['node_name'] + all_nodes.add(source) + all_nodes.add(target) + + if source not in graph: + graph[source] = [] + graph[source].append(target) + + if target not in in_degree: + in_degree[target] = 0 + in_degree[target] += 1 + + if source not in in_degree: + in_degree[source] = 0 + + return graph, in_degree, all_nodes + + # 使用Kahn算法检测是否有环 + def has_cycle(graph, in_degree, all_nodes): + # 找到所有入度为0的节点 + queue = [node for node in all_nodes if in_degree.get(node, 0) == 0] + visited_count = 0 + + while queue: + node = queue.pop(0) + visited_count += 1 + + # 访问所有邻居 + for neighbor in graph.get(node, []): + in_degree[neighbor] -= 1 + if in_degree[neighbor] == 0: + queue.append(neighbor) + + # 如果访问的节点数等于总节点数,说明没有环 + return visited_count < len(all_nodes) + + # 逐个添加边,如果添加后产生环则跳过 + result_flows = [] + + for flow in sorted_flows: + # 尝试添加这条边 + temp_flows = result_flows + [flow] + + # 检查是否产生环 + graph, in_degree, all_nodes = build_graph(temp_flows) + + # 复制in_degree用于检测(因为检测过程会修改它) + in_degree_copy = in_degree.copy() + + if not has_cycle(graph, in_degree_copy, all_nodes): + # 没有产生环,可以添加 + result_flows.append(flow) + else: + # 产生环,跳过这条边 + print(f"Skipping edge that creates cycle: {flow['source']['node_name']} -> {flow['target']['node_name']}") + + removed_count = len(flows_data) - len(result_flows) + if removed_count > 0: + print(f"Removed {removed_count} edges to eliminate cycles in Sankey diagram") + + return result_flows + + +def get_report_type(date_str): + """获取报告期类型""" + if not date_str: + return '' + if isinstance(date_str, str): + date = datetime.strptime(date_str, '%Y-%m-%d') + else: + date = date_str + + month = date.month + year = date.year + + if month == 3: + return f"{year}年一季报" + elif month == 6: + return f"{year}年中报" + elif month == 9: + return f"{year}年三季报" + elif month == 12: + return f"{year}年年报" + else: + return str(date_str) + + +@app.route('/api/financial/stock-info/', methods=['GET']) +def get_stock_info(seccode): + """获取股票基本信息和最新财务摘要""" + try: + # 获取最新的财务数据 + query = text(""" + SELECT distinct a.SECCODE, + a.SECNAME, + a.ENDDATE, + a.F003N as eps, + a.F004N as basic_eps, + a.F005N as diluted_eps, + a.F006N as deducted_eps, + a.F007N as undistributed_profit_ps, + a.F008N as bvps, + a.F010N as capital_reserve_ps, + a.F014N as roe, + a.F067N as roe_weighted, + a.F016N as roa, + a.F078N as gross_margin, + a.F017N as net_margin, + a.F089N as revenue, + a.F101N as net_profit, + a.F102N as parent_net_profit, + a.F118N as total_assets, + a.F121N as total_liabilities, + a.F128N as total_equity, + a.F052N as revenue_growth, + a.F053N as profit_growth, + a.F054N as equity_growth, + a.F056N as asset_growth, + a.F122N as share_capital + FROM ea_financialindex a + WHERE a.SECCODE = :seccode + ORDER BY a.ENDDATE DESC LIMIT 1 + """) + + result = engine.execute(query, seccode=seccode).fetchone() + + if not result: + return jsonify({ + 'success': False, + 'message': f'未找到股票代码 {seccode} 的财务数据' + }), 404 + + # 获取最近的业绩预告 + forecast_query = text(""" + SELECT distinct F001D as report_date, + F003V as forecast_type, + F004V as content, + F007N as profit_lower, + F008N as profit_upper, + F009N as change_lower, + F010N as change_upper + FROM ea_forecast + WHERE SECCODE = :seccode + AND F006C = 'T' + ORDER BY F001D DESC LIMIT 1 + """) + + forecast_result = engine.execute(forecast_query, seccode=seccode).fetchone() + + data = { + 'stock_code': result.SECCODE, + 'stock_name': result.SECNAME, + 'latest_period': format_date(result.ENDDATE), + 'report_type': get_report_type(result.ENDDATE), + 'key_metrics': { + 'eps': format_decimal(result.eps), + 'basic_eps': format_decimal(result.basic_eps), + 'diluted_eps': format_decimal(result.diluted_eps), + 'deducted_eps': format_decimal(result.deducted_eps), + 'bvps': format_decimal(result.bvps), + 'roe': format_decimal(result.roe), + 'roe_weighted': format_decimal(result.roe_weighted), + 'roa': format_decimal(result.roa), + 'gross_margin': format_decimal(result.gross_margin), + 'net_margin': format_decimal(result.net_margin), + }, + 'financial_summary': { + 'revenue': format_decimal(result.revenue), + 'net_profit': format_decimal(result.net_profit), + 'parent_net_profit': format_decimal(result.parent_net_profit), + 'total_assets': format_decimal(result.total_assets), + 'total_liabilities': format_decimal(result.total_liabilities), + 'total_equity': format_decimal(result.total_equity), + 'share_capital': format_decimal(result.share_capital), + }, + 'growth_rates': { + 'revenue_growth': format_decimal(result.revenue_growth), + 'profit_growth': format_decimal(result.profit_growth), + 'equity_growth': format_decimal(result.equity_growth), + 'asset_growth': format_decimal(result.asset_growth), + } + } + + # 添加业绩预告信息 + if forecast_result: + data['latest_forecast'] = { + 'report_date': format_date(forecast_result.report_date), + 'forecast_type': forecast_result.forecast_type, + 'content': forecast_result.content, + 'profit_range': { + 'lower': format_decimal(forecast_result.profit_lower), + 'upper': format_decimal(forecast_result.profit_upper), + }, + 'change_range': { + 'lower': format_decimal(forecast_result.change_lower), + 'upper': format_decimal(forecast_result.change_upper), + } + } + + return jsonify({ + 'success': True, + 'data': data + }) + + except Exception as e: + return jsonify({ + 'success': False, + 'error': str(e) + }), 500 + + +@app.route('/api/financial/balance-sheet/', methods=['GET']) +def get_balance_sheet(seccode): + """获取完整的资产负债表数据""" + try: + limit = request.args.get('limit', 12, type=int) + + query = text(""" + SELECT distinct ENDDATE, + DECLAREDATE, + -- 流动资产 + F006N as cash, -- 货币资金 + F007N as trading_financial_assets, -- 交易性金融资产 + F008N as notes_receivable, -- 应收票据 + F009N as accounts_receivable, -- 应收账款 + F010N as prepayments, -- 预付款项 + F011N as other_receivables, -- 其他应收款 + F013N as interest_receivable, -- 应收利息 + F014N as dividends_receivable, -- 应收股利 + F015N as inventory, -- 存货 + F016N as consumable_biological_assets, -- 消耗性生物资产 + F017N as non_current_assets_due_within_one_year, -- 一年内到期的非流动资产 + F018N as other_current_assets, -- 其他流动资产 + F019N as total_current_assets, -- 流动资产合计 + + -- 非流动资产 + F020N as available_for_sale_financial_assets, -- 可供出售金融资产 + F021N as held_to_maturity_investments, -- 持有至到期投资 + F022N as long_term_receivables, -- 长期应收款 + F023N as long_term_equity_investments, -- 长期股权投资 + F024N as investment_property, -- 投资性房地产 + F025N as fixed_assets, -- 固定资产 + F026N as construction_in_progress, -- 在建工程 + F027N as engineering_materials, -- 工程物资 + F029N as productive_biological_assets, -- 生产性生物资产 + F030N as oil_and_gas_assets, -- 油气资产 + F031N as intangible_assets, -- 无形资产 + F032N as development_expenditure, -- 开发支出 + F033N as goodwill, -- 商誉 + F034N as long_term_deferred_expenses, -- 长期待摊费用 + F035N as deferred_tax_assets, -- 递延所得税资产 + F036N as other_non_current_assets, -- 其他非流动资产 + F037N as total_non_current_assets, -- 非流动资产合计 + F038N as total_assets, -- 资产总计 + + -- 流动负债 + F039N as short_term_borrowings, -- 短期借款 + F040N as trading_financial_liabilities, -- 交易性金融负债 + F041N as notes_payable, -- 应付票据 + F042N as accounts_payable, -- 应付账款 + F043N as advance_receipts, -- 预收款项 + F044N as employee_compensation_payable, -- 应付职工薪酬 + F045N as taxes_payable, -- 应交税费 + F046N as interest_payable, -- 应付利息 + F047N as dividends_payable, -- 应付股利 + F048N as other_payables, -- 其他应付款 + F050N as non_current_liabilities_due_within_one_year, -- 一年内到期的非流动负债 + F051N as other_current_liabilities, -- 其他流动负债 + F052N as total_current_liabilities, -- 流动负债合计 + + -- 非流动负债 + F053N as long_term_borrowings, -- 长期借款 + F054N as bonds_payable, -- 应付债券 + F055N as long_term_payables, -- 长期应付款 + F056N as special_payables, -- 专项应付款 + F057N as estimated_liabilities, -- 预计负债 + F058N as deferred_tax_liabilities, -- 递延所得税负债 + F059N as other_non_current_liabilities, -- 其他非流动负债 + F060N as total_non_current_liabilities, -- 非流动负债合计 + F061N as total_liabilities, -- 负债合计 + + -- 所有者权益 + F062N as share_capital, -- 股本 + F063N as capital_reserve, -- 资本公积 + F064N as surplus_reserve, -- 盈余公积 + F065N as undistributed_profit, -- 未分配利润 + F066N as treasury_stock, -- 库存股 + F067N as minority_interests, -- 少数股东权益 + F070N as total_equity, -- 所有者权益合计 + F071N as total_liabilities_and_equity, -- 负债和所有者权益合计 + F073N as parent_company_equity, -- 归属于母公司所有者权益 + F074N as other_comprehensive_income, -- 其他综合收益 + + -- 新会计准则科目 + F110N as other_debt_investments, -- 其他债权投资 + F111N as other_equity_investments, -- 其他权益工具投资 + F112N as other_non_current_financial_assets, -- 其他非流动金融资产 + F115N as contract_liabilities, -- 合同负债 + F119N as contract_assets, -- 合同资产 + F120N as receivables_financing, -- 应收款项融资 + F121N as right_of_use_assets, -- 使用权资产 + F122N as lease_liabilities -- 租赁负债 + FROM ea_asset + WHERE SECCODE = :seccode + and F002V = '071001' + ORDER BY ENDDATE DESC LIMIT :limit + """) + + result = engine.execute(query, seccode=seccode, limit=limit) + data = [] + + for row in result: + # 安全计算关键比率,避免 Decimal 与 None 运算错误 + def to_float(v): + try: + return float(v) if v is not None else None + except Exception: + return None + + ta = to_float(row.total_assets) + tl = to_float(row.total_liabilities) + tca = to_float(row.total_current_assets) + tcl = to_float(row.total_current_liabilities) + inv = to_float(row.inventory) or 0.0 + + asset_liability_ratio_val = None + if ta is not None and ta != 0 and tl is not None: + asset_liability_ratio_val = (tl / ta) * 100 + + current_ratio_val = None + if tcl is not None and tcl != 0 and tca is not None: + current_ratio_val = tca / tcl + + quick_ratio_val = None + if tcl is not None and tcl != 0 and tca is not None: + quick_ratio_val = (tca - inv) / tcl + + period_data = { + 'period': format_date(row.ENDDATE), + 'declare_date': format_date(row.DECLAREDATE), + 'report_type': get_report_type(row.ENDDATE), + + # 资产部分 + 'assets': { + 'current_assets': { + 'cash': format_decimal(row.cash), + 'trading_financial_assets': format_decimal(row.trading_financial_assets), + 'notes_receivable': format_decimal(row.notes_receivable), + 'accounts_receivable': format_decimal(row.accounts_receivable), + 'prepayments': format_decimal(row.prepayments), + 'other_receivables': format_decimal(row.other_receivables), + 'inventory': format_decimal(row.inventory), + 'contract_assets': format_decimal(row.contract_assets), + 'other_current_assets': format_decimal(row.other_current_assets), + 'total': format_decimal(row.total_current_assets), + }, + 'non_current_assets': { + 'long_term_equity_investments': format_decimal(row.long_term_equity_investments), + 'investment_property': format_decimal(row.investment_property), + 'fixed_assets': format_decimal(row.fixed_assets), + 'construction_in_progress': format_decimal(row.construction_in_progress), + 'intangible_assets': format_decimal(row.intangible_assets), + 'goodwill': format_decimal(row.goodwill), + 'right_of_use_assets': format_decimal(row.right_of_use_assets), + 'deferred_tax_assets': format_decimal(row.deferred_tax_assets), + 'other_non_current_assets': format_decimal(row.other_non_current_assets), + 'total': format_decimal(row.total_non_current_assets), + }, + 'total': format_decimal(row.total_assets), + }, + + # 负债部分 + 'liabilities': { + 'current_liabilities': { + 'short_term_borrowings': format_decimal(row.short_term_borrowings), + 'notes_payable': format_decimal(row.notes_payable), + 'accounts_payable': format_decimal(row.accounts_payable), + 'advance_receipts': format_decimal(row.advance_receipts), + 'contract_liabilities': format_decimal(row.contract_liabilities), + 'employee_compensation_payable': format_decimal(row.employee_compensation_payable), + 'taxes_payable': format_decimal(row.taxes_payable), + 'other_payables': format_decimal(row.other_payables), + 'non_current_liabilities_due_within_one_year': format_decimal( + row.non_current_liabilities_due_within_one_year), + 'total': format_decimal(row.total_current_liabilities), + }, + 'non_current_liabilities': { + 'long_term_borrowings': format_decimal(row.long_term_borrowings), + 'bonds_payable': format_decimal(row.bonds_payable), + 'lease_liabilities': format_decimal(row.lease_liabilities), + 'deferred_tax_liabilities': format_decimal(row.deferred_tax_liabilities), + 'other_non_current_liabilities': format_decimal(row.other_non_current_liabilities), + 'total': format_decimal(row.total_non_current_liabilities), + }, + 'total': format_decimal(row.total_liabilities), + }, + + # 股东权益部分 + 'equity': { + 'share_capital': format_decimal(row.share_capital), + 'capital_reserve': format_decimal(row.capital_reserve), + 'surplus_reserve': format_decimal(row.surplus_reserve), + 'undistributed_profit': format_decimal(row.undistributed_profit), + 'treasury_stock': format_decimal(row.treasury_stock), + 'other_comprehensive_income': format_decimal(row.other_comprehensive_income), + 'parent_company_equity': format_decimal(row.parent_company_equity), + 'minority_interests': format_decimal(row.minority_interests), + 'total': format_decimal(row.total_equity), + }, + + # 关键比率 + 'key_ratios': { + 'asset_liability_ratio': format_decimal(asset_liability_ratio_val), + 'current_ratio': format_decimal(current_ratio_val), + 'quick_ratio': format_decimal(quick_ratio_val), + } + } + data.append(period_data) + + return jsonify({ + 'success': True, + 'data': data + }) + + except Exception as e: + return jsonify({ + 'success': False, + 'error': str(e) + }), 500 + + +@app.route('/api/financial/income-statement/', methods=['GET']) +def get_income_statement(seccode): + """获取完整的利润表数据""" + try: + limit = request.args.get('limit', 12, type=int) + + query = text(""" + SELECT distinct ENDDATE, + STARTDATE, + DECLAREDATE, + -- 营业收入部分 + F006N as revenue, -- 营业收入 + F035N as total_operating_revenue, -- 营业总收入 + F051N as other_income, -- 其他收入 + + -- 营业成本部分 + F007N as cost, -- 营业成本 + F008N as taxes_and_surcharges, -- 税金及附加 + F009N as selling_expenses, -- 销售费用 + F010N as admin_expenses, -- 管理费用 + F056N as rd_expenses, -- 研发费用 + F012N as financial_expenses, -- 财务费用 + F062N as interest_expense, -- 利息费用 + F063N as interest_income, -- 利息收入 + F013N as asset_impairment_loss, -- 资产减值损失(营业总成本) + F057N as credit_impairment_loss, -- 信用减值损失(营业总成本) + F036N as total_operating_cost, -- 营业总成本 + + -- 其他收益 + F014N as fair_value_change_income, -- 公允价值变动净收益 + F015N as investment_income, -- 投资收益 + F016N as investment_income_from_associates, -- 对联营企业和合营企业的投资收益 + F037N as exchange_income, -- 汇兑收益 + F058N as net_exposure_hedging_income, -- 净敞口套期收益 + F059N as asset_disposal_income, -- 资产处置收益 + + -- 利润部分 + F018N as operating_profit, -- 营业利润 + F019N as subsidy_income, -- 补贴收入 + F020N as non_operating_income, -- 营业外收入 + F021N as non_operating_expenses, -- 营业外支出 + F022N as non_current_asset_disposal_loss, -- 非流动资产处置损失 + F024N as total_profit, -- 利润总额 + F025N as income_tax_expense, -- 所得税 + F027N as net_profit, -- 净利润 + F028N as parent_net_profit, -- 归属于母公司所有者的净利润 + F029N as minority_profit, -- 少数股东损益 + + -- 持续经营 + F060N as continuing_operations_net_profit, -- 持续经营净利润 + F061N as discontinued_operations_net_profit, -- 终止经营净利润 + + -- 每股收益 + F031N as basic_eps, -- 基本每股收益 + F032N as diluted_eps, -- 稀释每股收益 + + -- 综合收益 + F038N as other_comprehensive_income_after_tax, -- 其他综合收益的税后净额 + F039N as total_comprehensive_income, -- 综合收益总额 + F040N as parent_company_comprehensive_income, -- 归属于母公司的综合收益 + F041N as minority_comprehensive_income -- 归属于少数股东的综合收益 + FROM ea_profit + WHERE SECCODE = :seccode + and F002V = '071001' + ORDER BY ENDDATE DESC LIMIT :limit + """) + + result = engine.execute(query, seccode=seccode, limit=limit) + data = [] + + for row in result: + # 计算一些衍生指标 + gross_profit = (row.revenue - row.cost) if row.revenue and row.cost else None + gross_margin = (gross_profit / row.revenue * 100) if row.revenue and gross_profit else None + operating_margin = ( + row.operating_profit / row.revenue * 100) if row.revenue and row.operating_profit else None + net_margin = (row.net_profit / row.revenue * 100) if row.revenue and row.net_profit else None + + # 三费合计 + three_expenses = 0 + if row.selling_expenses: + three_expenses += row.selling_expenses + if row.admin_expenses: + three_expenses += row.admin_expenses + if row.financial_expenses: + three_expenses += row.financial_expenses + + # 四费合计(加研发) + four_expenses = three_expenses + if row.rd_expenses: + four_expenses += row.rd_expenses + + period_data = { + 'period': format_date(row.ENDDATE), + 'start_date': format_date(row.STARTDATE), + 'declare_date': format_date(row.DECLAREDATE), + 'report_type': get_report_type(row.ENDDATE), + + # 收入部分 + 'revenue': { + 'operating_revenue': format_decimal(row.revenue), + 'total_operating_revenue': format_decimal(row.total_operating_revenue), + 'other_income': format_decimal(row.other_income), + }, + + # 成本费用部分 + 'costs': { + 'operating_cost': format_decimal(row.cost), + 'taxes_and_surcharges': format_decimal(row.taxes_and_surcharges), + 'selling_expenses': format_decimal(row.selling_expenses), + 'admin_expenses': format_decimal(row.admin_expenses), + 'rd_expenses': format_decimal(row.rd_expenses), + 'financial_expenses': format_decimal(row.financial_expenses), + 'interest_expense': format_decimal(row.interest_expense), + 'interest_income': format_decimal(row.interest_income), + 'asset_impairment_loss': format_decimal(row.asset_impairment_loss), + 'credit_impairment_loss': format_decimal(row.credit_impairment_loss), + 'total_operating_cost': format_decimal(row.total_operating_cost), + 'three_expenses_total': format_decimal(three_expenses), + 'four_expenses_total': format_decimal(four_expenses), + }, + + # 其他收益 + 'other_gains': { + 'fair_value_change': format_decimal(row.fair_value_change_income), + 'investment_income': format_decimal(row.investment_income), + 'investment_income_from_associates': format_decimal(row.investment_income_from_associates), + 'exchange_income': format_decimal(row.exchange_income), + 'asset_disposal_income': format_decimal(row.asset_disposal_income), + }, + + # 利润 + 'profit': { + 'gross_profit': format_decimal(gross_profit), + 'operating_profit': format_decimal(row.operating_profit), + 'total_profit': format_decimal(row.total_profit), + 'net_profit': format_decimal(row.net_profit), + 'parent_net_profit': format_decimal(row.parent_net_profit), + 'minority_profit': format_decimal(row.minority_profit), + 'continuing_operations_net_profit': format_decimal(row.continuing_operations_net_profit), + 'discontinued_operations_net_profit': format_decimal(row.discontinued_operations_net_profit), + }, + + # 非经营项目 + 'non_operating': { + 'subsidy_income': format_decimal(row.subsidy_income), + 'non_operating_income': format_decimal(row.non_operating_income), + 'non_operating_expenses': format_decimal(row.non_operating_expenses), + }, + + # 每股收益 + 'per_share': { + 'basic_eps': format_decimal(row.basic_eps), + 'diluted_eps': format_decimal(row.diluted_eps), + }, + + # 综合收益 + 'comprehensive_income': { + 'other_comprehensive_income': format_decimal(row.other_comprehensive_income_after_tax), + 'total_comprehensive_income': format_decimal(row.total_comprehensive_income), + 'parent_comprehensive_income': format_decimal(row.parent_company_comprehensive_income), + 'minority_comprehensive_income': format_decimal(row.minority_comprehensive_income), + }, + + # 关键比率 + 'margins': { + 'gross_margin': format_decimal(gross_margin), + 'operating_margin': format_decimal(operating_margin), + 'net_margin': format_decimal(net_margin), + 'expense_ratio': format_decimal(four_expenses / row.revenue * 100) if row.revenue else None, + 'rd_ratio': format_decimal( + row.rd_expenses / row.revenue * 100) if row.revenue and row.rd_expenses else None, + } + } + data.append(period_data) + + return jsonify({ + 'success': True, + 'data': data + }) + + except Exception as e: + return jsonify({ + 'success': False, + 'error': str(e) + }), 500 + + +@app.route('/api/financial/cashflow/', methods=['GET']) +def get_cashflow(seccode): + """获取完整的现金流量表数据""" + try: + limit = request.args.get('limit', 12, type=int) + + query = text(""" + SELECT distinct ENDDATE, + STARTDATE, + DECLAREDATE, + -- 经营活动现金流 + F006N as cash_from_sales, -- 销售商品、提供劳务收到的现金 + F007N as tax_refunds, -- 收到的税费返还 + F008N as other_operating_cash_received, -- 收到其他与经营活动有关的现金 + F009N as total_operating_cash_inflow, -- 经营活动现金流入小计 + F010N as cash_paid_for_goods, -- 购买商品、接受劳务支付的现金 + F011N as cash_paid_to_employees, -- 支付给职工以及为职工支付的现金 + F012N as taxes_paid, -- 支付的各项税费 + F013N as other_operating_cash_paid, -- 支付其他与经营活动有关的现金 + F014N as total_operating_cash_outflow, -- 经营活动现金流出小计 + F015N as net_operating_cash_flow, -- 经营活动产生的现金流量净额 + + -- 投资活动现金流 + F016N as cash_from_investment_recovery, -- 收回投资收到的现金 + F017N as cash_from_investment_income, -- 取得投资收益收到的现金 + F018N as cash_from_asset_disposal, -- 处置固定资产、无形资产和其他长期资产收回的现金净额 + F019N as cash_from_subsidiary_disposal, -- 处置子公司及其他营业单位收到的现金净额 + F020N as other_investment_cash_received, -- 收到其他与投资活动有关的现金 + F021N as total_investment_cash_inflow, -- 投资活动现金流入小计 + F022N as cash_paid_for_assets, -- 购建固定资产、无形资产和其他长期资产支付的现金 + F023N as cash_paid_for_investments, -- 投资支付的现金 + F024N as cash_paid_for_subsidiaries, -- 取得子公司及其他营业单位支付的现金净额 + F025N as other_investment_cash_paid, -- 支付其他与投资活动有关的现金 + F026N as total_investment_cash_outflow, -- 投资活动现金流出小计 + F027N as net_investment_cash_flow, -- 投资活动产生的现金流量净额 + + -- 筹资活动现金流 + F028N as cash_from_capital, -- 吸收投资收到的现金 + F029N as cash_from_borrowings, -- 取得借款收到的现金 + F030N as other_financing_cash_received, -- 收到其他与筹资活动有关的现金 + F031N as total_financing_cash_inflow, -- 筹资活动现金流入小计 + F032N as cash_paid_for_debt, -- 偿还债务支付的现金 + F033N as cash_paid_for_distribution, -- 分配股利、利润或偿付利息支付的现金 + F034N as other_financing_cash_paid, -- 支付其他与筹资活动有关的现金 + F035N as total_financing_cash_outflow, -- 筹资活动现金流出小计 + F036N as net_financing_cash_flow, -- 筹资活动产生的现金流量净额 + + -- 汇率变动影响 + F037N as exchange_rate_effect, -- 汇率变动对现金及现金等价物的影响 + F038N as other_cash_effect, -- 其他原因对现金的影响 + + -- 现金净增加额 + F039N as net_cash_increase, -- 现金及现金等价物净增加额 + F040N as beginning_cash_balance, -- 期初现金及现金等价物余额 + F041N as ending_cash_balance, -- 期末现金及现金等价物余额 + + -- 补充资料部分 + F044N as net_profit, -- 净利润 + F045N as asset_impairment, -- 资产减值准备 + F096N as credit_impairment, -- 信用减值损失 + F046N as depreciation, -- 固定资产折旧、油气资产折耗、生产性生物资产折旧 + F097N as right_of_use_asset_depreciation, -- 使用权资产折旧/摊销 + F047N as intangible_amortization, -- 无形资产摊销 + F048N as long_term_expense_amortization, -- 长期待摊费用摊销 + F049N as loss_on_disposal, -- 处置固定资产、无形资产和其他长期资产的损失 + F050N as fixed_asset_scrap_loss, -- 固定资产报废损失 + F051N as fair_value_change_loss, -- 公允价值变动损失 + F052N as financial_expenses, -- 财务费用 + F053N as investment_loss, -- 投资损失 + F054N as deferred_tax_asset_decrease, -- 递延所得税资产减少 + F055N as deferred_tax_liability_increase, -- 递延所得税负债增加 + F056N as inventory_decrease, -- 存货的减少 + F057N as operating_receivables_decrease, -- 经营性应收项目的减少 + F058N as operating_payables_increase, -- 经营性应付项目的增加 + F059N as other, -- 其他 + F060N as net_operating_cash_flow_indirect, -- 经营活动产生的现金流量净额(间接法) + + -- 特殊行业科目(金融) + F072N as customer_deposit_increase, -- 客户存款和同业存放款项净增加额 + F073N as central_bank_borrowing_increase, -- 向中央银行借款净增加额 + F081N as interest_and_commission_received, -- 收取利息、手续费及佣金的现金 + F087N as interest_and_commission_paid -- 支付利息、手续费及佣金的现金 + FROM ea_cashflow + WHERE SECCODE = :seccode + and F002V = '071001' + ORDER BY ENDDATE DESC LIMIT :limit + """) + + result = engine.execute(query, seccode=seccode, limit=limit) + data = [] + + for row in result: + # 计算一些衍生指标 + free_cash_flow = None + if row.net_operating_cash_flow and row.cash_paid_for_assets: + free_cash_flow = row.net_operating_cash_flow - row.cash_paid_for_assets + + period_data = { + 'period': format_date(row.ENDDATE), + 'start_date': format_date(row.STARTDATE), + 'declare_date': format_date(row.DECLAREDATE), + 'report_type': get_report_type(row.ENDDATE), + + # 经营活动现金流 + 'operating_activities': { + 'inflow': { + 'cash_from_sales': format_decimal(row.cash_from_sales), + 'tax_refunds': format_decimal(row.tax_refunds), + 'other': format_decimal(row.other_operating_cash_received), + 'total': format_decimal(row.total_operating_cash_inflow), + }, + 'outflow': { + 'cash_for_goods': format_decimal(row.cash_paid_for_goods), + 'cash_for_employees': format_decimal(row.cash_paid_to_employees), + 'taxes_paid': format_decimal(row.taxes_paid), + 'other': format_decimal(row.other_operating_cash_paid), + 'total': format_decimal(row.total_operating_cash_outflow), + }, + 'net_flow': format_decimal(row.net_operating_cash_flow), + }, + + # 投资活动现金流 + 'investment_activities': { + 'inflow': { + 'investment_recovery': format_decimal(row.cash_from_investment_recovery), + 'investment_income': format_decimal(row.cash_from_investment_income), + 'asset_disposal': format_decimal(row.cash_from_asset_disposal), + 'subsidiary_disposal': format_decimal(row.cash_from_subsidiary_disposal), + 'other': format_decimal(row.other_investment_cash_received), + 'total': format_decimal(row.total_investment_cash_inflow), + }, + 'outflow': { + 'asset_purchase': format_decimal(row.cash_paid_for_assets), + 'investments': format_decimal(row.cash_paid_for_investments), + 'subsidiaries': format_decimal(row.cash_paid_for_subsidiaries), + 'other': format_decimal(row.other_investment_cash_paid), + 'total': format_decimal(row.total_investment_cash_outflow), + }, + 'net_flow': format_decimal(row.net_investment_cash_flow), + }, + + # 筹资活动现金流 + 'financing_activities': { + 'inflow': { + 'capital': format_decimal(row.cash_from_capital), + 'borrowings': format_decimal(row.cash_from_borrowings), + 'other': format_decimal(row.other_financing_cash_received), + 'total': format_decimal(row.total_financing_cash_inflow), + }, + 'outflow': { + 'debt_repayment': format_decimal(row.cash_paid_for_debt), + 'distribution': format_decimal(row.cash_paid_for_distribution), + 'other': format_decimal(row.other_financing_cash_paid), + 'total': format_decimal(row.total_financing_cash_outflow), + }, + 'net_flow': format_decimal(row.net_financing_cash_flow), + }, + + # 现金变动 + 'cash_changes': { + 'exchange_rate_effect': format_decimal(row.exchange_rate_effect), + 'other_effect': format_decimal(row.other_cash_effect), + 'net_increase': format_decimal(row.net_cash_increase), + 'beginning_balance': format_decimal(row.beginning_cash_balance), + 'ending_balance': format_decimal(row.ending_cash_balance), + }, + + # 补充资料(间接法) + 'indirect_method': { + 'net_profit': format_decimal(row.net_profit), + 'adjustments': { + 'asset_impairment': format_decimal(row.asset_impairment), + 'credit_impairment': format_decimal(row.credit_impairment), + 'depreciation': format_decimal(row.depreciation), + 'intangible_amortization': format_decimal(row.intangible_amortization), + 'financial_expenses': format_decimal(row.financial_expenses), + 'investment_loss': format_decimal(row.investment_loss), + 'inventory_decrease': format_decimal(row.inventory_decrease), + 'receivables_decrease': format_decimal(row.operating_receivables_decrease), + 'payables_increase': format_decimal(row.operating_payables_increase), + }, + 'net_operating_cash_flow': format_decimal(row.net_operating_cash_flow_indirect), + }, + + # 关键指标 + 'key_metrics': { + 'free_cash_flow': format_decimal(free_cash_flow), + 'cash_flow_to_profit_ratio': format_decimal( + row.net_operating_cash_flow / row.net_profit) if row.net_profit and row.net_operating_cash_flow else None, + 'capex': format_decimal(row.cash_paid_for_assets), + } + } + data.append(period_data) + + return jsonify({ + 'success': True, + 'data': data + }) + + except Exception as e: + return jsonify({ + 'success': False, + 'error': str(e) + }), 500 + + +@app.route('/api/financial/financial-metrics/', methods=['GET']) +def get_financial_metrics(seccode): + """获取完整的财务指标数据""" + try: + limit = request.args.get('limit', 12, type=int) + + query = text(""" + SELECT distinct ENDDATE, + STARTDATE, + -- 每股指标 + F003N as eps, -- 每股收益 + F004N as basic_eps, -- 基本每股收益 + F005N as diluted_eps, -- 稀释每股收益 + F006N as deducted_eps, -- 扣除非经常性损益每股收益 + F007N as undistributed_profit_ps, -- 每股未分配利润 + F008N as bvps, -- 每股净资产 + F009N as adjusted_bvps, -- 调整后每股净资产 + F010N as capital_reserve_ps, -- 每股资本公积金 + F059N as cash_flow_ps, -- 每股现金流量 + F060N as operating_cash_flow_ps, -- 每股经营现金流量 + + -- 盈利能力指标 + F011N as operating_profit_margin, -- 营业利润率 + F012N as tax_rate, -- 营业税金率 + F013N as cost_ratio, -- 营业成本率 + F014N as roe, -- 净资产收益率 + F066N as roe_deducted, -- 净资产收益率(扣除非经常性损益) + F067N as roe_weighted, -- 净资产收益率-加权 + F068N as roe_weighted_deducted, -- 净资产收益率-加权(扣除非经常性损益) + F015N as investment_return, -- 投资收益率 + F016N as roa, -- 总资产报酬率 + F017N as net_profit_margin, -- 净利润率 + F078N as gross_margin, -- 毛利率 + F020N as cost_profit_ratio, -- 成本费用利润率 + + -- 费用率指标 + F018N as admin_expense_ratio, -- 管理费用率 + F019N as financial_expense_ratio, -- 财务费用率 + F021N as three_expense_ratio, -- 三费比重 + F091N as selling_expense, -- 销售费用 + F092N as admin_expense, -- 管理费用 + F093N as financial_expense, -- 财务费用 + F094N as three_expense_total, -- 三费合计 + F130N as rd_expense, -- 研发费用 + F131N as rd_expense_ratio, -- 研发费用率 + F132N as selling_expense_ratio, -- 销售费用率 + F133N as four_expense_ratio, -- 四费费用率 + + -- 运营能力指标 + F022N as receivable_turnover, -- 应收账款周转率 + F023N as inventory_turnover, -- 存货周转率 + F024N as working_capital_turnover, -- 运营资金周转率 + F025N as total_asset_turnover, -- 总资产周转率 + F026N as fixed_asset_turnover, -- 固定资产周转率 + F027N as receivable_days, -- 应收账款周转天数 + F028N as inventory_days, -- 存货周转天数 + F029N as current_asset_turnover, -- 流动资产周转率 + F030N as current_asset_days, -- 流动资产周转天数 + F031N as total_asset_days, -- 总资产周转天数 + F032N as equity_turnover, -- 股东权益周转率 + + -- 偿债能力指标 + F041N as asset_liability_ratio, -- 资产负债率 + F042N as current_ratio, -- 流动比率 + F043N as quick_ratio, -- 速动比率 + F044N as cash_ratio, -- 现金比率 + F045N as interest_coverage, -- 利息保障倍数 + F049N as conservative_quick_ratio, -- 保守速动比率 + F050N as cash_to_maturity_debt_ratio, -- 现金到期债务比率 + F051N as tangible_asset_debt_ratio, -- 有形资产净值债务率 + + -- 成长能力指标 + F052N as revenue_growth, -- 营业收入增长率 + F053N as net_profit_growth, -- 净利润增长率 + F054N as equity_growth, -- 净资产增长率 + F055N as fixed_asset_growth, -- 固定资产增长率 + F056N as total_asset_growth, -- 总资产增长率 + F057N as investment_income_growth, -- 投资收益增长率 + F058N as operating_profit_growth, -- 营业利润增长率 + F141N as deducted_profit_growth, -- 扣除非经常性损益后的净利润同比变化率 + F142N as parent_profit_growth, -- 归属于母公司所有者的净利润同比变化率 + F143N as operating_cash_flow_growth, -- 经营活动产生的现金流净额同比变化率 + + -- 现金流量指标 + F061N as operating_cash_to_short_debt, -- 经营净现金比率(短期债务) + F062N as operating_cash_to_total_debt, -- 经营净现金比率(全部债务) + F063N as operating_cash_to_profit_ratio, -- 经营活动现金净流量与净利润比率 + F064N as cash_revenue_ratio, -- 营业收入现金含量 + F065N as cash_recovery_rate, -- 全部资产现金回收率 + F082N as cash_to_profit_ratio, -- 净利含金量 + + -- 财务结构指标 + F033N as current_asset_ratio, -- 流动资产比率 + F034N as cash_ratio_structure, -- 货币资金比率 + F036N as inventory_ratio, -- 存货比率 + F037N as fixed_asset_ratio, -- 固定资产比率 + F038N as liability_structure_ratio, -- 负债结构比 + F039N as equity_ratio, -- 产权比率 + F040N as net_asset_ratio, -- 净资产比率 + F046N as working_capital, -- 营运资金 + F047N as non_current_liability_ratio, -- 非流动负债比率 + F048N as current_liability_ratio, -- 流动负债比率 + + -- 非经常性损益 + F076N as deducted_net_profit, -- 扣除非经常性损益后的净利润 + F077N as non_recurring_items, -- 非经常性损益合计 + F083N as non_recurring_ratio, -- 非经常性损益占比 + + -- 综合指标 + F085N as ebit, -- 基本获利能力(EBIT) + F086N as receivable_to_asset_ratio, -- 应收账款占比 + F087N as inventory_to_asset_ratio -- 存货占比 + FROM ea_financialindex + WHERE SECCODE = :seccode + ORDER BY ENDDATE DESC LIMIT :limit + """) + + result = engine.execute(query, seccode=seccode, limit=limit) + data = [] + + for row in result: + period_data = { + 'period': format_date(row.ENDDATE), + 'start_date': format_date(row.STARTDATE), + 'report_type': get_report_type(row.ENDDATE), + + # 每股指标 + 'per_share_metrics': { + 'eps': format_decimal(row.eps), + 'basic_eps': format_decimal(row.basic_eps), + 'diluted_eps': format_decimal(row.diluted_eps), + 'deducted_eps': format_decimal(row.deducted_eps), + 'bvps': format_decimal(row.bvps), + 'adjusted_bvps': format_decimal(row.adjusted_bvps), + 'undistributed_profit_ps': format_decimal(row.undistributed_profit_ps), + 'capital_reserve_ps': format_decimal(row.capital_reserve_ps), + 'cash_flow_ps': format_decimal(row.cash_flow_ps), + 'operating_cash_flow_ps': format_decimal(row.operating_cash_flow_ps), + }, + + # 盈利能力 + 'profitability': { + 'roe': format_decimal(row.roe), + 'roe_deducted': format_decimal(row.roe_deducted), + 'roe_weighted': format_decimal(row.roe_weighted), + 'roa': format_decimal(row.roa), + 'gross_margin': format_decimal(row.gross_margin), + 'net_profit_margin': format_decimal(row.net_profit_margin), + 'operating_profit_margin': format_decimal(row.operating_profit_margin), + 'cost_profit_ratio': format_decimal(row.cost_profit_ratio), + 'ebit': format_decimal(row.ebit), + }, + + # 费用率 + 'expense_ratios': { + 'selling_expense_ratio': format_decimal(row.selling_expense_ratio), + 'admin_expense_ratio': format_decimal(row.admin_expense_ratio), + 'financial_expense_ratio': format_decimal(row.financial_expense_ratio), + 'rd_expense_ratio': format_decimal(row.rd_expense_ratio), + 'three_expense_ratio': format_decimal(row.three_expense_ratio), + 'four_expense_ratio': format_decimal(row.four_expense_ratio), + }, + + # 运营能力 + 'operational_efficiency': { + 'receivable_turnover': format_decimal(row.receivable_turnover), + 'receivable_days': format_decimal(row.receivable_days), + 'inventory_turnover': format_decimal(row.inventory_turnover), + 'inventory_days': format_decimal(row.inventory_days), + 'total_asset_turnover': format_decimal(row.total_asset_turnover), + 'total_asset_days': format_decimal(row.total_asset_days), + 'fixed_asset_turnover': format_decimal(row.fixed_asset_turnover), + 'current_asset_turnover': format_decimal(row.current_asset_turnover), + 'working_capital_turnover': format_decimal(row.working_capital_turnover), + }, + + # 偿债能力 + 'solvency': { + 'current_ratio': format_decimal(row.current_ratio), + 'quick_ratio': format_decimal(row.quick_ratio), + 'cash_ratio': format_decimal(row.cash_ratio), + 'conservative_quick_ratio': format_decimal(row.conservative_quick_ratio), + 'asset_liability_ratio': format_decimal(row.asset_liability_ratio), + 'interest_coverage': format_decimal(row.interest_coverage), + 'cash_to_maturity_debt_ratio': format_decimal(row.cash_to_maturity_debt_ratio), + 'tangible_asset_debt_ratio': format_decimal(row.tangible_asset_debt_ratio), + }, + + # 成长能力 + 'growth': { + 'revenue_growth': format_decimal(row.revenue_growth), + 'net_profit_growth': format_decimal(row.net_profit_growth), + 'deducted_profit_growth': format_decimal(row.deducted_profit_growth), + 'parent_profit_growth': format_decimal(row.parent_profit_growth), + 'equity_growth': format_decimal(row.equity_growth), + 'total_asset_growth': format_decimal(row.total_asset_growth), + 'fixed_asset_growth': format_decimal(row.fixed_asset_growth), + 'operating_profit_growth': format_decimal(row.operating_profit_growth), + 'operating_cash_flow_growth': format_decimal(row.operating_cash_flow_growth), + }, + + # 现金流量 + 'cash_flow_quality': { + 'operating_cash_to_profit_ratio': format_decimal(row.operating_cash_to_profit_ratio), + 'cash_to_profit_ratio': format_decimal(row.cash_to_profit_ratio), + 'cash_revenue_ratio': format_decimal(row.cash_revenue_ratio), + 'cash_recovery_rate': format_decimal(row.cash_recovery_rate), + 'operating_cash_to_short_debt': format_decimal(row.operating_cash_to_short_debt), + 'operating_cash_to_total_debt': format_decimal(row.operating_cash_to_total_debt), + }, + + # 财务结构 + 'financial_structure': { + 'current_asset_ratio': format_decimal(row.current_asset_ratio), + 'fixed_asset_ratio': format_decimal(row.fixed_asset_ratio), + 'inventory_ratio': format_decimal(row.inventory_ratio), + 'receivable_to_asset_ratio': format_decimal(row.receivable_to_asset_ratio), + 'current_liability_ratio': format_decimal(row.current_liability_ratio), + 'non_current_liability_ratio': format_decimal(row.non_current_liability_ratio), + 'equity_ratio': format_decimal(row.equity_ratio), + }, + + # 非经常性损益 + 'non_recurring': { + 'deducted_net_profit': format_decimal(row.deducted_net_profit), + 'non_recurring_items': format_decimal(row.non_recurring_items), + 'non_recurring_ratio': format_decimal(row.non_recurring_ratio), + } + } + data.append(period_data) + + return jsonify({ + 'success': True, + 'data': data + }) + + except Exception as e: + return jsonify({ + 'success': False, + 'error': str(e) + }), 500 + + +@app.route('/api/financial/main-business/', methods=['GET']) +def get_main_business(seccode): + """获取主营业务构成数据(包括产品和行业分类)""" + try: + limit = request.args.get('periods', 4, type=int) # 获取最近几期的数据 + + # 获取最近的报告期 + period_query = text(""" + SELECT DISTINCT ENDDATE + FROM ea_mainproduct + WHERE SECCODE = :seccode + ORDER BY ENDDATE DESC LIMIT :limit + """) + + periods = engine.execute(period_query, seccode=seccode, limit=limit).fetchall() + + # 产品分类数据 + product_data = [] + for period in periods: + query = text(""" + SELECT distinct ENDDATE, + F002V as category, + F003V as content, + F005N as revenue, + F006N as cost, + F007N as profit + FROM ea_mainproduct + WHERE SECCODE = :seccode + AND ENDDATE = :enddate + ORDER BY F005N DESC + """) + + result = engine.execute(query, seccode=seccode, enddate=period[0]) + # Convert result to list to allow multiple iterations + rows = list(result) + + period_products = [] + total_revenue = 0 + for row in rows: + if row.revenue: + total_revenue += row.revenue + + for row in rows: + product = { + 'category': row.category, + 'content': row.content, + 'revenue': format_decimal(row.revenue), + 'cost': format_decimal(row.cost), + 'profit': format_decimal(row.profit), + 'profit_margin': format_decimal( + (row.profit / row.revenue * 100) if row.revenue and row.profit else None), + 'revenue_ratio': format_decimal( + (row.revenue / total_revenue * 100) if total_revenue and row.revenue else None) + } + period_products.append(product) + + if period_products: + product_data.append({ + 'period': format_date(period[0]), + 'report_type': get_report_type(period[0]), + 'total_revenue': format_decimal(total_revenue), + 'products': period_products + }) + + # 行业分类数据(从ea_mainind表) + industry_data = [] + for period in periods: + query = text(""" + SELECT distinct ENDDATE, + F002V as business_content, + F007N as main_revenue, + F008N as main_cost, + F009N as main_profit, + F010N as gross_margin, + F012N as revenue_ratio + FROM ea_mainind + WHERE SECCODE = :seccode + AND ENDDATE = :enddate + ORDER BY F007N DESC + """) + + result = engine.execute(query, seccode=seccode, enddate=period[0]) + # Convert result to list to allow multiple iterations + rows = list(result) + + period_industries = [] + for row in rows: + industry = { + 'content': row.business_content, + 'revenue': format_decimal(row.main_revenue), + 'cost': format_decimal(row.main_cost), + 'profit': format_decimal(row.main_profit), + 'gross_margin': format_decimal(row.gross_margin), + 'revenue_ratio': format_decimal(row.revenue_ratio) + } + period_industries.append(industry) + + if period_industries: + industry_data.append({ + 'period': format_date(period[0]), + 'report_type': get_report_type(period[0]), + 'industries': period_industries + }) + + return jsonify({ + 'success': True, + 'data': { + 'product_classification': product_data, + 'industry_classification': industry_data + } + }) + + except Exception as e: + return jsonify({ + 'success': False, + 'error': str(e) + }), 500 + + +@app.route('/api/financial/forecast/', methods=['GET']) +def get_forecast(seccode): + """获取业绩预告和预披露时间""" + try: + # 获取业绩预告 + forecast_query = text(""" + SELECT distinct DECLAREDATE, + F001D as report_date, + F002V as forecast_type_code, + F003V as forecast_type, + F004V as content, + F005V as reason, + F006C as latest_flag, + F007N as profit_lower, + F008N as profit_upper, + F009N as change_lower, + F010N as change_upper, + UPDATE_DATE + FROM ea_forecast + WHERE SECCODE = :seccode + ORDER BY F001D DESC, UPDATE_DATE DESC LIMIT 10 + """) + + forecast_result = engine.execute(forecast_query, seccode=seccode) + forecast_data = [] + + for row in forecast_result: + forecast = { + 'declare_date': format_date(row.DECLAREDATE), + 'report_date': format_date(row.report_date), + 'report_type': get_report_type(row.report_date), + 'forecast_type': row.forecast_type, + 'forecast_type_code': row.forecast_type_code, + 'content': row.content, + 'reason': row.reason, + 'is_latest': row.latest_flag == 'T', + 'profit_range': { + 'lower': format_decimal(row.profit_lower), + 'upper': format_decimal(row.profit_upper), + }, + 'change_range': { + 'lower': format_decimal(row.change_lower), + 'upper': format_decimal(row.change_upper), + }, + 'update_date': format_date(row.UPDATE_DATE) + } + forecast_data.append(forecast) + + # 获取预披露时间 + pretime_query = text(""" + SELECT distinct F001D as report_period, + F002D as scheduled_date, + F003D as change_date_1, + F004D as change_date_2, + F005D as change_date_3, + F006D as actual_date, + F007D as change_date_4, + F008D as change_date_5, + UPDATE_DATE + FROM ea_pretime + WHERE SECCODE = :seccode + ORDER BY F001D DESC LIMIT 8 + """) + + pretime_result = engine.execute(pretime_query, seccode=seccode) + pretime_data = [] + + for row in pretime_result: + # 收集所有变更日期 + change_dates = [] + for date in [row.change_date_1, row.change_date_2, row.change_date_3, + row.change_date_4, row.change_date_5]: + if date: + change_dates.append(format_date(date)) + + pretime = { + 'report_period': format_date(row.report_period), + 'report_type': get_report_type(row.report_period), + 'scheduled_date': format_date(row.scheduled_date), + 'actual_date': format_date(row.actual_date), + 'change_dates': change_dates, + 'update_date': format_date(row.UPDATE_DATE), + 'status': 'completed' if row.actual_date else 'pending' + } + pretime_data.append(pretime) + + return jsonify({ + 'success': True, + 'data': { + 'forecasts': forecast_data, + 'disclosure_schedule': pretime_data + } + }) + + except Exception as e: + return jsonify({ + 'success': False, + 'error': str(e) + }), 500 + + +@app.route('/api/financial/industry-rank/', methods=['GET']) +def get_industry_rank(seccode): + """获取行业排名数据""" + try: + limit = request.args.get('limit', 4, type=int) + + query = text(""" + SELECT distinct F001V as industry_level, + F002V as level_description, + F003D as report_date, + INDNAME as industry_name, + -- 每股收益 + F004N as eps, + F005N as eps_industry_avg, + F006N as eps_rank, + -- 扣除后每股收益 + F007N as deducted_eps, + F008N as deducted_eps_industry_avg, + F009N as deducted_eps_rank, + -- 每股净资产 + F010N as bvps, + F011N as bvps_industry_avg, + F012N as bvps_rank, + -- 净资产收益率 + F013N as roe, + F014N as roe_industry_avg, + F015N as roe_rank, + -- 每股未分配利润 + F016N as undistributed_profit_ps, + F017N as undistributed_profit_ps_industry_avg, + F018N as undistributed_profit_ps_rank, + -- 每股经营现金流量 + F019N as operating_cash_flow_ps, + F020N as operating_cash_flow_ps_industry_avg, + F021N as operating_cash_flow_ps_rank, + -- 营业收入增长率 + F022N as revenue_growth, + F023N as revenue_growth_industry_avg, + F024N as revenue_growth_rank, + -- 净利润增长率 + F025N as profit_growth, + F026N as profit_growth_industry_avg, + F027N as profit_growth_rank, + -- 营业利润率 + F028N as operating_margin, + F029N as operating_margin_industry_avg, + F030N as operating_margin_rank, + -- 资产负债率 + F031N as debt_ratio, + F032N as debt_ratio_industry_avg, + F033N as debt_ratio_rank, + -- 应收账款周转率 + F034N as receivable_turnover, + F035N as receivable_turnover_industry_avg, + F036N as receivable_turnover_rank, + UPDATE_DATE + FROM ea_finindexrank + WHERE SECCODE = :seccode + ORDER BY F003D DESC, F001V ASC LIMIT :limit_total + """) + + # 获取多个报告期的数据 + result = engine.execute(query, seccode=seccode, limit_total=limit * 4) + + # 按报告期和行业级别组织数据 + data_by_period = {} + for row in result: + period = format_date(row.report_date) + if period not in data_by_period: + data_by_period[period] = [] + + rank_data = { + 'industry_level': row.industry_level, + 'level_description': row.level_description, + 'industry_name': row.industry_name, + 'metrics': { + 'eps': { + 'value': format_decimal(row.eps), + 'industry_avg': format_decimal(row.eps_industry_avg), + 'rank': int(row.eps_rank) if row.eps_rank else None + }, + 'deducted_eps': { + 'value': format_decimal(row.deducted_eps), + 'industry_avg': format_decimal(row.deducted_eps_industry_avg), + 'rank': int(row.deducted_eps_rank) if row.deducted_eps_rank else None + }, + 'bvps': { + 'value': format_decimal(row.bvps), + 'industry_avg': format_decimal(row.bvps_industry_avg), + 'rank': int(row.bvps_rank) if row.bvps_rank else None + }, + 'roe': { + 'value': format_decimal(row.roe), + 'industry_avg': format_decimal(row.roe_industry_avg), + 'rank': int(row.roe_rank) if row.roe_rank else None + }, + 'operating_cash_flow_ps': { + 'value': format_decimal(row.operating_cash_flow_ps), + 'industry_avg': format_decimal(row.operating_cash_flow_ps_industry_avg), + 'rank': int(row.operating_cash_flow_ps_rank) if row.operating_cash_flow_ps_rank else None + }, + 'revenue_growth': { + 'value': format_decimal(row.revenue_growth), + 'industry_avg': format_decimal(row.revenue_growth_industry_avg), + 'rank': int(row.revenue_growth_rank) if row.revenue_growth_rank else None + }, + 'profit_growth': { + 'value': format_decimal(row.profit_growth), + 'industry_avg': format_decimal(row.profit_growth_industry_avg), + 'rank': int(row.profit_growth_rank) if row.profit_growth_rank else None + }, + 'operating_margin': { + 'value': format_decimal(row.operating_margin), + 'industry_avg': format_decimal(row.operating_margin_industry_avg), + 'rank': int(row.operating_margin_rank) if row.operating_margin_rank else None + }, + 'debt_ratio': { + 'value': format_decimal(row.debt_ratio), + 'industry_avg': format_decimal(row.debt_ratio_industry_avg), + 'rank': int(row.debt_ratio_rank) if row.debt_ratio_rank else None + }, + 'receivable_turnover': { + 'value': format_decimal(row.receivable_turnover), + 'industry_avg': format_decimal(row.receivable_turnover_industry_avg), + 'rank': int(row.receivable_turnover_rank) if row.receivable_turnover_rank else None + } + } + } + data_by_period[period].append(rank_data) + + # 转换为列表格式 + data = [] + for period, ranks in data_by_period.items(): + data.append({ + 'period': period, + 'report_type': get_report_type(period), + 'rankings': ranks + }) + + return jsonify({ + 'success': True, + 'data': data + }) + + except Exception as e: + return jsonify({ + 'success': False, + 'error': str(e) + }), 500 + + +@app.route('/api/financial/comparison/', methods=['GET']) +def get_period_comparison(seccode): + """获取不同报告期的对比数据""" + try: + periods = request.args.get('periods', 8, type=int) + + # 获取多期财务数据进行对比 + query = text(""" + SELECT distinct fi.ENDDATE, + fi.F089N as revenue, + fi.F101N as net_profit, + fi.F102N as parent_net_profit, + fi.F078N as gross_margin, + fi.F017N as net_margin, + fi.F014N as roe, + fi.F016N as roa, + fi.F052N as revenue_growth, + fi.F053N as profit_growth, + fi.F003N as eps, + fi.F060N as operating_cash_flow_ps, + fi.F042N as current_ratio, + fi.F041N as debt_ratio, + fi.F105N as operating_cash_flow, + fi.F118N as total_assets, + fi.F121N as total_liabilities, + fi.F128N as total_equity + FROM ea_financialindex fi + WHERE fi.SECCODE = :seccode + ORDER BY fi.ENDDATE DESC LIMIT :periods + """) + + result = engine.execute(query, seccode=seccode, periods=periods) + + data = [] + for row in result: + period_data = { + 'period': format_date(row.ENDDATE), + 'report_type': get_report_type(row.ENDDATE), + 'performance': { + 'revenue': format_decimal(row.revenue), + 'net_profit': format_decimal(row.net_profit), + 'parent_net_profit': format_decimal(row.parent_net_profit), + 'operating_cash_flow': format_decimal(row.operating_cash_flow), + }, + 'profitability': { + 'gross_margin': format_decimal(row.gross_margin), + 'net_margin': format_decimal(row.net_margin), + 'roe': format_decimal(row.roe), + 'roa': format_decimal(row.roa), + }, + 'growth': { + 'revenue_growth': format_decimal(row.revenue_growth), + 'profit_growth': format_decimal(row.profit_growth), + }, + 'per_share': { + 'eps': format_decimal(row.eps), + 'operating_cash_flow_ps': format_decimal(row.operating_cash_flow_ps), + }, + 'financial_health': { + 'current_ratio': format_decimal(row.current_ratio), + 'debt_ratio': format_decimal(row.debt_ratio), + 'total_assets': format_decimal(row.total_assets), + 'total_liabilities': format_decimal(row.total_liabilities), + 'total_equity': format_decimal(row.total_equity), + } + } + data.append(period_data) + + # 计算同比和环比变化 + for i in range(len(data)): + if i > 0: # 环比 + data[i]['qoq_changes'] = { + 'revenue': calculate_change(data[i]['performance']['revenue'], + data[i - 1]['performance']['revenue']), + 'net_profit': calculate_change(data[i]['performance']['net_profit'], + data[i - 1]['performance']['net_profit']), + } + + # 同比(找到去年同期) + current_period = data[i]['period'] + yoy_period = get_yoy_period(current_period) + for j in range(len(data)): + if data[j]['period'] == yoy_period: + data[i]['yoy_changes'] = { + 'revenue': calculate_change(data[i]['performance']['revenue'], + data[j]['performance']['revenue']), + 'net_profit': calculate_change(data[i]['performance']['net_profit'], + data[j]['performance']['net_profit']), + } + break + + return jsonify({ + 'success': True, + 'data': data + }) + + except Exception as e: + return jsonify({ + 'success': False, + 'error': str(e) + }), 500 + + +# 辅助函数 +def calculate_change(current, previous): + """计算变化率""" + if previous and current: + return format_decimal((current - previous) / abs(previous) * 100) + return None + + +def get_yoy_period(date_str): + """获取去年同期""" + if not date_str: + return None + try: + date = datetime.strptime(date_str, '%Y-%m-%d') + yoy_date = date.replace(year=date.year - 1) + return yoy_date.strftime('%Y-%m-%d') + except: + return None + + +@app.route('/api/market/trade/', methods=['GET']) +def get_trade_data(seccode): + """获取股票交易数据(日K线)""" + try: + days = request.args.get('days', 60, type=int) + end_date = request.args.get('end_date', datetime.now().strftime('%Y-%m-%d')) + + query = text(""" + SELECT TRADEDATE, + SECNAME, + F002N as pre_close, + F003N as open, + F004N as volume, + F005N as high, + F006N as low, + F007N as close, + F008N as trades_count, + F009N as change_amount, + F010N as change_percent, + F011N as amount, + F012N as turnover_rate, + F013N as amplitude, + F020N as total_shares, + F021N as float_shares, + F026N as pe_ratio + FROM ea_trade + WHERE SECCODE = :seccode + AND TRADEDATE <= :end_date + ORDER BY TRADEDATE DESC + LIMIT :days + """) + + result = engine.execute(query, seccode=seccode, end_date=end_date, days=days) + + data = [] + for row in result: + data.append({ + 'date': format_date(row.TRADEDATE), + 'stock_name': row.SECNAME, + 'open': format_decimal(row.open), + 'high': format_decimal(row.high), + 'low': format_decimal(row.low), + 'close': format_decimal(row.close), + 'pre_close': format_decimal(row.pre_close), + 'volume': format_decimal(row.volume), + 'amount': format_decimal(row.amount), + 'change_amount': format_decimal(row.change_amount), + 'change_percent': format_decimal(row.change_percent), + 'turnover_rate': format_decimal(row.turnover_rate), + 'amplitude': format_decimal(row.amplitude), + 'trades_count': format_decimal(row.trades_count), + 'pe_ratio': format_decimal(row.pe_ratio), + 'total_shares': format_decimal(row.total_shares), + 'float_shares': format_decimal(row.float_shares), + }) + + # 倒序,让最早的日期在前 + data.reverse() + + # 计算统计数据 + if data: + prices = [d['close'] for d in data if d['close']] + stats = { + 'highest': max(prices) if prices else None, + 'lowest': min(prices) if prices else None, + 'average': sum(prices) / len(prices) if prices else None, + 'latest_price': data[-1]['close'] if data else None, + 'total_volume': sum([d['volume'] for d in data if d['volume']]) if data else None, + 'total_amount': sum([d['amount'] for d in data if d['amount']]) if data else None, + } + else: + stats = {} + + return jsonify({ + 'success': True, + 'data': data, + 'stats': stats + }) + + except Exception as e: + return jsonify({ + 'success': False, + 'error': str(e) + }), 500 + + +@app.route('/api/market/funding/', methods=['GET']) +def get_funding_data(seccode): + """获取融资融券数据""" + try: + days = request.args.get('days', 30, type=int) + + query = text(""" + SELECT TRADEDATE, + SECNAME, + F001N as financing_balance, + F002N as financing_buy, + F003N as financing_repay, + F004N as securities_balance, + F006N as securities_sell, + F007N as securities_repay, + F008N as securities_balance_amount, + F009N as total_balance + FROM ea_funding + WHERE SECCODE = :seccode + ORDER BY TRADEDATE DESC LIMIT :days + """) + + result = engine.execute(query, seccode=seccode, days=days) + + data = [] + for row in result: + data.append({ + 'date': format_date(row.TRADEDATE), + 'stock_name': row.SECNAME, + 'financing': { + 'balance': format_decimal(row.financing_balance), + 'buy': format_decimal(row.financing_buy), + 'repay': format_decimal(row.financing_repay), + 'net': format_decimal( + row.financing_buy - row.financing_repay) if row.financing_buy and row.financing_repay else None + }, + 'securities': { + 'balance': format_decimal(row.securities_balance), + 'sell': format_decimal(row.securities_sell), + 'repay': format_decimal(row.securities_repay), + 'balance_amount': format_decimal(row.securities_balance_amount) + }, + 'total_balance': format_decimal(row.total_balance) + }) + + data.reverse() + + return jsonify({ + 'success': True, + 'data': data + }) + + except Exception as e: + return jsonify({ + 'success': False, + 'error': str(e) + }), 500 + + +@app.route('/api/market/bigdeal/', methods=['GET']) +def get_bigdeal_data(seccode): + """获取大宗交易数据""" + try: + days = request.args.get('days', 30, type=int) + + query = text(""" + SELECT TRADEDATE, + SECNAME, + F001V as exchange, + F002V as buyer_dept, + F003V as seller_dept, + F004N as price, + F005N as volume, + F006N as amount, + F007N as seq_no + FROM ea_bigdeal + WHERE SECCODE = :seccode + ORDER BY TRADEDATE DESC, F007N LIMIT :days + """) + + result = engine.execute(query, seccode=seccode, days=days) + + data = [] + for row in result: + data.append({ + 'date': format_date(row.TRADEDATE), + 'stock_name': row.SECNAME, + 'exchange': row.exchange, + 'buyer_dept': row.buyer_dept, + 'seller_dept': row.seller_dept, + 'price': format_decimal(row.price), + 'volume': format_decimal(row.volume), + 'amount': format_decimal(row.amount), + 'seq_no': int(row.seq_no) if row.seq_no else None + }) + + # 按日期分组统计 + daily_stats = {} + for item in data: + date = item['date'] + if date not in daily_stats: + daily_stats[date] = { + 'date': date, + 'count': 0, + 'total_volume': 0, + 'total_amount': 0, + 'avg_price': 0, + 'deals': [] + } + daily_stats[date]['count'] += 1 + daily_stats[date]['total_volume'] += item['volume'] or 0 + daily_stats[date]['total_amount'] += item['amount'] or 0 + daily_stats[date]['deals'].append(item) + + # 计算平均价格 + for date in daily_stats: + if daily_stats[date]['total_volume'] > 0: + daily_stats[date]['avg_price'] = daily_stats[date]['total_amount'] / daily_stats[date]['total_volume'] + + return jsonify({ + 'success': True, + 'data': data, + 'daily_stats': list(daily_stats.values()) + }) + + except Exception as e: + return jsonify({ + 'success': False, + 'error': str(e) + }), 500 + + +@app.route('/api/market/unusual/', methods=['GET']) +def get_unusual_data(seccode): + """获取龙虎榜数据""" + try: + days = request.args.get('days', 30, type=int) + + query = text(""" + SELECT TRADEDATE, + SECNAME, + F001V as info_type_code, + F002V as info_type, + F003C as trade_type, + F004N as rank_no, + F005V as dept_name, + F006N as buy_amount, + F007N as sell_amount, + F008N as net_amount + FROM ea_unusual + WHERE SECCODE = :seccode + ORDER BY TRADEDATE DESC, F004N LIMIT 100 + """) + + result = engine.execute(query, seccode=seccode) + + data = [] + for row in result: + data.append({ + 'date': format_date(row.TRADEDATE), + 'stock_name': row.SECNAME, + 'info_type': row.info_type, + 'info_type_code': row.info_type_code, + 'trade_type': 'buy' if row.trade_type == 'B' else 'sell' if row.trade_type == 'S' else 'unknown', + 'rank': int(row.rank_no) if row.rank_no else None, + 'dept_name': row.dept_name, + 'buy_amount': format_decimal(row.buy_amount), + 'sell_amount': format_decimal(row.sell_amount), + 'net_amount': format_decimal(row.net_amount) + }) + + # 按日期分组 + grouped_data = {} + for item in data: + date = item['date'] + if date not in grouped_data: + grouped_data[date] = { + 'date': date, + 'info_types': set(), + 'buyers': [], + 'sellers': [], + 'total_buy': 0, + 'total_sell': 0, + 'net_amount': 0 + } + + grouped_data[date]['info_types'].add(item['info_type']) + + if item['trade_type'] == 'buy': + grouped_data[date]['buyers'].append(item) + grouped_data[date]['total_buy'] += item['buy_amount'] or 0 + elif item['trade_type'] == 'sell': + grouped_data[date]['sellers'].append(item) + grouped_data[date]['total_sell'] += item['sell_amount'] or 0 + + grouped_data[date]['net_amount'] = grouped_data[date]['total_buy'] - grouped_data[date]['total_sell'] + + # 转换set为list + for date in grouped_data: + grouped_data[date]['info_types'] = list(grouped_data[date]['info_types']) + + return jsonify({ + 'success': True, + 'data': data, + 'grouped_data': list(grouped_data.values()) + }) + + except Exception as e: + return jsonify({ + 'success': False, + 'error': str(e) + }), 500 + + +@app.route('/api/market/pledge/', methods=['GET']) +def get_pledge_data(seccode): + """获取股权质押数据""" + try: + query = text(""" + SELECT ENDDATE, + STARTDATE, + SECNAME, + F001N as unrestricted_pledge, + F002N as restricted_pledge, + F003N as total_shares_a, + F004N as pledge_count, + F005N as pledge_ratio + FROM ea_pledgeratio + WHERE SECCODE = :seccode + ORDER BY ENDDATE DESC LIMIT 12 + """) + + result = engine.execute(query, seccode=seccode) + + data = [] + for row in result: + total_pledge = (row.unrestricted_pledge or 0) + (row.restricted_pledge or 0) + data.append({ + 'end_date': format_date(row.ENDDATE), + 'start_date': format_date(row.STARTDATE), + 'stock_name': row.SECNAME, + 'unrestricted_pledge': format_decimal(row.unrestricted_pledge), + 'restricted_pledge': format_decimal(row.restricted_pledge), + 'total_pledge': format_decimal(total_pledge), + 'total_shares': format_decimal(row.total_shares_a), + 'pledge_count': int(row.pledge_count) if row.pledge_count else None, + 'pledge_ratio': format_decimal(row.pledge_ratio) + }) + + return jsonify({ + 'success': True, + 'data': data + }) + + except Exception as e: + return jsonify({ + 'success': False, + 'error': str(e) + }), 500 + + +@app.route('/api/market/summary/', methods=['GET']) +def get_market_summary(seccode): + """获取市场数据汇总""" + try: + # 获取最新交易数据 + trade_query = text(""" + SELECT * + FROM ea_trade + WHERE SECCODE = :seccode + ORDER BY TRADEDATE DESC LIMIT 1 + """) + + # 获取最新融资融券数据 + funding_query = text(""" + SELECT * + FROM ea_funding + WHERE SECCODE = :seccode + ORDER BY TRADEDATE DESC LIMIT 1 + """) + + # 获取最新质押数据 + pledge_query = text(""" + SELECT * + FROM ea_pledgeratio + WHERE SECCODE = :seccode + ORDER BY ENDDATE DESC LIMIT 1 + """) + + trade_result = engine.execute(trade_query, seccode=seccode).fetchone() + funding_result = engine.execute(funding_query, seccode=seccode).fetchone() + pledge_result = engine.execute(pledge_query, seccode=seccode).fetchone() + + summary = { + 'stock_code': seccode, + 'stock_name': trade_result.SECNAME if trade_result else None, + 'latest_trade': { + 'date': format_date(trade_result.TRADEDATE) if trade_result else None, + 'close': format_decimal(trade_result.F007N) if trade_result else None, + 'change_percent': format_decimal(trade_result.F010N) if trade_result else None, + 'volume': format_decimal(trade_result.F004N) if trade_result else None, + 'amount': format_decimal(trade_result.F011N) if trade_result else None, + 'pe_ratio': format_decimal(trade_result.F026N) if trade_result else None, + 'turnover_rate': format_decimal(trade_result.F012N) if trade_result else None, + } if trade_result else None, + 'latest_funding': { + 'date': format_date(funding_result.TRADEDATE) if funding_result else None, + 'financing_balance': format_decimal(funding_result.F001N) if funding_result else None, + 'securities_balance': format_decimal(funding_result.F004N) if funding_result else None, + 'total_balance': format_decimal(funding_result.F009N) if funding_result else None, + } if funding_result else None, + 'latest_pledge': { + 'date': format_date(pledge_result.ENDDATE) if pledge_result else None, + 'pledge_ratio': format_decimal(pledge_result.F005N) if pledge_result else None, + 'pledge_count': int(pledge_result.F004N) if pledge_result and pledge_result.F004N else None, + } if pledge_result else None + } + + return jsonify({ + 'success': True, + 'data': summary + }) + + except Exception as e: + return jsonify({ + 'success': False, + 'error': str(e) + }), 500 + + +@app.route('/api/stocks/search', methods=['GET']) +def search_stocks(): + """搜索股票(支持股票代码、股票简称、拼音首字母)""" + try: + query = request.args.get('q', '').strip() + limit = request.args.get('limit', 20, type=int) + + if not query: + return jsonify({ + 'success': False, + 'error': '请输入搜索关键词' + }), 400 + + with engine.connect() as conn: + test_sql = text(""" + SELECT SECCODE, SECNAME, F001V, F003V, F010V, F011V + FROM ea_stocklist + WHERE SECCODE = '300750' + OR F001V LIKE '%ndsd%' LIMIT 5 + """) + test_result = conn.execute(test_sql).fetchall() + + # 构建搜索SQL - 支持股票代码、股票简称、拼音简称搜索 + search_sql = text(""" + SELECT DISTINCT SECCODE as stock_code, + SECNAME as stock_name, + F001V as pinyin_abbr, + F003V as security_type, + F005V as exchange, + F011V as listing_status + FROM ea_stocklist + WHERE ( + UPPER(SECCODE) LIKE UPPER(:query_pattern) + OR UPPER(SECNAME) LIKE UPPER(:query_pattern) + OR UPPER(F001V) LIKE UPPER(:query_pattern) + ) + -- 基本过滤条件:只搜索正常的A股和B股 + AND (F011V = '正常上市' OR F010V = '013001') -- 正常上市状态 + AND F003V IN ('A股', 'B股') -- 只搜索A股和B股 + ORDER BY CASE + WHEN UPPER(SECCODE) = UPPER(:exact_query) THEN 1 + WHEN UPPER(SECNAME) = UPPER(:exact_query) THEN 2 + WHEN UPPER(F001V) = UPPER(:exact_query) THEN 3 + WHEN UPPER(SECCODE) LIKE UPPER(:prefix_pattern) THEN 4 + WHEN UPPER(SECNAME) LIKE UPPER(:prefix_pattern) THEN 5 + WHEN UPPER(F001V) LIKE UPPER(:prefix_pattern) THEN 6 + ELSE 7 + END, + SECCODE LIMIT :limit + """) + + result = conn.execute(search_sql, { + 'query_pattern': f'%{query}%', + 'exact_query': query, + 'prefix_pattern': f'{query}%', + 'limit': limit + }).fetchall() + + stocks = [] + for row in result: + # 获取当前价格 + current_price, _ = get_latest_price_from_clickhouse(row.stock_code) + + stocks.append({ + 'stock_code': row.stock_code, + 'stock_name': row.stock_name, + 'current_price': current_price or 0, # 添加当前价格 + 'pinyin_abbr': row.pinyin_abbr, + 'security_type': row.security_type, + 'exchange': row.exchange, + 'listing_status': row.listing_status + }) + + return jsonify({ + 'success': True, + 'data': stocks, + 'count': len(stocks) + }) + + except Exception as e: + return jsonify({ + 'success': False, + 'error': str(e) + }), 500 + + +@app.route('/api/market/heatmap', methods=['GET']) +def get_market_heatmap(): + """获取市场热力图数据(基于市值和涨跌幅)""" + try: + # 获取交易日期参数 + trade_date = request.args.get('date') + # 前端显示用的limit,但统计数据会基于全部股票 + display_limit = request.args.get('limit', 500, type=int) + + with engine.connect() as conn: + # 如果没有指定日期,获取最新交易日 + if not trade_date: + latest_date_result = conn.execute(text(""" + SELECT MAX(TRADEDATE) as latest_date + FROM ea_trade + """)).fetchone() + trade_date = latest_date_result.latest_date if latest_date_result else None + + if not trade_date: + return jsonify({ + 'success': False, + 'error': '无法获取交易数据' + }), 404 + + # 获取全部股票数据用于统计 + all_stocks_sql = text(""" + SELECT t.SECCODE as stock_code, + t.SECNAME as stock_name, + t.F010N as change_percent, -- 涨跌幅 + t.F007N as close_price, -- 收盘价 + t.F021N * t.F007N / 100000000 as market_cap, -- 市值(亿元) + t.F011N / 100000000 as amount, -- 成交额(亿元) + t.F012N as turnover_rate, -- 换手率 + b.F034V as industry, -- 申万行业分类一级名称 + b.F026V as province -- 所属省份 + FROM ea_trade t + LEFT JOIN ea_baseinfo b ON t.SECCODE = b.SECCODE + WHERE t.TRADEDATE = :trade_date + AND t.F010N IS NOT NULL -- 仅统计当日有涨跌幅数据的股票 + ORDER BY market_cap DESC + """) + + all_result = conn.execute(all_stocks_sql, { + 'trade_date': trade_date + }).fetchall() + + # 计算统计数据(基于全部股票) + total_market_cap = 0 + total_amount = 0 + rising_count = 0 + falling_count = 0 + flat_count = 0 + + all_data = [] + for row in all_result: + # F010N 已在 SQL 中确保非空 + change_percent = float(row.change_percent) + market_cap = float(row.market_cap) if row.market_cap else 0 + amount = float(row.amount) if row.amount else 0 + + total_market_cap += market_cap + total_amount += amount + + if change_percent > 0: + rising_count += 1 + elif change_percent < 0: + falling_count += 1 + else: + flat_count += 1 + + all_data.append({ + 'stock_code': row.stock_code, + 'stock_name': row.stock_name, + 'change_percent': change_percent, + 'close_price': float(row.close_price) if row.close_price else 0, + 'market_cap': market_cap, + 'amount': amount, + 'turnover_rate': float(row.turnover_rate) if row.turnover_rate else 0, + 'industry': row.industry, + 'province': row.province + }) + + # 只返回前display_limit条用于热力图显示 + heatmap_data = all_data[:display_limit] + + return jsonify({ + 'success': True, + 'data': heatmap_data, + 'trade_date': trade_date.strftime('%Y-%m-%d') if hasattr(trade_date, 'strftime') else str(trade_date), + 'count': len(all_data), # 全部股票数量 + 'display_count': len(heatmap_data), # 显示的股票数量 + 'statistics': { + 'total_market_cap': round(total_market_cap, 2), # 总市值(亿元) + 'total_amount': round(total_amount, 2), # 总成交额(亿元) + 'rising_count': rising_count, # 上涨家数 + 'falling_count': falling_count, # 下跌家数 + 'flat_count': flat_count # 平盘家数 + } + }) + + except Exception as e: + return jsonify({ + 'success': False, + 'error': str(e) + }), 500 + + +@app.route('/api/market/statistics', methods=['GET']) +def get_market_statistics(): + """获取市场统计数据(从ea_blocktrading表)""" + try: + # 获取交易日期参数 + trade_date = request.args.get('date') + + with engine.connect() as conn: + # 如果没有指定日期,获取最新交易日 + if not trade_date: + latest_date_result = conn.execute(text(""" + SELECT MAX(TRADEDATE) as latest_date + FROM ea_blocktrading + """)).fetchone() + trade_date = latest_date_result.latest_date if latest_date_result else None + + if not trade_date: + return jsonify({ + 'success': False, + 'error': '无法获取统计数据' + }), 404 + + # 获取沪深两市的统计数据 + stats_sql = text(""" + SELECT EXCHANGECODE, + EXCHANGENAME, + F001V as indicator_code, + F002V as indicator_name, + F003N as indicator_value, + F004V as unit, + TRADEDATE + FROM ea_blocktrading + WHERE TRADEDATE = :trade_date + AND EXCHANGECODE IN ('012001', '012002') -- 只获取上交所和深交所的数据 + AND F001V IN ( + '250006', '250014', -- 深交所股票总市值、上交所市价总值 + '250007', '250015', -- 深交所股票流通市值、上交所流通市值 + '250008', -- 深交所股票成交金额 + '250010', '250019', -- 深交所股票平均市盈率、上交所平均市盈率 + '250050', '250001' -- 上交所上市公司家数、深交所上市公司数 + ) + """) + + result = conn.execute(stats_sql, { + 'trade_date': trade_date + }).fetchall() + + # 整理数据 + statistics = {} + for row in result: + key = f"{row.EXCHANGECODE}_{row.indicator_code}" + statistics[key] = { + 'exchange_code': row.EXCHANGECODE, + 'exchange_name': row.EXCHANGENAME, + 'indicator_code': row.indicator_code, + 'indicator_name': row.indicator_name, + 'value': float(row.indicator_value) if row.indicator_value else 0, + 'unit': row.unit + } + + # 汇总数据 + summary = { + 'total_market_cap': 0, # 总市值 + 'total_float_cap': 0, # 流通市值 + 'total_amount': 0, # 成交额 + 'sh_pe_ratio': 0, # 上交所市盈率 + 'sz_pe_ratio': 0, # 深交所市盈率 + 'sh_companies': 0, # 上交所上市公司数 + 'sz_companies': 0 # 深交所上市公司数 + } + + # 计算汇总值 + if '012001_250014' in statistics: # 上交所市价总值 + summary['total_market_cap'] += statistics['012001_250014']['value'] + if '012002_250006' in statistics: # 深交所股票总市值 + summary['total_market_cap'] += statistics['012002_250006']['value'] + + if '012001_250015' in statistics: # 上交所流通市值 + summary['total_float_cap'] += statistics['012001_250015']['value'] + if '012002_250007' in statistics: # 深交所股票流通市值 + summary['total_float_cap'] += statistics['012002_250007']['value'] + + # 成交额需要获取上交所的数据 + # 获取上交所成交金额 + sh_amount_result = conn.execute(text(""" + SELECT F003N + FROM ea_blocktrading + WHERE TRADEDATE = :trade_date + AND EXCHANGECODE = '012001' + AND F002V LIKE '%成交金额%' LIMIT 1 + """), {'trade_date': trade_date}).fetchone() + + sh_amount = float(sh_amount_result.F003N) if sh_amount_result and sh_amount_result.F003N else 0 + sz_amount = statistics['012002_250008']['value'] if '012002_250008' in statistics else 0 + summary['total_amount'] = sh_amount + sz_amount + + if '012001_250019' in statistics: # 上交所平均市盈率 + summary['sh_pe_ratio'] = statistics['012001_250019']['value'] + if '012002_250010' in statistics: # 深交所股票平均市盈率 + summary['sz_pe_ratio'] = statistics['012002_250010']['value'] + + if '012001_250050' in statistics: # 上交所上市公司家数 + summary['sh_companies'] = int(statistics['012001_250050']['value']) + if '012002_250001' in statistics: # 深交所上市公司数 + summary['sz_companies'] = int(statistics['012002_250001']['value']) + + # 获取可用的交易日期列表 + available_dates_result = conn.execute(text(""" + SELECT DISTINCT TRADEDATE + FROM ea_blocktrading + WHERE EXCHANGECODE IN ('012001', '012002') + ORDER BY TRADEDATE DESC LIMIT 30 + """)).fetchall() + + available_dates = [str(row.TRADEDATE) for row in available_dates_result] + + return jsonify({ + 'success': True, + 'trade_date': str(trade_date), + 'summary': summary, + 'details': list(statistics.values()), + 'available_dates': available_dates + }) + + except Exception as e: + return jsonify({ + 'success': False, + 'error': str(e) + }), 500 + + +@app.route('/api/concepts/daily-top', methods=['GET']) +def get_daily_top_concepts(): + """获取每日涨幅靠前的概念板块""" + try: + # 获取交易日期参数 + trade_date = request.args.get('date') + limit = request.args.get('limit', 6, type=int) + + # 构建概念中心API的URL + concept_api_url = 'http://222.128.1.157:16801/search' + + # 准备请求数据 + request_data = { + 'query': '', + 'size': limit, + 'page': 1, + 'sort_by': 'change_pct' + } + + if trade_date: + request_data['trade_date'] = trade_date + + # 调用概念中心API + response = requests.post(concept_api_url, json=request_data, timeout=10) + + if response.status_code == 200: + data = response.json() + top_concepts = [] + + for concept in data.get('results', []): + top_concepts.append({ + 'concept_id': concept.get('concept_id'), + 'concept_name': concept.get('concept'), + 'description': concept.get('description'), + 'change_percent': concept.get('price_info', {}).get('avg_change_pct', 0), + 'stock_count': concept.get('stock_count', 0), + 'stocks': concept.get('stocks', [])[:5] # 只返回前5只股票 + }) + + return jsonify({ + 'success': True, + 'data': top_concepts, + 'trade_date': data.get('price_date'), + 'count': len(top_concepts) + }) + else: + return jsonify({ + 'success': False, + 'error': '获取概念数据失败' + }), 500 + + except Exception as e: + return jsonify({ + 'success': False, + 'error': str(e) + }), 500 + + +@app.route('/api/market/rise-analysis/', methods=['GET']) +def get_rise_analysis(seccode): + """获取股票涨幅分析数据""" + try: + # 获取日期范围参数 + start_date = request.args.get('start_date') + end_date = request.args.get('end_date') + + query = text(""" + SELECT stock_code, + stock_name, + trade_date, + rise_rate, + close_price, + volume, + amount, + main_business, + rise_reason_brief, + rise_reason_detail, + news_summary, + announcements, + guba_sentiment, + analysis_time + FROM stock_rise_analysis + WHERE stock_code = :stock_code + """) + + params = {'stock_code': seccode} + + # 添加日期筛选 + if start_date and end_date: + query = text(""" + SELECT stock_code, + stock_name, + trade_date, + rise_rate, + close_price, + volume, + amount, + main_business, + rise_reason_brief, + rise_reason_detail, + news_summary, + announcements, + guba_sentiment, + analysis_time + FROM stock_rise_analysis + WHERE stock_code = :stock_code + AND trade_date BETWEEN :start_date AND :end_date + ORDER BY trade_date DESC + """) + params['start_date'] = start_date + params['end_date'] = end_date + else: + query = text(""" + SELECT stock_code, + stock_name, + trade_date, + rise_rate, + close_price, + volume, + amount, + main_business, + rise_reason_brief, + rise_reason_detail, + news_summary, + announcements, + guba_sentiment, + analysis_time + FROM stock_rise_analysis + WHERE stock_code = :stock_code + ORDER BY trade_date DESC LIMIT 100 + """) + + result = engine.execute(query, **params).fetchall() + + # 格式化数据 + rise_analysis_data = [] + for row in result: + rise_analysis_data.append({ + 'stock_code': row.stock_code, + 'stock_name': row.stock_name, + 'trade_date': format_date(row.trade_date), + 'rise_rate': format_decimal(row.rise_rate), + 'close_price': format_decimal(row.close_price), + 'volume': format_decimal(row.volume), + 'amount': format_decimal(row.amount), + 'main_business': row.main_business, + 'rise_reason_brief': row.rise_reason_brief, + 'rise_reason_detail': row.rise_reason_detail, + 'news_summary': row.news_summary, + 'announcements': row.announcements, + 'guba_sentiment': row.guba_sentiment, + 'analysis_time': row.analysis_time.strftime('%Y-%m-%d %H:%M:%S') if row.analysis_time else None + }) + + return jsonify({ + 'success': True, + 'data': rise_analysis_data, + 'count': len(rise_analysis_data) + }) + + except Exception as e: + return jsonify({ + 'success': False, + 'error': str(e) + }), 500 + + +# ============================================ +# 公司分析相关接口 +# ============================================ + +@app.route('/api/company/comprehensive-analysis/', methods=['GET']) +def get_comprehensive_analysis(company_code): + """获取公司综合分析数据""" + try: + # 获取公司定性分析 + qualitative_query = text(""" + SELECT one_line_intro, + investment_highlights, + business_model_desc, + company_story, + positioning_analysis, + unique_value_proposition, + business_logic_explanation, + revenue_driver_analysis, + customer_value_analysis, + strategy_description, + strategic_initiatives, + created_at, + updated_at + FROM company_analysis + WHERE company_code = :company_code + """) + + qualitative_result = engine.execute(qualitative_query, company_code=company_code).fetchone() + + # 获取业务板块分析 + segments_query = text(""" + SELECT segment_name, + segment_description, + competitive_position, + future_potential, + key_customers, + value_chain_position, + created_at, + updated_at + FROM business_segment_analysis + WHERE company_code = :company_code + ORDER BY created_at DESC + """) + + segments_result = engine.execute(segments_query, company_code=company_code).fetchall() + + # 获取竞争地位数据 - 最新一期 + competitive_query = text(""" + SELECT market_position_score, + technology_score, + brand_score, + operation_score, + finance_score, + innovation_score, + risk_score, + growth_score, + industry_avg_comparison, + main_competitors, + competitive_advantages, + competitive_disadvantages, + industry_rank, + total_companies, + report_period, + updated_at + FROM company_competitive_position + WHERE company_code = :company_code + ORDER BY report_period DESC LIMIT 1 + """) + + competitive_result = engine.execute(competitive_query, company_code=company_code).fetchone() + + # 获取业务结构数据 - 最新一期 + business_structure_query = text(""" + SELECT business_name, + parent_business, + business_level, + revenue, + revenue_unit, + revenue_ratio, + profit, + profit_unit, + profit_ratio, + revenue_growth, + profit_growth, + gross_margin, + customer_count, + market_share, + report_period + FROM company_business_structure + WHERE company_code = :company_code + AND report_period = (SELECT MAX(report_period) + FROM company_business_structure + WHERE company_code = :company_code) + ORDER BY revenue_ratio DESC + """) + + business_structure_result = engine.execute(business_structure_query, company_code=company_code).fetchall() + + # 构建返回数据 + response_data = { + 'company_code': company_code, + 'qualitative_analysis': None, + 'business_segments': [], + 'competitive_position': None, + 'business_structure': [] + } + + # 处理定性分析数据 + if qualitative_result: + response_data['qualitative_analysis'] = { + 'core_positioning': { + 'one_line_intro': qualitative_result.one_line_intro, + 'investment_highlights': qualitative_result.investment_highlights, + 'business_model_desc': qualitative_result.business_model_desc, + 'company_story': qualitative_result.company_story + }, + 'business_understanding': { + 'positioning_analysis': qualitative_result.positioning_analysis, + 'unique_value_proposition': qualitative_result.unique_value_proposition, + 'business_logic_explanation': qualitative_result.business_logic_explanation, + 'revenue_driver_analysis': qualitative_result.revenue_driver_analysis, + 'customer_value_analysis': qualitative_result.customer_value_analysis + }, + 'strategy': { + 'strategy_description': qualitative_result.strategy_description, + 'strategic_initiatives': qualitative_result.strategic_initiatives + }, + 'updated_at': qualitative_result.updated_at.strftime( + '%Y-%m-%d %H:%M:%S') if qualitative_result.updated_at else None + } + + # 处理业务板块数据 + for segment in segments_result: + response_data['business_segments'].append({ + 'segment_name': segment.segment_name, + 'segment_description': segment.segment_description, + 'competitive_position': segment.competitive_position, + 'future_potential': segment.future_potential, + 'key_customers': segment.key_customers, + 'value_chain_position': segment.value_chain_position, + 'updated_at': segment.updated_at.strftime('%Y-%m-%d %H:%M:%S') if segment.updated_at else None + }) + + # 处理竞争地位数据 + if competitive_result: + response_data['competitive_position'] = { + 'scores': { + 'market_position': competitive_result.market_position_score, + 'technology': competitive_result.technology_score, + 'brand': competitive_result.brand_score, + 'operation': competitive_result.operation_score, + 'finance': competitive_result.finance_score, + 'innovation': competitive_result.innovation_score, + 'risk': competitive_result.risk_score, + 'growth': competitive_result.growth_score + }, + 'analysis': { + 'industry_avg_comparison': competitive_result.industry_avg_comparison, + 'main_competitors': competitive_result.main_competitors, + 'competitive_advantages': competitive_result.competitive_advantages, + 'competitive_disadvantages': competitive_result.competitive_disadvantages + }, + 'ranking': { + 'industry_rank': competitive_result.industry_rank, + 'total_companies': competitive_result.total_companies, + 'rank_percentage': round( + (competitive_result.industry_rank / competitive_result.total_companies * 100), + 2) if competitive_result.industry_rank and competitive_result.total_companies else None + }, + 'report_period': competitive_result.report_period, + 'updated_at': competitive_result.updated_at.strftime( + '%Y-%m-%d %H:%M:%S') if competitive_result.updated_at else None + } + + # 处理业务结构数据 + for business in business_structure_result: + response_data['business_structure'].append({ + 'business_name': business.business_name, + 'parent_business': business.parent_business, + 'business_level': business.business_level, + 'revenue': format_decimal(business.revenue), + 'revenue_unit': business.revenue_unit, + 'profit': format_decimal(business.profit), + 'profit_unit': business.profit_unit, + 'financial_metrics': { + 'revenue': format_decimal(business.revenue), + 'revenue_ratio': format_decimal(business.revenue_ratio), + 'profit': format_decimal(business.profit), + 'profit_ratio': format_decimal(business.profit_ratio), + 'gross_margin': format_decimal(business.gross_margin) + }, + 'growth_metrics': { + 'revenue_growth': format_decimal(business.revenue_growth), + 'profit_growth': format_decimal(business.profit_growth) + }, + 'market_metrics': { + 'customer_count': business.customer_count, + 'market_share': format_decimal(business.market_share) + }, + 'report_period': business.report_period + }) + + return jsonify({ + 'success': True, + 'data': response_data + }) + + except Exception as e: + return jsonify({ + 'success': False, + 'error': str(e) + }), 500 + + +@app.route('/api/company/value-chain-analysis/', methods=['GET']) +def get_value_chain_analysis(company_code): + """获取公司产业链分析数据""" + try: + # 获取产业链节点数据 + nodes_query = text(""" + SELECT node_name, + node_type, + node_level, + node_description, + importance_score, + market_share, + dependency_degree, + created_at + FROM company_value_chain_nodes + WHERE company_code = :company_code + ORDER BY node_level ASC, importance_score DESC + """) + + nodes_result = engine.execute(nodes_query, company_code=company_code).fetchall() + + # 获取产业链流向数据 + flows_query = text(""" + SELECT source_node, + source_type, + source_level, + target_node, + target_type, + target_level, + flow_value, + flow_ratio, + flow_type, + relationship_desc, + transaction_volume + FROM company_value_chain_flows + WHERE company_code = :company_code + ORDER BY flow_ratio DESC + """) + + flows_result = engine.execute(flows_query, company_code=company_code).fetchall() + + # 构建节点数据结构 + nodes_by_level = {} + all_nodes = [] + + for node in nodes_result: + node_data = { + 'node_name': node.node_name, + 'node_type': node.node_type, + 'node_level': node.node_level, + 'node_description': node.node_description, + 'importance_score': node.importance_score, + 'market_share': format_decimal(node.market_share), + 'dependency_degree': format_decimal(node.dependency_degree), + 'created_at': node.created_at.strftime('%Y-%m-%d %H:%M:%S') if node.created_at else None + } + + all_nodes.append(node_data) + + # 按层级分组 + level_key = f"level_{node.node_level}" + if level_key not in nodes_by_level: + nodes_by_level[level_key] = [] + nodes_by_level[level_key].append(node_data) + + # 构建流向数据 + flows_data = [] + for flow in flows_result: + flows_data.append({ + 'source': { + 'node_name': flow.source_node, + 'node_type': flow.source_type, + 'node_level': flow.source_level + }, + 'target': { + 'node_name': flow.target_node, + 'node_type': flow.target_type, + 'node_level': flow.target_level + }, + 'flow_metrics': { + 'flow_value': format_decimal(flow.flow_value), + 'flow_ratio': format_decimal(flow.flow_ratio), + 'flow_type': flow.flow_type + }, + 'relationship_info': { + 'relationship_desc': flow.relationship_desc, + 'transaction_volume': flow.transaction_volume + } + }) + + # 移除循环边,确保Sankey图数据是DAG(有向无环图) + flows_data = remove_cycles_from_sankey_flows(flows_data) + + # 统计各层级节点数量 + level_stats = {} + for level_key, nodes in nodes_by_level.items(): + level_stats[level_key] = { + 'count': len(nodes), + 'avg_importance': round(sum(node['importance_score'] or 0 for node in nodes) / len(nodes), + 2) if nodes else 0 + } + + response_data = { + 'company_code': company_code, + 'value_chain_structure': { + 'nodes_by_level': nodes_by_level, + 'level_statistics': level_stats, + 'total_nodes': len(all_nodes) + }, + 'value_chain_flows': flows_data, + 'analysis_summary': { + 'total_flows': len(flows_data), + 'upstream_nodes': len([n for n in all_nodes if n['node_level'] < 0]), + 'company_nodes': len([n for n in all_nodes if n['node_level'] == 0]), + 'downstream_nodes': len([n for n in all_nodes if n['node_level'] > 0]) + } + } + + return jsonify({ + 'success': True, + 'data': response_data + }) + + except Exception as e: + return jsonify({ + 'success': False, + 'error': str(e) + }), 500 + + +@app.route('/api/company/key-factors-timeline/', methods=['GET']) +def get_key_factors_timeline(company_code): + """获取公司关键因素和时间线数据""" + try: + # 获取请求参数 + report_period = request.args.get('report_period') # 可选的报告期筛选 + event_limit = request.args.get('event_limit', 50, type=int) # 时间线事件数量限制 + + # 获取关键因素类别 + categories_query = text(""" + SELECT id, + category_name, + category_desc, + display_order + FROM company_key_factor_categories + WHERE company_code = :company_code + ORDER BY display_order ASC, created_at ASC + """) + + categories_result = engine.execute(categories_query, company_code=company_code).fetchall() + + # 获取关键因素详情 + factors_query = text(""" + SELECT kf.category_id, + kf.factor_name, + kf.factor_type, + kf.factor_value, + kf.factor_unit, + kf.factor_desc, + kf.impact_direction, + kf.impact_weight, + kf.report_period, + kf.year_on_year, + kf.data_source, + kf.created_at, + kf.updated_at + FROM company_key_factors kf + WHERE kf.company_code = :company_code + """) + + params = {'company_code': company_code} + + # 如果指定了报告期,添加筛选条件 + if report_period: + factors_query = text(""" + SELECT kf.category_id, + kf.factor_name, + kf.factor_type, + kf.factor_value, + kf.factor_unit, + kf.factor_desc, + kf.impact_direction, + kf.impact_weight, + kf.report_period, + kf.year_on_year, + kf.data_source, + kf.created_at, + kf.updated_at + FROM company_key_factors kf + WHERE kf.company_code = :company_code + AND kf.report_period = :report_period + ORDER BY kf.impact_weight DESC, kf.updated_at DESC + """) + params['report_period'] = report_period + else: + factors_query = text(""" + SELECT kf.category_id, + kf.factor_name, + kf.factor_type, + kf.factor_value, + kf.factor_unit, + kf.factor_desc, + kf.impact_direction, + kf.impact_weight, + kf.report_period, + kf.year_on_year, + kf.data_source, + kf.created_at, + kf.updated_at + FROM company_key_factors kf + WHERE kf.company_code = :company_code + ORDER BY kf.report_period DESC, kf.impact_weight DESC, kf.updated_at DESC + """) + + factors_result = engine.execute(factors_query, **params).fetchall() + + # 获取发展时间线事件 + timeline_query = text(""" + SELECT event_date, + event_type, + event_title, + event_desc, + impact_score, + is_positive, + related_products, + related_partners, + financial_impact, + created_at + FROM company_timeline_events + WHERE company_code = :company_code + ORDER BY event_date DESC LIMIT :limit + """) + + timeline_result = engine.execute(timeline_query, + company_code=company_code, + limit=event_limit).fetchall() + + # 构建关键因素数据结构 + key_factors_data = {} + factors_by_category = {} + + # 先建立类别索引 + categories_map = {} + for category in categories_result: + categories_map[category.id] = { + 'category_name': category.category_name, + 'category_desc': category.category_desc, + 'display_order': category.display_order, + 'factors': [] + } + + # 将因素分组到类别中 + for factor in factors_result: + factor_data = { + 'factor_name': factor.factor_name, + 'factor_type': factor.factor_type, + 'factor_value': factor.factor_value, + 'factor_unit': factor.factor_unit, + 'factor_desc': factor.factor_desc, + 'impact_direction': factor.impact_direction, + 'impact_weight': factor.impact_weight, + 'report_period': factor.report_period, + 'year_on_year': format_decimal(factor.year_on_year), + 'data_source': factor.data_source, + 'updated_at': factor.updated_at.strftime('%Y-%m-%d %H:%M:%S') if factor.updated_at else None + } + + category_id = factor.category_id + if category_id and category_id in categories_map: + categories_map[category_id]['factors'].append(factor_data) + + # 构建时间线数据 + timeline_data = [] + for event in timeline_result: + timeline_data.append({ + 'event_date': event.event_date.strftime('%Y-%m-%d') if event.event_date else None, + 'event_type': event.event_type, + 'event_title': event.event_title, + 'event_desc': event.event_desc, + 'impact_metrics': { + 'impact_score': event.impact_score, + 'is_positive': event.is_positive + }, + 'related_info': { + 'related_products': event.related_products, + 'related_partners': event.related_partners, + 'financial_impact': event.financial_impact + }, + 'created_at': event.created_at.strftime('%Y-%m-%d %H:%M:%S') if event.created_at else None + }) + + # 统计信息 + total_factors = len(factors_result) + positive_events = len([e for e in timeline_result if e.is_positive]) + negative_events = len(timeline_result) - positive_events + + response_data = { + 'company_code': company_code, + 'key_factors': { + 'categories': list(categories_map.values()), + 'total_factors': total_factors, + 'report_period': report_period + }, + 'development_timeline': { + 'events': timeline_data, + 'statistics': { + 'total_events': len(timeline_data), + 'positive_events': positive_events, + 'negative_events': negative_events, + 'event_types': list(set(event.event_type for event in timeline_result if event.event_type)) + } + } + } + + return jsonify({ + 'success': True, + 'data': response_data + }) + + except Exception as e: + return jsonify({ + 'success': False, + 'error': str(e) + }), 500 + + +# ============================================ +# 模拟盘服务函数 +# ============================================ + +def get_or_create_simulation_account(user_id): + """获取或创建模拟账户""" + account = SimulationAccount.query.filter_by(user_id=user_id).first() + if not account: + account = SimulationAccount( + user_id=user_id, + account_name=f'模拟账户_{user_id}', + initial_capital=1000000.00, + available_cash=1000000.00 + ) + db.session.add(account) + db.session.commit() + return account + + +def is_trading_time(): + """判断是否为交易时间""" + now = beijing_now() + # 检查是否为工作日 + if now.weekday() >= 5: # 周六日 + return False + + # 检查是否为交易时间 + current_time = now.time() + morning_start = dt_time(9, 30) + morning_end = dt_time(11, 30) + afternoon_start = dt_time(13, 0) + afternoon_end = dt_time(15, 0) + + if (morning_start <= current_time <= morning_end) or \ + (afternoon_start <= current_time <= afternoon_end): + return True + + return False + + +def get_latest_price_from_clickhouse(stock_code): + """从ClickHouse获取最新价格(优先分钟数据,备选日线数据)""" + try: + client = get_clickhouse_client() + + # 确保stock_code包含后缀 + if '.' not in stock_code: + stock_code = f"{stock_code}.SH" if stock_code.startswith('6') else f"{stock_code}.SZ" + + # 1. 首先尝试获取最新的分钟数据(近30天) + minute_query = """ + SELECT close, timestamp + FROM stock_minute + WHERE code = %(code)s + AND timestamp >= today() - 30 + ORDER BY timestamp DESC + LIMIT 1 \ + """ + + result = client.execute(minute_query, {'code': stock_code}) + + if result: + return float(result[0][0]), result[0][1] + + # 2. 如果没有分钟数据,获取最新的日线收盘价 + daily_query = """ + SELECT close, date + FROM stock_daily + WHERE code = %(code)s + AND date >= today() - 90 + ORDER BY date DESC + LIMIT 1 \ + """ + + daily_result = client.execute(daily_query, {'code': stock_code}) + + if daily_result: + return float(daily_result[0][0]), daily_result[0][1] + + # 3. 如果还是没有,尝试从其他表获取(如果有的话) + fallback_query = """ + SELECT close_price, trade_date + FROM stock_minute_kline + WHERE stock_code = %(code6)s + AND trade_date >= today() - 30 + ORDER BY trade_date DESC, trade_time DESC LIMIT 1 \ + """ + + # 提取6位代码 + code6 = stock_code.split('.')[0] + fallback_result = client.execute(fallback_query, {'code6': code6}) + + if fallback_result: + return float(fallback_result[0][0]), fallback_result[0][1] + + print(f"警告: 无法获取股票 {stock_code} 的价格数据") + return None, None + + except Exception as e: + print(f"获取最新价格失败 {stock_code}: {e}") + return None, None + + +def get_next_minute_price(stock_code, order_time): + """获取下单后一分钟内的收盘价作为成交价""" + try: + client = get_clickhouse_client() + + # 确保stock_code包含后缀 + if '.' not in stock_code: + stock_code = f"{stock_code}.SH" if stock_code.startswith('6') else f"{stock_code}.SZ" + + # 获取下单后一分钟内的数据 + query = """ + SELECT close, timestamp + FROM stock_minute + WHERE code = %(code)s + AND timestamp \ + > %(order_time)s + AND timestamp <= %(end_time)s + ORDER BY timestamp ASC + LIMIT 1 \ + """ + + end_time = order_time + timedelta(minutes=1) + + result = client.execute(query, { + 'code': stock_code, + 'order_time': order_time, + 'end_time': end_time + }) + + if result: + return float(result[0][0]), result[0][1] + + # 如果一分钟内没有数据,获取最近的数据 + query = """ + SELECT close, timestamp + FROM stock_minute + WHERE code = %(code)s + AND timestamp \ + > %(order_time)s + ORDER BY timestamp ASC + LIMIT 1 \ + """ + + result = client.execute(query, { + 'code': stock_code, + 'order_time': order_time + }) + + if result: + return float(result[0][0]), result[0][1] + + # 如果没有后续分钟数据,使用最新可用价格 + print(f"没有找到下单后的分钟数据,使用最新价格: {stock_code}") + return get_latest_price_from_clickhouse(stock_code) + + except Exception as e: + print(f"获取成交价格失败: {e}") + # 出错时也尝试获取最新价格 + return get_latest_price_from_clickhouse(stock_code) + + +def validate_and_get_stock_info(stock_input): + """验证股票输入并获取标准代码和名称 + + 支持输入格式: + - 股票代码:600519 或 600519.SH + - 股票名称:贵州茅台 + - 拼音首字母:gzmt + - 名称(代码):贵州茅台(600519) + + 返回: (stock_code_with_suffix, stock_code_6digit, stock_name) 或 (None, None, None) + """ + # 先尝试标准化输入 + code6, name_from_input = _normalize_stock_input(stock_input) + + if code6: + # 如果能解析出6位代码,查询股票名称 + stock_name = name_from_input or _query_stock_name_by_code(code6) + stock_code_full = f"{code6}.SH" if code6.startswith('6') else f"{code6}.SZ" + return stock_code_full, code6, stock_name + + # 如果不是标准代码格式,尝试搜索 + with engine.connect() as conn: + search_sql = text(""" + SELECT DISTINCT SECCODE as stock_code, + SECNAME as stock_name + FROM ea_stocklist + WHERE ( + UPPER(SECCODE) = UPPER(:exact_match) + OR UPPER(SECNAME) = UPPER(:exact_match) + OR UPPER(F001V) = UPPER(:exact_match) + ) + AND F011V = '正常上市' + AND F003V IN ('A股', 'B股') LIMIT 1 + """) + + result = conn.execute(search_sql, { + 'exact_match': stock_input.upper() + }).fetchone() + + if result: + code6 = result.stock_code + stock_name = result.stock_name + stock_code_full = f"{code6}.SH" if code6.startswith('6') else f"{code6}.SZ" + return stock_code_full, code6, stock_name + + return None, None, None + + +def execute_simulation_order(order): + """执行模拟订单(优化版)""" + try: + # 标准化股票代码 + stock_code_full, code6, stock_name = validate_and_get_stock_info(order.stock_code) + + if not stock_code_full: + order.status = 'REJECTED' + order.reject_reason = '无效的股票代码' + db.session.commit() + return False + + # 更新订单的股票信息 + order.stock_code = stock_code_full + order.stock_name = stock_name + + # 获取成交价格(下单后一分钟的收盘价) + filled_price, filled_time = get_next_minute_price(stock_code_full, order.order_time) + + if not filled_price: + # 如果无法获取价格,订单保持PENDING状态,等待后台处理 + order.status = 'PENDING' + db.session.commit() + return True # 返回True表示下单成功,但未成交 + + # 更新订单信息 + order.filled_qty = order.order_qty + order.filled_price = filled_price + order.filled_amount = filled_price * order.order_qty + order.filled_time = filled_time or beijing_now() + + # 计算费用 + order.calculate_fees() + + # 获取账户 + account = SimulationAccount.query.get(order.account_id) + + if order.order_type == 'BUY': + # 买入操作 + total_cost = float(order.filled_amount) + float(order.total_fee) + + # 检查资金是否充足 + if float(account.available_cash) < total_cost: + order.status = 'REJECTED' + order.reject_reason = '可用资金不足' + db.session.commit() + return False + + # 扣除资金 + account.available_cash -= Decimal(str(total_cost)) + + # 更新或创建持仓 + position = SimulationPosition.query.filter_by( + account_id=account.id, + stock_code=order.stock_code + ).first() + + if position: + # 更新持仓 + total_cost_before = float(position.avg_cost) * position.position_qty + total_cost_after = total_cost_before + float(order.filled_amount) + total_qty_after = position.position_qty + order.filled_qty + + position.avg_cost = Decimal(str(total_cost_after / total_qty_after)) + position.position_qty = total_qty_after + # 今日买入,T+1才可用 + position.frozen_qty += order.filled_qty + else: + # 创建新持仓 + position = SimulationPosition( + account_id=account.id, + stock_code=order.stock_code, + stock_name=order.stock_name, + position_qty=order.filled_qty, + available_qty=0, # T+1 + frozen_qty=order.filled_qty, # 今日买入冻结 + avg_cost=order.filled_price, + current_price=order.filled_price + ) + db.session.add(position) + + # 更新持仓市值 + position.update_market_value(order.filled_price) + + else: # SELL + # 卖出操作 + print(f"🔍 调试:查找持仓,账户ID: {account.id}, 股票代码: {order.stock_code}") + + # 先尝试用完整格式查找 + position = SimulationPosition.query.filter_by( + account_id=account.id, + stock_code=order.stock_code + ).first() + + # 如果没找到,尝试用6位数字格式查找 + if not position and '.' in order.stock_code: + code6 = order.stock_code.split('.')[0] + print(f"🔍 调试:尝试用6位格式查找: {code6}") + position = SimulationPosition.query.filter_by( + account_id=account.id, + stock_code=code6 + ).first() + + print(f"🔍 调试:找到持仓: {position}") + if position: + print( + f"🔍 调试:持仓详情 - 股票代码: {position.stock_code}, 持仓数量: {position.position_qty}, 可用数量: {position.available_qty}") + + # 检查持仓是否存在 + if not position: + order.status = 'REJECTED' + order.reject_reason = '持仓不存在' + db.session.commit() + return False + + # 检查总持仓数量是否足够(包括冻结的) + total_holdings = position.position_qty + if total_holdings < order.order_qty: + order.status = 'REJECTED' + order.reject_reason = f'持仓数量不足,当前持仓: {total_holdings} 股,需要: {order.order_qty} 股' + db.session.commit() + return False + + # 如果可用数量不足,但总持仓足够,则从冻结数量中解冻 + if position.available_qty < order.order_qty: + # 计算需要解冻的数量 + need_to_unfreeze = order.order_qty - position.available_qty + if position.frozen_qty >= need_to_unfreeze: + # 解冻部分冻结数量 + position.frozen_qty -= need_to_unfreeze + position.available_qty += need_to_unfreeze + print(f"解冻 {need_to_unfreeze} 股用于卖出") + else: + order.status = 'REJECTED' + order.reject_reason = f'可用数量不足,可用: {position.available_qty} 股,冻结: {position.frozen_qty} 股,需要: {order.order_qty} 股' + db.session.commit() + return False + + # 更新持仓 + position.position_qty -= order.filled_qty + position.available_qty -= order.filled_qty + + # 增加资金 + account.available_cash += Decimal(str(float(order.filled_amount) - float(order.total_fee))) + + # 如果全部卖出,删除持仓记录 + if position.position_qty == 0: + db.session.delete(position) + + # 创建成交记录 + transaction = SimulationTransaction( + account_id=account.id, + order_id=order.id, + transaction_no=f"T{int(beijing_now().timestamp() * 1000000)}", + stock_code=order.stock_code, + stock_name=order.stock_name, + transaction_type=order.order_type, + transaction_price=order.filled_price, + transaction_qty=order.filled_qty, + transaction_amount=order.filled_amount, + commission=order.commission, + stamp_tax=order.stamp_tax, + transfer_fee=order.transfer_fee, + total_fee=order.total_fee, + transaction_time=order.filled_time, + settlement_date=(order.filled_time + timedelta(days=1)).date() + ) + db.session.add(transaction) + + # 更新订单状态 + order.status = 'FILLED' + + # 更新账户总资产 + update_account_assets(account) + + db.session.commit() + return True + + except Exception as e: + print(f"执行订单失败: {e}") + db.session.rollback() + return False + + +def update_account_assets(account): + """更新账户资产(轻量级版本,不实时获取价格)""" + try: + # 只计算已有的持仓市值,不实时获取价格 + # 价格更新由后台脚本负责 + positions = SimulationPosition.query.filter_by(account_id=account.id).all() + total_market_value = sum(position.market_value or Decimal('0') for position in positions) + + account.position_value = total_market_value + account.calculate_total_assets() + + db.session.commit() + + except Exception as e: + print(f"更新账户资产失败: {e}") + db.session.rollback() + + +def update_all_positions_price(): + """更新所有持仓的最新价格(定时任务调用)""" + try: + positions = SimulationPosition.query.all() + + for position in positions: + latest_price, _ = get_latest_price_from_clickhouse(position.stock_code) + if latest_price: + # 记录昨日收盘价(用于计算今日盈亏) + yesterday_close = position.current_price + + # 更新市值 + position.update_market_value(latest_price) + + # 计算今日盈亏 + position.today_profit = (Decimal(str(latest_price)) - yesterday_close) * position.position_qty + position.today_profit_rate = ((Decimal( + str(latest_price)) - yesterday_close) / yesterday_close * 100) if yesterday_close > 0 else 0 + + db.session.commit() + + except Exception as e: + print(f"更新持仓价格失败: {e}") + db.session.rollback() + + +def process_t1_settlement(): + """处理T+1结算(每日收盘后运行)""" + try: + # 获取所有需要结算的持仓 + positions = SimulationPosition.query.filter(SimulationPosition.frozen_qty > 0).all() + + for position in positions: + # 将冻结数量转为可用数量 + position.available_qty += position.frozen_qty + position.frozen_qty = 0 + + db.session.commit() + + except Exception as e: + print(f"T+1结算失败: {e}") + db.session.rollback() + + +# ============================================ +# 模拟盘API接口 +# ============================================ + +@app.route('/api/simulation/account', methods=['GET']) +@login_required +def get_simulation_account(): + """获取模拟账户信息""" + try: + account = get_or_create_simulation_account(current_user.id) + + # 更新账户资产 + update_account_assets(account) + + return jsonify({ + 'success': True, + 'data': { + 'account_id': account.id, + 'account_name': account.account_name, + 'initial_capital': float(account.initial_capital), + 'available_cash': float(account.available_cash), + 'frozen_cash': float(account.frozen_cash), + 'position_value': float(account.position_value), + 'total_assets': float(account.total_assets), + 'total_profit': float(account.total_profit), + 'total_profit_rate': float(account.total_profit_rate), + 'daily_profit': float(account.daily_profit), + 'daily_profit_rate': float(account.daily_profit_rate), + 'created_at': account.created_at.isoformat(), + 'updated_at': account.updated_at.isoformat() + } + }) + + except Exception as e: + return jsonify({'success': False, 'error': str(e)}), 500 + + +@app.route('/api/simulation/positions', methods=['GET']) +@login_required +def get_simulation_positions(): + """获取模拟持仓列表(优化版本,使用缓存的价格数据)""" + try: + account = get_or_create_simulation_account(current_user.id) + + # 直接获取持仓数据,不实时更新价格(由后台脚本负责) + positions = SimulationPosition.query.filter_by(account_id=account.id).all() + + positions_data = [] + for position in positions: + positions_data.append({ + 'id': position.id, + 'stock_code': position.stock_code, + 'stock_name': position.stock_name, + 'position_qty': position.position_qty, + 'available_qty': position.available_qty, + 'frozen_qty': position.frozen_qty, + 'avg_cost': float(position.avg_cost), + 'current_price': float(position.current_price or 0), + 'market_value': float(position.market_value or 0), + 'profit': float(position.profit or 0), + 'profit_rate': float(position.profit_rate or 0), + 'today_profit': float(position.today_profit or 0), + 'today_profit_rate': float(position.today_profit_rate or 0), + 'updated_at': position.updated_at.isoformat() + }) + + return jsonify({ + 'success': True, + 'data': positions_data + }) + + except Exception as e: + return jsonify({'success': False, 'error': str(e)}), 500 + + +@app.route('/api/simulation/orders', methods=['GET']) +@login_required +def get_simulation_orders(): + """获取模拟订单列表""" + try: + account = get_or_create_simulation_account(current_user.id) + + # 获取查询参数 + status = request.args.get('status') # 订单状态筛选 + date_str = request.args.get('date') # 日期筛选 + limit = request.args.get('limit', 50, type=int) + + query = SimulationOrder.query.filter_by(account_id=account.id) + + if status: + query = query.filter_by(status=status) + + if date_str: + try: + date = datetime.strptime(date_str, '%Y-%m-%d').date() + start_time = datetime.combine(date, dt_time(0, 0, 0)) + end_time = datetime.combine(date, dt_time(23, 59, 59)) + query = query.filter(SimulationOrder.order_time.between(start_time, end_time)) + except ValueError: + pass + + orders = query.order_by(SimulationOrder.order_time.desc()).limit(limit).all() + + orders_data = [] + for order in orders: + orders_data.append({ + 'id': order.id, + 'order_no': order.order_no, + 'stock_code': order.stock_code, + 'stock_name': order.stock_name, + 'order_type': order.order_type, + 'price_type': order.price_type, + 'order_price': float(order.order_price) if order.order_price else None, + 'order_qty': order.order_qty, + 'filled_qty': order.filled_qty, + 'filled_price': float(order.filled_price) if order.filled_price else None, + 'filled_amount': float(order.filled_amount) if order.filled_amount else None, + 'commission': float(order.commission), + 'stamp_tax': float(order.stamp_tax), + 'transfer_fee': float(order.transfer_fee), + 'total_fee': float(order.total_fee), + 'status': order.status, + 'reject_reason': order.reject_reason, + 'order_time': order.order_time.isoformat(), + 'filled_time': order.filled_time.isoformat() if order.filled_time else None + }) + + return jsonify({ + 'success': True, + 'data': orders_data + }) + + except Exception as e: + return jsonify({'success': False, 'error': str(e)}), 500 + + +@app.route('/api/simulation/place-order', methods=['POST']) +@login_required +def place_simulation_order(): + """下单""" + try: + # 移除交易时间检查,允许7x24小时下单 + # 非交易时间下的单子会保持PENDING状态,等待行情数据 + + data = request.get_json() + stock_code = data.get('stock_code') + order_type = data.get('order_type') # BUY/SELL + order_qty = data.get('order_qty') + price_type = data.get('price_type', 'MARKET') # 目前只支持市价单 + + # 标准化股票代码格式 + if stock_code and '.' not in stock_code: + # 如果没有后缀,根据股票代码添加后缀 + if stock_code.startswith('6'): + stock_code = f"{stock_code}.SH" + elif stock_code.startswith('0') or stock_code.startswith('3'): + stock_code = f"{stock_code}.SZ" + + # 参数验证 + if not all([stock_code, order_type, order_qty]): + return jsonify({'success': False, 'error': '缺少必要参数'}), 400 + + if order_type not in ['BUY', 'SELL']: + return jsonify({'success': False, 'error': '订单类型错误'}), 400 + + order_qty = int(order_qty) + if order_qty <= 0 or order_qty % 100 != 0: + return jsonify({'success': False, 'error': '下单数量必须为100的整数倍'}), 400 + + # 获取账户 + account = get_or_create_simulation_account(current_user.id) + + # 获取股票信息 + stock_name = None + with engine.connect() as conn: + result = conn.execute(text( + "SELECT SECNAME FROM ea_stocklist WHERE SECCODE = :code" + ), {"code": stock_code.split('.')[0]}).fetchone() + if result: + stock_name = result[0] + + # 创建订单 + order = SimulationOrder( + account_id=account.id, + order_no=f"O{int(beijing_now().timestamp() * 1000000)}", + stock_code=stock_code, + stock_name=stock_name, + order_type=order_type, + price_type=price_type, + order_qty=order_qty, + status='PENDING' + ) + + db.session.add(order) + db.session.commit() + + # 执行订单 + print(f"🔍 调试:开始执行订单,股票代码: {order.stock_code}, 订单类型: {order.order_type}") + success = execute_simulation_order(order) + print(f"🔍 调试:订单执行结果: {success}, 订单状态: {order.status}") + + if success: + # 重新查询订单状态,因为可能在execute_simulation_order中被修改 + db.session.refresh(order) + + if order.status == 'FILLED': + return jsonify({ + 'success': True, + 'message': '订单执行成功,已成交', + 'data': { + 'order_no': order.order_no, + 'status': 'FILLED', + 'filled_price': float(order.filled_price) if order.filled_price else None, + 'filled_qty': order.filled_qty, + 'filled_amount': float(order.filled_amount) if order.filled_amount else None, + 'total_fee': float(order.total_fee) + } + }) + elif order.status == 'PENDING': + return jsonify({ + 'success': True, + 'message': '订单提交成功,等待行情数据成交', + 'data': { + 'order_no': order.order_no, + 'status': 'PENDING', + 'order_qty': order.order_qty, + 'order_price': float(order.order_price) if order.order_price else None + } + }) + else: + return jsonify({ + 'success': False, + 'error': order.reject_reason or '订单状态异常' + }), 400 + else: + return jsonify({ + 'success': False, + 'error': order.reject_reason or '订单执行失败' + }), 400 + + except Exception as e: + db.session.rollback() + return jsonify({'success': False, 'error': str(e)}), 500 + + +@app.route('/api/simulation/cancel-order/', methods=['POST']) +@login_required +def cancel_simulation_order(order_id): + """撤销订单""" + try: + account = get_or_create_simulation_account(current_user.id) + + order = SimulationOrder.query.filter_by( + id=order_id, + account_id=account.id, + status='PENDING' + ).first() + + if not order: + return jsonify({'success': False, 'error': '订单不存在或无法撤销'}), 404 + + order.status = 'CANCELLED' + order.cancel_time = beijing_now() + + db.session.commit() + + return jsonify({ + 'success': True, + 'message': '订单已撤销' + }) + + except Exception as e: + db.session.rollback() + return jsonify({'success': False, 'error': str(e)}), 500 + + +@app.route('/api/simulation/transactions', methods=['GET']) +@login_required +def get_simulation_transactions(): + """获取成交记录""" + try: + account = get_or_create_simulation_account(current_user.id) + + # 获取查询参数 + date_str = request.args.get('date') + limit = request.args.get('limit', 100, type=int) + + query = SimulationTransaction.query.filter_by(account_id=account.id) + + if date_str: + try: + date = datetime.strptime(date_str, '%Y-%m-%d').date() + start_time = datetime.combine(date, dt_time(0, 0, 0)) + end_time = datetime.combine(date, dt_time(23, 59, 59)) + query = query.filter(SimulationTransaction.transaction_time.between(start_time, end_time)) + except ValueError: + pass + + transactions = query.order_by(SimulationTransaction.transaction_time.desc()).limit(limit).all() + + transactions_data = [] + for trans in transactions: + transactions_data.append({ + 'id': trans.id, + 'transaction_no': trans.transaction_no, + 'stock_code': trans.stock_code, + 'stock_name': trans.stock_name, + 'transaction_type': trans.transaction_type, + 'transaction_price': float(trans.transaction_price), + 'transaction_qty': trans.transaction_qty, + 'transaction_amount': float(trans.transaction_amount), + 'commission': float(trans.commission), + 'stamp_tax': float(trans.stamp_tax), + 'transfer_fee': float(trans.transfer_fee), + 'total_fee': float(trans.total_fee), + 'transaction_time': trans.transaction_time.isoformat(), + 'settlement_date': trans.settlement_date.isoformat() if trans.settlement_date else None + }) + + return jsonify({ + 'success': True, + 'data': transactions_data + }) + + except Exception as e: + return jsonify({'success': False, 'error': str(e)}), 500 + + +def get_simulation_statistics(): + """获取模拟交易统计""" + try: + account = get_or_create_simulation_account(current_user.id) + + # 获取统计时间范围 + days = request.args.get('days', 30, type=int) + end_date = beijing_now().date() + start_date = end_date - timedelta(days=days) + + # 查询日统计数据 + daily_stats = SimulationDailyStats.query.filter( + SimulationDailyStats.account_id == account.id, + SimulationDailyStats.stat_date >= start_date, + SimulationDailyStats.stat_date <= end_date + ).order_by(SimulationDailyStats.stat_date).all() + + # 查询总体统计 + total_transactions = SimulationTransaction.query.filter_by(account_id=account.id).count() + win_transactions = SimulationTransaction.query.filter( + SimulationTransaction.account_id == account.id, + SimulationTransaction.transaction_type == 'SELL' + ).all() + + win_count = 0 + total_profit = Decimal('0') + for trans in win_transactions: + # 查找对应的买入记录计算盈亏 + position = SimulationPosition.query.filter_by( + account_id=account.id, + stock_code=trans.stock_code + ).first() + if position and trans.transaction_price > position.avg_cost: + win_count += 1 + profit = (trans.transaction_price - position.avg_cost) * trans.transaction_qty if position else 0 + total_profit += profit + + # 构建日收益曲线 + daily_returns = [] + for stat in daily_stats: + daily_returns.append({ + 'date': stat.stat_date.isoformat(), + 'daily_profit': float(stat.daily_profit), + 'daily_profit_rate': float(stat.daily_profit_rate), + 'total_profit': float(stat.total_profit), + 'total_profit_rate': float(stat.total_profit_rate), + 'closing_assets': float(stat.closing_assets) + }) + + return jsonify({ + 'success': True, + 'data': { + 'summary': { + 'total_transactions': total_transactions, + 'win_count': win_count, + 'win_rate': (win_count / len(win_transactions) * 100) if win_transactions else 0, + 'total_profit': float(total_profit), + 'average_profit_per_trade': float(total_profit / len(win_transactions)) if win_transactions else 0 + }, + 'daily_returns': daily_returns + } + }) + + except Exception as e: + return jsonify({'success': False, 'error': str(e)}), 500 + + +@app.route('/api/simulation/t1-settlement', methods=['POST']) +@login_required +def trigger_t1_settlement(): + """手动触发T+1结算""" + try: + # 导入后台处理器的函数 + from simulation_background_processor import process_t1_settlement + + # 执行T+1结算 + process_t1_settlement() + + return jsonify({ + 'success': True, + 'message': 'T+1结算执行成功' + }) + + except Exception as e: + return jsonify({ + 'success': False, + 'error': str(e) + }), 500 + + +@app.route('/api/simulation/debug-positions', methods=['GET']) +@login_required +def debug_positions(): + """调试接口:查看持仓数据""" + try: + account = get_or_create_simulation_account(current_user.id) + + positions = SimulationPosition.query.filter_by(account_id=account.id).all() + + positions_data = [] + for position in positions: + positions_data.append({ + 'stock_code': position.stock_code, + 'stock_name': position.stock_name, + 'position_qty': position.position_qty, + 'available_qty': position.available_qty, + 'frozen_qty': position.frozen_qty, + 'avg_cost': float(position.avg_cost), + 'current_price': float(position.current_price or 0) + }) + + return jsonify({ + 'success': True, + 'data': positions_data + }) + + except Exception as e: + return jsonify({ + 'success': False, + 'error': str(e) + }), 500 + + +@app.route('/api/simulation/debug-transactions', methods=['GET']) +@login_required +def debug_transactions(): + """调试接口:查看成交记录数据""" + try: + account = get_or_create_simulation_account(current_user.id) + + transactions = SimulationTransaction.query.filter_by(account_id=account.id).all() + + transactions_data = [] + for trans in transactions: + transactions_data.append({ + 'id': trans.id, + 'transaction_no': trans.transaction_no, + 'stock_code': trans.stock_code, + 'stock_name': trans.stock_name, + 'transaction_type': trans.transaction_type, + 'transaction_price': float(trans.transaction_price), + 'transaction_qty': trans.transaction_qty, + 'transaction_amount': float(trans.transaction_amount), + 'commission': float(trans.commission), + 'stamp_tax': float(trans.stamp_tax), + 'transfer_fee': float(trans.transfer_fee), + 'total_fee': float(trans.total_fee), + 'transaction_time': trans.transaction_time.isoformat(), + 'settlement_date': trans.settlement_date.isoformat() if trans.settlement_date else None + }) + + return jsonify({ + 'success': True, + 'data': transactions_data, + 'count': len(transactions_data) + }) + + except Exception as e: + return jsonify({ + 'success': False, + 'error': str(e) + }), 500 + + +@app.route('/api/simulation/daily-settlement', methods=['POST']) +@login_required +def trigger_daily_settlement(): + """手动触发日结算""" + try: + # 导入后台处理器的函数 + from simulation_background_processor import generate_daily_stats + + # 执行日结算 + generate_daily_stats() + + return jsonify({ + 'success': True, + 'message': '日结算执行成功' + }) + + except Exception as e: + return jsonify({ + 'success': False, + 'error': str(e) + }), 500 + + +@app.route('/api/simulation/reset', methods=['POST']) +@login_required +def reset_simulation_account(): + """重置模拟账户""" + try: + account = SimulationAccount.query.filter_by(user_id=current_user.id).first() + + if account: + # 删除所有相关数据 + SimulationPosition.query.filter_by(account_id=account.id).delete() + SimulationOrder.query.filter_by(account_id=account.id).delete() + SimulationTransaction.query.filter_by(account_id=account.id).delete() + SimulationDailyStats.query.filter_by(account_id=account.id).delete() + + # 重置账户数据 + account.available_cash = account.initial_capital + account.frozen_cash = Decimal('0') + account.position_value = Decimal('0') + account.total_assets = account.initial_capital + account.total_profit = Decimal('0') + account.total_profit_rate = Decimal('0') + account.daily_profit = Decimal('0') + account.daily_profit_rate = Decimal('0') + account.updated_at = beijing_now() + + db.session.commit() + + return jsonify({ + 'success': True, + 'message': '模拟账户已重置' + }) + + except Exception as e: + db.session.rollback() + return jsonify({'success': False, 'error': str(e)}), 500 + + +if __name__ == '__main__': + # 创建数据库表 + with app.app_context(): + try: + db.create_all() + # 安全地初始化订阅套餐 + initialize_subscription_plans_safe() + except Exception as e: + app.logger.error(f"数据库初始化失败: {e}") + + # 初始化事件轮询机制(WebSocket 推送) + initialize_event_polling() + + # 使用 socketio.run 替代 app.run 以支持 WebSocket + socketio.run(app, host='0.0.0.0', port=5001, debug=False, allow_unsafe_werkzeug=True) \ No newline at end of file diff --git a/app.py.backup_20251114_145340 b/app.py.backup_20251114_145340 new file mode 100644 index 00000000..b5922c1a --- /dev/null +++ b/app.py.backup_20251114_145340 @@ -0,0 +1,12556 @@ +import base64 +import csv +import io +import os +import time +import urllib +import uuid +from functools import wraps +import qrcode +from flask_mail import Mail, Message +from flask_socketio import SocketIO, emit, join_room, leave_room +import pytz +import requests +from celery import Celery +from flask_compress import Compress +from pathlib import Path +import json +from sqlalchemy import Column, Integer, String, Boolean, DateTime, create_engine, text, func, or_ +from flask import Flask, render_template, request, jsonify, redirect, url_for, flash, session, render_template_string, \ + current_app, make_response +from flask_sqlalchemy import SQLAlchemy +from flask_login import LoginManager, UserMixin, login_user, logout_user, login_required, current_user +import random +from werkzeug.security import generate_password_hash, check_password_hash +import re +import string +from datetime import datetime, timedelta, time as dt_time, date +from clickhouse_driver import Client as Cclient +from flask_cors import CORS + +from collections import defaultdict +from functools import lru_cache +import jieba +import jieba.analyse +from flask_cors import cross_origin +from tencentcloud.common import credential +from tencentcloud.common.profile.client_profile import ClientProfile +from tencentcloud.common.profile.http_profile import HttpProfile +from tencentcloud.sms.v20210111 import sms_client, models +from tencentcloud.common.exception.tencent_cloud_sdk_exception import TencentCloudSDKException +from sqlalchemy import text, desc, and_ +import pandas as pd +from decimal import Decimal +from apscheduler.schedulers.background import BackgroundScheduler + +# 交易日数据缓存 +trading_days = [] +trading_days_set = set() + + +def load_trading_days(): + """加载交易日数据""" + global trading_days, trading_days_set + try: + with open('tdays.csv', 'r') as f: + reader = csv.DictReader(f) + for row in reader: + date_str = row['DateTime'] + # 解析日期 (格式: 2010/1/4) + date = datetime.strptime(date_str, '%Y/%m/%d').date() + trading_days.append(date) + trading_days_set.add(date) + + # 排序交易日 + trading_days.sort() + print(f"成功加载 {len(trading_days)} 个交易日数据") + except Exception as e: + print(f"加载交易日数据失败: {e}") + + +def get_trading_day_near_date(target_date): + """ + 获取距离目标日期最近的交易日 + 如果目标日期是交易日,返回该日期 + 如果不是,返回下一个交易日 + """ + if not trading_days: + load_trading_days() + + if not trading_days: + return None + + # 如果目标日期是datetime,转换为date + if isinstance(target_date, datetime): + target_date = target_date.date() + + # 检查目标日期是否是交易日 + if target_date in trading_days_set: + return target_date + + # 查找下一个交易日 + for trading_day in trading_days: + if trading_day >= target_date: + return trading_day + + # 如果没有找到,返回最后一个交易日 + return trading_days[-1] if trading_days else None + + +# 应用启动时加载交易日数据 +load_trading_days() + +engine = create_engine( + "mysql+pymysql://root:Zzl5588161!@222.128.1.157:33060/stock?charset=utf8mb4", + echo=False, + pool_size=10, + pool_recycle=3600, + pool_pre_ping=True, + pool_timeout=30, + max_overflow=20 +) +engine_med = create_engine( + "mysql+pymysql://root:Zzl5588161!@222.128.1.157:33060/med?charset=utf8mb4", + echo=False, + pool_size=5, + pool_recycle=3600, + pool_pre_ping=True, + pool_timeout=30, + max_overflow=10 +) +engine_2 = create_engine( + "mysql+pymysql://root:Zzl5588161!@222.128.1.157:33060/valuefrontier?charset=utf8mb4", + echo=False, + pool_size=5, + pool_recycle=3600, + pool_pre_ping=True, + pool_timeout=30, + max_overflow=10 +) +app = Flask(__name__) +# 存储验证码的临时字典(生产环境应使用Redis) +verification_codes = {} +wechat_qr_sessions = {} +# 腾讯云短信配置 +SMS_SECRET_ID = 'AKID2we9TacdTAhCjCSYTErHVimeJo9Yr00s' +SMS_SECRET_KEY = 'pMlBWijlkgT9fz5ziEXdWEnAPTJzRfkf' +SMS_SDK_APP_ID = "1400972398" +SMS_SIGN_NAME = "价值前沿科技" +SMS_TEMPLATE_REGISTER = "2386557" # 注册模板 +SMS_TEMPLATE_LOGIN = "2386540" # 登录模板 + +# 微信开放平台配置 +WECHAT_APPID = 'wxa8d74c47041b5f87' +WECHAT_APPSECRET = 'eedef95b11787fd7ca7f1acc6c9061bc' +WECHAT_REDIRECT_URI = 'http://valuefrontier.cn/api/auth/wechat/callback' + +# 邮件服务配置(QQ企业邮箱) +MAIL_SERVER = 'smtp.exmail.qq.com' +MAIL_PORT = 465 +MAIL_USE_SSL = True +MAIL_USE_TLS = False +MAIL_USERNAME = 'admin@valuefrontier.cn' +MAIL_PASSWORD = 'QYncRu6WUdASvTg4' +MAIL_DEFAULT_SENDER = 'admin@valuefrontier.cn' + +# Session和安全配置 +app.config['SECRET_KEY'] = ''.join(random.choices(string.ascii_letters + string.digits, k=32)) +app.config['SESSION_COOKIE_SECURE'] = False # 如果生产环境使用HTTPS,应设为True +app.config['SESSION_COOKIE_HTTPONLY'] = True # 生产环境应设为True,防止XSS攻击 +app.config['SESSION_COOKIE_SAMESITE'] = 'Lax' # 使用'Lax'以平衡安全性和功能性 +app.config['SESSION_COOKIE_DOMAIN'] = None # 不限制域名 +app.config['SESSION_COOKIE_PATH'] = '/' # 设置cookie路径 +app.config['PERMANENT_SESSION_LIFETIME'] = timedelta(days=7) # session持续7天 +app.config['REMEMBER_COOKIE_DURATION'] = timedelta(days=30) # 记住登录30天 +app.config['REMEMBER_COOKIE_SECURE'] = False # 记住登录cookie不要求HTTPS +app.config['REMEMBER_COOKIE_HTTPONLY'] = False # 允许JavaScript访问 + +# 配置邮件 +app.config['MAIL_SERVER'] = MAIL_SERVER +app.config['MAIL_PORT'] = MAIL_PORT +app.config['MAIL_USE_SSL'] = MAIL_USE_SSL +app.config['MAIL_USE_TLS'] = MAIL_USE_TLS +app.config['MAIL_USERNAME'] = MAIL_USERNAME +app.config['MAIL_PASSWORD'] = MAIL_PASSWORD +app.config['MAIL_DEFAULT_SENDER'] = MAIL_DEFAULT_SENDER + +# 允许前端跨域访问 - 修复CORS配置 +try: + CORS(app, + origins=["http://localhost:3000", "http://127.0.0.1:3000", "http://localhost:5173", "https://valuefrontier.cn", + "http://valuefrontier.cn"], # 明确指定允许的源 + methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], + allow_headers=["Content-Type", "Authorization", "X-Requested-With"], + supports_credentials=True, # 允许携带凭据 + expose_headers=["Content-Type", "Authorization"]) +except ImportError: + pass # 如果未安装flask_cors则跳过 + +# 初始化 Flask-Login +login_manager = LoginManager() +login_manager.init_app(app) +login_manager.login_view = 'login' +login_manager.login_message = '请先登录访问此页面' +login_manager.remember_cookie_duration = timedelta(days=30) # 记住登录持续时间 +Compress(app) +MAX_CONTENT_LENGTH = 16 * 1024 * 1024 # 16MB max file size +# Configure Flask-Compress +app.config['COMPRESS_ALGORITHM'] = ['gzip', 'br'] +app.config['COMPRESS_MIMETYPES'] = [ + 'text/html', + 'text/css', + 'text/xml', + 'application/json', + 'application/javascript', + 'application/x-javascript' +] +app.config['SQLALCHEMY_DATABASE_URI'] = 'mysql+pymysql://root:Zzl5588161!@222.128.1.157:33060/stock?charset=utf8mb4' +app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False +app.config['SQLALCHEMY_ENGINE_OPTIONS'] = { + 'pool_size': 10, + 'pool_recycle': 3600, + 'pool_pre_ping': True, + 'pool_timeout': 30, + 'max_overflow': 20 +} +# Cache directory setup +CACHE_DIR = Path('cache') +CACHE_DIR.mkdir(exist_ok=True) + + +def beijing_now(): + # 使用 pytz 处理时区,但返回 naive datetime(适合数据库存储) + beijing_tz = pytz.timezone('Asia/Shanghai') + return datetime.now(beijing_tz).replace(tzinfo=None) + + +# 检查用户是否登录的装饰器 +def login_required(f): + @wraps(f) + def decorated_function(*args, **kwargs): + if 'user_id' not in session: + return jsonify({'success': False, 'error': '未登录'}), 401 + return f(*args, **kwargs) + + return decorated_function + + +# Memory management constants +MAX_MEMORY_PERCENT = 75 +MEMORY_CHECK_INTERVAL = 300 +MAX_CACHE_ITEMS = 50 +db = SQLAlchemy(app) + +# 初始化邮件服务 +mail = Mail(app) + +# 初始化 Flask-SocketIO(用于实时事件推送) +socketio = SocketIO( + app, + cors_allowed_origins=["http://localhost:3000", "http://127.0.0.1:3000", "http://localhost:5173", + "https://valuefrontier.cn", "http://valuefrontier.cn"], + async_mode='gevent', + logger=True, + engineio_logger=False, + ping_timeout=120, # 心跳超时时间(秒),客户端120秒内无响应才断开 + ping_interval=25 # 心跳检测间隔(秒),每25秒发送一次ping +) + + +@login_manager.user_loader +def load_user(user_id): + """Flask-Login 用户加载回调""" + try: + return User.query.get(int(user_id)) + except Exception as e: + app.logger.error(f"用户加载错误: {e}") + return None + + +# 全局错误处理器 - 确保API接口始终返回JSON +@app.errorhandler(404) +def not_found_error(error): + """404错误处理""" + if request.path.startswith('/api/'): + return jsonify({'success': False, 'error': '接口不存在'}), 404 + return error + + +@app.errorhandler(500) +def internal_error(error): + """500错误处理""" + db.session.rollback() + if request.path.startswith('/api/'): + return jsonify({'success': False, 'error': '服务器内部错误'}), 500 + return error + + +@app.errorhandler(405) +def method_not_allowed_error(error): + """405错误处理""" + if request.path.startswith('/api/'): + return jsonify({'success': False, 'error': '请求方法不被允许'}), 405 + return error + + +class Post(db.Model): + """帖子模型""" + id = db.Column(db.Integer, primary_key=True) + event_id = db.Column(db.Integer, db.ForeignKey('event.id'), nullable=False) + user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False) + + # 内容 + title = db.Column(db.String(200)) # 标题(可选) + content = db.Column(db.Text, nullable=False) # 内容 + content_type = db.Column(db.String(20), default='text') # 内容类型:text/rich_text/link + + # 时间 + created_at = db.Column(db.DateTime, default=beijing_now) + updated_at = db.Column(db.DateTime, default=beijing_now, onupdate=beijing_now) + + # 统计 + likes_count = db.Column(db.Integer, default=0) + comments_count = db.Column(db.Integer, default=0) + view_count = db.Column(db.Integer, default=0) + + # 状态 + status = db.Column(db.String(20), default='active') # active/hidden/deleted + is_top = db.Column(db.Boolean, default=False) # 是否置顶 + + # 关系 + user = db.relationship('User', backref='posts') + likes = db.relationship('PostLike', backref='post', lazy='dynamic') + comments = db.relationship('Comment', backref='post', lazy='dynamic') + + +class Comment(db.Model): + """帖子评论模型""" + id = db.Column(db.Integer, primary_key=True) + post_id = db.Column(db.Integer, db.ForeignKey('post.id'), nullable=False) + user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False) + + # 内容 + content = db.Column(db.Text, nullable=False) + parent_id = db.Column(db.Integer, db.ForeignKey('comment.id')) + + # 时间 + created_at = db.Column(db.DateTime, default=beijing_now) + updated_at = db.Column(db.DateTime, default=beijing_now, onupdate=beijing_now) + + # 统计 + likes_count = db.Column(db.Integer, default=0) + + # 状态 + status = db.Column(db.String(20), default='active') # active/hidden/deleted + + # 关系 + user = db.relationship('User', backref='comments') + replies = db.relationship('Comment', backref=db.backref('parent', remote_side=[id])) + + +class User(UserMixin, db.Model): + """用户模型 - 完全匹配现有数据库表结构""" + __tablename__ = 'user' + + # 主键 + id = db.Column(db.Integer, primary_key=True, autoincrement=True) + + # 基础账号信息 + username = db.Column(db.String(80), unique=True, nullable=False) + email = db.Column(db.String(120), unique=True, nullable=True) + password_hash = db.Column(db.String(255), nullable=True) + email_confirmed = db.Column(db.Boolean, nullable=True, default=True) + + # 时间字段 + created_at = db.Column(db.DateTime, nullable=True, default=beijing_now) + last_seen = db.Column(db.DateTime, nullable=True, default=beijing_now) + + # 账号状态 + status = db.Column(db.String(20), nullable=True, default='active') + + # 个人资料信息 + nickname = db.Column(db.String(30), nullable=True) + avatar_url = db.Column(db.String(200), nullable=True) + banner_url = db.Column(db.String(200), nullable=True) + bio = db.Column(db.String(200), nullable=True) + gender = db.Column(db.String(10), nullable=True) + birth_date = db.Column(db.Date, nullable=True) + location = db.Column(db.String(100), nullable=True) + + # 联系方式 + phone = db.Column(db.String(20), nullable=True) + wechat_id = db.Column(db.String(80), nullable=True) # 微信号 + + # 实名认证 + real_name = db.Column(db.String(30), nullable=True) + id_number = db.Column(db.String(18), nullable=True) + is_verified = db.Column(db.Boolean, nullable=True, default=False) + verify_time = db.Column(db.DateTime, nullable=True) + + # 投资偏好 + trading_experience = db.Column(db.String(200), nullable=True) + investment_style = db.Column(db.String(50), nullable=True) + risk_preference = db.Column(db.String(20), nullable=True) + investment_amount = db.Column(db.String(20), nullable=True) + preferred_markets = db.Column(db.String(200), nullable=True) + + # 社区数据 + user_level = db.Column(db.Integer, nullable=True, default=1) + reputation_score = db.Column(db.Integer, nullable=True, default=0) + contribution_point = db.Column(db.Integer, nullable=True, default=0) + post_count = db.Column(db.Integer, nullable=True, default=0) + comment_count = db.Column(db.Integer, nullable=True, default=0) + follower_count = db.Column(db.Integer, nullable=True, default=0) + following_count = db.Column(db.Integer, nullable=True, default=0) + + # 创作者相关 + is_creator = db.Column(db.Boolean, nullable=True, default=False) + creator_type = db.Column(db.String(20), nullable=True) + creator_tags = db.Column(db.String(200), nullable=True) + + # 通知设置 + email_notifications = db.Column(db.Boolean, nullable=True, default=True) + sms_notifications = db.Column(db.Boolean, nullable=True, default=False) + wechat_notifications = db.Column(db.Boolean, nullable=True, default=False) + notification_preferences = db.Column(db.String(500), nullable=True) + + # 隐私和界面设置 + privacy_level = db.Column(db.String(20), nullable=True, default='public') + theme_preference = db.Column(db.String(20), nullable=True, default='light') + blocked_keywords = db.Column(db.String(500), nullable=True) + + # 手机验证相关 + phone_confirmed = db.Column(db.Boolean, nullable=True, default=False) # 注意:原表中是blob,这里改为Boolean更合理 + phone_confirm_time = db.Column(db.DateTime, nullable=True) + + # 微信登录相关字段 + wechat_union_id = db.Column(db.String(100), nullable=True) # 微信UnionID + wechat_open_id = db.Column(db.String(100), nullable=True) # 微信OpenID + + def __init__(self, username, email=None, password=None, phone=None): + """初始化用户""" + self.username = username + if email: + self.email = email + if phone: + self.phone = phone + if password: + self.set_password(password) + self.nickname = username # 默认昵称为用户名 + self.created_at = beijing_now() + self.last_seen = beijing_now() + + def set_password(self, password): + """设置密码""" + if password: + self.password_hash = generate_password_hash(password) + + def check_password(self, password): + """验证密码""" + if not password or not self.password_hash: + return False + return check_password_hash(self.password_hash, password) + + def update_last_seen(self): + """更新最后活跃时间""" + self.last_seen = beijing_now() + db.session.commit() + + def confirm_email(self): + """确认邮箱""" + self.email_confirmed = True + db.session.commit() + + def confirm_phone(self): + """确认手机号""" + self.phone_confirmed = True + self.phone_confirm_time = beijing_now() + db.session.commit() + + def bind_wechat(self, open_id, union_id=None, wechat_info=None): + """绑定微信账号""" + self.wechat_open_id = open_id + if union_id: + self.wechat_union_id = union_id + + # 如果提供了微信用户信息,更新头像和昵称 + if wechat_info: + if not self.avatar_url and wechat_info.get('headimgurl'): + self.avatar_url = wechat_info['headimgurl'] + if not self.nickname and wechat_info.get('nickname'): + # 确保昵称编码正确且长度合理 + nickname = self._sanitize_nickname(wechat_info['nickname']) + self.nickname = nickname + + db.session.commit() + + def _sanitize_nickname(self, nickname): + """清理和验证昵称""" + if not nickname: + return '微信用户' + + try: + # 确保是正确的UTF-8字符串 + sanitized = str(nickname).strip() + + # 移除可能的控制字符 + import re + sanitized = re.sub(r'[\x00-\x1f\x7f-\x9f]', '', sanitized) + + # 限制长度(避免过长的昵称) + if len(sanitized) > 50: + sanitized = sanitized[:47] + '...' + + # 如果清理后为空,使用默认值 + if not sanitized: + sanitized = '微信用户' + + return sanitized + except Exception as e: + return '微信用户' + + def unbind_wechat(self): + """解绑微信账号""" + self.wechat_open_id = None + self.wechat_union_id = None + db.session.commit() + + def increment_post_count(self): + """增加发帖数""" + self.post_count = (self.post_count or 0) + 1 + db.session.commit() + + def increment_comment_count(self): + """增加评论数""" + self.comment_count = (self.comment_count or 0) + 1 + db.session.commit() + + def add_reputation(self, points): + """增加声誉分数""" + self.reputation_score = (self.reputation_score or 0) + points + db.session.commit() + + def to_dict(self, include_sensitive=False): + """转换为字典""" + data = { + 'id': self.id, + 'username': self.username, + 'nickname': self.nickname or self.username, + 'avatar_url': self.avatar_url, + 'banner_url': self.banner_url, + 'bio': self.bio, + 'gender': self.gender, + 'location': self.location, + 'user_level': self.user_level or 1, + 'reputation_score': self.reputation_score or 0, + 'contribution_point': self.contribution_point or 0, + 'post_count': self.post_count or 0, + 'comment_count': self.comment_count or 0, + 'follower_count': self.follower_count or 0, + 'following_count': self.following_count or 0, + 'is_creator': self.is_creator or False, + 'creator_type': self.creator_type, + 'creator_tags': self.creator_tags, + 'is_verified': self.is_verified or False, + 'created_at': self.created_at.isoformat() if self.created_at else None, + 'last_seen': self.last_seen.isoformat() if self.last_seen else None, + 'status': self.status, + 'has_wechat': bool(self.wechat_open_id), + 'is_authenticated': True + } + + # 敏感信息只在需要时包含 + if include_sensitive: + data.update({ + 'email': self.email, + 'phone': self.phone, + 'email_confirmed': self.email_confirmed, + 'phone_confirmed': self.phone_confirmed, + 'real_name': self.real_name, + 'birth_date': self.birth_date.isoformat() if self.birth_date else None, + 'trading_experience': self.trading_experience, + 'investment_style': self.investment_style, + 'risk_preference': self.risk_preference, + 'investment_amount': self.investment_amount, + 'preferred_markets': self.preferred_markets, + 'email_notifications': self.email_notifications, + 'sms_notifications': self.sms_notifications, + 'wechat_notifications': self.wechat_notifications, + 'privacy_level': self.privacy_level, + 'theme_preference': self.theme_preference + }) + + return data + + def to_public_dict(self): + """公开信息字典(用于显示给其他用户)""" + return { + 'id': self.id, + 'username': self.username, + 'nickname': self.nickname or self.username, + 'avatar_url': self.avatar_url, + 'bio': self.bio, + 'user_level': self.user_level or 1, + 'reputation_score': self.reputation_score or 0, + 'post_count': self.post_count or 0, + 'follower_count': self.follower_count or 0, + 'is_creator': self.is_creator or False, + 'creator_type': self.creator_type, + 'is_verified': self.is_verified or False, + 'created_at': self.created_at.isoformat() if self.created_at else None + } + + @staticmethod + def find_by_login_info(login_info): + """根据登录信息查找用户(支持用户名、邮箱、手机号)""" + return User.query.filter( + db.or_( + User.username == login_info, + User.email == login_info, + User.phone == login_info + ) + ).first() + + @staticmethod + def find_by_wechat_openid(open_id): + """根据微信OpenID查找用户""" + return User.query.filter_by(wechat_open_id=open_id).first() + + @staticmethod + def find_by_wechat_unionid(union_id): + """根据微信UnionID查找用户""" + return User.query.filter_by(wechat_union_id=union_id).first() + + @staticmethod + def is_username_taken(username): + """检查用户名是否已被使用""" + return User.query.filter_by(username=username).first() is not None + + @staticmethod + def is_email_taken(email): + """检查邮箱是否已被使用""" + return User.query.filter_by(email=email).first() is not None + + @staticmethod + def is_phone_taken(phone): + """检查手机号是否已被使用""" + return User.query.filter_by(phone=phone).first() is not None + + def __repr__(self): + return f'' + + +# ============================================ +# 订阅功能模块(安全版本 - 独立表) +# ============================================ +class UserSubscription(db.Model): + """用户订阅表 - 独立于现有User表""" + __tablename__ = 'user_subscriptions' + + id = db.Column(db.Integer, primary_key=True, autoincrement=True) + user_id = db.Column(db.Integer, nullable=False, unique=True, index=True) + subscription_type = db.Column(db.String(10), nullable=False, default='free') + subscription_status = db.Column(db.String(20), nullable=False, default='active') + start_date = db.Column(db.DateTime, nullable=True) + end_date = db.Column(db.DateTime, nullable=True) + billing_cycle = db.Column(db.String(10), nullable=True) + auto_renewal = db.Column(db.Boolean, nullable=False, default=False) + created_at = db.Column(db.DateTime, default=beijing_now) + updated_at = db.Column(db.DateTime, default=beijing_now, onupdate=beijing_now) + + def is_active(self): + if self.subscription_status != 'active': + return False + if self.subscription_type == 'free': + return True + if self.end_date: + try: + now = beijing_now() + if self.end_date < now: + return False + except Exception as e: + return False + return True + + def days_left(self): + if self.subscription_type == 'free' or not self.end_date: + return 999 + try: + now = beijing_now() + delta = self.end_date - now + return max(0, delta.days) + except Exception as e: + return 0 + + def to_dict(self): + return { + 'type': self.subscription_type, + 'status': self.subscription_status, + 'is_active': self.is_active(), + 'days_left': self.days_left(), + 'start_date': self.start_date.isoformat() if self.start_date else None, + 'end_date': self.end_date.isoformat() if self.end_date else None, + 'billing_cycle': self.billing_cycle, + 'auto_renewal': self.auto_renewal + } + + +class SubscriptionPlan(db.Model): + """订阅套餐表""" + __tablename__ = 'subscription_plans' + + id = db.Column(db.Integer, primary_key=True, autoincrement=True) + name = db.Column(db.String(50), nullable=False, unique=True) + display_name = db.Column(db.String(100), nullable=False) + description = db.Column(db.Text, nullable=True) + monthly_price = db.Column(db.Numeric(10, 2), nullable=False) + yearly_price = db.Column(db.Numeric(10, 2), nullable=False) + features = db.Column(db.Text, nullable=True) + pricing_options = db.Column(db.Text, nullable=True) # JSON格式:[{"months": 1, "price": 99}, {"months": 12, "price": 999}] + is_active = db.Column(db.Boolean, default=True) + sort_order = db.Column(db.Integer, default=0) + created_at = db.Column(db.DateTime, default=beijing_now) + + def to_dict(self): + # 解析pricing_options(如果存在) + pricing_opts = None + if self.pricing_options: + try: + pricing_opts = json.loads(self.pricing_options) + except: + pricing_opts = None + + # 如果没有pricing_options,则从monthly_price和yearly_price生成默认选项 + if not pricing_opts: + pricing_opts = [ + { + 'months': 1, + 'price': float(self.monthly_price) if self.monthly_price else 0, + 'label': '月付', + 'cycle_key': 'monthly' + }, + { + 'months': 12, + 'price': float(self.yearly_price) if self.yearly_price else 0, + 'label': '年付', + 'cycle_key': 'yearly', + 'discount_percent': 20 # 年付默认20%折扣 + } + ] + + return { + 'id': self.id, + 'name': self.name, + 'display_name': self.display_name, + 'description': self.description, + 'monthly_price': float(self.monthly_price) if self.monthly_price else 0, + 'yearly_price': float(self.yearly_price) if self.yearly_price else 0, + 'pricing_options': pricing_opts, # 新增:灵活计费周期选项 + 'features': json.loads(self.features) if self.features else [], + 'is_active': self.is_active, + 'sort_order': self.sort_order + } + + +class PaymentOrder(db.Model): + """支付订单表""" + __tablename__ = 'payment_orders' + + id = db.Column(db.Integer, primary_key=True, autoincrement=True) + order_no = db.Column(db.String(32), unique=True, nullable=False) + user_id = db.Column(db.Integer, nullable=False) + plan_name = db.Column(db.String(20), nullable=False) + billing_cycle = db.Column(db.String(10), nullable=False) + amount = db.Column(db.Numeric(10, 2), nullable=False) + wechat_order_id = db.Column(db.String(64), nullable=True) + prepay_id = db.Column(db.String(64), nullable=True) + qr_code_url = db.Column(db.String(200), nullable=True) + status = db.Column(db.String(20), default='pending') + created_at = db.Column(db.DateTime, default=beijing_now) + paid_at = db.Column(db.DateTime, nullable=True) + expired_at = db.Column(db.DateTime, nullable=True) + remark = db.Column(db.String(200), nullable=True) + + def __init__(self, user_id, plan_name, billing_cycle, amount): + self.user_id = user_id + self.plan_name = plan_name + self.billing_cycle = billing_cycle + self.amount = amount + import random + timestamp = int(beijing_now().timestamp() * 1000000) + random_suffix = random.randint(1000, 9999) + self.order_no = f"{timestamp}{user_id:04d}{random_suffix}" + self.expired_at = beijing_now() + timedelta(minutes=30) + + def is_expired(self): + if not self.expired_at: + return False + try: + now = beijing_now() + return now > self.expired_at + except Exception as e: + return False + + def mark_as_paid(self, wechat_order_id, transaction_id=None): + self.status = 'paid' + self.paid_at = beijing_now() + self.wechat_order_id = wechat_order_id + + def to_dict(self): + return { + 'id': self.id, + 'order_no': self.order_no, + 'user_id': self.user_id, + 'plan_name': self.plan_name, + 'billing_cycle': self.billing_cycle, + 'amount': float(self.amount) if self.amount else 0, + 'original_amount': float(self.original_amount) if hasattr(self, 'original_amount') and self.original_amount else None, + 'discount_amount': float(self.discount_amount) if hasattr(self, 'discount_amount') and self.discount_amount else 0, + 'promo_code': self.promo_code.code if hasattr(self, 'promo_code') and self.promo_code else None, + 'is_upgrade': self.is_upgrade if hasattr(self, 'is_upgrade') else False, + 'qr_code_url': self.qr_code_url, + 'status': self.status, + 'is_expired': self.is_expired(), + 'created_at': self.created_at.isoformat() if self.created_at else None, + 'paid_at': self.paid_at.isoformat() if self.paid_at else None, + 'expired_at': self.expired_at.isoformat() if self.expired_at else None, + 'remark': self.remark + } + + +class PromoCode(db.Model): + """优惠码表""" + __tablename__ = 'promo_codes' + + id = db.Column(db.Integer, primary_key=True, autoincrement=True) + code = db.Column(db.String(50), unique=True, nullable=False, index=True) + description = db.Column(db.String(200), nullable=True) + + # 折扣类型和值 + discount_type = db.Column(db.String(20), nullable=False) # 'percentage' 或 'fixed_amount' + discount_value = db.Column(db.Numeric(10, 2), nullable=False) + + # 适用范围 + applicable_plans = db.Column(db.String(200), nullable=True) # JSON格式 + applicable_cycles = db.Column(db.String(50), nullable=True) # JSON格式 + min_amount = db.Column(db.Numeric(10, 2), nullable=True) + + # 使用限制 + max_uses = db.Column(db.Integer, nullable=True) + max_uses_per_user = db.Column(db.Integer, default=1) + current_uses = db.Column(db.Integer, default=0) + + # 有效期 + valid_from = db.Column(db.DateTime, nullable=False) + valid_until = db.Column(db.DateTime, nullable=False) + + # 状态 + is_active = db.Column(db.Boolean, default=True) + created_by = db.Column(db.Integer, nullable=True) + created_at = db.Column(db.DateTime, default=beijing_now) + updated_at = db.Column(db.DateTime, default=beijing_now, onupdate=beijing_now) + + def to_dict(self): + return { + 'id': self.id, + 'code': self.code, + 'description': self.description, + 'discount_type': self.discount_type, + 'discount_value': float(self.discount_value) if self.discount_value else 0, + 'applicable_plans': json.loads(self.applicable_plans) if self.applicable_plans else None, + 'applicable_cycles': json.loads(self.applicable_cycles) if self.applicable_cycles else None, + 'min_amount': float(self.min_amount) if self.min_amount else None, + 'max_uses': self.max_uses, + 'max_uses_per_user': self.max_uses_per_user, + 'current_uses': self.current_uses, + 'valid_from': self.valid_from.isoformat() if self.valid_from else None, + 'valid_until': self.valid_until.isoformat() if self.valid_until else None, + 'is_active': self.is_active + } + + +class PromoCodeUsage(db.Model): + """优惠码使用记录表""" + __tablename__ = 'promo_code_usage' + + id = db.Column(db.Integer, primary_key=True, autoincrement=True) + promo_code_id = db.Column(db.Integer, db.ForeignKey('promo_codes.id'), nullable=False) + user_id = db.Column(db.Integer, nullable=False, index=True) + order_id = db.Column(db.Integer, db.ForeignKey('payment_orders.id'), nullable=False) + + original_amount = db.Column(db.Numeric(10, 2), nullable=False) + discount_amount = db.Column(db.Numeric(10, 2), nullable=False) + final_amount = db.Column(db.Numeric(10, 2), nullable=False) + + used_at = db.Column(db.DateTime, default=beijing_now) + + # 关系 + promo_code = db.relationship('PromoCode', backref='usages') + order = db.relationship('PaymentOrder', backref='promo_usage') + + +class SubscriptionUpgrade(db.Model): + """订阅升级/降级记录表""" + __tablename__ = 'subscription_upgrades' + + id = db.Column(db.Integer, primary_key=True, autoincrement=True) + user_id = db.Column(db.Integer, nullable=False, index=True) + order_id = db.Column(db.Integer, db.ForeignKey('payment_orders.id'), nullable=False) + + # 原订阅信息 + from_plan = db.Column(db.String(20), nullable=False) + from_cycle = db.Column(db.String(10), nullable=False) + from_end_date = db.Column(db.DateTime, nullable=True) + + # 新订阅信息 + to_plan = db.Column(db.String(20), nullable=False) + to_cycle = db.Column(db.String(10), nullable=False) + to_end_date = db.Column(db.DateTime, nullable=False) + + # 价格计算 + remaining_value = db.Column(db.Numeric(10, 2), nullable=False) + upgrade_amount = db.Column(db.Numeric(10, 2), nullable=False) + actual_amount = db.Column(db.Numeric(10, 2), nullable=False) + + upgrade_type = db.Column(db.String(20), nullable=False) # 'plan_upgrade', 'cycle_change', 'both' + created_at = db.Column(db.DateTime, default=beijing_now) + + # 关系 + order = db.relationship('PaymentOrder', backref='upgrade_record') + + +# ============================================ +# 模拟盘相关模型 +# ============================================ +class SimulationAccount(db.Model): + """模拟账户""" + __tablename__ = 'simulation_accounts' + + id = db.Column(db.Integer, primary_key=True) + user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False, unique=True) + account_name = db.Column(db.String(100), default='我的模拟账户') + initial_capital = db.Column(db.Numeric(15, 2), default=1000000.00) # 初始资金 + available_cash = db.Column(db.Numeric(15, 2), default=1000000.00) # 可用资金 + frozen_cash = db.Column(db.Numeric(15, 2), default=0.00) # 冻结资金 + position_value = db.Column(db.Numeric(15, 2), default=0.00) # 持仓市值 + total_assets = db.Column(db.Numeric(15, 2), default=1000000.00) # 总资产 + total_profit = db.Column(db.Numeric(15, 2), default=0.00) # 总盈亏 + total_profit_rate = db.Column(db.Numeric(10, 4), default=0.00) # 总收益率 + daily_profit = db.Column(db.Numeric(15, 2), default=0.00) # 日盈亏 + daily_profit_rate = db.Column(db.Numeric(10, 4), default=0.00) # 日收益率 + created_at = db.Column(db.DateTime, default=beijing_now) + updated_at = db.Column(db.DateTime, default=beijing_now, onupdate=beijing_now) + last_settlement_date = db.Column(db.Date) # 最后结算日期 + + # 关系 + user = db.relationship('User', backref='simulation_account') + positions = db.relationship('SimulationPosition', backref='account', lazy='dynamic') + orders = db.relationship('SimulationOrder', backref='account', lazy='dynamic') + transactions = db.relationship('SimulationTransaction', backref='account', lazy='dynamic') + + def calculate_total_assets(self): + """计算总资产""" + self.total_assets = self.available_cash + self.frozen_cash + self.position_value + self.total_profit = self.total_assets - self.initial_capital + self.total_profit_rate = (self.total_profit / self.initial_capital) * 100 if self.initial_capital > 0 else 0 + return self.total_assets + + +class SimulationPosition(db.Model): + """模拟持仓""" + __tablename__ = 'simulation_positions' + + id = db.Column(db.Integer, primary_key=True) + account_id = db.Column(db.Integer, db.ForeignKey('simulation_accounts.id'), nullable=False) + stock_code = db.Column(db.String(20), nullable=False) + stock_name = db.Column(db.String(100)) + position_qty = db.Column(db.Integer, default=0) # 持仓数量 + available_qty = db.Column(db.Integer, default=0) # 可用数量(T+1) + frozen_qty = db.Column(db.Integer, default=0) # 冻结数量 + avg_cost = db.Column(db.Numeric(10, 3), default=0.00) # 平均成本 + current_price = db.Column(db.Numeric(10, 3), default=0.00) # 当前价格 + market_value = db.Column(db.Numeric(15, 2), default=0.00) # 市值 + profit = db.Column(db.Numeric(15, 2), default=0.00) # 盈亏 + profit_rate = db.Column(db.Numeric(10, 4), default=0.00) # 盈亏比例 + today_profit = db.Column(db.Numeric(15, 2), default=0.00) # 今日盈亏 + today_profit_rate = db.Column(db.Numeric(10, 4), default=0.00) # 今日盈亏比例 + created_at = db.Column(db.DateTime, default=beijing_now) + updated_at = db.Column(db.DateTime, default=beijing_now, onupdate=beijing_now) + + __table_args__ = ( + db.UniqueConstraint('account_id', 'stock_code', name='unique_account_stock'), + ) + + def update_market_value(self, current_price): + """更新市值和盈亏""" + self.current_price = current_price + self.market_value = self.position_qty * current_price + total_cost = self.position_qty * self.avg_cost + self.profit = self.market_value - total_cost + self.profit_rate = (self.profit / total_cost * 100) if total_cost > 0 else 0 + return self.market_value + + +class SimulationOrder(db.Model): + """模拟订单""" + __tablename__ = 'simulation_orders' + + id = db.Column(db.Integer, primary_key=True) + account_id = db.Column(db.Integer, db.ForeignKey('simulation_accounts.id'), nullable=False) + order_no = db.Column(db.String(32), unique=True, nullable=False) + stock_code = db.Column(db.String(20), nullable=False) + stock_name = db.Column(db.String(100)) + order_type = db.Column(db.String(10), nullable=False) # BUY/SELL + price_type = db.Column(db.String(10), default='MARKET') # MARKET/LIMIT + order_price = db.Column(db.Numeric(10, 3)) # 委托价格 + order_qty = db.Column(db.Integer, nullable=False) # 委托数量 + filled_qty = db.Column(db.Integer, default=0) # 成交数量 + filled_price = db.Column(db.Numeric(10, 3)) # 成交价格 + filled_amount = db.Column(db.Numeric(15, 2)) # 成交金额 + commission = db.Column(db.Numeric(10, 2), default=0.00) # 手续费 + stamp_tax = db.Column(db.Numeric(10, 2), default=0.00) # 印花税 + transfer_fee = db.Column(db.Numeric(10, 2), default=0.00) # 过户费 + total_fee = db.Column(db.Numeric(10, 2), default=0.00) # 总费用 + status = db.Column(db.String(20), default='PENDING') # PENDING/PARTIAL/FILLED/CANCELLED/REJECTED + reject_reason = db.Column(db.String(200)) + order_time = db.Column(db.DateTime, default=beijing_now) + filled_time = db.Column(db.DateTime) + cancel_time = db.Column(db.DateTime) + + def calculate_fees(self): + """计算交易费用""" + if not self.filled_amount: + return 0 + + # 佣金(万分之2.5,最低5元) + self.commission = max(float(self.filled_amount) * 0.00025, 5.0) + + # 印花税(卖出时收取千分之1) + if self.order_type == 'SELL': + self.stamp_tax = float(self.filled_amount) * 0.001 + else: + self.stamp_tax = 0 + + # 过户费(双向收取,万分之0.2) + self.transfer_fee = float(self.filled_amount) * 0.00002 + + # 总费用 + self.total_fee = self.commission + self.stamp_tax + self.transfer_fee + + return self.total_fee + + +class SimulationTransaction(db.Model): + """模拟成交记录""" + __tablename__ = 'simulation_transactions' + + id = db.Column(db.Integer, primary_key=True) + account_id = db.Column(db.Integer, db.ForeignKey('simulation_accounts.id'), nullable=False) + order_id = db.Column(db.Integer, db.ForeignKey('simulation_orders.id'), nullable=False) + transaction_no = db.Column(db.String(32), unique=True, nullable=False) + stock_code = db.Column(db.String(20), nullable=False) + stock_name = db.Column(db.String(100)) + transaction_type = db.Column(db.String(10), nullable=False) # BUY/SELL + transaction_price = db.Column(db.Numeric(10, 3), nullable=False) + transaction_qty = db.Column(db.Integer, nullable=False) + transaction_amount = db.Column(db.Numeric(15, 2), nullable=False) + commission = db.Column(db.Numeric(10, 2), default=0.00) + stamp_tax = db.Column(db.Numeric(10, 2), default=0.00) + transfer_fee = db.Column(db.Numeric(10, 2), default=0.00) + total_fee = db.Column(db.Numeric(10, 2), default=0.00) + transaction_time = db.Column(db.DateTime, default=beijing_now) + settlement_date = db.Column(db.Date) # T+1结算日期 + + # 关系 + order = db.relationship('SimulationOrder', backref='transactions') + + +class SimulationDailyStats(db.Model): + """模拟账户日统计""" + __tablename__ = 'simulation_daily_stats' + + id = db.Column(db.Integer, primary_key=True) + account_id = db.Column(db.Integer, db.ForeignKey('simulation_accounts.id'), nullable=False) + stat_date = db.Column(db.Date, nullable=False) + opening_assets = db.Column(db.Numeric(15, 2)) # 期初资产 + closing_assets = db.Column(db.Numeric(15, 2)) # 期末资产 + daily_profit = db.Column(db.Numeric(15, 2)) # 日盈亏 + daily_profit_rate = db.Column(db.Numeric(10, 4)) # 日收益率 + total_profit = db.Column(db.Numeric(15, 2)) # 累计盈亏 + total_profit_rate = db.Column(db.Numeric(10, 4)) # 累计收益率 + trade_count = db.Column(db.Integer, default=0) # 交易次数 + win_count = db.Column(db.Integer, default=0) # 盈利次数 + loss_count = db.Column(db.Integer, default=0) # 亏损次数 + max_profit = db.Column(db.Numeric(15, 2)) # 最大盈利 + max_loss = db.Column(db.Numeric(15, 2)) # 最大亏损 + created_at = db.Column(db.DateTime, default=beijing_now) + + __table_args__ = ( + db.UniqueConstraint('account_id', 'stat_date', name='unique_account_date'), + ) + + +def get_user_subscription_safe(user_id): + """安全地获取用户订阅信息""" + try: + subscription = UserSubscription.query.filter_by(user_id=user_id).first() + if not subscription: + subscription = UserSubscription(user_id=user_id) + db.session.add(subscription) + db.session.commit() + return subscription + except Exception as e: + # 返回默认免费版本对象 + class DefaultSub: + def to_dict(self): + return { + 'type': 'free', + 'status': 'active', + 'is_active': True, + 'days_left': 999, + 'billing_cycle': None, + 'auto_renewal': False + } + + return DefaultSub() + + +def activate_user_subscription(user_id, plan_type, billing_cycle, extend_from_now=False): + """激活用户订阅 + + Args: + user_id: 用户ID + plan_type: 套餐类型 + billing_cycle: 计费周期 + extend_from_now: 是否从当前时间开始延长(用于升级场景) + """ + try: + subscription = UserSubscription.query.filter_by(user_id=user_id).first() + if not subscription: + subscription = UserSubscription(user_id=user_id) + db.session.add(subscription) + + subscription.subscription_type = plan_type + subscription.subscription_status = 'active' + subscription.billing_cycle = billing_cycle + + if not extend_from_now or not subscription.start_date: + subscription.start_date = beijing_now() + + if billing_cycle == 'monthly': + subscription.end_date = beijing_now() + timedelta(days=30) + else: # yearly + subscription.end_date = beijing_now() + timedelta(days=365) + + subscription.updated_at = beijing_now() + db.session.commit() + return subscription + except Exception as e: + return None + + +def validate_promo_code(code, plan_name, billing_cycle, amount, user_id): + """验证优惠码 + + Returns: + tuple: (promo_code_obj, error_message) + """ + try: + promo = PromoCode.query.filter_by(code=code.upper(), is_active=True).first() + + if not promo: + return None, "优惠码不存在或已失效" + + # 检查有效期 + now = beijing_now() + if now < promo.valid_from: + return None, "优惠码尚未生效" + if now > promo.valid_until: + return None, "优惠码已过期" + + # 检查使用次数 + if promo.max_uses and promo.current_uses >= promo.max_uses: + return None, "优惠码已被使用完" + + # 检查每用户使用次数 + if promo.max_uses_per_user: + user_usage_count = PromoCodeUsage.query.filter_by( + promo_code_id=promo.id, + user_id=user_id + ).count() + if user_usage_count >= promo.max_uses_per_user: + return None, f"您已使用过此优惠码(限用{promo.max_uses_per_user}次)" + + # 检查适用套餐 + if promo.applicable_plans: + try: + applicable = json.loads(promo.applicable_plans) + if plan_name not in applicable: + return None, "该优惠码不适用于此套餐" + except: + pass + + # 检查适用周期 + if promo.applicable_cycles: + try: + applicable = json.loads(promo.applicable_cycles) + if billing_cycle not in applicable: + return None, "该优惠码不适用于此计费周期" + except: + pass + + # 检查最低消费 + if promo.min_amount and amount < float(promo.min_amount): + return None, f"需满{float(promo.min_amount):.2f}元才可使用此优惠码" + + return promo, None + except Exception as e: + return None, f"验证优惠码时出错: {str(e)}" + + +def calculate_discount(promo_code, amount): + """计算优惠金额""" + try: + if promo_code.discount_type == 'percentage': + discount = amount * (float(promo_code.discount_value) / 100) + else: # fixed_amount + discount = float(promo_code.discount_value) + + # 确保折扣不超过总金额 + return min(discount, amount) + except: + return 0 + + +def calculate_remaining_value(subscription, current_plan): + """计算当前订阅的剩余价值""" + try: + if not subscription or not subscription.end_date: + return 0 + + now = beijing_now() + if subscription.end_date <= now: + return 0 + + days_left = (subscription.end_date - now).days + + if subscription.billing_cycle == 'monthly': + daily_value = float(current_plan.monthly_price) / 30 + else: # yearly + daily_value = float(current_plan.yearly_price) / 365 + + return daily_value * days_left + except: + return 0 + + +def calculate_upgrade_price(user_id, to_plan_name, to_cycle, promo_code=None): + """计算升级所需价格 + + Returns: + dict: 包含价格计算结果的字典 + """ + try: + # 1. 获取当前订阅 + current_sub = UserSubscription.query.filter_by(user_id=user_id).first() + + # 2. 获取目标套餐 + to_plan = SubscriptionPlan.query.filter_by(name=to_plan_name, is_active=True).first() + if not to_plan: + return {'error': '目标套餐不存在'} + + # 3. 计算目标套餐价格 + new_price = float(to_plan.yearly_price if to_cycle == 'yearly' else to_plan.monthly_price) + + # 4. 如果是新订阅(非升级) + if not current_sub or current_sub.subscription_type == 'free': + result = { + 'is_upgrade': False, + 'new_plan_price': new_price, + 'remaining_value': 0, + 'upgrade_amount': new_price, + 'original_amount': new_price, + 'discount_amount': 0, + 'final_amount': new_price, + 'promo_code': None + } + + # 应用优惠码 + if promo_code: + promo, error = validate_promo_code(promo_code, to_plan_name, to_cycle, new_price, user_id) + if promo: + discount = calculate_discount(promo, new_price) + result['discount_amount'] = discount + result['final_amount'] = new_price - discount + result['promo_code'] = promo.code + elif error: + result['promo_error'] = error + + return result + + # 5. 升级场景:计算剩余价值 + current_plan = SubscriptionPlan.query.filter_by(name=current_sub.subscription_type, is_active=True).first() + if not current_plan: + return {'error': '当前套餐信息不存在'} + + remaining_value = calculate_remaining_value(current_sub, current_plan) + + # 6. 计算升级差价 + upgrade_amount = max(0, new_price - remaining_value) + + # 7. 判断升级类型 + upgrade_type = 'new' + if current_sub.subscription_type != to_plan_name and current_sub.billing_cycle != to_cycle: + upgrade_type = 'both' + elif current_sub.subscription_type != to_plan_name: + upgrade_type = 'plan_upgrade' + elif current_sub.billing_cycle != to_cycle: + upgrade_type = 'cycle_change' + + result = { + 'is_upgrade': True, + 'upgrade_type': upgrade_type, + 'current_plan': current_sub.subscription_type, + 'current_cycle': current_sub.billing_cycle, + 'current_end_date': current_sub.end_date.isoformat() if current_sub.end_date else None, + 'new_plan_price': new_price, + 'remaining_value': remaining_value, + 'upgrade_amount': upgrade_amount, + 'original_amount': upgrade_amount, + 'discount_amount': 0, + 'final_amount': upgrade_amount, + 'promo_code': None + } + + # 8. 应用优惠码 + if promo_code and upgrade_amount > 0: + promo, error = validate_promo_code(promo_code, to_plan_name, to_cycle, upgrade_amount, user_id) + if promo: + discount = calculate_discount(promo, upgrade_amount) + result['discount_amount'] = discount + result['final_amount'] = upgrade_amount - discount + result['promo_code'] = promo.code + elif error: + result['promo_error'] = error + + return result + except Exception as e: + return {'error': str(e)} + + +def initialize_subscription_plans_safe(): + """安全地初始化订阅套餐""" + try: + if SubscriptionPlan.query.first(): + return + + pro_plan = SubscriptionPlan( + name='pro', + display_name='Pro版本', + description='适合个人投资者的基础功能套餐', + monthly_price=0.01, + yearly_price=0.08, + features=json.dumps([ + "基础股票分析工具", + "历史数据查询", + "基础财务报表", + "简单投资计划记录", + "标准客服支持" + ]), + sort_order=1 + ) + + max_plan = SubscriptionPlan( + name='max', + display_name='Max版本', + description='适合专业投资者的全功能套餐', + monthly_price=0.1, + yearly_price=0.8, + features=json.dumps([ + "全部Pro版本功能", + "高级分析工具", + "实时数据推送", + "专业财务分析报告", + "AI投资建议", + "无限投资计划存储", + "优先客服支持", + "独家研报访问" + ]), + sort_order=2 + ) + + db.session.add(pro_plan) + db.session.add(max_plan) + db.session.commit() + except Exception as e: + pass + + +# -------------------------------------------- +# 订阅等级工具函数 +# -------------------------------------------- +def _get_current_subscription_info(): + """获取当前登录用户订阅信息的字典形式,未登录或异常时视为免费用户。""" + try: + user_id = session.get('user_id') + if not user_id: + return { + 'type': 'free', + 'status': 'active', + 'is_active': True + } + sub = get_user_subscription_safe(user_id) + data = sub.to_dict() + # 标准化字段名 + return { + 'type': data.get('type') or data.get('subscription_type') or 'free', + 'status': data.get('status') or data.get('subscription_status') or 'active', + 'is_active': data.get('is_active', True) + } + except Exception: + return { + 'type': 'free', + 'status': 'active', + 'is_active': True + } + + +def _subscription_level(sub_type): + """将订阅类型映射到等级数值,free=0, pro=1, max=2。""" + mapping = {'free': 0, 'pro': 1, 'max': 2} + return mapping.get((sub_type or 'free').lower(), 0) + + +def _has_required_level(required: str) -> bool: + """判断当前用户是否达到所需订阅级别。""" + info = _get_current_subscription_info() + if not info.get('is_active', True): + return False + return _subscription_level(info.get('type')) >= _subscription_level(required) + + +# ============================================ +# 订阅相关API接口 +# ============================================ + +@app.route('/api/subscription/plans', methods=['GET']) +def get_subscription_plans(): + """获取订阅套餐列表""" + try: + plans = SubscriptionPlan.query.filter_by(is_active=True).order_by(SubscriptionPlan.sort_order).all() + return jsonify({ + 'success': True, + 'data': [plan.to_dict() for plan in plans] + }) + except Exception as e: + # 返回默认套餐(包含pricing_options以兼容新前端) + default_plans = [ + { + 'id': 1, + 'name': 'pro', + 'display_name': 'Pro版本', + 'description': '适合个人投资者的基础功能套餐', + 'monthly_price': 198, + 'yearly_price': 2000, + 'pricing_options': [ + {'months': 1, 'price': 198, 'label': '月付', 'cycle_key': 'monthly'}, + {'months': 3, 'price': 534, 'label': '3个月', 'cycle_key': '3months', 'discount_percent': 10}, + {'months': 6, 'price': 950, 'label': '半年', 'cycle_key': '6months', 'discount_percent': 20}, + {'months': 12, 'price': 2000, 'label': '1年', 'cycle_key': 'yearly', 'discount_percent': 16}, + {'months': 24, 'price': 3600, 'label': '2年', 'cycle_key': '2years', 'discount_percent': 24}, + {'months': 36, 'price': 5040, 'label': '3年', 'cycle_key': '3years', 'discount_percent': 29} + ], + 'features': ['基础股票分析工具', '历史数据查询', '基础财务报表', '简单投资计划记录', '标准客服支持'], + 'is_active': True, + 'sort_order': 1 + }, + { + 'id': 2, + 'name': 'max', + 'display_name': 'Max版本', + 'description': '适合专业投资者的全功能套餐', + 'monthly_price': 998, + 'yearly_price': 10000, + 'pricing_options': [ + {'months': 1, 'price': 998, 'label': '月付', 'cycle_key': 'monthly'}, + {'months': 3, 'price': 2695, 'label': '3个月', 'cycle_key': '3months', 'discount_percent': 10}, + {'months': 6, 'price': 4790, 'label': '半年', 'cycle_key': '6months', 'discount_percent': 20}, + {'months': 12, 'price': 10000, 'label': '1年', 'cycle_key': 'yearly', 'discount_percent': 17}, + {'months': 24, 'price': 18000, 'label': '2年', 'cycle_key': '2years', 'discount_percent': 25}, + {'months': 36, 'price': 25200, 'label': '3年', 'cycle_key': '3years', 'discount_percent': 30} + ], + 'features': ['全部Pro版本功能', '高级分析工具', '实时数据推送', 'API访问', '优先客服支持'], + 'is_active': True, + 'sort_order': 2 + } + ] + return jsonify({ + 'success': True, + 'data': default_plans + }) + + +@app.route('/api/subscription/current', methods=['GET']) +def get_current_subscription(): + """获取当前用户的订阅信息""" + try: + if 'user_id' not in session: + return jsonify({'success': False, 'error': '未登录'}), 401 + + subscription = get_user_subscription_safe(session['user_id']) + return jsonify({ + 'success': True, + 'data': subscription.to_dict() + }) + except Exception as e: + return jsonify({ + 'success': True, + 'data': { + 'type': 'free', + 'status': 'active', + 'is_active': True, + 'days_left': 999 + } + }) + + +@app.route('/api/subscription/info', methods=['GET']) +def get_subscription_info(): + """获取当前用户的订阅信息 - 前端专用接口""" + try: + info = _get_current_subscription_info() + return jsonify({ + 'success': True, + 'data': info + }) + except Exception as e: + print(f"获取订阅信息错误: {e}") + return jsonify({ + 'success': True, + 'data': { + 'type': 'free', + 'status': 'active', + 'is_active': True, + 'days_left': 999 + } + }) + + +@app.route('/api/promo-code/validate', methods=['POST']) +def validate_promo_code_api(): + """验证优惠码""" + try: + if 'user_id' not in session: + return jsonify({'success': False, 'error': '未登录'}), 401 + + data = request.get_json() + code = data.get('code', '').strip() + plan_name = data.get('plan_name') + billing_cycle = data.get('billing_cycle') + amount = data.get('amount', 0) + + if not code or not plan_name or not billing_cycle: + return jsonify({'success': False, 'error': '参数不完整'}), 400 + + # 验证优惠码 + promo, error = validate_promo_code(code, plan_name, billing_cycle, amount, session['user_id']) + + if error: + return jsonify({ + 'success': False, + 'valid': False, + 'error': error + }) + + # 计算折扣 + discount_amount = calculate_discount(promo, amount) + final_amount = amount - discount_amount + + return jsonify({ + 'success': True, + 'valid': True, + 'promo_code': promo.to_dict(), + 'discount_amount': discount_amount, + 'final_amount': final_amount + }) + + except Exception as e: + return jsonify({ + 'success': False, + 'error': f'验证失败: {str(e)}' + }), 500 + + +@app.route('/api/subscription/calculate-price', methods=['POST']) +def calculate_subscription_price(): + """计算订阅价格(支持升级和优惠码)""" + try: + if 'user_id' not in session: + return jsonify({'success': False, 'error': '未登录'}), 401 + + data = request.get_json() + to_plan = data.get('to_plan') + to_cycle = data.get('to_cycle') + promo_code = data.get('promo_code', '').strip() or None + + if not to_plan or not to_cycle: + return jsonify({'success': False, 'error': '参数不完整'}), 400 + + # 计算价格 + result = calculate_upgrade_price(session['user_id'], to_plan, to_cycle, promo_code) + + if 'error' in result: + return jsonify({ + 'success': False, + 'error': result['error'] + }), 400 + + return jsonify({ + 'success': True, + 'data': result + }) + + except Exception as e: + return jsonify({ + 'success': False, + 'error': f'计算失败: {str(e)}' + }), 500 + + +@app.route('/api/payment/create-order', methods=['POST']) +def create_payment_order(): + """创建支付订单(支持升级和优惠码)""" + try: + if 'user_id' not in session: + return jsonify({'success': False, 'error': '未登录'}), 401 + + data = request.get_json() + plan_name = data.get('plan_name') + billing_cycle = data.get('billing_cycle') + promo_code = data.get('promo_code', '').strip() or None + + if not plan_name or not billing_cycle: + return jsonify({'success': False, 'error': '参数不完整'}), 400 + + # 计算价格(包括升级和优惠码) + price_result = calculate_upgrade_price(session['user_id'], plan_name, billing_cycle, promo_code) + + if 'error' in price_result: + return jsonify({'success': False, 'error': price_result['error']}), 400 + + amount = price_result['final_amount'] + original_amount = price_result['original_amount'] + discount_amount = price_result['discount_amount'] + is_upgrade = price_result.get('is_upgrade', False) + + # 创建订单 + try: + order = PaymentOrder( + user_id=session['user_id'], + plan_name=plan_name, + billing_cycle=billing_cycle, + amount=amount + ) + + # 添加扩展字段(使用动态属性) + if hasattr(order, 'original_amount') or True: # 兼容性检查 + order.original_amount = original_amount + order.discount_amount = discount_amount + order.is_upgrade = is_upgrade + + # 如果使用了优惠码,关联优惠码 + if promo_code and price_result.get('promo_code'): + promo_obj = PromoCode.query.filter_by(code=promo_code.upper()).first() + if promo_obj: + order.promo_code_id = promo_obj.id + + # 如果是升级,记录原套餐信息 + if is_upgrade: + order.upgrade_from_plan = price_result.get('current_plan') + + db.session.add(order) + db.session.commit() + + # 如果是升级订单,创建升级记录 + if is_upgrade and price_result.get('upgrade_type'): + try: + upgrade_record = SubscriptionUpgrade( + user_id=session['user_id'], + order_id=order.id, + from_plan=price_result['current_plan'], + from_cycle=price_result['current_cycle'], + from_end_date=datetime.fromisoformat(price_result['current_end_date']) if price_result.get('current_end_date') else None, + to_plan=plan_name, + to_cycle=billing_cycle, + to_end_date=beijing_now() + timedelta(days=365 if billing_cycle == 'yearly' else 30), + remaining_value=price_result['remaining_value'], + upgrade_amount=price_result['upgrade_amount'], + actual_amount=amount, + upgrade_type=price_result['upgrade_type'] + ) + db.session.add(upgrade_record) + db.session.commit() + except Exception as e: + print(f"创建升级记录失败: {e}") + # 不影响主流程 + + except Exception as e: + db.session.rollback() + return jsonify({'success': False, 'error': f'订单创建失败: {str(e)}'}), 500 + + # 尝试调用真实的微信支付API + try: + from wechat_pay import create_wechat_pay_instance, check_wechat_pay_ready + + # 检查微信支付是否就绪 + is_ready, ready_msg = check_wechat_pay_ready() + if not is_ready: + # 使用模拟二维码 + order.qr_code_url = f"https://api.qrserver.com/v1/create-qr-code/?size=200x200&data=wxpay://order/{order.order_no}" + order.remark = f"演示模式 - {ready_msg}" + else: + wechat_pay = create_wechat_pay_instance() + + # 创建微信支付订单 + plan_display_name = f"{plan_name.upper()}版本-{billing_cycle}" + wechat_result = wechat_pay.create_native_order( + order_no=order.order_no, + total_fee=float(amount), + body=f"VFr-{plan_display_name}", + product_id=f"{plan_name}_{billing_cycle}" + ) + + if wechat_result['success']: + + # 获取微信返回的原始code_url + wechat_code_url = wechat_result['code_url'] + + # 将微信协议URL转换为二维码图片URL + import urllib.parse + encoded_url = urllib.parse.quote(wechat_code_url, safe='') + qr_image_url = f"https://api.qrserver.com/v1/create-qr-code/?size=200x200&data={encoded_url}" + + order.qr_code_url = qr_image_url + order.prepay_id = wechat_result.get('prepay_id') + order.remark = f"微信支付 - {wechat_code_url}" + + else: + order.qr_code_url = f"https://api.qrserver.com/v1/create-qr-code/?size=200x200&data=wxpay://order/{order.order_no}" + order.remark = f"微信支付失败: {wechat_result.get('error')}" + + except ImportError as e: + order.qr_code_url = f"https://api.qrserver.com/v1/create-qr-code/?size=200x200&data=wxpay://order/{order.order_no}" + order.remark = "微信支付模块未配置" + except Exception as e: + order.qr_code_url = f"https://api.qrserver.com/v1/create-qr-code/?size=200x200&data=wxpay://order/{order.order_no}" + order.remark = f"支付异常: {str(e)}" + + db.session.commit() + + return jsonify({ + 'success': True, + 'data': order.to_dict(), + 'message': '订单创建成功' + }) + + except Exception as e: + db.session.rollback() + return jsonify({'success': False, 'error': '创建订单失败'}), 500 + + +@app.route('/api/payment/order//status', methods=['GET']) +def check_order_status(order_id): + """查询订单支付状态""" + try: + if 'user_id' not in session: + return jsonify({'success': False, 'error': '未登录'}), 401 + + # 查找订单 + order = PaymentOrder.query.filter_by( + id=order_id, + user_id=session['user_id'] + ).first() + + if not order: + return jsonify({'success': False, 'error': '订单不存在'}), 404 + + # 如果订单已经是已支付状态,直接返回 + if order.status == 'paid': + return jsonify({ + 'success': True, + 'data': order.to_dict(), + 'message': '订单已支付', + 'payment_success': True + }) + + # 如果订单过期,标记为过期 + if order.is_expired(): + order.status = 'expired' + db.session.commit() + return jsonify({ + 'success': True, + 'data': order.to_dict(), + 'message': '订单已过期' + }) + + # 调用微信支付API查询真实状态 + try: + from wechat_pay import create_wechat_pay_instance + wechat_pay = create_wechat_pay_instance() + + query_result = wechat_pay.query_order(order_no=order.order_no) + + if query_result['success']: + trade_state = query_result.get('trade_state') + transaction_id = query_result.get('transaction_id') + + if trade_state == 'SUCCESS': + # 支付成功,更新订单状态 + order.mark_as_paid(transaction_id) + + # 激活用户订阅 + activate_user_subscription(order.user_id, order.plan_name, order.billing_cycle) + + return jsonify({ + 'success': True, + 'data': order.to_dict(), + 'message': '支付成功!订阅已激活', + 'payment_success': True + }) + elif trade_state in ['NOTPAY', 'USERPAYING']: + # 未支付或支付中 + return jsonify({ + 'success': True, + 'data': order.to_dict(), + 'message': '等待支付...', + 'payment_success': False + }) + else: + # 支付失败或取消 + order.status = 'cancelled' + db.session.commit() + return jsonify({ + 'success': True, + 'data': order.to_dict(), + 'message': '支付已取消', + 'payment_success': False + }) + else: + # 微信查询失败,返回当前状态 + return jsonify({ + 'success': True, + 'data': order.to_dict(), + 'message': f"查询失败: {query_result.get('error')}", + 'payment_success': False + }) + + except Exception as e: + # 查询失败,返回当前订单状态 + return jsonify({ + 'success': True, + 'data': order.to_dict(), + 'message': '无法查询支付状态,请稍后重试', + 'payment_success': False + }) + + except Exception as e: + return jsonify({'success': False, 'error': '查询失败'}), 500 + + +@app.route('/api/payment/order//force-update', methods=['POST']) +def force_update_order_status(order_id): + """强制更新订单支付状态(调试用)""" + try: + if 'user_id' not in session: + return jsonify({'success': False, 'error': '未登录'}), 401 + + # 查找订单 + order = PaymentOrder.query.filter_by( + id=order_id, + user_id=session['user_id'] + ).first() + + if not order: + return jsonify({'success': False, 'error': '订单不存在'}), 404 + + # 检查微信支付状态 + try: + from wechat_pay import create_wechat_pay_instance + wechat_pay = create_wechat_pay_instance() + + query_result = wechat_pay.query_order(order_no=order.order_no) + + if query_result['success'] and query_result.get('trade_state') == 'SUCCESS': + # 强制更新为已支付 + old_status = order.status + order.mark_as_paid(query_result.get('transaction_id')) + + # 激活用户订阅 + activate_user_subscription(order.user_id, order.plan_name, order.billing_cycle) + + # 记录优惠码使用(如果使用了优惠码) + if hasattr(order, 'promo_code_id') and order.promo_code_id: + try: + promo_usage = PromoCodeUsage( + promo_code_id=order.promo_code_id, + user_id=order.user_id, + order_id=order.id, + original_amount=order.original_amount if hasattr(order, 'original_amount') else order.amount, + discount_amount=order.discount_amount if hasattr(order, 'discount_amount') else 0, + final_amount=order.amount + ) + db.session.add(promo_usage) + + # 更新优惠码使用次数 + promo = PromoCode.query.get(order.promo_code_id) + if promo: + promo.current_uses = (promo.current_uses or 0) + 1 + except Exception as e: + print(f"记录优惠码使用失败: {e}") + + db.session.commit() + + print(f"✅ 订单状态强制更新成功: {old_status} -> paid") + + return jsonify({ + 'success': True, + 'message': f'订单状态已从 {old_status} 更新为 paid', + 'data': order.to_dict(), + 'payment_success': True + }) + else: + return jsonify({ + 'success': False, + 'error': '微信支付状态不是成功状态,无法强制更新' + }) + + except Exception as e: + print(f"❌ 强制更新失败: {e}") + return jsonify({ + 'success': False, + 'error': f'强制更新失败: {str(e)}' + }) + + except Exception as e: + print(f"强制更新订单状态失败: {str(e)}") + return jsonify({'success': False, 'error': '操作失败'}), 500 + + +@app.route('/api/payment/wechat/callback', methods=['POST']) +def wechat_payment_callback(): + """微信支付回调处理""" + try: + # 获取原始XML数据 + raw_data = request.get_data() + print(f"📥 收到微信支付回调: {raw_data}") + + # 验证回调数据 + try: + from wechat_pay import create_wechat_pay_instance + wechat_pay = create_wechat_pay_instance() + verify_result = wechat_pay.verify_callback(raw_data.decode('utf-8')) + + if not verify_result['success']: + print(f"❌ 微信支付回调验证失败: {verify_result['error']}") + return '' + + callback_data = verify_result['data'] + + except Exception as e: + print(f"❌ 微信支付回调处理异常: {e}") + # 简单解析XML(fallback) + callback_data = _parse_xml_callback(raw_data.decode('utf-8')) + if not callback_data: + return '' + + # 获取关键字段 + return_code = callback_data.get('return_code') + result_code = callback_data.get('result_code') + order_no = callback_data.get('out_trade_no') + transaction_id = callback_data.get('transaction_id') + + print(f"📦 回调数据解析:") + print(f" 返回码: {return_code}") + print(f" 结果码: {result_code}") + print(f" 订单号: {order_no}") + print(f" 交易号: {transaction_id}") + + if not order_no: + return '' + + # 查找订单 + order = PaymentOrder.query.filter_by(order_no=order_no).first() + if not order: + print(f"❌ 订单不存在: {order_no}") + return '' + + # 处理支付成功 + if return_code == 'SUCCESS' and result_code == 'SUCCESS': + print(f"🎉 支付回调成功: 订单 {order_no}") + + # 检查订单是否已经处理过 + if order.status == 'paid': + print(f"ℹ️ 订单已处理过: {order_no}") + db.session.commit() + return '' + + # 更新订单状态(无论之前是什么状态) + old_status = order.status + order.mark_as_paid(transaction_id) + print(f"📝 订单状态已更新: {old_status} -> paid") + + # 激活用户订阅 + subscription = activate_user_subscription(order.user_id, order.plan_name, order.billing_cycle) + + if subscription: + print(f"✅ 用户订阅已激活: 用户{order.user_id}, 套餐{order.plan_name}") + else: + print(f"⚠️ 订阅激活失败,但订单已标记为已支付") + + db.session.commit() + + # 返回成功响应给微信 + return '' + + except Exception as e: + db.session.rollback() + print(f"❌ 微信支付回调处理失败: {e}") + import traceback + app.logger.error(f"回调处理错误: {e}", exc_info=True) + return '' + + +def _parse_xml_callback(xml_data): + """简单的XML回调数据解析""" + try: + import xml.etree.ElementTree as ET + root = ET.fromstring(xml_data) + result = {} + for child in root: + result[child.tag] = child.text + return result + except Exception as e: + print(f"XML解析失败: {e}") + return None + + +@app.route('/api/auth/session', methods=['GET']) +def get_session_info(): + """获取当前登录用户信息""" + if 'user_id' in session: + user = User.query.get(session['user_id']) + if user: + # 获取用户订阅信息 + subscription_info = get_user_subscription_safe(user.id).to_dict() + + return jsonify({ + 'success': True, + 'isAuthenticated': True, + 'user': { + 'id': user.id, + 'username': user.username, + 'nickname': user.nickname or user.username, + 'email': user.email, + 'phone': user.phone, + 'phone_confirmed': bool(user.phone_confirmed), + 'email_confirmed': bool(user.email_confirmed) if hasattr(user, 'email_confirmed') else None, + 'avatar_url': user.avatar_url, + 'has_wechat': bool(user.wechat_open_id), + 'created_at': user.created_at.isoformat() if user.created_at else None, + 'last_seen': user.last_seen.isoformat() if user.last_seen else None, + # 将订阅字段映射到前端期望的字段名 + 'subscription_type': subscription_info['type'], + 'subscription_status': subscription_info['status'], + 'subscription_end_date': subscription_info['end_date'], + 'is_subscription_active': subscription_info['is_active'], + 'subscription_days_left': subscription_info['days_left'] + } + }) + + return jsonify({ + 'success': True, + 'isAuthenticated': False, + 'user': None + }) + + +def generate_verification_code(): + """生成6位数字验证码""" + return ''.join(random.choices(string.digits, k=6)) + + +@app.route('/api/auth/login', methods=['POST']) +def login(): + """传统登录 - 使用Session""" + try: + + username = request.form.get('username') + email = request.form.get('email') + phone = request.form.get('phone') + password = request.form.get('password') + + # 验证必要参数 + if not password: + return jsonify({'success': False, 'error': '密码不能为空'}), 400 + + # 根据提供的信息查找用户 + user = None + if username: + # 检查username是否为手机号格式 + if re.match(r'^1[3-9]\d{9}$', username): + # 如果username是手机号格式,先按手机号查找 + user = User.query.filter_by(phone=username).first() + if not user: + # 如果没找到,再按用户名查找 + user = User.find_by_login_info(username) + else: + # 不是手机号格式,按用户名查找 + user = User.find_by_login_info(username) + elif email: + user = User.query.filter_by(email=email).first() + elif phone: + user = User.query.filter_by(phone=phone).first() + else: + return jsonify({'success': False, 'error': '请提供用户名、邮箱或手机号'}), 400 + + if not user: + return jsonify({'success': False, 'error': '用户不存在'}), 404 + + # 尝试密码验证 + password_valid = user.check_password(password) + + if not password_valid: + # 还可以尝试直接验证 + if user.password_hash: + from werkzeug.security import check_password_hash + direct_check = check_password_hash(user.password_hash, password) + return jsonify({'success': False, 'error': '密码错误'}), 401 + + # 设置session + session.permanent = True # 使用永久session + session['user_id'] = user.id + session['username'] = user.username + session['logged_in'] = True + + # Flask-Login 登录 + login_user(user, remember=True) + + # 更新最后登录时间 + user.update_last_seen() + + return jsonify({ + 'success': True, + 'message': '登录成功', + 'user': { + 'id': user.id, + 'username': user.username, + 'nickname': user.nickname or user.username, + 'email': user.email, + 'phone': user.phone, + 'avatar_url': user.avatar_url, + 'has_wechat': bool(user.wechat_open_id) + } + }) + + except Exception as e: + import traceback + app.logger.error(f"回调处理错误: {e}", exc_info=True) + return jsonify({'success': False, 'error': '登录处理失败,请重试'}), 500 + + +# 添加OPTIONS请求处理 +@app.before_request +def handle_preflight(): + if request.method == "OPTIONS": + response = make_response() + response.headers.add("Access-Control-Allow-Origin", "*") + response.headers.add('Access-Control-Allow-Headers', "*") + response.headers.add('Access-Control-Allow-Methods', "*") + return response + + +# 修改密码API +@app.route('/api/account/change-password', methods=['POST']) +@login_required +def change_password(): + """修改当前用户密码""" + try: + data = request.get_json() or request.form + current_password = data.get('currentPassword') or data.get('current_password') + new_password = data.get('newPassword') or data.get('new_password') + is_first_set = data.get('isFirstSet', False) # 是否为首次设置密码 + + if not new_password: + return jsonify({'success': False, 'error': '新密码不能为空'}), 400 + + if len(new_password) < 6: + return jsonify({'success': False, 'error': '新密码至少需要6个字符'}), 400 + + # 获取当前用户 + user = current_user + if not user: + return jsonify({'success': False, 'error': '用户未登录'}), 401 + + # 检查是否为微信用户且首次设置密码 + is_wechat_user = bool(user.wechat_open_id) + + # 如果是微信用户首次设置密码,或者明确标记为首次设置,则跳过当前密码验证 + if is_first_set or (is_wechat_user and not current_password): + pass # 跳过当前密码验证 + else: + # 普通用户或非首次设置,需要验证当前密码 + if not current_password: + return jsonify({'success': False, 'error': '请输入当前密码'}), 400 + + if not user.check_password(current_password): + return jsonify({'success': False, 'error': '当前密码错误'}), 400 + + # 设置新密码 + user.set_password(new_password) + db.session.commit() + + return jsonify({ + 'success': True, + 'message': '密码设置成功' if (is_first_set or is_wechat_user) else '密码修改成功' + }) + + except Exception as e: + return jsonify({'success': False, 'error': str(e)}), 500 + + +# 检查用户密码状态API +@app.route('/api/account/password-status', methods=['GET']) +@login_required +def get_password_status(): + """获取当前用户的密码状态信息""" + try: + user = current_user + if not user: + return jsonify({'success': False, 'error': '用户未登录'}), 401 + + is_wechat_user = bool(user.wechat_open_id) + + return jsonify({ + 'success': True, + 'data': { + 'isWechatUser': is_wechat_user, + 'hasPassword': bool(user.password_hash), + 'needsFirstTimeSetup': is_wechat_user # 微信用户需要首次设置 + } + }) + + except Exception as e: + return jsonify({'success': False, 'error': str(e)}), 500 + + +# 检查用户信息完整性API +@app.route('/api/account/profile-completeness', methods=['GET']) +@login_required +def get_profile_completeness(): + try: + user = current_user + if not user: + return jsonify({'success': False, 'error': '用户未登录'}), 401 + + is_wechat_user = bool(user.wechat_open_id) + + # 检查各项信息 + completeness = { + 'hasPassword': bool(user.password_hash), + 'hasPhone': bool(user.phone), + 'hasEmail': bool(user.email and '@' in user.email and not user.email.endswith('@valuefrontier.temp')), + 'isWechatUser': is_wechat_user + } + + # 计算完整度 + total_items = 3 + completed_items = sum([completeness['hasPassword'], completeness['hasPhone'], completeness['hasEmail']]) + completeness_percentage = int((completed_items / total_items) * 100) + + # 智能判断是否需要提醒 + needs_attention = False + missing_items = [] + + # 只在用户首次登录或最近登录时提醒 + if is_wechat_user: + # 检查用户是否是新用户(注册7天内) + is_new_user = (datetime.now() - user.created_at).days < 7 + + # 检查是否最近没有提醒过(使用session记录) + last_reminder = session.get('last_completeness_reminder') + should_remind = False + + if not last_reminder: + should_remind = True + else: + # 每7天最多提醒一次 + days_since_reminder = (datetime.now() - datetime.fromisoformat(last_reminder)).days + should_remind = days_since_reminder >= 7 + + # 只对新用户或长时间未完善的用户提醒 + if (is_new_user or completeness_percentage < 50) and should_remind: + needs_attention = True + if not completeness['hasPassword']: + missing_items.append('登录密码') + if not completeness['hasPhone']: + missing_items.append('手机号') + if not completeness['hasEmail']: + missing_items.append('邮箱') + + # 记录本次提醒时间 + session['last_completeness_reminder'] = datetime.now().isoformat() + + return jsonify({ + 'success': True, + 'data': { + 'completeness': completeness, + 'completenessPercentage': completeness_percentage, + 'needsAttention': needs_attention, + 'missingItems': missing_items, + 'isComplete': completed_items == total_items, + 'showReminder': needs_attention # 前端使用这个字段决定是否显示提醒 + } + }) + + except Exception as e: + print(f"获取资料完整性错误: {e}") + return jsonify({'success': False, 'error': str(e)}), 500 + + +@app.route('/api/auth/logout', methods=['POST']) +def logout(): + """登出 - 清除Session""" + logout_user() # Flask-Login 登出 + session.clear() + return jsonify({'success': True, 'message': '已登出'}) + + +@app.route('/api/auth/send-verification-code', methods=['POST']) +def send_verification_code(): + """发送验证码(支持手机号和邮箱)""" + try: + data = request.get_json() + credential = data.get('credential') # 手机号或邮箱 + code_type = data.get('type') # 'phone' 或 'email' + purpose = data.get('purpose', 'login') # 'login' 或 'register' + + if not credential or not code_type: + return jsonify({'success': False, 'error': '缺少必要参数'}), 400 + + # 清理格式字符(空格、横线、括号等) + if code_type == 'phone': + # 移除手机号中的空格、横线、括号、加号等格式字符 + credential = re.sub(r'[\s\-\(\)\+]', '', credential) + print(f"📱 清理后的手机号: {credential}") + elif code_type == 'email': + # 邮箱只移除空格 + credential = credential.strip() + + # 生成验证码 + verification_code = generate_verification_code() + + # 存储验证码到session(实际生产环境建议使用Redis) + session_key = f'verification_code_{code_type}_{credential}_{purpose}' + session[session_key] = { + 'code': verification_code, + 'timestamp': time.time(), + 'attempts': 0 + } + + if code_type == 'phone': + # 手机号验证码发送 + if not re.match(r'^1[3-9]\d{9}$', credential): + return jsonify({'success': False, 'error': '手机号格式不正确'}), 400 + + # 发送真实短信验证码 + if send_sms_code(credential, verification_code, SMS_TEMPLATE_LOGIN): + print(f"[短信已发送] 验证码到 {credential}: {verification_code}") + else: + return jsonify({'success': False, 'error': '短信发送失败,请稍后重试'}), 500 + + elif code_type == 'email': + # 邮箱验证码发送 + if not re.match(r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$', credential): + return jsonify({'success': False, 'error': '邮箱格式不正确'}), 400 + + # 发送真实邮件验证码 + if send_email_code(credential, verification_code): + print(f"[邮件已发送] 验证码到 {credential}: {verification_code}") + else: + return jsonify({'success': False, 'error': '邮件发送失败,请稍后重试'}), 500 + + else: + return jsonify({'success': False, 'error': '不支持的验证码类型'}), 400 + + return jsonify({ + 'success': True, + 'message': f'验证码已发送到您的{code_type}' + }) + + except Exception as e: + print(f"发送验证码错误: {e}") + return jsonify({'success': False, 'error': '发送验证码失败'}), 500 + + +@app.route('/api/auth/login-with-code', methods=['POST']) +def login_with_verification_code(): + """使用验证码登录/注册(自动注册)""" + try: + data = request.get_json() + credential = data.get('credential') # 手机号或邮箱 + verification_code = data.get('verification_code') + login_type = data.get('login_type') # 'phone' 或 'email' + + if not credential or not verification_code or not login_type: + return jsonify({'success': False, 'error': '缺少必要参数'}), 400 + + # 清理格式字符(空格、横线、括号等) + if login_type == 'phone': + # 移除手机号中的空格、横线、括号、加号等格式字符 + original_credential = credential + credential = re.sub(r'[\s\-\(\)\+]', '', credential) + if original_credential != credential: + print(f"📱 登录时清理手机号: {original_credential} -> {credential}") + elif login_type == 'email': + # 邮箱只移除前后空格 + credential = credential.strip() + + # 检查验证码 + session_key = f'verification_code_{login_type}_{credential}_login' + stored_code_info = session.get(session_key) + + if not stored_code_info: + return jsonify({'success': False, 'error': '验证码已过期或不存在'}), 400 + + # 检查验证码是否过期(5分钟) + if time.time() - stored_code_info['timestamp'] > 300: + session.pop(session_key, None) + return jsonify({'success': False, 'error': '验证码已过期'}), 400 + + # 检查尝试次数 + if stored_code_info['attempts'] >= 3: + session.pop(session_key, None) + return jsonify({'success': False, 'error': '验证码错误次数过多'}), 400 + + # 验证码错误 + if stored_code_info['code'] != verification_code: + stored_code_info['attempts'] += 1 + session[session_key] = stored_code_info + return jsonify({'success': False, 'error': '验证码错误'}), 400 + + # 验证码正确,查找用户 + user = None + is_new_user = False + + if login_type == 'phone': + user = User.query.filter_by(phone=credential).first() + if not user: + # 自动注册新用户 + is_new_user = True + # 生成唯一用户名 + base_username = f"user_{credential}" + username = base_username + counter = 1 + while User.query.filter_by(username=username).first(): + username = f"{base_username}_{counter}" + counter += 1 + + # 创建新用户 + user = User(username=username, phone=credential) + user.phone_confirmed = True + user.email = f"{username}@valuefrontier.temp" # 临时邮箱 + db.session.add(user) + db.session.commit() + + elif login_type == 'email': + user = User.query.filter_by(email=credential).first() + if not user: + # 自动注册新用户 + is_new_user = True + # 从邮箱生成用户名 + email_prefix = credential.split('@')[0] + base_username = f"user_{email_prefix}" + username = base_username + counter = 1 + while User.query.filter_by(username=username).first(): + username = f"{base_username}_{counter}" + counter += 1 + + # 如果用户不存在,自动创建新用户 + if not user: + try: + # 生成用户名 + if login_type == 'phone': + # 使用手机号生成用户名 + base_username = f"用户{credential[-4:]}" + elif login_type == 'email': + # 使用邮箱前缀生成用户名 + base_username = credential.split('@')[0] + else: + base_username = "新用户" + + # 确保用户名唯一 + username = base_username + counter = 1 + while User.is_username_taken(username): + username = f"{base_username}_{counter}" + counter += 1 + + # 创建新用户 + user = User(username=username) + + # 设置手机号或邮箱 + if login_type == 'phone': + user.phone = credential + elif login_type == 'email': + user.email = credential + + # 设置默认密码(使用随机密码,用户后续可以修改) + user.set_password(uuid.uuid4().hex) + user.status = 'active' + user.nickname = username + + db.session.add(user) + db.session.commit() + + is_new_user = True + print(f"✅ 自动创建新用户: {username}, {login_type}: {credential}") + + except Exception as e: + print(f"❌ 创建用户失败: {e}") + db.session.rollback() + return jsonify({'success': False, 'error': '创建用户失败'}), 500 + + # 清除验证码 + session.pop(session_key, None) + + # 设置session + session.permanent = True + session['user_id'] = user.id + session['username'] = user.username + session['logged_in'] = True + + # Flask-Login 登录 + login_user(user, remember=True) + + # 更新最后登录时间 + user.update_last_seen() + + # 根据是否为新用户返回不同的消息 + message = '注册成功,欢迎加入!' if is_new_user else '登录成功' + + return jsonify({ + 'success': True, + 'message': message, + 'is_new_user': is_new_user, + 'user': { + 'id': user.id, + 'username': user.username, + 'nickname': user.nickname or user.username, + 'email': user.email, + 'phone': user.phone, + 'avatar_url': user.avatar_url, + 'has_wechat': bool(user.wechat_open_id) + } + }) + + except Exception as e: + print(f"验证码登录错误: {e}") + db.session.rollback() + return jsonify({'success': False, 'error': '登录失败'}), 500 + + +@app.route('/api/auth/register', methods=['POST']) +def register(): + """用户注册 - 使用Session""" + username = request.form.get('username') + email = request.form.get('email') + password = request.form.get('password') + + # 验证输入 + if not all([username, email, password]): + return jsonify({'success': False, 'error': '所有字段都是必填的'}), 400 + + # 检查用户名和邮箱是否已存在 + if User.is_username_taken(username): + return jsonify({'success': False, 'error': '用户名已存在'}), 400 + + if User.is_email_taken(email): + return jsonify({'success': False, 'error': '邮箱已被使用'}), 400 + + try: + # 创建新用户 + user = User(username=username, email=email) + user.set_password(password) + user.email_confirmed = True # 暂时默认已确认 + + db.session.add(user) + db.session.commit() + + # 自动登录 + session.permanent = True + session['user_id'] = user.id + session['username'] = user.username + session['logged_in'] = True + + # Flask-Login 登录 + login_user(user, remember=True) + + return jsonify({ + 'success': True, + 'message': '注册成功', + 'user': { + 'id': user.id, + 'username': user.username, + 'nickname': user.nickname or user.username, + 'email': user.email + } + }), 201 + + except Exception as e: + db.session.rollback() + print(f"验证码登录/注册错误: {e}") + return jsonify({'success': False, 'error': '登录失败'}), 500 + + +def send_sms_code(phone, code, template_id): + """发送短信验证码""" + try: + cred = credential.Credential(SMS_SECRET_ID, SMS_SECRET_KEY) + httpProfile = HttpProfile() + httpProfile.endpoint = "sms.tencentcloudapi.com" + + clientProfile = ClientProfile() + clientProfile.httpProfile = httpProfile + client = sms_client.SmsClient(cred, "ap-beijing", clientProfile) + + req = models.SendSmsRequest() + params = { + "PhoneNumberSet": [phone], + "SmsSdkAppId": SMS_SDK_APP_ID, + "TemplateId": template_id, + "SignName": SMS_SIGN_NAME, + "TemplateParamSet": [code, "5"] if template_id == SMS_TEMPLATE_LOGIN else [code] + } + req.from_json_string(json.dumps(params)) + + resp = client.SendSms(req) + return True + except TencentCloudSDKException as err: + print(f"SMS Error: {err}") + return False + + +def send_email_code(email, code): + """发送邮件验证码""" + try: + print(f"[邮件发送] 准备发送验证码到: {email}") + print(f"[邮件配置] 服务器: {MAIL_SERVER}, 端口: {MAIL_PORT}, SSL: {MAIL_USE_SSL}") + + msg = Message( + subject='价值前沿 - 验证码', + recipients=[email], + body=f'您的验证码是:{code},有效期5分钟。如非本人操作,请忽略此邮件。' + ) + mail.send(msg) + print(f"[邮件发送] 验证码邮件发送成功到: {email}") + return True + except Exception as e: + print(f"[邮件发送错误] 发送到 {email} 失败: {str(e)}") + print(f"[邮件发送错误] 错误类型: {type(e).__name__}") + return False + + +@app.route('/api/auth/send-sms-code', methods=['POST']) +def send_sms_verification(): + """发送手机验证码""" + data = request.get_json() + phone = data.get('phone') + + if not phone: + return jsonify({'error': '手机号不能为空'}), 400 + + # 注册时验证是否已注册;若用于绑定手机,需要另外接口 + # 这里保留原逻辑,新增绑定接口处理不同规则 + if User.query.filter_by(phone=phone).first(): + return jsonify({'error': '该手机号已注册'}), 400 + + # 生成验证码 + code = generate_verification_code() + + # 发送短信 + if send_sms_code(phone, code, SMS_TEMPLATE_REGISTER): + # 存储验证码(5分钟有效) + verification_codes[f'phone_{phone}'] = { + 'code': code, + 'expires': time.time() + 300 + } + return jsonify({'message': '验证码已发送'}), 200 + else: + return jsonify({'error': '验证码发送失败'}), 500 + + +@app.route('/api/auth/send-email-code', methods=['POST']) +def send_email_verification(): + """发送邮箱验证码""" + data = request.get_json() + email = data.get('email') + + if not email: + return jsonify({'error': '邮箱不能为空'}), 400 + + if User.query.filter_by(email=email).first(): + return jsonify({'error': '该邮箱已注册'}), 400 + + # 生成验证码 + code = generate_verification_code() + + # 发送邮件 + if send_email_code(email, code): + # 存储验证码(5分钟有效) + verification_codes[f'email_{email}'] = { + 'code': code, + 'expires': time.time() + 300 + } + return jsonify({'message': '验证码已发送'}), 200 + else: + return jsonify({'error': '验证码发送失败'}), 500 + + +@app.route('/api/auth/register/phone', methods=['POST']) +def register_with_phone(): + """手机号注册 - 使用Session""" + data = request.get_json() + phone = data.get('phone') + code = data.get('code') + password = data.get('password') + username = data.get('username') + + if not all([phone, code, password, username]): + return jsonify({'success': False, 'error': '所有字段都是必填的'}), 400 + + # 验证验证码 + stored_code = verification_codes.get(f'phone_{phone}') + if not stored_code or stored_code['expires'] < time.time(): + return jsonify({'success': False, 'error': '验证码已过期'}), 400 + + if stored_code['code'] != code: + return jsonify({'success': False, 'error': '验证码错误'}), 400 + + if User.query.filter_by(username=username).first(): + return jsonify({'success': False, 'error': '用户名已存在'}), 400 + + try: + # 创建用户 + user = User(username=username, phone=phone) + user.email = f"{username}@valuefrontier.temp" + user.set_password(password) + user.phone_confirmed = True + + db.session.add(user) + db.session.commit() + + # 清除验证码 + del verification_codes[f'phone_{phone}'] + + # 自动登录 + session.permanent = True + session['user_id'] = user.id + session['username'] = user.username + session['logged_in'] = True + + # Flask-Login 登录 + login_user(user, remember=True) + + return jsonify({ + 'success': True, + 'message': '注册成功', + 'user': { + 'id': user.id, + 'username': user.username, + 'phone': user.phone + } + }), 201 + + except Exception as e: + db.session.rollback() + return jsonify({'success': False, 'error': '注册失败,请重试'}), 500 + + +@app.route('/api/account/phone/send-code', methods=['POST']) +def send_sms_bind_code(): + """发送绑定手机验证码(需已登录)""" + if not session.get('logged_in'): + return jsonify({'error': '未登录'}), 401 + + data = request.get_json() + phone = data.get('phone') + if not phone: + return jsonify({'error': '手机号不能为空'}), 400 + + # 绑定时要求手机号未被占用 + if User.query.filter_by(phone=phone).first(): + return jsonify({'error': '该手机号已被其他账号使用'}), 400 + + code = generate_verification_code() + if send_sms_code(phone, code, SMS_TEMPLATE_REGISTER): + verification_codes[f'bind_{phone}'] = { + 'code': code, + 'expires': time.time() + 300 + } + return jsonify({'message': '验证码已发送'}), 200 + else: + return jsonify({'error': '验证码发送失败'}), 500 + + +@app.route('/api/account/phone/bind', methods=['POST']) +def bind_phone(): + """当前登录用户绑定手机号""" + if not session.get('logged_in'): + return jsonify({'error': '未登录'}), 401 + + data = request.get_json() + phone = data.get('phone') + code = data.get('code') + + if not phone or not code: + return jsonify({'error': '手机号和验证码不能为空'}), 400 + + stored = verification_codes.get(f'bind_{phone}') + if not stored or stored['expires'] < time.time(): + return jsonify({'error': '验证码已过期'}), 400 + if stored['code'] != code: + return jsonify({'error': '验证码错误'}), 400 + + if User.query.filter_by(phone=phone).first(): + return jsonify({'error': '该手机号已被其他账号使用'}), 400 + + try: + user = User.query.get(session.get('user_id')) + if not user: + return jsonify({'error': '用户不存在'}), 404 + + user.phone = phone + user.confirm_phone() + # 清除验证码 + del verification_codes[f'bind_{phone}'] + + return jsonify({'message': '绑定成功', 'success': True}), 200 + except Exception as e: + print(f"Bind phone error: {e}") + db.session.rollback() + return jsonify({'error': '绑定失败,请重试'}), 500 + + +@app.route('/api/account/phone/unbind', methods=['POST']) +def unbind_phone(): + """解绑手机号(需已登录)""" + if not session.get('logged_in'): + return jsonify({'error': '未登录'}), 401 + + try: + user = User.query.get(session.get('user_id')) + if not user: + return jsonify({'error': '用户不存在'}), 404 + + user.phone = None + user.phone_confirmed = False + user.phone_confirm_time = None + db.session.commit() + return jsonify({'message': '解绑成功', 'success': True}), 200 + except Exception as e: + print(f"Unbind phone error: {e}") + db.session.rollback() + return jsonify({'error': '解绑失败,请重试'}), 500 + + +@app.route('/api/account/email/send-bind-code', methods=['POST']) +def send_email_bind_code(): + """发送绑定邮箱验证码(需已登录)""" + if not session.get('logged_in'): + return jsonify({'error': '未登录'}), 401 + + data = request.get_json() + email = data.get('email') + + if not email: + return jsonify({'error': '邮箱不能为空'}), 400 + + # 邮箱格式验证 + if not re.match(r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$', email): + return jsonify({'error': '邮箱格式不正确'}), 400 + + # 检查邮箱是否已被其他账号使用 + if User.query.filter_by(email=email).first(): + return jsonify({'error': '该邮箱已被其他账号使用'}), 400 + + # 生成验证码 + code = ''.join(random.choices(string.digits, k=6)) + + if send_email_code(email, code): + # 存储验证码(5分钟有效) + verification_codes[f'bind_{email}'] = { + 'code': code, + 'expires': time.time() + 300 + } + return jsonify({'message': '验证码已发送'}), 200 + else: + return jsonify({'error': '验证码发送失败'}), 500 + + +@app.route('/api/account/email/bind', methods=['POST']) +def bind_email(): + """当前登录用户绑定邮箱""" + if not session.get('logged_in'): + return jsonify({'error': '未登录'}), 401 + + data = request.get_json() + email = data.get('email') + code = data.get('code') + + if not email or not code: + return jsonify({'error': '邮箱和验证码不能为空'}), 400 + + stored = verification_codes.get(f'bind_{email}') + if not stored or stored['expires'] < time.time(): + return jsonify({'error': '验证码已过期'}), 400 + if stored['code'] != code: + return jsonify({'error': '验证码错误'}), 400 + + if User.query.filter_by(email=email).first(): + return jsonify({'error': '该邮箱已被其他账号使用'}), 400 + + try: + user = User.query.get(session.get('user_id')) + if not user: + return jsonify({'error': '用户不存在'}), 404 + + user.email = email + user.confirm_email() + db.session.commit() + + # 清除验证码 + del verification_codes[f'bind_{email}'] + + return jsonify({ + 'message': '邮箱绑定成功', + 'success': True, + 'user': { + 'email': user.email, + 'email_confirmed': user.email_confirmed + } + }), 200 + except Exception as e: + print(f"Bind email error: {e}") + db.session.rollback() + return jsonify({'error': '绑定失败,请重试'}), 500 + + +@app.route('/api/account/email/unbind', methods=['POST']) +def unbind_email(): + """解绑邮箱(需已登录)""" + if not session.get('logged_in'): + return jsonify({'error': '未登录'}), 401 + + try: + user = User.query.get(session.get('user_id')) + if not user: + return jsonify({'error': '用户不存在'}), 404 + + user.email = None + user.email_confirmed = False + db.session.commit() + return jsonify({'message': '解绑成功', 'success': True}), 200 + except Exception as e: + print(f"Unbind email error: {e}") + db.session.rollback() + return jsonify({'error': '解绑失败,请重试'}), 500 + + +@app.route('/api/auth/register/email', methods=['POST']) +def register_with_email(): + """邮箱注册 - 使用Session""" + data = request.get_json() + email = data.get('email') + code = data.get('code') + password = data.get('password') + username = data.get('username') + + if not all([email, code, password, username]): + return jsonify({'success': False, 'error': '所有字段都是必填的'}), 400 + + # 验证验证码 + stored_code = verification_codes.get(f'email_{email}') + if not stored_code or stored_code['expires'] < time.time(): + return jsonify({'success': False, 'error': '验证码已过期'}), 400 + + if stored_code['code'] != code: + return jsonify({'success': False, 'error': '验证码错误'}), 400 + + if User.query.filter_by(username=username).first(): + return jsonify({'success': False, 'error': '用户名已存在'}), 400 + + try: + # 创建用户 + user = User(username=username, email=email) + user.set_password(password) + user.email_confirmed = True + + db.session.add(user) + db.session.commit() + + # 清除验证码 + del verification_codes[f'email_{email}'] + + # 自动登录 + session.permanent = True + session['user_id'] = user.id + session['username'] = user.username + session['logged_in'] = True + + # Flask-Login 登录 + login_user(user, remember=True) + + return jsonify({ + 'success': True, + 'message': '注册成功', + 'user': { + 'id': user.id, + 'username': user.username, + 'email': user.email + } + }), 201 + + except Exception as e: + db.session.rollback() + return jsonify({'success': False, 'error': '注册失败,请重试'}), 500 + + +def get_wechat_access_token(code): + """通过code获取微信access_token""" + url = "https://api.weixin.qq.com/sns/oauth2/access_token" + params = { + 'appid': WECHAT_APPID, + 'secret': WECHAT_APPSECRET, + 'code': code, + 'grant_type': 'authorization_code' + } + + try: + response = requests.get(url, params=params, timeout=10) + data = response.json() + + if 'errcode' in data: + print(f"WeChat access token error: {data}") + return None + + return data + except Exception as e: + print(f"WeChat access token request error: {e}") + return None + + +def get_wechat_userinfo(access_token, openid): + """获取微信用户信息(包含UnionID)""" + url = "https://api.weixin.qq.com/sns/userinfo" + params = { + 'access_token': access_token, + 'openid': openid, + 'lang': 'zh_CN' + } + + try: + response = requests.get(url, params=params, timeout=10) + response.encoding = 'utf-8' # 明确设置编码为UTF-8 + data = response.json() + + if 'errcode' in data: + print(f"WeChat userinfo error: {data}") + return None + + # 确保nickname字段的编码正确 + if 'nickname' in data and data['nickname']: + # 确保昵称是正确的UTF-8编码 + try: + # 检查是否已经是正确的UTF-8字符串 + data['nickname'] = data['nickname'].encode('utf-8').decode('utf-8') + except (UnicodeEncodeError, UnicodeDecodeError) as e: + print(f"Nickname encoding error: {e}, using default") + data['nickname'] = '微信用户' + + return data + except Exception as e: + print(f"WeChat userinfo request error: {e}") + return None + + +@app.route('/api/auth/wechat/qrcode', methods=['GET']) +def get_wechat_qrcode(): + """返回微信授权URL,前端使用iframe展示""" + # 生成唯一state参数 + state = uuid.uuid4().hex + + # URL编码回调地址 + redirect_uri = urllib.parse.quote_plus(WECHAT_REDIRECT_URI) + + # 构建微信授权URL + wechat_auth_url = ( + f"https://open.weixin.qq.com/connect/qrconnect?" + f"appid={WECHAT_APPID}&redirect_uri={redirect_uri}" + f"&response_type=code&scope=snsapi_login&state={state}" + "#wechat_redirect" + ) + + # 存储session信息 + wechat_qr_sessions[state] = { + 'status': 'waiting', + 'expires': time.time() + 300, # 5分钟过期 + 'user_info': None, + 'wechat_openid': None, + 'wechat_unionid': None + } + + return jsonify({"code":0, + "data": + { + 'auth_url': wechat_auth_url, + 'session_id': state, + 'expires_in': 300 + }}), 200 + + +@app.route('/api/account/wechat/qrcode', methods=['GET']) +def get_wechat_bind_qrcode(): + """发起微信绑定二维码,会话标记为绑定模式""" + if not session.get('logged_in'): + return jsonify({'error': '未登录'}), 401 + + # 生成唯一state参数 + state = uuid.uuid4().hex + + # URL编码回调地址 + redirect_uri = urllib.parse.quote_plus(WECHAT_REDIRECT_URI) + + # 构建微信授权URL + wechat_auth_url = ( + f"https://open.weixin.qq.com/connect/qrconnect?" + f"appid={WECHAT_APPID}&redirect_uri={redirect_uri}" + f"&response_type=code&scope=snsapi_login&state={state}" + "#wechat_redirect" + ) + + # 存储session信息,标记为绑定模式并记录目标用户 + wechat_qr_sessions[state] = { + 'status': 'waiting', + 'expires': time.time() + 300, + 'mode': 'bind', + 'bind_user_id': session.get('user_id'), + 'user_info': None, + 'wechat_openid': None, + 'wechat_unionid': None + } + + return jsonify({ + 'auth_url': wechat_auth_url, + 'session_id': state, + 'expires_in': 300 + }), 200 + + +@app.route('/api/auth/wechat/check', methods=['POST']) +def check_wechat_scan(): + """检查微信扫码状态""" + data = request.get_json() + session_id = data.get('session_id') + + if not session_id or session_id not in wechat_qr_sessions: + return jsonify({'status': 'invalid', 'error': '无效的session'}), 400 + + session = wechat_qr_sessions[session_id] + + # 检查是否过期 + if time.time() > session['expires']: + del wechat_qr_sessions[session_id] + return jsonify({'status': 'expired'}), 200 + + return jsonify({ + 'status': session['status'], + 'user_info': session.get('user_info'), + 'expires_in': int(session['expires'] - time.time()) + }), 200 + + +@app.route('/api/account/wechat/check', methods=['POST']) +def check_wechat_bind_scan(): + """检查微信扫码绑定状态""" + data = request.get_json() + session_id = data.get('session_id') + + if not session_id or session_id not in wechat_qr_sessions: + return jsonify({'status': 'invalid', 'error': '无效的session'}), 400 + + sess = wechat_qr_sessions[session_id] + + # 绑定模式限制 + if sess.get('mode') != 'bind': + return jsonify({'status': 'invalid', 'error': '会话模式错误'}), 400 + + # 过期处理 + if time.time() > sess['expires']: + del wechat_qr_sessions[session_id] + return jsonify({'status': 'expired'}), 200 + + return jsonify({ + 'status': sess['status'], + 'user_info': sess.get('user_info'), + 'expires_in': int(sess['expires'] - time.time()) + }), 200 + + +@app.route('/api/auth/wechat/callback', methods=['GET']) +def wechat_callback(): + """微信授权回调处理 - 使用Session""" + code = request.args.get('code') + state = request.args.get('state') + error = request.args.get('error') + + # 错误处理:用户拒绝授权 + if error: + if state in wechat_qr_sessions: + wechat_qr_sessions[state]['status'] = 'auth_denied' + wechat_qr_sessions[state]['error'] = '用户拒绝授权' + print(f"❌ 用户拒绝授权: state={state}") + return redirect('/auth/signin?error=wechat_auth_denied') + + # 参数验证 + if not code or not state: + if state in wechat_qr_sessions: + wechat_qr_sessions[state]['status'] = 'auth_failed' + wechat_qr_sessions[state]['error'] = '授权参数缺失' + return redirect('/auth/signin?error=wechat_auth_failed') + + # 验证state + if state not in wechat_qr_sessions: + return redirect('/auth/signin?error=session_expired') + + session_data = wechat_qr_sessions[state] + + # 检查过期 + if time.time() > session_data['expires']: + del wechat_qr_sessions[state] + return redirect('/auth/signin?error=session_expired') + + try: + # 步骤1: 用户已扫码并授权(微信回调过来说明用户已完成扫码+授权) + session_data['status'] = 'scanned' + print(f"✅ 微信扫码回调: state={state}, code={code[:10]}...") + + # 步骤2: 获取access_token + token_data = get_wechat_access_token(code) + if not token_data: + session_data['status'] = 'auth_failed' + session_data['error'] = '获取访问令牌失败' + print(f"❌ 获取微信access_token失败: state={state}") + return redirect('/auth/signin?error=token_failed') + + # 步骤3: Token获取成功,标记为已授权 + session_data['status'] = 'authorized' + print(f"✅ 微信授权成功: openid={token_data['openid']}") + + # 步骤4: 获取用户信息 + user_info = get_wechat_userinfo(token_data['access_token'], token_data['openid']) + if not user_info: + session_data['status'] = 'auth_failed' + session_data['error'] = '获取用户信息失败' + print(f"❌ 获取微信用户信息失败: openid={token_data['openid']}") + return redirect('/auth/signin?error=userinfo_failed') + + # 查找或创建用户 / 或处理绑定 + openid = token_data['openid'] + unionid = user_info.get('unionid') or token_data.get('unionid') + + # 如果是绑定流程 + session_item = wechat_qr_sessions.get(state) + if session_item and session_item.get('mode') == 'bind': + try: + target_user_id = session.get('user_id') or session_item.get('bind_user_id') + if not target_user_id: + return redirect('/auth/signin?error=bind_no_user') + + target_user = User.query.get(target_user_id) + if not target_user: + return redirect('/auth/signin?error=bind_user_missing') + + # 检查该微信是否已被其他账户绑定 + existing = None + if unionid: + existing = User.query.filter_by(wechat_union_id=unionid).first() + if not existing: + existing = User.query.filter_by(wechat_open_id=openid).first() + + if existing and existing.id != target_user.id: + session_item['status'] = 'bind_conflict' + return redirect('/home?bind=conflict') + + # 执行绑定 + target_user.bind_wechat(openid, unionid, wechat_info=user_info) + + # 标记绑定完成,供前端轮询 + session_item['status'] = 'bind_ready' + session_item['user_info'] = {'user_id': target_user.id} + + return redirect('/home?bind=success') + except Exception as e: + print(f"❌ 微信绑定失败: {e}") + db.session.rollback() + session_item['status'] = 'bind_failed' + return redirect('/home?bind=failed') + + user = None + is_new_user = False + + if unionid: + user = User.query.filter_by(wechat_union_id=unionid).first() + if not user: + user = User.query.filter_by(wechat_open_id=openid).first() + + if not user: + # 创建新用户 + # 先清理微信昵称 + raw_nickname = user_info.get('nickname', '微信用户') + # 创建临时用户实例以使用清理方法 + temp_user = User.__new__(User) + sanitized_nickname = temp_user._sanitize_nickname(raw_nickname) + + username = sanitized_nickname + counter = 1 + while User.is_username_taken(username): + username = f"{sanitized_nickname}_{counter}" + counter += 1 + + user = User(username=username) + user.nickname = sanitized_nickname + user.avatar_url = user_info.get('headimgurl') + user.wechat_open_id = openid + user.wechat_union_id = unionid + user.set_password(uuid.uuid4().hex) + user.status = 'active' + + db.session.add(user) + db.session.commit() + + is_new_user = True + print(f"✅ 微信扫码自动创建新用户: {username}, openid: {openid}") + + # 更新最后登录时间 + user.update_last_seen() + + # 设置session + session.permanent = True + session['user_id'] = user.id + session['username'] = user.username + session['logged_in'] = True + session['wechat_login'] = True # 标记是微信登录 + + # Flask-Login 登录 + login_user(user, remember=True) + + # 更新微信session状态,供前端轮询检测 + if state in wechat_qr_sessions: + session_item = wechat_qr_sessions[state] + # 仅处理登录/注册流程,不处理绑定流程 + if not session_item.get('mode'): + # 更新状态和用户信息 + session_item['status'] = 'register_ready' if is_new_user else 'login_ready' + session_item['user_info'] = {'user_id': user.id} + print(f"✅ 微信扫码状态已更新: {session_item['status']}, user_id: {user.id}") + + # 直接跳转到首页 + return redirect('/home') + + except Exception as e: + print(f"❌ 微信登录失败: {e}") + import traceback + traceback.print_exc() + db.session.rollback() + + # 更新session状态为失败 + if state in wechat_qr_sessions: + wechat_qr_sessions[state]['status'] = 'auth_failed' + wechat_qr_sessions[state]['error'] = str(e) + + return redirect('/auth/signin?error=login_failed') + + +@app.route('/api/auth/login/wechat', methods=['POST']) +def login_with_wechat(): + """微信登录 - 修复版本""" + data = request.get_json() + session_id = data.get('session_id') + + if not session_id: + return jsonify({'success': False, 'error': 'session_id不能为空'}), 400 + + # 验证session + session = wechat_qr_sessions.get(session_id) + if not session: + return jsonify({'success': False, 'error': '会话不存在或已过期'}), 400 + + # 检查session状态 + if session['status'] not in ['login_ready', 'register_ready']: + return jsonify({'success': False, 'error': '会话状态无效'}), 400 + + # 检查是否有用户信息 + user_info = session.get('user_info') + if not user_info or not user_info.get('user_id'): + return jsonify({'success': False, 'error': '用户信息不完整'}), 400 + + try: + user = User.query.get(user_info['user_id']) + if not user: + return jsonify({'success': False, 'error': '用户不存在'}), 404 + + # 更新最后登录时间 + user.update_last_seen() + + # 清除session + del wechat_qr_sessions[session_id] + + # 生成登录响应 + response_data = { + 'success': True, + 'message': '登录成功' if session['status'] == 'login_ready' else '注册并登录成功', + 'user': { + 'id': user.id, + 'username': user.username, + 'nickname': user.nickname or user.username, + 'email': user.email, + 'avatar_url': user.avatar_url, + 'has_wechat': True, + 'wechat_open_id': user.wechat_open_id, + 'wechat_union_id': user.wechat_union_id, + 'created_at': user.created_at.isoformat() if user.created_at else None, + 'last_seen': user.last_seen.isoformat() if user.last_seen else None + } + } + + # 如果需要token认证,可以在这里生成 + # response_data['token'] = generate_token(user.id) + + return jsonify(response_data), 200 + + except Exception as e: + print(f"❌ 微信登录错误: {e}") + import traceback + app.logger.error(f"回调处理错误: {e}", exc_info=True) + return jsonify({ + 'success': False, + 'error': '登录失败,请重试' + }), 500 + + +@app.route('/api/account/wechat/unbind', methods=['POST']) +def unbind_wechat_account(): + """解绑当前登录用户的微信""" + if not session.get('logged_in'): + return jsonify({'error': '未登录'}), 401 + + try: + user = User.query.get(session.get('user_id')) + if not user: + return jsonify({'error': '用户不存在'}), 404 + + user.unbind_wechat() + return jsonify({'message': '解绑成功', 'success': True}), 200 + except Exception as e: + print(f"Unbind wechat error: {e}") + db.session.rollback() + return jsonify({'error': '解绑失败,请重试'}), 500 + + +# 评论模型 +class EventComment(db.Model): + """事件评论""" + __tablename__ = 'event_comment' + + id = db.Column(db.Integer, primary_key=True) + event_id = db.Column(db.Integer, nullable=False) + user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=True) + author = db.Column(db.String(50), default='匿名用户') + content = db.Column(db.Text, nullable=False) + parent_id = db.Column(db.Integer, db.ForeignKey('event_comment.id')) + likes = db.Column(db.Integer, default=0) + created_at = db.Column(db.DateTime, default=beijing_now) + status = db.Column(db.String(20), default='active') + + user = db.relationship('User', backref='event_comments') + replies = db.relationship('EventComment', backref=db.backref('parent', remote_side=[id])) + + def to_dict(self, user_session_id=None, current_user_id=None): + # 检查当前用户是否已点赞 + user_liked = False + if user_session_id: + like_record = CommentLike.query.filter_by( + comment_id=self.id, + session_id=user_session_id + ).first() + user_liked = like_record is not None + + # 检查当前用户是否可以删除此评论 + can_delete = current_user_id is not None and self.user_id == current_user_id + + return { + 'id': self.id, + 'event_id': self.event_id, + 'author': self.author, + 'content': self.content, + 'parent_id': self.parent_id, + 'likes': self.likes, + 'created_at': self.created_at.isoformat() if self.created_at else None, + 'user_liked': user_liked, + 'can_delete': can_delete, + 'user_id': self.user_id, + 'replies': [reply.to_dict(user_session_id, current_user_id) for reply in self.replies if + reply.status == 'active'] + } + + +class CommentLike(db.Model): + """评论点赞记录""" + __tablename__ = 'comment_like' + + id = db.Column(db.Integer, primary_key=True) + comment_id = db.Column(db.Integer, db.ForeignKey('event_comment.id'), nullable=False) + session_id = db.Column(db.String(100), nullable=False) + created_at = db.Column(db.DateTime, default=beijing_now) + + __table_args__ = (db.UniqueConstraint('comment_id', 'session_id'),) + + +@app.after_request +def after_request(response): + """处理所有响应,添加CORS头部和安全头部""" + origin = request.headers.get('Origin') + allowed_origins = ['http://localhost:3000', 'http://127.0.0.1:3000', 'http://localhost:5173', + 'https://valuefrontier.cn', 'http://valuefrontier.cn'] + + if origin in allowed_origins: + response.headers['Access-Control-Allow-Origin'] = origin + response.headers['Access-Control-Allow-Credentials'] = 'true' + response.headers['Access-Control-Allow-Headers'] = 'Content-Type,Authorization,X-Requested-With' + response.headers['Access-Control-Allow-Methods'] = 'GET,PUT,POST,DELETE,OPTIONS' + response.headers['Access-Control-Expose-Headers'] = 'Content-Type,Authorization' + + # 处理预检请求 + if request.method == 'OPTIONS': + response.status_code = 200 + + return response + + +def add_cors_headers(response): + """添加CORS头(保留原有函数以兼容)""" + origin = request.headers.get('Origin') + allowed_origins = ['http://localhost:3000', 'http://127.0.0.1:3000', 'http://localhost:5173', + 'https://valuefrontier.cn', 'http://valuefrontier.cn'] + + if origin in allowed_origins: + response.headers['Access-Control-Allow-Origin'] = origin + else: + response.headers['Access-Control-Allow-Origin'] = 'http://localhost:3000' + + response.headers['Access-Control-Allow-Headers'] = 'Content-Type,Authorization,X-Requested-With' + response.headers['Access-Control-Allow-Methods'] = 'GET,PUT,POST,DELETE,OPTIONS' + response.headers['Access-Control-Allow-Credentials'] = 'true' + return response + + +class EventFollow(db.Model): + """事件关注""" + id = db.Column(db.Integer, primary_key=True) + user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False) + event_id = db.Column(db.Integer, db.ForeignKey('event.id'), nullable=False) + created_at = db.Column(db.DateTime, default=beijing_now) + + user = db.relationship('User', backref='event_follows') + + __table_args__ = (db.UniqueConstraint('user_id', 'event_id'),) + + +class FutureEventFollow(db.Model): + """未来事件关注""" + __tablename__ = 'future_event_follow' + + id = db.Column(db.Integer, primary_key=True) + user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False) + future_event_id = db.Column(db.Integer, nullable=False) # future_events表的id + created_at = db.Column(db.DateTime, default=beijing_now) + + user = db.relationship('User', backref='future_event_follows') + + __table_args__ = (db.UniqueConstraint('user_id', 'future_event_id'),) + + +# —— 自选股输入统一化与名称补全工具 —— +def _normalize_stock_input(raw_input: str): + """解析用户输入为标准6位股票代码与可选名称。 + + 支持: + - 6位代码: "600519",或带后缀 "600519.SH"/"600519.SZ" + - 名称(代码): "贵州茅台(600519)" 或 "贵州茅台(600519)" + 返回 (code6, name_or_none) + """ + if not raw_input: + return None, None + s = str(raw_input).strip() + + # 名称(600519) 或 名称(600519) + m = re.match(r"^(.+?)[\((]\s*(\d{6})\s*[\))]\s*$", s) + if m: + name = m.group(1).strip() + code = m.group(2) + return code, (name if name else None) + + # 600519 或 600519.SH / 600519.SZ + m2 = re.match(r"^(\d{6})(?:\.(?:SH|SZ))?$", s, re.IGNORECASE) + if m2: + return m2.group(1), None + + # SH600519 / SZ000001 + m3 = re.match(r"^(SH|SZ)(\d{6})$", s, re.IGNORECASE) + if m3: + return m3.group(2), None + + return None, None + + +def _query_stock_name_by_code(code6: str): + """根据6位代码查询股票名称,查不到返回None。""" + try: + with engine.connect() as conn: + q = text(""" + SELECT SECNAME + FROM ea_baseinfo + WHERE SECCODE = :c LIMIT 1 + """) + row = conn.execute(q, {'c': code6}).fetchone() + if row: + return row[0] + except Exception: + pass + return None + + +class Watchlist(db.Model): + """用户自选股""" + __tablename__ = 'watchlist' + + id = db.Column(db.Integer, primary_key=True) + user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False) + stock_code = db.Column(db.String(20), nullable=False) + stock_name = db.Column(db.String(100), nullable=True) + created_at = db.Column(db.DateTime, default=beijing_now) + + user = db.relationship('User', backref='watchlist') + + __table_args__ = (db.UniqueConstraint('user_id', 'stock_code'),) + + +@app.route('/api/account/watchlist', methods=['GET']) +def get_my_watchlist(): + """获取当前用户的自选股列表""" + try: + if 'user_id' not in session: + return jsonify({'success': False, 'error': '未登录'}), 401 + + items = Watchlist.query.filter_by(user_id=session['user_id']).order_by(Watchlist.created_at.desc()).all() + + # 懒更新:统一代码为6位、补全缺失的名称,并去重(同一代码保留一个记录) + from collections import defaultdict + groups = defaultdict(list) + for i in items: + code6, _ = _normalize_stock_input(i.stock_code) + normalized_code = code6 or (i.stock_code.strip().upper() if isinstance(i.stock_code, str) else i.stock_code) + groups[normalized_code].append(i) + + dirty = False + to_delete = [] + for code6, group in groups.items(): + # 选择保留记录:优先有名称的,其次创建时间早的 + def sort_key(x): + return (x.stock_name is None, x.created_at or datetime.min) + + group_sorted = sorted(group, key=sort_key) + keep = group_sorted[0] + # 规范保留项 + if keep.stock_code != code6: + keep.stock_code = code6 + dirty = True + if not keep.stock_name and code6: + nm = _query_stock_name_by_code(code6) + if nm: + keep.stock_name = nm + dirty = True + # 其余删除 + for g in group_sorted[1:]: + to_delete.append(g) + + if to_delete: + for g in to_delete: + db.session.delete(g) + dirty = True + + if dirty: + db.session.commit() + + return jsonify({'success': True, 'data': [ + { + 'id': i.id, + 'stock_code': i.stock_code, + 'stock_name': i.stock_name, + 'created_at': i.created_at.isoformat() if i.created_at else None + } for i in items + ]}) + except Exception as e: + print(f"Error in get_my_watchlist: {str(e)}") + return jsonify({'success': False, 'error': str(e)}), 500 + + +@app.route('/api/account/watchlist', methods=['POST']) +def add_to_watchlist(): + """添加到自选股""" + if 'user_id' not in session: + return jsonify({'success': False, 'error': '未登录'}), 401 + + data = request.get_json() or {} + raw_code = data.get('stock_code') + raw_name = data.get('stock_name') + + code6, name_from_input = _normalize_stock_input(raw_code) + if not code6: + return jsonify({'success': False, 'error': '无效的股票标识'}), 400 + + # 优先使用传入名称,其次从输入解析中获得,最后查库补全 + final_name = raw_name or name_from_input or _query_stock_name_by_code(code6) + + # 查找已存在记录,兼容历史:6位/带后缀 + candidates = [code6, f"{code6}.SH", f"{code6}.SZ"] + existing = Watchlist.query.filter( + Watchlist.user_id == session['user_id'], + Watchlist.stock_code.in_(candidates) + ).first() + if existing: + # 统一为6位,补全名称 + updated = False + if existing.stock_code != code6: + existing.stock_code = code6 + updated = True + if (not existing.stock_name) and final_name: + existing.stock_name = final_name + updated = True + if updated: + db.session.commit() + return jsonify({'success': True, 'data': {'id': existing.id}}) + + item = Watchlist(user_id=session['user_id'], stock_code=code6, stock_name=final_name) + db.session.add(item) + db.session.commit() + return jsonify({'success': True, 'data': {'id': item.id}}) + + +@app.route('/api/account/watchlist/', methods=['DELETE']) +def remove_from_watchlist(stock_code): + """从自选股移除""" + if 'user_id' not in session: + return jsonify({'success': False, 'error': '未登录'}), 401 + + code6, _ = _normalize_stock_input(stock_code) + candidates = [] + if code6: + candidates = [code6, f"{code6}.SH", f"{code6}.SZ"] + # 包含原始传入(以兼容历史) + if stock_code not in candidates: + candidates.append(stock_code) + + item = Watchlist.query.filter( + Watchlist.user_id == session['user_id'], + Watchlist.stock_code.in_(candidates) + ).first() + if not item: + return jsonify({'success': False, 'error': '未找到自选项'}), 404 + db.session.delete(item) + db.session.commit() + return jsonify({'success': True}) + + +@app.route('/api/account/watchlist/realtime', methods=['GET']) +def get_watchlist_realtime(): + """获取自选股实时行情数据(基于分钟线)""" + try: + if 'user_id' not in session: + return jsonify({'success': False, 'error': '未登录'}), 401 + + # 获取用户自选股列表 + watchlist = Watchlist.query.filter_by(user_id=session['user_id']).all() + if not watchlist: + return jsonify({'success': True, 'data': []}) + + # 获取股票代码列表 + stock_codes = [] + for item in watchlist: + code6, _ = _normalize_stock_input(item.stock_code) + # 统一内部查询代码 + normalized = code6 or str(item.stock_code).strip().upper() + stock_codes.append(normalized) + + # 使用现有的分钟线接口获取最新行情 + client = get_clickhouse_client() + quotes_data = {} + + # 获取最新交易日 + today = datetime.now().date() + + # 获取每只股票的最新价格 + for code in stock_codes: + raw_code = str(code).strip().upper() + if '.' in raw_code: + stock_code_full = raw_code + else: + stock_code_full = f"{raw_code}.SH" if raw_code.startswith('6') else f"{raw_code}.SZ" + + # 获取最新分钟线数据(先查近7天,若无数据再兜底倒序取最近一条) + query = """ + SELECT + close, timestamp, high, low, volume, amt + FROM stock_minute + WHERE code = %(code)s + AND timestamp >= %(start)s + ORDER BY timestamp DESC + LIMIT 1 \ + """ + + # 获取最近7天的分钟数据 + start_date = today - timedelta(days=7) + + result = client.execute(query, { + 'code': stock_code_full, + 'start': datetime.combine(start_date, dt_time(9, 30)) + }) + + # 若近7天无数据,兜底直接取最近一条 + if not result: + fallback_query = """ + SELECT + close, timestamp, high, low, volume, amt + FROM stock_minute + WHERE code = %(code)s + ORDER BY timestamp DESC + LIMIT 1 \ + """ + result = client.execute(fallback_query, {'code': stock_code_full}) + + if result: + latest_data = result[0] + latest_ts = latest_data[1] + + # 获取该bar所属交易日前一个交易日的收盘价 + prev_close_query = """ + SELECT close + FROM stock_minute + WHERE code = %(code)s + AND timestamp \ + < %(start)s + ORDER BY timestamp DESC + LIMIT 1 \ + """ + + prev_result = client.execute(prev_close_query, { + 'code': stock_code_full, + 'start': datetime.combine(latest_ts.date(), dt_time(9, 30)) + }) + + prev_close = float(prev_result[0][0]) if prev_result else float(latest_data[0]) + + # 计算涨跌幅 + change = float(latest_data[0]) - prev_close + change_percent = (change / prev_close * 100) if prev_close > 0 else 0.0 + + quotes_data[code] = { + 'price': float(latest_data[0]), + 'prev_close': float(prev_close), + 'change': float(change), + 'change_percent': float(change_percent), + 'high': float(latest_data[2]), + 'low': float(latest_data[3]), + 'volume': int(latest_data[4]), + 'amount': float(latest_data[5]), + 'update_time': latest_ts.strftime('%H:%M:%S') + } + + # 构建响应数据 + response_data = [] + for item in watchlist: + code6, _ = _normalize_stock_input(item.stock_code) + quote = quotes_data.get(code6 or item.stock_code, {}) + response_data.append({ + 'stock_code': code6 or item.stock_code, + 'stock_name': item.stock_name or (code6 and _query_stock_name_by_code(code6)) or None, + 'current_price': quote.get('price', 0), + 'prev_close': quote.get('prev_close', 0), + 'change': quote.get('change', 0), + 'change_percent': quote.get('change_percent', 0), + 'high': quote.get('high', 0), + 'low': quote.get('low', 0), + 'volume': quote.get('volume', 0), + 'amount': quote.get('amount', 0), + 'update_time': quote.get('update_time', ''), + # industry 字段在 Watchlist 模型中不存在,先不返回该字段 + }) + + return jsonify({ + 'success': True, + 'data': response_data + }) + + except Exception as e: + print(f"获取实时行情失败: {str(e)}") + return jsonify({'success': False, 'error': '获取实时行情失败'}), 500 + + +# 投资计划和复盘相关的模型 +class InvestmentPlan(db.Model): + __tablename__ = 'investment_plans' + id = db.Column(db.Integer, primary_key=True, autoincrement=True) + user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False) + date = db.Column(db.Date, nullable=False) + title = db.Column(db.String(200), nullable=False) + content = db.Column(db.Text) + type = db.Column(db.String(20)) # 'plan' or 'review' + stocks = db.Column(db.Text) # JSON array of stock codes + tags = db.Column(db.String(500)) # JSON array of tags + status = db.Column(db.String(20), default='active') # active, completed, cancelled + created_at = db.Column(db.DateTime, default=datetime.utcnow) + updated_at = db.Column(db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) + + def to_dict(self): + return { + 'id': self.id, + 'date': self.date.isoformat() if self.date else None, + 'title': self.title, + 'content': self.content, + 'type': self.type, + 'stocks': json.loads(self.stocks) if self.stocks else [], + 'tags': json.loads(self.tags) if self.tags else [], + 'status': self.status, + 'created_at': self.created_at.isoformat() if self.created_at else None, + 'updated_at': self.updated_at.isoformat() if self.updated_at else None + } + + +@app.route('/api/account/investment-plans', methods=['GET']) +def get_investment_plans(): + """获取投资计划和复盘记录""" + try: + if 'user_id' not in session: + return jsonify({'success': False, 'error': '未登录'}), 401 + + plan_type = request.args.get('type') # 'plan', 'review', or None for all + start_date = request.args.get('start_date') + end_date = request.args.get('end_date') + + query = InvestmentPlan.query.filter_by(user_id=session['user_id']) + + if plan_type: + query = query.filter_by(type=plan_type) + + if start_date: + query = query.filter(InvestmentPlan.date >= datetime.fromisoformat(start_date).date()) + + if end_date: + query = query.filter(InvestmentPlan.date <= datetime.fromisoformat(end_date).date()) + + plans = query.order_by(InvestmentPlan.date.desc()).all() + + return jsonify({ + 'success': True, + 'data': [plan.to_dict() for plan in plans] + }) + + except Exception as e: + print(f"获取投资计划失败: {str(e)}") + return jsonify({'success': False, 'error': '获取数据失败'}), 500 + + +@app.route('/api/account/investment-plans', methods=['POST']) +def create_investment_plan(): + """创建投资计划或复盘记录""" + try: + if 'user_id' not in session: + return jsonify({'success': False, 'error': '未登录'}), 401 + + data = request.get_json() + + # 验证必要字段 + if not data.get('date') or not data.get('title') or not data.get('type'): + return jsonify({'success': False, 'error': '缺少必要字段'}), 400 + + plan = InvestmentPlan( + user_id=session['user_id'], + date=datetime.fromisoformat(data['date']).date(), + title=data['title'], + content=data.get('content', ''), + type=data['type'], + stocks=json.dumps(data.get('stocks', [])), + tags=json.dumps(data.get('tags', [])), + status=data.get('status', 'active') + ) + + db.session.add(plan) + db.session.commit() + + return jsonify({ + 'success': True, + 'data': plan.to_dict() + }) + + except Exception as e: + db.session.rollback() + print(f"创建投资计划失败: {str(e)}") + return jsonify({'success': False, 'error': '创建失败'}), 500 + + +@app.route('/api/account/investment-plans/', methods=['PUT']) +def update_investment_plan(plan_id): + """更新投资计划或复盘记录""" + try: + if 'user_id' not in session: + return jsonify({'success': False, 'error': '未登录'}), 401 + + plan = InvestmentPlan.query.filter_by(id=plan_id, user_id=session['user_id']).first() + if not plan: + return jsonify({'success': False, 'error': '未找到该记录'}), 404 + + data = request.get_json() + + if 'date' in data: + plan.date = datetime.fromisoformat(data['date']).date() + if 'title' in data: + plan.title = data['title'] + if 'content' in data: + plan.content = data['content'] + if 'stocks' in data: + plan.stocks = json.dumps(data['stocks']) + if 'tags' in data: + plan.tags = json.dumps(data['tags']) + if 'status' in data: + plan.status = data['status'] + + plan.updated_at = datetime.utcnow() + db.session.commit() + + return jsonify({ + 'success': True, + 'data': plan.to_dict() + }) + + except Exception as e: + db.session.rollback() + print(f"更新投资计划失败: {str(e)}") + return jsonify({'success': False, 'error': '更新失败'}), 500 + + +@app.route('/api/account/investment-plans/', methods=['DELETE']) +def delete_investment_plan(plan_id): + """删除投资计划或复盘记录""" + try: + if 'user_id' not in session: + return jsonify({'success': False, 'error': '未登录'}), 401 + + plan = InvestmentPlan.query.filter_by(id=plan_id, user_id=session['user_id']).first() + if not plan: + return jsonify({'success': False, 'error': '未找到该记录'}), 404 + + db.session.delete(plan) + db.session.commit() + + return jsonify({'success': True}) + + except Exception as e: + db.session.rollback() + print(f"删除投资计划失败: {str(e)}") + return jsonify({'success': False, 'error': '删除失败'}), 500 + + +@app.route('/api/account/events/following', methods=['GET']) +def get_my_following_events(): + """获取我关注的事件列表""" + if 'user_id' not in session: + return jsonify({'success': False, 'error': '未登录'}), 401 + + follows = EventFollow.query.filter_by(user_id=session['user_id']).order_by(EventFollow.created_at.desc()).all() + event_ids = [f.event_id for f in follows] + if not event_ids: + return jsonify({'success': True, 'data': []}) + + events = Event.query.filter(Event.id.in_(event_ids)).all() + data = [] + for ev in events: + data.append({ + 'id': ev.id, + 'title': ev.title, + 'event_type': ev.event_type, + 'start_time': ev.start_time.isoformat() if ev.start_time else None, + 'hot_score': ev.hot_score, + 'follower_count': ev.follower_count, + }) + return jsonify({'success': True, 'data': data}) + + +@app.route('/api/account/events/comments', methods=['GET']) +def get_my_event_comments(): + """获取我在事件上的评论(EventComment)""" + if 'user_id' not in session: + return jsonify({'success': False, 'error': '未登录'}), 401 + + comments = EventComment.query.filter_by(user_id=session['user_id']).order_by(EventComment.created_at.desc()).limit( + 100).all() + return jsonify({'success': True, 'data': [c.to_dict() for c in comments]}) + + +@app.route('/api/account/future-events/following', methods=['GET']) +def get_my_following_future_events(): + """获取当前用户关注的未来事件""" + if 'user_id' not in session: + return jsonify({'success': False, 'error': '未登录'}), 401 + + try: + # 获取用户关注的未来事件ID列表 + follows = FutureEventFollow.query.filter_by(user_id=session['user_id']).all() + future_event_ids = [f.future_event_id for f in follows] + + if not future_event_ids: + return jsonify({'success': True, 'data': []}) + + # 查询未来事件详情 + sql = """ + SELECT * + FROM future_events + WHERE data_id IN :event_ids + ORDER BY calendar_time \ + """ + + result = db.session.execute( + text(sql), + {'event_ids': tuple(future_event_ids)} + ) + + events = [] + for row in result: + event_data = { + 'id': row.data_id, + 'title': row.title, + 'type': row.type, + 'calendar_time': row.calendar_time.isoformat(), + 'star': row.star, + 'former': row.former, + 'forecast': row.forecast, + 'fact': row.fact, + 'is_following': True, # 这些都是已关注的 + 'related_stocks': parse_json_field(row.related_stocks), + 'concepts': parse_json_field(row.concepts) + } + events.append(event_data) + + return jsonify({'success': True, 'data': events}) + + except Exception as e: + return jsonify({'success': False, 'error': str(e)}), 500 + + +class PostLike(db.Model): + """帖子点赞""" + id = db.Column(db.Integer, primary_key=True) + user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False) + post_id = db.Column(db.Integer, db.ForeignKey('post.id'), nullable=False) + created_at = db.Column(db.DateTime, default=beijing_now) + + user = db.relationship('User', backref='post_likes') + + __table_args__ = (db.UniqueConstraint('user_id', 'post_id'),) + + +class Event(db.Model): + """事件模型""" + id = db.Column(db.Integer, primary_key=True) + title = db.Column(db.String(200), nullable=False) + description = db.Column(db.Text) + + # 事件类型与状态 + event_type = db.Column(db.String(50)) + status = db.Column(db.String(20), default='active') + + # 时间相关 + start_time = db.Column(db.DateTime, default=beijing_now) + end_time = db.Column(db.DateTime) + created_at = db.Column(db.DateTime, default=beijing_now) + updated_at = db.Column(db.DateTime, default=beijing_now) + + # 热度与统计 + hot_score = db.Column(db.Float, default=0) + view_count = db.Column(db.Integer, default=0) + trending_score = db.Column(db.Float, default=0) + post_count = db.Column(db.Integer, default=0) + follower_count = db.Column(db.Integer, default=0) + + # 关联信息 + related_industries = db.Column(db.JSON) + keywords = db.Column(db.JSON) + files = db.Column(db.JSON) + importance = db.Column(db.String(20)) + related_avg_chg = db.Column(db.Float, default=0) + related_max_chg = db.Column(db.Float, default=0) + related_week_chg = db.Column(db.Float, default=0) + + # 新增字段 + invest_score = db.Column(db.Integer) # 超预期得分 + expectation_surprise_score = db.Column(db.Integer) + # 创建者信息 + creator_id = db.Column(db.Integer, db.ForeignKey('user.id')) + creator = db.relationship('User', backref='created_events') + + # 关系 + posts = db.relationship('Post', backref='event', lazy='dynamic') + followers = db.relationship('EventFollow', backref='event', lazy='dynamic') + related_stocks = db.relationship('RelatedStock', backref='event', lazy='dynamic') + historical_events = db.relationship('HistoricalEvent', backref='event', lazy='dynamic') + related_data = db.relationship('RelatedData', backref='event', lazy='dynamic') + related_concepts = db.relationship('RelatedConcepts', backref='event', lazy='dynamic') + + @property + def keywords_list(self): + """返回解析后的关键词列表""" + if not self.keywords: + return [] + + if isinstance(self.keywords, list): + return self.keywords + + try: + # 如果是字符串,尝试解析JSON + if isinstance(self.keywords, str): + decoded = json.loads(self.keywords) + # 处理Unicode编码的情况 + if isinstance(decoded, list): + return [ + keyword.encode('utf-8').decode('unicode_escape') + if isinstance(keyword, str) and '\\u' in keyword + else keyword + for keyword in decoded + ] + return [] + + # 如果已经是字典或其他格式,尝试转换为列表 + return list(self.keywords) + except (json.JSONDecodeError, AttributeError, TypeError): + return [] + + def set_keywords(self, keywords): + """设置关键词列表""" + if isinstance(keywords, list): + self.keywords = json.dumps(keywords, ensure_ascii=False) + elif isinstance(keywords, str): + try: + # 尝试解析JSON字符串 + parsed = json.loads(keywords) + if isinstance(parsed, list): + self.keywords = json.dumps(parsed, ensure_ascii=False) + else: + self.keywords = json.dumps([keywords], ensure_ascii=False) + except json.JSONDecodeError: + # 如果不是有效的JSON,将其作为单个关键词 + self.keywords = json.dumps([keywords], ensure_ascii=False) + + +class RelatedStock(db.Model): + """相关标的模型""" + id = db.Column(db.Integer, primary_key=True) + event_id = db.Column(db.Integer, db.ForeignKey('event.id')) + stock_code = db.Column(db.String(20)) # 股票代码 + stock_name = db.Column(db.String(100)) # 股票名称 + sector = db.Column(db.String(100)) # 关联类型 + relation_desc = db.Column(db.String(1024)) # 关联原因描述 + created_at = db.Column(db.DateTime, default=beijing_now) + updated_at = db.Column(db.DateTime, default=beijing_now, onupdate=beijing_now) + correlation = db.Column(db.Float()) + momentum = db.Column(db.String(1024)) # 动量 + retrieved_sources = db.Column(db.JSON) # 动量 + + +class RelatedData(db.Model): + """关联数据模型""" + id = db.Column(db.Integer, primary_key=True) + event_id = db.Column(db.Integer, db.ForeignKey('event.id')) + title = db.Column(db.String(200)) # 数据标题 + data_type = db.Column(db.String(50)) # 数据类型 + data_content = db.Column(db.JSON) # 数据内容(JSON格式) + description = db.Column(db.Text) # 数据描述 + created_at = db.Column(db.DateTime, default=beijing_now) + + +class RelatedConcepts(db.Model): + """关联数据模型""" + id = db.Column(db.Integer, primary_key=True) + event_id = db.Column(db.Integer, db.ForeignKey('event.id')) + concept_code = db.Column(db.String(20)) # 数据标题 + concept = db.Column(db.String(100)) # 数据类型 + reason = db.Column(db.Text) # 数据描述 + image_paths = db.Column(db.JSON) # 数据内容(JSON格式) + created_at = db.Column(db.DateTime, default=beijing_now) + + @property + def image_paths_list(self): + """返回解析后的图片路径列表""" + if not self.image_paths: + return [] + + try: + # 如果是字符串,先解析成JSON + if isinstance(self.image_paths, str): + paths = json.loads(self.image_paths) + else: + paths = self.image_paths + + # 确保paths是列表 + if not isinstance(paths, list): + paths = [paths] + + # 从每个对象中提取path字段 + return [item['path'] if isinstance(item, dict) and 'path' in item + else item for item in paths] + except Exception as e: + print(f"Error processing image paths: {e}") + return [] + + def get_first_image_path(self): + """获取第一张图片的完整路径""" + paths = self.image_paths_list + if not paths: + return None + + # 获取第一个路径 + first_path = paths[0] + # 返回完整路径 + return first_path + + +class EventHotHistory(db.Model): + """事件热度历史记录""" + id = db.Column(db.Integer, primary_key=True) + event_id = db.Column(db.Integer, db.ForeignKey('event.id')) + score = db.Column(db.Float) # 总分 + interaction_score = db.Column(db.Float) # 互动分数 + follow_score = db.Column(db.Float) # 关注度分数 + view_score = db.Column(db.Float) # 浏览量分数 + recent_activity_score = db.Column(db.Float) # 最近活跃度分数 + time_decay = db.Column(db.Float) # 时间衰减因子 + created_at = db.Column(db.DateTime, default=beijing_now) + + event = db.relationship('Event', backref='hot_history') + + +class EventTransmissionNode(db.Model): + """事件传导节点模型""" + __tablename__ = 'event_transmission_nodes' + + id = db.Column(db.Integer, primary_key=True) + event_id = db.Column(db.Integer, db.ForeignKey('event.id'), nullable=False) + node_type = db.Column(db.Enum('company', 'industry', 'policy', 'technology', + 'market', 'event', 'other'), nullable=False) + node_name = db.Column(db.String(200), nullable=False) + node_description = db.Column(db.Text) + importance_score = db.Column(db.Integer, default=50) + stock_code = db.Column(db.String(20)) + is_main_event = db.Column(db.Boolean, default=False) + + created_at = db.Column(db.DateTime, default=beijing_now) + updated_at = db.Column(db.DateTime, default=beijing_now, onupdate=beijing_now) + + # Relationships + event = db.relationship('Event', backref='transmission_nodes') + outgoing_edges = db.relationship('EventTransmissionEdge', + foreign_keys='EventTransmissionEdge.from_node_id', + backref='from_node', cascade='all, delete-orphan') + incoming_edges = db.relationship('EventTransmissionEdge', + foreign_keys='EventTransmissionEdge.to_node_id', + backref='to_node', cascade='all, delete-orphan') + + __table_args__ = ( + db.Index('idx_event_id', 'event_id'), + db.Index('idx_node_type', 'node_type'), + db.Index('idx_main_event', 'is_main_event'), + ) + + +class EventTransmissionEdge(db.Model): + """事件传导边模型""" + __tablename__ = 'event_transmission_edges' + + id = db.Column(db.Integer, primary_key=True) + event_id = db.Column(db.Integer, db.ForeignKey('event.id'), nullable=False) + from_node_id = db.Column(db.Integer, db.ForeignKey('event_transmission_nodes.id'), nullable=False) + to_node_id = db.Column(db.Integer, db.ForeignKey('event_transmission_nodes.id'), nullable=False) + + transmission_type = db.Column(db.Enum('supply_chain', 'competition', 'policy', + 'technology', 'capital_flow', 'expectation', + 'cyclic_effect', 'other'), nullable=False) + transmission_mechanism = db.Column(db.Text) + direction = db.Column(db.Enum('positive', 'negative', 'neutral', 'mixed'), default='neutral') + strength = db.Column(db.Integer, default=50) + impact = db.Column(db.Text) + is_circular = db.Column(db.Boolean, default=False) + + created_at = db.Column(db.DateTime, default=beijing_now) + updated_at = db.Column(db.DateTime, default=beijing_now, onupdate=beijing_now) + + # Relationship + event = db.relationship('Event', backref='transmission_edges') + + __table_args__ = ( + db.Index('idx_event_id', 'event_id'), + db.Index('idx_strength', 'strength'), + db.Index('idx_from_to', 'from_node_id', 'to_node_id'), + db.Index('idx_circular', 'is_circular'), + ) + + +# 在 paste-2.txt 的模型定义部分添加 +class EventSankeyFlow(db.Model): + """事件桑基流模型""" + __tablename__ = 'event_sankey_flows' + + id = db.Column(db.Integer, primary_key=True) + event_id = db.Column(db.Integer, db.ForeignKey('event.id'), nullable=False) + + # 流的基本信息 + source_node = db.Column(db.String(200), nullable=False) + source_type = db.Column(db.Enum('event', 'policy', 'technology', 'industry', + 'company', 'product'), nullable=False) + source_level = db.Column(db.Integer, nullable=False, default=0) + + target_node = db.Column(db.String(200), nullable=False) + target_type = db.Column(db.Enum('policy', 'technology', 'industry', + 'company', 'product'), nullable=False) + target_level = db.Column(db.Integer, nullable=False, default=1) + + # 流量信息 + flow_value = db.Column(db.Numeric(10, 2), nullable=False) + flow_ratio = db.Column(db.Numeric(5, 4), nullable=False) + + # 传导机制 + transmission_path = db.Column(db.String(500)) + impact_description = db.Column(db.Text) + evidence_strength = db.Column(db.Integer, default=50) + + # 时间戳 + created_at = db.Column(db.DateTime, default=beijing_now) + updated_at = db.Column(db.DateTime, default=beijing_now, onupdate=beijing_now) + + # 关系 + event = db.relationship('Event', backref='sankey_flows') + + __table_args__ = ( + db.Index('idx_event_id', 'event_id'), + db.Index('idx_source_target', 'source_node', 'target_node'), + db.Index('idx_levels', 'source_level', 'target_level'), + db.Index('idx_flow_value', 'flow_value'), + ) + + +class HistoricalEvent(db.Model): + """历史事件模型""" + id = db.Column(db.Integer, primary_key=True) + event_id = db.Column(db.Integer, db.ForeignKey('event.id')) + title = db.Column(db.String(200)) + content = db.Column(db.Text) + event_date = db.Column(db.DateTime) + relevance = db.Column(db.Integer) # 相关性 + importance = db.Column(db.Integer) # 重要程度 + related_stock = db.Column(db.JSON) # 保留JSON字段 + created_at = db.Column(db.DateTime, default=beijing_now) + + # 新增关系 + stocks = db.relationship('HistoricalEventStock', backref='historical_event', lazy='dynamic', + cascade='all, delete-orphan') + + +class HistoricalEventStock(db.Model): + """历史事件相关股票模型""" + __tablename__ = 'historical_event_stocks' + + id = db.Column(db.Integer, primary_key=True) + historical_event_id = db.Column(db.Integer, db.ForeignKey('historical_event.id'), nullable=False) + stock_code = db.Column(db.String(20), nullable=False) + stock_name = db.Column(db.String(50)) + relation_desc = db.Column(db.Text) + correlation = db.Column(db.Float, default=0.5) + sector = db.Column(db.String(100)) + created_at = db.Column(db.DateTime, default=beijing_now) + + __table_args__ = ( + db.UniqueConstraint('historical_event_id', 'stock_code', name='unique_event_stock'), + ) + + +# === 股票盈利预测(自有表) === +class StockForecastData(db.Model): + """股票盈利预测数据 + + 源于本地表 stock_forecast_data,由独立离线程序写入。 + 字段与表结构保持一致,仅用于读取聚合后输出前端报表所需的结构。 + """ + __tablename__ = 'stock_forecast_data' + + id = db.Column(db.Integer, primary_key=True) + stock_code = db.Column(db.String(6), nullable=False) + indicator_name = db.Column(db.String(50), nullable=False) + year_2022a = db.Column(db.Numeric(15, 2)) + year_2023a = db.Column(db.Numeric(15, 2)) + year_2024a = db.Column(db.Numeric(15, 2)) + year_2025e = db.Column(db.Numeric(15, 2)) + year_2026e = db.Column(db.Numeric(15, 2)) + year_2027e = db.Column(db.Numeric(15, 2)) + process_time = db.Column(db.DateTime, nullable=False) + + __table_args__ = ( + db.UniqueConstraint('stock_code', 'indicator_name', name='unique_stock_indicator'), + ) + + def values_by_year(self): + years = ['2022A', '2023A', '2024A', '2025E', '2026E', '2027E'] + vals = [self.year_2022a, self.year_2023a, self.year_2024a, self.year_2025e, self.year_2026e, self.year_2027e] + + def _to_float(x): + try: + return float(x) if x is not None else None + except Exception: + return None + + return years, [_to_float(v) for v in vals] + + +@app.route('/api/events/', methods=['GET']) +def get_event_detail(event_id): + """获取事件详情""" + try: + event = Event.query.get_or_404(event_id) + + # 增加浏览计数 + event.view_count += 1 + db.session.commit() + + return jsonify({ + 'success': True, + 'data': { + 'id': event.id, + 'title': event.title, + 'description': event.description, + 'event_type': event.event_type, + 'status': event.status, + 'start_time': event.start_time.isoformat() if event.start_time else None, + 'end_time': event.end_time.isoformat() if event.end_time else None, + 'created_at': event.created_at.isoformat() if event.created_at else None, + 'hot_score': event.hot_score, + 'view_count': event.view_count, + 'trending_score': event.trending_score, + 'post_count': event.post_count, + 'follower_count': event.follower_count, + 'related_industries': event.related_industries, + 'keywords': event.keywords_list, + 'importance': event.importance, + 'related_avg_chg': event.related_avg_chg, + 'related_max_chg': event.related_max_chg, + 'related_week_chg': event.related_week_chg, + 'invest_score': event.invest_score, + 'expectation_surprise_score': event.expectation_surprise_score, + 'creator_id': event.creator_id, + 'has_chain_analysis': ( + EventTransmissionNode.query.filter_by(event_id=event_id).first() is not None or + EventSankeyFlow.query.filter_by(event_id=event_id).first() is not None + ), + 'is_following': False, # 需要根据当前用户状态判断 + } + }) + except Exception as e: + return jsonify({'success': False, 'error': str(e)}), 500 + + +@app.route('/api/events//stocks', methods=['GET']) +def get_related_stocks(event_id): + """获取相关股票列表""" + try: + # 订阅控制:相关标的需要 Pro 及以上 + if not _has_required_level('pro'): + return jsonify({'success': False, 'error': '需要Pro订阅', 'required_level': 'pro'}), 403 + event = Event.query.get_or_404(event_id) + stocks = event.related_stocks.order_by(RelatedStock.correlation.desc()).all() + + stocks_data = [] + for stock in stocks: + if stock.retrieved_sources is not None: + stocks_data.append({ + 'id': stock.id, + 'stock_code': stock.stock_code, + 'stock_name': stock.stock_name, + 'sector': stock.sector, + 'relation_desc': {"data":stock.retrieved_sources}, + 'retrieved_sources': stock.retrieved_sources, + 'correlation': stock.correlation, + 'momentum': stock.momentum, + 'created_at': stock.created_at.isoformat() if stock.created_at else None, + 'updated_at': stock.updated_at.isoformat() if stock.updated_at else None + }) + else: + stocks_data.append({ + 'id': stock.id, + 'stock_code': stock.stock_code, + 'stock_name': stock.stock_name, + 'sector': stock.sector, + 'relation_desc': stock.relation_desc, + 'correlation': stock.correlation, + 'momentum': stock.momentum, + 'created_at': stock.created_at.isoformat() if stock.created_at else None, + 'updated_at': stock.updated_at.isoformat() if stock.updated_at else None + }) + + return jsonify({ + 'success': True, + 'data': stocks_data + }) + except Exception as e: + return jsonify({'success': False, 'error': str(e)}), 500 + + +@app.route('/api/events//stocks', methods=['POST']) +def add_related_stock(event_id): + """添加相关股票""" + try: + event = Event.query.get_or_404(event_id) + data = request.get_json() + + # 验证必要字段 + if not data.get('stock_code') or not data.get('relation_desc'): + return jsonify({'success': False, 'error': '缺少必要字段'}), 400 + + # 检查是否已存在 + existing = RelatedStock.query.filter_by( + event_id=event_id, + stock_code=data['stock_code'] + ).first() + + if existing: + return jsonify({'success': False, 'error': '该股票已存在'}), 400 + + # 创建新的相关股票记录 + new_stock = RelatedStock( + event_id=event_id, + stock_code=data['stock_code'], + stock_name=data.get('stock_name', ''), + sector=data.get('sector', ''), + relation_desc=data['relation_desc'], + correlation=data.get('correlation', 0.5), + momentum=data.get('momentum', '') + ) + + db.session.add(new_stock) + db.session.commit() + + return jsonify({ + 'success': True, + 'data': { + 'id': new_stock.id, + 'stock_code': new_stock.stock_code, + 'relation_desc': new_stock.relation_desc + } + }) + except Exception as e: + db.session.rollback() + return jsonify({'success': False, 'error': str(e)}), 500 + + +@app.route('/api/stocks/', methods=['DELETE']) +def delete_related_stock(stock_id): + """删除相关股票""" + try: + stock = RelatedStock.query.get_or_404(stock_id) + db.session.delete(stock) + db.session.commit() + + return jsonify({'success': True, 'message': '删除成功'}) + except Exception as e: + db.session.rollback() + return jsonify({'success': False, 'error': str(e)}), 500 + + +@app.route('/api/events//concepts', methods=['GET']) +def get_related_concepts(event_id): + """获取相关概念列表""" + try: + # 订阅控制:相关概念需要 Pro 及以上 + if not _has_required_level('pro'): + return jsonify({'success': False, 'error': '需要Pro订阅', 'required_level': 'pro'}), 403 + event = Event.query.get_or_404(event_id) + concepts = event.related_concepts.all() + + concepts_data = [] + for concept in concepts: + concepts_data.append({ + 'id': concept.id, + 'concept_code': concept.concept_code, + 'concept': concept.concept, + 'reason': concept.reason, + 'image_paths': concept.image_paths_list, + 'first_image_path': concept.get_first_image_path(), + 'created_at': concept.created_at.isoformat() if concept.created_at else None + }) + + return jsonify({ + 'success': True, + 'data': concepts_data + }) + except Exception as e: + return jsonify({'success': False, 'error': str(e)}), 500 + + +@app.route('/api/events//historical', methods=['GET']) +def get_historical_events(event_id): + """获取历史事件对比""" + try: + event = Event.query.get_or_404(event_id) + historical_events = event.historical_events.order_by(HistoricalEvent.event_date.desc()).all() + + events_data = [] + for hist_event in historical_events: + events_data.append({ + 'id': hist_event.id, + 'title': hist_event.title, + 'content': hist_event.content, + 'event_date': hist_event.event_date.isoformat() if hist_event.event_date else None, + 'importance': hist_event.importance, + 'relevance': hist_event.relevance, + 'created_at': hist_event.created_at.isoformat() if hist_event.created_at else None + }) + + # 订阅控制:免费用户仅返回前2条;Pro/Max返回全部 + info = _get_current_subscription_info() + sub_type = (info.get('type') or 'free').lower() + if sub_type == 'free': + return jsonify({ + 'success': True, + 'data': events_data[:2], + 'truncated': len(events_data) > 2, + 'required_level': 'pro' + }) + return jsonify({'success': True, 'data': events_data}) + except Exception as e: + return jsonify({'success': False, 'error': str(e)}), 500 + + +@app.route('/api/historical-events//stocks', methods=['GET']) +def get_historical_event_stocks(event_id): + """获取历史事件相关股票列表""" + try: + # 直接查询历史事件,不需要通过主事件 + hist_event = HistoricalEvent.query.get_or_404(event_id) + stocks = hist_event.stocks.order_by(HistoricalEventStock.correlation.desc()).all() + + # 获取事件对应的交易日 + event_trading_date = None + if hist_event.event_date: + event_trading_date = get_trading_day_near_date(hist_event.event_date) + + stocks_data = [] + for stock in stocks: + stock_data = { + 'id': stock.id, + 'stock_code': stock.stock_code, + 'stock_name': stock.stock_name, + 'sector': stock.sector, + 'relation_desc': stock.relation_desc, + 'correlation': stock.correlation, + 'created_at': stock.created_at.isoformat() if stock.created_at else None + } + + # 添加涨幅数据 + if event_trading_date: + try: + # 查询股票在事件对应交易日的数据 + with engine.connect() as conn: + query = text(""" + SELECT close_price, change_pct + FROM ea_dailyline + WHERE seccode = :stock_code + AND date = :trading_date + ORDER BY date DESC + LIMIT 1 + """) + + result = conn.execute(query, { + 'stock_code': stock.stock_code, + 'trading_date': event_trading_date + }).fetchone() + + if result: + stock_data['event_day_close'] = float(result[0]) if result[0] else None + stock_data['event_day_change_pct'] = float(result[1]) if result[1] else None + else: + stock_data['event_day_close'] = None + stock_data['event_day_change_pct'] = None + except Exception as e: + print(f"查询股票{stock.stock_code}在{event_trading_date}的数据失败: {e}") + stock_data['event_day_close'] = None + stock_data['event_day_change_pct'] = None + else: + stock_data['event_day_close'] = None + stock_data['event_day_change_pct'] = None + + stocks_data.append(stock_data) + + return jsonify({ + 'success': True, + 'data': stocks_data, + 'event_trading_date': event_trading_date.isoformat() if event_trading_date else None + }) + except Exception as e: + return jsonify({'success': False, 'error': str(e)}), 500 + + +@app.route('/api/events//expectation-score', methods=['GET']) +def get_expectation_score(event_id): + """获取超预期得分""" + try: + event = Event.query.get_or_404(event_id) + + # 如果事件有超预期得分,直接返回 + if event.expectation_surprise_score is not None: + score = event.expectation_surprise_score + else: + # 如果没有,根据历史事件计算一个模拟得分 + historical_events = event.historical_events.all() + if historical_events: + # 基于历史事件数量和重要性计算得分 + total_importance = sum(ev.importance or 0 for ev in historical_events) + avg_importance = total_importance / len(historical_events) if historical_events else 0 + score = min(100, max(0, int(avg_importance * 20 + len(historical_events) * 5))) + else: + # 默认得分 + score = 65 + + return jsonify({ + 'success': True, + 'data': { + 'score': score, + 'description': '基于历史事件判断当前事件的超预期情况,满分100分' + } + }) + except Exception as e: + return jsonify({'success': False, 'error': str(e)}), 500 + + +@app.route('/api/events//follow', methods=['POST']) +def toggle_event_follow(event_id): + """切换事件关注状态(需登录)""" + if 'user_id' not in session: + return jsonify({'success': False, 'error': '未登录'}), 401 + + try: + event = Event.query.get_or_404(event_id) + user_id = session['user_id'] + + existing = EventFollow.query.filter_by(user_id=user_id, event_id=event_id).first() + if existing: + # 取消关注 + db.session.delete(existing) + event.follower_count = max(0, (event.follower_count or 0) - 1) + db.session.commit() + return jsonify({'success': True, 'data': {'is_following': False, 'follower_count': event.follower_count}}) + else: + # 关注 + follow = EventFollow(user_id=user_id, event_id=event_id) + db.session.add(follow) + event.follower_count = (event.follower_count or 0) + 1 + db.session.commit() + return jsonify({'success': True, 'data': {'is_following': True, 'follower_count': event.follower_count}}) + except Exception as e: + db.session.rollback() + return jsonify({'success': False, 'error': str(e)}), 500 + + +@app.route('/api/events//transmission', methods=['GET']) +def get_transmission_chain(event_id): + try: + # 订阅控制:传导链分析需要 Max 及以上 + if not _has_required_level('max'): + return jsonify({'success': False, 'error': '需要Max订阅', 'required_level': 'max'}), 403 + # 确保数据库连接是活跃的 + db.session.execute(text('SELECT 1')) + + event = Event.query.get_or_404(event_id) + nodes = EventTransmissionNode.query.filter_by(event_id=event_id).all() + edges = EventTransmissionEdge.query.filter_by(event_id=event_id).all() + + # 过滤孤立节点 + connected_node_ids = set() + for edge in edges: + connected_node_ids.add(edge.from_node_id) + connected_node_ids.add(edge.to_node_id) + + # 只保留有连接的节点 + connected_nodes = [node for node in nodes if node.id in connected_node_ids] + + # 如果没有主事件节点,也保留主事件节点 + main_event_node = next((node for node in nodes if node.is_main_event), None) + if main_event_node and main_event_node not in connected_nodes: + connected_nodes.append(main_event_node) + + if not connected_nodes: + return jsonify({'success': False, 'message': '暂无传导链分析数据'}) + + # 节点类型到中文类别的映射 + categories = { + 'event': "事件", 'industry': "行业", 'company': "公司", + 'policy': "政策", 'technology': "技术", 'market': "市场", 'other': "其他" + } + + nodes_data = [] + for node in connected_nodes: + node_category = categories.get(node.node_type, "其他") + nodes_data.append({ + 'id': str(node.id), # 转换为字符串以保持一致性 + 'name': node.node_name, + 'category': node_category, + 'value': node.importance_score or 20, + 'extra': { + 'node_type': node.node_type, + 'description': node.node_description, + 'importance_score': node.importance_score, + 'stock_code': node.stock_code, + 'is_main_event': node.is_main_event + } + }) + + edges_data = [] + for edge in edges: + # 确保边的两端节点都在连接节点列表中 + if edge.from_node_id in connected_node_ids and edge.to_node_id in connected_node_ids: + edges_data.append({ + 'source': str(edge.from_node_id), # 转换为字符串以保持一致性 + 'target': str(edge.to_node_id), # 转换为字符串以保持一致性 + 'value': edge.strength or 50, + 'extra': { + 'transmission_type': edge.transmission_type, + 'transmission_mechanism': edge.transmission_mechanism, + 'direction': edge.direction, + 'strength': edge.strength, + 'impact': edge.impact, + 'is_circular': edge.is_circular, + } + }) + + return jsonify({ + 'success': True, + 'data': { + 'nodes': nodes_data, + 'edges': edges_data + } + }) + except Exception as e: + return jsonify({'success': False, 'error': str(e)}), 500 + + +# 修复股票报价API - 支持GET和POST方法 +@app.route('/api/stock/quotes', methods=['GET', 'POST']) +def get_stock_quotes(): + try: + if request.method == 'GET': + # GET 请求从查询参数获取数据 + codes_str = request.args.get('codes', '') + codes = [code.strip() for code in codes_str.split(',') if code.strip()] + event_time_str = request.args.get('event_time') + else: + # POST 请求从 JSON 获取数据 + codes = request.json.get('codes', []) + event_time_str = request.json.get('event_time') + + if not codes: + return jsonify({'success': False, 'error': '请提供股票代码'}), 400 + + # 处理事件时间 + if event_time_str: + try: + event_time = datetime.fromisoformat(event_time_str.replace('Z', '+00:00')) + except: + event_time = datetime.now() + else: + event_time = datetime.now() + + current_time = datetime.now() + client = get_clickhouse_client() + + # Get stock names from MySQL + stock_names = {} + with engine.connect() as conn: + for code in codes: + codez = code.split('.')[0] + result = conn.execute(text( + "SELECT SECNAME FROM ea_stocklist WHERE SECCODE = :code" + ), {"code": codez}).fetchone() + if result: + stock_names[code] = result[0] + else: + stock_names[code] = f"股票{codez}" + + def get_trading_day_and_times(event_datetime): + event_date = event_datetime.date() + event_time = event_datetime.time() + + # Trading hours + market_open = dt_time(9, 30) + market_close = dt_time(15, 0) + + with engine.connect() as conn: + # First check if the event date itself is a trading day + is_trading_day = conn.execute(text(""" + SELECT 1 + FROM trading_days + WHERE EXCHANGE_DATE = :date + """), {"date": event_date}).fetchone() is not None + + if is_trading_day: + # If it's a trading day, determine time period based on event time + if event_time < market_open: + # Before market opens - use full trading day + return event_date, market_open, market_close + elif event_time > market_close: + # After market closes - get next trading day + next_trading_day = conn.execute(text(""" + SELECT EXCHANGE_DATE + FROM trading_days + WHERE EXCHANGE_DATE > :date + ORDER BY EXCHANGE_DATE LIMIT 1 + """), {"date": event_date}).fetchone() + # Convert to date object if we found a next trading day + return (next_trading_day[0].date() if next_trading_day else None, + market_open, market_close) + else: + # During trading hours + return event_date, event_time, market_close + else: + # If not a trading day, get next trading day + next_trading_day = conn.execute(text(""" + SELECT EXCHANGE_DATE + FROM trading_days + WHERE EXCHANGE_DATE > :date + ORDER BY EXCHANGE_DATE LIMIT 1 + """), {"date": event_date}).fetchone() + # Convert to date object if we found a next trading day + return (next_trading_day[0].date() if next_trading_day else None, + market_open, market_close) + + trading_day, start_time, end_time = get_trading_day_and_times(event_time) + + if not trading_day: + return jsonify({ + 'success': True, + 'data': {code: {'name': name, 'price': None, 'change': None} + for code, name in stock_names.items()} + }) + + # For historical dates, ensure we're using actual data + start_datetime = datetime.combine(trading_day, start_time) + end_datetime = datetime.combine(trading_day, end_time) + + # If the trading day is in the future relative to current time, + # return only names without data + if trading_day > current_time.date(): + return jsonify({ + 'success': True, + 'data': {code: {'name': name, 'price': None, 'change': None} + for code, name in stock_names.items()} + }) + + results = {} + print(f"处理股票代码: {codes}, 交易日: {trading_day}, 时间范围: {start_datetime} - {end_datetime}") + + for code in codes: + try: + print(f"正在查询股票 {code} 的价格数据...") + # Get the first price and last price for the trading period + data = client.execute(""" + WITH first_price AS (SELECT close + FROM stock_minute + WHERE code = %(code)s + AND timestamp >= %(start)s + AND timestamp <= %(end)s + ORDER BY timestamp + LIMIT 1 + ), + last_price AS ( + SELECT close + FROM stock_minute + WHERE code = %(code)s + AND timestamp >= %(start)s + AND timestamp <= %(end)s + ORDER BY timestamp DESC + LIMIT 1 + ) + SELECT last_price.close as last_price, + (last_price.close - first_price.close) / first_price.close * 100 as change + FROM last_price + CROSS JOIN first_price + WHERE EXISTS (SELECT 1 FROM first_price) + AND EXISTS (SELECT 1 FROM last_price) + """, { + 'code': code, + 'start': start_datetime, + 'end': end_datetime + }) + + print(f"股票 {code} 查询结果: {data}") + if data and data[0] and data[0][0] is not None: + price = float(data[0][0]) if data[0][0] is not None else None + change = float(data[0][1]) if data[0][1] is not None else None + + results[code] = { + 'price': price, + 'change': change, + 'name': stock_names.get(code, f'股票{code.split(".")[0]}') + } + else: + results[code] = { + 'price': None, + 'change': None, + 'name': stock_names.get(code, f'股票{code.split(".")[0]}') + } + except Exception as e: + print(f"Error processing stock {code}: {e}") + results[code] = { + 'price': None, + 'change': None, + 'name': stock_names.get(code, f'股票{code.split(".")[0]}') + } + + # 返回标准格式 + return jsonify({'success': True, 'data': results}) + + except Exception as e: + print(f"Stock quotes API error: {e}") + return jsonify({'success': False, 'error': str(e)}), 500 + + +def get_clickhouse_client(): + return Cclient( + host='222.128.1.157', + port=18000, + user='default', + password='Zzl33818!', + database='stock' + ) + + +@app.route('/api/account/calendar/events', methods=['GET', 'POST']) +def account_calendar_events(): + """返回当前用户的投资计划与关注的未来事件(合并)。 + GET: 可按日期范围/月份过滤;POST: 新增投资计划(写入 InvestmentPlan)。 + """ + try: + if 'user_id' not in session: + return jsonify({'success': False, 'error': '未登录'}), 401 + + if request.method == 'POST': + data = request.get_json() or {} + title = data.get('title') + event_date_str = data.get('event_date') or data.get('date') + plan_type = data.get('type') or 'plan' + description = data.get('description') or data.get('content') or '' + stocks = data.get('stocks') or [] + + if not title or not event_date_str: + return jsonify({'success': False, 'error': '缺少必填字段'}), 400 + + try: + event_date = datetime.fromisoformat(event_date_str).date() + except Exception: + return jsonify({'success': False, 'error': '日期格式错误'}), 400 + + plan = InvestmentPlan( + user_id=session['user_id'], + date=event_date, + title=title, + content=description, + type=plan_type, + stocks=json.dumps(stocks), + tags=json.dumps(data.get('tags', [])), + status=data.get('status', 'active') + ) + db.session.add(plan) + db.session.commit() + + return jsonify({'success': True, 'data': { + 'id': plan.id, + 'title': plan.title, + 'event_date': plan.date.isoformat(), + 'type': plan.type, + 'description': plan.content, + 'stocks': json.loads(plan.stocks) if plan.stocks else [], + 'source': 'plan' + }}) + + # GET + # 解析过滤参数:date 或 (year, month) 或 (start_date, end_date) + date_str = request.args.get('date') + year = request.args.get('year', type=int) + month = request.args.get('month', type=int) + start_date_str = request.args.get('start_date') + end_date_str = request.args.get('end_date') + + start_date = None + end_date = None + if date_str: + try: + d = datetime.fromisoformat(date_str).date() + start_date = d + end_date = d + except Exception: + pass + elif year and month: + # 月份范围 + start_date = datetime(year, month, 1).date() + if month == 12: + end_date = datetime(year + 1, 1, 1).date() - timedelta(days=1) + else: + end_date = datetime(year, month + 1, 1).date() - timedelta(days=1) + elif start_date_str and end_date_str: + try: + start_date = datetime.fromisoformat(start_date_str).date() + end_date = datetime.fromisoformat(end_date_str).date() + except Exception: + start_date = None + end_date = None + + # 查询投资计划 + plans_query = InvestmentPlan.query.filter_by(user_id=session['user_id']) + if start_date and end_date: + plans_query = plans_query.filter(InvestmentPlan.date >= start_date, InvestmentPlan.date <= end_date) + elif start_date: + plans_query = plans_query.filter(InvestmentPlan.date == start_date) + plans = plans_query.order_by(InvestmentPlan.date.asc()).all() + + plan_events = [{ + 'id': p.id, + 'title': p.title, + 'event_date': p.date.isoformat(), + 'type': p.type or 'plan', + 'description': p.content, + 'importance': 3, + 'stocks': json.loads(p.stocks) if p.stocks else [], + 'source': 'plan' + } for p in plans] + + # 查询关注的未来事件 + follows = FutureEventFollow.query.filter_by(user_id=session['user_id']).all() + future_event_ids = [f.future_event_id for f in follows] + + future_events = [] + if future_event_ids: + base_sql = """ + SELECT data_id, \ + title, \ + type, \ + calendar_time, \ + star, \ + former, \ + forecast, \ + fact, \ + related_stocks, \ + concepts + FROM future_events + WHERE data_id IN :event_ids \ + """ + + params = {'event_ids': tuple(future_event_ids)} + # 日期过滤(按 calendar_time 的日期) + if start_date and end_date: + base_sql += " AND DATE(calendar_time) BETWEEN :start_date AND :end_date" + params.update({'start_date': start_date, 'end_date': end_date}) + elif start_date: + base_sql += " AND DATE(calendar_time) = :start_date" + params.update({'start_date': start_date}) + + base_sql += " ORDER BY calendar_time" + + result = db.session.execute(text(base_sql), params) + for row in result: + # related_stocks 形如 [[code,name,reason,score], ...] + rs = parse_json_field(row.related_stocks) + stock_tags = [] + try: + for it in rs: + if isinstance(it, (list, tuple)) and len(it) >= 2: + stock_tags.append(f"{it[0]} {it[1]}") + elif isinstance(it, str): + stock_tags.append(it) + except Exception: + pass + + future_events.append({ + 'id': row.data_id, + 'title': row.title, + 'event_date': (row.calendar_time.date().isoformat() if row.calendar_time else None), + 'type': 'future_event', + 'importance': int(row.star) if getattr(row, 'star', None) is not None else 3, + 'description': row.former or '', + 'stocks': stock_tags, + 'is_following': True, + 'source': 'future' + }) + + return jsonify({'success': True, 'data': plan_events + future_events}) + + except Exception as e: + return jsonify({'success': False, 'error': str(e)}), 500 + + +@app.route('/api/account/calendar/events/', methods=['DELETE']) +def delete_account_calendar_event(event_id): + """删除用户创建的投资计划事件(不影响关注的未来事件)。""" + try: + if 'user_id' not in session: + return jsonify({'success': False, 'error': '未登录'}), 401 + plan = InvestmentPlan.query.filter_by(id=event_id, user_id=session['user_id']).first() + if not plan: + return jsonify({'success': False, 'error': '未找到该记录'}), 404 + db.session.delete(plan) + db.session.commit() + return jsonify({'success': True}) + except Exception as e: + db.session.rollback() + return jsonify({'success': False, 'error': str(e)}), 500 + + +@app.route('/api/stock//kline') +def get_stock_kline(stock_code): + chart_type = request.args.get('type', 'minute') + event_time = request.args.get('event_time') + + try: + event_datetime = datetime.fromisoformat(event_time) if event_time else datetime.now() + except ValueError: + return jsonify({'error': 'Invalid event_time format'}), 400 + + # 获取股票名称 + with engine.connect() as conn: + result = conn.execute(text( + "SELECT SECNAME FROM ea_stocklist WHERE SECCODE = :code" + ), {"code": stock_code.split('.')[0]}).fetchone() + stock_name = result[0] if result else 'Unknown' + + if chart_type == 'daily': + return get_daily_kline(stock_code, event_datetime, stock_name) + elif chart_type == 'minute': + return get_minute_kline(stock_code, event_datetime, stock_name) + elif chart_type == 'timeline': + return get_timeline_data(stock_code, event_datetime, stock_name) + else: + # 对于未知的类型,返回错误 + return jsonify({'error': f'Unsupported chart type: {chart_type}'}), 400 + + +@app.route('/api/stock//latest-minute', methods=['GET']) +def get_latest_minute_data(stock_code): + """获取最新交易日的分钟频数据""" + client = get_clickhouse_client() + + # 确保股票代码包含后缀 + if '.' not in stock_code: + stock_code = f"{stock_code}.SH" if stock_code.startswith('6') else f"{stock_code}.SZ" + + # 获取股票名称 + with engine.connect() as conn: + result = conn.execute(text( + "SELECT SECNAME FROM ea_stocklist WHERE SECCODE = :code" + ), {"code": stock_code.split('.')[0]}).fetchone() + stock_name = result[0] if result else 'Unknown' + + # 查找最近30天内有数据的最新交易日 + target_date = None + current_date = datetime.now().date() + + for i in range(30): + check_date = current_date - timedelta(days=i) + trading_day = get_trading_day_near_date(check_date) + + if trading_day and trading_day <= current_date: + # 检查这个交易日是否有分钟数据 + test_data = client.execute(""" + SELECT COUNT(*) + FROM stock_minute + WHERE code = %(code)s + AND timestamp BETWEEN %(start)s AND %(end)s + LIMIT 1 + """, { + 'code': stock_code, + 'start': datetime.combine(trading_day, dt_time(9, 30)), + 'end': datetime.combine(trading_day, dt_time(15, 0)) + }) + + if test_data and test_data[0][0] > 0: + target_date = trading_day + break + + if not target_date: + return jsonify({ + 'error': 'No data available', + 'code': stock_code, + 'name': stock_name, + 'data': [], + 'trade_date': current_date.strftime('%Y-%m-%d'), + 'type': 'minute' + }) + + # 获取目标日期的完整交易时段数据 + data = client.execute(""" + SELECT + timestamp, + open, + high, + low, + close, + volume, + amt + FROM stock_minute + WHERE code = %(code)s + AND timestamp BETWEEN %(start)s AND %(end)s + ORDER BY timestamp + """, { + 'code': stock_code, + 'start': datetime.combine(target_date, dt_time(9, 30)), + 'end': datetime.combine(target_date, dt_time(15, 0)) + }) + + kline_data = [{ + 'time': row[0].strftime('%H:%M'), + 'open': float(row[1]), + 'high': float(row[2]), + 'low': float(row[3]), + 'close': float(row[4]), + 'volume': float(row[5]), + 'amount': float(row[6]) + } for row in data] + + return jsonify({ + 'code': stock_code, + 'name': stock_name, + 'data': kline_data, + 'trade_date': target_date.strftime('%Y-%m-%d'), + 'type': 'minute', + 'is_latest': True + }) + + +@app.route('/api/stock//forecast-report', methods=['GET']) +def get_stock_forecast_report(stock_code): + """基于 stock_forecast_data 输出报表所需数据结构 + + 返回: + - income_profit_trend: 营业收入/归母净利润趋势 + - growth_bars: 增长率柱状图数据(基于营业收入同比) + - eps_trend: EPS 折线 + - pe_peg_axes: PE/PEG 双轴 + - detail_table: 详细数据表格(与附件结构一致) + """ + try: + # 读取该股票所有指标 + rows = StockForecastData.query.filter_by(stock_code=stock_code).all() + if not rows: + return jsonify({'success': False, 'error': 'no_data'}), 404 + + # 将指标映射为字典 + indicators = {} + for r in rows: + years, vals = r.values_by_year() + indicators[r.indicator_name] = dict(zip(years, vals)) + + def safe(x): + return x if x is not None else None + + years = ['2022A', '2023A', '2024A', '2025E', '2026E', '2027E'] + + # 营业收入与净利润趋势 + income = indicators.get('营业总收入(百万元)', {}) + profit = indicators.get('归母净利润(百万元)', {}) + income_profit_trend = { + 'years': years, + 'income': [safe(income.get(y)) for y in years], + 'profit': [safe(profit.get(y)) for y in years] + } + + # 增长率柱状(若表内已有"增长率(%)",直接使用;否则按营业收入同比计算) + growth = indicators.get('增长率(%)') + if growth is None: + # 计算同比: (curr - prev)/prev*100 + growth_vals = [] + prev = None + for y in years: + curr = income.get(y) + if prev is not None and prev not in (None, 0) and curr is not None: + growth_vals.append(round((float(curr) - float(prev)) / float(prev) * 100, 2)) + else: + growth_vals.append(None) + prev = curr + else: + growth_vals = [safe(growth.get(y)) for y in years] + growth_bars = { + 'years': years, + 'revenue_growth_pct': growth_vals, + 'net_profit_growth_pct': None # 如后续需要可扩展 + } + + # EPS 趋势 + eps = indicators.get('EPS(稀释)') or indicators.get('EPS(元/股)') or {} + eps_trend = { + 'years': years, + 'eps': [safe(eps.get(y)) for y in years] + } + + # PE / PEG 双轴 + pe = indicators.get('PE') or {} + peg = indicators.get('PEG') or {} + pe_peg_axes = { + 'years': years, + 'pe': [safe(pe.get(y)) for y in years], + 'peg': [safe(peg.get(y)) for y in years] + } + + # 详细数据表格(列顺序固定) + def fmt(val): + try: + return None if val is None else round(float(val), 2) + except Exception: + return None + + detail_rows = [ + { + '指标': '营业总收入(百万元)', + **{y: fmt(income.get(y)) for y in years}, + }, + { + '指标': '增长率(%)', + **{y: fmt(v) for y, v in zip(years, growth_vals)}, + }, + { + '指标': '归母净利润(百万元)', + **{y: fmt(profit.get(y)) for y in years}, + }, + { + '指标': 'EPS(稀释)', + **{y: fmt(eps.get(y)) for y in years}, + }, + { + '指标': 'PE', + **{y: fmt(pe.get(y)) for y in years}, + }, + { + '指标': 'PEG', + **{y: fmt(peg.get(y)) for y in years}, + }, + ] + + return jsonify({ + 'success': True, + 'data': { + 'income_profit_trend': income_profit_trend, + 'growth_bars': growth_bars, + 'eps_trend': eps_trend, + 'pe_peg_axes': pe_peg_axes, + 'detail_table': { + 'years': years, + 'rows': detail_rows + } + } + }) + except Exception as e: + app.logger.error(f"forecast report error: {e}", exc_info=True) + return jsonify({'success': False, 'error': str(e)}), 500 + + +@app.route('/api/stock//basic-info', methods=['GET']) +def get_stock_basic_info(stock_code): + """获取股票基本信息(来自ea_baseinfo表)""" + try: + with engine.connect() as conn: + query = text(""" + SELECT SECCODE, + SECNAME, + ORGNAME, + F001V as en_name, + F002V as en_short_name, + F003V as legal_representative, + F004V as reg_address, + F005V as office_address, + F006V as post_code, + F007N as reg_capital, + F009V as currency, + F010D as establish_date, + F011V as website, + F012V as email, + F013V as tel, + F014V as fax, + F015V as main_business, + F016V as business_scope, + F017V as company_intro, + F018V as secretary, + F019V as secretary_tel, + F020V as secretary_fax, + F021V as secretary_email, + F024V as listing_status, + F026V as province, + F028V as city, + F030V as industry_l1, + F032V as industry_l2, + F034V as sw_industry_l1, + F036V as sw_industry_l2, + F038V as sw_industry_l3, + F039V as accounting_firm, + F040V as law_firm, + F041V as chairman, + F042V as general_manager, + F043V as independent_directors, + F050V as credit_code, + F054V as company_size, + UPDATE_DATE + FROM ea_baseinfo + WHERE SECCODE = :stock_code LIMIT 1 + """) + + result = conn.execute(query, {'stock_code': stock_code}).fetchone() + + if not result: + return jsonify({ + 'success': False, + 'error': f'未找到股票代码 {stock_code} 的基本信息' + }), 404 + + # 转换为字典 + basic_info = {} + for key, value in zip(result.keys(), result): + if isinstance(value, datetime): + basic_info[key] = value.strftime('%Y-%m-%d') + elif isinstance(value, Decimal): + basic_info[key] = float(value) + else: + basic_info[key] = value + + return jsonify({ + 'success': True, + 'data': basic_info + }) + + except Exception as e: + app.logger.error(f"Error getting stock basic info: {e}", exc_info=True) + return jsonify({'success': False, 'error': str(e)}), 500 + + +@app.route('/api/stock//announcements', methods=['GET']) +def get_stock_announcements(stock_code): + """获取股票公告列表""" + try: + limit = request.args.get('limit', 50, type=int) + + with engine.connect() as conn: + query = text(""" + SELECT F001D as announce_date, + F002V as title, + F003V as url, + F004V as format, + F005N as file_size, + F006V as info_type, + UPDATE_DATE + FROM ea_baseinfolist + WHERE SECCODE = :stock_code + ORDER BY F001D DESC LIMIT :limit + """) + + result = conn.execute(query, {'stock_code': stock_code, 'limit': limit}).fetchall() + + announcements = [] + for row in result: + announcement = {} + for key, value in zip(row.keys(), row): + if value is None: + announcement[key] = None + elif isinstance(value, datetime): + announcement[key] = value.strftime('%Y-%m-%d %H:%M:%S') + elif isinstance(value, date): + announcement[key] = value.strftime('%Y-%m-%d') + elif isinstance(value, Decimal): + announcement[key] = float(value) + else: + announcement[key] = value + announcements.append(announcement) + + return jsonify({ + 'success': True, + 'data': announcements, + 'total': len(announcements) + }) + + except Exception as e: + app.logger.error(f"Error getting stock announcements: {e}", exc_info=True) + return jsonify({'success': False, 'error': str(e)}), 500 + + +@app.route('/api/stock//disclosure-schedule', methods=['GET']) +def get_stock_disclosure_schedule(stock_code): + """获取股票财报预披露时间表""" + try: + with engine.connect() as conn: + query = text(""" + SELECT distinct F001D as report_period, + F002D as scheduled_date, + F003D as change_date1, + F004D as change_date2, + F005D as change_date3, + F006D as actual_date, + F007D as change_date4, + F008D as change_date5, + MODTIME as mod_time + FROM ea_pretime + WHERE SECCODE = :stock_code + ORDER BY F001D DESC LIMIT 20 + """) + + result = conn.execute(query, {'stock_code': stock_code}).fetchall() + + schedules = [] + for row in result: + schedule = {} + for key, value in zip(row.keys(), row): + if value is None: + schedule[key] = None + elif isinstance(value, datetime): + schedule[key] = value.strftime('%Y-%m-%d %H:%M:%S') + elif isinstance(value, date): + schedule[key] = value.strftime('%Y-%m-%d') + elif isinstance(value, Decimal): + schedule[key] = float(value) + else: + schedule[key] = value + + # 计算最新的预约日期 + latest_scheduled = schedule.get('scheduled_date') + for change_field in ['change_date5', 'change_date4', 'change_date3', 'change_date2', 'change_date1']: + if schedule.get(change_field): + latest_scheduled = schedule[change_field] + break + + schedule['latest_scheduled_date'] = latest_scheduled + schedule['is_disclosed'] = bool(schedule.get('actual_date')) + + # 格式化报告期名称 + if schedule.get('report_period'): + period_date = schedule['report_period'] + if period_date.endswith('-03-31'): + schedule['report_name'] = f"{period_date[:4]}年一季报" + elif period_date.endswith('-06-30'): + schedule['report_name'] = f"{period_date[:4]}年中报" + elif period_date.endswith('-09-30'): + schedule['report_name'] = f"{period_date[:4]}年三季报" + elif period_date.endswith('-12-31'): + schedule['report_name'] = f"{period_date[:4]}年年报" + else: + schedule['report_name'] = period_date + + schedules.append(schedule) + + return jsonify({ + 'success': True, + 'data': schedules, + 'total': len(schedules) + }) + + except Exception as e: + app.logger.error(f"Error getting disclosure schedule: {e}", exc_info=True) + return jsonify({'success': False, 'error': str(e)}), 500 + + +@app.route('/api/stock//actual-control', methods=['GET']) +def get_stock_actual_control(stock_code): + """获取股票实际控制人信息""" + try: + with engine.connect() as conn: + query = text(""" + SELECT DECLAREDATE as declare_date, + ENDDATE as end_date, + F001V as direct_holder_id, + F002V as direct_holder_name, + F003V as actual_controller_id, + F004V as actual_controller_name, + F005N as holding_shares, + F006N as holding_ratio, + F007V as control_type_code, + F008V as control_type, + F012V as direct_controller_id, + F013V as direct_controller_name, + F014V as controller_type, + ORGNAME as org_name, + SECCODE as sec_code, + SECNAME as sec_name + FROM ea_actualcon + WHERE SECCODE = :stock_code + ORDER BY ENDDATE DESC, DECLAREDATE DESC LIMIT 20 + """) + + result = conn.execute(query, {'stock_code': stock_code}).fetchall() + + control_info = [] + for row in result: + control_record = {} + for key, value in zip(row.keys(), row): + if value is None: + control_record[key] = None + elif isinstance(value, datetime): + control_record[key] = value.strftime('%Y-%m-%d %H:%M:%S') + elif isinstance(value, date): + control_record[key] = value.strftime('%Y-%m-%d') + elif isinstance(value, Decimal): + control_record[key] = float(value) + else: + control_record[key] = value + + control_info.append(control_record) + + return jsonify({ + 'success': True, + 'data': control_info, + 'total': len(control_info) + }) + + except Exception as e: + app.logger.error(f"Error getting actual control info: {e}", exc_info=True) + return jsonify({'success': False, 'error': str(e)}), 500 + + +@app.route('/api/stock//concentration', methods=['GET']) +def get_stock_concentration(stock_code): + """获取股票股权集中度信息""" + try: + with engine.connect() as conn: + query = text(""" + SELECT ENDDATE as end_date, + F001V as stat_item, + F002N as holding_shares, + F003N as holding_ratio, + F004N as ratio_change, + ORGNAME as org_name, + SECCODE as sec_code, + SECNAME as sec_name + FROM ea_concentration + WHERE SECCODE = :stock_code + ORDER BY ENDDATE DESC LIMIT 20 + """) + + result = conn.execute(query, {'stock_code': stock_code}).fetchall() + + concentration_info = [] + for row in result: + concentration_record = {} + for key, value in zip(row.keys(), row): + if value is None: + concentration_record[key] = None + elif isinstance(value, datetime): + concentration_record[key] = value.strftime('%Y-%m-%d %H:%M:%S') + elif isinstance(value, date): + concentration_record[key] = value.strftime('%Y-%m-%d') + elif isinstance(value, Decimal): + concentration_record[key] = float(value) + else: + concentration_record[key] = value + + concentration_info.append(concentration_record) + + return jsonify({ + 'success': True, + 'data': concentration_info, + 'total': len(concentration_info) + }) + + except Exception as e: + app.logger.error(f"Error getting concentration info: {e}", exc_info=True) + return jsonify({'success': False, 'error': str(e)}), 500 + + +@app.route('/api/stock//management', methods=['GET']) +def get_stock_management(stock_code): + """获取股票管理层信息""" + try: + # 获取是否只显示在职人员参数 + active_only = request.args.get('active_only', 'true').lower() == 'true' + + with engine.connect() as conn: + base_query = """ + SELECT DECLAREDATE as declare_date, \ + F001V as person_id, \ + F002V as name, \ + F007D as start_date, \ + F008D as end_date, \ + F009V as position_name, \ + F010V as gender, \ + F011V as education, \ + F012V as birth_year, \ + F013V as nationality, \ + F014V as position_category_code, \ + F015V as position_category, \ + F016V as position_code, \ + F017V as highest_degree, \ + F019V as resume, \ + F020C as is_active, \ + ORGNAME as org_name, \ + SECCODE as sec_code, \ + SECNAME as sec_name + FROM ea_management + WHERE SECCODE = :stock_code \ + """ + + if active_only: + base_query += " AND F020C = '1'" + + base_query += " ORDER BY DECLAREDATE DESC, F007D DESC" + + query = text(base_query) + + result = conn.execute(query, {'stock_code': stock_code}).fetchall() + + management_info = [] + for row in result: + management_record = {} + for key, value in zip(row.keys(), row): + if value is None: + management_record[key] = None + elif isinstance(value, datetime): + management_record[key] = value.strftime('%Y-%m-%d %H:%M:%S') + elif isinstance(value, date): + management_record[key] = value.strftime('%Y-%m-%d') + elif isinstance(value, Decimal): + management_record[key] = float(value) + else: + management_record[key] = value + + management_info.append(management_record) + + return jsonify({ + 'success': True, + 'data': management_info, + 'total': len(management_info) + }) + + except Exception as e: + app.logger.error(f"Error getting management info: {e}", exc_info=True) + return jsonify({'success': False, 'error': str(e)}), 500 + + +@app.route('/api/stock//top-circulation-shareholders', methods=['GET']) +def get_stock_top_circulation_shareholders(stock_code): + """获取股票十大流通股东信息""" + try: + limit = request.args.get('limit', 10, type=int) + + with engine.connect() as conn: + query = text(""" + SELECT DECLAREDATE as declare_date, + ENDDATE as end_date, + F001N as shareholder_rank, + F002V as shareholder_id, + F003V as shareholder_name, + F004V as shareholder_type, + F005N as holding_shares, + F006N as total_share_ratio, + F007N as circulation_share_ratio, + F011V as share_nature, + F012N as b_shares, + F013N as h_shares, + F014N as other_shares, + ORGNAME as org_name, + SECCODE as sec_code, + SECNAME as sec_name + FROM ea_tencirculation + WHERE SECCODE = :stock_code + ORDER BY ENDDATE DESC, F001N ASC LIMIT :limit + """) + + result = conn.execute(query, {'stock_code': stock_code, 'limit': limit}).fetchall() + + shareholders_info = [] + for row in result: + shareholder_record = {} + for key, value in zip(row.keys(), row): + if value is None: + shareholder_record[key] = None + elif isinstance(value, datetime): + shareholder_record[key] = value.strftime('%Y-%m-%d %H:%M:%S') + elif isinstance(value, date): + shareholder_record[key] = value.strftime('%Y-%m-%d') + elif isinstance(value, Decimal): + shareholder_record[key] = float(value) + else: + shareholder_record[key] = value + + shareholders_info.append(shareholder_record) + + return jsonify({ + 'success': True, + 'data': shareholders_info, + 'total': len(shareholders_info) + }) + + except Exception as e: + app.logger.error(f"Error getting top circulation shareholders: {e}", exc_info=True) + return jsonify({'success': False, 'error': str(e)}), 500 + + +@app.route('/api/stock//top-shareholders', methods=['GET']) +def get_stock_top_shareholders(stock_code): + """获取股票十大股东信息""" + try: + limit = request.args.get('limit', 10, type=int) + + with engine.connect() as conn: + query = text(""" + SELECT DECLAREDATE as declare_date, + ENDDATE as end_date, + F001N as shareholder_rank, + F002V as shareholder_name, + F003V as shareholder_id, + F004V as shareholder_type, + F005N as holding_shares, + F006N as total_share_ratio, + F007N as circulation_share_ratio, + F011V as share_nature, + F016N as restricted_shares, + F017V as concert_party_group, + F018N as circulation_shares, + ORGNAME as org_name, + SECCODE as sec_code, + SECNAME as sec_name + FROM ea_tenshareholder + WHERE SECCODE = :stock_code + ORDER BY ENDDATE DESC, F001N ASC LIMIT :limit + """) + + result = conn.execute(query, {'stock_code': stock_code, 'limit': limit}).fetchall() + + shareholders_info = [] + for row in result: + shareholder_record = {} + for key, value in zip(row.keys(), row): + if value is None: + shareholder_record[key] = None + elif isinstance(value, datetime): + shareholder_record[key] = value.strftime('%Y-%m-%d %H:%M:%S') + elif isinstance(value, date): + shareholder_record[key] = value.strftime('%Y-%m-%d') + elif isinstance(value, Decimal): + shareholder_record[key] = float(value) + else: + shareholder_record[key] = value + + shareholders_info.append(shareholder_record) + + return jsonify({ + 'success': True, + 'data': shareholders_info, + 'total': len(shareholders_info) + }) + + except Exception as e: + app.logger.error(f"Error getting top shareholders: {e}", exc_info=True) + return jsonify({'success': False, 'error': str(e)}), 500 + + +@app.route('/api/stock//branches', methods=['GET']) +def get_stock_branches(stock_code): + """获取股票分支机构信息""" + try: + with engine.connect() as conn: + query = text(""" + SELECT CRECODE as cre_code, + F001V as branch_name, + F002V as register_capital, + F003V as business_status, + F004D as register_date, + F005N as related_company_count, + F006V as legal_person, + ORGNAME as org_name, + SECCODE as sec_code, + SECNAME as sec_name + FROM ea_branch + WHERE SECCODE = :stock_code + ORDER BY F004D DESC + """) + + result = conn.execute(query, {'stock_code': stock_code}).fetchall() + + branches_info = [] + for row in result: + branch_record = {} + for key, value in zip(row.keys(), row): + if value is None: + branch_record[key] = None + elif isinstance(value, datetime): + branch_record[key] = value.strftime('%Y-%m-%d %H:%M:%S') + elif isinstance(value, date): + branch_record[key] = value.strftime('%Y-%m-%d') + elif isinstance(value, Decimal): + branch_record[key] = float(value) + else: + branch_record[key] = value + + branches_info.append(branch_record) + + return jsonify({ + 'success': True, + 'data': branches_info, + 'total': len(branches_info) + }) + + except Exception as e: + app.logger.error(f"Error getting branches info: {e}", exc_info=True) + return jsonify({'success': False, 'error': str(e)}), 500 + + +@app.route('/api/stock//patents', methods=['GET']) +def get_stock_patents(stock_code): + """获取股票专利信息""" + try: + limit = request.args.get('limit', 50, type=int) + patent_type = request.args.get('type', None) # 专利类型筛选 + + with engine.connect() as conn: + base_query = """ + SELECT CRECODE as cre_code, \ + F001V as patent_name, \ + F002V as application_number, \ + F003V as publication_number, \ + F004V as classification_number, \ + F005D as publication_date, \ + F006D as application_date, \ + F007V as patent_type, \ + F008V as applicant, \ + F009V as inventor, \ + ID as id, \ + ORGNAME as org_name, \ + SECCODE as sec_code, \ + SECNAME as sec_name + FROM ea_patent + WHERE SECCODE = :stock_code \ + """ + + params = {'stock_code': stock_code, 'limit': limit} + + if patent_type: + base_query += " AND F007V = :patent_type" + params['patent_type'] = patent_type + + base_query += " ORDER BY F006D DESC, F005D DESC LIMIT :limit" + + query = text(base_query) + + result = conn.execute(query, params).fetchall() + + patents_info = [] + for row in result: + patent_record = {} + for key, value in zip(row.keys(), row): + if value is None: + patent_record[key] = None + elif isinstance(value, datetime): + patent_record[key] = value.strftime('%Y-%m-%d %H:%M:%S') + elif isinstance(value, date): + patent_record[key] = value.strftime('%Y-%m-%d') + elif isinstance(value, Decimal): + patent_record[key] = float(value) + else: + patent_record[key] = value + + patents_info.append(patent_record) + + return jsonify({ + 'success': True, + 'data': patents_info, + 'total': len(patents_info) + }) + + except Exception as e: + app.logger.error(f"Error getting patents info: {e}", exc_info=True) + return jsonify({'success': False, 'error': str(e)}), 500 + + +def get_daily_kline(stock_code, event_datetime, stock_name): + """处理日K线数据""" + stock_code = stock_code.split('.')[0] + + with engine.connect() as conn: + # 获取事件日期前后的数据 + kline_sql = """ + WITH date_range AS (SELECT TRADEDATE \ + FROM ea_trade \ + WHERE SECCODE = :stock_code \ + AND TRADEDATE BETWEEN DATE_SUB(:trade_date, INTERVAL 60 DAY) \ + AND DATE_ADD(:trade_date, INTERVAL 30 DAY) \ + GROUP BY TRADEDATE \ + ORDER BY TRADEDATE) + SELECT t.TRADEDATE, + CAST(t.F003N AS FLOAT) as open, + CAST(t.F007N AS FLOAT) as close, + CAST(t.F005N AS FLOAT) as high, + CAST(t.F006N AS FLOAT) as low, + CAST(t.F004N AS FLOAT) as volume + FROM ea_trade t + JOIN date_range d \ + ON t.TRADEDATE = d.TRADEDATE + WHERE t.SECCODE = :stock_code + ORDER BY t.TRADEDATE \ + """ + + result = conn.execute(text(kline_sql), { + "stock_code": stock_code, + "trade_date": event_datetime.date() + }).fetchall() + + if not result: + return jsonify({ + 'error': 'No data available', + 'code': stock_code, + 'name': stock_name, + 'data': [], + 'trade_date': event_datetime.date().strftime('%Y-%m-%d'), + 'type': 'daily' + }) + + kline_data = [{ + 'time': row.TRADEDATE.strftime('%Y-%m-%d'), + 'open': float(row.open), + 'high': float(row.high), + 'low': float(row.low), + 'close': float(row.close), + 'volume': float(row.volume) + } for row in result] + + return jsonify({ + 'code': stock_code, + 'name': stock_name, + 'data': kline_data, + 'trade_date': event_datetime.date().strftime('%Y-%m-%d'), + 'type': 'daily', + 'is_history': True + }) + + +def get_minute_kline(stock_code, event_datetime, stock_name): + """处理分钟K线数据""" + client = get_clickhouse_client() + + target_date = get_trading_day_near_date(event_datetime.date()) + is_after_market = event_datetime.time() > dt_time(15, 0) + + # 核心逻辑改动:先判断当前日期是否是交易日,以及是否已收盘 + if target_date and is_after_market: + # 如果是交易日且已收盘,查找下一个交易日 + next_trade_date = get_trading_day_near_date(target_date + timedelta(days=1)) + if next_trade_date: + target_date = next_trade_date + + if not target_date: + return jsonify({ + 'error': 'No data available', + 'code': stock_code, + 'name': stock_name, + 'data': [], + 'trade_date': event_datetime.date().strftime('%Y-%m-%d'), + 'type': 'minute' + }) + + # 获取目标日期的完整交易时段数据 + data = client.execute(""" + SELECT + timestamp, open, high, low, close, volume, amt + FROM stock_minute + WHERE code = %(code)s + AND timestamp BETWEEN %(start)s + AND %(end)s + ORDER BY timestamp + """, { + 'code': stock_code, + 'start': datetime.combine(target_date, dt_time(9, 30)), + 'end': datetime.combine(target_date, dt_time(15, 0)) + }) + + kline_data = [{ + 'time': row[0].strftime('%H:%M'), + 'open': float(row[1]), + 'high': float(row[2]), + 'low': float(row[3]), + 'close': float(row[4]), + 'volume': float(row[5]), + 'amount': float(row[6]) + } for row in data] + + return jsonify({ + 'code': stock_code, + 'name': stock_name, + 'data': kline_data, + 'trade_date': target_date.strftime('%Y-%m-%d'), + 'type': 'minute', + 'is_history': target_date < event_datetime.date() + }) + + +def get_timeline_data(stock_code, event_datetime, stock_name): + """处理分时均价线数据(timeline)。 + 规则: + - 若事件时间在交易日的15:00之后,则展示下一个交易日的分时数据; + - 若事件日非交易日,优先展示下一个交易日;如无,则回退到最近一个交易日; + - 数据区间固定为 09:30-15:00。 + """ + client = get_clickhouse_client() + + target_date = get_trading_day_near_date(event_datetime.date()) + is_after_market = event_datetime.time() > dt_time(15, 0) + + # 与分钟K逻辑保持一致的日期选择规则 + if target_date and is_after_market: + next_trade_date = get_trading_day_near_date(target_date + timedelta(days=1)) + if next_trade_date: + target_date = next_trade_date + + if not target_date: + return jsonify({ + 'error': 'No data available', + 'code': stock_code, + 'name': stock_name, + 'data': [], + 'trade_date': event_datetime.date().strftime('%Y-%m-%d'), + 'type': 'timeline' + }) + + # 获取昨收盘价 + prev_close_query = """ + SELECT close + FROM stock_minute + WHERE code = %(code)s + AND timestamp \ + < %(start)s + ORDER BY timestamp DESC + LIMIT 1 \ + """ + + prev_close_result = client.execute(prev_close_query, { + 'code': stock_code, + 'start': datetime.combine(target_date, dt_time(9, 30)) + }) + + prev_close = float(prev_close_result[0][0]) if prev_close_result else None + + data = client.execute( + """ + SELECT + timestamp, close, volume + FROM stock_minute + WHERE code = %(code)s + AND timestamp BETWEEN %(start)s + AND %(end)s + ORDER BY timestamp + """, + { + 'code': stock_code, + 'start': datetime.combine(target_date, dt_time(9, 30)), + 'end': datetime.combine(target_date, dt_time(15, 0)), + } + ) + + timeline_data = [] + total_amount = 0 + total_volume = 0 + for row in data: + price = float(row[1]) + volume = float(row[2]) + total_amount += price * volume + total_volume += volume + avg_price = total_amount / total_volume if total_volume > 0 else price + + # 计算涨跌幅 + change_percent = ((price - prev_close) / prev_close * 100) if prev_close else 0 + + timeline_data.append({ + 'time': row[0].strftime('%H:%M'), + 'price': price, + 'avg_price': avg_price, + 'volume': volume, + 'change_percent': change_percent, + }) + + return jsonify({ + 'code': stock_code, + 'name': stock_name, + 'data': timeline_data, + 'trade_date': target_date.strftime('%Y-%m-%d'), + 'type': 'timeline', + 'is_history': target_date < event_datetime.date(), + 'prev_close': prev_close, + }) + + +# ==================== 指数行情API(与股票逻辑一致,数据表为 index_minute) ==================== +@app.route('/api/index//kline') +def get_index_kline(index_code): + chart_type = request.args.get('type', 'minute') + event_time = request.args.get('event_time') + + try: + event_datetime = datetime.fromisoformat(event_time) if event_time else datetime.now() + except ValueError: + return jsonify({'error': 'Invalid event_time format'}), 400 + + # 指数名称(暂无索引表,先返回代码本身) + index_name = index_code + + if chart_type == 'minute': + return get_index_minute_kline(index_code, event_datetime, index_name) + elif chart_type == 'timeline': + return get_index_timeline_data(index_code, event_datetime, index_name) + elif chart_type == 'daily': + return get_index_daily_kline(index_code, event_datetime, index_name) + else: + return jsonify({'error': f'Unsupported chart type: {chart_type}'}), 400 + + +def get_index_minute_kline(index_code, event_datetime, index_name): + client = get_clickhouse_client() + target_date = get_trading_day_near_date(event_datetime.date()) + + if not target_date: + return jsonify({ + 'error': 'No data available', + 'code': index_code, + 'name': index_name, + 'data': [], + 'trade_date': event_datetime.date().strftime('%Y-%m-%d'), + 'type': 'minute' + }) + + data = client.execute( + """ + SELECT timestamp, open, high, low, close, volume, amt + FROM index_minute + WHERE code = %(code)s + AND timestamp BETWEEN %(start)s + AND %(end)s + ORDER BY timestamp + """, + { + 'code': index_code, + 'start': datetime.combine(target_date, dt_time(9, 30)), + 'end': datetime.combine(target_date, dt_time(15, 0)), + } + ) + + kline_data = [{ + 'time': row[0].strftime('%H:%M'), + 'open': float(row[1]), + 'high': float(row[2]), + 'low': float(row[3]), + 'close': float(row[4]), + 'volume': float(row[5]), + 'amount': float(row[6]), + } for row in data] + + return jsonify({ + 'code': index_code, + 'name': index_name, + 'data': kline_data, + 'trade_date': target_date.strftime('%Y-%m-%d'), + 'type': 'minute', + 'is_history': target_date < event_datetime.date(), + }) + + +def get_index_timeline_data(index_code, event_datetime, index_name): + client = get_clickhouse_client() + target_date = get_trading_day_near_date(event_datetime.date()) + + if not target_date: + return jsonify({ + 'error': 'No data available', + 'code': index_code, + 'name': index_name, + 'data': [], + 'trade_date': event_datetime.date().strftime('%Y-%m-%d'), + 'type': 'timeline' + }) + + data = client.execute( + """ + SELECT timestamp, close, volume + FROM index_minute + WHERE code = %(code)s + AND timestamp BETWEEN %(start)s + AND %(end)s + ORDER BY timestamp + """, + { + 'code': index_code, + 'start': datetime.combine(target_date, dt_time(9, 30)), + 'end': datetime.combine(target_date, dt_time(15, 0)), + } + ) + + timeline = [] + total_amount = 0 + total_volume = 0 + for row in data: + price = float(row[1]) + volume = float(row[2]) + total_amount += price * volume + total_volume += volume + avg_price = total_amount / total_volume if total_volume > 0 else price + timeline.append({ + 'time': row[0].strftime('%H:%M'), + 'price': price, + 'avg_price': avg_price, + 'volume': volume, + }) + + return jsonify({ + 'code': index_code, + 'name': index_name, + 'data': timeline, + 'trade_date': target_date.strftime('%Y-%m-%d'), + 'type': 'timeline', + 'is_history': target_date < event_datetime.date(), + }) + + +def get_index_daily_kline(index_code, event_datetime, index_name): + """从 MySQL 的 stock.ea_exchangetrade 获取指数日线 + 注意:表中 INDEXCODE 无后缀,例如 000001.SH -> 000001 + 字段: + F003N 开市指数 -> open + F004N 最高指数 -> high + F005N 最低指数 -> low + F006N 最近指数 -> close(作为当日收盘或最近价使用) + F007N 昨日收市指数 -> prev_close + """ + # 去掉后缀 + code_no_suffix = index_code.split('.')[0] + + # 选择展示的最后交易日 + target_date = get_trading_day_near_date(event_datetime.date()) + if not target_date: + return jsonify({ + 'error': 'No data available', + 'code': index_code, + 'name': index_name, + 'data': [], + 'trade_date': event_datetime.date().strftime('%Y-%m-%d'), + 'type': 'daily' + }) + + # 取最近一段时间的日线(倒序再反转为升序) + with engine.connect() as conn: + rows = conn.execute(text( + """ + SELECT TRADEDATE, F003N, F004N, F005N, F006N, F007N + FROM ea_exchangetrade + WHERE INDEXCODE = :code + AND TRADEDATE <= :end_dt + ORDER BY TRADEDATE DESC LIMIT 180 + """ + ), { + 'code': code_no_suffix, + 'end_dt': datetime.combine(target_date, dt_time(23, 59, 59)) + }).fetchall() + + # 反转为时间升序 + rows = list(reversed(rows)) + + daily = [] + for i, r in enumerate(rows): + trade_dt = r[0] + open_v = r[1] + high_v = r[2] + low_v = r[3] + last_v = r[4] + prev_close_v = r[5] + + # 正确的前收盘价逻辑:使用前一个交易日的F006N(收盘价) + calculated_prev_close = None + if i > 0 and rows[i - 1][4] is not None: + # 使用前一个交易日的收盘价作为前收盘价 + calculated_prev_close = float(rows[i - 1][4]) + else: + # 第一条记录,尝试使用F007N字段作为备选 + if prev_close_v is not None and prev_close_v > 0: + calculated_prev_close = float(prev_close_v) + + daily.append({ + 'time': trade_dt.strftime('%Y-%m-%d') if hasattr(trade_dt, 'strftime') else str(trade_dt), + 'open': float(open_v) if open_v is not None else None, + 'high': float(high_v) if high_v is not None else None, + 'low': float(low_v) if low_v is not None else None, + 'close': float(last_v) if last_v is not None else None, + 'prev_close': calculated_prev_close, + }) + + return jsonify({ + 'code': index_code, + 'name': index_name, + 'data': daily, + 'trade_date': target_date.strftime('%Y-%m-%d'), + 'type': 'daily', + 'is_history': target_date < event_datetime.date(), + }) + + +# ==================== 日历API ==================== +@app.route('/api/v1/calendar/event-counts', methods=['GET']) +def get_event_counts(): + """获取日历事件数量统计""" + try: + # 获取月份参数 + year = request.args.get('year', datetime.now().year, type=int) + month = request.args.get('month', datetime.now().month, type=int) + + # 计算月份的开始和结束日期 + start_date = datetime(year, month, 1) + if month == 12: + end_date = datetime(year + 1, 1, 1) + else: + end_date = datetime(year, month + 1, 1) + + # 查询事件数量 + query = """ + SELECT DATE(calendar_time) as date, COUNT(*) as count + FROM future_events + WHERE calendar_time BETWEEN :start_date AND :end_date + AND type = 'event' + GROUP BY DATE(calendar_time) +""" + + result = db.session.execute(text(query), { + 'start_date': start_date, + 'end_date': end_date + }) + + # 格式化结果 + events = [] + for day in result: + events.append({ + 'date': day.date.isoformat(), + 'count': day.count, + 'className': get_event_class(day.count) + }) + + return jsonify({ + 'success': True, + 'data': events + }) + + except Exception as e: + return jsonify({ + 'success': False, + 'error': str(e) + }), 500 + + +@app.route('/api/v1/calendar/events', methods=['GET']) +def get_calendar_events(): + """获取指定日期的事件列表""" + date_str = request.args.get('date') + event_type = request.args.get('type', 'all') + + if not date_str: + return jsonify({ + 'success': False, + 'error': 'Date parameter required' + }), 400 + + try: + date = datetime.strptime(date_str, '%Y-%m-%d') + except ValueError: + return jsonify({ + 'success': False, + 'error': 'Invalid date format' + }), 400 + + # 修复SQL语法:去掉函数名后的空格,去掉参数前的空格 + query = """ + SELECT * + FROM future_events + WHERE DATE(calendar_time) = :date + """ + + params = {'date': date} + + if event_type != 'all': + query += " AND type = :type" + params['type'] = event_type + + query += " ORDER BY calendar_time" + + result = db.session.execute(text(query), params) + + events = [] + user_following_ids = set() + if 'user_id' in session: + follows = FutureEventFollow.query.filter_by(user_id=session['user_id']).all() + user_following_ids = {f.future_event_id for f in follows} + + for row in result: + event_data = { + 'id': row.data_id, + 'title': row.title, + 'type': row.type, + 'calendar_time': row.calendar_time.isoformat(), + 'star': row.star, + 'former': row.former, + 'forecast': row.forecast, + 'fact': row.fact, + 'is_following': row.data_id in user_following_ids + } + + # 解析相关股票和概念 + if row.related_stocks: + try: + if isinstance(row.related_stocks, str): + if row.related_stocks.startswith('['): + event_data['related_stocks'] = json.loads(row.related_stocks) + else: + event_data['related_stocks'] = row.related_stocks.split(',') + else: + event_data['related_stocks'] = row.related_stocks + except: + event_data['related_stocks'] = [] + else: + event_data['related_stocks'] = [] + + if row.concepts: + try: + if isinstance(row.concepts, str): + if row.concepts.startswith('['): + event_data['concepts'] = json.loads(row.concepts) + else: + event_data['concepts'] = row.concepts.split(',') + else: + event_data['concepts'] = row.concepts + except: + event_data['concepts'] = [] + else: + event_data['concepts'] = [] + + events.append(event_data) + + return jsonify({ + 'success': True, + 'data': events + }) + +@app.route('/api/v1/calendar/events/', methods=['GET']) +def get_calendar_event_detail(event_id): + """获取日历事件详情""" + try: + sql = """ + SELECT * + FROM future_events + WHERE data_id = :event_id \ + """ + + result = db.session.execute(text(sql), {'event_id': event_id}).first() + + if not result: + return jsonify({ + 'success': False, + 'error': 'Event not found' + }), 404 + + event_data = { + 'id': result.data_id, + 'title': result.title, + 'type': result.type, + 'calendar_time': result.calendar_time.isoformat(), + 'star': result.star, + 'former': result.former, + 'forecast': result.forecast, + 'fact': result.fact, + 'related_stocks': parse_json_field(result.related_stocks), + 'concepts': parse_json_field(result.concepts) + } + + # 检查当前用户是否关注了该未来事件 + if 'user_id' in session: + is_following = FutureEventFollow.query.filter_by( + user_id=session['user_id'], + future_event_id=event_id + ).first() is not None + event_data['is_following'] = is_following + else: + event_data['is_following'] = False + + return jsonify({ + 'success': True, + 'data': event_data + }) + + except Exception as e: + return jsonify({ + 'success': False, + 'error': str(e) + }), 500 + + +@app.route('/api/v1/calendar/events//follow', methods=['POST']) +def toggle_future_event_follow(event_id): + """切换未来事件关注状态(需登录)""" + if 'user_id' not in session: + return jsonify({'success': False, 'error': '未登录'}), 401 + + try: + # 检查未来事件是否存在 + sql = """ + SELECT data_id \ + FROM future_events \ + WHERE data_id = :event_id \ + """ + result = db.session.execute(text(sql), {'event_id': event_id}).first() + + if not result: + return jsonify({'success': False, 'error': '未来事件不存在'}), 404 + + user_id = session['user_id'] + + # 检查是否已关注 + existing = FutureEventFollow.query.filter_by( + user_id=user_id, + future_event_id=event_id + ).first() + + if existing: + # 取消关注 + db.session.delete(existing) + db.session.commit() + return jsonify({ + 'success': True, + 'data': {'is_following': False} + }) + else: + # 关注 + follow = FutureEventFollow( + user_id=user_id, + future_event_id=event_id + ) + db.session.add(follow) + db.session.commit() + return jsonify({ + 'success': True, + 'data': {'is_following': True} + }) + except Exception as e: + db.session.rollback() + return jsonify({'success': False, 'error': str(e)}), 500 + + +def get_event_class(count): + """根据事件数量返回CSS类名""" + if count >= 10: + return 'event-high' + elif count >= 5: + return 'event-medium' + elif count > 0: + return 'event-low' + return '' + + +def parse_json_field(field_value): + """解析JSON字段""" + if not field_value: + return [] + try: + if isinstance(field_value, str): + if field_value.startswith('['): + return json.loads(field_value) + else: + return field_value.split(',') + else: + return field_value + except: + return [] + + +# ==================== 行业API ==================== +@app.route('/api/classifications', methods=['GET']) +def get_classifications(): + """获取申银万国行业分类树形结构""" + try: + # 查询申银万国行业分类的所有数据 + sql = """ + SELECT f003v as code, f004v as level1, f005v as level2, f006v as level3,f007v as level4 + FROM ea_sector + WHERE f002v = '申银万国行业分类' + AND f003v IS NOT NULL + AND f004v IS NOT NULL + ORDER BY f003v + """ + + result = db.session.execute(text(sql)).all() + + # 构建树形结构 + tree_dict = {} + + for row in result: + code = row.code + level1 = row.level1 + level2 = row.level2 + level3 = row.level3 + + # 跳过空数据 + if not level1: + continue + + # 第一层 + if level1 not in tree_dict: + # 获取第一层的code(取前3位或前缀) + level1_code = code[:3] if len(code) >= 3 else code + tree_dict[level1] = { + 'value': level1_code, + 'label': level1, + 'children_dict': {} + } + + # 第二层 + if level2: + if level2 not in tree_dict[level1]['children_dict']: + # 获取第二层的code(取前6位) + level2_code = code[:6] if len(code) >= 6 else code + tree_dict[level1]['children_dict'][level2] = { + 'value': level2_code, + 'label': level2, + 'children_dict': {} + } + + # 第三层 + if level3: + if level3 not in tree_dict[level1]['children_dict'][level2]['children_dict']: + tree_dict[level1]['children_dict'][level2]['children_dict'][level3] = { + 'value': code, + 'label': level3 + } + + # 转换为最终格式 + result_list = [] + for level1_name, level1_data in tree_dict.items(): + level1_node = { + 'value': level1_data['value'], + 'label': level1_data['label'] + } + + # 处理第二层 + if level1_data['children_dict']: + level1_children = [] + for level2_name, level2_data in level1_data['children_dict'].items(): + level2_node = { + 'value': level2_data['value'], + 'label': level2_data['label'] + } + + # 处理第三层 + if level2_data['children_dict']: + level2_children = [] + for level3_name, level3_data in level2_data['children_dict'].items(): + level2_children.append({ + 'value': level3_data['value'], + 'label': level3_data['label'] + }) + if level2_children: + level2_node['children'] = level2_children + + level1_children.append(level2_node) + + if level1_children: + level1_node['children'] = level1_children + + result_list.append(level1_node) + + return jsonify({ + 'success': True, + 'data': result_list + }) + + except Exception as e: + return jsonify({ + 'success': False, + 'error': str(e) + }), 500 + + +@app.route('/api/stocklist', methods=['GET']) +def get_stock_list(): + """获取股票列表""" + try: + sql = """ + SELECT DISTINCT SECCODE as code, SECNAME as name + FROM ea_stocklist + ORDER BY SECCODE + """ + + result = db.session.execute(text(sql)).all() + + stocks = [{'code': row.code, 'name': row.name} for row in result] + + return jsonify(stocks) + + except Exception as e: + return jsonify({ + 'success': False, + 'error': str(e) + }), 500 + + +@app.route('/api/events', methods=['GET'], strict_slashes=False) +def api_get_events(): + """ + 获取事件列表API - 支持筛选、排序、分页,兼容前端调用 + """ + try: + # 分页参数 + page = max(1, request.args.get('page', 1, type=int)) + per_page = min(100, max(1, request.args.get('per_page', 10, type=int))) + + # 基础筛选参数 + event_type = request.args.get('type', 'all') + event_status = request.args.get('status', 'active') + importance = request.args.get('importance', 'all') + + # 日期筛选参数 + start_date = request.args.get('start_date') + end_date = request.args.get('end_date') + date_range = request.args.get('date_range') + recent_days = request.args.get('recent_days', type=int) + + # 行业筛选参数(只支持申银万国行业分类) + industry_code = request.args.get('industry_code') # 申万行业代码,如 "S370502" + + # 概念/标签筛选参数 + tag = request.args.get('tag') + tags = request.args.get('tags') + keywords = request.args.get('keywords') + + # 搜索参数 + search_query = request.args.get('q') + search_type = request.args.get('search_type', 'topic') + search_fields = request.args.get('search_fields', 'title,description').split(',') + + # 排序参数 + sort_by = request.args.get('sort', 'new') + return_type = request.args.get('return_type', 'avg') + order = request.args.get('order', 'desc') + + # 收益率筛选参数 + min_avg_return = request.args.get('min_avg_return', type=float) + max_avg_return = request.args.get('max_avg_return', type=float) + min_max_return = request.args.get('min_max_return', type=float) + max_max_return = request.args.get('max_max_return', type=float) + min_week_return = request.args.get('min_week_return', type=float) + max_week_return = request.args.get('max_week_return', type=float) + + # 其他筛选参数 + min_hot_score = request.args.get('min_hot_score', type=float) + max_hot_score = request.args.get('max_hot_score', type=float) + min_view_count = request.args.get('min_view_count', type=int) + creator_id = request.args.get('creator_id', type=int) + + # 返回格式参数 + include_creator = request.args.get('include_creator', 'true').lower() == 'true' + include_stats = request.args.get('include_stats', 'true').lower() == 'true' + include_related_data = request.args.get('include_related_data', 'false').lower() == 'true' + + # ==================== 构建查询 ==================== + query = Event.query + if event_status != 'all': + query = query.filter_by(status=event_status) + if event_type != 'all': + query = query.filter_by(event_type=event_type) + # 支持多个重要性级别筛选,用逗号分隔(如 importance=S,A) + if importance != 'all': + if ',' in importance: + # 多个重要性级别 + importance_list = [imp.strip() for imp in importance.split(',') if imp.strip()] + query = query.filter(Event.importance.in_(importance_list)) + else: + # 单个重要性级别 + query = query.filter_by(importance=importance) + if creator_id: + query = query.filter_by(creator_id=creator_id) + # 新增:行业代码过滤(申银万国行业分类) + if industry_code: + # related_industries 格式: [{"申银万国行业分类": "S370502"}, ...] + # 支持多个行业代码,用逗号分隔 + json_path = '$[*]."申银万国行业分类"' + + # 如果包含逗号,说明是多个行业代码 + if ',' in industry_code: + codes = [code.strip() for code in industry_code.split(',') if code.strip()] + # 使用 OR 条件匹配任意一个行业代码 + conditions = [] + for code in codes: + conditions.append( + text("JSON_CONTAINS(JSON_EXTRACT(related_industries, :json_path), :code)") + .bindparams(json_path=json_path, code=json.dumps(code)) + ) + query = query.filter(db.or_(*conditions)) + else: + # 单个行业代码 + query = query.filter( + text("JSON_CONTAINS(JSON_EXTRACT(related_industries, :json_path), :industry_code)") + ).params(json_path=json_path, industry_code=json.dumps(industry_code)) + # 新增:关键词/全文搜索过滤(MySQL JSON) + if search_query: + like_pattern = f"%{search_query}%" + query = query.filter( + db.or_( + Event.title.ilike(like_pattern), + Event.description.ilike(like_pattern), + text(f"JSON_SEARCH(keywords, 'one', '%{search_query}%') IS NOT NULL") + ) + ) + if recent_days: + from datetime import datetime, timedelta + cutoff_date = datetime.now() - timedelta(days=recent_days) + query = query.filter(Event.created_at >= cutoff_date) + else: + if date_range and ' 至 ' in date_range: + try: + start_date_str, end_date_str = date_range.split(' 至 ') + start_date = start_date_str.strip() + end_date = end_date_str.strip() + except ValueError: + pass + if start_date: + from datetime import datetime + try: + if len(start_date) == 10: + start_datetime = datetime.strptime(start_date, '%Y-%m-%d') + else: + start_datetime = datetime.strptime(start_date, '%Y-%m-%d %H:%M:%S') + query = query.filter(Event.created_at >= start_datetime) + except ValueError: + pass + if end_date: + from datetime import datetime + try: + if len(end_date) == 10: + end_datetime = datetime.strptime(end_date, '%Y-%m-%d') + end_datetime = end_datetime.replace(hour=23, minute=59, second=59) + else: + end_datetime = datetime.strptime(end_date, '%Y-%m-%d %H:%M:%S') + query = query.filter(Event.created_at <= end_datetime) + except ValueError: + pass + if min_view_count is not None: + query = query.filter(Event.view_count >= min_view_count) + # 排序 + from sqlalchemy import desc, asc, case + order_func = desc if order.lower() == 'desc' else asc + if sort_by == 'hot': + query = query.order_by(order_func(Event.hot_score)) + elif sort_by == 'new': + query = query.order_by(order_func(Event.created_at)) + elif sort_by == 'returns': + if return_type == 'avg': + query = query.order_by(order_func(Event.related_avg_chg)) + elif return_type == 'max': + query = query.order_by(order_func(Event.related_max_chg)) + elif return_type == 'week': + query = query.order_by(order_func(Event.related_week_chg)) + elif sort_by == 'importance': + importance_order = case( + (Event.importance == 'S', 1), + (Event.importance == 'A', 2), + (Event.importance == 'B', 3), + (Event.importance == 'C', 4), + else_=5 + ) + if order.lower() == 'desc': + query = query.order_by(importance_order) + else: + query = query.order_by(desc(importance_order)) + elif sort_by == 'view_count': + query = query.order_by(order_func(Event.view_count)) + # 分页 + paginated = query.paginate(page=page, per_page=per_page, error_out=False) + events_data = [] + for event in paginated.items: + event_dict = { + 'id': event.id, + 'title': event.title, + 'description': event.description, + 'event_type': event.event_type, + 'importance': event.importance, + 'status': event.status, + 'created_at': event.created_at.isoformat() if event.created_at else None, + 'updated_at': event.updated_at.isoformat() if event.updated_at else None, + 'start_time': event.start_time.isoformat() if event.start_time else None, + 'end_time': event.end_time.isoformat() if event.end_time else None, + } + if include_stats: + event_dict.update({ + 'hot_score': event.hot_score, + 'view_count': event.view_count, + 'post_count': event.post_count, + 'follower_count': event.follower_count, + 'related_avg_chg': event.related_avg_chg, + 'related_max_chg': event.related_max_chg, + 'related_week_chg': event.related_week_chg, + 'invest_score': event.invest_score, + 'trending_score': event.trending_score, + }) + if include_creator: + event_dict['creator'] = { + 'id': event.creator.id if event.creator else None, + 'username': event.creator.username if event.creator else 'Anonymous' + } + event_dict['keywords'] = event.keywords_list if hasattr(event, 'keywords_list') else event.keywords + event_dict['related_industries'] = event.related_industries + if include_related_data: + pass + events_data.append(event_dict) + applied_filters = {} + if event_type != 'all': + applied_filters['type'] = event_type + if importance != 'all': + applied_filters['importance'] = importance + if start_date: + applied_filters['start_date'] = start_date + if end_date: + applied_filters['end_date'] = end_date + if industry_code: + applied_filters['industry_code'] = industry_code + if tag: + applied_filters['tag'] = tag + if tags: + applied_filters['tags'] = tags + if search_query: + applied_filters['search_query'] = search_query + applied_filters['search_type'] = search_type + return jsonify({ + 'success': True, + 'data': { + 'events': events_data, + 'pagination': { + 'page': paginated.page, + 'per_page': paginated.per_page, + 'total': paginated.total, + 'pages': paginated.pages, + 'has_prev': paginated.has_prev, + 'has_next': paginated.has_next + }, + 'filters': { + 'applied_filters': applied_filters, + 'total_count': paginated.total + } + } + }) + except Exception as e: + app.logger.error(f"获取事件列表出错: {str(e)}", exc_info=True) + return jsonify({ + 'success': False, + 'error': str(e), + 'error_type': type(e).__name__ + }), 500 + + +@app.route('/api/events/hot', methods=['GET']) +def get_hot_events(): + """获取热点事件""" + try: + from datetime import datetime, timedelta + days = request.args.get('days', 3, type=int) + limit = request.args.get('limit', 4, type=int) + since_date = datetime.now() - timedelta(days=days) + hot_events = Event.query.filter( + Event.status == 'active', + Event.created_at >= since_date, + Event.related_avg_chg != None, + Event.related_avg_chg > 0 + ).order_by(Event.related_avg_chg.desc()).limit(limit).all() + if len(hot_events) < limit: + additional_events = Event.query.filter( + Event.status == 'active', + Event.created_at >= since_date, + ~Event.id.in_([event.id for event in hot_events]) + ).order_by(Event.hot_score.desc()).limit(limit - len(hot_events)).all() + hot_events.extend(additional_events) + events_data = [] + for event in hot_events: + events_data.append({ + 'id': event.id, + 'title': event.title, + 'description': event.description, + 'importance': event.importance, + 'created_at': event.created_at.isoformat() if event.created_at else None, + 'related_avg_chg': event.related_avg_chg, + 'creator': { + 'username': event.creator.username if event.creator else 'Anonymous' + } + }) + return jsonify({'success': True, 'data': events_data}) + except Exception as e: + return jsonify({'success': False, 'error': str(e)}), 500 + + +@app.route('/api/events/keywords/popular', methods=['GET']) +def get_popular_keywords(): + """获取热门关键词""" + try: + limit = request.args.get('limit', 20, type=int) + sql = ''' + WITH RECURSIVE \ + numbers AS (SELECT 0 as n \ + UNION ALL \ + SELECT n + 1 \ + FROM numbers \ + WHERE n < 100), \ + json_array AS (SELECT JSON_UNQUOTE(JSON_EXTRACT(e.keywords, CONCAT('$[', n.n, ']'))) as keyword, \ + COUNT(*) as count + FROM event e + CROSS JOIN numbers n + WHERE + e.status = 'active' + AND JSON_EXTRACT(e.keywords \ + , CONCAT('$[' \ + , n.n \ + , ']')) IS NOT NULL + GROUP BY JSON_UNQUOTE(JSON_EXTRACT(e.keywords, CONCAT('$[', n.n, ']'))) + HAVING keyword IS NOT NULL + ) + SELECT keyword, count + FROM json_array + ORDER BY count DESC, keyword LIMIT :limit \ + ''' + result = db.session.execute(text(sql), {'limit': limit}).all() + keywords_data = [{'keyword': row.keyword, 'count': row.count} for row in result] + return jsonify({'success': True, 'data': keywords_data}) + except Exception as e: + return jsonify({'success': False, 'error': str(e)}), 500 + + +@app.route('/api/events//sankey-data') +def get_event_sankey_data(event_id): + """ + 获取事件桑基图数据 (最终优化版) + - 处理重名节点 + - 检测并打破循环依赖 + """ + flows = EventSankeyFlow.query.filter_by(event_id=event_id).order_by( + EventSankeyFlow.source_level, EventSankeyFlow.target_level + ).all() + + if not flows: + return jsonify({'success': False, 'message': '暂无桑基图数据'}) + + nodes_map = {} + links = [] + type_colors = { + 'event': '#ff4757', 'policy': '#10ac84', 'technology': '#ee5a6f', + 'industry': '#00d2d3', 'company': '#54a0ff', 'product': '#ffd93d' + } + + # --- 1. 识别并处理重名节点 (与上一版相同) --- + all_node_keys = set() + name_counts = {} + for flow in flows: + source_key = f"{flow.source_node}|{flow.source_level}" + target_key = f"{flow.target_node}|{flow.target_level}" + all_node_keys.add(source_key) + all_node_keys.add(target_key) + name_counts.setdefault(flow.source_node, set()).add(flow.source_level) + name_counts.setdefault(flow.target_node, set()).add(flow.target_level) + + duplicate_names = {name for name, levels in name_counts.items() if len(levels) > 1} + + for flow in flows: + source_key = f"{flow.source_node}|{flow.source_level}" + if source_key not in nodes_map: + display_name = f"{flow.source_node} (L{flow.source_level})" if flow.source_node in duplicate_names else flow.source_node + nodes_map[source_key] = {'name': display_name, 'type': flow.source_type, 'level': flow.source_level, + 'color': type_colors.get(flow.source_type)} + + target_key = f"{flow.target_node}|{flow.target_level}" + if target_key not in nodes_map: + display_name = f"{flow.target_node} (L{flow.target_level})" if flow.target_node in duplicate_names else flow.target_node + nodes_map[target_key] = {'name': display_name, 'type': flow.target_type, 'level': flow.target_level, + 'color': type_colors.get(flow.target_type)} + + links.append({ + 'source_key': source_key, 'target_key': target_key, 'value': float(flow.flow_value), + 'ratio': float(flow.flow_ratio), 'transmission_path': flow.transmission_path, + 'impact_description': flow.impact_description, 'evidence_strength': flow.evidence_strength + }) + + # --- 2. 循环检测与处理 --- + # 构建邻接表 + adj = defaultdict(list) + for link in links: + adj[link['source_key']].append(link['target_key']) + + # 深度优先搜索(DFS)来检测循环 + path = set() # 记录当前递归路径上的节点 + visited = set() # 记录所有访问过的节点 + back_edges = set() # 记录导致循环的"回流边" + + def detect_cycle_util(node): + path.add(node) + visited.add(node) + for neighbour in adj.get(node, []): + if neighbour in path: + # 发现了循环,记录这条回流边 (target, source) + back_edges.add((neighbour, node)) + elif neighbour not in visited: + detect_cycle_util(neighbour) + path.remove(node) + + # 从所有节点开始检测 + for node_key in list(adj.keys()): + if node_key not in visited: + detect_cycle_util(node_key) + + # 过滤掉导致循环的边 + if back_edges: + print(f"检测到并移除了 {len(back_edges)} 条循环边: {back_edges}") + + valid_links_no_cycle = [] + for link in links: + if (link['source_key'], link['target_key']) not in back_edges and \ + (link['target_key'], link['source_key']) not in back_edges: # 移除非严格意义上的双向边 + valid_links_no_cycle.append(link) + + # --- 3. 构建最终的 JSON 响应 (与上一版相似) --- + node_list = [] + node_index_map = {} + sorted_node_keys = sorted(nodes_map.keys(), key=lambda k: (nodes_map[k]['level'], nodes_map[k]['name'])) + + for i, key in enumerate(sorted_node_keys): + node_list.append(nodes_map[key]) + node_index_map[key] = i + + final_links = [] + for link in valid_links_no_cycle: + source_idx = node_index_map.get(link['source_key']) + target_idx = node_index_map.get(link['target_key']) + if source_idx is not None and target_idx is not None: + # 移除临时的 key,只保留 ECharts 需要的字段 + link.pop('source_key', None) + link.pop('target_key', None) + link['source'] = source_idx + link['target'] = target_idx + final_links.append(link) + + # ... (统计信息计算部分保持不变) ... + stats = { + 'total_nodes': len(node_list), 'total_flows': len(final_links), + 'total_flow_value': sum(link['value'] for link in final_links), + 'max_level': max((node['level'] for node in node_list), default=0), + 'node_type_counts': {ntype: sum(1 for n in node_list if n['type'] == ntype) for ntype in type_colors} + } + + return jsonify({ + 'success': True, + 'data': {'nodes': node_list, 'links': final_links, 'stats': stats} + }) + + +# 优化后的传导链分析 API +@app.route('/api/events//chain-analysis') +def get_event_chain_analysis(event_id): + """获取事件传导链分析数据""" + nodes = EventTransmissionNode.query.filter_by(event_id=event_id).all() + if not nodes: + return jsonify({'success': False, 'message': '暂无传导链分析数据'}) + + edges = EventTransmissionEdge.query.filter_by(event_id=event_id).all() + + # 过滤孤立节点 + connected_node_ids = set() + for edge in edges: + connected_node_ids.add(edge.from_node_id) + connected_node_ids.add(edge.to_node_id) + + # 只保留有连接的节点 + connected_nodes = [node for node in nodes if node.id in connected_node_ids] + + if not connected_nodes: + return jsonify({'success': False, 'message': '所有节点都是孤立的,暂无传导关系'}) + + # 节点分类,用于力导向图的图例 + categories = { + 'event': "事件", 'industry': "行业", 'company': "公司", + 'policy': "政策", 'technology': "技术", 'market': "市场", 'other': "其他" + } + + # 计算每个节点的连接数 + node_connection_count = {} + for node in connected_nodes: + count = sum(1 for edge in edges + if edge.from_node_id == node.id or edge.to_node_id == node.id) + node_connection_count[node.id] = count + + nodes_data = [] + for node in connected_nodes: + connection_count = node_connection_count[node.id] + + nodes_data.append({ + 'id': str(node.id), + 'name': node.node_name, + 'value': node.importance_score, # 用于控制节点大小的基础值 + 'category': categories.get(node.node_type), + 'extra': { + 'node_type': node.node_type, + 'description': node.node_description, + 'importance_score': node.importance_score, + 'stock_code': node.stock_code, + 'is_main_event': node.is_main_event, + 'connection_count': connection_count, # 添加连接数信息 + } + }) + + edges_data = [] + for edge in edges: + # 确保边的两端节点都在连接节点列表中 + if edge.from_node_id in connected_node_ids and edge.to_node_id in connected_node_ids: + edges_data.append({ + 'source': str(edge.from_node_id), + 'target': str(edge.to_node_id), + 'value': edge.strength, # 用于控制边的宽度 + 'extra': { + 'transmission_type': edge.transmission_type, + 'transmission_mechanism': edge.transmission_mechanism, + 'direction': edge.direction, + 'strength': edge.strength, + 'impact': edge.impact, + 'is_circular': edge.is_circular, + } + }) + + # 重新计算统计信息(基于连接的节点和边) + stats = { + 'total_nodes': len(connected_nodes), + 'total_edges': len(edges_data), + 'node_types': {cat: sum(1 for n in connected_nodes if n.node_type == node_type) + for node_type, cat in categories.items()}, + 'edge_types': {edge.transmission_type: sum(1 for e in edges_data + if e['extra']['transmission_type'] == edge.transmission_type) for + edge in edges}, + 'avg_importance': sum(node.importance_score for node in connected_nodes) / len( + connected_nodes) if connected_nodes else 0, + 'avg_strength': sum(edge.strength for edge in edges) / len(edges) if edges else 0 + } + + return jsonify({ + 'success': True, + 'data': { + 'nodes': nodes_data, + 'edges': edges_data, + 'categories': list(categories.values()), + 'stats': stats + } + }) + + +@app.route('/api/events//chain-node/', methods=['GET']) +@cross_origin() +def get_chain_node_detail(event_id, node_id): + """获取传导链节点及其直接关联节点的详细信息""" + node = db.session.get(EventTransmissionNode, node_id) + if not node or node.event_id != event_id: + return jsonify({'success': False, 'message': '节点不存在'}) + + # 验证节点是否为孤立节点 + total_connections = (EventTransmissionEdge.query.filter_by(from_node_id=node_id).count() + + EventTransmissionEdge.query.filter_by(to_node_id=node_id).count()) + + if total_connections == 0 and not node.is_main_event: + return jsonify({'success': False, 'message': '该节点为孤立节点,无连接关系'}) + + # 找出影响当前节点的父节点 + parents_info = [] + incoming_edges = EventTransmissionEdge.query.filter_by(to_node_id=node_id).all() + for edge in incoming_edges: + parent = db.session.get(EventTransmissionNode, edge.from_node_id) + if parent: + parents_info.append({ + 'id': parent.id, + 'name': parent.node_name, + 'type': parent.node_type, + 'direction': edge.direction, + 'strength': edge.strength, + 'transmission_type': edge.transmission_type, + 'transmission_mechanism': edge.transmission_mechanism, # 修复字段名 + 'is_circular': edge.is_circular, + 'impact': edge.impact + }) + + # 找出被当前节点影响的子节点 + children_info = [] + outgoing_edges = EventTransmissionEdge.query.filter_by(from_node_id=node_id).all() + for edge in outgoing_edges: + child = db.session.get(EventTransmissionNode, edge.to_node_id) + if child: + children_info.append({ + 'id': child.id, + 'name': child.node_name, + 'type': child.node_type, + 'direction': edge.direction, + 'strength': edge.strength, + 'transmission_type': edge.transmission_type, + 'transmission_mechanism': edge.transmission_mechanism, # 修复字段名 + 'is_circular': edge.is_circular, + 'impact': edge.impact + }) + + node_data = { + 'id': node.id, + 'name': node.node_name, + 'type': node.node_type, + 'description': node.node_description, + 'importance_score': node.importance_score, + 'stock_code': node.stock_code, + 'is_main_event': node.is_main_event, + 'total_connections': total_connections, + 'incoming_connections': len(incoming_edges), + 'outgoing_connections': len(outgoing_edges) + } + + return jsonify({ + 'success': True, + 'data': { + 'node': node_data, + 'parents': parents_info, + 'children': children_info + } + }) + + +@app.route('/api/events//posts', methods=['GET']) +def get_event_posts(event_id): + """获取事件下的帖子""" + try: + sort_type = request.args.get('sort', 'latest') + page = request.args.get('page', 1, type=int) + per_page = request.args.get('per_page', 20, type=int) + + # 查询事件下的帖子 + query = Post.query.filter_by(event_id=event_id, status='active') + + if sort_type == 'hot': + query = query.order_by(Post.likes_count.desc(), Post.created_at.desc()) + else: # latest + query = query.order_by(Post.created_at.desc()) + + # 分页 + pagination = query.paginate(page=page, per_page=per_page, error_out=False) + posts = pagination.items + + posts_data = [] + for post in posts: + post_dict = { + 'id': post.id, + 'event_id': post.event_id, + 'user_id': post.user_id, + 'title': post.title, + 'content': post.content, + 'content_type': post.content_type, + 'created_at': post.created_at.isoformat(), + 'updated_at': post.updated_at.isoformat(), + 'likes_count': post.likes_count, + 'comments_count': post.comments_count, + 'view_count': post.view_count, + 'is_top': post.is_top, + 'user': { + 'id': post.user.id, + 'username': post.user.username, + 'avatar_url': post.user.avatar_url + } if post.user else None, + 'liked': False # 后续可以根据当前用户判断 + } + posts_data.append(post_dict) + + return jsonify({ + 'success': True, + 'data': posts_data, + 'pagination': { + 'page': page, + 'per_page': per_page, + 'total': pagination.total, + 'pages': pagination.pages + } + }) + + except Exception as e: + print(f"获取帖子失败: {e}") + return jsonify({ + 'success': False, + 'error': str(e) + }), 500 + + +@app.route('/api/events//posts', methods=['POST']) +@login_required +def create_event_post(event_id): + """在事件下创建帖子""" + try: + data = request.get_json() + content = data.get('content', '').strip() + title = data.get('title', '').strip() + content_type = data.get('content_type', 'text') + + if not content: + return jsonify({ + 'success': False, + 'message': '帖子内容不能为空' + }), 400 + + # 创建新帖子 + post = Post( + event_id=event_id, + user_id=current_user.id, + title=title, + content=content, + content_type=content_type + ) + + db.session.add(post) + + # 更新事件的帖子数 + event = Event.query.get(event_id) + if event: + event.post_count = Post.query.filter_by(event_id=event_id, status='active').count() + + # 更新用户发帖数 + current_user.post_count = (current_user.post_count or 0) + 1 + + db.session.commit() + + return jsonify({ + 'success': True, + 'data': { + 'id': post.id, + 'event_id': post.event_id, + 'user_id': post.user_id, + 'title': post.title, + 'content': post.content, + 'content_type': post.content_type, + 'created_at': post.created_at.isoformat(), + 'user': { + 'id': current_user.id, + 'username': current_user.username, + 'avatar_url': current_user.avatar_url + } + }, + 'message': '帖子发布成功' + }) + + except Exception as e: + db.session.rollback() + print(f"创建帖子失败: {e}") + return jsonify({ + 'success': False, + 'message': str(e) + }), 500 + + +@app.route('/api/posts//comments', methods=['GET']) +def get_post_comments(post_id): + """获取帖子的评论""" + try: + sort_type = request.args.get('sort', 'latest') + + # 查询帖子的顶级评论(非回复) + query = Comment.query.filter_by(post_id=post_id, parent_id=None, status='active') + + if sort_type == 'hot': + comments = query.order_by(Comment.likes_count.desc(), Comment.created_at.desc()).all() + else: # latest + comments = query.order_by(Comment.created_at.desc()).all() + + comments_data = [] + for comment in comments: + comment_dict = { + 'id': comment.id, + 'post_id': comment.post_id, + 'user_id': comment.user_id, + 'content': comment.content, + 'created_at': comment.created_at.isoformat(), + 'updated_at': comment.updated_at.isoformat(), + 'likes_count': comment.likes_count, + 'user': { + 'id': comment.user.id, + 'username': comment.user.username, + 'avatar_url': comment.user.avatar_url + } if comment.user else None, + 'replies': [] # 加载回复 + } + + # 加载回复 + replies = Comment.query.filter_by(parent_id=comment.id, status='active').order_by(Comment.created_at).all() + for reply in replies: + reply_dict = { + 'id': reply.id, + 'post_id': reply.post_id, + 'user_id': reply.user_id, + 'content': reply.content, + 'parent_id': reply.parent_id, + 'created_at': reply.created_at.isoformat(), + 'likes_count': reply.likes_count, + 'user': { + 'id': reply.user.id, + 'username': reply.user.username, + 'avatar_url': reply.user.avatar_url + } if reply.user else None + } + comment_dict['replies'].append(reply_dict) + + comments_data.append(comment_dict) + + return jsonify({ + 'success': True, + 'data': comments_data + }) + + except Exception as e: + print(f"获取评论失败: {e}") + return jsonify({ + 'success': False, + 'error': str(e) + }), 500 + + +@app.route('/api/posts//comments', methods=['POST']) +@login_required +def create_post_comment(post_id): + """在帖子下创建评论""" + try: + data = request.get_json() + content = data.get('content', '').strip() + parent_id = data.get('parent_id') + + if not content: + return jsonify({ + 'success': False, + 'message': '评论内容不能为空' + }), 400 + + # 创建新评论 + comment = Comment( + post_id=post_id, + user_id=current_user.id, + content=content, + parent_id=parent_id + ) + + db.session.add(comment) + + # 更新帖子评论数 + post = Post.query.get(post_id) + if post: + post.comments_count = Comment.query.filter_by(post_id=post_id, status='active').count() + + # 更新用户评论数 + current_user.comment_count = (current_user.comment_count or 0) + 1 + + db.session.commit() + + return jsonify({ + 'success': True, + 'data': { + 'id': comment.id, + 'post_id': comment.post_id, + 'user_id': comment.user_id, + 'content': comment.content, + 'parent_id': comment.parent_id, + 'created_at': comment.created_at.isoformat(), + 'user': { + 'id': current_user.id, + 'username': current_user.username, + 'avatar_url': current_user.avatar_url + } + }, + 'message': '评论发布成功' + }) + + except Exception as e: + db.session.rollback() + print(f"创建评论失败: {e}") + return jsonify({ + 'success': False, + 'message': str(e) + }), 500 + + +# 兼容旧的评论接口,转换为帖子模式 +@app.route('/api/events//comments', methods=['GET']) +def get_event_comments(event_id): + """获取事件评论(兼容旧接口)""" + # 将事件评论转换为获取事件下所有帖子的评论 + return get_event_posts(event_id) + + +@app.route('/api/events//comments', methods=['POST']) +@login_required +def add_event_comment(event_id): + """添加事件评论(兼容旧接口)""" + try: + data = request.get_json() + content = data.get('content', '').strip() + parent_id = data.get('parent_id') + + if not content: + return jsonify({ + 'success': False, + 'message': '评论内容不能为空' + }), 400 + + # 如果有 parent_id,说明是回复,需要找到对应的帖子 + if parent_id: + # 这是一个回复,需要将其转换为对应帖子的评论 + # 首先需要找到 parent_id 对应的帖子 + # 这里假设旧的 parent_id 是之前的 EventComment id + # 需要在数据迁移时处理这个映射关系 + return jsonify({ + 'success': False, + 'message': '回复功能正在升级中,请稍后再试' + }), 503 + + # 如果没有 parent_id,说明是顶级评论,创建为新帖子 + post = Post( + event_id=event_id, + user_id=current_user.id, + content=content, + content_type='text' + ) + + db.session.add(post) + + # 更新事件的帖子数 + event = Event.query.get(event_id) + if event: + event.post_count = Post.query.filter_by(event_id=event_id, status='active').count() + + # 更新用户发帖数 + current_user.post_count = (current_user.post_count or 0) + 1 + + db.session.commit() + + # 返回兼容旧接口的数据格式 + return jsonify({ + 'success': True, + 'data': { + 'id': post.id, + 'event_id': post.event_id, + 'user_id': post.user_id, + 'author': current_user.username, + 'content': post.content, + 'parent_id': None, + 'likes': 0, + 'created_at': post.created_at.isoformat(), + 'status': 'active', + 'user': { + 'id': current_user.id, + 'username': current_user.username, + 'avatar_url': current_user.avatar_url + }, + 'replies': [] + }, + 'message': '评论发布成功' + }) + + except Exception as e: + db.session.rollback() + print(f"添加事件评论失败: {e}") + return jsonify({ + 'success': False, + 'message': str(e) + }), 500 + + +# ==================== WebSocket 事件处理器(实时事件推送) ==================== + +@socketio.on('connect') +def handle_connect(): + """客户端连接事件""" + print(f'\n[WebSocket DEBUG] ========== 客户端连接 ==========') + print(f'[WebSocket DEBUG] Socket ID: {request.sid}') + print(f'[WebSocket DEBUG] Remote Address: {request.remote_addr if hasattr(request, "remote_addr") else "N/A"}') + print(f'[WebSocket] 客户端已连接: {request.sid}') + + emit('connection_response', { + 'status': 'connected', + 'sid': request.sid, + 'message': '已连接到事件推送服务' + }) + print(f'[WebSocket DEBUG] ✓ 已发送 connection_response') + print(f'[WebSocket DEBUG] ========== 连接完成 ==========\n') + + +@socketio.on('subscribe_events') +def handle_subscribe(data): + """ + 客户端订阅事件推送 + data: { + 'event_type': 'all' | 'policy' | 'market' | 'tech' | ..., + 'importance': 'all' | 'S' | 'A' | 'B' | 'C', + 'filters': {...} # 可选的其他筛选条件 + } + """ + try: + print(f'\n[WebSocket DEBUG] ========== 收到订阅请求 ==========') + print(f'[WebSocket DEBUG] Socket ID: {request.sid}') + print(f'[WebSocket DEBUG] 订阅数据: {data}') + + event_type = data.get('event_type', 'all') + importance = data.get('importance', 'all') + + print(f'[WebSocket DEBUG] 事件类型: {event_type}') + print(f'[WebSocket DEBUG] 重要性: {importance}') + + # 加入对应的房间 + room_name = f"events_{event_type}" + print(f'[WebSocket DEBUG] 准备加入房间: {room_name}') + join_room(room_name) + print(f'[WebSocket DEBUG] ✓ 已加入房间: {room_name}') + + print(f'[WebSocket] 客户端 {request.sid} 订阅了房间: {room_name}') + + response_data = { + 'success': True, + 'room': room_name, + 'event_type': event_type, + 'importance': importance, + 'message': f'已订阅 {event_type} 类型的事件推送' + } + print(f'[WebSocket DEBUG] 准备发送 subscription_confirmed: {response_data}') + emit('subscription_confirmed', response_data) + print(f'[WebSocket DEBUG] ✓ 已发送 subscription_confirmed') + print(f'[WebSocket DEBUG] ========== 订阅完成 ==========\n') + + except Exception as e: + print(f'[WebSocket ERROR] 订阅失败: {e}') + import traceback + traceback.print_exc() + emit('subscription_error', { + 'success': False, + 'error': str(e) + }) + + +@socketio.on('unsubscribe_events') +def handle_unsubscribe(data): + """取消订阅事件推送""" + try: + print(f'\n[WebSocket DEBUG] ========== 收到取消订阅请求 ==========') + print(f'[WebSocket DEBUG] Socket ID: {request.sid}') + print(f'[WebSocket DEBUG] 数据: {data}') + + event_type = data.get('event_type', 'all') + room_name = f"events_{event_type}" + + print(f'[WebSocket DEBUG] 准备离开房间: {room_name}') + leave_room(room_name) + print(f'[WebSocket DEBUG] ✓ 已离开房间: {room_name}') + + print(f'[WebSocket] 客户端 {request.sid} 取消订阅房间: {room_name}') + + emit('unsubscription_confirmed', { + 'success': True, + 'room': room_name, + 'message': f'已取消订阅 {event_type} 类型的事件推送' + }) + print(f'[WebSocket DEBUG] ========== 取消订阅完成 ==========\n') + + except Exception as e: + print(f'[WebSocket ERROR] 取消订阅失败: {e}') + import traceback + traceback.print_exc() + emit('unsubscription_error', { + 'success': False, + 'error': str(e) + }) + + +@socketio.on('disconnect') +def handle_disconnect(): + """客户端断开连接事件""" + print(f'\n[WebSocket DEBUG] ========== 客户端断开 ==========') + print(f'[WebSocket DEBUG] Socket ID: {request.sid}') + print(f'[WebSocket] 客户端已断开: {request.sid}') + print(f'[WebSocket DEBUG] ========== 断开完成 ==========\n') + + +# ==================== WebSocket 辅助函数 ==================== + +def broadcast_new_event(event): + """ + 广播新事件到所有订阅的客户端 + 在创建新事件时调用此函数 + + Args: + event: Event 模型实例 + """ + try: + print(f'\n[WebSocket DEBUG] ========== 广播新事件 ==========') + print(f'[WebSocket DEBUG] 事件ID: {event.id}') + print(f'[WebSocket DEBUG] 事件标题: {event.title}') + print(f'[WebSocket DEBUG] 事件类型: {event.event_type}') + print(f'[WebSocket DEBUG] 重要性: {event.importance}') + + event_data = { + 'id': event.id, + 'title': event.title, + 'description': event.description, + 'event_type': event.event_type, + 'importance': event.importance, + 'status': event.status, + 'created_at': event.created_at.isoformat() if event.created_at else None, + 'hot_score': event.hot_score, + 'view_count': event.view_count, + 'related_avg_chg': event.related_avg_chg, + 'related_max_chg': event.related_max_chg, + 'keywords': event.keywords_list if hasattr(event, 'keywords_list') else event.keywords, + } + + print(f'[WebSocket DEBUG] 准备发送的数据: {event_data}') + + # 发送到所有订阅者(all 房间) + print(f'[WebSocket DEBUG] 正在发送到房间: events_all') + socketio.emit('new_event', event_data, room='events_all', namespace='/') + print(f'[WebSocket DEBUG] ✓ 已发送到 events_all') + + # 发送到特定类型订阅者 + if event.event_type: + room_name = f"events_{event.event_type}" + print(f'[WebSocket DEBUG] 正在发送到房间: {room_name}') + socketio.emit('new_event', event_data, room=room_name, namespace='/') + print(f'[WebSocket DEBUG] ✓ 已发送到 {room_name}') + print(f'[WebSocket] 已推送新事件到房间: events_all, {room_name}') + else: + print(f'[WebSocket] 已推送新事件到房间: events_all') + + print(f'[WebSocket DEBUG] ========== 广播完成 ==========\n') + + except Exception as e: + print(f'[WebSocket ERROR] 推送新事件失败: {e}') + import traceback + traceback.print_exc() + + +# ==================== WebSocket 轮询机制(检测新事件) ==================== + +# 内存变量:记录近24小时内已知的事件ID集合和最大ID +known_event_ids_in_24h = set() # 近24小时内已知的所有事件ID +last_max_event_id = 0 # 已知的最大事件ID + +def poll_new_events(): + """ + 定期轮询数据库,检查是否有新事件 + 每 30 秒执行一次 + + 新的设计思路(修复 created_at 不是入库时间的问题): + 1. 查询近24小时内的所有活跃事件(按 created_at,因为这是事件发生时间) + 2. 通过对比事件ID(自增ID)来判断是否为新插入的事件 + 3. 推送 ID > last_max_event_id 的事件 + 4. 更新已知事件ID集合和最大ID + """ + global known_event_ids_in_24h, last_max_event_id + + try: + with app.app_context(): + from datetime import datetime, timedelta + + current_time = datetime.now() + print(f'\n[轮询 DEBUG] ========== 开始轮询 ==========') + print(f'[轮询 DEBUG] 当前时间: {current_time.strftime("%Y-%m-%d %H:%M:%S")}') + print(f'[轮询 DEBUG] 已知事件ID数量: {len(known_event_ids_in_24h)}') + print(f'[轮询 DEBUG] 当前最大事件ID: {last_max_event_id}') + + # 查询近24小时内的所有活跃事件(按事件发生时间 created_at) + time_24h_ago = current_time - timedelta(hours=24) + print(f'[轮询 DEBUG] 查询时间范围: 近24小时({time_24h_ago.strftime("%Y-%m-%d %H:%M:%S")} ~ 现在)') + + # 查询所有近24小时内的活跃事件 + events_in_24h = Event.query.filter( + Event.created_at >= time_24h_ago, + Event.status == 'active' + ).order_by(Event.id.asc()).all() + + print(f'[轮询 DEBUG] 数据库查询结果: 找到 {len(events_in_24h)} 个近24小时内的事件') + + # 找出新插入的事件(ID > last_max_event_id) + new_events = [ + event for event in events_in_24h + if event.id > last_max_event_id + ] + + print(f'[轮询 DEBUG] 新事件数量(ID > {last_max_event_id}): {len(new_events)} 个') + + if new_events: + print(f'[轮询] 发现 {len(new_events)} 个新事件') + + for event in new_events: + print(f'[轮询 DEBUG] 新事件详情:') + print(f'[轮询 DEBUG] - ID: {event.id}') + print(f'[轮询 DEBUG] - 标题: {event.title}') + print(f'[轮询 DEBUG] - 事件发生时间(created_at): {event.created_at}') + print(f'[轮询 DEBUG] - 事件类型: {event.event_type}') + + # 推送新事件 + print(f'[轮询 DEBUG] 准备推送事件 ID={event.id}') + broadcast_new_event(event) + print(f'[轮询] ✓ 已推送事件 ID={event.id}, 标题={event.title}') + + # 更新已知事件ID集合(所有近24小时内的事件ID) + known_event_ids_in_24h = set(event.id for event in events_in_24h) + + # 更新最大事件ID + new_max_id = max(event.id for event in events_in_24h) + print(f'[轮询 DEBUG] 更新最大事件ID: {last_max_event_id} -> {new_max_id}') + last_max_event_id = new_max_id + + print(f'[轮询 DEBUG] 更新后已知事件ID数量: {len(known_event_ids_in_24h)}') + + else: + print(f'[轮询 DEBUG] 没有新事件需要推送') + + # 即使没有新事件,也要更新已知事件集合(清理超过24小时的) + if events_in_24h: + known_event_ids_in_24h = set(event.id for event in events_in_24h) + current_max_id = max(event.id for event in events_in_24h) + if current_max_id != last_max_event_id: + print(f'[轮询 DEBUG] 更新最大事件ID: {last_max_event_id} -> {current_max_id}') + last_max_event_id = current_max_id + + print(f'[轮询 DEBUG] ========== 轮询结束 ==========\n') + + except Exception as e: + print(f'[轮询 ERROR] 检查新事件时出错: {e}') + import traceback + traceback.print_exc() + + +def initialize_event_polling(): + """ + 初始化事件轮询机制 + 在应用启动时调用 + """ + global known_event_ids_in_24h, last_max_event_id + + try: + from datetime import datetime, timedelta + + with app.app_context(): + current_time = datetime.now() + time_24h_ago = current_time - timedelta(hours=24) + + print(f'\n[轮询] ========== 初始化事件轮询 ==========') + print(f'[轮询] 当前时间: {current_time.strftime("%Y-%m-%d %H:%M:%S")}') + + # 查询近24小时内的所有活跃事件 + events_in_24h = Event.query.filter( + Event.created_at >= time_24h_ago, + Event.status == 'active' + ).order_by(Event.id.asc()).all() + + # 初始化已知事件ID集合 + known_event_ids_in_24h = set(event.id for event in events_in_24h) + + # 初始化最大事件ID + if events_in_24h: + last_max_event_id = max(event.id for event in events_in_24h) + print(f'[轮询] 近24小时内共有 {len(events_in_24h)} 个活跃事件') + print(f'[轮询] 初始最大事件ID: {last_max_event_id}') + print(f'[轮询] 事件ID范围: {min(event.id for event in events_in_24h)} ~ {last_max_event_id}') + else: + last_max_event_id = 0 + print(f'[轮询] 近24小时内没有活跃事件') + print(f'[轮询] 初始最大事件ID: 0') + + # 统计数据库中的事件总数 + total_events = Event.query.filter_by(status='active').count() + print(f'[轮询] 数据库中共有 {total_events} 个活跃事件(所有时间)') + print(f'[轮询] 只会推送 ID > {last_max_event_id} 的新事件') + print(f'[轮询] ========== 初始化完成 ==========\n') + + # 创建后台调度器 + scheduler = BackgroundScheduler() + # 每 30 秒执行一次轮询 + scheduler.add_job( + func=poll_new_events, + trigger='interval', + seconds=30, + id='poll_new_events', + name='检查新事件并推送', + replace_existing=True + ) + scheduler.start() + print('[轮询] 调度器已启动,每 30 秒检查一次新事件') + + except Exception as e: + print(f'[轮询] 初始化失败: {e}') + + +# ==================== 结束 WebSocket 部分 ==================== + + +@app.route('/api/posts//like', methods=['POST']) +@login_required +def like_post(post_id): + """点赞/取消点赞帖子""" + try: + post = Post.query.get_or_404(post_id) + + # 检查是否已经点赞 + existing_like = PostLike.query.filter_by( + post_id=post_id, + user_id=current_user.id + ).first() + + if existing_like: + # 取消点赞 + db.session.delete(existing_like) + post.likes_count = max(0, post.likes_count - 1) + message = '取消点赞成功' + liked = False + else: + # 添加点赞 + new_like = PostLike(post_id=post_id, user_id=current_user.id) + db.session.add(new_like) + post.likes_count += 1 + message = '点赞成功' + liked = True + + db.session.commit() + + return jsonify({ + 'success': True, + 'message': message, + 'likes_count': post.likes_count, + 'liked': liked + }) + + except Exception as e: + db.session.rollback() + print(f"点赞失败: {e}") + return jsonify({ + 'success': False, + 'message': str(e) + }), 500 + + +@app.route('/api/comments//like', methods=['POST']) +@login_required +def like_comment(comment_id): + """点赞/取消点赞评论""" + try: + comment = Comment.query.get_or_404(comment_id) + + # 检查是否已经点赞(需要创建 CommentLike 关联到新的 Comment 模型) + # 暂时使用简单的计数器 + comment.likes_count += 1 + db.session.commit() + + return jsonify({ + 'success': True, + 'message': '点赞成功', + 'likes_count': comment.likes_count + }) + + except Exception as e: + db.session.rollback() + print(f"点赞失败: {e}") + return jsonify({ + 'success': False, + 'message': str(e) + }), 500 + + +@app.route('/api/posts/', methods=['DELETE']) +@login_required +def delete_post(post_id): + """删除帖子""" + try: + post = Post.query.get_or_404(post_id) + + # 检查权限:只能删除自己的帖子 + if post.user_id != current_user.id: + return jsonify({ + 'success': False, + 'message': '您只能删除自己的帖子' + }), 403 + + # 软删除 + post.status = 'deleted' + + # 更新事件的帖子数 + event = Event.query.get(post.event_id) + if event: + event.post_count = Post.query.filter_by(event_id=post.event_id, status='active').count() + + # 更新用户发帖数 + if current_user.post_count > 0: + current_user.post_count -= 1 + + db.session.commit() + + return jsonify({ + 'success': True, + 'message': '帖子删除成功' + }) + + except Exception as e: + db.session.rollback() + print(f"删除帖子失败: {e}") + return jsonify({ + 'success': False, + 'message': str(e) + }), 500 + + +@app.route('/api/comments/', methods=['DELETE']) +@login_required +def delete_comment(comment_id): + """删除评论""" + try: + comment = Comment.query.get_or_404(comment_id) + + # 检查权限:只能删除自己的评论 + if comment.user_id != current_user.id: + return jsonify({ + 'success': False, + 'message': '您只能删除自己的评论' + }), 403 + + # 软删除 + comment.status = 'deleted' + comment.content = '[该评论已被删除]' + + # 更新帖子评论数 + post = Post.query.get(comment.post_id) + if post: + post.comments_count = Comment.query.filter_by(post_id=comment.post_id, status='active').count() + + # 更新用户评论数 + if current_user.comment_count > 0: + current_user.comment_count -= 1 + + db.session.commit() + + return jsonify({ + 'success': True, + 'message': '评论删除成功' + }) + + except Exception as e: + db.session.rollback() + print(f"删除评论失败: {e}") + return jsonify({ + 'success': False, + 'message': str(e) + }), 500 + + +def format_decimal(value): + """格式化decimal类型数据""" + if value is None: + return None + if isinstance(value, Decimal): + return float(value) + return float(value) + + +def format_date(date_obj): + """格式化日期""" + if date_obj is None: + return None + if isinstance(date_obj, datetime): + return date_obj.strftime('%Y-%m-%d') + return str(date_obj) + + +def remove_cycles_from_sankey_flows(flows_data): + """ + 移除Sankey图数据中的循环边,确保数据是DAG(有向无环图) + 使用拓扑排序算法检测循环,优先保留flow_ratio高的边 + + Args: + flows_data: list of flow objects with 'source', 'target', 'flow_metrics' keys + + Returns: + list of flows without cycles + """ + if not flows_data: + return flows_data + + # 按flow_ratio降序排序,优先保留重要的边 + sorted_flows = sorted( + flows_data, + key=lambda x: x.get('flow_metrics', {}).get('flow_ratio', 0) or 0, + reverse=True + ) + + # 构建图的邻接表和入度表 + def build_graph(flows): + graph = {} # node -> list of successors + in_degree = {} # node -> in-degree count + all_nodes = set() + + for flow in flows: + source = flow['source']['node_name'] + target = flow['target']['node_name'] + all_nodes.add(source) + all_nodes.add(target) + + if source not in graph: + graph[source] = [] + graph[source].append(target) + + if target not in in_degree: + in_degree[target] = 0 + in_degree[target] += 1 + + if source not in in_degree: + in_degree[source] = 0 + + return graph, in_degree, all_nodes + + # 使用Kahn算法检测是否有环 + def has_cycle(graph, in_degree, all_nodes): + # 找到所有入度为0的节点 + queue = [node for node in all_nodes if in_degree.get(node, 0) == 0] + visited_count = 0 + + while queue: + node = queue.pop(0) + visited_count += 1 + + # 访问所有邻居 + for neighbor in graph.get(node, []): + in_degree[neighbor] -= 1 + if in_degree[neighbor] == 0: + queue.append(neighbor) + + # 如果访问的节点数等于总节点数,说明没有环 + return visited_count < len(all_nodes) + + # 逐个添加边,如果添加后产生环则跳过 + result_flows = [] + + for flow in sorted_flows: + # 尝试添加这条边 + temp_flows = result_flows + [flow] + + # 检查是否产生环 + graph, in_degree, all_nodes = build_graph(temp_flows) + + # 复制in_degree用于检测(因为检测过程会修改它) + in_degree_copy = in_degree.copy() + + if not has_cycle(graph, in_degree_copy, all_nodes): + # 没有产生环,可以添加 + result_flows.append(flow) + else: + # 产生环,跳过这条边 + print(f"Skipping edge that creates cycle: {flow['source']['node_name']} -> {flow['target']['node_name']}") + + removed_count = len(flows_data) - len(result_flows) + if removed_count > 0: + print(f"Removed {removed_count} edges to eliminate cycles in Sankey diagram") + + return result_flows + + +def get_report_type(date_str): + """获取报告期类型""" + if not date_str: + return '' + if isinstance(date_str, str): + date = datetime.strptime(date_str, '%Y-%m-%d') + else: + date = date_str + + month = date.month + year = date.year + + if month == 3: + return f"{year}年一季报" + elif month == 6: + return f"{year}年中报" + elif month == 9: + return f"{year}年三季报" + elif month == 12: + return f"{year}年年报" + else: + return str(date_str) + + +@app.route('/api/financial/stock-info/', methods=['GET']) +def get_stock_info(seccode): + """获取股票基本信息和最新财务摘要""" + try: + # 获取最新的财务数据 + query = text(""" + SELECT distinct a.SECCODE, + a.SECNAME, + a.ENDDATE, + a.F003N as eps, + a.F004N as basic_eps, + a.F005N as diluted_eps, + a.F006N as deducted_eps, + a.F007N as undistributed_profit_ps, + a.F008N as bvps, + a.F010N as capital_reserve_ps, + a.F014N as roe, + a.F067N as roe_weighted, + a.F016N as roa, + a.F078N as gross_margin, + a.F017N as net_margin, + a.F089N as revenue, + a.F101N as net_profit, + a.F102N as parent_net_profit, + a.F118N as total_assets, + a.F121N as total_liabilities, + a.F128N as total_equity, + a.F052N as revenue_growth, + a.F053N as profit_growth, + a.F054N as equity_growth, + a.F056N as asset_growth, + a.F122N as share_capital + FROM ea_financialindex a + WHERE a.SECCODE = :seccode + ORDER BY a.ENDDATE DESC LIMIT 1 + """) + + result = engine.execute(query, seccode=seccode).fetchone() + + if not result: + return jsonify({ + 'success': False, + 'message': f'未找到股票代码 {seccode} 的财务数据' + }), 404 + + # 获取最近的业绩预告 + forecast_query = text(""" + SELECT distinct F001D as report_date, + F003V as forecast_type, + F004V as content, + F007N as profit_lower, + F008N as profit_upper, + F009N as change_lower, + F010N as change_upper + FROM ea_forecast + WHERE SECCODE = :seccode + AND F006C = 'T' + ORDER BY F001D DESC LIMIT 1 + """) + + forecast_result = engine.execute(forecast_query, seccode=seccode).fetchone() + + data = { + 'stock_code': result.SECCODE, + 'stock_name': result.SECNAME, + 'latest_period': format_date(result.ENDDATE), + 'report_type': get_report_type(result.ENDDATE), + 'key_metrics': { + 'eps': format_decimal(result.eps), + 'basic_eps': format_decimal(result.basic_eps), + 'diluted_eps': format_decimal(result.diluted_eps), + 'deducted_eps': format_decimal(result.deducted_eps), + 'bvps': format_decimal(result.bvps), + 'roe': format_decimal(result.roe), + 'roe_weighted': format_decimal(result.roe_weighted), + 'roa': format_decimal(result.roa), + 'gross_margin': format_decimal(result.gross_margin), + 'net_margin': format_decimal(result.net_margin), + }, + 'financial_summary': { + 'revenue': format_decimal(result.revenue), + 'net_profit': format_decimal(result.net_profit), + 'parent_net_profit': format_decimal(result.parent_net_profit), + 'total_assets': format_decimal(result.total_assets), + 'total_liabilities': format_decimal(result.total_liabilities), + 'total_equity': format_decimal(result.total_equity), + 'share_capital': format_decimal(result.share_capital), + }, + 'growth_rates': { + 'revenue_growth': format_decimal(result.revenue_growth), + 'profit_growth': format_decimal(result.profit_growth), + 'equity_growth': format_decimal(result.equity_growth), + 'asset_growth': format_decimal(result.asset_growth), + } + } + + # 添加业绩预告信息 + if forecast_result: + data['latest_forecast'] = { + 'report_date': format_date(forecast_result.report_date), + 'forecast_type': forecast_result.forecast_type, + 'content': forecast_result.content, + 'profit_range': { + 'lower': format_decimal(forecast_result.profit_lower), + 'upper': format_decimal(forecast_result.profit_upper), + }, + 'change_range': { + 'lower': format_decimal(forecast_result.change_lower), + 'upper': format_decimal(forecast_result.change_upper), + } + } + + return jsonify({ + 'success': True, + 'data': data + }) + + except Exception as e: + return jsonify({ + 'success': False, + 'error': str(e) + }), 500 + + +@app.route('/api/financial/balance-sheet/', methods=['GET']) +def get_balance_sheet(seccode): + """获取完整的资产负债表数据""" + try: + limit = request.args.get('limit', 12, type=int) + + query = text(""" + SELECT distinct ENDDATE, + DECLAREDATE, + -- 流动资产 + F006N as cash, -- 货币资金 + F007N as trading_financial_assets, -- 交易性金融资产 + F008N as notes_receivable, -- 应收票据 + F009N as accounts_receivable, -- 应收账款 + F010N as prepayments, -- 预付款项 + F011N as other_receivables, -- 其他应收款 + F013N as interest_receivable, -- 应收利息 + F014N as dividends_receivable, -- 应收股利 + F015N as inventory, -- 存货 + F016N as consumable_biological_assets, -- 消耗性生物资产 + F017N as non_current_assets_due_within_one_year, -- 一年内到期的非流动资产 + F018N as other_current_assets, -- 其他流动资产 + F019N as total_current_assets, -- 流动资产合计 + + -- 非流动资产 + F020N as available_for_sale_financial_assets, -- 可供出售金融资产 + F021N as held_to_maturity_investments, -- 持有至到期投资 + F022N as long_term_receivables, -- 长期应收款 + F023N as long_term_equity_investments, -- 长期股权投资 + F024N as investment_property, -- 投资性房地产 + F025N as fixed_assets, -- 固定资产 + F026N as construction_in_progress, -- 在建工程 + F027N as engineering_materials, -- 工程物资 + F029N as productive_biological_assets, -- 生产性生物资产 + F030N as oil_and_gas_assets, -- 油气资产 + F031N as intangible_assets, -- 无形资产 + F032N as development_expenditure, -- 开发支出 + F033N as goodwill, -- 商誉 + F034N as long_term_deferred_expenses, -- 长期待摊费用 + F035N as deferred_tax_assets, -- 递延所得税资产 + F036N as other_non_current_assets, -- 其他非流动资产 + F037N as total_non_current_assets, -- 非流动资产合计 + F038N as total_assets, -- 资产总计 + + -- 流动负债 + F039N as short_term_borrowings, -- 短期借款 + F040N as trading_financial_liabilities, -- 交易性金融负债 + F041N as notes_payable, -- 应付票据 + F042N as accounts_payable, -- 应付账款 + F043N as advance_receipts, -- 预收款项 + F044N as employee_compensation_payable, -- 应付职工薪酬 + F045N as taxes_payable, -- 应交税费 + F046N as interest_payable, -- 应付利息 + F047N as dividends_payable, -- 应付股利 + F048N as other_payables, -- 其他应付款 + F050N as non_current_liabilities_due_within_one_year, -- 一年内到期的非流动负债 + F051N as other_current_liabilities, -- 其他流动负债 + F052N as total_current_liabilities, -- 流动负债合计 + + -- 非流动负债 + F053N as long_term_borrowings, -- 长期借款 + F054N as bonds_payable, -- 应付债券 + F055N as long_term_payables, -- 长期应付款 + F056N as special_payables, -- 专项应付款 + F057N as estimated_liabilities, -- 预计负债 + F058N as deferred_tax_liabilities, -- 递延所得税负债 + F059N as other_non_current_liabilities, -- 其他非流动负债 + F060N as total_non_current_liabilities, -- 非流动负债合计 + F061N as total_liabilities, -- 负债合计 + + -- 所有者权益 + F062N as share_capital, -- 股本 + F063N as capital_reserve, -- 资本公积 + F064N as surplus_reserve, -- 盈余公积 + F065N as undistributed_profit, -- 未分配利润 + F066N as treasury_stock, -- 库存股 + F067N as minority_interests, -- 少数股东权益 + F070N as total_equity, -- 所有者权益合计 + F071N as total_liabilities_and_equity, -- 负债和所有者权益合计 + F073N as parent_company_equity, -- 归属于母公司所有者权益 + F074N as other_comprehensive_income, -- 其他综合收益 + + -- 新会计准则科目 + F110N as other_debt_investments, -- 其他债权投资 + F111N as other_equity_investments, -- 其他权益工具投资 + F112N as other_non_current_financial_assets, -- 其他非流动金融资产 + F115N as contract_liabilities, -- 合同负债 + F119N as contract_assets, -- 合同资产 + F120N as receivables_financing, -- 应收款项融资 + F121N as right_of_use_assets, -- 使用权资产 + F122N as lease_liabilities -- 租赁负债 + FROM ea_asset + WHERE SECCODE = :seccode + and F002V = '071001' + ORDER BY ENDDATE DESC LIMIT :limit + """) + + result = engine.execute(query, seccode=seccode, limit=limit) + data = [] + + for row in result: + # 安全计算关键比率,避免 Decimal 与 None 运算错误 + def to_float(v): + try: + return float(v) if v is not None else None + except Exception: + return None + + ta = to_float(row.total_assets) + tl = to_float(row.total_liabilities) + tca = to_float(row.total_current_assets) + tcl = to_float(row.total_current_liabilities) + inv = to_float(row.inventory) or 0.0 + + asset_liability_ratio_val = None + if ta is not None and ta != 0 and tl is not None: + asset_liability_ratio_val = (tl / ta) * 100 + + current_ratio_val = None + if tcl is not None and tcl != 0 and tca is not None: + current_ratio_val = tca / tcl + + quick_ratio_val = None + if tcl is not None and tcl != 0 and tca is not None: + quick_ratio_val = (tca - inv) / tcl + + period_data = { + 'period': format_date(row.ENDDATE), + 'declare_date': format_date(row.DECLAREDATE), + 'report_type': get_report_type(row.ENDDATE), + + # 资产部分 + 'assets': { + 'current_assets': { + 'cash': format_decimal(row.cash), + 'trading_financial_assets': format_decimal(row.trading_financial_assets), + 'notes_receivable': format_decimal(row.notes_receivable), + 'accounts_receivable': format_decimal(row.accounts_receivable), + 'prepayments': format_decimal(row.prepayments), + 'other_receivables': format_decimal(row.other_receivables), + 'inventory': format_decimal(row.inventory), + 'contract_assets': format_decimal(row.contract_assets), + 'other_current_assets': format_decimal(row.other_current_assets), + 'total': format_decimal(row.total_current_assets), + }, + 'non_current_assets': { + 'long_term_equity_investments': format_decimal(row.long_term_equity_investments), + 'investment_property': format_decimal(row.investment_property), + 'fixed_assets': format_decimal(row.fixed_assets), + 'construction_in_progress': format_decimal(row.construction_in_progress), + 'intangible_assets': format_decimal(row.intangible_assets), + 'goodwill': format_decimal(row.goodwill), + 'right_of_use_assets': format_decimal(row.right_of_use_assets), + 'deferred_tax_assets': format_decimal(row.deferred_tax_assets), + 'other_non_current_assets': format_decimal(row.other_non_current_assets), + 'total': format_decimal(row.total_non_current_assets), + }, + 'total': format_decimal(row.total_assets), + }, + + # 负债部分 + 'liabilities': { + 'current_liabilities': { + 'short_term_borrowings': format_decimal(row.short_term_borrowings), + 'notes_payable': format_decimal(row.notes_payable), + 'accounts_payable': format_decimal(row.accounts_payable), + 'advance_receipts': format_decimal(row.advance_receipts), + 'contract_liabilities': format_decimal(row.contract_liabilities), + 'employee_compensation_payable': format_decimal(row.employee_compensation_payable), + 'taxes_payable': format_decimal(row.taxes_payable), + 'other_payables': format_decimal(row.other_payables), + 'non_current_liabilities_due_within_one_year': format_decimal( + row.non_current_liabilities_due_within_one_year), + 'total': format_decimal(row.total_current_liabilities), + }, + 'non_current_liabilities': { + 'long_term_borrowings': format_decimal(row.long_term_borrowings), + 'bonds_payable': format_decimal(row.bonds_payable), + 'lease_liabilities': format_decimal(row.lease_liabilities), + 'deferred_tax_liabilities': format_decimal(row.deferred_tax_liabilities), + 'other_non_current_liabilities': format_decimal(row.other_non_current_liabilities), + 'total': format_decimal(row.total_non_current_liabilities), + }, + 'total': format_decimal(row.total_liabilities), + }, + + # 股东权益部分 + 'equity': { + 'share_capital': format_decimal(row.share_capital), + 'capital_reserve': format_decimal(row.capital_reserve), + 'surplus_reserve': format_decimal(row.surplus_reserve), + 'undistributed_profit': format_decimal(row.undistributed_profit), + 'treasury_stock': format_decimal(row.treasury_stock), + 'other_comprehensive_income': format_decimal(row.other_comprehensive_income), + 'parent_company_equity': format_decimal(row.parent_company_equity), + 'minority_interests': format_decimal(row.minority_interests), + 'total': format_decimal(row.total_equity), + }, + + # 关键比率 + 'key_ratios': { + 'asset_liability_ratio': format_decimal(asset_liability_ratio_val), + 'current_ratio': format_decimal(current_ratio_val), + 'quick_ratio': format_decimal(quick_ratio_val), + } + } + data.append(period_data) + + return jsonify({ + 'success': True, + 'data': data + }) + + except Exception as e: + return jsonify({ + 'success': False, + 'error': str(e) + }), 500 + + +@app.route('/api/financial/income-statement/', methods=['GET']) +def get_income_statement(seccode): + """获取完整的利润表数据""" + try: + limit = request.args.get('limit', 12, type=int) + + query = text(""" + SELECT distinct ENDDATE, + STARTDATE, + DECLAREDATE, + -- 营业收入部分 + F006N as revenue, -- 营业收入 + F035N as total_operating_revenue, -- 营业总收入 + F051N as other_income, -- 其他收入 + + -- 营业成本部分 + F007N as cost, -- 营业成本 + F008N as taxes_and_surcharges, -- 税金及附加 + F009N as selling_expenses, -- 销售费用 + F010N as admin_expenses, -- 管理费用 + F056N as rd_expenses, -- 研发费用 + F012N as financial_expenses, -- 财务费用 + F062N as interest_expense, -- 利息费用 + F063N as interest_income, -- 利息收入 + F013N as asset_impairment_loss, -- 资产减值损失(营业总成本) + F057N as credit_impairment_loss, -- 信用减值损失(营业总成本) + F036N as total_operating_cost, -- 营业总成本 + + -- 其他收益 + F014N as fair_value_change_income, -- 公允价值变动净收益 + F015N as investment_income, -- 投资收益 + F016N as investment_income_from_associates, -- 对联营企业和合营企业的投资收益 + F037N as exchange_income, -- 汇兑收益 + F058N as net_exposure_hedging_income, -- 净敞口套期收益 + F059N as asset_disposal_income, -- 资产处置收益 + + -- 利润部分 + F018N as operating_profit, -- 营业利润 + F019N as subsidy_income, -- 补贴收入 + F020N as non_operating_income, -- 营业外收入 + F021N as non_operating_expenses, -- 营业外支出 + F022N as non_current_asset_disposal_loss, -- 非流动资产处置损失 + F024N as total_profit, -- 利润总额 + F025N as income_tax_expense, -- 所得税 + F027N as net_profit, -- 净利润 + F028N as parent_net_profit, -- 归属于母公司所有者的净利润 + F029N as minority_profit, -- 少数股东损益 + + -- 持续经营 + F060N as continuing_operations_net_profit, -- 持续经营净利润 + F061N as discontinued_operations_net_profit, -- 终止经营净利润 + + -- 每股收益 + F031N as basic_eps, -- 基本每股收益 + F032N as diluted_eps, -- 稀释每股收益 + + -- 综合收益 + F038N as other_comprehensive_income_after_tax, -- 其他综合收益的税后净额 + F039N as total_comprehensive_income, -- 综合收益总额 + F040N as parent_company_comprehensive_income, -- 归属于母公司的综合收益 + F041N as minority_comprehensive_income -- 归属于少数股东的综合收益 + FROM ea_profit + WHERE SECCODE = :seccode + and F002V = '071001' + ORDER BY ENDDATE DESC LIMIT :limit + """) + + result = engine.execute(query, seccode=seccode, limit=limit) + data = [] + + for row in result: + # 计算一些衍生指标 + gross_profit = (row.revenue - row.cost) if row.revenue and row.cost else None + gross_margin = (gross_profit / row.revenue * 100) if row.revenue and gross_profit else None + operating_margin = ( + row.operating_profit / row.revenue * 100) if row.revenue and row.operating_profit else None + net_margin = (row.net_profit / row.revenue * 100) if row.revenue and row.net_profit else None + + # 三费合计 + three_expenses = 0 + if row.selling_expenses: + three_expenses += row.selling_expenses + if row.admin_expenses: + three_expenses += row.admin_expenses + if row.financial_expenses: + three_expenses += row.financial_expenses + + # 四费合计(加研发) + four_expenses = three_expenses + if row.rd_expenses: + four_expenses += row.rd_expenses + + period_data = { + 'period': format_date(row.ENDDATE), + 'start_date': format_date(row.STARTDATE), + 'declare_date': format_date(row.DECLAREDATE), + 'report_type': get_report_type(row.ENDDATE), + + # 收入部分 + 'revenue': { + 'operating_revenue': format_decimal(row.revenue), + 'total_operating_revenue': format_decimal(row.total_operating_revenue), + 'other_income': format_decimal(row.other_income), + }, + + # 成本费用部分 + 'costs': { + 'operating_cost': format_decimal(row.cost), + 'taxes_and_surcharges': format_decimal(row.taxes_and_surcharges), + 'selling_expenses': format_decimal(row.selling_expenses), + 'admin_expenses': format_decimal(row.admin_expenses), + 'rd_expenses': format_decimal(row.rd_expenses), + 'financial_expenses': format_decimal(row.financial_expenses), + 'interest_expense': format_decimal(row.interest_expense), + 'interest_income': format_decimal(row.interest_income), + 'asset_impairment_loss': format_decimal(row.asset_impairment_loss), + 'credit_impairment_loss': format_decimal(row.credit_impairment_loss), + 'total_operating_cost': format_decimal(row.total_operating_cost), + 'three_expenses_total': format_decimal(three_expenses), + 'four_expenses_total': format_decimal(four_expenses), + }, + + # 其他收益 + 'other_gains': { + 'fair_value_change': format_decimal(row.fair_value_change_income), + 'investment_income': format_decimal(row.investment_income), + 'investment_income_from_associates': format_decimal(row.investment_income_from_associates), + 'exchange_income': format_decimal(row.exchange_income), + 'asset_disposal_income': format_decimal(row.asset_disposal_income), + }, + + # 利润 + 'profit': { + 'gross_profit': format_decimal(gross_profit), + 'operating_profit': format_decimal(row.operating_profit), + 'total_profit': format_decimal(row.total_profit), + 'net_profit': format_decimal(row.net_profit), + 'parent_net_profit': format_decimal(row.parent_net_profit), + 'minority_profit': format_decimal(row.minority_profit), + 'continuing_operations_net_profit': format_decimal(row.continuing_operations_net_profit), + 'discontinued_operations_net_profit': format_decimal(row.discontinued_operations_net_profit), + }, + + # 非经营项目 + 'non_operating': { + 'subsidy_income': format_decimal(row.subsidy_income), + 'non_operating_income': format_decimal(row.non_operating_income), + 'non_operating_expenses': format_decimal(row.non_operating_expenses), + }, + + # 每股收益 + 'per_share': { + 'basic_eps': format_decimal(row.basic_eps), + 'diluted_eps': format_decimal(row.diluted_eps), + }, + + # 综合收益 + 'comprehensive_income': { + 'other_comprehensive_income': format_decimal(row.other_comprehensive_income_after_tax), + 'total_comprehensive_income': format_decimal(row.total_comprehensive_income), + 'parent_comprehensive_income': format_decimal(row.parent_company_comprehensive_income), + 'minority_comprehensive_income': format_decimal(row.minority_comprehensive_income), + }, + + # 关键比率 + 'margins': { + 'gross_margin': format_decimal(gross_margin), + 'operating_margin': format_decimal(operating_margin), + 'net_margin': format_decimal(net_margin), + 'expense_ratio': format_decimal(four_expenses / row.revenue * 100) if row.revenue else None, + 'rd_ratio': format_decimal( + row.rd_expenses / row.revenue * 100) if row.revenue and row.rd_expenses else None, + } + } + data.append(period_data) + + return jsonify({ + 'success': True, + 'data': data + }) + + except Exception as e: + return jsonify({ + 'success': False, + 'error': str(e) + }), 500 + + +@app.route('/api/financial/cashflow/', methods=['GET']) +def get_cashflow(seccode): + """获取完整的现金流量表数据""" + try: + limit = request.args.get('limit', 12, type=int) + + query = text(""" + SELECT distinct ENDDATE, + STARTDATE, + DECLAREDATE, + -- 经营活动现金流 + F006N as cash_from_sales, -- 销售商品、提供劳务收到的现金 + F007N as tax_refunds, -- 收到的税费返还 + F008N as other_operating_cash_received, -- 收到其他与经营活动有关的现金 + F009N as total_operating_cash_inflow, -- 经营活动现金流入小计 + F010N as cash_paid_for_goods, -- 购买商品、接受劳务支付的现金 + F011N as cash_paid_to_employees, -- 支付给职工以及为职工支付的现金 + F012N as taxes_paid, -- 支付的各项税费 + F013N as other_operating_cash_paid, -- 支付其他与经营活动有关的现金 + F014N as total_operating_cash_outflow, -- 经营活动现金流出小计 + F015N as net_operating_cash_flow, -- 经营活动产生的现金流量净额 + + -- 投资活动现金流 + F016N as cash_from_investment_recovery, -- 收回投资收到的现金 + F017N as cash_from_investment_income, -- 取得投资收益收到的现金 + F018N as cash_from_asset_disposal, -- 处置固定资产、无形资产和其他长期资产收回的现金净额 + F019N as cash_from_subsidiary_disposal, -- 处置子公司及其他营业单位收到的现金净额 + F020N as other_investment_cash_received, -- 收到其他与投资活动有关的现金 + F021N as total_investment_cash_inflow, -- 投资活动现金流入小计 + F022N as cash_paid_for_assets, -- 购建固定资产、无形资产和其他长期资产支付的现金 + F023N as cash_paid_for_investments, -- 投资支付的现金 + F024N as cash_paid_for_subsidiaries, -- 取得子公司及其他营业单位支付的现金净额 + F025N as other_investment_cash_paid, -- 支付其他与投资活动有关的现金 + F026N as total_investment_cash_outflow, -- 投资活动现金流出小计 + F027N as net_investment_cash_flow, -- 投资活动产生的现金流量净额 + + -- 筹资活动现金流 + F028N as cash_from_capital, -- 吸收投资收到的现金 + F029N as cash_from_borrowings, -- 取得借款收到的现金 + F030N as other_financing_cash_received, -- 收到其他与筹资活动有关的现金 + F031N as total_financing_cash_inflow, -- 筹资活动现金流入小计 + F032N as cash_paid_for_debt, -- 偿还债务支付的现金 + F033N as cash_paid_for_distribution, -- 分配股利、利润或偿付利息支付的现金 + F034N as other_financing_cash_paid, -- 支付其他与筹资活动有关的现金 + F035N as total_financing_cash_outflow, -- 筹资活动现金流出小计 + F036N as net_financing_cash_flow, -- 筹资活动产生的现金流量净额 + + -- 汇率变动影响 + F037N as exchange_rate_effect, -- 汇率变动对现金及现金等价物的影响 + F038N as other_cash_effect, -- 其他原因对现金的影响 + + -- 现金净增加额 + F039N as net_cash_increase, -- 现金及现金等价物净增加额 + F040N as beginning_cash_balance, -- 期初现金及现金等价物余额 + F041N as ending_cash_balance, -- 期末现金及现金等价物余额 + + -- 补充资料部分 + F044N as net_profit, -- 净利润 + F045N as asset_impairment, -- 资产减值准备 + F096N as credit_impairment, -- 信用减值损失 + F046N as depreciation, -- 固定资产折旧、油气资产折耗、生产性生物资产折旧 + F097N as right_of_use_asset_depreciation, -- 使用权资产折旧/摊销 + F047N as intangible_amortization, -- 无形资产摊销 + F048N as long_term_expense_amortization, -- 长期待摊费用摊销 + F049N as loss_on_disposal, -- 处置固定资产、无形资产和其他长期资产的损失 + F050N as fixed_asset_scrap_loss, -- 固定资产报废损失 + F051N as fair_value_change_loss, -- 公允价值变动损失 + F052N as financial_expenses, -- 财务费用 + F053N as investment_loss, -- 投资损失 + F054N as deferred_tax_asset_decrease, -- 递延所得税资产减少 + F055N as deferred_tax_liability_increase, -- 递延所得税负债增加 + F056N as inventory_decrease, -- 存货的减少 + F057N as operating_receivables_decrease, -- 经营性应收项目的减少 + F058N as operating_payables_increase, -- 经营性应付项目的增加 + F059N as other, -- 其他 + F060N as net_operating_cash_flow_indirect, -- 经营活动产生的现金流量净额(间接法) + + -- 特殊行业科目(金融) + F072N as customer_deposit_increase, -- 客户存款和同业存放款项净增加额 + F073N as central_bank_borrowing_increase, -- 向中央银行借款净增加额 + F081N as interest_and_commission_received, -- 收取利息、手续费及佣金的现金 + F087N as interest_and_commission_paid -- 支付利息、手续费及佣金的现金 + FROM ea_cashflow + WHERE SECCODE = :seccode + and F002V = '071001' + ORDER BY ENDDATE DESC LIMIT :limit + """) + + result = engine.execute(query, seccode=seccode, limit=limit) + data = [] + + for row in result: + # 计算一些衍生指标 + free_cash_flow = None + if row.net_operating_cash_flow and row.cash_paid_for_assets: + free_cash_flow = row.net_operating_cash_flow - row.cash_paid_for_assets + + period_data = { + 'period': format_date(row.ENDDATE), + 'start_date': format_date(row.STARTDATE), + 'declare_date': format_date(row.DECLAREDATE), + 'report_type': get_report_type(row.ENDDATE), + + # 经营活动现金流 + 'operating_activities': { + 'inflow': { + 'cash_from_sales': format_decimal(row.cash_from_sales), + 'tax_refunds': format_decimal(row.tax_refunds), + 'other': format_decimal(row.other_operating_cash_received), + 'total': format_decimal(row.total_operating_cash_inflow), + }, + 'outflow': { + 'cash_for_goods': format_decimal(row.cash_paid_for_goods), + 'cash_for_employees': format_decimal(row.cash_paid_to_employees), + 'taxes_paid': format_decimal(row.taxes_paid), + 'other': format_decimal(row.other_operating_cash_paid), + 'total': format_decimal(row.total_operating_cash_outflow), + }, + 'net_flow': format_decimal(row.net_operating_cash_flow), + }, + + # 投资活动现金流 + 'investment_activities': { + 'inflow': { + 'investment_recovery': format_decimal(row.cash_from_investment_recovery), + 'investment_income': format_decimal(row.cash_from_investment_income), + 'asset_disposal': format_decimal(row.cash_from_asset_disposal), + 'subsidiary_disposal': format_decimal(row.cash_from_subsidiary_disposal), + 'other': format_decimal(row.other_investment_cash_received), + 'total': format_decimal(row.total_investment_cash_inflow), + }, + 'outflow': { + 'asset_purchase': format_decimal(row.cash_paid_for_assets), + 'investments': format_decimal(row.cash_paid_for_investments), + 'subsidiaries': format_decimal(row.cash_paid_for_subsidiaries), + 'other': format_decimal(row.other_investment_cash_paid), + 'total': format_decimal(row.total_investment_cash_outflow), + }, + 'net_flow': format_decimal(row.net_investment_cash_flow), + }, + + # 筹资活动现金流 + 'financing_activities': { + 'inflow': { + 'capital': format_decimal(row.cash_from_capital), + 'borrowings': format_decimal(row.cash_from_borrowings), + 'other': format_decimal(row.other_financing_cash_received), + 'total': format_decimal(row.total_financing_cash_inflow), + }, + 'outflow': { + 'debt_repayment': format_decimal(row.cash_paid_for_debt), + 'distribution': format_decimal(row.cash_paid_for_distribution), + 'other': format_decimal(row.other_financing_cash_paid), + 'total': format_decimal(row.total_financing_cash_outflow), + }, + 'net_flow': format_decimal(row.net_financing_cash_flow), + }, + + # 现金变动 + 'cash_changes': { + 'exchange_rate_effect': format_decimal(row.exchange_rate_effect), + 'other_effect': format_decimal(row.other_cash_effect), + 'net_increase': format_decimal(row.net_cash_increase), + 'beginning_balance': format_decimal(row.beginning_cash_balance), + 'ending_balance': format_decimal(row.ending_cash_balance), + }, + + # 补充资料(间接法) + 'indirect_method': { + 'net_profit': format_decimal(row.net_profit), + 'adjustments': { + 'asset_impairment': format_decimal(row.asset_impairment), + 'credit_impairment': format_decimal(row.credit_impairment), + 'depreciation': format_decimal(row.depreciation), + 'intangible_amortization': format_decimal(row.intangible_amortization), + 'financial_expenses': format_decimal(row.financial_expenses), + 'investment_loss': format_decimal(row.investment_loss), + 'inventory_decrease': format_decimal(row.inventory_decrease), + 'receivables_decrease': format_decimal(row.operating_receivables_decrease), + 'payables_increase': format_decimal(row.operating_payables_increase), + }, + 'net_operating_cash_flow': format_decimal(row.net_operating_cash_flow_indirect), + }, + + # 关键指标 + 'key_metrics': { + 'free_cash_flow': format_decimal(free_cash_flow), + 'cash_flow_to_profit_ratio': format_decimal( + row.net_operating_cash_flow / row.net_profit) if row.net_profit and row.net_operating_cash_flow else None, + 'capex': format_decimal(row.cash_paid_for_assets), + } + } + data.append(period_data) + + return jsonify({ + 'success': True, + 'data': data + }) + + except Exception as e: + return jsonify({ + 'success': False, + 'error': str(e) + }), 500 + + +@app.route('/api/financial/financial-metrics/', methods=['GET']) +def get_financial_metrics(seccode): + """获取完整的财务指标数据""" + try: + limit = request.args.get('limit', 12, type=int) + + query = text(""" + SELECT distinct ENDDATE, + STARTDATE, + -- 每股指标 + F003N as eps, -- 每股收益 + F004N as basic_eps, -- 基本每股收益 + F005N as diluted_eps, -- 稀释每股收益 + F006N as deducted_eps, -- 扣除非经常性损益每股收益 + F007N as undistributed_profit_ps, -- 每股未分配利润 + F008N as bvps, -- 每股净资产 + F009N as adjusted_bvps, -- 调整后每股净资产 + F010N as capital_reserve_ps, -- 每股资本公积金 + F059N as cash_flow_ps, -- 每股现金流量 + F060N as operating_cash_flow_ps, -- 每股经营现金流量 + + -- 盈利能力指标 + F011N as operating_profit_margin, -- 营业利润率 + F012N as tax_rate, -- 营业税金率 + F013N as cost_ratio, -- 营业成本率 + F014N as roe, -- 净资产收益率 + F066N as roe_deducted, -- 净资产收益率(扣除非经常性损益) + F067N as roe_weighted, -- 净资产收益率-加权 + F068N as roe_weighted_deducted, -- 净资产收益率-加权(扣除非经常性损益) + F015N as investment_return, -- 投资收益率 + F016N as roa, -- 总资产报酬率 + F017N as net_profit_margin, -- 净利润率 + F078N as gross_margin, -- 毛利率 + F020N as cost_profit_ratio, -- 成本费用利润率 + + -- 费用率指标 + F018N as admin_expense_ratio, -- 管理费用率 + F019N as financial_expense_ratio, -- 财务费用率 + F021N as three_expense_ratio, -- 三费比重 + F091N as selling_expense, -- 销售费用 + F092N as admin_expense, -- 管理费用 + F093N as financial_expense, -- 财务费用 + F094N as three_expense_total, -- 三费合计 + F130N as rd_expense, -- 研发费用 + F131N as rd_expense_ratio, -- 研发费用率 + F132N as selling_expense_ratio, -- 销售费用率 + F133N as four_expense_ratio, -- 四费费用率 + + -- 运营能力指标 + F022N as receivable_turnover, -- 应收账款周转率 + F023N as inventory_turnover, -- 存货周转率 + F024N as working_capital_turnover, -- 运营资金周转率 + F025N as total_asset_turnover, -- 总资产周转率 + F026N as fixed_asset_turnover, -- 固定资产周转率 + F027N as receivable_days, -- 应收账款周转天数 + F028N as inventory_days, -- 存货周转天数 + F029N as current_asset_turnover, -- 流动资产周转率 + F030N as current_asset_days, -- 流动资产周转天数 + F031N as total_asset_days, -- 总资产周转天数 + F032N as equity_turnover, -- 股东权益周转率 + + -- 偿债能力指标 + F041N as asset_liability_ratio, -- 资产负债率 + F042N as current_ratio, -- 流动比率 + F043N as quick_ratio, -- 速动比率 + F044N as cash_ratio, -- 现金比率 + F045N as interest_coverage, -- 利息保障倍数 + F049N as conservative_quick_ratio, -- 保守速动比率 + F050N as cash_to_maturity_debt_ratio, -- 现金到期债务比率 + F051N as tangible_asset_debt_ratio, -- 有形资产净值债务率 + + -- 成长能力指标 + F052N as revenue_growth, -- 营业收入增长率 + F053N as net_profit_growth, -- 净利润增长率 + F054N as equity_growth, -- 净资产增长率 + F055N as fixed_asset_growth, -- 固定资产增长率 + F056N as total_asset_growth, -- 总资产增长率 + F057N as investment_income_growth, -- 投资收益增长率 + F058N as operating_profit_growth, -- 营业利润增长率 + F141N as deducted_profit_growth, -- 扣除非经常性损益后的净利润同比变化率 + F142N as parent_profit_growth, -- 归属于母公司所有者的净利润同比变化率 + F143N as operating_cash_flow_growth, -- 经营活动产生的现金流净额同比变化率 + + -- 现金流量指标 + F061N as operating_cash_to_short_debt, -- 经营净现金比率(短期债务) + F062N as operating_cash_to_total_debt, -- 经营净现金比率(全部债务) + F063N as operating_cash_to_profit_ratio, -- 经营活动现金净流量与净利润比率 + F064N as cash_revenue_ratio, -- 营业收入现金含量 + F065N as cash_recovery_rate, -- 全部资产现金回收率 + F082N as cash_to_profit_ratio, -- 净利含金量 + + -- 财务结构指标 + F033N as current_asset_ratio, -- 流动资产比率 + F034N as cash_ratio_structure, -- 货币资金比率 + F036N as inventory_ratio, -- 存货比率 + F037N as fixed_asset_ratio, -- 固定资产比率 + F038N as liability_structure_ratio, -- 负债结构比 + F039N as equity_ratio, -- 产权比率 + F040N as net_asset_ratio, -- 净资产比率 + F046N as working_capital, -- 营运资金 + F047N as non_current_liability_ratio, -- 非流动负债比率 + F048N as current_liability_ratio, -- 流动负债比率 + + -- 非经常性损益 + F076N as deducted_net_profit, -- 扣除非经常性损益后的净利润 + F077N as non_recurring_items, -- 非经常性损益合计 + F083N as non_recurring_ratio, -- 非经常性损益占比 + + -- 综合指标 + F085N as ebit, -- 基本获利能力(EBIT) + F086N as receivable_to_asset_ratio, -- 应收账款占比 + F087N as inventory_to_asset_ratio -- 存货占比 + FROM ea_financialindex + WHERE SECCODE = :seccode + ORDER BY ENDDATE DESC LIMIT :limit + """) + + result = engine.execute(query, seccode=seccode, limit=limit) + data = [] + + for row in result: + period_data = { + 'period': format_date(row.ENDDATE), + 'start_date': format_date(row.STARTDATE), + 'report_type': get_report_type(row.ENDDATE), + + # 每股指标 + 'per_share_metrics': { + 'eps': format_decimal(row.eps), + 'basic_eps': format_decimal(row.basic_eps), + 'diluted_eps': format_decimal(row.diluted_eps), + 'deducted_eps': format_decimal(row.deducted_eps), + 'bvps': format_decimal(row.bvps), + 'adjusted_bvps': format_decimal(row.adjusted_bvps), + 'undistributed_profit_ps': format_decimal(row.undistributed_profit_ps), + 'capital_reserve_ps': format_decimal(row.capital_reserve_ps), + 'cash_flow_ps': format_decimal(row.cash_flow_ps), + 'operating_cash_flow_ps': format_decimal(row.operating_cash_flow_ps), + }, + + # 盈利能力 + 'profitability': { + 'roe': format_decimal(row.roe), + 'roe_deducted': format_decimal(row.roe_deducted), + 'roe_weighted': format_decimal(row.roe_weighted), + 'roa': format_decimal(row.roa), + 'gross_margin': format_decimal(row.gross_margin), + 'net_profit_margin': format_decimal(row.net_profit_margin), + 'operating_profit_margin': format_decimal(row.operating_profit_margin), + 'cost_profit_ratio': format_decimal(row.cost_profit_ratio), + 'ebit': format_decimal(row.ebit), + }, + + # 费用率 + 'expense_ratios': { + 'selling_expense_ratio': format_decimal(row.selling_expense_ratio), + 'admin_expense_ratio': format_decimal(row.admin_expense_ratio), + 'financial_expense_ratio': format_decimal(row.financial_expense_ratio), + 'rd_expense_ratio': format_decimal(row.rd_expense_ratio), + 'three_expense_ratio': format_decimal(row.three_expense_ratio), + 'four_expense_ratio': format_decimal(row.four_expense_ratio), + }, + + # 运营能力 + 'operational_efficiency': { + 'receivable_turnover': format_decimal(row.receivable_turnover), + 'receivable_days': format_decimal(row.receivable_days), + 'inventory_turnover': format_decimal(row.inventory_turnover), + 'inventory_days': format_decimal(row.inventory_days), + 'total_asset_turnover': format_decimal(row.total_asset_turnover), + 'total_asset_days': format_decimal(row.total_asset_days), + 'fixed_asset_turnover': format_decimal(row.fixed_asset_turnover), + 'current_asset_turnover': format_decimal(row.current_asset_turnover), + 'working_capital_turnover': format_decimal(row.working_capital_turnover), + }, + + # 偿债能力 + 'solvency': { + 'current_ratio': format_decimal(row.current_ratio), + 'quick_ratio': format_decimal(row.quick_ratio), + 'cash_ratio': format_decimal(row.cash_ratio), + 'conservative_quick_ratio': format_decimal(row.conservative_quick_ratio), + 'asset_liability_ratio': format_decimal(row.asset_liability_ratio), + 'interest_coverage': format_decimal(row.interest_coverage), + 'cash_to_maturity_debt_ratio': format_decimal(row.cash_to_maturity_debt_ratio), + 'tangible_asset_debt_ratio': format_decimal(row.tangible_asset_debt_ratio), + }, + + # 成长能力 + 'growth': { + 'revenue_growth': format_decimal(row.revenue_growth), + 'net_profit_growth': format_decimal(row.net_profit_growth), + 'deducted_profit_growth': format_decimal(row.deducted_profit_growth), + 'parent_profit_growth': format_decimal(row.parent_profit_growth), + 'equity_growth': format_decimal(row.equity_growth), + 'total_asset_growth': format_decimal(row.total_asset_growth), + 'fixed_asset_growth': format_decimal(row.fixed_asset_growth), + 'operating_profit_growth': format_decimal(row.operating_profit_growth), + 'operating_cash_flow_growth': format_decimal(row.operating_cash_flow_growth), + }, + + # 现金流量 + 'cash_flow_quality': { + 'operating_cash_to_profit_ratio': format_decimal(row.operating_cash_to_profit_ratio), + 'cash_to_profit_ratio': format_decimal(row.cash_to_profit_ratio), + 'cash_revenue_ratio': format_decimal(row.cash_revenue_ratio), + 'cash_recovery_rate': format_decimal(row.cash_recovery_rate), + 'operating_cash_to_short_debt': format_decimal(row.operating_cash_to_short_debt), + 'operating_cash_to_total_debt': format_decimal(row.operating_cash_to_total_debt), + }, + + # 财务结构 + 'financial_structure': { + 'current_asset_ratio': format_decimal(row.current_asset_ratio), + 'fixed_asset_ratio': format_decimal(row.fixed_asset_ratio), + 'inventory_ratio': format_decimal(row.inventory_ratio), + 'receivable_to_asset_ratio': format_decimal(row.receivable_to_asset_ratio), + 'current_liability_ratio': format_decimal(row.current_liability_ratio), + 'non_current_liability_ratio': format_decimal(row.non_current_liability_ratio), + 'equity_ratio': format_decimal(row.equity_ratio), + }, + + # 非经常性损益 + 'non_recurring': { + 'deducted_net_profit': format_decimal(row.deducted_net_profit), + 'non_recurring_items': format_decimal(row.non_recurring_items), + 'non_recurring_ratio': format_decimal(row.non_recurring_ratio), + } + } + data.append(period_data) + + return jsonify({ + 'success': True, + 'data': data + }) + + except Exception as e: + return jsonify({ + 'success': False, + 'error': str(e) + }), 500 + + +@app.route('/api/financial/main-business/', methods=['GET']) +def get_main_business(seccode): + """获取主营业务构成数据(包括产品和行业分类)""" + try: + limit = request.args.get('periods', 4, type=int) # 获取最近几期的数据 + + # 获取最近的报告期 + period_query = text(""" + SELECT DISTINCT ENDDATE + FROM ea_mainproduct + WHERE SECCODE = :seccode + ORDER BY ENDDATE DESC LIMIT :limit + """) + + periods = engine.execute(period_query, seccode=seccode, limit=limit).fetchall() + + # 产品分类数据 + product_data = [] + for period in periods: + query = text(""" + SELECT distinct ENDDATE, + F002V as category, + F003V as content, + F005N as revenue, + F006N as cost, + F007N as profit + FROM ea_mainproduct + WHERE SECCODE = :seccode + AND ENDDATE = :enddate + ORDER BY F005N DESC + """) + + result = engine.execute(query, seccode=seccode, enddate=period[0]) + # Convert result to list to allow multiple iterations + rows = list(result) + + period_products = [] + total_revenue = 0 + for row in rows: + if row.revenue: + total_revenue += row.revenue + + for row in rows: + product = { + 'category': row.category, + 'content': row.content, + 'revenue': format_decimal(row.revenue), + 'cost': format_decimal(row.cost), + 'profit': format_decimal(row.profit), + 'profit_margin': format_decimal( + (row.profit / row.revenue * 100) if row.revenue and row.profit else None), + 'revenue_ratio': format_decimal( + (row.revenue / total_revenue * 100) if total_revenue and row.revenue else None) + } + period_products.append(product) + + if period_products: + product_data.append({ + 'period': format_date(period[0]), + 'report_type': get_report_type(period[0]), + 'total_revenue': format_decimal(total_revenue), + 'products': period_products + }) + + # 行业分类数据(从ea_mainind表) + industry_data = [] + for period in periods: + query = text(""" + SELECT distinct ENDDATE, + F002V as business_content, + F007N as main_revenue, + F008N as main_cost, + F009N as main_profit, + F010N as gross_margin, + F012N as revenue_ratio + FROM ea_mainind + WHERE SECCODE = :seccode + AND ENDDATE = :enddate + ORDER BY F007N DESC + """) + + result = engine.execute(query, seccode=seccode, enddate=period[0]) + # Convert result to list to allow multiple iterations + rows = list(result) + + period_industries = [] + for row in rows: + industry = { + 'content': row.business_content, + 'revenue': format_decimal(row.main_revenue), + 'cost': format_decimal(row.main_cost), + 'profit': format_decimal(row.main_profit), + 'gross_margin': format_decimal(row.gross_margin), + 'revenue_ratio': format_decimal(row.revenue_ratio) + } + period_industries.append(industry) + + if period_industries: + industry_data.append({ + 'period': format_date(period[0]), + 'report_type': get_report_type(period[0]), + 'industries': period_industries + }) + + return jsonify({ + 'success': True, + 'data': { + 'product_classification': product_data, + 'industry_classification': industry_data + } + }) + + except Exception as e: + return jsonify({ + 'success': False, + 'error': str(e) + }), 500 + + +@app.route('/api/financial/forecast/', methods=['GET']) +def get_forecast(seccode): + """获取业绩预告和预披露时间""" + try: + # 获取业绩预告 + forecast_query = text(""" + SELECT distinct DECLAREDATE, + F001D as report_date, + F002V as forecast_type_code, + F003V as forecast_type, + F004V as content, + F005V as reason, + F006C as latest_flag, + F007N as profit_lower, + F008N as profit_upper, + F009N as change_lower, + F010N as change_upper, + UPDATE_DATE + FROM ea_forecast + WHERE SECCODE = :seccode + ORDER BY F001D DESC, UPDATE_DATE DESC LIMIT 10 + """) + + forecast_result = engine.execute(forecast_query, seccode=seccode) + forecast_data = [] + + for row in forecast_result: + forecast = { + 'declare_date': format_date(row.DECLAREDATE), + 'report_date': format_date(row.report_date), + 'report_type': get_report_type(row.report_date), + 'forecast_type': row.forecast_type, + 'forecast_type_code': row.forecast_type_code, + 'content': row.content, + 'reason': row.reason, + 'is_latest': row.latest_flag == 'T', + 'profit_range': { + 'lower': format_decimal(row.profit_lower), + 'upper': format_decimal(row.profit_upper), + }, + 'change_range': { + 'lower': format_decimal(row.change_lower), + 'upper': format_decimal(row.change_upper), + }, + 'update_date': format_date(row.UPDATE_DATE) + } + forecast_data.append(forecast) + + # 获取预披露时间 + pretime_query = text(""" + SELECT distinct F001D as report_period, + F002D as scheduled_date, + F003D as change_date_1, + F004D as change_date_2, + F005D as change_date_3, + F006D as actual_date, + F007D as change_date_4, + F008D as change_date_5, + UPDATE_DATE + FROM ea_pretime + WHERE SECCODE = :seccode + ORDER BY F001D DESC LIMIT 8 + """) + + pretime_result = engine.execute(pretime_query, seccode=seccode) + pretime_data = [] + + for row in pretime_result: + # 收集所有变更日期 + change_dates = [] + for date in [row.change_date_1, row.change_date_2, row.change_date_3, + row.change_date_4, row.change_date_5]: + if date: + change_dates.append(format_date(date)) + + pretime = { + 'report_period': format_date(row.report_period), + 'report_type': get_report_type(row.report_period), + 'scheduled_date': format_date(row.scheduled_date), + 'actual_date': format_date(row.actual_date), + 'change_dates': change_dates, + 'update_date': format_date(row.UPDATE_DATE), + 'status': 'completed' if row.actual_date else 'pending' + } + pretime_data.append(pretime) + + return jsonify({ + 'success': True, + 'data': { + 'forecasts': forecast_data, + 'disclosure_schedule': pretime_data + } + }) + + except Exception as e: + return jsonify({ + 'success': False, + 'error': str(e) + }), 500 + + +@app.route('/api/financial/industry-rank/', methods=['GET']) +def get_industry_rank(seccode): + """获取行业排名数据""" + try: + limit = request.args.get('limit', 4, type=int) + + query = text(""" + SELECT distinct F001V as industry_level, + F002V as level_description, + F003D as report_date, + INDNAME as industry_name, + -- 每股收益 + F004N as eps, + F005N as eps_industry_avg, + F006N as eps_rank, + -- 扣除后每股收益 + F007N as deducted_eps, + F008N as deducted_eps_industry_avg, + F009N as deducted_eps_rank, + -- 每股净资产 + F010N as bvps, + F011N as bvps_industry_avg, + F012N as bvps_rank, + -- 净资产收益率 + F013N as roe, + F014N as roe_industry_avg, + F015N as roe_rank, + -- 每股未分配利润 + F016N as undistributed_profit_ps, + F017N as undistributed_profit_ps_industry_avg, + F018N as undistributed_profit_ps_rank, + -- 每股经营现金流量 + F019N as operating_cash_flow_ps, + F020N as operating_cash_flow_ps_industry_avg, + F021N as operating_cash_flow_ps_rank, + -- 营业收入增长率 + F022N as revenue_growth, + F023N as revenue_growth_industry_avg, + F024N as revenue_growth_rank, + -- 净利润增长率 + F025N as profit_growth, + F026N as profit_growth_industry_avg, + F027N as profit_growth_rank, + -- 营业利润率 + F028N as operating_margin, + F029N as operating_margin_industry_avg, + F030N as operating_margin_rank, + -- 资产负债率 + F031N as debt_ratio, + F032N as debt_ratio_industry_avg, + F033N as debt_ratio_rank, + -- 应收账款周转率 + F034N as receivable_turnover, + F035N as receivable_turnover_industry_avg, + F036N as receivable_turnover_rank, + UPDATE_DATE + FROM ea_finindexrank + WHERE SECCODE = :seccode + ORDER BY F003D DESC, F001V ASC LIMIT :limit_total + """) + + # 获取多个报告期的数据 + result = engine.execute(query, seccode=seccode, limit_total=limit * 4) + + # 按报告期和行业级别组织数据 + data_by_period = {} + for row in result: + period = format_date(row.report_date) + if period not in data_by_period: + data_by_period[period] = [] + + rank_data = { + 'industry_level': row.industry_level, + 'level_description': row.level_description, + 'industry_name': row.industry_name, + 'metrics': { + 'eps': { + 'value': format_decimal(row.eps), + 'industry_avg': format_decimal(row.eps_industry_avg), + 'rank': int(row.eps_rank) if row.eps_rank else None + }, + 'deducted_eps': { + 'value': format_decimal(row.deducted_eps), + 'industry_avg': format_decimal(row.deducted_eps_industry_avg), + 'rank': int(row.deducted_eps_rank) if row.deducted_eps_rank else None + }, + 'bvps': { + 'value': format_decimal(row.bvps), + 'industry_avg': format_decimal(row.bvps_industry_avg), + 'rank': int(row.bvps_rank) if row.bvps_rank else None + }, + 'roe': { + 'value': format_decimal(row.roe), + 'industry_avg': format_decimal(row.roe_industry_avg), + 'rank': int(row.roe_rank) if row.roe_rank else None + }, + 'operating_cash_flow_ps': { + 'value': format_decimal(row.operating_cash_flow_ps), + 'industry_avg': format_decimal(row.operating_cash_flow_ps_industry_avg), + 'rank': int(row.operating_cash_flow_ps_rank) if row.operating_cash_flow_ps_rank else None + }, + 'revenue_growth': { + 'value': format_decimal(row.revenue_growth), + 'industry_avg': format_decimal(row.revenue_growth_industry_avg), + 'rank': int(row.revenue_growth_rank) if row.revenue_growth_rank else None + }, + 'profit_growth': { + 'value': format_decimal(row.profit_growth), + 'industry_avg': format_decimal(row.profit_growth_industry_avg), + 'rank': int(row.profit_growth_rank) if row.profit_growth_rank else None + }, + 'operating_margin': { + 'value': format_decimal(row.operating_margin), + 'industry_avg': format_decimal(row.operating_margin_industry_avg), + 'rank': int(row.operating_margin_rank) if row.operating_margin_rank else None + }, + 'debt_ratio': { + 'value': format_decimal(row.debt_ratio), + 'industry_avg': format_decimal(row.debt_ratio_industry_avg), + 'rank': int(row.debt_ratio_rank) if row.debt_ratio_rank else None + }, + 'receivable_turnover': { + 'value': format_decimal(row.receivable_turnover), + 'industry_avg': format_decimal(row.receivable_turnover_industry_avg), + 'rank': int(row.receivable_turnover_rank) if row.receivable_turnover_rank else None + } + } + } + data_by_period[period].append(rank_data) + + # 转换为列表格式 + data = [] + for period, ranks in data_by_period.items(): + data.append({ + 'period': period, + 'report_type': get_report_type(period), + 'rankings': ranks + }) + + return jsonify({ + 'success': True, + 'data': data + }) + + except Exception as e: + return jsonify({ + 'success': False, + 'error': str(e) + }), 500 + + +@app.route('/api/financial/comparison/', methods=['GET']) +def get_period_comparison(seccode): + """获取不同报告期的对比数据""" + try: + periods = request.args.get('periods', 8, type=int) + + # 获取多期财务数据进行对比 + query = text(""" + SELECT distinct fi.ENDDATE, + fi.F089N as revenue, + fi.F101N as net_profit, + fi.F102N as parent_net_profit, + fi.F078N as gross_margin, + fi.F017N as net_margin, + fi.F014N as roe, + fi.F016N as roa, + fi.F052N as revenue_growth, + fi.F053N as profit_growth, + fi.F003N as eps, + fi.F060N as operating_cash_flow_ps, + fi.F042N as current_ratio, + fi.F041N as debt_ratio, + fi.F105N as operating_cash_flow, + fi.F118N as total_assets, + fi.F121N as total_liabilities, + fi.F128N as total_equity + FROM ea_financialindex fi + WHERE fi.SECCODE = :seccode + ORDER BY fi.ENDDATE DESC LIMIT :periods + """) + + result = engine.execute(query, seccode=seccode, periods=periods) + + data = [] + for row in result: + period_data = { + 'period': format_date(row.ENDDATE), + 'report_type': get_report_type(row.ENDDATE), + 'performance': { + 'revenue': format_decimal(row.revenue), + 'net_profit': format_decimal(row.net_profit), + 'parent_net_profit': format_decimal(row.parent_net_profit), + 'operating_cash_flow': format_decimal(row.operating_cash_flow), + }, + 'profitability': { + 'gross_margin': format_decimal(row.gross_margin), + 'net_margin': format_decimal(row.net_margin), + 'roe': format_decimal(row.roe), + 'roa': format_decimal(row.roa), + }, + 'growth': { + 'revenue_growth': format_decimal(row.revenue_growth), + 'profit_growth': format_decimal(row.profit_growth), + }, + 'per_share': { + 'eps': format_decimal(row.eps), + 'operating_cash_flow_ps': format_decimal(row.operating_cash_flow_ps), + }, + 'financial_health': { + 'current_ratio': format_decimal(row.current_ratio), + 'debt_ratio': format_decimal(row.debt_ratio), + 'total_assets': format_decimal(row.total_assets), + 'total_liabilities': format_decimal(row.total_liabilities), + 'total_equity': format_decimal(row.total_equity), + } + } + data.append(period_data) + + # 计算同比和环比变化 + for i in range(len(data)): + if i > 0: # 环比 + data[i]['qoq_changes'] = { + 'revenue': calculate_change(data[i]['performance']['revenue'], + data[i - 1]['performance']['revenue']), + 'net_profit': calculate_change(data[i]['performance']['net_profit'], + data[i - 1]['performance']['net_profit']), + } + + # 同比(找到去年同期) + current_period = data[i]['period'] + yoy_period = get_yoy_period(current_period) + for j in range(len(data)): + if data[j]['period'] == yoy_period: + data[i]['yoy_changes'] = { + 'revenue': calculate_change(data[i]['performance']['revenue'], + data[j]['performance']['revenue']), + 'net_profit': calculate_change(data[i]['performance']['net_profit'], + data[j]['performance']['net_profit']), + } + break + + return jsonify({ + 'success': True, + 'data': data + }) + + except Exception as e: + return jsonify({ + 'success': False, + 'error': str(e) + }), 500 + + +# 辅助函数 +def calculate_change(current, previous): + """计算变化率""" + if previous and current: + return format_decimal((current - previous) / abs(previous) * 100) + return None + + +def get_yoy_period(date_str): + """获取去年同期""" + if not date_str: + return None + try: + date = datetime.strptime(date_str, '%Y-%m-%d') + yoy_date = date.replace(year=date.year - 1) + return yoy_date.strftime('%Y-%m-%d') + except: + return None + + +@app.route('/api/market/trade/', methods=['GET']) +def get_trade_data(seccode): + """获取股票交易数据(日K线)""" + try: + days = request.args.get('days', 60, type=int) + end_date = request.args.get('end_date', datetime.now().strftime('%Y-%m-%d')) + + query = text(""" + SELECT TRADEDATE, + SECNAME, + F002N as pre_close, + F003N as open, + F004N as volume, + F005N as high, + F006N as low, + F007N as close, + F008N as trades_count, + F009N as change_amount, + F010N as change_percent, + F011N as amount, + F012N as turnover_rate, + F013N as amplitude, + F020N as total_shares, + F021N as float_shares, + F026N as pe_ratio + FROM ea_trade + WHERE SECCODE = :seccode + AND TRADEDATE <= :end_date + ORDER BY TRADEDATE DESC + LIMIT :days + """) + + result = engine.execute(query, seccode=seccode, end_date=end_date, days=days) + + data = [] + for row in result: + data.append({ + 'date': format_date(row.TRADEDATE), + 'stock_name': row.SECNAME, + 'open': format_decimal(row.open), + 'high': format_decimal(row.high), + 'low': format_decimal(row.low), + 'close': format_decimal(row.close), + 'pre_close': format_decimal(row.pre_close), + 'volume': format_decimal(row.volume), + 'amount': format_decimal(row.amount), + 'change_amount': format_decimal(row.change_amount), + 'change_percent': format_decimal(row.change_percent), + 'turnover_rate': format_decimal(row.turnover_rate), + 'amplitude': format_decimal(row.amplitude), + 'trades_count': format_decimal(row.trades_count), + 'pe_ratio': format_decimal(row.pe_ratio), + 'total_shares': format_decimal(row.total_shares), + 'float_shares': format_decimal(row.float_shares), + }) + + # 倒序,让最早的日期在前 + data.reverse() + + # 计算统计数据 + if data: + prices = [d['close'] for d in data if d['close']] + stats = { + 'highest': max(prices) if prices else None, + 'lowest': min(prices) if prices else None, + 'average': sum(prices) / len(prices) if prices else None, + 'latest_price': data[-1]['close'] if data else None, + 'total_volume': sum([d['volume'] for d in data if d['volume']]) if data else None, + 'total_amount': sum([d['amount'] for d in data if d['amount']]) if data else None, + } + else: + stats = {} + + return jsonify({ + 'success': True, + 'data': data, + 'stats': stats + }) + + except Exception as e: + return jsonify({ + 'success': False, + 'error': str(e) + }), 500 + + +@app.route('/api/market/funding/', methods=['GET']) +def get_funding_data(seccode): + """获取融资融券数据""" + try: + days = request.args.get('days', 30, type=int) + + query = text(""" + SELECT TRADEDATE, + SECNAME, + F001N as financing_balance, + F002N as financing_buy, + F003N as financing_repay, + F004N as securities_balance, + F006N as securities_sell, + F007N as securities_repay, + F008N as securities_balance_amount, + F009N as total_balance + FROM ea_funding + WHERE SECCODE = :seccode + ORDER BY TRADEDATE DESC LIMIT :days + """) + + result = engine.execute(query, seccode=seccode, days=days) + + data = [] + for row in result: + data.append({ + 'date': format_date(row.TRADEDATE), + 'stock_name': row.SECNAME, + 'financing': { + 'balance': format_decimal(row.financing_balance), + 'buy': format_decimal(row.financing_buy), + 'repay': format_decimal(row.financing_repay), + 'net': format_decimal( + row.financing_buy - row.financing_repay) if row.financing_buy and row.financing_repay else None + }, + 'securities': { + 'balance': format_decimal(row.securities_balance), + 'sell': format_decimal(row.securities_sell), + 'repay': format_decimal(row.securities_repay), + 'balance_amount': format_decimal(row.securities_balance_amount) + }, + 'total_balance': format_decimal(row.total_balance) + }) + + data.reverse() + + return jsonify({ + 'success': True, + 'data': data + }) + + except Exception as e: + return jsonify({ + 'success': False, + 'error': str(e) + }), 500 + + +@app.route('/api/market/bigdeal/', methods=['GET']) +def get_bigdeal_data(seccode): + """获取大宗交易数据""" + try: + days = request.args.get('days', 30, type=int) + + query = text(""" + SELECT TRADEDATE, + SECNAME, + F001V as exchange, + F002V as buyer_dept, + F003V as seller_dept, + F004N as price, + F005N as volume, + F006N as amount, + F007N as seq_no + FROM ea_bigdeal + WHERE SECCODE = :seccode + ORDER BY TRADEDATE DESC, F007N LIMIT :days + """) + + result = engine.execute(query, seccode=seccode, days=days) + + data = [] + for row in result: + data.append({ + 'date': format_date(row.TRADEDATE), + 'stock_name': row.SECNAME, + 'exchange': row.exchange, + 'buyer_dept': row.buyer_dept, + 'seller_dept': row.seller_dept, + 'price': format_decimal(row.price), + 'volume': format_decimal(row.volume), + 'amount': format_decimal(row.amount), + 'seq_no': int(row.seq_no) if row.seq_no else None + }) + + # 按日期分组统计 + daily_stats = {} + for item in data: + date = item['date'] + if date not in daily_stats: + daily_stats[date] = { + 'date': date, + 'count': 0, + 'total_volume': 0, + 'total_amount': 0, + 'avg_price': 0, + 'deals': [] + } + daily_stats[date]['count'] += 1 + daily_stats[date]['total_volume'] += item['volume'] or 0 + daily_stats[date]['total_amount'] += item['amount'] or 0 + daily_stats[date]['deals'].append(item) + + # 计算平均价格 + for date in daily_stats: + if daily_stats[date]['total_volume'] > 0: + daily_stats[date]['avg_price'] = daily_stats[date]['total_amount'] / daily_stats[date]['total_volume'] + + return jsonify({ + 'success': True, + 'data': data, + 'daily_stats': list(daily_stats.values()) + }) + + except Exception as e: + return jsonify({ + 'success': False, + 'error': str(e) + }), 500 + + +@app.route('/api/market/unusual/', methods=['GET']) +def get_unusual_data(seccode): + """获取龙虎榜数据""" + try: + days = request.args.get('days', 30, type=int) + + query = text(""" + SELECT TRADEDATE, + SECNAME, + F001V as info_type_code, + F002V as info_type, + F003C as trade_type, + F004N as rank_no, + F005V as dept_name, + F006N as buy_amount, + F007N as sell_amount, + F008N as net_amount + FROM ea_unusual + WHERE SECCODE = :seccode + ORDER BY TRADEDATE DESC, F004N LIMIT 100 + """) + + result = engine.execute(query, seccode=seccode) + + data = [] + for row in result: + data.append({ + 'date': format_date(row.TRADEDATE), + 'stock_name': row.SECNAME, + 'info_type': row.info_type, + 'info_type_code': row.info_type_code, + 'trade_type': 'buy' if row.trade_type == 'B' else 'sell' if row.trade_type == 'S' else 'unknown', + 'rank': int(row.rank_no) if row.rank_no else None, + 'dept_name': row.dept_name, + 'buy_amount': format_decimal(row.buy_amount), + 'sell_amount': format_decimal(row.sell_amount), + 'net_amount': format_decimal(row.net_amount) + }) + + # 按日期分组 + grouped_data = {} + for item in data: + date = item['date'] + if date not in grouped_data: + grouped_data[date] = { + 'date': date, + 'info_types': set(), + 'buyers': [], + 'sellers': [], + 'total_buy': 0, + 'total_sell': 0, + 'net_amount': 0 + } + + grouped_data[date]['info_types'].add(item['info_type']) + + if item['trade_type'] == 'buy': + grouped_data[date]['buyers'].append(item) + grouped_data[date]['total_buy'] += item['buy_amount'] or 0 + elif item['trade_type'] == 'sell': + grouped_data[date]['sellers'].append(item) + grouped_data[date]['total_sell'] += item['sell_amount'] or 0 + + grouped_data[date]['net_amount'] = grouped_data[date]['total_buy'] - grouped_data[date]['total_sell'] + + # 转换set为list + for date in grouped_data: + grouped_data[date]['info_types'] = list(grouped_data[date]['info_types']) + + return jsonify({ + 'success': True, + 'data': data, + 'grouped_data': list(grouped_data.values()) + }) + + except Exception as e: + return jsonify({ + 'success': False, + 'error': str(e) + }), 500 + + +@app.route('/api/market/pledge/', methods=['GET']) +def get_pledge_data(seccode): + """获取股权质押数据""" + try: + query = text(""" + SELECT ENDDATE, + STARTDATE, + SECNAME, + F001N as unrestricted_pledge, + F002N as restricted_pledge, + F003N as total_shares_a, + F004N as pledge_count, + F005N as pledge_ratio + FROM ea_pledgeratio + WHERE SECCODE = :seccode + ORDER BY ENDDATE DESC LIMIT 12 + """) + + result = engine.execute(query, seccode=seccode) + + data = [] + for row in result: + total_pledge = (row.unrestricted_pledge or 0) + (row.restricted_pledge or 0) + data.append({ + 'end_date': format_date(row.ENDDATE), + 'start_date': format_date(row.STARTDATE), + 'stock_name': row.SECNAME, + 'unrestricted_pledge': format_decimal(row.unrestricted_pledge), + 'restricted_pledge': format_decimal(row.restricted_pledge), + 'total_pledge': format_decimal(total_pledge), + 'total_shares': format_decimal(row.total_shares_a), + 'pledge_count': int(row.pledge_count) if row.pledge_count else None, + 'pledge_ratio': format_decimal(row.pledge_ratio) + }) + + return jsonify({ + 'success': True, + 'data': data + }) + + except Exception as e: + return jsonify({ + 'success': False, + 'error': str(e) + }), 500 + + +@app.route('/api/market/summary/', methods=['GET']) +def get_market_summary(seccode): + """获取市场数据汇总""" + try: + # 获取最新交易数据 + trade_query = text(""" + SELECT * + FROM ea_trade + WHERE SECCODE = :seccode + ORDER BY TRADEDATE DESC LIMIT 1 + """) + + # 获取最新融资融券数据 + funding_query = text(""" + SELECT * + FROM ea_funding + WHERE SECCODE = :seccode + ORDER BY TRADEDATE DESC LIMIT 1 + """) + + # 获取最新质押数据 + pledge_query = text(""" + SELECT * + FROM ea_pledgeratio + WHERE SECCODE = :seccode + ORDER BY ENDDATE DESC LIMIT 1 + """) + + trade_result = engine.execute(trade_query, seccode=seccode).fetchone() + funding_result = engine.execute(funding_query, seccode=seccode).fetchone() + pledge_result = engine.execute(pledge_query, seccode=seccode).fetchone() + + summary = { + 'stock_code': seccode, + 'stock_name': trade_result.SECNAME if trade_result else None, + 'latest_trade': { + 'date': format_date(trade_result.TRADEDATE) if trade_result else None, + 'close': format_decimal(trade_result.F007N) if trade_result else None, + 'change_percent': format_decimal(trade_result.F010N) if trade_result else None, + 'volume': format_decimal(trade_result.F004N) if trade_result else None, + 'amount': format_decimal(trade_result.F011N) if trade_result else None, + 'pe_ratio': format_decimal(trade_result.F026N) if trade_result else None, + 'turnover_rate': format_decimal(trade_result.F012N) if trade_result else None, + } if trade_result else None, + 'latest_funding': { + 'date': format_date(funding_result.TRADEDATE) if funding_result else None, + 'financing_balance': format_decimal(funding_result.F001N) if funding_result else None, + 'securities_balance': format_decimal(funding_result.F004N) if funding_result else None, + 'total_balance': format_decimal(funding_result.F009N) if funding_result else None, + } if funding_result else None, + 'latest_pledge': { + 'date': format_date(pledge_result.ENDDATE) if pledge_result else None, + 'pledge_ratio': format_decimal(pledge_result.F005N) if pledge_result else None, + 'pledge_count': int(pledge_result.F004N) if pledge_result and pledge_result.F004N else None, + } if pledge_result else None + } + + return jsonify({ + 'success': True, + 'data': summary + }) + + except Exception as e: + return jsonify({ + 'success': False, + 'error': str(e) + }), 500 + + +@app.route('/api/stocks/search', methods=['GET']) +def search_stocks(): + """搜索股票(支持股票代码、股票简称、拼音首字母)""" + try: + query = request.args.get('q', '').strip() + limit = request.args.get('limit', 20, type=int) + + if not query: + return jsonify({ + 'success': False, + 'error': '请输入搜索关键词' + }), 400 + + with engine.connect() as conn: + test_sql = text(""" + SELECT SECCODE, SECNAME, F001V, F003V, F010V, F011V + FROM ea_stocklist + WHERE SECCODE = '300750' + OR F001V LIKE '%ndsd%' LIMIT 5 + """) + test_result = conn.execute(test_sql).fetchall() + + # 构建搜索SQL - 支持股票代码、股票简称、拼音简称搜索 + search_sql = text(""" + SELECT DISTINCT SECCODE as stock_code, + SECNAME as stock_name, + F001V as pinyin_abbr, + F003V as security_type, + F005V as exchange, + F011V as listing_status + FROM ea_stocklist + WHERE ( + UPPER(SECCODE) LIKE UPPER(:query_pattern) + OR UPPER(SECNAME) LIKE UPPER(:query_pattern) + OR UPPER(F001V) LIKE UPPER(:query_pattern) + ) + -- 基本过滤条件:只搜索正常的A股和B股 + AND (F011V = '正常上市' OR F010V = '013001') -- 正常上市状态 + AND F003V IN ('A股', 'B股') -- 只搜索A股和B股 + ORDER BY CASE + WHEN UPPER(SECCODE) = UPPER(:exact_query) THEN 1 + WHEN UPPER(SECNAME) = UPPER(:exact_query) THEN 2 + WHEN UPPER(F001V) = UPPER(:exact_query) THEN 3 + WHEN UPPER(SECCODE) LIKE UPPER(:prefix_pattern) THEN 4 + WHEN UPPER(SECNAME) LIKE UPPER(:prefix_pattern) THEN 5 + WHEN UPPER(F001V) LIKE UPPER(:prefix_pattern) THEN 6 + ELSE 7 + END, + SECCODE LIMIT :limit + """) + + result = conn.execute(search_sql, { + 'query_pattern': f'%{query}%', + 'exact_query': query, + 'prefix_pattern': f'{query}%', + 'limit': limit + }).fetchall() + + stocks = [] + for row in result: + # 获取当前价格 + current_price, _ = get_latest_price_from_clickhouse(row.stock_code) + + stocks.append({ + 'stock_code': row.stock_code, + 'stock_name': row.stock_name, + 'current_price': current_price or 0, # 添加当前价格 + 'pinyin_abbr': row.pinyin_abbr, + 'security_type': row.security_type, + 'exchange': row.exchange, + 'listing_status': row.listing_status + }) + + return jsonify({ + 'success': True, + 'data': stocks, + 'count': len(stocks) + }) + + except Exception as e: + return jsonify({ + 'success': False, + 'error': str(e) + }), 500 + + +@app.route('/api/market/heatmap', methods=['GET']) +def get_market_heatmap(): + """获取市场热力图数据(基于市值和涨跌幅)""" + try: + # 获取交易日期参数 + trade_date = request.args.get('date') + # 前端显示用的limit,但统计数据会基于全部股票 + display_limit = request.args.get('limit', 500, type=int) + + with engine.connect() as conn: + # 如果没有指定日期,获取最新交易日 + if not trade_date: + latest_date_result = conn.execute(text(""" + SELECT MAX(TRADEDATE) as latest_date + FROM ea_trade + """)).fetchone() + trade_date = latest_date_result.latest_date if latest_date_result else None + + if not trade_date: + return jsonify({ + 'success': False, + 'error': '无法获取交易数据' + }), 404 + + # 获取全部股票数据用于统计 + all_stocks_sql = text(""" + SELECT t.SECCODE as stock_code, + t.SECNAME as stock_name, + t.F010N as change_percent, -- 涨跌幅 + t.F007N as close_price, -- 收盘价 + t.F021N * t.F007N / 100000000 as market_cap, -- 市值(亿元) + t.F011N / 100000000 as amount, -- 成交额(亿元) + t.F012N as turnover_rate, -- 换手率 + b.F034V as industry, -- 申万行业分类一级名称 + b.F026V as province -- 所属省份 + FROM ea_trade t + LEFT JOIN ea_baseinfo b ON t.SECCODE = b.SECCODE + WHERE t.TRADEDATE = :trade_date + AND t.F010N IS NOT NULL -- 仅统计当日有涨跌幅数据的股票 + ORDER BY market_cap DESC + """) + + all_result = conn.execute(all_stocks_sql, { + 'trade_date': trade_date + }).fetchall() + + # 计算统计数据(基于全部股票) + total_market_cap = 0 + total_amount = 0 + rising_count = 0 + falling_count = 0 + flat_count = 0 + + all_data = [] + for row in all_result: + # F010N 已在 SQL 中确保非空 + change_percent = float(row.change_percent) + market_cap = float(row.market_cap) if row.market_cap else 0 + amount = float(row.amount) if row.amount else 0 + + total_market_cap += market_cap + total_amount += amount + + if change_percent > 0: + rising_count += 1 + elif change_percent < 0: + falling_count += 1 + else: + flat_count += 1 + + all_data.append({ + 'stock_code': row.stock_code, + 'stock_name': row.stock_name, + 'change_percent': change_percent, + 'close_price': float(row.close_price) if row.close_price else 0, + 'market_cap': market_cap, + 'amount': amount, + 'turnover_rate': float(row.turnover_rate) if row.turnover_rate else 0, + 'industry': row.industry, + 'province': row.province + }) + + # 只返回前display_limit条用于热力图显示 + heatmap_data = all_data[:display_limit] + + return jsonify({ + 'success': True, + 'data': heatmap_data, + 'trade_date': trade_date.strftime('%Y-%m-%d') if hasattr(trade_date, 'strftime') else str(trade_date), + 'count': len(all_data), # 全部股票数量 + 'display_count': len(heatmap_data), # 显示的股票数量 + 'statistics': { + 'total_market_cap': round(total_market_cap, 2), # 总市值(亿元) + 'total_amount': round(total_amount, 2), # 总成交额(亿元) + 'rising_count': rising_count, # 上涨家数 + 'falling_count': falling_count, # 下跌家数 + 'flat_count': flat_count # 平盘家数 + } + }) + + except Exception as e: + return jsonify({ + 'success': False, + 'error': str(e) + }), 500 + + +@app.route('/api/market/statistics', methods=['GET']) +def get_market_statistics(): + """获取市场统计数据(从ea_blocktrading表)""" + try: + # 获取交易日期参数 + trade_date = request.args.get('date') + + with engine.connect() as conn: + # 如果没有指定日期,获取最新交易日 + if not trade_date: + latest_date_result = conn.execute(text(""" + SELECT MAX(TRADEDATE) as latest_date + FROM ea_blocktrading + """)).fetchone() + trade_date = latest_date_result.latest_date if latest_date_result else None + + if not trade_date: + return jsonify({ + 'success': False, + 'error': '无法获取统计数据' + }), 404 + + # 获取沪深两市的统计数据 + stats_sql = text(""" + SELECT EXCHANGECODE, + EXCHANGENAME, + F001V as indicator_code, + F002V as indicator_name, + F003N as indicator_value, + F004V as unit, + TRADEDATE + FROM ea_blocktrading + WHERE TRADEDATE = :trade_date + AND EXCHANGECODE IN ('012001', '012002') -- 只获取上交所和深交所的数据 + AND F001V IN ( + '250006', '250014', -- 深交所股票总市值、上交所市价总值 + '250007', '250015', -- 深交所股票流通市值、上交所流通市值 + '250008', -- 深交所股票成交金额 + '250010', '250019', -- 深交所股票平均市盈率、上交所平均市盈率 + '250050', '250001' -- 上交所上市公司家数、深交所上市公司数 + ) + """) + + result = conn.execute(stats_sql, { + 'trade_date': trade_date + }).fetchall() + + # 整理数据 + statistics = {} + for row in result: + key = f"{row.EXCHANGECODE}_{row.indicator_code}" + statistics[key] = { + 'exchange_code': row.EXCHANGECODE, + 'exchange_name': row.EXCHANGENAME, + 'indicator_code': row.indicator_code, + 'indicator_name': row.indicator_name, + 'value': float(row.indicator_value) if row.indicator_value else 0, + 'unit': row.unit + } + + # 汇总数据 + summary = { + 'total_market_cap': 0, # 总市值 + 'total_float_cap': 0, # 流通市值 + 'total_amount': 0, # 成交额 + 'sh_pe_ratio': 0, # 上交所市盈率 + 'sz_pe_ratio': 0, # 深交所市盈率 + 'sh_companies': 0, # 上交所上市公司数 + 'sz_companies': 0 # 深交所上市公司数 + } + + # 计算汇总值 + if '012001_250014' in statistics: # 上交所市价总值 + summary['total_market_cap'] += statistics['012001_250014']['value'] + if '012002_250006' in statistics: # 深交所股票总市值 + summary['total_market_cap'] += statistics['012002_250006']['value'] + + if '012001_250015' in statistics: # 上交所流通市值 + summary['total_float_cap'] += statistics['012001_250015']['value'] + if '012002_250007' in statistics: # 深交所股票流通市值 + summary['total_float_cap'] += statistics['012002_250007']['value'] + + # 成交额需要获取上交所的数据 + # 获取上交所成交金额 + sh_amount_result = conn.execute(text(""" + SELECT F003N + FROM ea_blocktrading + WHERE TRADEDATE = :trade_date + AND EXCHANGECODE = '012001' + AND F002V LIKE '%成交金额%' LIMIT 1 + """), {'trade_date': trade_date}).fetchone() + + sh_amount = float(sh_amount_result.F003N) if sh_amount_result and sh_amount_result.F003N else 0 + sz_amount = statistics['012002_250008']['value'] if '012002_250008' in statistics else 0 + summary['total_amount'] = sh_amount + sz_amount + + if '012001_250019' in statistics: # 上交所平均市盈率 + summary['sh_pe_ratio'] = statistics['012001_250019']['value'] + if '012002_250010' in statistics: # 深交所股票平均市盈率 + summary['sz_pe_ratio'] = statistics['012002_250010']['value'] + + if '012001_250050' in statistics: # 上交所上市公司家数 + summary['sh_companies'] = int(statistics['012001_250050']['value']) + if '012002_250001' in statistics: # 深交所上市公司数 + summary['sz_companies'] = int(statistics['012002_250001']['value']) + + # 获取可用的交易日期列表 + available_dates_result = conn.execute(text(""" + SELECT DISTINCT TRADEDATE + FROM ea_blocktrading + WHERE EXCHANGECODE IN ('012001', '012002') + ORDER BY TRADEDATE DESC LIMIT 30 + """)).fetchall() + + available_dates = [str(row.TRADEDATE) for row in available_dates_result] + + return jsonify({ + 'success': True, + 'trade_date': str(trade_date), + 'summary': summary, + 'details': list(statistics.values()), + 'available_dates': available_dates + }) + + except Exception as e: + return jsonify({ + 'success': False, + 'error': str(e) + }), 500 + + +@app.route('/api/concepts/daily-top', methods=['GET']) +def get_daily_top_concepts(): + """获取每日涨幅靠前的概念板块""" + try: + # 获取交易日期参数 + trade_date = request.args.get('date') + limit = request.args.get('limit', 6, type=int) + + # 构建概念中心API的URL + concept_api_url = 'http://222.128.1.157:16801/search' + + # 准备请求数据 + request_data = { + 'query': '', + 'size': limit, + 'page': 1, + 'sort_by': 'change_pct' + } + + if trade_date: + request_data['trade_date'] = trade_date + + # 调用概念中心API + response = requests.post(concept_api_url, json=request_data, timeout=10) + + if response.status_code == 200: + data = response.json() + top_concepts = [] + + for concept in data.get('results', []): + top_concepts.append({ + 'concept_id': concept.get('concept_id'), + 'concept_name': concept.get('concept'), + 'description': concept.get('description'), + 'change_percent': concept.get('price_info', {}).get('avg_change_pct', 0), + 'stock_count': concept.get('stock_count', 0), + 'stocks': concept.get('stocks', [])[:5] # 只返回前5只股票 + }) + + return jsonify({ + 'success': True, + 'data': top_concepts, + 'trade_date': data.get('price_date'), + 'count': len(top_concepts) + }) + else: + return jsonify({ + 'success': False, + 'error': '获取概念数据失败' + }), 500 + + except Exception as e: + return jsonify({ + 'success': False, + 'error': str(e) + }), 500 + + +@app.route('/api/market/rise-analysis/', methods=['GET']) +def get_rise_analysis(seccode): + """获取股票涨幅分析数据""" + try: + # 获取日期范围参数 + start_date = request.args.get('start_date') + end_date = request.args.get('end_date') + + query = text(""" + SELECT stock_code, + stock_name, + trade_date, + rise_rate, + close_price, + volume, + amount, + main_business, + rise_reason_brief, + rise_reason_detail, + news_summary, + announcements, + guba_sentiment, + analysis_time + FROM stock_rise_analysis + WHERE stock_code = :stock_code + """) + + params = {'stock_code': seccode} + + # 添加日期筛选 + if start_date and end_date: + query = text(""" + SELECT stock_code, + stock_name, + trade_date, + rise_rate, + close_price, + volume, + amount, + main_business, + rise_reason_brief, + rise_reason_detail, + news_summary, + announcements, + guba_sentiment, + analysis_time + FROM stock_rise_analysis + WHERE stock_code = :stock_code + AND trade_date BETWEEN :start_date AND :end_date + ORDER BY trade_date DESC + """) + params['start_date'] = start_date + params['end_date'] = end_date + else: + query = text(""" + SELECT stock_code, + stock_name, + trade_date, + rise_rate, + close_price, + volume, + amount, + main_business, + rise_reason_brief, + rise_reason_detail, + news_summary, + announcements, + guba_sentiment, + analysis_time + FROM stock_rise_analysis + WHERE stock_code = :stock_code + ORDER BY trade_date DESC LIMIT 100 + """) + + result = engine.execute(query, **params).fetchall() + + # 格式化数据 + rise_analysis_data = [] + for row in result: + rise_analysis_data.append({ + 'stock_code': row.stock_code, + 'stock_name': row.stock_name, + 'trade_date': format_date(row.trade_date), + 'rise_rate': format_decimal(row.rise_rate), + 'close_price': format_decimal(row.close_price), + 'volume': format_decimal(row.volume), + 'amount': format_decimal(row.amount), + 'main_business': row.main_business, + 'rise_reason_brief': row.rise_reason_brief, + 'rise_reason_detail': row.rise_reason_detail, + 'news_summary': row.news_summary, + 'announcements': row.announcements, + 'guba_sentiment': row.guba_sentiment, + 'analysis_time': row.analysis_time.strftime('%Y-%m-%d %H:%M:%S') if row.analysis_time else None + }) + + return jsonify({ + 'success': True, + 'data': rise_analysis_data, + 'count': len(rise_analysis_data) + }) + + except Exception as e: + return jsonify({ + 'success': False, + 'error': str(e) + }), 500 + + +# ============================================ +# 公司分析相关接口 +# ============================================ + +@app.route('/api/company/comprehensive-analysis/', methods=['GET']) +def get_comprehensive_analysis(company_code): + """获取公司综合分析数据""" + try: + # 获取公司定性分析 + qualitative_query = text(""" + SELECT one_line_intro, + investment_highlights, + business_model_desc, + company_story, + positioning_analysis, + unique_value_proposition, + business_logic_explanation, + revenue_driver_analysis, + customer_value_analysis, + strategy_description, + strategic_initiatives, + created_at, + updated_at + FROM company_analysis + WHERE company_code = :company_code + """) + + qualitative_result = engine.execute(qualitative_query, company_code=company_code).fetchone() + + # 获取业务板块分析 + segments_query = text(""" + SELECT segment_name, + segment_description, + competitive_position, + future_potential, + key_customers, + value_chain_position, + created_at, + updated_at + FROM business_segment_analysis + WHERE company_code = :company_code + ORDER BY created_at DESC + """) + + segments_result = engine.execute(segments_query, company_code=company_code).fetchall() + + # 获取竞争地位数据 - 最新一期 + competitive_query = text(""" + SELECT market_position_score, + technology_score, + brand_score, + operation_score, + finance_score, + innovation_score, + risk_score, + growth_score, + industry_avg_comparison, + main_competitors, + competitive_advantages, + competitive_disadvantages, + industry_rank, + total_companies, + report_period, + updated_at + FROM company_competitive_position + WHERE company_code = :company_code + ORDER BY report_period DESC LIMIT 1 + """) + + competitive_result = engine.execute(competitive_query, company_code=company_code).fetchone() + + # 获取业务结构数据 - 最新一期 + business_structure_query = text(""" + SELECT business_name, + parent_business, + business_level, + revenue, + revenue_unit, + revenue_ratio, + profit, + profit_unit, + profit_ratio, + revenue_growth, + profit_growth, + gross_margin, + customer_count, + market_share, + report_period + FROM company_business_structure + WHERE company_code = :company_code + AND report_period = (SELECT MAX(report_period) + FROM company_business_structure + WHERE company_code = :company_code) + ORDER BY revenue_ratio DESC + """) + + business_structure_result = engine.execute(business_structure_query, company_code=company_code).fetchall() + + # 构建返回数据 + response_data = { + 'company_code': company_code, + 'qualitative_analysis': None, + 'business_segments': [], + 'competitive_position': None, + 'business_structure': [] + } + + # 处理定性分析数据 + if qualitative_result: + response_data['qualitative_analysis'] = { + 'core_positioning': { + 'one_line_intro': qualitative_result.one_line_intro, + 'investment_highlights': qualitative_result.investment_highlights, + 'business_model_desc': qualitative_result.business_model_desc, + 'company_story': qualitative_result.company_story + }, + 'business_understanding': { + 'positioning_analysis': qualitative_result.positioning_analysis, + 'unique_value_proposition': qualitative_result.unique_value_proposition, + 'business_logic_explanation': qualitative_result.business_logic_explanation, + 'revenue_driver_analysis': qualitative_result.revenue_driver_analysis, + 'customer_value_analysis': qualitative_result.customer_value_analysis + }, + 'strategy': { + 'strategy_description': qualitative_result.strategy_description, + 'strategic_initiatives': qualitative_result.strategic_initiatives + }, + 'updated_at': qualitative_result.updated_at.strftime( + '%Y-%m-%d %H:%M:%S') if qualitative_result.updated_at else None + } + + # 处理业务板块数据 + for segment in segments_result: + response_data['business_segments'].append({ + 'segment_name': segment.segment_name, + 'segment_description': segment.segment_description, + 'competitive_position': segment.competitive_position, + 'future_potential': segment.future_potential, + 'key_customers': segment.key_customers, + 'value_chain_position': segment.value_chain_position, + 'updated_at': segment.updated_at.strftime('%Y-%m-%d %H:%M:%S') if segment.updated_at else None + }) + + # 处理竞争地位数据 + if competitive_result: + response_data['competitive_position'] = { + 'scores': { + 'market_position': competitive_result.market_position_score, + 'technology': competitive_result.technology_score, + 'brand': competitive_result.brand_score, + 'operation': competitive_result.operation_score, + 'finance': competitive_result.finance_score, + 'innovation': competitive_result.innovation_score, + 'risk': competitive_result.risk_score, + 'growth': competitive_result.growth_score + }, + 'analysis': { + 'industry_avg_comparison': competitive_result.industry_avg_comparison, + 'main_competitors': competitive_result.main_competitors, + 'competitive_advantages': competitive_result.competitive_advantages, + 'competitive_disadvantages': competitive_result.competitive_disadvantages + }, + 'ranking': { + 'industry_rank': competitive_result.industry_rank, + 'total_companies': competitive_result.total_companies, + 'rank_percentage': round( + (competitive_result.industry_rank / competitive_result.total_companies * 100), + 2) if competitive_result.industry_rank and competitive_result.total_companies else None + }, + 'report_period': competitive_result.report_period, + 'updated_at': competitive_result.updated_at.strftime( + '%Y-%m-%d %H:%M:%S') if competitive_result.updated_at else None + } + + # 处理业务结构数据 + for business in business_structure_result: + response_data['business_structure'].append({ + 'business_name': business.business_name, + 'parent_business': business.parent_business, + 'business_level': business.business_level, + 'revenue': format_decimal(business.revenue), + 'revenue_unit': business.revenue_unit, + 'profit': format_decimal(business.profit), + 'profit_unit': business.profit_unit, + 'financial_metrics': { + 'revenue': format_decimal(business.revenue), + 'revenue_ratio': format_decimal(business.revenue_ratio), + 'profit': format_decimal(business.profit), + 'profit_ratio': format_decimal(business.profit_ratio), + 'gross_margin': format_decimal(business.gross_margin) + }, + 'growth_metrics': { + 'revenue_growth': format_decimal(business.revenue_growth), + 'profit_growth': format_decimal(business.profit_growth) + }, + 'market_metrics': { + 'customer_count': business.customer_count, + 'market_share': format_decimal(business.market_share) + }, + 'report_period': business.report_period + }) + + return jsonify({ + 'success': True, + 'data': response_data + }) + + except Exception as e: + return jsonify({ + 'success': False, + 'error': str(e) + }), 500 + + +@app.route('/api/company/value-chain-analysis/', methods=['GET']) +def get_value_chain_analysis(company_code): + """获取公司产业链分析数据""" + try: + # 获取产业链节点数据 + nodes_query = text(""" + SELECT node_name, + node_type, + node_level, + node_description, + importance_score, + market_share, + dependency_degree, + created_at + FROM company_value_chain_nodes + WHERE company_code = :company_code + ORDER BY node_level ASC, importance_score DESC + """) + + nodes_result = engine.execute(nodes_query, company_code=company_code).fetchall() + + # 获取产业链流向数据 + flows_query = text(""" + SELECT source_node, + source_type, + source_level, + target_node, + target_type, + target_level, + flow_value, + flow_ratio, + flow_type, + relationship_desc, + transaction_volume + FROM company_value_chain_flows + WHERE company_code = :company_code + ORDER BY flow_ratio DESC + """) + + flows_result = engine.execute(flows_query, company_code=company_code).fetchall() + + # 构建节点数据结构 + nodes_by_level = {} + all_nodes = [] + + for node in nodes_result: + node_data = { + 'node_name': node.node_name, + 'node_type': node.node_type, + 'node_level': node.node_level, + 'node_description': node.node_description, + 'importance_score': node.importance_score, + 'market_share': format_decimal(node.market_share), + 'dependency_degree': format_decimal(node.dependency_degree), + 'created_at': node.created_at.strftime('%Y-%m-%d %H:%M:%S') if node.created_at else None + } + + all_nodes.append(node_data) + + # 按层级分组 + level_key = f"level_{node.node_level}" + if level_key not in nodes_by_level: + nodes_by_level[level_key] = [] + nodes_by_level[level_key].append(node_data) + + # 构建流向数据 + flows_data = [] + for flow in flows_result: + flows_data.append({ + 'source': { + 'node_name': flow.source_node, + 'node_type': flow.source_type, + 'node_level': flow.source_level + }, + 'target': { + 'node_name': flow.target_node, + 'node_type': flow.target_type, + 'node_level': flow.target_level + }, + 'flow_metrics': { + 'flow_value': format_decimal(flow.flow_value), + 'flow_ratio': format_decimal(flow.flow_ratio), + 'flow_type': flow.flow_type + }, + 'relationship_info': { + 'relationship_desc': flow.relationship_desc, + 'transaction_volume': flow.transaction_volume + } + }) + + # 移除循环边,确保Sankey图数据是DAG(有向无环图) + flows_data = remove_cycles_from_sankey_flows(flows_data) + + # 统计各层级节点数量 + level_stats = {} + for level_key, nodes in nodes_by_level.items(): + level_stats[level_key] = { + 'count': len(nodes), + 'avg_importance': round(sum(node['importance_score'] or 0 for node in nodes) / len(nodes), + 2) if nodes else 0 + } + + response_data = { + 'company_code': company_code, + 'value_chain_structure': { + 'nodes_by_level': nodes_by_level, + 'level_statistics': level_stats, + 'total_nodes': len(all_nodes) + }, + 'value_chain_flows': flows_data, + 'analysis_summary': { + 'total_flows': len(flows_data), + 'upstream_nodes': len([n for n in all_nodes if n['node_level'] < 0]), + 'company_nodes': len([n for n in all_nodes if n['node_level'] == 0]), + 'downstream_nodes': len([n for n in all_nodes if n['node_level'] > 0]) + } + } + + return jsonify({ + 'success': True, + 'data': response_data + }) + + except Exception as e: + return jsonify({ + 'success': False, + 'error': str(e) + }), 500 + + +@app.route('/api/company/key-factors-timeline/', methods=['GET']) +def get_key_factors_timeline(company_code): + """获取公司关键因素和时间线数据""" + try: + # 获取请求参数 + report_period = request.args.get('report_period') # 可选的报告期筛选 + event_limit = request.args.get('event_limit', 50, type=int) # 时间线事件数量限制 + + # 获取关键因素类别 + categories_query = text(""" + SELECT id, + category_name, + category_desc, + display_order + FROM company_key_factor_categories + WHERE company_code = :company_code + ORDER BY display_order ASC, created_at ASC + """) + + categories_result = engine.execute(categories_query, company_code=company_code).fetchall() + + # 获取关键因素详情 + factors_query = text(""" + SELECT kf.category_id, + kf.factor_name, + kf.factor_type, + kf.factor_value, + kf.factor_unit, + kf.factor_desc, + kf.impact_direction, + kf.impact_weight, + kf.report_period, + kf.year_on_year, + kf.data_source, + kf.created_at, + kf.updated_at + FROM company_key_factors kf + WHERE kf.company_code = :company_code + """) + + params = {'company_code': company_code} + + # 如果指定了报告期,添加筛选条件 + if report_period: + factors_query = text(""" + SELECT kf.category_id, + kf.factor_name, + kf.factor_type, + kf.factor_value, + kf.factor_unit, + kf.factor_desc, + kf.impact_direction, + kf.impact_weight, + kf.report_period, + kf.year_on_year, + kf.data_source, + kf.created_at, + kf.updated_at + FROM company_key_factors kf + WHERE kf.company_code = :company_code + AND kf.report_period = :report_period + ORDER BY kf.impact_weight DESC, kf.updated_at DESC + """) + params['report_period'] = report_period + else: + factors_query = text(""" + SELECT kf.category_id, + kf.factor_name, + kf.factor_type, + kf.factor_value, + kf.factor_unit, + kf.factor_desc, + kf.impact_direction, + kf.impact_weight, + kf.report_period, + kf.year_on_year, + kf.data_source, + kf.created_at, + kf.updated_at + FROM company_key_factors kf + WHERE kf.company_code = :company_code + ORDER BY kf.report_period DESC, kf.impact_weight DESC, kf.updated_at DESC + """) + + factors_result = engine.execute(factors_query, **params).fetchall() + + # 获取发展时间线事件 + timeline_query = text(""" + SELECT event_date, + event_type, + event_title, + event_desc, + impact_score, + is_positive, + related_products, + related_partners, + financial_impact, + created_at + FROM company_timeline_events + WHERE company_code = :company_code + ORDER BY event_date DESC LIMIT :limit + """) + + timeline_result = engine.execute(timeline_query, + company_code=company_code, + limit=event_limit).fetchall() + + # 构建关键因素数据结构 + key_factors_data = {} + factors_by_category = {} + + # 先建立类别索引 + categories_map = {} + for category in categories_result: + categories_map[category.id] = { + 'category_name': category.category_name, + 'category_desc': category.category_desc, + 'display_order': category.display_order, + 'factors': [] + } + + # 将因素分组到类别中 + for factor in factors_result: + factor_data = { + 'factor_name': factor.factor_name, + 'factor_type': factor.factor_type, + 'factor_value': factor.factor_value, + 'factor_unit': factor.factor_unit, + 'factor_desc': factor.factor_desc, + 'impact_direction': factor.impact_direction, + 'impact_weight': factor.impact_weight, + 'report_period': factor.report_period, + 'year_on_year': format_decimal(factor.year_on_year), + 'data_source': factor.data_source, + 'updated_at': factor.updated_at.strftime('%Y-%m-%d %H:%M:%S') if factor.updated_at else None + } + + category_id = factor.category_id + if category_id and category_id in categories_map: + categories_map[category_id]['factors'].append(factor_data) + + # 构建时间线数据 + timeline_data = [] + for event in timeline_result: + timeline_data.append({ + 'event_date': event.event_date.strftime('%Y-%m-%d') if event.event_date else None, + 'event_type': event.event_type, + 'event_title': event.event_title, + 'event_desc': event.event_desc, + 'impact_metrics': { + 'impact_score': event.impact_score, + 'is_positive': event.is_positive + }, + 'related_info': { + 'related_products': event.related_products, + 'related_partners': event.related_partners, + 'financial_impact': event.financial_impact + }, + 'created_at': event.created_at.strftime('%Y-%m-%d %H:%M:%S') if event.created_at else None + }) + + # 统计信息 + total_factors = len(factors_result) + positive_events = len([e for e in timeline_result if e.is_positive]) + negative_events = len(timeline_result) - positive_events + + response_data = { + 'company_code': company_code, + 'key_factors': { + 'categories': list(categories_map.values()), + 'total_factors': total_factors, + 'report_period': report_period + }, + 'development_timeline': { + 'events': timeline_data, + 'statistics': { + 'total_events': len(timeline_data), + 'positive_events': positive_events, + 'negative_events': negative_events, + 'event_types': list(set(event.event_type for event in timeline_result if event.event_type)) + } + } + } + + return jsonify({ + 'success': True, + 'data': response_data + }) + + except Exception as e: + return jsonify({ + 'success': False, + 'error': str(e) + }), 500 + + +# ============================================ +# 模拟盘服务函数 +# ============================================ + +def get_or_create_simulation_account(user_id): + """获取或创建模拟账户""" + account = SimulationAccount.query.filter_by(user_id=user_id).first() + if not account: + account = SimulationAccount( + user_id=user_id, + account_name=f'模拟账户_{user_id}', + initial_capital=1000000.00, + available_cash=1000000.00 + ) + db.session.add(account) + db.session.commit() + return account + + +def is_trading_time(): + """判断是否为交易时间""" + now = beijing_now() + # 检查是否为工作日 + if now.weekday() >= 5: # 周六日 + return False + + # 检查是否为交易时间 + current_time = now.time() + morning_start = dt_time(9, 30) + morning_end = dt_time(11, 30) + afternoon_start = dt_time(13, 0) + afternoon_end = dt_time(15, 0) + + if (morning_start <= current_time <= morning_end) or \ + (afternoon_start <= current_time <= afternoon_end): + return True + + return False + + +def get_latest_price_from_clickhouse(stock_code): + """从ClickHouse获取最新价格(优先分钟数据,备选日线数据)""" + try: + client = get_clickhouse_client() + + # 确保stock_code包含后缀 + if '.' not in stock_code: + stock_code = f"{stock_code}.SH" if stock_code.startswith('6') else f"{stock_code}.SZ" + + # 1. 首先尝试获取最新的分钟数据(近30天) + minute_query = """ + SELECT close, timestamp + FROM stock_minute + WHERE code = %(code)s + AND timestamp >= today() - 30 + ORDER BY timestamp DESC + LIMIT 1 \ + """ + + result = client.execute(minute_query, {'code': stock_code}) + + if result: + return float(result[0][0]), result[0][1] + + # 2. 如果没有分钟数据,获取最新的日线收盘价 + daily_query = """ + SELECT close, date + FROM stock_daily + WHERE code = %(code)s + AND date >= today() - 90 + ORDER BY date DESC + LIMIT 1 \ + """ + + daily_result = client.execute(daily_query, {'code': stock_code}) + + if daily_result: + return float(daily_result[0][0]), daily_result[0][1] + + # 3. 如果还是没有,尝试从其他表获取(如果有的话) + fallback_query = """ + SELECT close_price, trade_date + FROM stock_minute_kline + WHERE stock_code = %(code6)s + AND trade_date >= today() - 30 + ORDER BY trade_date DESC, trade_time DESC LIMIT 1 \ + """ + + # 提取6位代码 + code6 = stock_code.split('.')[0] + fallback_result = client.execute(fallback_query, {'code6': code6}) + + if fallback_result: + return float(fallback_result[0][0]), fallback_result[0][1] + + print(f"警告: 无法获取股票 {stock_code} 的价格数据") + return None, None + + except Exception as e: + print(f"获取最新价格失败 {stock_code}: {e}") + return None, None + + +def get_next_minute_price(stock_code, order_time): + """获取下单后一分钟内的收盘价作为成交价""" + try: + client = get_clickhouse_client() + + # 确保stock_code包含后缀 + if '.' not in stock_code: + stock_code = f"{stock_code}.SH" if stock_code.startswith('6') else f"{stock_code}.SZ" + + # 获取下单后一分钟内的数据 + query = """ + SELECT close, timestamp + FROM stock_minute + WHERE code = %(code)s + AND timestamp \ + > %(order_time)s + AND timestamp <= %(end_time)s + ORDER BY timestamp ASC + LIMIT 1 \ + """ + + end_time = order_time + timedelta(minutes=1) + + result = client.execute(query, { + 'code': stock_code, + 'order_time': order_time, + 'end_time': end_time + }) + + if result: + return float(result[0][0]), result[0][1] + + # 如果一分钟内没有数据,获取最近的数据 + query = """ + SELECT close, timestamp + FROM stock_minute + WHERE code = %(code)s + AND timestamp \ + > %(order_time)s + ORDER BY timestamp ASC + LIMIT 1 \ + """ + + result = client.execute(query, { + 'code': stock_code, + 'order_time': order_time + }) + + if result: + return float(result[0][0]), result[0][1] + + # 如果没有后续分钟数据,使用最新可用价格 + print(f"没有找到下单后的分钟数据,使用最新价格: {stock_code}") + return get_latest_price_from_clickhouse(stock_code) + + except Exception as e: + print(f"获取成交价格失败: {e}") + # 出错时也尝试获取最新价格 + return get_latest_price_from_clickhouse(stock_code) + + +def validate_and_get_stock_info(stock_input): + """验证股票输入并获取标准代码和名称 + + 支持输入格式: + - 股票代码:600519 或 600519.SH + - 股票名称:贵州茅台 + - 拼音首字母:gzmt + - 名称(代码):贵州茅台(600519) + + 返回: (stock_code_with_suffix, stock_code_6digit, stock_name) 或 (None, None, None) + """ + # 先尝试标准化输入 + code6, name_from_input = _normalize_stock_input(stock_input) + + if code6: + # 如果能解析出6位代码,查询股票名称 + stock_name = name_from_input or _query_stock_name_by_code(code6) + stock_code_full = f"{code6}.SH" if code6.startswith('6') else f"{code6}.SZ" + return stock_code_full, code6, stock_name + + # 如果不是标准代码格式,尝试搜索 + with engine.connect() as conn: + search_sql = text(""" + SELECT DISTINCT SECCODE as stock_code, + SECNAME as stock_name + FROM ea_stocklist + WHERE ( + UPPER(SECCODE) = UPPER(:exact_match) + OR UPPER(SECNAME) = UPPER(:exact_match) + OR UPPER(F001V) = UPPER(:exact_match) + ) + AND F011V = '正常上市' + AND F003V IN ('A股', 'B股') LIMIT 1 + """) + + result = conn.execute(search_sql, { + 'exact_match': stock_input.upper() + }).fetchone() + + if result: + code6 = result.stock_code + stock_name = result.stock_name + stock_code_full = f"{code6}.SH" if code6.startswith('6') else f"{code6}.SZ" + return stock_code_full, code6, stock_name + + return None, None, None + + +def execute_simulation_order(order): + """执行模拟订单(优化版)""" + try: + # 标准化股票代码 + stock_code_full, code6, stock_name = validate_and_get_stock_info(order.stock_code) + + if not stock_code_full: + order.status = 'REJECTED' + order.reject_reason = '无效的股票代码' + db.session.commit() + return False + + # 更新订单的股票信息 + order.stock_code = stock_code_full + order.stock_name = stock_name + + # 获取成交价格(下单后一分钟的收盘价) + filled_price, filled_time = get_next_minute_price(stock_code_full, order.order_time) + + if not filled_price: + # 如果无法获取价格,订单保持PENDING状态,等待后台处理 + order.status = 'PENDING' + db.session.commit() + return True # 返回True表示下单成功,但未成交 + + # 更新订单信息 + order.filled_qty = order.order_qty + order.filled_price = filled_price + order.filled_amount = filled_price * order.order_qty + order.filled_time = filled_time or beijing_now() + + # 计算费用 + order.calculate_fees() + + # 获取账户 + account = SimulationAccount.query.get(order.account_id) + + if order.order_type == 'BUY': + # 买入操作 + total_cost = float(order.filled_amount) + float(order.total_fee) + + # 检查资金是否充足 + if float(account.available_cash) < total_cost: + order.status = 'REJECTED' + order.reject_reason = '可用资金不足' + db.session.commit() + return False + + # 扣除资金 + account.available_cash -= Decimal(str(total_cost)) + + # 更新或创建持仓 + position = SimulationPosition.query.filter_by( + account_id=account.id, + stock_code=order.stock_code + ).first() + + if position: + # 更新持仓 + total_cost_before = float(position.avg_cost) * position.position_qty + total_cost_after = total_cost_before + float(order.filled_amount) + total_qty_after = position.position_qty + order.filled_qty + + position.avg_cost = Decimal(str(total_cost_after / total_qty_after)) + position.position_qty = total_qty_after + # 今日买入,T+1才可用 + position.frozen_qty += order.filled_qty + else: + # 创建新持仓 + position = SimulationPosition( + account_id=account.id, + stock_code=order.stock_code, + stock_name=order.stock_name, + position_qty=order.filled_qty, + available_qty=0, # T+1 + frozen_qty=order.filled_qty, # 今日买入冻结 + avg_cost=order.filled_price, + current_price=order.filled_price + ) + db.session.add(position) + + # 更新持仓市值 + position.update_market_value(order.filled_price) + + else: # SELL + # 卖出操作 + print(f"🔍 调试:查找持仓,账户ID: {account.id}, 股票代码: {order.stock_code}") + + # 先尝试用完整格式查找 + position = SimulationPosition.query.filter_by( + account_id=account.id, + stock_code=order.stock_code + ).first() + + # 如果没找到,尝试用6位数字格式查找 + if not position and '.' in order.stock_code: + code6 = order.stock_code.split('.')[0] + print(f"🔍 调试:尝试用6位格式查找: {code6}") + position = SimulationPosition.query.filter_by( + account_id=account.id, + stock_code=code6 + ).first() + + print(f"🔍 调试:找到持仓: {position}") + if position: + print( + f"🔍 调试:持仓详情 - 股票代码: {position.stock_code}, 持仓数量: {position.position_qty}, 可用数量: {position.available_qty}") + + # 检查持仓是否存在 + if not position: + order.status = 'REJECTED' + order.reject_reason = '持仓不存在' + db.session.commit() + return False + + # 检查总持仓数量是否足够(包括冻结的) + total_holdings = position.position_qty + if total_holdings < order.order_qty: + order.status = 'REJECTED' + order.reject_reason = f'持仓数量不足,当前持仓: {total_holdings} 股,需要: {order.order_qty} 股' + db.session.commit() + return False + + # 如果可用数量不足,但总持仓足够,则从冻结数量中解冻 + if position.available_qty < order.order_qty: + # 计算需要解冻的数量 + need_to_unfreeze = order.order_qty - position.available_qty + if position.frozen_qty >= need_to_unfreeze: + # 解冻部分冻结数量 + position.frozen_qty -= need_to_unfreeze + position.available_qty += need_to_unfreeze + print(f"解冻 {need_to_unfreeze} 股用于卖出") + else: + order.status = 'REJECTED' + order.reject_reason = f'可用数量不足,可用: {position.available_qty} 股,冻结: {position.frozen_qty} 股,需要: {order.order_qty} 股' + db.session.commit() + return False + + # 更新持仓 + position.position_qty -= order.filled_qty + position.available_qty -= order.filled_qty + + # 增加资金 + account.available_cash += Decimal(str(float(order.filled_amount) - float(order.total_fee))) + + # 如果全部卖出,删除持仓记录 + if position.position_qty == 0: + db.session.delete(position) + + # 创建成交记录 + transaction = SimulationTransaction( + account_id=account.id, + order_id=order.id, + transaction_no=f"T{int(beijing_now().timestamp() * 1000000)}", + stock_code=order.stock_code, + stock_name=order.stock_name, + transaction_type=order.order_type, + transaction_price=order.filled_price, + transaction_qty=order.filled_qty, + transaction_amount=order.filled_amount, + commission=order.commission, + stamp_tax=order.stamp_tax, + transfer_fee=order.transfer_fee, + total_fee=order.total_fee, + transaction_time=order.filled_time, + settlement_date=(order.filled_time + timedelta(days=1)).date() + ) + db.session.add(transaction) + + # 更新订单状态 + order.status = 'FILLED' + + # 更新账户总资产 + update_account_assets(account) + + db.session.commit() + return True + + except Exception as e: + print(f"执行订单失败: {e}") + db.session.rollback() + return False + + +def update_account_assets(account): + """更新账户资产(轻量级版本,不实时获取价格)""" + try: + # 只计算已有的持仓市值,不实时获取价格 + # 价格更新由后台脚本负责 + positions = SimulationPosition.query.filter_by(account_id=account.id).all() + total_market_value = sum(position.market_value or Decimal('0') for position in positions) + + account.position_value = total_market_value + account.calculate_total_assets() + + db.session.commit() + + except Exception as e: + print(f"更新账户资产失败: {e}") + db.session.rollback() + + +def update_all_positions_price(): + """更新所有持仓的最新价格(定时任务调用)""" + try: + positions = SimulationPosition.query.all() + + for position in positions: + latest_price, _ = get_latest_price_from_clickhouse(position.stock_code) + if latest_price: + # 记录昨日收盘价(用于计算今日盈亏) + yesterday_close = position.current_price + + # 更新市值 + position.update_market_value(latest_price) + + # 计算今日盈亏 + position.today_profit = (Decimal(str(latest_price)) - yesterday_close) * position.position_qty + position.today_profit_rate = ((Decimal( + str(latest_price)) - yesterday_close) / yesterday_close * 100) if yesterday_close > 0 else 0 + + db.session.commit() + + except Exception as e: + print(f"更新持仓价格失败: {e}") + db.session.rollback() + + +def process_t1_settlement(): + """处理T+1结算(每日收盘后运行)""" + try: + # 获取所有需要结算的持仓 + positions = SimulationPosition.query.filter(SimulationPosition.frozen_qty > 0).all() + + for position in positions: + # 将冻结数量转为可用数量 + position.available_qty += position.frozen_qty + position.frozen_qty = 0 + + db.session.commit() + + except Exception as e: + print(f"T+1结算失败: {e}") + db.session.rollback() + + +# ============================================ +# 模拟盘API接口 +# ============================================ + +@app.route('/api/simulation/account', methods=['GET']) +@login_required +def get_simulation_account(): + """获取模拟账户信息""" + try: + account = get_or_create_simulation_account(current_user.id) + + # 更新账户资产 + update_account_assets(account) + + return jsonify({ + 'success': True, + 'data': { + 'account_id': account.id, + 'account_name': account.account_name, + 'initial_capital': float(account.initial_capital), + 'available_cash': float(account.available_cash), + 'frozen_cash': float(account.frozen_cash), + 'position_value': float(account.position_value), + 'total_assets': float(account.total_assets), + 'total_profit': float(account.total_profit), + 'total_profit_rate': float(account.total_profit_rate), + 'daily_profit': float(account.daily_profit), + 'daily_profit_rate': float(account.daily_profit_rate), + 'created_at': account.created_at.isoformat(), + 'updated_at': account.updated_at.isoformat() + } + }) + + except Exception as e: + return jsonify({'success': False, 'error': str(e)}), 500 + + +@app.route('/api/simulation/positions', methods=['GET']) +@login_required +def get_simulation_positions(): + """获取模拟持仓列表(优化版本,使用缓存的价格数据)""" + try: + account = get_or_create_simulation_account(current_user.id) + + # 直接获取持仓数据,不实时更新价格(由后台脚本负责) + positions = SimulationPosition.query.filter_by(account_id=account.id).all() + + positions_data = [] + for position in positions: + positions_data.append({ + 'id': position.id, + 'stock_code': position.stock_code, + 'stock_name': position.stock_name, + 'position_qty': position.position_qty, + 'available_qty': position.available_qty, + 'frozen_qty': position.frozen_qty, + 'avg_cost': float(position.avg_cost), + 'current_price': float(position.current_price or 0), + 'market_value': float(position.market_value or 0), + 'profit': float(position.profit or 0), + 'profit_rate': float(position.profit_rate or 0), + 'today_profit': float(position.today_profit or 0), + 'today_profit_rate': float(position.today_profit_rate or 0), + 'updated_at': position.updated_at.isoformat() + }) + + return jsonify({ + 'success': True, + 'data': positions_data + }) + + except Exception as e: + return jsonify({'success': False, 'error': str(e)}), 500 + + +@app.route('/api/simulation/orders', methods=['GET']) +@login_required +def get_simulation_orders(): + """获取模拟订单列表""" + try: + account = get_or_create_simulation_account(current_user.id) + + # 获取查询参数 + status = request.args.get('status') # 订单状态筛选 + date_str = request.args.get('date') # 日期筛选 + limit = request.args.get('limit', 50, type=int) + + query = SimulationOrder.query.filter_by(account_id=account.id) + + if status: + query = query.filter_by(status=status) + + if date_str: + try: + date = datetime.strptime(date_str, '%Y-%m-%d').date() + start_time = datetime.combine(date, dt_time(0, 0, 0)) + end_time = datetime.combine(date, dt_time(23, 59, 59)) + query = query.filter(SimulationOrder.order_time.between(start_time, end_time)) + except ValueError: + pass + + orders = query.order_by(SimulationOrder.order_time.desc()).limit(limit).all() + + orders_data = [] + for order in orders: + orders_data.append({ + 'id': order.id, + 'order_no': order.order_no, + 'stock_code': order.stock_code, + 'stock_name': order.stock_name, + 'order_type': order.order_type, + 'price_type': order.price_type, + 'order_price': float(order.order_price) if order.order_price else None, + 'order_qty': order.order_qty, + 'filled_qty': order.filled_qty, + 'filled_price': float(order.filled_price) if order.filled_price else None, + 'filled_amount': float(order.filled_amount) if order.filled_amount else None, + 'commission': float(order.commission), + 'stamp_tax': float(order.stamp_tax), + 'transfer_fee': float(order.transfer_fee), + 'total_fee': float(order.total_fee), + 'status': order.status, + 'reject_reason': order.reject_reason, + 'order_time': order.order_time.isoformat(), + 'filled_time': order.filled_time.isoformat() if order.filled_time else None + }) + + return jsonify({ + 'success': True, + 'data': orders_data + }) + + except Exception as e: + return jsonify({'success': False, 'error': str(e)}), 500 + + +@app.route('/api/simulation/place-order', methods=['POST']) +@login_required +def place_simulation_order(): + """下单""" + try: + # 移除交易时间检查,允许7x24小时下单 + # 非交易时间下的单子会保持PENDING状态,等待行情数据 + + data = request.get_json() + stock_code = data.get('stock_code') + order_type = data.get('order_type') # BUY/SELL + order_qty = data.get('order_qty') + price_type = data.get('price_type', 'MARKET') # 目前只支持市价单 + + # 标准化股票代码格式 + if stock_code and '.' not in stock_code: + # 如果没有后缀,根据股票代码添加后缀 + if stock_code.startswith('6'): + stock_code = f"{stock_code}.SH" + elif stock_code.startswith('0') or stock_code.startswith('3'): + stock_code = f"{stock_code}.SZ" + + # 参数验证 + if not all([stock_code, order_type, order_qty]): + return jsonify({'success': False, 'error': '缺少必要参数'}), 400 + + if order_type not in ['BUY', 'SELL']: + return jsonify({'success': False, 'error': '订单类型错误'}), 400 + + order_qty = int(order_qty) + if order_qty <= 0 or order_qty % 100 != 0: + return jsonify({'success': False, 'error': '下单数量必须为100的整数倍'}), 400 + + # 获取账户 + account = get_or_create_simulation_account(current_user.id) + + # 获取股票信息 + stock_name = None + with engine.connect() as conn: + result = conn.execute(text( + "SELECT SECNAME FROM ea_stocklist WHERE SECCODE = :code" + ), {"code": stock_code.split('.')[0]}).fetchone() + if result: + stock_name = result[0] + + # 创建订单 + order = SimulationOrder( + account_id=account.id, + order_no=f"O{int(beijing_now().timestamp() * 1000000)}", + stock_code=stock_code, + stock_name=stock_name, + order_type=order_type, + price_type=price_type, + order_qty=order_qty, + status='PENDING' + ) + + db.session.add(order) + db.session.commit() + + # 执行订单 + print(f"🔍 调试:开始执行订单,股票代码: {order.stock_code}, 订单类型: {order.order_type}") + success = execute_simulation_order(order) + print(f"🔍 调试:订单执行结果: {success}, 订单状态: {order.status}") + + if success: + # 重新查询订单状态,因为可能在execute_simulation_order中被修改 + db.session.refresh(order) + + if order.status == 'FILLED': + return jsonify({ + 'success': True, + 'message': '订单执行成功,已成交', + 'data': { + 'order_no': order.order_no, + 'status': 'FILLED', + 'filled_price': float(order.filled_price) if order.filled_price else None, + 'filled_qty': order.filled_qty, + 'filled_amount': float(order.filled_amount) if order.filled_amount else None, + 'total_fee': float(order.total_fee) + } + }) + elif order.status == 'PENDING': + return jsonify({ + 'success': True, + 'message': '订单提交成功,等待行情数据成交', + 'data': { + 'order_no': order.order_no, + 'status': 'PENDING', + 'order_qty': order.order_qty, + 'order_price': float(order.order_price) if order.order_price else None + } + }) + else: + return jsonify({ + 'success': False, + 'error': order.reject_reason or '订单状态异常' + }), 400 + else: + return jsonify({ + 'success': False, + 'error': order.reject_reason or '订单执行失败' + }), 400 + + except Exception as e: + db.session.rollback() + return jsonify({'success': False, 'error': str(e)}), 500 + + +@app.route('/api/simulation/cancel-order/', methods=['POST']) +@login_required +def cancel_simulation_order(order_id): + """撤销订单""" + try: + account = get_or_create_simulation_account(current_user.id) + + order = SimulationOrder.query.filter_by( + id=order_id, + account_id=account.id, + status='PENDING' + ).first() + + if not order: + return jsonify({'success': False, 'error': '订单不存在或无法撤销'}), 404 + + order.status = 'CANCELLED' + order.cancel_time = beijing_now() + + db.session.commit() + + return jsonify({ + 'success': True, + 'message': '订单已撤销' + }) + + except Exception as e: + db.session.rollback() + return jsonify({'success': False, 'error': str(e)}), 500 + + +@app.route('/api/simulation/transactions', methods=['GET']) +@login_required +def get_simulation_transactions(): + """获取成交记录""" + try: + account = get_or_create_simulation_account(current_user.id) + + # 获取查询参数 + date_str = request.args.get('date') + limit = request.args.get('limit', 100, type=int) + + query = SimulationTransaction.query.filter_by(account_id=account.id) + + if date_str: + try: + date = datetime.strptime(date_str, '%Y-%m-%d').date() + start_time = datetime.combine(date, dt_time(0, 0, 0)) + end_time = datetime.combine(date, dt_time(23, 59, 59)) + query = query.filter(SimulationTransaction.transaction_time.between(start_time, end_time)) + except ValueError: + pass + + transactions = query.order_by(SimulationTransaction.transaction_time.desc()).limit(limit).all() + + transactions_data = [] + for trans in transactions: + transactions_data.append({ + 'id': trans.id, + 'transaction_no': trans.transaction_no, + 'stock_code': trans.stock_code, + 'stock_name': trans.stock_name, + 'transaction_type': trans.transaction_type, + 'transaction_price': float(trans.transaction_price), + 'transaction_qty': trans.transaction_qty, + 'transaction_amount': float(trans.transaction_amount), + 'commission': float(trans.commission), + 'stamp_tax': float(trans.stamp_tax), + 'transfer_fee': float(trans.transfer_fee), + 'total_fee': float(trans.total_fee), + 'transaction_time': trans.transaction_time.isoformat(), + 'settlement_date': trans.settlement_date.isoformat() if trans.settlement_date else None + }) + + return jsonify({ + 'success': True, + 'data': transactions_data + }) + + except Exception as e: + return jsonify({'success': False, 'error': str(e)}), 500 + + +def get_simulation_statistics(): + """获取模拟交易统计""" + try: + account = get_or_create_simulation_account(current_user.id) + + # 获取统计时间范围 + days = request.args.get('days', 30, type=int) + end_date = beijing_now().date() + start_date = end_date - timedelta(days=days) + + # 查询日统计数据 + daily_stats = SimulationDailyStats.query.filter( + SimulationDailyStats.account_id == account.id, + SimulationDailyStats.stat_date >= start_date, + SimulationDailyStats.stat_date <= end_date + ).order_by(SimulationDailyStats.stat_date).all() + + # 查询总体统计 + total_transactions = SimulationTransaction.query.filter_by(account_id=account.id).count() + win_transactions = SimulationTransaction.query.filter( + SimulationTransaction.account_id == account.id, + SimulationTransaction.transaction_type == 'SELL' + ).all() + + win_count = 0 + total_profit = Decimal('0') + for trans in win_transactions: + # 查找对应的买入记录计算盈亏 + position = SimulationPosition.query.filter_by( + account_id=account.id, + stock_code=trans.stock_code + ).first() + if position and trans.transaction_price > position.avg_cost: + win_count += 1 + profit = (trans.transaction_price - position.avg_cost) * trans.transaction_qty if position else 0 + total_profit += profit + + # 构建日收益曲线 + daily_returns = [] + for stat in daily_stats: + daily_returns.append({ + 'date': stat.stat_date.isoformat(), + 'daily_profit': float(stat.daily_profit), + 'daily_profit_rate': float(stat.daily_profit_rate), + 'total_profit': float(stat.total_profit), + 'total_profit_rate': float(stat.total_profit_rate), + 'closing_assets': float(stat.closing_assets) + }) + + return jsonify({ + 'success': True, + 'data': { + 'summary': { + 'total_transactions': total_transactions, + 'win_count': win_count, + 'win_rate': (win_count / len(win_transactions) * 100) if win_transactions else 0, + 'total_profit': float(total_profit), + 'average_profit_per_trade': float(total_profit / len(win_transactions)) if win_transactions else 0 + }, + 'daily_returns': daily_returns + } + }) + + except Exception as e: + return jsonify({'success': False, 'error': str(e)}), 500 + + +@app.route('/api/simulation/t1-settlement', methods=['POST']) +@login_required +def trigger_t1_settlement(): + """手动触发T+1结算""" + try: + # 导入后台处理器的函数 + from simulation_background_processor import process_t1_settlement + + # 执行T+1结算 + process_t1_settlement() + + return jsonify({ + 'success': True, + 'message': 'T+1结算执行成功' + }) + + except Exception as e: + return jsonify({ + 'success': False, + 'error': str(e) + }), 500 + + +@app.route('/api/simulation/debug-positions', methods=['GET']) +@login_required +def debug_positions(): + """调试接口:查看持仓数据""" + try: + account = get_or_create_simulation_account(current_user.id) + + positions = SimulationPosition.query.filter_by(account_id=account.id).all() + + positions_data = [] + for position in positions: + positions_data.append({ + 'stock_code': position.stock_code, + 'stock_name': position.stock_name, + 'position_qty': position.position_qty, + 'available_qty': position.available_qty, + 'frozen_qty': position.frozen_qty, + 'avg_cost': float(position.avg_cost), + 'current_price': float(position.current_price or 0) + }) + + return jsonify({ + 'success': True, + 'data': positions_data + }) + + except Exception as e: + return jsonify({ + 'success': False, + 'error': str(e) + }), 500 + + +@app.route('/api/simulation/debug-transactions', methods=['GET']) +@login_required +def debug_transactions(): + """调试接口:查看成交记录数据""" + try: + account = get_or_create_simulation_account(current_user.id) + + transactions = SimulationTransaction.query.filter_by(account_id=account.id).all() + + transactions_data = [] + for trans in transactions: + transactions_data.append({ + 'id': trans.id, + 'transaction_no': trans.transaction_no, + 'stock_code': trans.stock_code, + 'stock_name': trans.stock_name, + 'transaction_type': trans.transaction_type, + 'transaction_price': float(trans.transaction_price), + 'transaction_qty': trans.transaction_qty, + 'transaction_amount': float(trans.transaction_amount), + 'commission': float(trans.commission), + 'stamp_tax': float(trans.stamp_tax), + 'transfer_fee': float(trans.transfer_fee), + 'total_fee': float(trans.total_fee), + 'transaction_time': trans.transaction_time.isoformat(), + 'settlement_date': trans.settlement_date.isoformat() if trans.settlement_date else None + }) + + return jsonify({ + 'success': True, + 'data': transactions_data, + 'count': len(transactions_data) + }) + + except Exception as e: + return jsonify({ + 'success': False, + 'error': str(e) + }), 500 + + +@app.route('/api/simulation/daily-settlement', methods=['POST']) +@login_required +def trigger_daily_settlement(): + """手动触发日结算""" + try: + # 导入后台处理器的函数 + from simulation_background_processor import generate_daily_stats + + # 执行日结算 + generate_daily_stats() + + return jsonify({ + 'success': True, + 'message': '日结算执行成功' + }) + + except Exception as e: + return jsonify({ + 'success': False, + 'error': str(e) + }), 500 + + +@app.route('/api/simulation/reset', methods=['POST']) +@login_required +def reset_simulation_account(): + """重置模拟账户""" + try: + account = SimulationAccount.query.filter_by(user_id=current_user.id).first() + + if account: + # 删除所有相关数据 + SimulationPosition.query.filter_by(account_id=account.id).delete() + SimulationOrder.query.filter_by(account_id=account.id).delete() + SimulationTransaction.query.filter_by(account_id=account.id).delete() + SimulationDailyStats.query.filter_by(account_id=account.id).delete() + + # 重置账户数据 + account.available_cash = account.initial_capital + account.frozen_cash = Decimal('0') + account.position_value = Decimal('0') + account.total_assets = account.initial_capital + account.total_profit = Decimal('0') + account.total_profit_rate = Decimal('0') + account.daily_profit = Decimal('0') + account.daily_profit_rate = Decimal('0') + account.updated_at = beijing_now() + + db.session.commit() + + return jsonify({ + 'success': True, + 'message': '模拟账户已重置' + }) + + except Exception as e: + db.session.rollback() + return jsonify({'success': False, 'error': str(e)}), 500 + + +if __name__ == '__main__': + # 创建数据库表 + with app.app_context(): + try: + db.create_all() + # 安全地初始化订阅套餐 + initialize_subscription_plans_safe() + except Exception as e: + app.logger.error(f"数据库初始化失败: {e}") + + # 初始化事件轮询机制(WebSocket 推送) + initialize_event_polling() + + # 使用 socketio.run 替代 app.run 以支持 WebSocket + socketio.run(app, host='0.0.0.0', port=5001, debug=False, allow_unsafe_werkzeug=True) \ No newline at end of file diff --git a/src/components/Citation/CitedContent.js b/src/components/Citation/CitedContent.js index a9726ab3..0f1426ef 100644 --- a/src/components/Citation/CitedContent.js +++ b/src/components/Citation/CitedContent.js @@ -21,6 +21,8 @@ const { Text } = Typography; * @param {Object} props.prefixStyle - 前缀标签的自定义样式(可选) * @param {boolean} props.showAIBadge - 是否显示右上角 AI 标识,默认 true(可选) * @param {Object} props.containerStyle - 容器额外样式(可选) + * @param {string} props.textColor - 文本颜色,默认自动判断背景色(可选) + * @param {string} props.titleColor - 标题颜色,默认继承 textColor(可选) * * @example * */ const CitedContent = ({ @@ -38,7 +41,9 @@ const CitedContent = ({ prefix = '', prefixStyle = {}, showAIBadge = true, - containerStyle = {} + containerStyle = {}, + textColor, + titleColor }) => { // 处理数据 const processed = processCitationData(data); @@ -52,6 +57,19 @@ const CitedContent = ({ return null; } + // 自动判断文本颜色:如果容器背景是深色,使用浅色文本 + const bgColor = containerStyle.backgroundColor; + const isDarkBg = bgColor && ( + bgColor.includes('rgba(0,0,0') || + bgColor.includes('rgba(0, 0, 0') || + bgColor === 'transparent' || + bgColor.includes('#1A202C') || + bgColor.includes('#171923') + ); + + const finalTextColor = textColor || (isDarkBg ? '#E2E8F0' : '#262626'); + const finalTitleColor = titleColor || finalTextColor; + return (
- + {title}
@@ -105,6 +123,7 @@ const CitedContent = ({ fontWeight: 'bold', display: 'inline', marginRight: 4, + color: finalTextColor, ...prefixStyle }}> {prefix} @@ -114,7 +133,7 @@ const CitedContent = ({ {processed.segments.map((segment, index) => ( {/* 文本片段 */} - + {segment.text} @@ -126,7 +145,7 @@ const CitedContent = ({ {/* 在片段之间添加逗号分隔符(最后一个不加) */} {index < processed.segments.length - 1 && ( - + )} ))} diff --git a/src/views/Community/components/DynamicNewsDetail/StockListItem.js b/src/views/Community/components/DynamicNewsDetail/StockListItem.js index d10adfb5..f5a361e4 100644 --- a/src/views/Community/components/DynamicNewsDetail/StockListItem.js +++ b/src/views/Community/components/DynamicNewsDetail/StockListItem.js @@ -311,6 +311,7 @@ const StockListItem = ({ data={stock.relation_desc} title="" showAIBadge={true} + textColor={PROFESSIONAL_COLORS.text.primary} containerStyle={{ backgroundColor: 'transparent', borderRadius: '0', diff --git a/src/views/EventDetail/components/HistoricalEvents.js b/src/views/EventDetail/components/HistoricalEvents.js index 48cc4204..e5539ea4 100644 --- a/src/views/EventDetail/components/HistoricalEvents.js +++ b/src/views/EventDetail/components/HistoricalEvents.js @@ -344,6 +344,7 @@ const HistoricalEvents = ({ data={content} title="" showAIBadge={true} + textColor={PROFESSIONAL_COLORS.text.primary} containerStyle={{ backgroundColor: useColorModeValue('#f7fafc', 'rgba(45, 55, 72, 0.6)'), borderRadius: '8px', diff --git a/src/views/EventDetail/components/TransmissionChainAnalysis.js b/src/views/EventDetail/components/TransmissionChainAnalysis.js index 34009aa0..ab6896a8 100644 --- a/src/views/EventDetail/components/TransmissionChainAnalysis.js +++ b/src/views/EventDetail/components/TransmissionChainAnalysis.js @@ -972,6 +972,7 @@ const TransmissionChainAnalysis = ({ eventId }) => { ) : ( `${selectedNode.extra.description}(AI合成)` @@ -1081,7 +1082,8 @@ const TransmissionChainAnalysis = ({ eventId }) => { data={parent.transmission_mechanism} title="" prefix="机制:" - prefixStyle={{ fontSize: 12, color: '#666', fontWeight: 'bold' }} + prefixStyle={{ fontSize: 12, color: PROFESSIONAL_COLORS.text.secondary, fontWeight: 'bold' }} + textColor={PROFESSIONAL_COLORS.text.primary} containerStyle={{ marginTop: 8 }} showAIBadge={false} /> @@ -1136,7 +1138,8 @@ const TransmissionChainAnalysis = ({ eventId }) => { data={child.transmission_mechanism} title="" prefix="机制:" - prefixStyle={{ fontSize: 12, color: '#666', fontWeight: 'bold' }} + prefixStyle={{ fontSize: 12, color: PROFESSIONAL_COLORS.text.secondary, fontWeight: 'bold' }} + textColor={PROFESSIONAL_COLORS.text.primary} containerStyle={{ marginTop: 8 }} showAIBadge={false} />