6784 lines
263 KiB
Python
6784 lines
263 KiB
Python
# ===================== Gevent Monkey Patch (必须在最开头) =====================
|
||
# 检测是否通过 gevent/gunicorn 运行,如果是则打 monkey patch
|
||
import os
|
||
import sys
|
||
|
||
# 检查环境变量或命令行参数判断是否需要 gevent
|
||
_USE_GEVENT = os.environ.get('USE_GEVENT', 'false').lower() == 'true'
|
||
if _USE_GEVENT or 'gevent' in sys.modules:
|
||
try:
|
||
from gevent import monkey
|
||
monkey.patch_all()
|
||
print("✅ Gevent monkey patch 已应用")
|
||
except ImportError:
|
||
print("⚠️ Gevent 未安装,跳过 monkey patch")
|
||
# ===================== Gevent Monkey Patch 结束 =====================
|
||
|
||
import csv
|
||
import logging
|
||
import random
|
||
import re
|
||
import math
|
||
import secrets
|
||
import string
|
||
|
||
import pytz
|
||
import requests
|
||
from flask_compress import Compress
|
||
from functools import wraps
|
||
from pathlib import Path
|
||
import time
|
||
from sqlalchemy import create_engine, text, func, or_, case, event, desc, asc
|
||
from flask import Flask, has_request_context, render_template, request, jsonify, redirect, url_for, flash, session, \
|
||
render_template_string, current_app, send_from_directory
|
||
|
||
# Flask 3.x 兼容性补丁:flask-sqlalchemy 旧版本需要 _app_ctx_stack
|
||
try:
|
||
from flask import _app_ctx_stack
|
||
except ImportError:
|
||
import flask
|
||
from werkzeug.local import LocalStack
|
||
import threading
|
||
|
||
|
||
# 创建一个兼容的 LocalStack 子类
|
||
class CompatLocalStack(LocalStack):
|
||
@property
|
||
def __ident_func__(self):
|
||
# 返回当前线程的标识函数
|
||
# 优先使用 greenlet(协程),否则使用 threading
|
||
try:
|
||
from greenlet import getcurrent
|
||
return getcurrent
|
||
except ImportError:
|
||
return threading.get_ident
|
||
|
||
|
||
flask._app_ctx_stack = CompatLocalStack()
|
||
|
||
from flask_sqlalchemy import SQLAlchemy
|
||
from flask_login import LoginManager, UserMixin, login_user, logout_user, login_required, current_user
|
||
from flask_mail import Mail, Message
|
||
from itsdangerous import URLSafeTimedSerializer
|
||
from flask_migrate import Migrate
|
||
from flask_session import Session # type: ignore
|
||
from sqlalchemy.dialects.mysql.base import MySQLDialect
|
||
from sqlalchemy.dialects.postgresql import JSONB
|
||
from werkzeug.utils import secure_filename
|
||
from PIL import Image
|
||
from datetime import datetime, timedelta, time as dt_time
|
||
from werkzeug.security import generate_password_hash, check_password_hash
|
||
import json
|
||
from clickhouse_driver import Client as Cclient
|
||
from queue import Queue, Empty, Full
|
||
from threading import Lock, RLock
|
||
from contextlib import contextmanager
|
||
import jwt
|
||
from docx import Document
|
||
from tencentcloud.common import credential
|
||
from tencentcloud.common.profile.client_profile import ClientProfile
|
||
from tencentcloud.common.profile.http_profile import HttpProfile
|
||
from tencentcloud.sms.v20210111 import sms_client, models
|
||
from tencentcloud.common.exception.tencent_cloud_sdk_exception import TencentCloudSDKException
|
||
|
||
engine = create_engine("mysql+pymysql://root:Zzl5588161!@222.128.1.157:33060/stock", echo=False, pool_size=20,
|
||
max_overflow=50)
|
||
engine_med = create_engine("mysql+pymysql://root:Zzl5588161!@222.128.1.157:33060/med", echo=False)
|
||
engine_2 = create_engine("mysql+pymysql://root:Zzl5588161!@222.128.1.157:33060/valuefrontier", echo=False)
|
||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
# ===================== ClickHouse 连接池实现(增强版)=====================
|
||
import threading as _threading
|
||
import atexit
|
||
|
||
class ClickHouseConnectionPool:
|
||
"""
|
||
ClickHouse 连接池(增强版)
|
||
- 支持连接复用,避免频繁创建/销毁连接
|
||
- 支持连接超时和健康检查
|
||
- 支持连接最大存活时间(防止僵尸连接)
|
||
- 支持后台自动清理过期连接
|
||
- 支持自动重连和重试
|
||
- 线程安全
|
||
"""
|
||
|
||
def __init__(self, host, port, user, password, database,
|
||
pool_size=10, max_overflow=10,
|
||
connection_timeout=10, query_timeout=30,
|
||
health_check_interval=60,
|
||
max_connection_lifetime=300, # 新增:连接最大存活时间(秒)
|
||
cleanup_interval=60, # 新增:清理间隔(秒)
|
||
max_retries=3): # 新增:最大重试次数
|
||
"""
|
||
初始化连接池
|
||
|
||
Args:
|
||
host: ClickHouse 主机地址
|
||
port: ClickHouse 端口
|
||
user: 用户名
|
||
password: 密码
|
||
database: 数据库名
|
||
pool_size: 连接池核心大小(预创建连接数)
|
||
max_overflow: 最大溢出连接数(总连接数 = pool_size + max_overflow)
|
||
connection_timeout: 获取连接超时时间(秒)
|
||
query_timeout: 查询超时时间(秒)
|
||
health_check_interval: 健康检查间隔(秒)
|
||
max_connection_lifetime: 连接最大存活时间(秒),超过后自动关闭重建
|
||
cleanup_interval: 后台清理线程间隔(秒)
|
||
max_retries: 查询失败时的最大重试次数
|
||
"""
|
||
self.host = host
|
||
self.port = port
|
||
self.user = user
|
||
self.password = password
|
||
self.database = database
|
||
self.pool_size = pool_size
|
||
self.max_overflow = max_overflow
|
||
self.connection_timeout = connection_timeout
|
||
self.query_timeout = query_timeout
|
||
self.health_check_interval = health_check_interval
|
||
self.max_connection_lifetime = max_connection_lifetime
|
||
self.cleanup_interval = cleanup_interval
|
||
self.max_retries = max_retries
|
||
|
||
# 连接池队列
|
||
self._pool = Queue(maxsize=pool_size + max_overflow)
|
||
# 当前活跃连接数
|
||
self._active_connections = 0
|
||
# 锁
|
||
self._lock = RLock()
|
||
# 连接最后使用时间记录
|
||
self._last_used = {}
|
||
# 连接创建时间记录
|
||
self._created_at = {}
|
||
# 是否已关闭
|
||
self._closed = False
|
||
# 清理线程
|
||
self._cleanup_thread = None
|
||
self._cleanup_stop_event = _threading.Event()
|
||
|
||
# 初始化核心连接(延迟初始化,首次使用时创建)
|
||
self._initialized = False
|
||
# 清理线程也延迟启动(避免 fork 前启动线程导致问题)
|
||
self._cleanup_thread_started = False
|
||
logger.info(f"ClickHouse 连接池配置完成: pool_size={pool_size}, max_overflow={max_overflow}, "
|
||
f"max_lifetime={max_connection_lifetime}s")
|
||
|
||
# 注册退出时清理
|
||
atexit.register(self.close_all)
|
||
|
||
def _start_cleanup_thread(self):
|
||
"""启动后台清理线程(延迟启动,首次使用连接池时调用)"""
|
||
if self._cleanup_thread_started or self._closed:
|
||
return
|
||
|
||
with self._lock:
|
||
if self._cleanup_thread_started or self._closed:
|
||
return
|
||
|
||
def cleanup_worker():
|
||
while not self._cleanup_stop_event.wait(self.cleanup_interval):
|
||
if self._closed:
|
||
break
|
||
try:
|
||
self._cleanup_expired_connections()
|
||
except Exception as e:
|
||
logger.error(f"清理连接时出错: {e}")
|
||
|
||
self._cleanup_thread = _threading.Thread(target=cleanup_worker, daemon=True, name="CH-Pool-Cleanup")
|
||
self._cleanup_thread.start()
|
||
self._cleanup_thread_started = True
|
||
logger.debug("后台清理线程已启动")
|
||
|
||
def _cleanup_expired_connections(self):
|
||
"""清理过期连接"""
|
||
current_time = time.time()
|
||
cleaned_count = 0
|
||
|
||
# 创建临时列表存放需要保留的连接
|
||
valid_connections = []
|
||
|
||
# 从池中取出所有连接进行检查
|
||
while True:
|
||
try:
|
||
conn = self._pool.get_nowait()
|
||
conn_id = id(conn)
|
||
created_at = self._created_at.get(conn_id, current_time)
|
||
last_used = self._last_used.get(conn_id, current_time)
|
||
|
||
# 检查是否过期
|
||
lifetime = current_time - created_at
|
||
idle_time = current_time - last_used
|
||
|
||
if lifetime > self.max_connection_lifetime:
|
||
# 连接存活时间过长,关闭
|
||
logger.debug(f"连接 {conn_id} 存活时间 {lifetime:.0f}s 超过限制,关闭")
|
||
self._close_connection(conn)
|
||
cleaned_count += 1
|
||
elif idle_time > self.health_check_interval * 3:
|
||
# 长时间空闲,进行健康检查
|
||
if not self._check_connection_health(conn):
|
||
self._close_connection(conn)
|
||
cleaned_count += 1
|
||
else:
|
||
valid_connections.append(conn)
|
||
else:
|
||
valid_connections.append(conn)
|
||
except Empty:
|
||
break
|
||
|
||
# 将有效连接放回池中
|
||
for conn in valid_connections:
|
||
try:
|
||
self._pool.put_nowait(conn)
|
||
except Full:
|
||
self._close_connection(conn)
|
||
|
||
if cleaned_count > 0:
|
||
logger.info(f"清理了 {cleaned_count} 个过期连接,当前活跃连接数: {self._active_connections}")
|
||
|
||
def _init_pool(self):
|
||
"""初始化连接池,预创建部分核心连接(非阻塞)"""
|
||
if self._initialized:
|
||
return
|
||
|
||
with self._lock:
|
||
if self._initialized:
|
||
return
|
||
|
||
# 启动清理线程(延迟到首次使用时启动)
|
||
self._start_cleanup_thread()
|
||
|
||
# 只预创建 1 个连接,其余按需创建
|
||
init_count = min(1, self.pool_size)
|
||
for i in range(init_count):
|
||
try:
|
||
conn = self._create_connection()
|
||
if conn:
|
||
self._pool.put(conn)
|
||
logger.info(f"预创建 ClickHouse 连接 {i+1}/{init_count} 成功")
|
||
except Exception as e:
|
||
logger.warning(f"预创建 ClickHouse 连接失败 ({i+1}/{init_count}): {e}")
|
||
# 预创建失败不阻塞启动,后续按需创建
|
||
break
|
||
|
||
self._initialized = True
|
||
|
||
def _create_connection(self):
|
||
"""创建新的 ClickHouse 连接"""
|
||
try:
|
||
client = Cclient(
|
||
host=self.host,
|
||
port=self.port,
|
||
user=self.user,
|
||
password=self.password,
|
||
database=self.database,
|
||
connect_timeout=self.connection_timeout,
|
||
send_receive_timeout=self.query_timeout,
|
||
sync_request_timeout=self.query_timeout,
|
||
settings={
|
||
'max_execution_time': self.query_timeout,
|
||
'connect_timeout': self.connection_timeout,
|
||
}
|
||
)
|
||
conn_id = id(client)
|
||
current_time = time.time()
|
||
self._created_at[conn_id] = current_time
|
||
self._last_used[conn_id] = current_time
|
||
with self._lock:
|
||
self._active_connections += 1
|
||
logger.debug(f"创建新的 ClickHouse 连接: {conn_id}")
|
||
return client
|
||
except Exception as e:
|
||
logger.error(f"创建 ClickHouse 连接失败: {e}")
|
||
raise
|
||
|
||
def _check_connection_health(self, conn):
|
||
"""检查连接健康状态(带超时保护)"""
|
||
try:
|
||
conn_id = id(conn)
|
||
last_used = self._last_used.get(conn_id, 0)
|
||
created_at = self._created_at.get(conn_id, 0)
|
||
current_time = time.time()
|
||
|
||
# 检查连接是否存活时间过长
|
||
if current_time - created_at > self.max_connection_lifetime:
|
||
logger.debug(f"连接 {conn_id} 超过最大存活时间,标记为不健康")
|
||
return False
|
||
|
||
# 如果连接长时间未使用,进行健康检查
|
||
if current_time - last_used > self.health_check_interval:
|
||
# 执行简单查询检查连接
|
||
conn.execute("SELECT 1")
|
||
self._last_used[conn_id] = current_time
|
||
logger.debug(f"连接 {conn_id} 健康检查通过")
|
||
|
||
return True
|
||
except Exception as e:
|
||
logger.warning(f"连接健康检查失败: {e}")
|
||
return False
|
||
|
||
def _close_connection(self, conn):
|
||
"""关闭连接"""
|
||
if conn is None:
|
||
return
|
||
try:
|
||
conn_id = id(conn)
|
||
try:
|
||
conn.disconnect()
|
||
except:
|
||
pass # 忽略断开连接时的错误
|
||
self._last_used.pop(conn_id, None)
|
||
self._created_at.pop(conn_id, None)
|
||
with self._lock:
|
||
self._active_connections = max(0, self._active_connections - 1)
|
||
logger.debug(f"关闭 ClickHouse 连接: {conn_id}")
|
||
except Exception as e:
|
||
logger.warning(f"关闭连接时出错: {e}")
|
||
|
||
def get_connection(self, timeout=None):
|
||
"""
|
||
从连接池获取连接
|
||
|
||
Args:
|
||
timeout: 获取连接的超时时间,默认使用 connection_timeout
|
||
|
||
Returns:
|
||
ClickHouse 客户端连接
|
||
|
||
Raises:
|
||
TimeoutError: 获取连接超时
|
||
Exception: 创建连接失败
|
||
"""
|
||
if self._closed:
|
||
raise RuntimeError("连接池已关闭")
|
||
|
||
# 延迟初始化
|
||
if not self._initialized:
|
||
self._init_pool()
|
||
|
||
timeout = timeout or self.connection_timeout
|
||
start_time = time.time()
|
||
|
||
while True:
|
||
elapsed = time.time() - start_time
|
||
if elapsed >= timeout:
|
||
raise TimeoutError(f"获取 ClickHouse 连接超时 (timeout={timeout}s)")
|
||
|
||
remaining_timeout = timeout - elapsed
|
||
|
||
# 首先尝试从池中获取连接
|
||
try:
|
||
conn = self._pool.get(block=True, timeout=min(remaining_timeout, 1.0))
|
||
|
||
# 检查连接健康状态
|
||
if self._check_connection_health(conn):
|
||
self._last_used[id(conn)] = time.time()
|
||
return conn
|
||
else:
|
||
# 连接不健康,关闭并尝试获取新连接
|
||
self._close_connection(conn)
|
||
continue
|
||
|
||
except Empty:
|
||
# 池中没有可用连接,检查是否可以创建新连接
|
||
with self._lock:
|
||
if self._active_connections < self.pool_size + self.max_overflow:
|
||
try:
|
||
return self._create_connection()
|
||
except Exception as e:
|
||
logger.error(f"创建溢出连接失败: {e}")
|
||
# 不立即抛出异常,继续等待
|
||
|
||
# 短暂等待后重试
|
||
time.sleep(0.1)
|
||
|
||
def release_connection(self, conn):
|
||
"""
|
||
释放连接回连接池
|
||
|
||
Args:
|
||
conn: 要释放的连接
|
||
"""
|
||
if conn is None:
|
||
return
|
||
|
||
if self._closed:
|
||
self._close_connection(conn)
|
||
return
|
||
|
||
conn_id = id(conn)
|
||
created_at = self._created_at.get(conn_id, 0)
|
||
|
||
# 如果连接存活时间过长,直接关闭而不放回池中
|
||
if time.time() - created_at > self.max_connection_lifetime:
|
||
logger.debug(f"连接 {conn_id} 超过最大存活时间,关闭而不放回池中")
|
||
self._close_connection(conn)
|
||
return
|
||
|
||
self._last_used[conn_id] = time.time()
|
||
|
||
try:
|
||
self._pool.put(conn, block=False)
|
||
logger.debug(f"连接 {conn_id} 已释放回连接池")
|
||
except Full:
|
||
# 池已满,关闭多余连接
|
||
logger.debug(f"连接池已满,关闭多余连接: {conn_id}")
|
||
self._close_connection(conn)
|
||
|
||
@contextmanager
|
||
def connection(self, timeout=None):
|
||
"""
|
||
上下文管理器方式获取连接
|
||
|
||
Usage:
|
||
with pool.connection() as conn:
|
||
result = conn.execute("SELECT * FROM table")
|
||
"""
|
||
conn = None
|
||
try:
|
||
conn = self.get_connection(timeout)
|
||
yield conn
|
||
except Exception as e:
|
||
# 发生异常时,检查连接是否需要重建
|
||
if conn:
|
||
try:
|
||
# 尝试简单查询检测连接状态
|
||
conn.execute("SELECT 1")
|
||
except:
|
||
# 连接已损坏,关闭它
|
||
self._close_connection(conn)
|
||
conn = None
|
||
raise
|
||
finally:
|
||
if conn:
|
||
self.release_connection(conn)
|
||
|
||
def execute(self, query, params=None, timeout=None):
|
||
"""
|
||
执行查询(自动管理连接,带重试机制)
|
||
|
||
Args:
|
||
query: SQL 查询语句
|
||
params: 查询参数
|
||
timeout: 查询超时时间
|
||
|
||
Returns:
|
||
查询结果
|
||
"""
|
||
last_error = None
|
||
for retry in range(self.max_retries):
|
||
try:
|
||
with self.connection(timeout) as conn:
|
||
return conn.execute(query, params)
|
||
except (TimeoutError, RuntimeError) as e:
|
||
# 这些错误不应该重试
|
||
raise
|
||
except Exception as e:
|
||
last_error = e
|
||
logger.warning(f"查询执行失败 (重试 {retry + 1}/{self.max_retries}): {e}")
|
||
if retry < self.max_retries - 1:
|
||
time.sleep(0.5 * (retry + 1)) # 递增等待时间
|
||
|
||
raise last_error
|
||
|
||
def get_pool_status(self):
|
||
"""获取连接池状态"""
|
||
return {
|
||
'pool_size': self.pool_size,
|
||
'max_overflow': self.max_overflow,
|
||
'active_connections': self._active_connections,
|
||
'available_connections': self._pool.qsize(),
|
||
'max_connections': self.pool_size + self.max_overflow,
|
||
'max_connection_lifetime': self.max_connection_lifetime,
|
||
'initialized': self._initialized,
|
||
'closed': self._closed
|
||
}
|
||
|
||
def close_all(self):
|
||
"""关闭所有连接"""
|
||
self._closed = True
|
||
self._cleanup_stop_event.set() # 停止清理线程
|
||
|
||
# 等待清理线程结束
|
||
if self._cleanup_thread and self._cleanup_thread.is_alive():
|
||
self._cleanup_thread.join(timeout=2)
|
||
|
||
# 关闭所有池中的连接
|
||
while not self._pool.empty():
|
||
try:
|
||
conn = self._pool.get_nowait()
|
||
self._close_connection(conn)
|
||
except Empty:
|
||
break
|
||
logger.info("ClickHouse 连接池已关闭所有连接")
|
||
|
||
|
||
# 初始化全局 ClickHouse 连接池(懒加载模式)
|
||
clickhouse_pool = None
|
||
_pool_lock = Lock()
|
||
|
||
|
||
def _init_clickhouse_pool():
|
||
"""懒加载初始化 ClickHouse 连接池"""
|
||
global clickhouse_pool
|
||
if clickhouse_pool is None:
|
||
with _pool_lock:
|
||
if clickhouse_pool is None:
|
||
clickhouse_pool = ClickHouseConnectionPool(
|
||
host='222.128.1.157',
|
||
port=18000,
|
||
user='default',
|
||
password='Zzl33818!',
|
||
database='stock',
|
||
pool_size=5, # 核心连接数
|
||
max_overflow=20, # 溢出连接数,总共支持 25 并发
|
||
connection_timeout=15, # 连接超时 15 秒(增加容忍度)
|
||
query_timeout=60, # 查询超时 60 秒(给复杂查询更多时间)
|
||
health_check_interval=30, # 30 秒进行健康检查
|
||
max_connection_lifetime=300, # 连接最大存活 5 分钟(防止僵尸连接)
|
||
cleanup_interval=60, # 每 60 秒清理一次过期连接
|
||
max_retries=3 # 查询失败最多重试 3 次
|
||
)
|
||
return clickhouse_pool
|
||
# ===================== ClickHouse 连接池实现结束 =====================
|
||
app = Flask(__name__)
|
||
Compress(app)
|
||
UPLOAD_FOLDER = 'static/uploads/avatars'
|
||
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif'}
|
||
MAX_CONTENT_LENGTH = 16 * 1024 * 1024 # 16MB max file size
|
||
# Configure Flask-Compress
|
||
app.config['COMPRESS_ALGORITHM'] = ['gzip', 'br']
|
||
app.config['COMPRESS_MIMETYPES'] = [
|
||
'text/html',
|
||
'text/css',
|
||
'text/xml',
|
||
'application/json',
|
||
'application/javascript',
|
||
'application/x-javascript'
|
||
]
|
||
|
||
# ===================== Token 存储(支持多 worker 共享) =====================
|
||
class TokenStore:
|
||
"""
|
||
Token 存储类 - 支持 Redis(多 worker 共享)或内存(单 worker)
|
||
"""
|
||
def __init__(self):
|
||
self._redis_client = None
|
||
self._memory_store = {}
|
||
self._prefix = 'vf_token:'
|
||
self._initialized = False
|
||
|
||
def _ensure_initialized(self):
|
||
"""延迟初始化,确保在 fork 后才连接 Redis"""
|
||
if self._initialized:
|
||
return
|
||
self._initialized = True
|
||
|
||
redis_url = os.environ.get('REDIS_URL', 'redis://localhost:6379/0')
|
||
try:
|
||
import redis
|
||
self._redis_client = redis.from_url(redis_url)
|
||
self._redis_client.ping()
|
||
logger.info(f"✅ Token 存储: Redis ({redis_url})")
|
||
except Exception as e:
|
||
logger.warning(f"⚠️ Redis 不可用 ({e}),Token 使用内存存储(多 worker 模式下会有问题!)")
|
||
self._redis_client = None
|
||
|
||
def get(self, token):
|
||
"""获取 token 数据"""
|
||
self._ensure_initialized()
|
||
if self._redis_client:
|
||
try:
|
||
data = self._redis_client.get(f"{self._prefix}{token}")
|
||
if data:
|
||
return json.loads(data)
|
||
return None
|
||
except Exception as e:
|
||
logger.error(f"Redis get error: {e}")
|
||
return self._memory_store.get(token)
|
||
return self._memory_store.get(token)
|
||
|
||
def set(self, token, data, expire_seconds=30*24*3600):
|
||
"""设置 token 数据"""
|
||
self._ensure_initialized()
|
||
if self._redis_client:
|
||
try:
|
||
# 将 datetime 转为字符串存储
|
||
store_data = data.copy()
|
||
if 'expires' in store_data and isinstance(store_data['expires'], datetime):
|
||
store_data['expires'] = store_data['expires'].isoformat()
|
||
self._redis_client.setex(
|
||
f"{self._prefix}{token}",
|
||
expire_seconds,
|
||
json.dumps(store_data)
|
||
)
|
||
return
|
||
except Exception as e:
|
||
logger.error(f"Redis set error: {e}")
|
||
self._memory_store[token] = data
|
||
|
||
def delete(self, token):
|
||
"""删除 token"""
|
||
self._ensure_initialized()
|
||
if self._redis_client:
|
||
try:
|
||
self._redis_client.delete(f"{self._prefix}{token}")
|
||
return
|
||
except Exception as e:
|
||
logger.error(f"Redis delete error: {e}")
|
||
self._memory_store.pop(token, None)
|
||
|
||
def __contains__(self, token):
|
||
"""支持 'in' 操作符"""
|
||
return self.get(token) is not None
|
||
|
||
|
||
# 使用 TokenStore 替代内存字典
|
||
user_tokens = TokenStore()
|
||
|
||
app.config['SECRET_KEY'] = 'vf7891574233241'
|
||
app.config['SQLALCHEMY_DATABASE_URI'] = 'mysql+pymysql://root:Zzl5588161!@222.128.1.157:33060/stock?charset=utf8mb4'
|
||
app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False
|
||
|
||
app.config['JSON_AS_ASCII'] = False
|
||
app.config['JSONIFY_PRETTYPRINT_REGULAR'] = True
|
||
|
||
# 邮件配置
|
||
app.config['MAIL_SERVER'] = 'smtp.exmail.qq.com'
|
||
app.config['MAIL_PORT'] = 465
|
||
app.config['MAIL_USE_SSL'] = True
|
||
app.config['MAIL_USERNAME'] = 'admin@valuefrontier.cn'
|
||
app.config['MAIL_PASSWORD'] = 'QYncRu6WUdASvTg4'
|
||
|
||
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
|
||
app.config['MAX_CONTENT_LENGTH'] = MAX_CONTENT_LENGTH
|
||
|
||
# 腾讯云短信配置
|
||
SMS_SECRET_ID = 'AKID2we9TacdTAhCjCSYTErHVimeJo9Yr00s'
|
||
SMS_SECRET_KEY = 'pMlBWijlkgT9fz5ziEXdWEnAPTJzRfkf'
|
||
SMS_SDK_APP_ID = "1400972398"
|
||
SMS_SIGN_NAME = "价值前沿科技"
|
||
SMS_TEMPLATE_REGISTER = "2386557" # 注册模板
|
||
SMS_TEMPLATE_LOGIN = "2386540" # 登录模板
|
||
verification_codes = {}
|
||
|
||
# 微信小程序
|
||
app.config['WECHAT_APP_ID'] = 'wx0edeaab76d4fa414'
|
||
app.config['WECHAT_APP_SECRET'] = '0d0c70084f05a8c1411f6b89da7e815d'
|
||
app.config['BASE_URL'] = 'http://43.143.189.195:5002'
|
||
app.config['WECHAT_REDIRECT_URI'] = f"{app.config['BASE_URL']}/api/wechat/callback"
|
||
WECHAT_APP_ID = 'wx0edeaab76d4fa414'
|
||
WECHAT_APP_SECRET = '0d0c70084f05a8c1411f6b89da7e815d'
|
||
JWT_SECRET_KEY = 'vfllmgreat33818!' # 请修改为安全的密钥
|
||
JWT_ALGORITHM = 'HS256'
|
||
JWT_EXPIRATION_HOURS = 24 * 7 # Token有效期7天
|
||
|
||
# Session 配置
|
||
# 优先使用 Redis(支持多 worker 共享),否则回退到文件系统
|
||
_REDIS_URL = os.environ.get('REDIS_URL', 'redis://localhost:6379/0')
|
||
_USE_REDIS_SESSION = os.environ.get('USE_REDIS_SESSION', 'true').lower() == 'true'
|
||
|
||
try:
|
||
if _USE_REDIS_SESSION:
|
||
import redis
|
||
# 测试 Redis 连接
|
||
_redis_client = redis.from_url(_REDIS_URL)
|
||
_redis_client.ping()
|
||
|
||
app.config['SESSION_TYPE'] = 'redis'
|
||
app.config['SESSION_REDIS'] = _redis_client
|
||
app.config['SESSION_KEY_PREFIX'] = 'vf_session:'
|
||
logger.info(f"✅ Session 存储: Redis ({_REDIS_URL})")
|
||
else:
|
||
raise Exception("Redis session disabled by config")
|
||
except Exception as e:
|
||
# Redis 不可用,回退到文件系统
|
||
logger.warning(f"⚠️ Redis 不可用 ({e}),使用文件系统 session(多 worker 模式下可能不稳定)")
|
||
app.config['SESSION_TYPE'] = 'filesystem'
|
||
app.config['SESSION_FILE_DIR'] = os.path.join(os.path.dirname(__file__), 'flask_session')
|
||
os.makedirs(app.config['SESSION_FILE_DIR'], exist_ok=True)
|
||
|
||
app.config['SESSION_PERMANENT'] = True
|
||
app.config['PERMANENT_SESSION_LIFETIME'] = timedelta(days=7) # Session 有效期 7 天
|
||
app.config['SESSION_COOKIE_SECURE'] = False # 生产环境 HTTPS 时设为 True
|
||
app.config['SESSION_COOKIE_HTTPONLY'] = True
|
||
app.config['SESSION_COOKIE_SAMESITE'] = 'Lax'
|
||
|
||
# Cache directory setup
|
||
CACHE_DIR = Path('cache')
|
||
CACHE_DIR.mkdir(exist_ok=True)
|
||
|
||
# Memory management constants
|
||
MAX_MEMORY_PERCENT = 75
|
||
MEMORY_CHECK_INTERVAL = 300
|
||
MAX_CACHE_ITEMS = 50
|
||
|
||
# 申银万国行业分类缓存(启动时初始化,避免每次请求都查询数据库)
|
||
# 结构: {industry_level: {industry_name: [code_prefix1, code_prefix2, ...]}}
|
||
SYWG_INDUSTRY_CACHE = {
|
||
2: {}, # level2: 一级行业
|
||
3: {}, # level3: 二级行业
|
||
4: {}, # level4: 三级行业
|
||
5: {} # level5: 四级行业
|
||
}
|
||
|
||
# 初始化扩展
|
||
db = SQLAlchemy(app)
|
||
mail = Mail(app)
|
||
login_manager = LoginManager(app)
|
||
login_manager.login_view = 'login'
|
||
serializer = URLSafeTimedSerializer(app.config['SECRET_KEY'])
|
||
|
||
migrate = Migrate(app, db)
|
||
|
||
DOMAIN = 'http://43.143.189.195:5002'
|
||
|
||
JWT_SECRET = 'Llmgreat123'
|
||
JWT_EXPIRES_SECONDS = 3600 # 1小时有效期
|
||
|
||
Session(app)
|
||
|
||
|
||
def token_required(f):
|
||
"""装饰器:需要token认证的接口"""
|
||
from functools import wraps
|
||
|
||
@wraps(f)
|
||
def decorated_function(*args, **kwargs):
|
||
token = None
|
||
|
||
# 从请求头获取token
|
||
auth_header = request.headers.get('Authorization')
|
||
if auth_header and auth_header.startswith('Bearer '):
|
||
token = auth_header[7:]
|
||
|
||
if not token:
|
||
return jsonify({'message': '缺少认证token'}), 401
|
||
|
||
token_data = user_tokens.get(token)
|
||
if not token_data:
|
||
return jsonify({'message': 'Token无效', 'code': 401}), 401
|
||
|
||
# 检查是否过期(expires 可能是字符串或 datetime)
|
||
expires = token_data['expires']
|
||
if isinstance(expires, str):
|
||
expires = datetime.fromisoformat(expires)
|
||
if expires < datetime.now():
|
||
user_tokens.delete(token)
|
||
return jsonify({'message': 'Token已过期'}), 401
|
||
|
||
# 获取用户对象并添加到请求上下文
|
||
user = User.query.get(token_data['user_id'])
|
||
if not user:
|
||
return jsonify({'message': '用户不存在'}), 404
|
||
|
||
# 将用户对象添加到request
|
||
request.user = user
|
||
request.current_user_id = token_data['user_id']
|
||
|
||
return f(*args, **kwargs)
|
||
|
||
return decorated_function
|
||
|
||
|
||
def beijing_now():
|
||
# 使用 pytz 处理时区
|
||
beijing_tz = pytz.timezone('Asia/Shanghai')
|
||
return datetime.now(beijing_tz)
|
||
|
||
|
||
# ============================================
|
||
# 订阅功能模块(与 app.py 保持一致)
|
||
# ============================================
|
||
class UserSubscription(db.Model):
|
||
"""用户订阅表 - 独立于现有User表"""
|
||
__tablename__ = 'user_subscriptions'
|
||
|
||
id = db.Column(db.Integer, primary_key=True, autoincrement=True)
|
||
user_id = db.Column(db.Integer, nullable=False, unique=True, index=True)
|
||
subscription_type = db.Column(db.String(10), nullable=False, default='free')
|
||
subscription_status = db.Column(db.String(20), nullable=False, default='active')
|
||
start_date = db.Column(db.DateTime, nullable=True)
|
||
end_date = db.Column(db.DateTime, nullable=True)
|
||
billing_cycle = db.Column(db.String(10), nullable=True)
|
||
auto_renewal = db.Column(db.Boolean, nullable=False, default=False)
|
||
created_at = db.Column(db.DateTime, default=beijing_now)
|
||
updated_at = db.Column(db.DateTime, default=beijing_now, onupdate=beijing_now)
|
||
|
||
def is_active(self):
|
||
if self.subscription_status != 'active':
|
||
return False
|
||
if self.subscription_type == 'free':
|
||
return True
|
||
if self.end_date:
|
||
# 将数据库的 naive datetime 转换为带时区的 aware datetime
|
||
beijing_tz = pytz.timezone('Asia/Shanghai')
|
||
end_date_aware = self.end_date if self.end_date.tzinfo else beijing_tz.localize(self.end_date)
|
||
return beijing_now() <= end_date_aware
|
||
return True
|
||
|
||
def days_left(self):
|
||
if not self.is_active():
|
||
return 0
|
||
if self.subscription_type == 'free':
|
||
return 999
|
||
if not self.end_date:
|
||
return 999
|
||
try:
|
||
now = beijing_now()
|
||
# 将数据库的 naive datetime 转换为带时区的 aware datetime
|
||
beijing_tz = pytz.timezone('Asia/Shanghai')
|
||
end_date_aware = self.end_date if self.end_date.tzinfo else beijing_tz.localize(self.end_date)
|
||
delta = end_date_aware - now
|
||
return max(0, delta.days)
|
||
except Exception as e:
|
||
return 0
|
||
|
||
def to_dict(self):
|
||
return {
|
||
'type': self.subscription_type,
|
||
'status': self.subscription_status,
|
||
'is_active': self.is_active(),
|
||
'start_date': self.start_date.isoformat() if self.start_date else None,
|
||
'end_date': self.end_date.isoformat() if self.end_date else None,
|
||
'days_left': self.days_left(),
|
||
'billing_cycle': self.billing_cycle,
|
||
'auto_renewal': self.auto_renewal
|
||
}
|
||
|
||
|
||
# ============================================
|
||
# 订阅等级工具函数
|
||
# ============================================
|
||
def get_user_subscription_safe(user_id):
|
||
"""
|
||
安全地获取用户订阅信息
|
||
:param user_id: 用户ID
|
||
:return: UserSubscription 对象或默认免费订阅
|
||
"""
|
||
try:
|
||
subscription = UserSubscription.query.filter_by(user_id=user_id).first()
|
||
if not subscription:
|
||
# 如果用户没有订阅记录,创建默认免费订阅
|
||
subscription = UserSubscription(
|
||
user_id=user_id,
|
||
subscription_type='free',
|
||
subscription_status='active'
|
||
)
|
||
db.session.add(subscription)
|
||
db.session.commit()
|
||
return subscription
|
||
except Exception as e:
|
||
print(f"获取用户订阅信息失败: {e}")
|
||
# 返回一个临时的免费订阅对象(不保存到数据库)
|
||
temp_sub = UserSubscription(
|
||
user_id=user_id,
|
||
subscription_type='free',
|
||
subscription_status='active'
|
||
)
|
||
return temp_sub
|
||
|
||
|
||
def _get_current_subscription_info():
|
||
"""
|
||
获取当前登录用户订阅信息的字典形式,未登录或异常时视为免费用户。
|
||
小程序场景下从 request.current_user_id 获取用户ID
|
||
"""
|
||
try:
|
||
user_id = getattr(request, 'current_user_id', None)
|
||
if not user_id:
|
||
return {
|
||
'type': 'free',
|
||
'status': 'active',
|
||
'is_active': True
|
||
}
|
||
sub = get_user_subscription_safe(user_id)
|
||
return {
|
||
'type': sub.subscription_type,
|
||
'status': sub.subscription_status,
|
||
'is_active': sub.is_active(),
|
||
'start_date': sub.start_date.isoformat() if sub.start_date else None,
|
||
'end_date': sub.end_date.isoformat() if sub.end_date else None,
|
||
'days_left': sub.days_left()
|
||
}
|
||
except Exception as e:
|
||
print(f"获取订阅信息异常: {e}")
|
||
return {
|
||
'type': 'free',
|
||
'status': 'active',
|
||
'is_active': True
|
||
}
|
||
|
||
|
||
def _subscription_level(sub_type):
|
||
"""将订阅类型映射到等级数值,free=0, pro=1, max=2。"""
|
||
mapping = {'free': 0, 'pro': 1, 'max': 2}
|
||
return mapping.get((sub_type or 'free').lower(), 0)
|
||
|
||
|
||
def _has_required_level(required: str) -> bool:
|
||
"""判断当前用户是否达到所需订阅级别。"""
|
||
info = _get_current_subscription_info()
|
||
if not info.get('is_active', True):
|
||
return False
|
||
return _subscription_level(info.get('type')) >= _subscription_level(required)
|
||
|
||
|
||
# ============================================
|
||
# 权限装饰器
|
||
# ============================================
|
||
def subscription_required(level='pro'):
|
||
"""
|
||
订阅等级装饰器 - 小程序专用
|
||
用法:
|
||
@subscription_required('pro') # 需要 Pro 或 Max 用户
|
||
@subscription_required('max') # 仅限 Max 用户
|
||
|
||
注意:此装饰器需要配合 使用
|
||
"""
|
||
from functools import wraps
|
||
|
||
def decorator(f):
|
||
@wraps(f)
|
||
def decorated_function(*args, **kwargs):
|
||
if not _has_required_level(level):
|
||
current_info = _get_current_subscription_info()
|
||
current_type = current_info.get('type', 'free')
|
||
is_active = current_info.get('is_active', False)
|
||
|
||
if not is_active:
|
||
return jsonify({
|
||
'success': False,
|
||
'error': '您的订阅已过期,请续费后继续使用',
|
||
'error_code': 'SUBSCRIPTION_EXPIRED',
|
||
'current_subscription': current_type,
|
||
'required_subscription': level
|
||
}), 403
|
||
|
||
return jsonify({
|
||
'success': False,
|
||
'error': f'此功能需要 {level.upper()} 或更高等级会员',
|
||
'error_code': 'SUBSCRIPTION_REQUIRED',
|
||
'current_subscription': current_type,
|
||
'required_subscription': level
|
||
}), 403
|
||
|
||
return f(*args, **kwargs)
|
||
|
||
return decorated_function
|
||
|
||
return decorator
|
||
|
||
|
||
def pro_or_max_required(f):
|
||
"""
|
||
快捷装饰器:要求 Pro 或 Max 用户(小程序专用场景)
|
||
等同于 @subscription_required('pro')
|
||
"""
|
||
from functools import wraps
|
||
|
||
@wraps(f)
|
||
def decorated_function(*args, **kwargs):
|
||
if not _has_required_level('pro'):
|
||
current_info = _get_current_subscription_info()
|
||
current_type = current_info.get('type', 'free')
|
||
|
||
return jsonify({
|
||
'success': False,
|
||
'error': '小程序功能仅对 Pro 和 Max 会员开放',
|
||
'error_code': 'MINIPROGRAM_PRO_REQUIRED',
|
||
'current_subscription': current_type,
|
||
'required_subscription': 'pro',
|
||
'message': '请升级到 Pro 或 Max 会员以使用小程序完整功能'
|
||
}), 403
|
||
|
||
return f(*args, **kwargs)
|
||
|
||
return decorated_function
|
||
|
||
|
||
class User(UserMixin, db.Model):
|
||
"""用户模型"""
|
||
id = db.Column(db.Integer, primary_key=True)
|
||
|
||
# 基础账号信息(注册时必填)
|
||
username = db.Column(db.String(80), unique=True, nullable=False) # 用户名
|
||
email = db.Column(db.String(120), unique=True, nullable=False) # 邮箱
|
||
password_hash = db.Column(db.String(128), nullable=False) # 密码哈希
|
||
email_confirmed = db.Column(db.Boolean, default=False) # 邮箱是否验证
|
||
wechat_union_id = db.Column(db.String(100), unique=True) # 微信 UnionID
|
||
wechat_open_id = db.Column(db.String(100)) # 微信 OpenID
|
||
|
||
# 账号状态
|
||
created_at = db.Column(db.DateTime, default=beijing_now) # 注册时间
|
||
last_seen = db.Column(db.DateTime, default=beijing_now) # 最后活跃时间
|
||
status = db.Column(db.String(20), default='active') # 账号状态 active/banned/deleted
|
||
|
||
# 个人资料(可选,后续在个人中心完善)
|
||
nickname = db.Column(db.String(30)) # 社区昵称
|
||
avatar_url = db.Column(db.String(200)) # 头像URL
|
||
banner_url = db.Column(db.String(200)) # 个人主页背景图
|
||
bio = db.Column(db.String(200)) # 个人简介
|
||
gender = db.Column(db.String(10)) # 性别
|
||
birth_date = db.Column(db.Date) # 生日
|
||
location = db.Column(db.String(100)) # 所在地
|
||
|
||
# 联系方式(可选)
|
||
phone = db.Column(db.String(20)) # 手机号
|
||
wechat_id = db.Column(db.String(80)) # 微信号
|
||
|
||
# 实名认证信息(可选)
|
||
real_name = db.Column(db.String(30)) # 真实姓名
|
||
id_number = db.Column(db.String(18)) # 身份证号(加密存储)
|
||
is_verified = db.Column(db.Boolean, default=False) # 是否实名认证
|
||
verify_time = db.Column(db.DateTime) # 实名认证时间
|
||
|
||
# 投资相关信息(可选)
|
||
trading_experience = db.Column(db.String(200)) # 炒股年限
|
||
investment_style = db.Column(db.String(50)) # 投资风格
|
||
risk_preference = db.Column(db.String(20)) # 风险偏好
|
||
investment_amount = db.Column(db.String(20)) # 投资规模
|
||
preferred_markets = db.Column(db.String(200), default='[]') # 偏好市场 JSON
|
||
|
||
# 社区信息(系统自动更新)
|
||
user_level = db.Column(db.Integer, default=1) # 用户等级
|
||
reputation_score = db.Column(db.Integer, default=0) # 信用积分
|
||
contribution_point = db.Column(db.Integer, default=0) # 贡献点数
|
||
post_count = db.Column(db.Integer, default=0) # 发帖数
|
||
comment_count = db.Column(db.Integer, default=0) # 评论数
|
||
follower_count = db.Column(db.Integer, default=0) # 粉丝数
|
||
following_count = db.Column(db.Integer, default=0) # 关注数
|
||
|
||
# 创作者信息(可选)
|
||
is_creator = db.Column(db.Boolean, default=False) # 是否创作者
|
||
creator_type = db.Column(db.String(20)) # 创作者类型
|
||
creator_tags = db.Column(db.String(200), default='[]') # 创作者标签 JSON
|
||
|
||
# 系统设置
|
||
email_notifications = db.Column(db.Boolean, default=True) # 邮件通知
|
||
sms_notifications = db.Column(db.Boolean, default=False) # 短信通知
|
||
wechat_notifications = db.Column(db.Boolean, default=False) # 微信通知
|
||
notification_preferences = db.Column(db.String(500), default='{}') # 通知偏好 JSON
|
||
privacy_level = db.Column(db.String(20), default='public') # 隐私级别
|
||
theme_preference = db.Column(db.String(20), default='light') # 主题偏好
|
||
blocked_keywords = db.Column(db.String(500), default='[]') # 屏蔽关键词 JSON
|
||
# 手机号验证
|
||
phone_confirmed = db.Column(db.Boolean, default=False) # 手机是否验证
|
||
phone_confirm_time = db.Column(db.DateTime) # 手机验证时间
|
||
|
||
def __init__(self, username, email=None, password=None, phone=None):
|
||
"""初始化用户,只需要基本信息"""
|
||
self.username = username
|
||
if email:
|
||
self.email = email
|
||
if phone:
|
||
self.phone = phone
|
||
if password:
|
||
self.set_password(password)
|
||
self.created_at = beijing_now()
|
||
self.last_seen = beijing_now()
|
||
|
||
def set_password(self, password):
|
||
"""设置密码"""
|
||
self.password_hash = generate_password_hash(password)
|
||
|
||
def check_password(self, password):
|
||
"""验证密码"""
|
||
return check_password_hash(self.password_hash, password)
|
||
|
||
def update_last_seen(self):
|
||
"""更新最后活跃时间"""
|
||
self.last_seen = beijing_now()
|
||
|
||
# JSON 字段的getter和setter
|
||
def get_preferred_markets(self):
|
||
"""获取偏好市场列表"""
|
||
if self.preferred_markets:
|
||
try:
|
||
return json.loads(self.preferred_markets)
|
||
except:
|
||
return []
|
||
return []
|
||
|
||
def get_blocked_keywords(self):
|
||
"""获取屏蔽关键词列表"""
|
||
if self.blocked_keywords:
|
||
try:
|
||
return json.loads(self.blocked_keywords)
|
||
except:
|
||
return []
|
||
return []
|
||
|
||
def get_notification_preferences(self):
|
||
"""获取通知偏好设置"""
|
||
if self.notification_preferences:
|
||
try:
|
||
return json.loads(self.notification_preferences)
|
||
except:
|
||
return {}
|
||
return {}
|
||
|
||
def get_creator_tags(self):
|
||
"""获取创作者标签"""
|
||
if self.creator_tags:
|
||
try:
|
||
return json.loads(self.creator_tags)
|
||
except:
|
||
return []
|
||
return []
|
||
|
||
def set_preferred_markets(self, markets):
|
||
"""设置偏好市场"""
|
||
self.preferred_markets = json.dumps(markets)
|
||
|
||
def set_blocked_keywords(self, keywords):
|
||
"""设置屏蔽关键词"""
|
||
self.blocked_keywords = json.dumps(keywords)
|
||
|
||
def set_notification_preferences(self, preferences):
|
||
"""设置通知偏好"""
|
||
self.notification_preferences = json.dumps(preferences)
|
||
|
||
def set_creator_tags(self, tags):
|
||
"""设置创作者标签"""
|
||
self.creator_tags = json.dumps(tags)
|
||
|
||
def to_dict(self):
|
||
"""返回用户的字典表示"""
|
||
return {
|
||
'id': self.id,
|
||
'username': self.username,
|
||
'email': self.email,
|
||
'nickname': self.nickname,
|
||
'avatar_url': get_full_avatar_url(self.avatar_url), # 修改这里
|
||
'bio': self.bio,
|
||
'is_verified': self.is_verified,
|
||
'user_level': self.user_level,
|
||
'reputation_score': self.reputation_score,
|
||
'is_creator': self.is_creator
|
||
}
|
||
|
||
def __repr__(self):
|
||
return f'<User {self.username}>'
|
||
|
||
|
||
class Notification(db.Model):
|
||
"""通知模型"""
|
||
id = db.Column(db.Integer, primary_key=True)
|
||
user_id = db.Column(db.Integer, db.ForeignKey('user.id'))
|
||
type = db.Column(db.String(50)) # 通知类型
|
||
content = db.Column(db.Text) # 通知内容
|
||
link = db.Column(db.String(200)) # 相关链接
|
||
is_read = db.Column(db.Boolean, default=False) # 是否已读
|
||
created_at = db.Column(db.DateTime, default=beijing_now)
|
||
|
||
def __init__(self, user_id, type, content, link=None):
|
||
self.user_id = user_id
|
||
self.type = type
|
||
self.content = content
|
||
self.link = link
|
||
|
||
|
||
class Event(db.Model):
|
||
"""事件模型"""
|
||
id = db.Column(db.Integer, primary_key=True)
|
||
title = db.Column(db.String(200), nullable=False)
|
||
description = db.Column(db.Text)
|
||
|
||
# 事件类型与状态
|
||
event_type = db.Column(db.String(50))
|
||
status = db.Column(db.String(20), default='active')
|
||
|
||
# 时间相关
|
||
start_time = db.Column(db.DateTime, default=beijing_now)
|
||
end_time = db.Column(db.DateTime)
|
||
created_at = db.Column(db.DateTime, default=beijing_now)
|
||
updated_at = db.Column(db.DateTime, default=beijing_now)
|
||
|
||
# 热度与统计
|
||
hot_score = db.Column(db.Float, default=0)
|
||
view_count = db.Column(db.Integer, default=0)
|
||
trending_score = db.Column(db.Float, default=0)
|
||
post_count = db.Column(db.Integer, default=0)
|
||
follower_count = db.Column(db.Integer, default=0)
|
||
|
||
# 关联信息
|
||
related_industries = db.Column(db.JSON)
|
||
keywords = db.Column(db.JSON)
|
||
files = db.Column(db.JSON)
|
||
importance = db.Column(db.String(20))
|
||
related_avg_chg = db.Column(db.Float, default=0)
|
||
related_max_chg = db.Column(db.Float, default=0)
|
||
related_week_chg = db.Column(db.Float, default=0)
|
||
|
||
# 新增字段
|
||
invest_score = db.Column(db.Integer) # 超预期得分
|
||
expectation_surprise_score = db.Column(db.Integer)
|
||
# 创建者信息
|
||
creator_id = db.Column(db.Integer, db.ForeignKey('user.id'))
|
||
creator = db.relationship('User', backref='created_events')
|
||
|
||
# 关系
|
||
posts = db.relationship('Post', backref='event', lazy='dynamic')
|
||
followers = db.relationship('EventFollow', backref='event', lazy='dynamic')
|
||
related_stocks = db.relationship('RelatedStock', backref='event', lazy='dynamic')
|
||
historical_events = db.relationship('HistoricalEvent', backref='event', lazy='dynamic')
|
||
related_data = db.relationship('RelatedData', backref='event', lazy='dynamic')
|
||
related_concepts = db.relationship('RelatedConcepts', backref='event', lazy='dynamic')
|
||
ind_type = db.Column(db.String(255))
|
||
|
||
@property
|
||
def keywords_list(self):
|
||
"""返回解析后的关键词列表"""
|
||
if not self.keywords:
|
||
return []
|
||
|
||
if isinstance(self.keywords, list):
|
||
return self.keywords
|
||
|
||
try:
|
||
# 如果是字符串,尝试解析JSON
|
||
if isinstance(self.keywords, str):
|
||
decoded = json.loads(self.keywords)
|
||
# 处理Unicode编码的情况
|
||
if isinstance(decoded, list):
|
||
return [
|
||
keyword.encode('utf-8').decode('unicode_escape')
|
||
if isinstance(keyword, str) and '\\u' in keyword
|
||
else keyword
|
||
for keyword in decoded
|
||
]
|
||
return []
|
||
|
||
# 如果已经是字典或其他格式,尝试转换为列表
|
||
return list(self.keywords)
|
||
except (json.JSONDecodeError, AttributeError, TypeError):
|
||
return []
|
||
|
||
def set_keywords(self, keywords):
|
||
"""设置关键词列表"""
|
||
if isinstance(keywords, list):
|
||
self.keywords = json.dumps(keywords, ensure_ascii=False)
|
||
elif isinstance(keywords, str):
|
||
try:
|
||
# 尝试解析JSON字符串
|
||
parsed = json.loads(keywords)
|
||
if isinstance(parsed, list):
|
||
self.keywords = json.dumps(parsed, ensure_ascii=False)
|
||
else:
|
||
self.keywords = json.dumps([keywords], ensure_ascii=False)
|
||
except json.JSONDecodeError:
|
||
# 如果不是有效的JSON,将其作为单个关键词
|
||
self.keywords = json.dumps([keywords], ensure_ascii=False)
|
||
|
||
|
||
class RelatedStock(db.Model):
|
||
"""相关标的模型"""
|
||
id = db.Column(db.Integer, primary_key=True)
|
||
event_id = db.Column(db.Integer, db.ForeignKey('event.id'))
|
||
stock_code = db.Column(db.String(20)) # 股票代码
|
||
stock_name = db.Column(db.String(100)) # 股票名称
|
||
sector = db.Column(db.String(100)) # 关联类型
|
||
relation_desc = db.Column(db.String(1024)) # 关联原因描述
|
||
created_at = db.Column(db.DateTime, default=beijing_now)
|
||
updated_at = db.Column(db.DateTime, default=beijing_now, onupdate=beijing_now)
|
||
correlation = db.Column(db.Float())
|
||
momentum = db.Column(db.String(1024)) # 动量
|
||
# 新增字段
|
||
retrieved_sources = db.Column(db.JSON) # 研报检索源数据
|
||
retrieved_update_time = db.Column(db.DateTime) # 检索数据更新时间
|
||
|
||
|
||
class RelatedData(db.Model):
|
||
"""关联数据模型"""
|
||
id = db.Column(db.Integer, primary_key=True)
|
||
event_id = db.Column(db.Integer, db.ForeignKey('event.id'))
|
||
title = db.Column(db.String(200)) # 数据标题
|
||
data_type = db.Column(db.String(50)) # 数据类型
|
||
data_content = db.Column(db.JSON) # 数据内容(JSON格式)
|
||
description = db.Column(db.Text) # 数据描述
|
||
created_at = db.Column(db.DateTime, default=beijing_now)
|
||
|
||
|
||
class RelatedConcepts(db.Model):
|
||
"""关联数据模型"""
|
||
id = db.Column(db.Integer, primary_key=True)
|
||
event_id = db.Column(db.Integer, db.ForeignKey('event.id'))
|
||
concept_code = db.Column(db.String(20)) # 数据标题
|
||
concept = db.Column(db.String(100)) # 数据类型
|
||
reason = db.Column(db.Text) # 数据描述
|
||
image_paths = db.Column(db.JSON) # 数据内容(JSON格式)
|
||
created_at = db.Column(db.DateTime, default=beijing_now)
|
||
|
||
@property
|
||
def image_paths_list(self):
|
||
"""返回解析后的图片路径列表"""
|
||
if not self.image_paths:
|
||
return []
|
||
|
||
try:
|
||
# 如果是字符串,先解析成JSON
|
||
if isinstance(self.image_paths, str):
|
||
paths = json.loads(self.image_paths)
|
||
else:
|
||
paths = self.image_paths
|
||
|
||
# 确保paths是列表
|
||
if not isinstance(paths, list):
|
||
paths = [paths]
|
||
|
||
# 从每个对象中提取path字段
|
||
return [item['path'] if isinstance(item, dict) and 'path' in item
|
||
else item for item in paths]
|
||
except Exception as e:
|
||
print(f"Error processing image paths: {e}")
|
||
return []
|
||
|
||
def get_first_image_path(self):
|
||
"""获取第一张图片的完整路径"""
|
||
paths = self.image_paths_list
|
||
if not paths:
|
||
return None
|
||
|
||
# 获取第一个路径
|
||
first_path = paths[0]
|
||
# 返回完整路径
|
||
return first_path
|
||
|
||
|
||
class EventHotHistory(db.Model):
|
||
"""事件热度历史记录"""
|
||
id = db.Column(db.Integer, primary_key=True)
|
||
event_id = db.Column(db.Integer, db.ForeignKey('event.id'))
|
||
score = db.Column(db.Float) # 总分
|
||
interaction_score = db.Column(db.Float) # 互动分数
|
||
follow_score = db.Column(db.Float) # 关注度分数
|
||
view_score = db.Column(db.Float) # 浏览量分数
|
||
recent_activity_score = db.Column(db.Float) # 最近活跃度分数
|
||
time_decay = db.Column(db.Float) # 时间衰减因子
|
||
created_at = db.Column(db.DateTime, default=beijing_now)
|
||
|
||
event = db.relationship('Event', backref='hot_history')
|
||
|
||
|
||
class Post(db.Model):
|
||
"""帖子模型"""
|
||
id = db.Column(db.Integer, primary_key=True)
|
||
event_id = db.Column(db.Integer, db.ForeignKey('event.id'), nullable=False)
|
||
user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False)
|
||
|
||
# 内容
|
||
title = db.Column(db.String(200)) # 标题(可选)
|
||
content = db.Column(db.Text, nullable=False) # 内容
|
||
content_type = db.Column(db.String(20), default='text') # 内容类型:text/rich_text/link
|
||
|
||
# 时间
|
||
created_at = db.Column(db.DateTime, default=beijing_now)
|
||
updated_at = db.Column(db.DateTime, default=beijing_now, onupdate=beijing_now)
|
||
|
||
# 统计
|
||
likes_count = db.Column(db.Integer, default=0)
|
||
comments_count = db.Column(db.Integer, default=0)
|
||
view_count = db.Column(db.Integer, default=0)
|
||
|
||
# 状态
|
||
status = db.Column(db.String(20), default='active') # active/hidden/deleted
|
||
is_top = db.Column(db.Boolean, default=False) # 是否置顶
|
||
|
||
# 关系
|
||
user = db.relationship('User', backref='posts')
|
||
likes = db.relationship('PostLike', backref='post', lazy='dynamic')
|
||
comments = db.relationship('Comment', backref='post', lazy='dynamic')
|
||
|
||
|
||
# 辅助模型
|
||
class EventFollow(db.Model):
|
||
"""事件关注"""
|
||
id = db.Column(db.Integer, primary_key=True)
|
||
user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False)
|
||
event_id = db.Column(db.Integer, db.ForeignKey('event.id'), nullable=False)
|
||
created_at = db.Column(db.DateTime, default=beijing_now)
|
||
|
||
user = db.relationship('User', backref='event_follows')
|
||
|
||
__table_args__ = (db.UniqueConstraint('user_id', 'event_id'),)
|
||
|
||
|
||
class PostLike(db.Model):
|
||
"""帖子点赞"""
|
||
id = db.Column(db.Integer, primary_key=True)
|
||
user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False)
|
||
post_id = db.Column(db.Integer, db.ForeignKey('post.id'), nullable=False)
|
||
created_at = db.Column(db.DateTime, default=beijing_now)
|
||
|
||
user = db.relationship('User', backref='post_likes')
|
||
|
||
__table_args__ = (db.UniqueConstraint('user_id', 'post_id'),)
|
||
|
||
|
||
class Comment(db.Model):
|
||
"""评论"""
|
||
id = db.Column(db.Integer, primary_key=True)
|
||
post_id = db.Column(db.Integer, db.ForeignKey('post.id'), nullable=False)
|
||
user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False)
|
||
content = db.Column(db.Text, nullable=False)
|
||
parent_id = db.Column(db.Integer, db.ForeignKey('comment.id')) # 父评论ID,用于回复
|
||
created_at = db.Column(db.DateTime, default=beijing_now)
|
||
status = db.Column(db.String(20), default='active')
|
||
|
||
user = db.relationship('User', backref='comments')
|
||
replies = db.relationship('Comment', backref=db.backref('parent', remote_side=[id]))
|
||
|
||
|
||
class StockBasicInfo(db.Model):
|
||
__tablename__ = 'ea_stocklist'
|
||
|
||
SECCODE = db.Column(db.String(10), primary_key=True)
|
||
SECNAME = db.Column(db.String(40))
|
||
ORGNAME = db.Column(db.String(100))
|
||
F001V = db.Column(db.String(100)) # Pinyin abbreviation
|
||
F003V = db.Column(db.String(50)) # Security category
|
||
F005V = db.Column(db.String(50)) # Trading market
|
||
F006D = db.Column(db.DateTime) # Listing date
|
||
F011V = db.Column(db.String(50)) # Listing status
|
||
|
||
|
||
class CompanyInfo(db.Model):
|
||
__tablename__ = 'ea_baseinfo'
|
||
|
||
SECCODE = db.Column(db.String(10), primary_key=True)
|
||
SECNAME = db.Column(db.String(40))
|
||
ORGNAME = db.Column(db.String(100))
|
||
F001V = db.Column(db.String(100)) # English name
|
||
F003V = db.Column(db.String(40)) # Legal representative
|
||
F015V = db.Column(db.String(500)) # Main business
|
||
F016V = db.Column(db.String(4000)) # Business scope
|
||
F017V = db.Column(db.String(2000)) # Company introduction
|
||
F030V = db.Column(db.String(60)) # CSRC industry first level
|
||
F032V = db.Column(db.String(60)) # CSRC industry second level
|
||
|
||
|
||
class TradeData(db.Model):
|
||
__tablename__ = 'ea_trade'
|
||
|
||
SECCODE = db.Column(db.String(10), primary_key=True)
|
||
SECNAME = db.Column(db.String(40))
|
||
TRADEDATE = db.Column(db.Date, primary_key=True)
|
||
F002N = db.Column(db.Numeric(18, 4)) # Previous close
|
||
F003N = db.Column(db.Numeric(18, 4)) # Open price
|
||
F004N = db.Column(db.Numeric(18, 4)) # Trading volume
|
||
F005N = db.Column(db.Numeric(18, 4)) # High price
|
||
F006N = db.Column(db.Numeric(18, 4)) # Low price
|
||
F007N = db.Column(db.Numeric(18, 4)) # Close price
|
||
F009N = db.Column(db.Numeric(18, 4)) # Change
|
||
F010N = db.Column(db.Numeric(18, 4)) # Change percentage
|
||
F011N = db.Column(db.Numeric(18, 4)) # Trading amount
|
||
|
||
|
||
class SectorInfo(db.Model):
|
||
__tablename__ = 'ea_sector'
|
||
|
||
SECCODE = db.Column(db.String(10), primary_key=True)
|
||
SECNAME = db.Column(db.String(40))
|
||
F001V = db.Column(db.String(50), primary_key=True) # Classification standard code
|
||
F002V = db.Column(db.String(50)) # Classification standard
|
||
F003V = db.Column(db.String(50)) # Sector code
|
||
F004V = db.Column(db.String(50)) # Sector level 1 name
|
||
F005V = db.Column(db.String(50)) # Sector level 2 name
|
||
F006V = db.Column(db.String(50)) # Sector level 3 name
|
||
F007V = db.Column(db.String(50)) # Sector level 4 name
|
||
|
||
|
||
def init_sywg_industry_cache():
|
||
"""
|
||
初始化申银万国行业分类缓存
|
||
在程序启动时调用,将所有行业分类数据加载到内存中
|
||
"""
|
||
global SYWG_INDUSTRY_CACHE
|
||
|
||
try:
|
||
app.logger.info('开始初始化申银万国行业分类缓存...')
|
||
|
||
# 定义层级映射关系
|
||
level_column_map = {
|
||
2: 'f004v', # level2 对应一级行业
|
||
3: 'f005v', # level3 对应二级行业
|
||
4: 'f006v', # level4 对应三级行业
|
||
5: 'f007v' # level5 对应四级行业
|
||
}
|
||
|
||
# 定义代码前缀长度映射
|
||
prefix_length_map = {
|
||
2: 3, # S + 2位
|
||
3: 5, # S + 2位 + 2位
|
||
4: 7, # S + 2位 + 2位 + 2位
|
||
5: 9 # 完整代码
|
||
}
|
||
|
||
# 遍历所有层级
|
||
for level, column_name in level_column_map.items():
|
||
# 查询该层级的所有行业及其代码
|
||
query_sql = f"""
|
||
SELECT DISTINCT {column_name} as industry_name, f003v as code
|
||
FROM ea_sector
|
||
WHERE f002v = '申银万国行业分类'
|
||
AND {column_name} IS NOT NULL
|
||
AND {column_name} != ''
|
||
"""
|
||
|
||
result = db.session.execute(text(query_sql))
|
||
rows = result.fetchall()
|
||
|
||
# 构建该层级的缓存
|
||
industry_dict = {}
|
||
for row in rows:
|
||
industry_name = row[0]
|
||
code = row[1]
|
||
|
||
if industry_name and code:
|
||
# 获取代码前缀
|
||
prefix_length = prefix_length_map[level]
|
||
code_prefix = code[:prefix_length]
|
||
|
||
# 将前缀添加到对应行业的列表中
|
||
if industry_name not in industry_dict:
|
||
industry_dict[industry_name] = set()
|
||
industry_dict[industry_name].add(code_prefix)
|
||
|
||
# 将set转换为list并存储到缓存中
|
||
for industry_name, prefixes in industry_dict.items():
|
||
SYWG_INDUSTRY_CACHE[level][industry_name] = list(prefixes)
|
||
|
||
app.logger.info(f'Level {level} 缓存完成,共 {len(industry_dict)} 个行业')
|
||
|
||
# 统计总数
|
||
total_count = sum(len(industries) for industries in SYWG_INDUSTRY_CACHE.values())
|
||
app.logger.info(f'申银万国行业分类缓存初始化完成,共缓存 {total_count} 个行业分类')
|
||
|
||
except Exception as e:
|
||
app.logger.error(f'初始化申银万国行业分类缓存失败: {str(e)}')
|
||
import traceback
|
||
app.logger.error(traceback.format_exc())
|
||
|
||
|
||
def send_async_email(msg):
|
||
"""异步发送邮件"""
|
||
try:
|
||
mail.send(msg)
|
||
except Exception as e:
|
||
app.logger.error(f"Error sending async email: {str(e)}")
|
||
|
||
|
||
def verify_sms_code(phone_number, code):
|
||
"""验证短信验证码"""
|
||
stored_code = session.get('sms_verification_code')
|
||
stored_phone = session.get('sms_verification_phone')
|
||
expiration = session.get('sms_verification_expiration')
|
||
|
||
if not all([stored_code, stored_phone, expiration]):
|
||
return False, "请先获取验证码"
|
||
|
||
if stored_phone != phone_number:
|
||
return False, "手机号与验证码不匹配"
|
||
|
||
if beijing_now().timestamp() > expiration:
|
||
return False, "验证码已过期"
|
||
|
||
if code != stored_code:
|
||
return False, "验证码错误"
|
||
|
||
return True, "验证成功"
|
||
|
||
|
||
def allowed_file(filename):
|
||
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
|
||
|
||
|
||
# ============================================
|
||
# 订阅相关 API 接口(小程序专用)
|
||
# ============================================
|
||
@app.route('/api/subscription/info', methods=['GET'])
|
||
@token_required
|
||
def get_subscription_info():
|
||
"""
|
||
获取当前用户的订阅信息 - 小程序专用接口
|
||
返回用户当前订阅类型、状态、剩余天数等信息
|
||
"""
|
||
try:
|
||
info = _get_current_subscription_info()
|
||
return jsonify({
|
||
'success': True,
|
||
'data': info
|
||
})
|
||
except Exception as e:
|
||
print(f"获取订阅信息错误: {e}")
|
||
return jsonify({
|
||
'success': True,
|
||
'data': {
|
||
'type': 'free',
|
||
'status': 'active',
|
||
'is_active': True,
|
||
'days_left': 0
|
||
}
|
||
})
|
||
|
||
|
||
@app.route('/api/subscription/check', methods=['GET'])
|
||
@token_required
|
||
def check_subscription_access():
|
||
"""
|
||
检查当前用户是否有权限使用小程序功能
|
||
返回:是否为 Pro/Max 用户
|
||
"""
|
||
try:
|
||
has_access = _has_required_level('pro')
|
||
info = _get_current_subscription_info()
|
||
|
||
return jsonify({
|
||
'success': True,
|
||
'data': {
|
||
'has_access': has_access,
|
||
'subscription_type': info.get('type', 'free'),
|
||
'is_active': info.get('is_active', False),
|
||
'message': '您可以使用小程序功能' if has_access else '小程序功能仅对 Pro 和 Max 会员开放'
|
||
}
|
||
})
|
||
except Exception as e:
|
||
print(f"检查订阅权限错误: {e}")
|
||
return jsonify({
|
||
'success': False,
|
||
'error': str(e)
|
||
}), 500
|
||
|
||
|
||
# ============================================
|
||
# 现有接口示例(应用权限控制)
|
||
# ============================================
|
||
|
||
# 更新视图函数
|
||
@app.route('/settings/profile', methods=['POST'])
|
||
@token_required
|
||
def update_profile():
|
||
"""更新个人资料"""
|
||
try:
|
||
user = request.user
|
||
form = request.form
|
||
|
||
# 基本信息更新
|
||
user.nickname = form.get('nickname')
|
||
user.bio = form.get('bio')
|
||
user.gender = form.get('gender')
|
||
user.birth_date = datetime.strptime(form.get('birth_date'), '%Y-%m-%d') if form.get('birth_date') else None
|
||
user.phone = form.get('phone')
|
||
user.location = form.get('location')
|
||
user.wechat_id = form.get('wechat_id')
|
||
|
||
# 处理头像上传
|
||
if 'avatar' in request.files:
|
||
file = request.files['avatar']
|
||
if file and allowed_file(file.filename):
|
||
# 生成安全的文件名
|
||
filename = secure_filename(f"{user.id}_{int(datetime.now().timestamp())}_{file.filename}")
|
||
filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
|
||
|
||
# 确保上传目录存在
|
||
os.makedirs(os.path.dirname(filepath), exist_ok=True)
|
||
|
||
# 保存并处理图片
|
||
image = Image.open(file)
|
||
image.thumbnail((300, 300)) # 调整图片大小
|
||
image.save(filepath)
|
||
|
||
# 更新用户头像URL
|
||
user.avatar_url = f'{DOMAIN}/static/uploads/avatars/{filename}'
|
||
|
||
db.session.commit()
|
||
return jsonify({'success': True, 'message': '个人资料已更新'})
|
||
|
||
except Exception as e:
|
||
db.session.rollback()
|
||
app.logger.error(f"Error updating profile: {str(e)}")
|
||
return jsonify({'success': False, 'message': '更新失败,请重试'})
|
||
|
||
|
||
# 投资偏好设置
|
||
@app.route('/settings/investment_preferences', methods=['POST'])
|
||
@token_required
|
||
def update_investment_preferences():
|
||
"""更新投资偏好"""
|
||
try:
|
||
user = request.user
|
||
form = request.form
|
||
|
||
user.trading_experience = form.get('trading_experience')
|
||
user.investment_style = form.get('investment_style')
|
||
user.risk_preference = form.get('risk_preference')
|
||
user.investment_amount = form.get('investment_amount')
|
||
user.preferred_markets = json.dumps(request.form.getlist('preferred_markets'))
|
||
|
||
db.session.commit()
|
||
return jsonify({'success': True, 'message': '投资偏好已更新'})
|
||
|
||
except Exception as e:
|
||
db.session.rollback()
|
||
app.logger.error(f"Error updating investment preferences: {str(e)}")
|
||
return jsonify({'success': False, 'message': '更新失败,请重试'})
|
||
|
||
|
||
def get_clickhouse_client():
|
||
"""
|
||
获取 ClickHouse 客户端(使用连接池,懒加载)
|
||
|
||
返回连接池对象,支持两种使用方式:
|
||
|
||
方式1(推荐)- 直接调用 execute:
|
||
client = get_clickhouse_client()
|
||
result = client.execute("SELECT * FROM table", {'param': value})
|
||
|
||
方式2 - 使用上下文管理器:
|
||
client = get_clickhouse_client()
|
||
with client.connection() as conn:
|
||
result = conn.execute("SELECT * FROM table")
|
||
"""
|
||
return _init_clickhouse_pool()
|
||
|
||
|
||
@app.route('/api/system/clickhouse-pool-status', methods=['GET'])
|
||
def api_clickhouse_pool_status():
|
||
"""获取 ClickHouse 连接池状态(仅供监控使用)"""
|
||
try:
|
||
pool = _init_clickhouse_pool()
|
||
status = pool.get_pool_status()
|
||
return jsonify({
|
||
'code': 200,
|
||
'message': 'success',
|
||
'data': status
|
||
})
|
||
except Exception as e:
|
||
return jsonify({
|
||
'code': 500,
|
||
'message': str(e),
|
||
'data': None
|
||
}), 500
|
||
|
||
|
||
@app.route('/api/stock/<stock_code>/kline')
|
||
def get_stock_kline(stock_code):
|
||
"""获取股票K线数据 - 仅限 Pro/Max 会员(小程序功能)"""
|
||
chart_type = request.args.get('chart_type', 'daily') # 默认改为daily
|
||
event_time = request.args.get('event_time')
|
||
|
||
try:
|
||
event_datetime = datetime.fromisoformat(event_time) if event_time else datetime.now()
|
||
except ValueError:
|
||
return jsonify({'error': 'Invalid event_time format'}), 400
|
||
|
||
# 获取股票名称
|
||
try:
|
||
with engine.connect() as conn:
|
||
result = conn.execute(text(
|
||
"SELECT SECNAME FROM ea_stocklist WHERE SECCODE = :code"
|
||
), {"code": stock_code.split('.')[0]}).fetchone()
|
||
stock_name = result[0] if result else 'Unknown'
|
||
except Exception as e:
|
||
print(f"Error getting stock name: {e}")
|
||
stock_name = 'Unknown'
|
||
|
||
if chart_type == 'daily':
|
||
return get_daily_kline(stock_code, event_datetime, stock_name)
|
||
elif chart_type == 'minute':
|
||
return get_minute_kline(stock_code, event_datetime, stock_name)
|
||
else:
|
||
return jsonify({
|
||
'error': 'Invalid chart type',
|
||
'message': 'Supported types: daily, minute',
|
||
'code': stock_code,
|
||
'name': stock_name
|
||
}), 400
|
||
|
||
|
||
def get_daily_kline(stock_code, event_datetime, stock_name):
|
||
"""处理日K线数据"""
|
||
stock_code = stock_code.split('.')[0]
|
||
|
||
print(f"Debug: stock_code={stock_code}, event_datetime={event_datetime}, stock_name={stock_name}")
|
||
|
||
try:
|
||
with engine.connect() as conn:
|
||
# 获取事件日期前后的数据
|
||
kline_sql = """
|
||
WITH date_range AS (SELECT TRADEDATE \
|
||
FROM ea_trade \
|
||
WHERE SECCODE = :stock_code \
|
||
AND TRADEDATE BETWEEN DATE_SUB(:trade_date, INTERVAL 60 DAY) \
|
||
AND :trade_date \
|
||
GROUP BY TRADEDATE \
|
||
ORDER BY TRADEDATE)
|
||
SELECT t.TRADEDATE,
|
||
CAST(t.F003N AS FLOAT) as open,
|
||
CAST(t.F007N AS FLOAT) as close,
|
||
CAST(t.F005N AS FLOAT) as high,
|
||
CAST(t.F006N AS FLOAT) as low,
|
||
CAST(t.F004N AS FLOAT) as volume
|
||
FROM ea_trade t
|
||
JOIN date_range d \
|
||
ON t.TRADEDATE = d.TRADEDATE
|
||
WHERE t.SECCODE = :stock_code
|
||
ORDER BY t.TRADEDATE \
|
||
"""
|
||
|
||
result = conn.execute(text(kline_sql), {
|
||
"stock_code": stock_code,
|
||
"trade_date": event_datetime.date()
|
||
}).fetchall()
|
||
|
||
print(f"Debug: Query result count: {len(result)}")
|
||
|
||
if not result:
|
||
print("Debug: No data found, trying fallback query...")
|
||
# 如果没有数据,尝试获取最近的交易数据
|
||
fallback_sql = """
|
||
SELECT TRADEDATE,
|
||
CAST(F003N AS FLOAT) as open,
|
||
CAST(F007N AS FLOAT) as close,
|
||
CAST(F005N AS FLOAT) as high,
|
||
CAST(F006N AS FLOAT) as low,
|
||
CAST(F004N AS FLOAT) as volume
|
||
FROM ea_trade
|
||
WHERE SECCODE = :stock_code
|
||
AND TRADEDATE <= :trade_date
|
||
AND F003N IS NOT NULL
|
||
AND F007N IS NOT NULL
|
||
AND F005N IS NOT NULL
|
||
AND F006N IS NOT NULL
|
||
AND F004N IS NOT NULL
|
||
ORDER BY TRADEDATE
|
||
LIMIT 100 \
|
||
"""
|
||
|
||
result = conn.execute(text(fallback_sql), {
|
||
"stock_code": stock_code,
|
||
"trade_date": event_datetime.date()
|
||
}).fetchall()
|
||
|
||
print(f"Debug: Fallback query result count: {len(result)}")
|
||
|
||
if not result:
|
||
return jsonify({
|
||
'error': 'No data available',
|
||
'code': stock_code,
|
||
'name': stock_name,
|
||
'data': [],
|
||
'trade_date': event_datetime.date().strftime('%Y-%m-%d'),
|
||
'type': 'daily'
|
||
})
|
||
|
||
kline_data = []
|
||
for row in result:
|
||
try:
|
||
kline_data.append({
|
||
'time': row.TRADEDATE.strftime('%Y-%m-%d'),
|
||
'open': float(row.open) if row.open else 0,
|
||
'high': float(row.high) if row.high else 0,
|
||
'low': float(row.low) if row.low else 0,
|
||
'close': float(row.close) if row.close else 0,
|
||
'volume': float(row.volume) if row.volume else 0
|
||
})
|
||
except (ValueError, TypeError) as e:
|
||
print(f"Debug: Error processing row: {e}")
|
||
continue
|
||
|
||
print(f"Debug: Final kline_data count: {len(kline_data)}")
|
||
|
||
return jsonify({
|
||
'code': stock_code,
|
||
'name': stock_name,
|
||
'data': kline_data,
|
||
'trade_date': event_datetime.date().strftime('%Y-%m-%d'),
|
||
'event_time': event_datetime.isoformat(),
|
||
'type': 'daily',
|
||
'is_history': True,
|
||
'data_count': len(kline_data)
|
||
})
|
||
|
||
except Exception as e:
|
||
print(f"Error in get_daily_kline: {e}")
|
||
return jsonify({
|
||
'error': f'Database error: {str(e)}',
|
||
'code': stock_code,
|
||
'name': stock_name,
|
||
'data': [],
|
||
'trade_date': event_datetime.date().strftime('%Y-%m-%d'),
|
||
'type': 'daily'
|
||
}), 500
|
||
|
||
|
||
def get_minute_kline(stock_code, event_datetime, stock_name):
|
||
"""处理分钟K线数据 - 包含零轴(昨日收盘价)"""
|
||
client = get_clickhouse_client()
|
||
stock_code_short = stock_code.split('.')[0] # 获取不带后缀的股票代码
|
||
|
||
def get_trading_days():
|
||
trading_days = set()
|
||
with open('tdays.csv', 'r') as f:
|
||
reader = csv.DictReader(f)
|
||
for row in reader:
|
||
trading_days.add(datetime.strptime(row['DateTime'], '%Y/%m/%d').date())
|
||
return trading_days
|
||
|
||
trading_days = get_trading_days()
|
||
|
||
def find_next_trading_day(current_date):
|
||
"""找到下一个交易日"""
|
||
while current_date <= max(trading_days):
|
||
current_date += timedelta(days=1)
|
||
if current_date in trading_days:
|
||
return current_date
|
||
return None
|
||
|
||
def find_prev_trading_day(current_date):
|
||
"""找到前一个交易日"""
|
||
while current_date >= min(trading_days):
|
||
current_date -= timedelta(days=1)
|
||
if current_date in trading_days:
|
||
return current_date
|
||
return None
|
||
|
||
def get_prev_close(stock_code_short, target_date):
|
||
"""获取前一交易日的收盘价作为零轴基准"""
|
||
prev_date = find_prev_trading_day(target_date)
|
||
if not prev_date:
|
||
return None
|
||
|
||
try:
|
||
with engine.connect() as conn:
|
||
# 查询前一交易日的收盘价
|
||
sql = """
|
||
SELECT CAST(F007N AS FLOAT) as close
|
||
FROM ea_trade
|
||
WHERE SECCODE = :stock_code
|
||
AND TRADEDATE = :prev_date
|
||
AND F007N IS NOT NULL
|
||
LIMIT 1 \
|
||
"""
|
||
result = conn.execute(text(sql), {
|
||
"stock_code": stock_code_short,
|
||
"prev_date": prev_date
|
||
}).fetchone()
|
||
|
||
if result:
|
||
return float(result.close)
|
||
else:
|
||
# 如果指定日期没有数据,尝试获取最近的收盘价
|
||
fallback_sql = """
|
||
SELECT CAST(F007N AS FLOAT) as close, TRADEDATE
|
||
FROM ea_trade
|
||
WHERE SECCODE = :stock_code
|
||
AND TRADEDATE \
|
||
< :target_date
|
||
AND F007N IS NOT NULL
|
||
ORDER BY TRADEDATE DESC
|
||
LIMIT 1 \
|
||
"""
|
||
result = conn.execute(text(fallback_sql), {
|
||
"stock_code": stock_code_short,
|
||
"target_date": target_date
|
||
}).fetchone()
|
||
|
||
if result:
|
||
print(f"Using close price from {result.TRADEDATE} as zero axis")
|
||
return float(result.close)
|
||
|
||
except Exception as e:
|
||
print(f"Error getting previous close: {e}")
|
||
|
||
return None
|
||
|
||
target_date = event_datetime.date()
|
||
is_after_market = event_datetime.time() > dt_time(15, 0)
|
||
|
||
# 核心逻辑:先判断当前日期是否是交易日,以及是否已收盘
|
||
if target_date in trading_days and is_after_market:
|
||
# 如果是交易日且已收盘,查找下一个交易日
|
||
next_trade_date = find_next_trading_day(target_date)
|
||
if next_trade_date:
|
||
target_date = next_trade_date
|
||
elif target_date not in trading_days:
|
||
# 如果不是交易日,先尝试找下一个交易日
|
||
next_trade_date = find_next_trading_day(target_date)
|
||
if next_trade_date:
|
||
target_date = next_trade_date
|
||
else:
|
||
# 如果找不到下一个交易日,找最近的历史交易日
|
||
target_date = find_prev_trading_day(target_date)
|
||
|
||
if not target_date:
|
||
return jsonify({
|
||
'error': 'No data available',
|
||
'code': stock_code,
|
||
'name': stock_name,
|
||
'data': [],
|
||
'trade_date': event_datetime.date().strftime('%Y-%m-%d'),
|
||
'type': 'minute'
|
||
})
|
||
|
||
# 获取前一交易日收盘价作为零轴
|
||
zero_axis = get_prev_close(stock_code_short, target_date)
|
||
|
||
# 获取目标日期的完整交易时段数据
|
||
data = client.execute("""
|
||
SELECT
|
||
timestamp, open, high, low, close, volume, amt
|
||
FROM stock_minute
|
||
WHERE code = %(code)s
|
||
AND timestamp BETWEEN %(start)s
|
||
AND %(end)s
|
||
ORDER BY timestamp
|
||
""", {
|
||
'code': stock_code,
|
||
'start': datetime.combine(target_date, dt_time(9, 30)),
|
||
'end': datetime.combine(target_date, dt_time(15, 0))
|
||
})
|
||
|
||
kline_data = []
|
||
for row in data:
|
||
point = {
|
||
'time': row[0].strftime('%H:%M'),
|
||
'open': float(row[1]),
|
||
'high': float(row[2]),
|
||
'low': float(row[3]),
|
||
'close': float(row[4]),
|
||
'volume': float(row[5]),
|
||
'amount': float(row[6])
|
||
}
|
||
|
||
# 如果有零轴数据,计算涨跌幅和涨跌额
|
||
if zero_axis:
|
||
point['prev_close'] = zero_axis
|
||
point['change'] = point['close'] - zero_axis # 涨跌额
|
||
point['change_pct'] = ((point['close'] - zero_axis) / zero_axis * 100) if zero_axis != 0 else 0 # 涨跌幅百分比
|
||
|
||
kline_data.append(point)
|
||
|
||
response_data = {
|
||
'code': stock_code,
|
||
'name': stock_name,
|
||
'data': kline_data,
|
||
'trade_date': target_date.strftime('%Y-%m-%d'),
|
||
'type': 'minute',
|
||
'is_history': target_date < event_datetime.date()
|
||
}
|
||
|
||
# 添加零轴信息到响应中
|
||
if zero_axis:
|
||
response_data['zero_axis'] = zero_axis
|
||
response_data['prev_close'] = zero_axis
|
||
|
||
# 计算当日整体涨跌幅(如果有数据)
|
||
if kline_data:
|
||
last_close = kline_data[-1]['close']
|
||
response_data['day_change'] = last_close - zero_axis
|
||
response_data['day_change_pct'] = ((last_close - zero_axis) / zero_axis * 100) if zero_axis != 0 else 0
|
||
|
||
return jsonify(response_data)
|
||
|
||
|
||
class HistoricalEvent(db.Model):
|
||
"""历史事件模型"""
|
||
id = db.Column(db.Integer, primary_key=True)
|
||
event_id = db.Column(db.Integer, db.ForeignKey('event.id'))
|
||
title = db.Column(db.String(200))
|
||
content = db.Column(db.Text)
|
||
event_date = db.Column(db.DateTime)
|
||
relevance = db.Column(db.Integer) # 相关性
|
||
importance = db.Column(db.Integer) # 重要程度
|
||
related_stock = db.Column(db.JSON) # 保留JSON字段
|
||
created_at = db.Column(db.DateTime, default=beijing_now)
|
||
|
||
# 新增关系
|
||
stocks = db.relationship('HistoricalEventStock', backref='historical_event', lazy='dynamic',
|
||
cascade='all, delete-orphan')
|
||
|
||
|
||
class HistoricalEventStock(db.Model):
|
||
"""历史事件相关股票模型"""
|
||
__tablename__ = 'historical_event_stocks'
|
||
|
||
id = db.Column(db.Integer, primary_key=True)
|
||
historical_event_id = db.Column(db.Integer, db.ForeignKey('historical_event.id'), nullable=False)
|
||
stock_code = db.Column(db.String(20), nullable=False)
|
||
stock_name = db.Column(db.String(50))
|
||
relation_desc = db.Column(db.Text)
|
||
correlation = db.Column(db.Float, default=0.5)
|
||
sector = db.Column(db.String(100))
|
||
created_at = db.Column(db.DateTime, default=beijing_now)
|
||
|
||
__table_args__ = (
|
||
db.UniqueConstraint('historical_event_id', 'stock_code', name='unique_event_stock'),
|
||
)
|
||
|
||
|
||
@app.route('/event/follow/<int:event_id>', methods=['POST'])
|
||
@token_required
|
||
def follow_event(event_id):
|
||
"""关注/取消关注事件"""
|
||
event = Event.query.get_or_404(event_id)
|
||
follow = EventFollow.query.filter_by(
|
||
user_id=request.user.id,
|
||
event_id=event_id
|
||
).first()
|
||
|
||
try:
|
||
if follow:
|
||
db.session.delete(follow)
|
||
event.follower_count -= 1
|
||
message = '已取消关注'
|
||
else:
|
||
follow = EventFollow(user_id=request.user.id, event_id=event_id)
|
||
db.session.add(follow)
|
||
event.follower_count += 1
|
||
message = '已关注'
|
||
|
||
db.session.commit()
|
||
return jsonify({'success': True, 'message': message})
|
||
|
||
except Exception as e:
|
||
db.session.rollback()
|
||
return jsonify({'success': False, 'message': '操作失败,请重试'})
|
||
|
||
|
||
# 帖子相关路由
|
||
@app.route('/post/create/<int:event_id>', methods=['GET', 'POST'])
|
||
@token_required
|
||
def create_post(event_id):
|
||
"""创建新帖子"""
|
||
event = Event.query.get_or_404(event_id)
|
||
|
||
if request.method == 'POST':
|
||
try:
|
||
post = Post(
|
||
event_id=event_id,
|
||
user_id=request.user.id,
|
||
title=request.form.get('title'),
|
||
content=request.form['content'],
|
||
content_type=request.form.get('content_type', 'text')
|
||
)
|
||
|
||
db.session.add(post)
|
||
event.post_count += 1
|
||
db.session.commit()
|
||
|
||
# 检查是否是 API 请求(通过 Accept header 或 Content-Type 判断)
|
||
if request.headers.get('Accept') == 'application/json' or \
|
||
request.headers.get('Content-Type', '').startswith('application/json'):
|
||
return jsonify({
|
||
'success': True,
|
||
'message': '发布成功',
|
||
'data': {
|
||
'post_id': post.id,
|
||
'event_id': event_id,
|
||
'redirect_url': url_for('event_detail', event_id=event_id)
|
||
}
|
||
})
|
||
else:
|
||
# 传统表单提交,添加成功消息并重定向
|
||
flash('发布成功', 'success')
|
||
return redirect(url_for('event_detail', event_id=event_id))
|
||
|
||
except Exception as e:
|
||
db.session.rollback()
|
||
if request.headers.get('Accept') == 'application/json' or \
|
||
request.headers.get('Content-Type', '').startswith('application/json'):
|
||
return jsonify({
|
||
'success': False,
|
||
'message': '发布失败,请重试'
|
||
}), 400
|
||
else:
|
||
flash('发布失败,请重试', 'error')
|
||
app.logger.error(f"Error creating post: {str(e)}")
|
||
|
||
return render_template('projects/create_post.html', event=event)
|
||
|
||
|
||
# 点赞相关路由
|
||
@app.route('/post/like/<int:post_id>', methods=['POST'])
|
||
@token_required
|
||
def like_post(post_id):
|
||
"""点赞/取消点赞帖子"""
|
||
post = Post.query.get_or_404(post_id)
|
||
like = PostLike.query.filter_by(
|
||
user_id=request.user.id,
|
||
post_id=post_id
|
||
).first()
|
||
|
||
try:
|
||
if like:
|
||
# 取消点赞
|
||
db.session.delete(like)
|
||
post.likes_count -= 1
|
||
message = '已取消点赞'
|
||
else:
|
||
# 添加点赞
|
||
like = PostLike(user_id=request.user.id, post_id=post_id)
|
||
db.session.add(like)
|
||
post.likes_count += 1
|
||
message = '已点赞'
|
||
|
||
db.session.commit()
|
||
return jsonify({
|
||
'success': True,
|
||
'message': message,
|
||
'likes_count': post.likes_count
|
||
})
|
||
|
||
except Exception as e:
|
||
db.session.rollback()
|
||
return jsonify({'success': False, 'message': '操作失败,请重试'})
|
||
|
||
|
||
def update_user_activity():
|
||
"""更新用户活跃度"""
|
||
with app.app_context():
|
||
try:
|
||
# 获取过去7天内的用户活动数据
|
||
seven_days_ago = beijing_now() - timedelta(days=7)
|
||
|
||
# 统计用户发帖、评论、点赞等活动
|
||
active_users = db.session.query(
|
||
User.id,
|
||
db.func.count(Post.id).label('post_count'),
|
||
db.func.count(Comment.id).label('comment_count'),
|
||
db.func.count(PostLike.id).label('like_count')
|
||
).outerjoin(Post, User.id == Post.user_id) \
|
||
.outerjoin(Comment, User.id == Comment.user_id) \
|
||
.outerjoin(PostLike, User.id == PostLike.user_id) \
|
||
.filter(
|
||
db.or_(
|
||
Post.created_at >= seven_days_ago,
|
||
Comment.created_at >= seven_days_ago,
|
||
PostLike.created_at >= seven_days_ago
|
||
)
|
||
).group_by(User.id).all()
|
||
|
||
# 更新用户活跃度分数
|
||
for user_id, post_count, comment_count, like_count in active_users:
|
||
activity_score = post_count * 2 + comment_count * 1 + like_count * 0.5
|
||
User.query.filter_by(id=user_id).update({
|
||
'activity_score': activity_score,
|
||
'last_active': beijing_now()
|
||
})
|
||
|
||
db.session.commit()
|
||
current_app.logger.info("Successfully updated user activity scores")
|
||
|
||
except Exception as e:
|
||
db.session.rollback()
|
||
current_app.logger.error(f"Error updating user activity: {str(e)}")
|
||
|
||
|
||
@app.route('/post/comment/<int:post_id>', methods=['POST'])
|
||
@token_required
|
||
def add_comment(post_id):
|
||
"""添加评论"""
|
||
post = Post.query.get_or_404(post_id)
|
||
|
||
try:
|
||
content = request.form.get('content')
|
||
parent_id = request.form.get('parent_id', type=int)
|
||
|
||
if not content:
|
||
return jsonify({'success': False, 'message': '评论内容不能为空'})
|
||
|
||
comment = Comment(
|
||
post_id=post_id,
|
||
user_id=request.user.id,
|
||
content=content,
|
||
parent_id=parent_id
|
||
)
|
||
|
||
db.session.add(comment)
|
||
post.comments_count += 1
|
||
|
||
db.session.commit()
|
||
|
||
return jsonify({
|
||
'success': True,
|
||
'message': '评论成功',
|
||
'comment': {
|
||
'id': comment.id,
|
||
'content': comment.content,
|
||
'user_name': request.user.username,
|
||
'user_avatar': get_full_avatar_url(request.user.avatar_url), # 修改这里
|
||
'created_at': comment.created_at.strftime('%Y-%m-%d %H:%M:%S')
|
||
}
|
||
})
|
||
|
||
except Exception as e:
|
||
db.session.rollback()
|
||
return jsonify({'success': False, 'message': '评论失败,请重试'})
|
||
|
||
|
||
@app.route('/post/comments/<int:post_id>')
|
||
def get_comments(post_id):
|
||
"""获取帖子评论列表"""
|
||
page = request.args.get('page', 1, type=int)
|
||
|
||
# 获取顶层评论
|
||
comments = Comment.query.filter_by(
|
||
post_id=post_id,
|
||
parent_id=None,
|
||
status='active'
|
||
).order_by(
|
||
Comment.created_at.desc()
|
||
).paginate(page=page, per_page=20)
|
||
|
||
# 同时获取每个顶层评论的部分回复
|
||
comments_data = []
|
||
for comment in comments.items:
|
||
replies = Comment.query.filter_by(
|
||
parent_id=comment.id,
|
||
status='active'
|
||
).order_by(
|
||
Comment.created_at.asc()
|
||
).limit(3).all()
|
||
|
||
comments_data.append({
|
||
'id': comment.id,
|
||
'content': comment.content,
|
||
'user': {
|
||
'id': comment.user.id,
|
||
'username': comment.user.username,
|
||
'avatar_url': get_full_avatar_url(comment.user.avatar_url), # 修改这里
|
||
},
|
||
'created_at': comment.created_at.strftime('%Y-%m-%d %H:%M:%S'),
|
||
'replies': [{
|
||
'id': reply.id,
|
||
'content': reply.content,
|
||
'user': {
|
||
'id': reply.user.id,
|
||
'username': reply.user.username,
|
||
'avatar_url': get_full_avatar_url(reply.user.avatar_url), # 修改这里
|
||
},
|
||
'created_at': reply.created_at.strftime('%Y-%m-%d %H:%M:%S')
|
||
} for reply in replies]
|
||
})
|
||
|
||
return jsonify({
|
||
'comments': comments_data,
|
||
'total': comments.total,
|
||
'pages': comments.pages,
|
||
'current_page': comments.page
|
||
})
|
||
|
||
|
||
beijing_tz = pytz.timezone('Asia/Shanghai')
|
||
|
||
|
||
def update_hot_scores():
|
||
"""
|
||
更新所有事件的热度分数
|
||
在Flask应用上下文中执行数据库操作
|
||
"""
|
||
with app.app_context():
|
||
try:
|
||
# 获取所有活跃事件
|
||
events = Event.query.filter_by(status='active').all()
|
||
current_time = beijing_now()
|
||
|
||
for event in events:
|
||
# 确保created_at有时区信息,解决naive和aware datetime比较问题
|
||
created_at = beijing_tz.localize(
|
||
event.created_at) if event.created_at.tzinfo is None else event.created_at
|
||
|
||
# 使用处理后的created_at计算hours_passed
|
||
hours_passed = (current_time - created_at).total_seconds() / 3600
|
||
|
||
# 基础分数 - 帖子数和评论数
|
||
posts = Post.query.filter_by(event_id=event.id).all()
|
||
post_count = len(posts)
|
||
comment_count = sum(post.comments_count for post in posts)
|
||
|
||
# 获取24小时内的新增帖子数
|
||
recent_posts = Post.query.filter(
|
||
Post.event_id == event.id,
|
||
Post.created_at >= current_time - timedelta(hours=24)
|
||
).count()
|
||
|
||
# 获取点赞数
|
||
like_count = db.session.query(func.sum(Post.likes_count)).filter(
|
||
Post.event_id == event.id
|
||
).scalar() or 0
|
||
|
||
# 基础互动分数 = 帖子数 * 2 + 评论数 * 1 + 点赞数 * 0.5
|
||
interaction_score = (post_count * 2) + (comment_count * 1) + (like_count * 0.5)
|
||
|
||
# 关注度分数 = 关注人数 * 3
|
||
follow_score = event.follower_count * 3
|
||
|
||
# 浏览量分数 = log(浏览量)
|
||
if event.view_count > 0:
|
||
view_score = math.log(event.view_count) * 2
|
||
else:
|
||
view_score = 0
|
||
|
||
# 时间衰减因子 - 使用上面已经计算好的hours_passed
|
||
time_decay = math.exp(-hours_passed / 72) # 3天后衰减为原始分数的1/e
|
||
|
||
# 最近活跃度权重
|
||
recent_activity_weight = (recent_posts * 5) # 24小时内的新帖权重高
|
||
|
||
# 总分 = (互动分数 + 关注度分数 + 浏览量分数 + 最近活跃度) * 时间衰减
|
||
total_score = (interaction_score + follow_score + view_score + recent_activity_weight) * time_decay
|
||
|
||
# 更新热度分数
|
||
event.hot_score = round(total_score, 2)
|
||
|
||
# 分数的对数值作为事件的trending_score (用于趋势排序)
|
||
if total_score > 0:
|
||
event.trending_score = math.log(total_score) * time_decay
|
||
else:
|
||
event.trending_score = 0
|
||
|
||
# 记录热度历史
|
||
history = EventHotHistory(
|
||
event_id=event.id,
|
||
score=event.hot_score,
|
||
interaction_score=interaction_score,
|
||
follow_score=follow_score,
|
||
view_score=view_score,
|
||
recent_activity_score=recent_activity_weight,
|
||
time_decay=time_decay
|
||
)
|
||
db.session.add(history)
|
||
|
||
db.session.commit()
|
||
app.logger.info("Successfully updated event hot scores")
|
||
|
||
except Exception as e:
|
||
db.session.rollback()
|
||
app.logger.error(f"Error updating hot scores: {str(e)}")
|
||
raise
|
||
|
||
|
||
# 添加热度历史记录模型
|
||
|
||
|
||
def calculate_hot_score(event):
|
||
"""计算事件热度分数"""
|
||
current_time = beijing_now()
|
||
time_diff = (current_time - event.created_at).total_seconds() / 3600 # 转换为小时
|
||
|
||
# 基础分数 = 浏览量 * 0.1 + 帖子数 * 0.5 + 关注数 * 1
|
||
base_score = (
|
||
event.view_count * 0.1 +
|
||
event.post_count * 0.5 +
|
||
event.follower_count * 1
|
||
)
|
||
|
||
# 时间衰减因子,72小时(3天)内的事件获得较高权重
|
||
time_factor = max(1 - (time_diff / 72), 0.1)
|
||
|
||
return base_score * time_factor
|
||
|
||
|
||
@app.route('/api/sector/hierarchy', methods=['GET'])
|
||
def api_sector_hierarchy():
|
||
"""行业层级关系接口:展示多个行业分类体系的层级结构"""
|
||
try:
|
||
# 定义需要返回的行业分类体系
|
||
classification_systems = [
|
||
'申银万国行业分类'
|
||
]
|
||
|
||
result = [] # 改为数组
|
||
|
||
for classification in classification_systems:
|
||
# 查询特定分类标准的数据
|
||
sectors = SectorInfo.query.filter_by(F002V=classification).all()
|
||
|
||
if not sectors:
|
||
continue
|
||
|
||
# 构建该分类体系的层级结构
|
||
hierarchy = {}
|
||
|
||
for sector in sectors:
|
||
level1 = sector.F004V # 一级行业
|
||
level2 = sector.F005V # 二级行业
|
||
level3 = sector.F006V # 三级行业
|
||
level4 = sector.F007V # 四级行业
|
||
|
||
# 统计股票数量
|
||
stock_code = sector.SECCODE
|
||
|
||
# 初始化一级行业
|
||
if level1 not in hierarchy:
|
||
hierarchy[level1] = {
|
||
'level2_sectors': {},
|
||
'stocks': set(),
|
||
'stocks_count': 0
|
||
}
|
||
|
||
# 添加股票到一级行业
|
||
if stock_code:
|
||
hierarchy[level1]['stocks'].add(stock_code)
|
||
|
||
# 处理二级行业
|
||
if level2:
|
||
if level2 not in hierarchy[level1]['level2_sectors']:
|
||
hierarchy[level1]['level2_sectors'][level2] = {
|
||
'level3_sectors': {},
|
||
'stocks': set(),
|
||
'stocks_count': 0
|
||
}
|
||
|
||
# 添加股票到二级行业
|
||
if stock_code:
|
||
hierarchy[level1]['level2_sectors'][level2]['stocks'].add(stock_code)
|
||
|
||
# 处理三级行业
|
||
if level3:
|
||
if level3 not in hierarchy[level1]['level2_sectors'][level2]['level3_sectors']:
|
||
hierarchy[level1]['level2_sectors'][level2]['level3_sectors'][level3] = {
|
||
'level4_sectors': [],
|
||
'stocks': set(),
|
||
'stocks_count': 0
|
||
}
|
||
|
||
# 添加股票到三级行业
|
||
if stock_code:
|
||
hierarchy[level1]['level2_sectors'][level2]['level3_sectors'][level3]['stocks'].add(
|
||
stock_code)
|
||
|
||
# 处理四级行业
|
||
if level4 and level4 not in \
|
||
hierarchy[level1]['level2_sectors'][level2]['level3_sectors'][level3]['level4_sectors']:
|
||
hierarchy[level1]['level2_sectors'][level2]['level3_sectors'][level3][
|
||
'level4_sectors'].append(level4)
|
||
|
||
# 计算股票数量并清理set对象
|
||
formatted_hierarchy = []
|
||
for level1, level1_data in hierarchy.items():
|
||
level1_item = {
|
||
'level1_sector': level1,
|
||
'stocks_count': len(level1_data['stocks']),
|
||
'level2_sectors': []
|
||
}
|
||
|
||
for level2, level2_data in level1_data['level2_sectors'].items():
|
||
level2_item = {
|
||
'level2_sector': level2,
|
||
'stocks_count': len(level2_data['stocks']),
|
||
'level3_sectors': []
|
||
}
|
||
|
||
for level3, level3_data in level2_data['level3_sectors'].items():
|
||
level3_item = {
|
||
'level3_sector': level3,
|
||
'stocks_count': len(level3_data['stocks']),
|
||
'level4_sectors': level3_data['level4_sectors']
|
||
}
|
||
level2_item['level3_sectors'].append(level3_item)
|
||
|
||
# 按股票数量排序
|
||
level2_item['level3_sectors'].sort(key=lambda x: x['stocks_count'], reverse=True)
|
||
level1_item['level2_sectors'].append(level2_item)
|
||
|
||
# 按股票数量排序
|
||
level1_item['level2_sectors'].sort(key=lambda x: x['stocks_count'], reverse=True)
|
||
formatted_hierarchy.append(level1_item)
|
||
|
||
# 按股票数量排序
|
||
formatted_hierarchy.sort(key=lambda x: x['stocks_count'], reverse=True)
|
||
|
||
# 将该分类体系添加到结果数组中
|
||
result.append({
|
||
'classification_name': classification,
|
||
'total_level1_count': len(formatted_hierarchy),
|
||
'total_stocks_count': sum(item['stocks_count'] for item in formatted_hierarchy),
|
||
'hierarchy': formatted_hierarchy
|
||
})
|
||
|
||
# 按总股票数量排序
|
||
result.sort(key=lambda x: x['total_stocks_count'], reverse=True)
|
||
|
||
return jsonify({
|
||
"code": 200,
|
||
"message": "success",
|
||
"data": result
|
||
})
|
||
|
||
except Exception as e:
|
||
return jsonify({
|
||
"code": 500,
|
||
"message": str(e),
|
||
"data": None
|
||
}), 500
|
||
|
||
|
||
@app.route('/api/sector/banner', methods=['GET'])
|
||
def api_sector_banner():
|
||
"""行业分类 banner 接口:返回一级分类和对应二级行业列表"""
|
||
try:
|
||
# 原始映射
|
||
sector_map = {
|
||
'石油石化': '大周期', '煤炭': '大周期', '有色金属': '大周期', '钢铁': '大周期', '基础化工': '大周期',
|
||
'建筑材料': '大周期', '机械设备': '大周期', '电力设备及新能源': '大周期', '国防军工': '大周期',
|
||
'电力设备': '大周期', '电网设备': '大周期', '风力发电': '大周期', '太阳能发电': '大周期',
|
||
'建筑装饰': '大周期',
|
||
|
||
'汽车': '大消费', '家用电器': '大消费', '酒类': '大消费', '食品饮料': '大消费', '医药生物': '大消费',
|
||
'纺织服饰': '大消费', '农林牧渔': '大消费', '商贸零售': '大消费', '轻工制造': '大消费',
|
||
'消费者服务': '大消费', '美容护理': '大消费', '社会服务': '大消费',
|
||
|
||
'银行': '大金融地产', '证券': '大金融地产', '保险': '大金融地产', '多元金融': '大金融地产',
|
||
'综合金融': '大金融地产', '房地产': '大金融地产', '非银金融': '大金融地产',
|
||
|
||
'计算机': 'TMT板块', '电子': 'TMT板块', '传媒': 'TMT板块', '通信': 'TMT板块',
|
||
|
||
'交通运输': '公共产业板块', '电力公用事业': '公共产业板块', '建筑': '公共产业板块',
|
||
'环保': '公共产业板块', '综合': '公共产业板块', '公用事业': '公共产业板块'
|
||
}
|
||
|
||
# 重组结构为 一级 → [二级...]
|
||
result_dict = {}
|
||
for sub_sector, primary_sector in sector_map.items():
|
||
result_dict.setdefault(primary_sector, []).append(sub_sector)
|
||
|
||
# 格式化成列表
|
||
result_list = [
|
||
{"primary_sector": primary, "sub_sectors": subs}
|
||
for primary, subs in result_dict.items()
|
||
]
|
||
|
||
return jsonify({
|
||
"code": 200,
|
||
"message": "success",
|
||
"data": result_list
|
||
})
|
||
|
||
except Exception as e:
|
||
return jsonify({
|
||
"code": 500,
|
||
"message": str(e),
|
||
"data": None
|
||
}), 500
|
||
|
||
|
||
def get_limit_rate(stock_code):
|
||
"""
|
||
根据股票代码获取涨跌停限制比例
|
||
|
||
Args:
|
||
stock_code: 股票代码
|
||
|
||
Returns:
|
||
float: 涨跌停限制比例
|
||
"""
|
||
if not stock_code:
|
||
return 10.0
|
||
|
||
# 去掉市场后缀
|
||
clean_code = stock_code.replace('.SH', '').replace('.SZ', '').replace('.BJ', '')
|
||
|
||
# ST股票 (5%涨跌停)
|
||
if 'ST' in stock_code.upper():
|
||
return 5.0
|
||
|
||
# 科创板 (688开头, 20%涨跌停)
|
||
if clean_code.startswith('688'):
|
||
return 20.0
|
||
|
||
# 创业板注册制 (30开头, 20%涨跌停)
|
||
if clean_code.startswith('30'):
|
||
return 20.0
|
||
|
||
# 北交所 (43、83、87开头, 30%涨跌停)
|
||
if clean_code.startswith(('43', '83', '87')):
|
||
return 30.0
|
||
|
||
# 主板、中小板默认 (10%涨跌停)
|
||
return 10.0
|
||
|
||
|
||
@app.route('/api/events', methods=['GET'])
|
||
def api_get_events():
|
||
"""
|
||
获取事件列表API - 优化版本(保持完全兼容)
|
||
仅限 Pro/Max 会员访问(小程序功能)
|
||
|
||
优化策略:
|
||
1. 使用ind_type字段简化内部逻辑
|
||
2. 批量获取股票行情,包括周涨跌计算
|
||
3. 保持原有返回数据结构不变
|
||
"""
|
||
try:
|
||
# ==================== 参数解析 ====================
|
||
|
||
# 分页参数
|
||
page = max(1, request.args.get('page', 1, type=int))
|
||
per_page = min(100, max(1, request.args.get('per_page', 10, type=int)))
|
||
|
||
# 基础筛选参数
|
||
event_type = request.args.get('type', 'all')
|
||
event_status = request.args.get('status', 'active')
|
||
importance = request.args.get('importance', 'all')
|
||
|
||
# 日期筛选参数
|
||
start_date = request.args.get('start_date')
|
||
end_date = request.args.get('end_date')
|
||
date_range = request.args.get('date_range')
|
||
recent_days = request.args.get('recent_days', type=int)
|
||
|
||
# 行业筛选参数(重新设计)
|
||
ind_type = request.args.get('ind_type', 'all')
|
||
stock_sector = request.args.get('stock_sector', 'all')
|
||
secondary_sector = request.args.get('secondary_sector', 'all')
|
||
|
||
# 新的行业层级筛选参数
|
||
industry_level = request.args.get('industry_level', type=int) # 筛选层级:1-4
|
||
industry_classification = request.args.get('industry_classification') # 行业名称
|
||
|
||
# 如果使用旧参数,映射到ind_type
|
||
if ind_type == 'all' and stock_sector != 'all':
|
||
ind_type = stock_sector
|
||
|
||
# 标签筛选参数
|
||
tag = request.args.get('tag')
|
||
tags = request.args.get('tags')
|
||
keywords = request.args.get('keywords')
|
||
|
||
# 搜索参数
|
||
search_query = request.args.get('q')
|
||
search_type = request.args.get('search_type', 'topic')
|
||
search_fields = request.args.get('search_fields', 'title,description').split(',')
|
||
|
||
# 排序参数
|
||
sort_by = request.args.get('sort', 'new')
|
||
return_type = request.args.get('return_type', 'avg')
|
||
order = request.args.get('order', 'desc')
|
||
|
||
# 收益率筛选参数
|
||
min_avg_return = request.args.get('min_avg_return', type=float)
|
||
max_avg_return = request.args.get('max_avg_return', type=float)
|
||
min_max_return = request.args.get('min_max_return', type=float)
|
||
max_max_return = request.args.get('max_max_return', type=float)
|
||
min_week_return = request.args.get('min_week_return', type=float)
|
||
max_week_return = request.args.get('max_week_return', type=float)
|
||
|
||
# 其他筛选参数
|
||
min_hot_score = request.args.get('min_hot_score', type=float)
|
||
max_hot_score = request.args.get('max_hot_score', type=float)
|
||
min_view_count = request.args.get('min_view_count', type=int)
|
||
creator_id = request.args.get('creator_id', type=int)
|
||
|
||
# 返回格式参数
|
||
include_creator = request.args.get('include_creator', 'true').lower() == 'true'
|
||
include_stats = request.args.get('include_stats', 'true').lower() == 'true'
|
||
include_related_data = request.args.get('include_related_data', 'false').lower() == 'true'
|
||
|
||
# ==================== 构建查询 ====================
|
||
|
||
query = Event.query
|
||
|
||
# 状态筛选
|
||
if event_status != 'all':
|
||
query = query.filter_by(status=event_status)
|
||
|
||
# 事件类型筛选
|
||
if event_type != 'all':
|
||
query = query.filter_by(event_type=event_type)
|
||
|
||
# 重要性筛选
|
||
if importance != 'all':
|
||
query = query.filter_by(importance=importance)
|
||
|
||
# 行业类型筛选(使用ind_type字段)
|
||
if ind_type != 'all':
|
||
query = query.filter_by(ind_type=ind_type)
|
||
|
||
# 创建者筛选
|
||
if creator_id:
|
||
query = query.filter_by(creator_id=creator_id)
|
||
|
||
# ==================== 日期筛选 ====================
|
||
|
||
if recent_days:
|
||
cutoff_date = datetime.now() - timedelta(days=recent_days)
|
||
query = query.filter(Event.created_at >= cutoff_date)
|
||
else:
|
||
# 处理日期范围字符串
|
||
if date_range and ' 至 ' in date_range:
|
||
try:
|
||
start_date_str, end_date_str = date_range.split(' 至 ')
|
||
start_date = start_date_str.strip()
|
||
end_date = end_date_str.strip()
|
||
except ValueError:
|
||
pass
|
||
|
||
# 开始日期
|
||
if start_date:
|
||
try:
|
||
if len(start_date) == 10:
|
||
start_datetime = datetime.strptime(start_date, '%Y-%m-%d')
|
||
else:
|
||
start_datetime = datetime.strptime(start_date, '%Y-%m-%d %H:%M:%S')
|
||
query = query.filter(Event.created_at >= start_datetime)
|
||
except ValueError:
|
||
pass
|
||
|
||
# 结束日期
|
||
if end_date:
|
||
try:
|
||
if len(end_date) == 10:
|
||
end_datetime = datetime.strptime(end_date, '%Y-%m-%d')
|
||
end_datetime = end_datetime.replace(hour=23, minute=59, second=59)
|
||
else:
|
||
end_datetime = datetime.strptime(end_date, '%Y-%m-%d %H:%M:%S')
|
||
query = query.filter(Event.created_at <= end_datetime)
|
||
except ValueError:
|
||
pass
|
||
|
||
# ==================== 行业层级筛选(申银万国行业分类) ====================
|
||
|
||
if industry_level and industry_classification:
|
||
# 排除行业分类体系名称本身,这些不是具体的行业
|
||
classification_systems = [
|
||
'申银万国行业分类', '中上协行业分类', '巨潮行业分类',
|
||
'新财富行业分类', '证监会行业分类', '证监会行业分类(2001)'
|
||
]
|
||
|
||
if industry_classification not in classification_systems:
|
||
# 使用内存缓存获取行业代码前缀(性能优化:避免每次请求都查询数据库)
|
||
# 前端发送的level值:
|
||
# level=2 -> 一级行业
|
||
# level=3 -> 二级行业
|
||
# level=4 -> 三级行业
|
||
# level=5 -> 四级行业
|
||
|
||
if industry_level in SYWG_INDUSTRY_CACHE:
|
||
# 直接从缓存中获取代码前缀列表
|
||
code_prefixes = SYWG_INDUSTRY_CACHE[industry_level].get(industry_classification, [])
|
||
|
||
if code_prefixes:
|
||
# 构建查询条件:查找related_industries中包含这些前缀的事件
|
||
if isinstance(db.engine.dialect, MySQLDialect):
|
||
# MySQL JSON查询
|
||
conditions = []
|
||
for prefix in code_prefixes:
|
||
conditions.append(
|
||
text("""
|
||
JSON_SEARCH(
|
||
related_industries,
|
||
'one',
|
||
CONCAT(:prefix, '%'),
|
||
NULL,
|
||
'$[*]."申银万国行业分类"'
|
||
) IS NOT NULL
|
||
""").params(prefix=prefix)
|
||
)
|
||
|
||
if conditions:
|
||
query = query.filter(or_(*conditions))
|
||
else:
|
||
# 其他数据库
|
||
pattern_conditions = []
|
||
for prefix in code_prefixes:
|
||
pattern_conditions.append(
|
||
text("related_industries::text LIKE :pattern").params(
|
||
pattern=f'%"申银万国行业分类": "{prefix}%'
|
||
)
|
||
)
|
||
|
||
if pattern_conditions:
|
||
query = query.filter(or_(*pattern_conditions))
|
||
else:
|
||
# 没有找到匹配的行业代码,返回空结果
|
||
query = query.filter(Event.id == -1)
|
||
else:
|
||
# 无效的层级参数
|
||
app.logger.warning(f"Invalid industry_level: {industry_level}")
|
||
else:
|
||
# industry_classification 是分类体系名称,不进行筛选
|
||
app.logger.info(
|
||
f"Skipping filter: industry_classification '{industry_classification}' is a classification system name")
|
||
|
||
# ==================== 细分行业筛选(保留向后兼容) ====================
|
||
|
||
elif secondary_sector != 'all':
|
||
# 直接按行业名称查询(最后一级行业 - level5/f007v)
|
||
sector_code_query = db.session.query(text("DISTINCT f003v")).select_from(
|
||
text("ea_sector")
|
||
).filter(
|
||
text("f002v = '申银万国行业分类' AND f007v = :sector_name")
|
||
).params(sector_name=secondary_sector)
|
||
|
||
sector_result = sector_code_query.first()
|
||
|
||
if sector_result and sector_result[0]:
|
||
industry_code_to_search = sector_result[0]
|
||
|
||
# 在related_industries JSON中查找包含该代码的事件
|
||
if isinstance(db.engine.dialect, MySQLDialect):
|
||
query = query.filter(
|
||
text("""
|
||
JSON_SEARCH(
|
||
related_industries,
|
||
'one',
|
||
:industry_code,
|
||
NULL,
|
||
'$[*]."申银万国行业分类"'
|
||
) IS NOT NULL
|
||
""")
|
||
).params(industry_code=industry_code_to_search)
|
||
else:
|
||
query = query.filter(
|
||
text("""
|
||
related_industries::text LIKE :pattern
|
||
""")
|
||
).params(pattern=f'%"申银万国行业分类": "{industry_code_to_search}"%')
|
||
else:
|
||
# 如果没有找到对应的行业代码,返回空结果
|
||
query = query.filter(Event.id == -1)
|
||
|
||
# ==================== 概念/标签筛选 ====================
|
||
|
||
# 单个标签筛选
|
||
if tag:
|
||
if isinstance(db.engine.dialect, MySQLDialect):
|
||
query = query.filter(text("JSON_CONTAINS(keywords, :tag, '$')"))
|
||
query = query.params(tag=json.dumps(tag))
|
||
else:
|
||
query = query.filter(Event.keywords.cast(JSONB).contains([tag]))
|
||
|
||
# 多个标签筛选 (AND逻辑)
|
||
if tags:
|
||
tag_list = [t.strip() for t in tags.split(',') if t.strip()]
|
||
for single_tag in tag_list:
|
||
if isinstance(db.engine.dialect, MySQLDialect):
|
||
query = query.filter(text("JSON_CONTAINS(keywords, :tag, '$')"))
|
||
query = query.params(tag=json.dumps(single_tag))
|
||
else:
|
||
query = query.filter(Event.keywords.cast(JSONB).contains([single_tag]))
|
||
|
||
# 关键词筛选 (OR逻辑)
|
||
if keywords:
|
||
keyword_list = [k.strip() for k in keywords.split(',') if k.strip()]
|
||
keyword_filters = []
|
||
for keyword in keyword_list:
|
||
if isinstance(db.engine.dialect, MySQLDialect):
|
||
keyword_filters.append(text("JSON_CONTAINS(keywords, :keyword, '$')"))
|
||
else:
|
||
keyword_filters.append(Event.keywords.cast(JSONB).contains([keyword]))
|
||
if keyword_filters:
|
||
query = query.filter(or_(*keyword_filters))
|
||
|
||
# ==================== 搜索功能 ====================
|
||
|
||
if search_query:
|
||
search_terms = search_query.strip().split()
|
||
|
||
if search_type == 'stock':
|
||
# 股票搜索
|
||
query = query.join(RelatedStock).filter(
|
||
or_(
|
||
RelatedStock.stock_code.ilike(f'%{search_query}%'),
|
||
RelatedStock.stock_name.ilike(f'%{search_query}%')
|
||
)
|
||
).distinct()
|
||
elif search_type == 'all':
|
||
# 全局搜索
|
||
search_filters = []
|
||
|
||
# 文本字段搜索
|
||
for term in search_terms:
|
||
term_filters = []
|
||
if 'title' in search_fields:
|
||
term_filters.append(Event.title.ilike(f'%{term}%'))
|
||
if 'description' in search_fields:
|
||
term_filters.append(Event.description.ilike(f'%{term}%'))
|
||
if 'keywords' in search_fields:
|
||
if isinstance(db.engine.dialect, MySQLDialect):
|
||
term_filters.append(text("JSON_CONTAINS(keywords, :term, '$')"))
|
||
else:
|
||
term_filters.append(Event.keywords.cast(JSONB).contains([term]))
|
||
|
||
if term_filters:
|
||
search_filters.append(or_(*term_filters))
|
||
|
||
# 股票搜索
|
||
stock_subquery = db.session.query(RelatedStock.event_id).filter(
|
||
or_(
|
||
RelatedStock.stock_code.ilike(f'%{search_query}%'),
|
||
RelatedStock.stock_name.ilike(f'%{search_query}%')
|
||
)
|
||
).subquery()
|
||
|
||
search_filters.append(Event.id.in_(stock_subquery))
|
||
|
||
if search_filters:
|
||
query = query.filter(or_(*search_filters))
|
||
else:
|
||
# 话题搜索 (默认)
|
||
for term in search_terms:
|
||
term_filters = []
|
||
if 'title' in search_fields:
|
||
term_filters.append(Event.title.ilike(f'%{term}%'))
|
||
if 'description' in search_fields:
|
||
term_filters.append(Event.description.ilike(f'%{term}%'))
|
||
if 'keywords' in search_fields:
|
||
if isinstance(db.engine.dialect, MySQLDialect):
|
||
term_filters.append(text("JSON_CONTAINS(keywords, :term, '$')"))
|
||
else:
|
||
term_filters.append(Event.keywords.cast(JSONB).contains([term]))
|
||
|
||
if term_filters:
|
||
query = query.filter(or_(*term_filters))
|
||
|
||
# ==================== 收益率筛选 ====================
|
||
|
||
if min_avg_return is not None:
|
||
query = query.filter(Event.related_avg_chg >= min_avg_return)
|
||
if max_avg_return is not None:
|
||
query = query.filter(Event.related_avg_chg <= max_avg_return)
|
||
|
||
if min_max_return is not None:
|
||
query = query.filter(Event.related_max_chg >= min_max_return)
|
||
if max_max_return is not None:
|
||
query = query.filter(Event.related_max_chg <= max_max_return)
|
||
|
||
if min_week_return is not None:
|
||
query = query.filter(Event.related_week_chg >= min_week_return)
|
||
if max_week_return is not None:
|
||
query = query.filter(Event.related_week_chg <= max_week_return)
|
||
|
||
# ==================== 其他数值筛选 ====================
|
||
|
||
if min_hot_score is not None:
|
||
query = query.filter(Event.hot_score >= min_hot_score)
|
||
if max_hot_score is not None:
|
||
query = query.filter(Event.hot_score <= max_hot_score)
|
||
|
||
if min_view_count is not None:
|
||
query = query.filter(Event.view_count >= min_view_count)
|
||
|
||
# ==================== 排序逻辑 ====================
|
||
|
||
order_func = desc if order.lower() == 'desc' else asc
|
||
|
||
if sort_by == 'hot':
|
||
query = query.order_by(order_func(Event.hot_score), desc(Event.created_at))
|
||
elif sort_by == 'new':
|
||
query = query.order_by(order_func(Event.created_at))
|
||
elif sort_by == 'returns':
|
||
if return_type == 'avg':
|
||
query = query.order_by(order_func(Event.related_avg_chg), desc(Event.created_at))
|
||
elif return_type == 'max':
|
||
query = query.order_by(order_func(Event.related_max_chg), desc(Event.created_at))
|
||
elif return_type == 'week':
|
||
query = query.order_by(order_func(Event.related_week_chg), desc(Event.created_at))
|
||
elif sort_by == 'importance':
|
||
importance_order = case(
|
||
(Event.importance == 'S', 1),
|
||
(Event.importance == 'A', 2),
|
||
(Event.importance == 'B', 3),
|
||
(Event.importance == 'C', 4),
|
||
else_=5
|
||
)
|
||
if order.lower() == 'desc':
|
||
query = query.order_by(importance_order, desc(Event.created_at))
|
||
else:
|
||
query = query.order_by(desc(importance_order), desc(Event.created_at))
|
||
elif sort_by == 'view_count':
|
||
query = query.order_by(order_func(Event.view_count), desc(Event.created_at))
|
||
elif sort_by == 'follow' and hasattr(request, 'user') and request.user.is_authenticated:
|
||
# 关注的事件排序
|
||
query = query.join(EventFollow).filter(
|
||
EventFollow.user_id == request.user.id
|
||
).order_by(order_func(Event.created_at))
|
||
else:
|
||
# 兜底排序:始终按时间倒序
|
||
query = query.order_by(desc(Event.created_at))
|
||
|
||
# ==================== 分页查询 ====================
|
||
|
||
paginated = query.paginate(page=page, per_page=per_page, error_out=False)
|
||
|
||
# ==================== 构建响应数据 ====================
|
||
|
||
events_data = []
|
||
for event in paginated.items:
|
||
# 构建事件数据(保持原有结构,个股信息和统计置空)
|
||
event_dict = {
|
||
'id': event.id,
|
||
'title': event.title,
|
||
'description': event.description,
|
||
'event_type': event.event_type,
|
||
'importance': event.importance,
|
||
'status': event.status,
|
||
'created_at': event.created_at.isoformat() if event.created_at else None,
|
||
'updated_at': event.updated_at.isoformat() if event.updated_at else None,
|
||
'start_time': event.start_time.isoformat() if event.start_time else None,
|
||
'end_time': event.end_time.isoformat() if event.end_time else None,
|
||
# 个股信息(置空)
|
||
'related_stocks': [],
|
||
# 股票统计(置空或使用数据库字段)
|
||
'stocks_stats': {
|
||
'stocks_count': 10,
|
||
'valid_stocks_count': 0,
|
||
# 使用数据库字段的涨跌幅
|
||
'avg_week_change': round(event.related_week_chg, 2) if event.related_week_chg else 0,
|
||
'max_week_change': round(event.related_max_chg, 2) if event.related_max_chg else 0,
|
||
'avg_daily_change': round(event.related_avg_chg, 2) if event.related_avg_chg else 0,
|
||
'max_daily_change': round(event.related_max_chg, 2) if event.related_max_chg else 0
|
||
}
|
||
}
|
||
|
||
# 统计信息(可选)
|
||
if include_stats:
|
||
event_dict.update({
|
||
'hot_score': event.hot_score,
|
||
'view_count': event.view_count,
|
||
'post_count': event.post_count,
|
||
'follower_count': event.follower_count,
|
||
'related_avg_chg': event.related_avg_chg,
|
||
'related_max_chg': event.related_max_chg,
|
||
'related_week_chg': event.related_week_chg,
|
||
'invest_score': event.invest_score,
|
||
'trending_score': event.trending_score,
|
||
})
|
||
|
||
# 创建者信息(可选)
|
||
if include_creator:
|
||
event_dict['creator'] = {
|
||
'id': event.creator.id if event.creator else None,
|
||
'username': event.creator.username if event.creator else 'Anonymous',
|
||
'avatar_url': get_full_avatar_url(event.creator.avatar_url) if event.creator else None,
|
||
'is_creator': event.creator.is_creator if event.creator else False,
|
||
'creator_type': event.creator.creator_type if event.creator else None
|
||
}
|
||
|
||
# 关联数据
|
||
event_dict['keywords'] = event.keywords if isinstance(event.keywords, list) else []
|
||
event_dict['related_industries'] = event.related_industries
|
||
|
||
# 包含统计信息(可选,置空)
|
||
if include_stats:
|
||
event_dict['stats'] = {
|
||
'related_stocks_count': 10,
|
||
'historical_events_count': 0,
|
||
'related_data_count': 0,
|
||
'related_concepts_count': 0
|
||
}
|
||
|
||
# 包含关联数据(可选,已置空)
|
||
if include_related_data:
|
||
event_dict['related_stocks'] = []
|
||
|
||
events_data.append(event_dict)
|
||
|
||
# ==================== 构建筛选信息 ====================
|
||
|
||
applied_filters = {}
|
||
|
||
# 记录已应用的筛选条件
|
||
if event_type != 'all':
|
||
applied_filters['type'] = event_type
|
||
if importance != 'all':
|
||
applied_filters['importance'] = importance
|
||
if start_date:
|
||
applied_filters['start_date'] = start_date
|
||
if end_date:
|
||
applied_filters['end_date'] = end_date
|
||
if industry_level and industry_classification:
|
||
applied_filters['industry_level'] = industry_level
|
||
applied_filters['industry_classification'] = industry_classification
|
||
if tag:
|
||
applied_filters['tag'] = tag
|
||
if tags:
|
||
applied_filters['tags'] = tags
|
||
if search_query:
|
||
applied_filters['search_query'] = search_query
|
||
applied_filters['search_type'] = search_type
|
||
|
||
# ==================== 返回结果(保持完全兼容,统计数据置空) ====================
|
||
|
||
return jsonify({
|
||
'success': True,
|
||
'data': {
|
||
'events': events_data,
|
||
'pagination': {
|
||
'page': paginated.page,
|
||
'per_page': paginated.per_page,
|
||
'total': paginated.total,
|
||
'pages': paginated.pages,
|
||
'has_prev': paginated.has_prev,
|
||
'has_next': paginated.has_next
|
||
},
|
||
'filters': {
|
||
'applied_filters': {
|
||
'type': event_type,
|
||
'status': event_status,
|
||
'importance': importance,
|
||
'stock_sector': stock_sector, # 保持兼容
|
||
'secondary_sector': secondary_sector, # 保持兼容
|
||
'sort': sort_by,
|
||
'order': order
|
||
}
|
||
},
|
||
# 整体股票涨跌幅分布统计(置空)
|
||
'overall_stats': {
|
||
'total_stocks': 0,
|
||
'change_distribution': {
|
||
'limit_down': 0,
|
||
'down_over_5': 0,
|
||
'down_5_to_1': 0,
|
||
'down_within_1': 0,
|
||
'flat': 0,
|
||
'up_within_1': 0,
|
||
'up_1_to_5': 0,
|
||
'up_over_5': 0,
|
||
'limit_up': 0
|
||
},
|
||
'change_distribution_percentages': {
|
||
'limit_down': 0,
|
||
'down_over_5': 0,
|
||
'down_5_to_1': 0,
|
||
'down_within_1': 0,
|
||
'flat': 0,
|
||
'up_within_1': 0,
|
||
'up_1_to_5': 0,
|
||
'up_over_5': 0,
|
||
'limit_up': 0
|
||
}
|
||
}
|
||
}
|
||
})
|
||
|
||
except Exception as e:
|
||
app.logger.error(f"获取事件列表出错: {str(e)}", exc_info=True)
|
||
return jsonify({
|
||
'success': False,
|
||
'error': str(e),
|
||
'error_type': type(e).__name__
|
||
}), 500
|
||
|
||
|
||
def get_filter_counts(base_query):
|
||
"""
|
||
获取各个筛选条件的计数信息
|
||
可用于前端显示筛选选项的可用数量
|
||
"""
|
||
try:
|
||
counts = {}
|
||
|
||
# 重要性计数
|
||
importance_counts = db.session.query(
|
||
Event.importance,
|
||
func.count(Event.id).label('count')
|
||
).filter(
|
||
Event.id.in_(base_query.with_entities(Event.id).subquery())
|
||
).group_by(Event.importance).all()
|
||
|
||
counts['importance'] = {item.importance or 'unknown': item.count for item in importance_counts}
|
||
|
||
# 事件类型计数
|
||
type_counts = db.session.query(
|
||
Event.event_type,
|
||
func.count(Event.id).label('count')
|
||
).filter(
|
||
Event.id.in_(base_query.with_entities(Event.id).subquery())
|
||
).group_by(Event.event_type).all()
|
||
|
||
counts['event_type'] = {item.event_type or 'unknown': item.count for item in type_counts}
|
||
|
||
return counts
|
||
except Exception:
|
||
return {}
|
||
|
||
|
||
def get_event_class(count):
|
||
"""根据事件数量返回对应的样式类"""
|
||
if count >= 10:
|
||
return 'bg-gradient-danger'
|
||
elif count >= 7:
|
||
return 'bg-gradient-warning'
|
||
elif count >= 4:
|
||
return 'bg-gradient-info'
|
||
else:
|
||
return 'bg-gradient-success'
|
||
|
||
|
||
@app.route('/api/calendar-event-counts')
|
||
def get_calendar_event_counts():
|
||
"""获取整月的事件数量统计,仅统计type为event的事件"""
|
||
try:
|
||
# 获取当前月份的开始和结束日期
|
||
today = datetime.now()
|
||
start_date = today.replace(day=1)
|
||
if today.month == 12:
|
||
end_date = today.replace(year=today.year + 1, month=1, day=1)
|
||
else:
|
||
end_date = today.replace(month=today.month + 1, day=1)
|
||
|
||
# 修改查询以仅统计type为event的事件数量
|
||
query = """
|
||
SELECT DATE(calendar_time) as date, COUNT(*) as count
|
||
FROM future_events
|
||
WHERE calendar_time BETWEEN :start_date AND :end_date
|
||
AND type = 'event'
|
||
GROUP BY DATE(calendar_time)
|
||
"""
|
||
|
||
result = db.session.execute(text(query), {
|
||
'start_date': start_date,
|
||
'end_date': end_date
|
||
})
|
||
|
||
# 格式化结果为日历事件格式
|
||
events = [{
|
||
'title': f'{day.count} 个事件',
|
||
'start': day.date.isoformat() if day.date else None,
|
||
'className': get_event_class(day.count)
|
||
} for day in result]
|
||
|
||
return jsonify(events)
|
||
|
||
except Exception as e:
|
||
app.logger.error(f"获取日历事件统计出错: {str(e)}", exc_info=True)
|
||
return jsonify({'error': str(e), 'error_type': type(e).__name__}), 500
|
||
|
||
|
||
def get_full_avatar_url(avatar_url):
|
||
"""
|
||
统一处理头像URL,确保返回完整的可访问URL
|
||
|
||
Args:
|
||
avatar_url: 头像URL字符串
|
||
|
||
Returns:
|
||
完整的头像URL,如果没有头像则返回默认头像URL
|
||
"""
|
||
if not avatar_url:
|
||
# 返回默认头像
|
||
return f"{DOMAIN}/static/assets/img/default-avatar.png"
|
||
|
||
# 如果已经是完整URL(http或https开头),直接返回
|
||
if avatar_url.startswith(('http://', 'https://')):
|
||
return avatar_url
|
||
|
||
# 如果是相对路径,拼接域名
|
||
if avatar_url.startswith('/'):
|
||
return f"{DOMAIN}{avatar_url}"
|
||
else:
|
||
return f"{DOMAIN}/{avatar_url}"
|
||
|
||
|
||
# 修改User模型的to_dict方法
|
||
def to_dict(self):
|
||
"""转换为字典格式,方便API返回"""
|
||
return {
|
||
'id': self.id,
|
||
'username': self.username,
|
||
'email': self.email,
|
||
'nickname': self.nickname,
|
||
'avatar_url': get_full_avatar_url(self.avatar_url), # 使用统一处理函数
|
||
'bio': self.bio,
|
||
'location': self.location,
|
||
'is_verified': self.is_verified,
|
||
'user_level': self.user_level,
|
||
'reputation_score': self.reputation_score,
|
||
'post_count': self.post_count,
|
||
'follower_count': self.follower_count,
|
||
'following_count': self.following_count,
|
||
'created_at': self.created_at.isoformat() if self.created_at else None,
|
||
'last_seen': self.last_seen.isoformat() if self.last_seen else None
|
||
}
|
||
|
||
|
||
# ==================== 标准化API接口 ====================
|
||
|
||
# 1. 首页接口
|
||
@app.route('/api/home', methods=['GET'])
|
||
def api_home():
|
||
try:
|
||
seven_days_ago = datetime.now() - timedelta(days=7)
|
||
hot_events = Event.query.filter(
|
||
Event.status == 'active',
|
||
Event.created_at >= seven_days_ago
|
||
).order_by(Event.hot_score.desc()).limit(10).all()
|
||
|
||
events_data = []
|
||
for event in hot_events:
|
||
related_stocks = RelatedStock.query.filter_by(event_id=event.id).all()
|
||
|
||
# 计算相关性统计数据
|
||
correlations = [float(stock.correlation or 0) for stock in related_stocks]
|
||
avg_correlation = sum(correlations) / len(correlations) if correlations else 0
|
||
max_correlation = max(correlations) if correlations else 0
|
||
|
||
stocks_data = []
|
||
total_week_change = 0
|
||
max_week_change = 0
|
||
total_daily_change = 0 # 新增:日涨跌幅总和
|
||
max_daily_change = 0 # 新增:最大日涨跌幅
|
||
valid_stocks_count = 0
|
||
|
||
for stock in related_stocks:
|
||
stock_code = stock.stock_code.split('.')[0]
|
||
|
||
# 获取最新交易日数据
|
||
latest_trade = db.session.execute(text("""
|
||
SELECT *
|
||
FROM ea_trade
|
||
WHERE SECCODE = :stock_code
|
||
ORDER BY TRADEDATE DESC LIMIT 1
|
||
"""), {"stock_code": stock_code}).first()
|
||
|
||
week_change = 0
|
||
daily_change = 0 # 新增:日涨跌幅
|
||
|
||
if latest_trade and latest_trade.F007N:
|
||
latest_price = float(latest_trade.F007N or 0)
|
||
latest_date = latest_trade.TRADEDATE
|
||
daily_change = float(latest_trade.F010N or 0) # F010N是日涨跌幅字段
|
||
|
||
# 更新日涨跌幅统计
|
||
total_daily_change += daily_change
|
||
max_daily_change = max(max_daily_change, daily_change)
|
||
|
||
# 获取最近5条交易记录
|
||
week_ago_trades = db.session.execute(text("""
|
||
SELECT *
|
||
FROM ea_trade
|
||
WHERE SECCODE = :stock_code
|
||
AND TRADEDATE < :latest_date
|
||
ORDER BY TRADEDATE DESC LIMIT 5
|
||
"""), {
|
||
"stock_code": stock_code,
|
||
"latest_date": latest_date
|
||
}).fetchall()
|
||
|
||
if week_ago_trades and week_ago_trades[-1].F007N:
|
||
week_ago_price = float(week_ago_trades[-1].F007N or 0)
|
||
if week_ago_price > 0:
|
||
week_change = (latest_price - week_ago_price) / week_ago_price * 100
|
||
total_week_change += week_change
|
||
max_week_change = max(max_week_change, week_change)
|
||
valid_stocks_count += 1
|
||
|
||
stocks_data.append({
|
||
"stock_code": stock.stock_code,
|
||
"stock_name": stock.stock_name,
|
||
"correlation": float(stock.correlation or 0),
|
||
"sector": stock.sector,
|
||
"week_change": round(week_change, 2),
|
||
"daily_change": round(daily_change, 2), # 新增:个股日涨跌幅
|
||
"latest_trade_date": latest_trade.TRADEDATE.strftime("%Y-%m-%d") if latest_trade else None
|
||
})
|
||
|
||
# 计算平均值
|
||
avg_week_change = total_week_change / valid_stocks_count if valid_stocks_count > 0 else 0
|
||
avg_daily_change = total_daily_change / valid_stocks_count if valid_stocks_count > 0 else 0
|
||
|
||
events_data.append({
|
||
"id": event.id,
|
||
"title": event.title,
|
||
"description": event.description,
|
||
'event_type': event.event_type,
|
||
'importance': event.importance, # 添加重要性
|
||
'status': event.status,
|
||
"created_at": event.created_at.strftime("%Y-%m-%d %H:%M:%S"),
|
||
'updated_at': event.updated_at.isoformat() if event.updated_at else None,
|
||
'start_time': event.start_time.isoformat() if event.start_time else None,
|
||
'end_time': event.end_time.isoformat() if event.end_time else None,
|
||
'view_count': event.view_count, # 添加浏览量
|
||
'post_count': event.post_count, # 添加帖子数
|
||
'follower_count': event.follower_count, # 添加关注者数
|
||
"related_stocks": stocks_data,
|
||
"stocks_stats": {
|
||
"avg_correlation": round(avg_correlation, 2),
|
||
"max_correlation": round(max_correlation, 2),
|
||
"stocks_count": len(related_stocks),
|
||
"valid_stocks_count": valid_stocks_count,
|
||
# 周涨跌统计
|
||
"avg_week_change": round(avg_week_change, 2),
|
||
"max_week_change": round(max_week_change, 2),
|
||
# 日涨跌统计
|
||
"avg_daily_change": round(avg_daily_change, 2),
|
||
"max_daily_change": round(max_daily_change, 2)
|
||
}
|
||
})
|
||
|
||
return jsonify({
|
||
"code": 200,
|
||
"message": "success",
|
||
"data": {
|
||
"events": events_data
|
||
}
|
||
})
|
||
|
||
except Exception as e:
|
||
print(f"Error in api_home: {str(e)}")
|
||
return jsonify({
|
||
"code": 500,
|
||
"message": str(e),
|
||
"data": None
|
||
}), 500
|
||
|
||
|
||
@app.route('/api/auth/logout', methods=['POST'])
|
||
def logout_with_token():
|
||
"""使用token登出"""
|
||
# 从请求头获取token
|
||
auth_header = request.headers.get('Authorization')
|
||
if auth_header and auth_header.startswith('Bearer '):
|
||
token = auth_header[7:]
|
||
else:
|
||
data = request.get_json()
|
||
token = data.get('token') if data else None
|
||
|
||
if token and token in user_tokens:
|
||
user_tokens.delete(token)
|
||
|
||
# 清除session
|
||
session.clear()
|
||
|
||
return jsonify({'message': '登出成功'}), 200
|
||
|
||
|
||
def send_sms_code(phone, code, template_id):
|
||
"""发送短信验证码"""
|
||
try:
|
||
cred = credential.Credential(SMS_SECRET_ID, SMS_SECRET_KEY)
|
||
httpProfile = HttpProfile()
|
||
httpProfile.endpoint = "sms.tencentcloudapi.com"
|
||
|
||
clientProfile = ClientProfile()
|
||
clientProfile.httpProfile = httpProfile
|
||
client = sms_client.SmsClient(cred, "ap-beijing", clientProfile)
|
||
|
||
req = models.SendSmsRequest()
|
||
params = {
|
||
"PhoneNumberSet": [phone],
|
||
"SmsSdkAppId": SMS_SDK_APP_ID,
|
||
"TemplateId": template_id,
|
||
"SignName": SMS_SIGN_NAME,
|
||
"TemplateParamSet": [code] if template_id == SMS_TEMPLATE_REGISTER else [code, "5"]
|
||
}
|
||
req.from_json_string(json.dumps(params))
|
||
|
||
resp = client.SendSms(req)
|
||
return True
|
||
except TencentCloudSDKException as err:
|
||
print(f"SMS Error: {err}")
|
||
return False
|
||
|
||
|
||
def generate_verification_code():
|
||
"""生成6位数字验证码"""
|
||
return ''.join(random.choices(string.digits, k=6))
|
||
|
||
|
||
@app.route('/api/auth/send-sms', methods=['POST'])
|
||
def send_sms_verification():
|
||
"""发送手机验证码(统一接口,自动判断场景)"""
|
||
data = request.get_json()
|
||
phone = data.get('phone')
|
||
|
||
if not phone:
|
||
return jsonify({'error': '手机号不能为空'}), 400
|
||
|
||
# 检查手机号是否已注册
|
||
user_exists = User.query.filter_by(phone=phone).first() is not None
|
||
|
||
# 生成验证码
|
||
code = generate_verification_code()
|
||
|
||
# 根据用户是否存在自动选择模板
|
||
template_id = SMS_TEMPLATE_LOGIN if user_exists else SMS_TEMPLATE_REGISTER
|
||
|
||
# 发送短信
|
||
if send_sms_code(phone, code, template_id):
|
||
# 统一存储验证码(5分钟有效)
|
||
verification_codes[phone] = {
|
||
'code': code,
|
||
'expires': time.time() + 300
|
||
}
|
||
|
||
# 简单返回成功,不暴露用户是否存在的信息
|
||
return jsonify({
|
||
'message': '验证码已发送',
|
||
'expires_in': 300 # 告诉前端验证码有效期(秒)
|
||
}), 200
|
||
else:
|
||
return jsonify({'error': '验证码发送失败'}), 500
|
||
|
||
|
||
def generate_token(length=32):
|
||
"""生成随机token"""
|
||
characters = string.ascii_letters + string.digits
|
||
return ''.join(secrets.choice(characters) for _ in range(length))
|
||
|
||
|
||
@app.route('/api/auth/login/phone', methods=['POST'])
|
||
def login_with_phone():
|
||
"""统一的手机号登录/注册接口"""
|
||
data = request.get_json()
|
||
phone = data.get('phone')
|
||
code = data.get('code')
|
||
username = data.get('username') # 可选,新用户可以提供
|
||
password = data.get('password') # 可选,新用户可以提供
|
||
|
||
if not all([phone, code]):
|
||
return jsonify({
|
||
'code': 400,
|
||
'message': '手机号和验证码不能为空'
|
||
}), 400
|
||
|
||
# 验证验证码
|
||
stored_code = verification_codes.get(phone)
|
||
|
||
if not stored_code or stored_code['expires'] < time.time():
|
||
return jsonify({
|
||
'code': 400,
|
||
'message': '验证码已过期'
|
||
}), 400
|
||
|
||
if stored_code['code'] != code:
|
||
return jsonify({
|
||
'code': 400,
|
||
'message': '验证码错误'
|
||
}), 400
|
||
|
||
try:
|
||
# 查找用户
|
||
user = User.query.filter_by(phone=phone).first()
|
||
is_new_user = False
|
||
|
||
# 如果用户不存在,自动注册
|
||
if not user:
|
||
is_new_user = True
|
||
|
||
# 如果提供了用户名,检查是否已存在
|
||
if username:
|
||
if User.query.filter_by(username=username).first():
|
||
return jsonify({
|
||
'code': 400,
|
||
'message': '用户名已被使用,请换一个'
|
||
}), 400
|
||
else:
|
||
# 自动生成用户名
|
||
base_username = f"user_{phone[-4:]}"
|
||
username = base_username
|
||
counter = 1
|
||
while User.query.filter_by(username=username).first():
|
||
username = f"{base_username}_{counter}"
|
||
counter += 1
|
||
|
||
# 创建新用户
|
||
user = User(username=username, phone=phone)
|
||
user.email = f"{username}@valuefrontier.temp"
|
||
|
||
# 如果提供了密码就使用,否则生成随机密码
|
||
if password:
|
||
user.set_password(password)
|
||
else:
|
||
random_password = generate_token(16)
|
||
user.set_password(random_password)
|
||
|
||
user.phone_confirmed = True
|
||
|
||
db.session.add(user)
|
||
db.session.commit()
|
||
|
||
# 生成token
|
||
token = generate_token(32)
|
||
|
||
# 存储token映射(30天有效期)
|
||
user_tokens.set(token, {
|
||
'user_id': user.id,
|
||
'expires': datetime.now() + timedelta(days=30)
|
||
})
|
||
|
||
# 清除验证码
|
||
del verification_codes[phone]
|
||
|
||
# 设置session(保持与原有逻辑兼容)
|
||
session.permanent = True
|
||
session['user_id'] = user.id
|
||
session['username'] = user.username
|
||
session['logged_in'] = True
|
||
|
||
# 返回响应
|
||
response_data = {
|
||
'code': 0,
|
||
'message': '欢迎回来' if not is_new_user else '注册成功,欢迎加入',
|
||
'token': token,
|
||
'is_new_user': is_new_user, # 告诉前端是否是新用户
|
||
'user': {
|
||
'id': user.id,
|
||
'username': user.username,
|
||
'phone': user.phone,
|
||
'need_complete_profile': is_new_user # 提示新用户完善资料
|
||
}
|
||
}
|
||
|
||
return jsonify(response_data), 200
|
||
|
||
except Exception as e:
|
||
db.session.rollback()
|
||
print(f"Login/Register error: {e}")
|
||
return jsonify({
|
||
'code': 500,
|
||
'message': '操作失败,请重试'
|
||
}), 500
|
||
|
||
|
||
@app.route('/api/auth/verify-token', methods=['POST'])
|
||
def verify_token():
|
||
"""验证token有效性(可选接口)"""
|
||
data = request.get_json()
|
||
token = data.get('token')
|
||
|
||
if not token:
|
||
return jsonify({'valid': False, 'message': 'Token不能为空'}), 400
|
||
|
||
token_data = user_tokens.get(token)
|
||
|
||
if not token_data:
|
||
return jsonify({'valid': False, 'message': 'Token无效', 'code': 401}), 401
|
||
|
||
# 检查是否过期(expires 可能是字符串或 datetime)
|
||
expires = token_data['expires']
|
||
if isinstance(expires, str):
|
||
expires = datetime.fromisoformat(expires)
|
||
if expires < datetime.now():
|
||
user_tokens.delete(token)
|
||
return jsonify({'valid': False, 'message': 'Token已过期'}), 401
|
||
|
||
# 获取用户信息
|
||
user = User.query.get(token_data['user_id'])
|
||
if not user:
|
||
return jsonify({'valid': False, 'message': '用户不存在'}), 404
|
||
|
||
return jsonify({
|
||
'valid': True,
|
||
'user': {
|
||
'id': user.id,
|
||
'username': user.username,
|
||
'phone': user.phone
|
||
}
|
||
}), 200
|
||
|
||
|
||
def generate_jwt_token(user_id):
|
||
"""
|
||
生成JWT Token - 与原系统保持一致
|
||
|
||
Args:
|
||
user_id: 用户ID
|
||
|
||
Returns:
|
||
str: JWT token字符串
|
||
"""
|
||
payload = {
|
||
'user_id': user_id,
|
||
'exp': datetime.utcnow() + timedelta(hours=JWT_EXPIRATION_HOURS),
|
||
'iat': datetime.utcnow()
|
||
}
|
||
|
||
token = jwt.encode(payload, JWT_SECRET_KEY, algorithm=JWT_ALGORITHM)
|
||
return token
|
||
|
||
|
||
@app.route('/api/auth/login/wechat', methods=['POST'])
|
||
def api_login_wechat():
|
||
try:
|
||
# 1. 获取请求数据
|
||
data = request.get_json()
|
||
code = data.get('code') if data else None
|
||
|
||
if not code:
|
||
return jsonify({
|
||
'code': 400,
|
||
'message': '缺少必要的参数',
|
||
'data': None
|
||
}), 400
|
||
|
||
# 2. 验证code格式
|
||
if not isinstance(code, str) or len(code) < 10:
|
||
return jsonify({
|
||
'code': 400,
|
||
'message': 'code格式无效',
|
||
'data': None
|
||
}), 400
|
||
|
||
logger.info(f"开始处理微信登录,code长度: {len(code)}")
|
||
|
||
# 3. 调用微信接口获取用户信息
|
||
wx_api_url = 'https://api.weixin.qq.com/sns/jscode2session'
|
||
params = {
|
||
'appid': WECHAT_APP_ID,
|
||
'secret': WECHAT_APP_SECRET,
|
||
'js_code': code,
|
||
'grant_type': 'authorization_code'
|
||
}
|
||
|
||
try:
|
||
response = requests.get(wx_api_url, params=params, timeout=10)
|
||
response.raise_for_status()
|
||
wx_data = response.json()
|
||
|
||
# 检查微信API返回的错误
|
||
if 'errcode' in wx_data and wx_data['errcode'] != 0:
|
||
error_messages = {
|
||
-1: '系统繁忙,请稍后重试',
|
||
40029: 'code无效或已过期',
|
||
45011: '频率限制,请稍后再试',
|
||
40013: 'AppID错误',
|
||
40125: 'AppSecret错误',
|
||
40226: '高风险用户,登录被拦截'
|
||
}
|
||
|
||
error_msg = error_messages.get(
|
||
wx_data['errcode'],
|
||
f"微信接口错误: {wx_data.get('errmsg', '未知错误')}"
|
||
)
|
||
|
||
logger.error(f"WeChat API error {wx_data['errcode']}: {error_msg}")
|
||
|
||
return jsonify({
|
||
'code': 400,
|
||
'message': error_msg,
|
||
'data': None
|
||
}), 400
|
||
|
||
# 验证必需字段
|
||
if 'openid' not in wx_data or 'session_key' not in wx_data:
|
||
logger.error("响应缺少必需字段")
|
||
return jsonify({
|
||
'code': 500,
|
||
'message': '微信响应格式错误',
|
||
'data': None
|
||
}), 500
|
||
|
||
openid = wx_data['openid']
|
||
session_key = wx_data['session_key']
|
||
unionid = wx_data.get('unionid') # 可能为None
|
||
|
||
logger.info(f"成功获取微信用户信息 - OpenID: {openid[:8]}...")
|
||
if unionid:
|
||
logger.info(f"获取到UnionID: {unionid[:8]}...")
|
||
|
||
except requests.exceptions.Timeout:
|
||
logger.error("请求微信API超时")
|
||
return jsonify({
|
||
'code': 500,
|
||
'message': '请求超时,请重试',
|
||
'data': None
|
||
}), 500
|
||
except requests.exceptions.RequestException as e:
|
||
logger.error(f"网络请求失败: {str(e)}")
|
||
return jsonify({
|
||
'code': 500,
|
||
'message': '网络错误',
|
||
'data': None
|
||
}), 500
|
||
|
||
# 4. 查找或创建用户 - 核心逻辑
|
||
user = None
|
||
is_new_user = False
|
||
|
||
logger.info(f"开始查找用户 - UnionID: {unionid}, OpenID: {openid[:8]}...")
|
||
|
||
if unionid:
|
||
# 情况1: 有unionid,优先通过unionid查找
|
||
user = User.query.filter_by(wechat_union_id=unionid).first()
|
||
|
||
if user:
|
||
logger.info(f"通过UnionID找到用户: {user.username}")
|
||
# 更新openid(可能用户从不同小程序登录)
|
||
if user.wechat_open_id != openid:
|
||
user.wechat_open_id = openid
|
||
logger.info(f"更新用户OpenID: {openid[:8]}...")
|
||
else:
|
||
# unionid没找到,再尝试用openid查找(处理历史数据)
|
||
user = User.query.filter_by(wechat_open_id=openid).first()
|
||
if user:
|
||
logger.info(f"通过OpenID找到用户: {user.username}")
|
||
# 补充unionid
|
||
user.wechat_union_id = unionid
|
||
logger.info(f"为用户补充UnionID: {unionid[:8]}...")
|
||
else:
|
||
# 情况2: 没有unionid,只能通过openid查找
|
||
logger.warning("未获取到UnionID(小程序可能未绑定开放平台)")
|
||
user = User.query.filter_by(wechat_open_id=openid).first()
|
||
if user:
|
||
logger.info(f"通过OpenID找到用户: {user.username}")
|
||
|
||
# 5. 创建新用户
|
||
if not user:
|
||
is_new_user = True
|
||
|
||
# 生成唯一用户名
|
||
timestamp = int(time.time())
|
||
username = f"wx_{timestamp}_{openid[-6:]}"
|
||
|
||
# 确保用户名唯一
|
||
counter = 0
|
||
base_username = username
|
||
while User.query.filter_by(username=username).first():
|
||
counter += 1
|
||
username = f"{base_username}_{counter}"
|
||
|
||
# 创建用户对象(使用你的User模型)
|
||
user = User(
|
||
username=username,
|
||
email=f"{username}@wechat.local", # 占位邮箱
|
||
password="wechat_login_no_password" # 微信登录不需要密码
|
||
)
|
||
|
||
# 设置微信相关字段
|
||
user.wechat_open_id = openid
|
||
user.wechat_union_id = unionid
|
||
user.status = 'active'
|
||
user.email_confirmed = False
|
||
|
||
# 设置默认值
|
||
user.nickname = f"微信用户{openid[-4:]}"
|
||
user.bio = "" # 空的个人简介
|
||
user.avatar_url = None # 稍后会处理
|
||
user.is_creator = False
|
||
user.is_verified = False
|
||
user.user_level = 1
|
||
user.reputation_score = 0
|
||
user.contribution_point = 0
|
||
user.post_count = 0
|
||
user.comment_count = 0
|
||
user.follower_count = 0
|
||
user.following_count = 0
|
||
|
||
# 设置默认偏好
|
||
user.email_notifications = True
|
||
user.privacy_level = 'public'
|
||
user.theme_preference = 'light'
|
||
|
||
db.session.add(user)
|
||
logger.info(f"创建新用户: {username}")
|
||
else:
|
||
# 更新最后登录时间
|
||
user.update_last_seen()
|
||
logger.info(f"用户登录: {user.username}")
|
||
|
||
# 6. 提交数据库更改
|
||
try:
|
||
db.session.commit()
|
||
except Exception as e:
|
||
db.session.rollback()
|
||
logger.error(f"保存用户信息失败: {str(e)}")
|
||
return jsonify({
|
||
'code': 500,
|
||
'message': '保存用户信息失败',
|
||
'data': None
|
||
}), 500
|
||
|
||
# 7. 生成JWT token(使用原系统的生成方法)
|
||
token = generate_token(32) # 使用相同的随机字符串生成器
|
||
|
||
# 存储token映射(与手机登录保持一致)
|
||
user_tokens.set(token, {
|
||
'user_id': user.id,
|
||
'expires': datetime.now() + timedelta(days=30) # 30天有效期
|
||
})
|
||
|
||
# 设置session(可选,保持与手机登录一致)
|
||
session.permanent = True
|
||
session['user_id'] = user.id
|
||
session['username'] = user.username
|
||
session['logged_in'] = True
|
||
|
||
# 9. 构造返回数据 - 完全匹配要求的格式
|
||
response_data = {
|
||
'code': 200,
|
||
'data': {
|
||
'token': token, # 现在这个token能被token_required识别了
|
||
'user': {
|
||
'avatar_url': get_full_avatar_url(user.avatar_url),
|
||
'bio': user.bio or "",
|
||
'email': user.email,
|
||
'id': user.id,
|
||
'is_creator': user.is_creator,
|
||
'is_verified': user.is_verified,
|
||
'nickname': user.nickname or user.username,
|
||
'reputation_score': user.reputation_score,
|
||
'user_level': user.user_level,
|
||
'username': user.username
|
||
}
|
||
},
|
||
'message': '登录成功'
|
||
}
|
||
|
||
# 10. 记录日志
|
||
logger.info(
|
||
f"微信登录成功 - 用户ID: {user.id}, "
|
||
f"用户名: {user.username}, "
|
||
f"新用户: {is_new_user}, "
|
||
f"有UnionID: {unionid is not None}"
|
||
)
|
||
|
||
return jsonify(response_data), 200
|
||
|
||
except Exception as e:
|
||
# 捕获所有未处理的异常
|
||
logger.error(f"微信登录处理异常: {str(e)}", exc_info=True)
|
||
db.session.rollback()
|
||
|
||
return jsonify({
|
||
'code': 500,
|
||
'message': '服务器内部错误',
|
||
'data': None
|
||
}), 500
|
||
|
||
|
||
@app.route('/api/auth/login/email', methods=['POST'])
|
||
def api_login_email():
|
||
"""邮箱登录接口"""
|
||
try:
|
||
data = request.get_json()
|
||
email = data.get('email')
|
||
password = data.get('password')
|
||
|
||
if not email or not password:
|
||
return jsonify({
|
||
'code': 400,
|
||
'message': '邮箱和密码不能为空',
|
||
'data': None
|
||
}), 400
|
||
|
||
user = User.query.filter_by(email=email).first()
|
||
if not user or not user.check_password(password):
|
||
return jsonify({
|
||
'code': 400,
|
||
'message': '邮箱或密码错误',
|
||
'data': None
|
||
}), 400
|
||
token = generate_jwt_token(user.id)
|
||
login_user(user)
|
||
user.update_last_seen()
|
||
db.session.commit()
|
||
|
||
return jsonify({
|
||
'code': 200,
|
||
'message': '登录成功',
|
||
'data': {
|
||
'token': token,
|
||
'user_id': user.id,
|
||
'username': user.username,
|
||
'email': user.email,
|
||
'is_verified': user.is_verified,
|
||
'user_level': user.user_level
|
||
}
|
||
})
|
||
except Exception as e:
|
||
return jsonify({
|
||
'code': 500,
|
||
'message': str(e),
|
||
'data': None
|
||
}), 500
|
||
|
||
|
||
# 5. 事件详情-相关标的接口
|
||
@app.route('/api/event/<int:event_id>/related-stocks-detail', methods=['GET'])
|
||
def api_event_related_stocks(event_id):
|
||
"""事件相关标的详情接口 - 仅限 Pro/Max 会员"""
|
||
try:
|
||
from datetime import datetime, timedelta, time as dt_time
|
||
from sqlalchemy import text
|
||
|
||
event = Event.query.get_or_404(event_id)
|
||
related_stocks = event.related_stocks.order_by(RelatedStock.correlation.desc()).all()
|
||
|
||
# 获取ClickHouse客户端用于分时数据查询
|
||
client = get_clickhouse_client()
|
||
|
||
# 获取事件时间(如果事件有开始时间,使用开始时间;否则使用创建时间)
|
||
event_time = event.start_time if event.start_time else event.created_at
|
||
current_time = datetime.now()
|
||
|
||
# 定义交易日和时间范围计算函数(与 app.py 中的逻辑完全一致)
|
||
def get_trading_day_and_times(event_datetime):
|
||
event_date = event_datetime.date()
|
||
event_time_only = event_datetime.time()
|
||
|
||
# Trading hours
|
||
market_open = dt_time(9, 30)
|
||
market_close = dt_time(15, 0)
|
||
|
||
with engine.connect() as conn:
|
||
# First check if the event date itself is a trading day
|
||
is_trading_day = conn.execute(text("""
|
||
SELECT 1
|
||
FROM trading_days
|
||
WHERE EXCHANGE_DATE = :date
|
||
"""), {"date": event_date}).fetchone() is not None
|
||
|
||
if is_trading_day:
|
||
# If it's a trading day, determine time period based on event time
|
||
if event_time_only < market_open:
|
||
# Before market opens - use full trading day
|
||
return event_date, market_open, market_close
|
||
elif event_time_only > market_close:
|
||
# After market closes - get next trading day
|
||
next_trading_day = conn.execute(text("""
|
||
SELECT EXCHANGE_DATE
|
||
FROM trading_days
|
||
WHERE EXCHANGE_DATE > :date
|
||
ORDER BY EXCHANGE_DATE LIMIT 1
|
||
"""), {"date": event_date}).fetchone()
|
||
# Convert to date object if we found a next trading day
|
||
return (next_trading_day[0].date() if next_trading_day else None,
|
||
market_open, market_close)
|
||
else:
|
||
# During trading hours
|
||
return event_date, event_time_only, market_close
|
||
else:
|
||
# If not a trading day, get next trading day
|
||
next_trading_day = conn.execute(text("""
|
||
SELECT EXCHANGE_DATE
|
||
FROM trading_days
|
||
WHERE EXCHANGE_DATE > :date
|
||
ORDER BY EXCHANGE_DATE LIMIT 1
|
||
"""), {"date": event_date}).fetchone()
|
||
# Convert to date object if we found a next trading day
|
||
return (next_trading_day[0].date() if next_trading_day else None,
|
||
market_open, market_close)
|
||
|
||
trading_day, start_time, end_time = get_trading_day_and_times(event_time)
|
||
|
||
if not trading_day:
|
||
# 如果没有交易日,返回空数据
|
||
return jsonify({
|
||
'code': 200,
|
||
'message': 'success',
|
||
'data': {
|
||
'event_id': event_id,
|
||
'event_title': event.title,
|
||
'event_desc': event.description,
|
||
'event_type': event.event_type,
|
||
'event_importance': event.importance,
|
||
'event_status': event.status,
|
||
'event_created_at': event.created_at.strftime("%Y-%m-%d %H:%M:%S"),
|
||
'event_start_time': event.start_time.isoformat() if event.start_time else None,
|
||
'event_end_time': event.end_time.isoformat() if event.end_time else None,
|
||
'keywords': event.keywords,
|
||
'view_count': event.view_count,
|
||
'post_count': event.post_count,
|
||
'follower_count': event.follower_count,
|
||
'related_stocks': [],
|
||
'total_count': 0
|
||
}
|
||
})
|
||
|
||
# For historical dates, ensure we're using actual data
|
||
start_datetime = datetime.combine(trading_day, start_time)
|
||
end_datetime = datetime.combine(trading_day, end_time)
|
||
|
||
# If the trading day is in the future relative to current time, return only names without data
|
||
if trading_day > current_time.date():
|
||
start_datetime = datetime.combine(trading_day, start_time)
|
||
end_datetime = datetime.combine(trading_day, end_time)
|
||
|
||
print(f"事件时间: {event_time}, 交易日: {trading_day}, 时间范围: {start_datetime} - {end_datetime}")
|
||
|
||
def get_minute_chart_data(stock_code):
|
||
"""获取股票分时图数据"""
|
||
try:
|
||
# 获取当前日期或最新交易日的分时数据
|
||
from datetime import datetime, timedelta, time as dt_time
|
||
today = datetime.now().date()
|
||
|
||
# 获取最新交易日的分时数据
|
||
data = client.execute("""
|
||
SELECT
|
||
timestamp, open, high, low, close, volume, amt
|
||
FROM stock_minute
|
||
WHERE code = %(code)s
|
||
AND timestamp >= %(start)s
|
||
AND timestamp <= %(end)s
|
||
ORDER BY timestamp
|
||
""", {
|
||
'code': stock_code,
|
||
'start': datetime.combine(today, dt_time(9, 30)),
|
||
'end': datetime.combine(today, dt_time(15, 0))
|
||
})
|
||
|
||
# 如果今天没有数据,获取最近的交易日数据
|
||
if not data:
|
||
# 获取最近的交易日数据
|
||
recent_data = client.execute("""
|
||
SELECT
|
||
timestamp, open, high, low, close, volume, amt
|
||
FROM stock_minute
|
||
WHERE code = %(code)s
|
||
AND timestamp >= (
|
||
SELECT MAX (timestamp) - INTERVAL 1 DAY
|
||
FROM stock_minute
|
||
WHERE code = %(code)s
|
||
)
|
||
ORDER BY timestamp
|
||
""", {
|
||
'code': stock_code
|
||
})
|
||
data = recent_data
|
||
|
||
# 格式化数据
|
||
minute_data = []
|
||
for row in data:
|
||
minute_data.append({
|
||
'time': row[0].strftime('%H:%M'),
|
||
'open': float(row[1]) if row[1] else None,
|
||
'high': float(row[2]) if row[2] else None,
|
||
'low': float(row[3]) if row[3] else None,
|
||
'close': float(row[4]) if row[4] else None,
|
||
'volume': float(row[5]) if row[5] else None,
|
||
'amount': float(row[6]) if row[6] else None
|
||
})
|
||
|
||
return minute_data
|
||
|
||
except Exception as e:
|
||
print(f"Error fetching minute data for {stock_code}: {e}")
|
||
return []
|
||
|
||
# ==================== 性能优化:批量查询所有股票数据 ====================
|
||
# 1. 收集所有股票代码
|
||
stock_codes = [stock.stock_code for stock in related_stocks]
|
||
|
||
# 2. 批量查询股票基本信息
|
||
stock_info_map = {}
|
||
if stock_codes:
|
||
stock_infos = StockBasicInfo.query.filter(StockBasicInfo.SECCODE.in_(stock_codes)).all()
|
||
for info in stock_infos:
|
||
stock_info_map[info.SECCODE] = info
|
||
|
||
# 处理不带后缀的股票代码
|
||
base_codes = [code.split('.')[0] for code in stock_codes if '.' in code and code not in stock_info_map]
|
||
if base_codes:
|
||
base_infos = StockBasicInfo.query.filter(StockBasicInfo.SECCODE.in_(base_codes)).all()
|
||
for info in base_infos:
|
||
# 将不带后缀的信息映射到带后缀的代码
|
||
for code in stock_codes:
|
||
if code.split('.')[0] == info.SECCODE and code not in stock_info_map:
|
||
stock_info_map[code] = info
|
||
|
||
# 3. 批量查询 ClickHouse 数据(价格、涨跌幅、分时图数据)
|
||
price_data_map = {} # 存储价格和涨跌幅数据
|
||
minute_chart_map = {} # 存储分时图数据
|
||
|
||
try:
|
||
if stock_codes:
|
||
print(f"批量查询 {len(stock_codes)} 只股票的价格数据...")
|
||
|
||
# 3.1 批量查询价格和涨跌幅数据(使用子查询方式,避免窗口函数与 GROUP BY 冲突)
|
||
batch_price_query = """
|
||
WITH first_prices AS (SELECT code,
|
||
close as first_price \
|
||
, ROW_NUMBER() OVER (PARTITION BY code ORDER BY timestamp ASC) as rn
|
||
FROM stock_minute
|
||
WHERE code IN %(codes)s
|
||
AND timestamp >= %(start)s
|
||
AND timestamp <= %(end)s
|
||
) \
|
||
, last_prices AS (
|
||
SELECT
|
||
code, close as last_price, open as open_price, high as high_price, low as low_price, volume, amt as amount, ROW_NUMBER() OVER (PARTITION BY code ORDER BY timestamp DESC) as rn
|
||
FROM stock_minute
|
||
WHERE code IN %(codes)s
|
||
AND timestamp >= %(start)s
|
||
AND timestamp <= %(end)s
|
||
)
|
||
SELECT fp.code, \
|
||
fp.first_price, \
|
||
lp.last_price, \
|
||
(lp.last_price - fp.first_price) / fp.first_price * 100 as change_pct, \
|
||
lp.open_price, \
|
||
lp.high_price, \
|
||
lp.low_price, \
|
||
lp.volume, \
|
||
lp.amount
|
||
FROM first_prices fp
|
||
INNER JOIN last_prices lp ON fp.code = lp.code
|
||
WHERE fp.rn = 1 \
|
||
AND lp.rn = 1 \
|
||
"""
|
||
|
||
price_data = client.execute(batch_price_query, {
|
||
'codes': stock_codes,
|
||
'start': start_datetime,
|
||
'end': end_datetime
|
||
})
|
||
|
||
print(f"批量查询返回 {len(price_data)} 条价格数据")
|
||
|
||
# 解析批量查询结果
|
||
for row in price_data:
|
||
code = row[0]
|
||
first_price = float(row[1]) if row[1] is not None else None
|
||
last_price = float(row[2]) if row[2] is not None else None
|
||
change_pct = float(row[3]) if row[3] is not None else None
|
||
open_price = float(row[4]) if row[4] is not None else None
|
||
high_price = float(row[5]) if row[5] is not None else None
|
||
low_price = float(row[6]) if row[6] is not None else None
|
||
volume = int(row[7]) if row[7] is not None else None
|
||
amount = float(row[8]) if row[8] is not None else None
|
||
|
||
change_amount = None
|
||
if last_price is not None and first_price is not None:
|
||
change_amount = last_price - first_price
|
||
|
||
price_data_map[code] = {
|
||
'latest_price': last_price,
|
||
'first_price': first_price,
|
||
'change_pct': change_pct,
|
||
'change_amount': change_amount,
|
||
'open_price': open_price,
|
||
'high_price': high_price,
|
||
'low_price': low_price,
|
||
'volume': volume,
|
||
'amount': amount,
|
||
}
|
||
|
||
# 3.2 批量查询分时图数据
|
||
print(f"批量查询分时图数据...")
|
||
minute_chart_query = """
|
||
SELECT code, timestamp, open, high, low, close, volume, amt
|
||
FROM stock_minute
|
||
WHERE code IN %(codes)s
|
||
AND timestamp >= %(start)s
|
||
AND timestamp <= %(end)s
|
||
ORDER BY code, timestamp \
|
||
"""
|
||
|
||
minute_data = client.execute(minute_chart_query, {
|
||
'codes': stock_codes,
|
||
'start': start_datetime,
|
||
'end': end_datetime
|
||
})
|
||
|
||
print(f"批量查询返回 {len(minute_data)} 条分时数据")
|
||
|
||
# 按股票代码分组分时数据
|
||
for row in minute_data:
|
||
code = row[0]
|
||
if code not in minute_chart_map:
|
||
minute_chart_map[code] = []
|
||
|
||
minute_chart_map[code].append({
|
||
'time': row[1].strftime('%H:%M'),
|
||
'open': float(row[2]) if row[2] else None,
|
||
'high': float(row[3]) if row[3] else None,
|
||
'low': float(row[4]) if row[4] else None,
|
||
'close': float(row[5]) if row[5] else None,
|
||
'volume': float(row[6]) if row[6] else None,
|
||
'amount': float(row[7]) if row[7] else None
|
||
})
|
||
|
||
except Exception as e:
|
||
print(f"批量查询 ClickHouse 失败: {e}")
|
||
# 如果批量查询失败,price_data_map 和 minute_chart_map 为空,后续会使用降级方案
|
||
|
||
# 4. 组装每个股票的数据(从批量查询结果中获取)
|
||
stocks_data = []
|
||
for stock in related_stocks:
|
||
print(f"正在组装股票 {stock.stock_code} 的数据...")
|
||
|
||
# 从批量查询结果中获取股票基本信息
|
||
stock_info = stock_info_map.get(stock.stock_code)
|
||
|
||
# 从批量查询结果中获取价格数据
|
||
price_info = price_data_map.get(stock.stock_code)
|
||
|
||
latest_price = None
|
||
first_price = None
|
||
change_pct = None
|
||
change_amount = None
|
||
open_price = None
|
||
high_price = None
|
||
low_price = None
|
||
volume = None
|
||
amount = None
|
||
trade_date = trading_day
|
||
|
||
if price_info:
|
||
# 使用批量查询的结果
|
||
latest_price = price_info['latest_price']
|
||
first_price = price_info['first_price']
|
||
change_pct = price_info['change_pct']
|
||
change_amount = price_info['change_amount']
|
||
open_price = price_info['open_price']
|
||
high_price = price_info['high_price']
|
||
low_price = price_info['low_price']
|
||
volume = price_info['volume']
|
||
amount = price_info['amount']
|
||
else:
|
||
# 如果批量查询没有返回数据,使用降级方案(TradeData)
|
||
print(f"股票 {stock.stock_code} 批量查询无数据,使用降级方案...")
|
||
try:
|
||
latest_trade = None
|
||
search_codes = [stock.stock_code, stock.stock_code.split('.')[0]]
|
||
|
||
for code in search_codes:
|
||
latest_trade = TradeData.query.filter_by(SECCODE=code) \
|
||
.order_by(TradeData.TRADEDATE.desc()).first()
|
||
if latest_trade:
|
||
break
|
||
|
||
if latest_trade:
|
||
latest_price = float(latest_trade.F007N) if latest_trade.F007N else None
|
||
open_price = float(latest_trade.F003N) if latest_trade.F003N else None
|
||
high_price = float(latest_trade.F005N) if latest_trade.F005N else None
|
||
low_price = float(latest_trade.F006N) if latest_trade.F006N else None
|
||
first_price = float(latest_trade.F002N) if latest_trade.F002N else None
|
||
volume = float(latest_trade.F004N) if latest_trade.F004N else None
|
||
amount = float(latest_trade.F011N) if latest_trade.F011N else None
|
||
trade_date = latest_trade.TRADEDATE
|
||
|
||
# 计算涨跌幅
|
||
if latest_trade.F010N:
|
||
change_pct = float(latest_trade.F010N)
|
||
if latest_trade.F009N:
|
||
change_amount = float(latest_trade.F009N)
|
||
except Exception as fallback_error:
|
||
print(f"降级查询也失败 {stock.stock_code}: {fallback_error}")
|
||
|
||
# 从批量查询结果中获取分时图数据
|
||
minute_chart_data = minute_chart_map.get(stock.stock_code, [])
|
||
|
||
stock_data = {
|
||
'id': stock.id,
|
||
'stock_code': stock.stock_code,
|
||
'stock_name': stock.stock_name,
|
||
'sector': stock.sector,
|
||
'relation_desc': stock.relation_desc,
|
||
'correlation': stock.correlation,
|
||
'momentum': stock.momentum,
|
||
'listing_date': stock_info.F006D.isoformat() if stock_info and stock_info.F006D else None,
|
||
'market': stock_info.F005V if stock_info else None,
|
||
|
||
# 交易数据
|
||
'trade_data': {
|
||
'latest_price': latest_price,
|
||
'first_price': first_price, # 事件发生时的价格
|
||
'open_price': open_price,
|
||
'high_price': high_price,
|
||
'low_price': low_price,
|
||
'change_amount': round(change_amount, 2) if change_amount is not None else None,
|
||
'change_pct': round(change_pct, 2) if change_pct is not None else None,
|
||
'volume': volume,
|
||
'amount': amount,
|
||
'trade_date': trade_date.isoformat() if trade_date else None,
|
||
'event_start_time': start_datetime.isoformat() if start_datetime else None, # 事件开始时间
|
||
'event_end_time': end_datetime.isoformat() if end_datetime else None, # 查询结束时间
|
||
} if latest_price is not None else None,
|
||
|
||
# 分时图数据
|
||
'minute_chart_data': minute_chart_data,
|
||
|
||
# 图表URL
|
||
'charts': {
|
||
'minute_chart_url': f"/api/stock/{stock.stock_code}/minute-chart",
|
||
'daily_chart_url': f"/api/stock/{stock.stock_code}/kline",
|
||
}
|
||
}
|
||
|
||
stocks_data.append(stock_data)
|
||
|
||
return jsonify({
|
||
'code': 200,
|
||
'message': 'success',
|
||
'data': {
|
||
'event_id': event_id,
|
||
'event_title': event.title,
|
||
'event_desc': event.description,
|
||
'event_type': event.event_type,
|
||
'event_importance': event.importance,
|
||
'event_status': event.status,
|
||
'event_created_at': event.created_at.strftime("%Y-%m-%d %H:%M:%S"),
|
||
'event_start_time': event.start_time.isoformat() if event.start_time else None,
|
||
'event_end_time': event.end_time.isoformat() if event.end_time else None,
|
||
'keywords': event.keywords,
|
||
'view_count': event.view_count,
|
||
'post_count': event.post_count,
|
||
'follower_count': event.follower_count,
|
||
'related_stocks': stocks_data,
|
||
'total_count': len(stocks_data)
|
||
}
|
||
})
|
||
except Exception as e:
|
||
return jsonify({
|
||
'code': 500,
|
||
'message': str(e),
|
||
'data': None
|
||
}), 500
|
||
|
||
|
||
@app.route('/api/stock/<stock_code>/minute-chart', methods=['GET'])
|
||
def get_minute_chart_data(stock_code):
|
||
"""获取股票分时图数据 - 仅限 Pro/Max 会员"""
|
||
client = get_clickhouse_client()
|
||
try:
|
||
# 获取当前日期或最新交易日的分时数据
|
||
from datetime import datetime, timedelta, time as dt_time
|
||
today = datetime.now().date()
|
||
|
||
# 获取最新交易日的分时数据
|
||
data = client.execute("""
|
||
SELECT
|
||
timestamp, open, high, low, close, volume, amt
|
||
FROM stock_minute
|
||
WHERE code = %(code)s
|
||
AND timestamp >= %(start)s
|
||
AND timestamp <= %(end)s
|
||
ORDER BY timestamp
|
||
""", {
|
||
'code': stock_code,
|
||
'start': datetime.combine(today, dt_time(9, 30)),
|
||
'end': datetime.combine(today, dt_time(15, 0))
|
||
})
|
||
|
||
# 如果今天没有数据,获取最近的交易日数据
|
||
if not data:
|
||
# 获取最近的交易日数据
|
||
recent_data = client.execute("""
|
||
SELECT
|
||
timestamp, open, high, low, close, volume, amt
|
||
FROM stock_minute
|
||
WHERE code = %(code)s
|
||
AND timestamp >= (
|
||
SELECT MAX (timestamp) - INTERVAL 1 DAY
|
||
FROM stock_minute
|
||
WHERE code = %(code)s
|
||
)
|
||
ORDER BY timestamp
|
||
""", {
|
||
'code': stock_code
|
||
})
|
||
data = recent_data
|
||
|
||
# 格式化数据
|
||
minute_data = []
|
||
for row in data:
|
||
minute_data.append({
|
||
'time': row[0].strftime('%H:%M'),
|
||
'open': float(row[1]) if row[1] else None,
|
||
'high': float(row[2]) if row[2] else None,
|
||
'low': float(row[3]) if row[3] else None,
|
||
'close': float(row[4]) if row[4] else None,
|
||
'volume': float(row[5]) if row[5] else None,
|
||
'amount': float(row[6]) if row[6] else None
|
||
})
|
||
|
||
return minute_data
|
||
except Exception as e:
|
||
print(f"Error getting minute chart data: {e}")
|
||
return []
|
||
|
||
|
||
@app.route('/api/event/<int:event_id>/stock/<stock_code>/detail', methods=['GET'])
|
||
def api_stock_detail(event_id, stock_code):
|
||
"""个股详情接口 - 仅限 Pro/Max 会员"""
|
||
try:
|
||
# 验证事件是否存在
|
||
event = Event.query.get_or_404(event_id)
|
||
|
||
# 获取查询参数
|
||
include_minute_data = request.args.get('include_minute_data', 'true').lower() == 'true'
|
||
include_full_sources = request.args.get('include_full_sources', 'false').lower() == 'true' # 是否包含完整研报来源
|
||
|
||
# 获取股票基本信息
|
||
basic_info = None
|
||
base_code = stock_code.split('.')[0] # 去掉后缀
|
||
|
||
# 按优先级查找股票信息
|
||
basic_info = StockBasicInfo.query.filter_by(SECCODE=stock_code).first()
|
||
if not basic_info:
|
||
basic_info = StockBasicInfo.query.filter(
|
||
StockBasicInfo.SECCODE.ilike(f"{stock_code}%")
|
||
).first()
|
||
if not basic_info:
|
||
basic_info = StockBasicInfo.query.filter(
|
||
StockBasicInfo.SECCODE.ilike(f"{base_code}%")
|
||
).first()
|
||
|
||
company_info = CompanyInfo.query.filter_by(SECCODE=stock_code).first()
|
||
if not company_info:
|
||
company_info = CompanyInfo.query.filter_by(SECCODE=base_code).first()
|
||
|
||
if not basic_info:
|
||
return jsonify({
|
||
'code': 404,
|
||
'stock_code': stock_code,
|
||
'message': '股票不存在',
|
||
'data': None
|
||
}), 404
|
||
|
||
# 获取最新交易数据
|
||
latest_trade = TradeData.query.filter_by(SECCODE=stock_code) \
|
||
.order_by(TradeData.TRADEDATE.desc()).first()
|
||
if not latest_trade:
|
||
latest_trade = TradeData.query.filter_by(SECCODE=base_code) \
|
||
.order_by(TradeData.TRADEDATE.desc()).first()
|
||
|
||
# 获取分时数据
|
||
minute_chart_data = []
|
||
if include_minute_data:
|
||
minute_chart_data = get_minute_chart_data(stock_code)
|
||
|
||
# 获取该事件的相关描述
|
||
related_stock = RelatedStock.query.filter_by(
|
||
event_id=event_id
|
||
).filter(
|
||
db.or_(
|
||
RelatedStock.stock_code == stock_code,
|
||
RelatedStock.stock_code == base_code,
|
||
RelatedStock.stock_code.like(f"{base_code}.%")
|
||
)
|
||
).first()
|
||
|
||
related_desc = None
|
||
if related_stock:
|
||
# 处理研报来源数据
|
||
retrieved_sources_data = None
|
||
sources_summary = None
|
||
|
||
if related_stock.retrieved_sources:
|
||
try:
|
||
# 解析研报来源
|
||
import json
|
||
sources = related_stock.retrieved_sources if isinstance(related_stock.retrieved_sources,
|
||
list) else json.loads(
|
||
related_stock.retrieved_sources)
|
||
|
||
# 统计信息
|
||
sources_summary = {
|
||
'total_count': len(sources),
|
||
'has_sources': True,
|
||
'match_scores': {}
|
||
}
|
||
|
||
# 统计匹配分数分布
|
||
for source in sources:
|
||
score = source.get('match_score', '未知')
|
||
sources_summary['match_scores'][score] = sources_summary['match_scores'].get(score, 0) + 1
|
||
|
||
# 根据参数决定返回完整数据还是摘要
|
||
if include_full_sources:
|
||
# 返回完整的研报来源
|
||
retrieved_sources_data = sources
|
||
else:
|
||
# 只返回前5条高质量来源作为预览
|
||
# 优先返回匹配度高的
|
||
high_quality_sources = [s for s in sources if s.get('match_score') == '好'][:3]
|
||
medium_quality_sources = [s for s in sources if s.get('match_score') == '中'][:2]
|
||
|
||
preview_sources = high_quality_sources + medium_quality_sources
|
||
if not preview_sources: # 如果没有高中匹配度的,返回前5条
|
||
preview_sources = sources[:5]
|
||
|
||
retrieved_sources_data = []
|
||
for source in preview_sources:
|
||
retrieved_sources_data.append({
|
||
'report_title': source.get('report_title', ''),
|
||
'author': source.get('author', ''),
|
||
'sentences': source.get('sentences', '')[:200] + '...' if len(
|
||
source.get('sentences', '')) > 200 else source.get('sentences', ''), # 限制长度
|
||
'match_score': source.get('match_score', ''),
|
||
'declare_date': source.get('declare_date', '')
|
||
})
|
||
|
||
except Exception as e:
|
||
print(f"Error processing retrieved_sources for stock {stock_code}: {e}")
|
||
sources_summary = {'has_sources': False, 'error': str(e)}
|
||
else:
|
||
sources_summary = {'has_sources': False, 'total_count': 0}
|
||
|
||
related_desc = {
|
||
'event_id': related_stock.event_id,
|
||
'relation_desc': related_stock.relation_desc,
|
||
'sector': related_stock.sector,
|
||
'correlation': float(related_stock.correlation) if related_stock.correlation else None,
|
||
'momentum': related_stock.momentum,
|
||
|
||
# 新增研报来源相关字段
|
||
'retrieved_sources': retrieved_sources_data,
|
||
'sources_summary': sources_summary,
|
||
'retrieved_update_time': related_stock.retrieved_update_time.isoformat() if related_stock.retrieved_update_time else None,
|
||
|
||
# 添加获取完整来源的URL
|
||
'sources_detail_url': f"/api/event/{event_id}/stock/{stock_code}/sources" if sources_summary.get(
|
||
'has_sources') else None
|
||
}
|
||
|
||
response_data = {
|
||
'code': 200,
|
||
'message': 'success',
|
||
'data': {
|
||
'event_info': {
|
||
'event_id': event.id,
|
||
'event_title': event.title,
|
||
'event_description': event.description,
|
||
'event_start_time': event.start_time.isoformat() if event.start_time else None,
|
||
'event_created_at': event.created_at.strftime("%Y-%m-%d %H:%M:%S") if event.created_at else None
|
||
},
|
||
'basic_info': {
|
||
'stock_code': basic_info.SECCODE,
|
||
'stock_name': basic_info.SECNAME,
|
||
'org_name': basic_info.ORGNAME,
|
||
'pinyin': basic_info.F001V,
|
||
'category': basic_info.F003V,
|
||
'market': basic_info.F005V,
|
||
'listing_date': basic_info.F006D.isoformat() if basic_info.F006D else None,
|
||
'status': basic_info.F011V
|
||
},
|
||
'company_info': {
|
||
'english_name': company_info.F001V if company_info else None,
|
||
'legal_representative': company_info.F003V if company_info else None,
|
||
'main_business': company_info.F015V if company_info else None,
|
||
'business_scope': company_info.F016V if company_info else None,
|
||
'company_intro': company_info.F017V if company_info else None,
|
||
'csrc_industry_l1': company_info.F030V if company_info else None,
|
||
'csrc_industry_l2': company_info.F032V if company_info else None
|
||
},
|
||
'latest_trade': {
|
||
'trade_date': latest_trade.TRADEDATE.isoformat() if latest_trade else None,
|
||
'close_price': float(latest_trade.F007N) if latest_trade and latest_trade.F007N else None,
|
||
'change': float(latest_trade.F009N) if latest_trade and latest_trade.F009N else None,
|
||
'change_pct': float(latest_trade.F010N) if latest_trade and latest_trade.F010N else None,
|
||
'volume': float(latest_trade.F004N) if latest_trade and latest_trade.F004N else None,
|
||
'amount': float(latest_trade.F011N) if latest_trade and latest_trade.F011N else None
|
||
} if latest_trade else None,
|
||
'minute_chart_data': minute_chart_data,
|
||
'related_desc': related_desc
|
||
}
|
||
}
|
||
|
||
response = jsonify(response_data)
|
||
response.headers['Content-Type'] = 'application/json; charset=utf-8'
|
||
return response
|
||
|
||
except Exception as e:
|
||
return jsonify({
|
||
'code': 500,
|
||
'message': str(e),
|
||
'data': None
|
||
}), 500
|
||
|
||
|
||
def get_stock_minute_chart_data(stock_code):
|
||
"""获取股票分时图数据"""
|
||
try:
|
||
client = get_clickhouse_client()
|
||
|
||
# 获取当前日期(使用最新的交易日)
|
||
from datetime import datetime, timedelta, time as dt_time
|
||
import csv
|
||
|
||
def get_trading_days():
|
||
trading_days = set()
|
||
with open('tdays.csv', 'r') as f:
|
||
reader = csv.DictReader(f)
|
||
for row in reader:
|
||
trading_days.add(datetime.strptime(row['DateTime'], '%Y/%m/%d').date())
|
||
return trading_days
|
||
|
||
trading_days = get_trading_days()
|
||
|
||
def find_latest_trading_day(current_date):
|
||
"""找到最新的交易日"""
|
||
while current_date >= min(trading_days):
|
||
if current_date in trading_days:
|
||
return current_date
|
||
current_date -= timedelta(days=1)
|
||
return None
|
||
|
||
target_date = find_latest_trading_day(datetime.now().date())
|
||
|
||
if not target_date:
|
||
return []
|
||
|
||
# 获取分时数据
|
||
data = client.execute("""
|
||
SELECT
|
||
timestamp, open, high, low, close, volume, amt
|
||
FROM stock_minute
|
||
WHERE code = %(code)s
|
||
AND timestamp BETWEEN %(start)s
|
||
AND %(end)s
|
||
ORDER BY timestamp
|
||
""", {
|
||
'code': stock_code,
|
||
'start': datetime.combine(target_date, dt_time(9, 30)),
|
||
'end': datetime.combine(target_date, dt_time(15, 0))
|
||
})
|
||
|
||
minute_data = []
|
||
for row in data:
|
||
minute_data.append({
|
||
'time': row[0].strftime('%H:%M'),
|
||
'open': float(row[1]),
|
||
'high': float(row[2]),
|
||
'low': float(row[3]),
|
||
'close': float(row[4]),
|
||
'volume': float(row[5]),
|
||
'amount': float(row[6])
|
||
})
|
||
|
||
return minute_data
|
||
|
||
except Exception as e:
|
||
print(f"Error getting minute chart data: {e}")
|
||
return []
|
||
|
||
|
||
# 7. 事件详情-相关概念接口
|
||
@app.route('/api/event/<int:event_id>/related-concepts', methods=['GET'])
|
||
def api_event_related_concepts(event_id):
|
||
"""事件相关概念接口"""
|
||
try:
|
||
event = Event.query.get_or_404(event_id)
|
||
related_concepts = event.related_concepts.all()
|
||
base_url = request.host_url
|
||
|
||
concepts_data = []
|
||
for concept in related_concepts:
|
||
image_paths = concept.image_paths_list
|
||
image_urls = [base_url + 'data/concepts/' + p for p in image_paths]
|
||
concepts_data.append({
|
||
'id': concept.id,
|
||
'concept_code': concept.concept_code,
|
||
'concept': concept.concept,
|
||
'reason': concept.reason,
|
||
'image_paths': image_paths,
|
||
'image_urls': image_urls,
|
||
'first_image': image_urls[0] if image_urls else None
|
||
})
|
||
|
||
return jsonify({
|
||
'code': 200,
|
||
'message': 'success',
|
||
'data': {
|
||
'event_id': event_id,
|
||
'event_title': event.title,
|
||
'related_concepts': concepts_data,
|
||
'total_count': len(concepts_data)
|
||
}
|
||
})
|
||
except Exception as e:
|
||
return jsonify({
|
||
'code': 500,
|
||
'message': str(e),
|
||
'data': None
|
||
}), 500
|
||
|
||
|
||
# 8. 事件详情-历史事件接口
|
||
@app.route('/api/event/<int:event_id>/historical-events', methods=['GET'])
|
||
def api_event_historical_events(event_id):
|
||
"""事件历史事件接口"""
|
||
try:
|
||
event = Event.query.get_or_404(event_id)
|
||
historical_events = event.historical_events.order_by(
|
||
HistoricalEvent.importance.desc(),
|
||
HistoricalEvent.event_date.desc()
|
||
).all()
|
||
|
||
events_data = []
|
||
for hist_event in historical_events:
|
||
# 获取相关股票信息
|
||
related_stocks = []
|
||
valid_changes = [] # 用于计算涨跌幅
|
||
|
||
for stock in hist_event.stocks.all():
|
||
base_stock_code = stock.stock_code.split('.')[0]
|
||
# 获取股票当日交易数据
|
||
trade_data = TradeData.query.filter(
|
||
TradeData.SECCODE.startswith(base_stock_code)
|
||
).order_by(TradeData.TRADEDATE.desc()).first()
|
||
|
||
if trade_data and trade_data.F010N is not None:
|
||
daily_change = float(trade_data.F010N)
|
||
valid_changes.append(daily_change)
|
||
else:
|
||
daily_change = None
|
||
|
||
stock_data = {
|
||
'stock_code': stock.stock_code,
|
||
'stock_name': stock.stock_name,
|
||
'relation_desc': stock.relation_desc,
|
||
'correlation': stock.correlation,
|
||
'sector': stock.sector,
|
||
'daily_change': daily_change,
|
||
'has_trade_data': True if trade_data else False
|
||
}
|
||
related_stocks.append(stock_data)
|
||
|
||
# 计算相关股票的平均涨幅和最大涨幅
|
||
avg_change = None
|
||
max_change = None
|
||
if valid_changes:
|
||
avg_change = sum(valid_changes) / len(valid_changes)
|
||
max_change = max(valid_changes)
|
||
|
||
events_data.append({
|
||
'id': hist_event.id,
|
||
'title': hist_event.title,
|
||
'content': hist_event.content,
|
||
'event_date': hist_event.event_date.isoformat() if hist_event.event_date else None,
|
||
'relevance': hist_event.relevance,
|
||
'importance': hist_event.importance,
|
||
'related_stocks': related_stocks,
|
||
# 使用计算得到的涨幅数据
|
||
'related_avg_chg': round(avg_change, 2) if avg_change is not None else None,
|
||
'related_max_chg': round(max_change, 2) if max_change is not None else None
|
||
})
|
||
|
||
# 计算当前事件的相关股票涨幅数据
|
||
current_valid_changes = []
|
||
for stock in event.related_stocks:
|
||
base_stock_code = stock.stock_code.split('.')[0]
|
||
trade_data = TradeData.query.filter(
|
||
TradeData.SECCODE.startswith(base_stock_code)
|
||
).order_by(TradeData.TRADEDATE.desc()).first()
|
||
|
||
if trade_data and trade_data.F010N is not None:
|
||
current_valid_changes.append(float(trade_data.F010N))
|
||
|
||
current_avg_change = None
|
||
current_max_change = None
|
||
if current_valid_changes:
|
||
current_avg_change = sum(current_valid_changes) / len(current_valid_changes)
|
||
current_max_change = max(current_valid_changes)
|
||
|
||
return jsonify({
|
||
'code': 200,
|
||
'message': 'success',
|
||
'data': {
|
||
'event_id': event_id,
|
||
'event_title': event.title,
|
||
'invest_score': event.invest_score,
|
||
'related_avg_chg': round(current_avg_change, 2) if current_avg_change is not None else None,
|
||
'related_max_chg': round(current_max_change, 2) if current_max_change is not None else None,
|
||
'historical_events': events_data,
|
||
'total_count': len(events_data)
|
||
}
|
||
})
|
||
except Exception as e:
|
||
print(f"Error in api_event_historical_events: {str(e)}")
|
||
return jsonify({
|
||
'code': 500,
|
||
'message': str(e),
|
||
'data': None
|
||
}), 500
|
||
|
||
|
||
@app.route('/api/event/<int:event_id>/comments', methods=['GET'])
|
||
def get_event_comments(event_id):
|
||
"""获取事件的所有评论和帖子(嵌套格式)
|
||
|
||
Query参数:
|
||
- page: 页码(默认1)
|
||
- per_page: 每页评论数(默认20)
|
||
- sort: 排序方式(time_desc/time_asc/hot, 默认time_desc)
|
||
- include_posts: 是否包含帖子信息(默认true)
|
||
- reply_limit: 每个评论显示的回复数量限制(默认3)
|
||
|
||
返回:
|
||
{
|
||
"success": true,
|
||
"data": {
|
||
"event": {
|
||
"id": 事件ID,
|
||
"title": "事件标题",
|
||
"description": "事件描述"
|
||
},
|
||
"posts": [帖子信息],
|
||
"total": 总评论数,
|
||
"current_page": 当前页码,
|
||
"total_pages": 总页数,
|
||
"comments": [
|
||
{
|
||
"comment_id": 评论ID,
|
||
"content": "评论内容",
|
||
"created_at": "评论时间",
|
||
"post_id": 所属帖子ID,
|
||
"post_title": "帖子标题",
|
||
"user": {
|
||
"user_id": 用户ID,
|
||
"nickname": "用户昵称",
|
||
"avatar_url": "头像URL"
|
||
},
|
||
"reply_count": 总回复数量,
|
||
"has_more_replies": 是否有更多回复,
|
||
"list": [ # 回复列表
|
||
{
|
||
"comment_id": 回复ID,
|
||
"content": "回复内容",
|
||
"created_at": "回复时间",
|
||
"user": {
|
||
"user_id": 用户ID,
|
||
"nickname": "用户昵称",
|
||
"avatar_url": "头像URL"
|
||
},
|
||
"reply_to": { # 被回复的用户信息
|
||
"user_id": 用户ID,
|
||
"nickname": "用户昵称"
|
||
}
|
||
}
|
||
]
|
||
}
|
||
]
|
||
}
|
||
}
|
||
"""
|
||
try:
|
||
# 获取查询参数
|
||
page = request.args.get('page', 1, type=int)
|
||
per_page = request.args.get('per_page', 20, type=int)
|
||
sort = request.args.get('sort', 'time_desc')
|
||
include_posts = request.args.get('include_posts', 'true').lower() == 'true'
|
||
reply_limit = request.args.get('reply_limit', 3, type=int) # 每个评论显示的回复数限制
|
||
|
||
# 参数验证
|
||
if page < 1:
|
||
page = 1
|
||
if per_page < 1 or per_page > 100:
|
||
per_page = 20
|
||
if reply_limit < 0 or reply_limit > 50: # 限制回复数量
|
||
reply_limit = 3
|
||
|
||
# 获取事件信息
|
||
event = Event.query.get_or_404(event_id)
|
||
|
||
# 获取事件下的所有帖子
|
||
posts_query = Post.query.filter_by(event_id=event_id, status='active') \
|
||
.order_by(Post.is_top.desc(), Post.created_at.desc())
|
||
posts = posts_query.all()
|
||
|
||
# 格式化帖子数据
|
||
posts_data = []
|
||
if include_posts:
|
||
for post in posts:
|
||
posts_data.append({
|
||
'post_id': post.id,
|
||
'title': post.title,
|
||
'content': post.content,
|
||
'content_type': post.content_type,
|
||
'created_at': post.created_at.strftime('%Y-%m-%d %H:%M:%S'),
|
||
'updated_at': post.updated_at.strftime('%Y-%m-%d %H:%M:%S') if post.updated_at else None,
|
||
'likes_count': post.likes_count,
|
||
'comments_count': post.comments_count,
|
||
'view_count': post.view_count,
|
||
'is_top': post.is_top,
|
||
'user': {
|
||
'user_id': post.user.id,
|
||
'username': post.user.username,
|
||
'nickname': post.user.nickname or post.user.username,
|
||
'avatar_url': get_full_avatar_url(post.user.avatar_url),
|
||
'user_level': post.user.user_level,
|
||
'is_verified': post.user.is_verified,
|
||
'is_creator': post.user.is_creator
|
||
}
|
||
})
|
||
|
||
# 获取帖子ID列表用于查询评论
|
||
post_ids = [p.id for p in posts]
|
||
|
||
if not post_ids:
|
||
return jsonify({
|
||
'success': True,
|
||
'data': {
|
||
'event': {
|
||
'id': event.id,
|
||
'title': event.title,
|
||
'description': event.description,
|
||
'event_type': event.event_type,
|
||
'importance': event.importance,
|
||
'status': event.status
|
||
},
|
||
'posts': posts_data,
|
||
'total': 0,
|
||
'current_page': page,
|
||
'total_pages': 0,
|
||
'comments': []
|
||
}
|
||
})
|
||
|
||
# 构建基础查询 - 只查询主评论
|
||
base_query = Comment.query.filter(
|
||
Comment.post_id.in_(post_ids),
|
||
Comment.parent_id == None, # 只查询主评论
|
||
Comment.status == 'active'
|
||
)
|
||
|
||
# 排序处理
|
||
if sort == 'time_asc':
|
||
base_query = base_query.order_by(Comment.created_at.asc())
|
||
elif sort == 'hot':
|
||
# 这里可以根据你的业务逻辑添加热度排序
|
||
base_query = base_query.order_by(Comment.created_at.desc())
|
||
else: # 默认按时间倒序
|
||
base_query = base_query.order_by(Comment.created_at.desc())
|
||
|
||
# 执行分页查询
|
||
pagination = base_query.paginate(page=page, per_page=per_page, error_out=False)
|
||
|
||
# 格式化评论数据(嵌套格式)
|
||
comments_data = []
|
||
for comment in pagination.items:
|
||
# 获取评论的总回复数量
|
||
reply_count = Comment.query.filter_by(
|
||
parent_id=comment.id,
|
||
status='active'
|
||
).count()
|
||
|
||
# 获取指定数量的回复
|
||
replies_query = Comment.query.filter_by(
|
||
parent_id=comment.id,
|
||
status='active'
|
||
).order_by(Comment.created_at.asc()) # 回复按时间正序排列
|
||
|
||
if reply_limit > 0:
|
||
replies = replies_query.limit(reply_limit).all()
|
||
else:
|
||
replies = []
|
||
|
||
# 格式化回复数据 - 作为list字段
|
||
replies_list = []
|
||
for reply in replies:
|
||
# 获取被回复的用户信息(这里是回复主评论,所以reply_to就是主评论的用户)
|
||
reply_to_user = {
|
||
'user_id': comment.user.id,
|
||
'nickname': comment.user.nickname or comment.user.username
|
||
}
|
||
|
||
replies_list.append({
|
||
'comment_id': reply.id,
|
||
'content': reply.content,
|
||
'created_at': reply.created_at.strftime('%Y-%m-%d %H:%M:%S'),
|
||
'user': {
|
||
'user_id': reply.user.id,
|
||
'username': reply.user.username,
|
||
'nickname': reply.user.nickname or reply.user.username,
|
||
'avatar_url': get_full_avatar_url(reply.user.avatar_url),
|
||
'user_level': reply.user.user_level,
|
||
'is_verified': reply.user.is_verified
|
||
},
|
||
'reply_to': reply_to_user
|
||
})
|
||
|
||
# 获取评论所属的帖子信息
|
||
post = comment.post
|
||
|
||
# 构建嵌套格式的评论数据
|
||
comments_data.append({
|
||
'comment_id': comment.id,
|
||
'content': comment.content,
|
||
'created_at': comment.created_at.strftime('%Y-%m-%d %H:%M:%S'),
|
||
'post_id': comment.post_id,
|
||
'post_title': post.title if post else None,
|
||
'post_content_preview': post.content[:100] + '...' if post and len(
|
||
post.content) > 100 else post.content if post else None,
|
||
'user': {
|
||
'user_id': comment.user.id,
|
||
'username': comment.user.username,
|
||
'nickname': comment.user.nickname or comment.user.username,
|
||
'avatar_url': get_full_avatar_url(comment.user.avatar_url),
|
||
'user_level': comment.user.user_level,
|
||
'is_verified': comment.user.is_verified
|
||
},
|
||
'reply_count': reply_count,
|
||
'has_more_replies': reply_count > len(replies_list),
|
||
'list': replies_list # 嵌套的回复列表
|
||
})
|
||
|
||
return jsonify({
|
||
'success': True,
|
||
'data': {
|
||
'event': {
|
||
'id': event.id,
|
||
'title': event.title,
|
||
'description': event.description,
|
||
'event_type': event.event_type,
|
||
'importance': event.importance,
|
||
'status': event.status,
|
||
'created_at': event.created_at.strftime('%Y-%m-%d %H:%M:%S'),
|
||
'hot_score': event.hot_score,
|
||
'view_count': event.view_count,
|
||
'post_count': event.post_count,
|
||
'follower_count': event.follower_count
|
||
},
|
||
'posts': posts_data,
|
||
'posts_count': len(posts_data),
|
||
'total': pagination.total,
|
||
'current_page': pagination.page,
|
||
'total_pages': pagination.pages,
|
||
'comments': comments_data
|
||
}
|
||
})
|
||
|
||
except Exception as e:
|
||
return jsonify({
|
||
'success': False,
|
||
'message': '获取评论列表失败',
|
||
'error': str(e)
|
||
}), 500
|
||
|
||
|
||
@app.route('/api/comment/<int:comment_id>/replies', methods=['GET'])
|
||
def get_comment_replies(comment_id):
|
||
"""获取某条评论的所有回复
|
||
|
||
Query参数:
|
||
- page: 页码(默认1)
|
||
- per_page: 每页回复数(默认20)
|
||
- sort: 排序方式(time_desc/time_asc, 默认time_desc)
|
||
|
||
返回格式:
|
||
{
|
||
"code": 200,
|
||
"message": "success",
|
||
"data": {
|
||
"comment": { # 原评论信息
|
||
"id": 评论ID,
|
||
"content": "评论内容",
|
||
"created_at": "评论时间",
|
||
"user": {
|
||
"id": 用户ID,
|
||
"nickname": "用户昵称",
|
||
"avatar_url": "头像URL"
|
||
}
|
||
},
|
||
"replies": { # 回复信息
|
||
"total": 总回复数,
|
||
"current_page": 当前页码,
|
||
"total_pages": 总页数,
|
||
"items": [
|
||
{
|
||
"id": 回复ID,
|
||
"content": "回复内容",
|
||
"created_at": "回复时间",
|
||
"user": {
|
||
"id": 用户ID,
|
||
"nickname": "用户昵称",
|
||
"avatar_url": "头像URL"
|
||
},
|
||
"reply_to": { # 被回复的用户信息
|
||
"id": 用户ID,
|
||
"nickname": "用户昵称"
|
||
}
|
||
}
|
||
]
|
||
}
|
||
}
|
||
}
|
||
"""
|
||
try:
|
||
# 获取查询参数
|
||
page = request.args.get('page', 1, type=int)
|
||
per_page = request.args.get('per_page', 20, type=int)
|
||
sort = request.args.get('sort', 'time_desc')
|
||
|
||
# 参数验证
|
||
if page < 1:
|
||
page = 1
|
||
if per_page < 1 or per_page > 100:
|
||
per_page = 20
|
||
|
||
# 获取原评论信息
|
||
comment = Comment.query.get_or_404(comment_id)
|
||
if comment.status != 'active':
|
||
return jsonify({
|
||
'code': 404,
|
||
'message': '评论不存在或已被删除',
|
||
'data': None
|
||
}), 404
|
||
|
||
# 构建原评论数据
|
||
comment_data = {
|
||
'comment_id': comment.id,
|
||
'content': comment.content,
|
||
'created_at': comment.created_at.strftime('%Y-%m-%d %H:%M:%S'),
|
||
'user': {
|
||
'user_id': comment.user.id,
|
||
'nickname': comment.user.nickname or comment.user.username,
|
||
'avatar_url': get_full_avatar_url(comment.user.avatar_url), # 修改这里
|
||
}
|
||
}
|
||
|
||
# 构建回复查询
|
||
replies_query = Comment.query.filter_by(
|
||
parent_id=comment_id,
|
||
status='active'
|
||
)
|
||
|
||
# 排序处理
|
||
if sort == 'time_asc':
|
||
replies_query = replies_query.order_by(Comment.created_at.asc())
|
||
else: # 默认按时间倒序
|
||
replies_query = replies_query.order_by(Comment.created_at.desc())
|
||
|
||
# 执行分页查询
|
||
pagination = replies_query.paginate(page=page, per_page=per_page, error_out=False)
|
||
|
||
# 格式化回复数据
|
||
replies_data = []
|
||
for reply in pagination.items:
|
||
# 获取被回复的用户信息
|
||
reply_to_user = None
|
||
if reply.parent_id:
|
||
parent_comment = Comment.query.get(reply.parent_id)
|
||
if parent_comment:
|
||
reply_to_user = {
|
||
'id': parent_comment.user.id,
|
||
'nickname': parent_comment.user.nickname or parent_comment.user.username
|
||
}
|
||
|
||
replies_data.append({
|
||
'reply_id': reply.id,
|
||
'content': reply.content,
|
||
'created_at': reply.created_at.strftime('%Y-%m-%d %H:%M:%S'),
|
||
'user': {
|
||
'user_id': reply.user.id,
|
||
'nickname': reply.user.nickname or reply.user.username,
|
||
'avatar_url': get_full_avatar_url(reply.user.avatar_url), # 修改这里
|
||
},
|
||
'reply_to': reply_to_user
|
||
})
|
||
|
||
return jsonify({
|
||
'code': 200,
|
||
'message': 'success',
|
||
'data': {
|
||
'comment': comment_data,
|
||
'replies': {
|
||
'total': pagination.total,
|
||
'current_page': pagination.page,
|
||
'total_pages': pagination.pages,
|
||
'items': replies_data
|
||
}
|
||
}
|
||
})
|
||
|
||
except Exception as e:
|
||
return jsonify({
|
||
'code': 500,
|
||
'message': str(e),
|
||
'data': None
|
||
}), 500
|
||
|
||
|
||
# 工具函数:处理转义字符,保留 Markdown 格式
|
||
def unescape_markdown_text(text):
|
||
"""
|
||
将数据库中存储的转义字符串转换为真正的换行符和特殊字符
|
||
例如:'\\n\\n#### 标题' -> '\n\n#### 标题'
|
||
"""
|
||
if not text:
|
||
return text
|
||
|
||
# 将转义的换行符转换为真正的换行符
|
||
# 注意:这里处理的是字符串字面量 '\\n',不是转义序列
|
||
text = text.replace('\\n', '\n')
|
||
text = text.replace('\\r', '\r')
|
||
text = text.replace('\\t', '\t')
|
||
|
||
return text.strip()
|
||
|
||
|
||
# 工具函数:清理 Markdown 文本
|
||
def clean_markdown_text(text):
|
||
"""清理文本中的 Markdown 符号和多余的换行符
|
||
|
||
Args:
|
||
text: 原始文本(可能包含 Markdown 符号)
|
||
|
||
Returns:
|
||
清理后的纯文本
|
||
"""
|
||
if not text:
|
||
return text
|
||
|
||
import re
|
||
|
||
# 1. 移除 Markdown 标题符号 (### , ## , # )
|
||
text = re.sub(r'^#{1,6}\s+', '', text, flags=re.MULTILINE)
|
||
|
||
# 2. 移除 Markdown 加粗符号 (**text** 或 __text__)
|
||
text = re.sub(r'\*\*(.+?)\*\*', r'\1', text)
|
||
text = re.sub(r'__(.+?)__', r'\1', text)
|
||
|
||
# 3. 移除 Markdown 斜体符号 (*text* 或 _text_)
|
||
text = re.sub(r'\*(.+?)\*', r'\1', text)
|
||
text = re.sub(r'_(.+?)_', r'\1', text)
|
||
|
||
# 4. 移除 Markdown 列表符号 (- , * , + , 1. )
|
||
text = re.sub(r'^[\s]*[-*+]\s+', '', text, flags=re.MULTILINE)
|
||
text = re.sub(r'^[\s]*\d+\.\s+', '', text, flags=re.MULTILINE)
|
||
|
||
# 5. 移除 Markdown 引用符号 (> )
|
||
text = re.sub(r'^>\s+', '', text, flags=re.MULTILINE)
|
||
|
||
# 6. 移除 Markdown 代码块符号 (``` 或 `)
|
||
text = re.sub(r'```[\s\S]*?```', '', text)
|
||
text = re.sub(r'`(.+?)`', r'\1', text)
|
||
|
||
# 7. 移除 Markdown 链接 ([text](url) -> text)
|
||
text = re.sub(r'\[(.+?)\]\(.+?\)', r'\1', text)
|
||
|
||
# 8. 清理多余的换行符
|
||
# 将多个连续的换行符(\n\n\n...)替换为单个换行符
|
||
text = re.sub(r'\n{3,}', '\n\n', text)
|
||
|
||
# 9. 清理行首行尾的空白字符
|
||
text = re.sub(r'^\s+|\s+$', '', text, flags=re.MULTILINE)
|
||
|
||
# 10. 移除多余的空格(连续多个空格替换为单个空格)
|
||
text = re.sub(r' {2,}', ' ', text)
|
||
|
||
# 11. 清理首尾空白
|
||
text = text.strip()
|
||
|
||
return text
|
||
|
||
|
||
# 10. 投资日历-事件接口(增强版)
|
||
@app.route('/api/calendar/events', methods=['GET'])
|
||
def api_calendar_events():
|
||
"""投资日历事件接口 - 连接 future_events 表 (修正版)"""
|
||
try:
|
||
start_date = request.args.get('start')
|
||
end_date = request.args.get('end')
|
||
importance = request.args.get('importance', 'all')
|
||
category = request.args.get('category', 'all')
|
||
search_query = request.args.get('q', '').strip() # 新增搜索参数
|
||
page = int(request.args.get('page', 1))
|
||
per_page = int(request.args.get('per_page', 10))
|
||
offset = (page - 1) * per_page
|
||
|
||
# 构建基础查询 - 使用 future_events 表
|
||
query = """
|
||
SELECT data_id, \
|
||
calendar_time, \
|
||
type, \
|
||
star, \
|
||
title, \
|
||
former, \
|
||
forecast, \
|
||
fact, \
|
||
related_stocks, \
|
||
concepts, \
|
||
inferred_tag
|
||
FROM future_events
|
||
WHERE 1 = 1 \
|
||
"""
|
||
|
||
params = {}
|
||
|
||
if start_date:
|
||
query += " AND calendar_time >= :start_date"
|
||
params['start_date'] = datetime.fromisoformat(start_date)
|
||
if end_date:
|
||
query += " AND calendar_time <= :end_date"
|
||
params['end_date'] = datetime.fromisoformat(end_date)
|
||
if importance != 'all':
|
||
query += " AND star = :importance"
|
||
params['importance'] = importance
|
||
if category != 'all':
|
||
# category参数用于筛选inferred_tag字段(如"大周期"、"大消费"等)
|
||
query += " AND inferred_tag = :category"
|
||
params['category'] = category
|
||
|
||
# 新增搜索条件
|
||
if search_query:
|
||
# 使用LIKE进行模糊搜索,同时搜索title和related_stocks字段
|
||
# 对于JSON字段,MySQL会将其作为文本进行搜索
|
||
query += """ AND (
|
||
title LIKE :search_pattern
|
||
OR CAST(related_stocks AS CHAR) LIKE :search_pattern
|
||
OR CAST(concepts AS CHAR) LIKE :search_pattern
|
||
)"""
|
||
params['search_pattern'] = f'%{search_query}%'
|
||
|
||
query += " ORDER BY calendar_time LIMIT :limit OFFSET :offset"
|
||
params['limit'] = per_page
|
||
params['offset'] = offset
|
||
|
||
result = db.session.execute(text(query), params)
|
||
events = result.fetchall()
|
||
|
||
# 总数统计(不包含分页)
|
||
count_query = """
|
||
SELECT COUNT(*) as count \
|
||
FROM future_events \
|
||
WHERE 1=1 \
|
||
"""
|
||
count_params = params.copy()
|
||
count_params.pop('limit', None)
|
||
count_params.pop('offset', None)
|
||
|
||
if start_date:
|
||
count_query += " AND calendar_time >= :start_date"
|
||
if end_date:
|
||
count_query += " AND calendar_time <= :end_date"
|
||
if importance != 'all':
|
||
count_query += " AND star = :importance"
|
||
if category != 'all':
|
||
count_query += " AND inferred_tag = :category"
|
||
|
||
# 新增搜索条件到计数查询
|
||
if search_query:
|
||
count_query += """ AND (
|
||
title LIKE :search_pattern
|
||
OR CAST(related_stocks AS CHAR) LIKE :search_pattern
|
||
OR CAST(concepts AS CHAR) LIKE :search_pattern
|
||
)"""
|
||
|
||
total_count_result = db.session.execute(text(count_query), count_params).fetchone()
|
||
total_count = total_count_result.count if total_count_result else 0
|
||
|
||
events_data = []
|
||
for event in events:
|
||
# 解析相关股票
|
||
related_stocks_list = []
|
||
related_avg_chg = 0
|
||
related_max_chg = 0
|
||
related_week_chg = 0
|
||
|
||
# 处理相关股票数据
|
||
if event.related_stocks:
|
||
try:
|
||
import json
|
||
import ast
|
||
|
||
# 使用与detail接口相同的解析逻辑
|
||
if isinstance(event.related_stocks, str):
|
||
try:
|
||
stock_data = json.loads(event.related_stocks)
|
||
except:
|
||
stock_data = ast.literal_eval(event.related_stocks)
|
||
else:
|
||
stock_data = event.related_stocks
|
||
|
||
if stock_data:
|
||
daily_changes = []
|
||
week_changes = []
|
||
|
||
# 处理正确的数据格式 [股票代码, 股票名称, 描述, 分数]
|
||
for stock_info in stock_data:
|
||
if isinstance(stock_info, list) and len(stock_info) >= 2:
|
||
stock_code = stock_info[0] # 股票代码
|
||
stock_name = stock_info[1] # 股票名称
|
||
description = stock_info[2] if len(stock_info) > 2 else ''
|
||
score = stock_info[3] if len(stock_info) > 3 else 0
|
||
else:
|
||
continue
|
||
|
||
if stock_code:
|
||
# 规范化股票代码,移除后缀
|
||
clean_code = stock_code.replace('.SZ', '').replace('.SH', '').replace('.BJ', '')
|
||
|
||
# 使用模糊匹配查询真实的交易数据
|
||
trade_query = """
|
||
SELECT F007N as close_price, F010N as change_pct, TRADEDATE
|
||
FROM ea_trade
|
||
WHERE SECCODE LIKE :stock_code_pattern
|
||
ORDER BY TRADEDATE DESC LIMIT 7 \
|
||
"""
|
||
trade_result = db.session.execute(text(trade_query),
|
||
{'stock_code_pattern': f'{clean_code}%'})
|
||
trade_data = trade_result.fetchall()
|
||
|
||
daily_chg = 0
|
||
week_chg = 0
|
||
|
||
if trade_data:
|
||
# 日涨跌幅(当日)
|
||
daily_chg = float(trade_data[0].change_pct or 0)
|
||
|
||
# 周涨跌幅(5个交易日)
|
||
if len(trade_data) >= 5:
|
||
current_price = float(trade_data[0].close_price or 0)
|
||
week_ago_price = float(trade_data[4].close_price or 0)
|
||
if week_ago_price > 0:
|
||
week_chg = ((current_price - week_ago_price) / week_ago_price) * 100
|
||
|
||
# 收集涨跌幅数据
|
||
daily_changes.append(daily_chg)
|
||
week_changes.append(week_chg)
|
||
|
||
related_stocks_list.append({
|
||
'code': stock_code,
|
||
'name': stock_name,
|
||
'description': description,
|
||
'score': score,
|
||
'daily_chg': daily_chg,
|
||
'week_chg': week_chg
|
||
})
|
||
|
||
# 计算平均收益率
|
||
if daily_changes:
|
||
related_avg_chg = round(sum(daily_changes) / len(daily_changes), 4)
|
||
related_max_chg = round(max(daily_changes), 4)
|
||
|
||
if week_changes:
|
||
related_week_chg = round(sum(week_changes) / len(week_changes), 4)
|
||
|
||
except Exception as e:
|
||
print(f"Error processing related stocks for event {event.data_id}: {e}")
|
||
|
||
# 解析相关概念
|
||
related_concepts = extract_concepts_from_concepts_field(event.concepts)
|
||
|
||
# 获取评星等级
|
||
star_rating = event.star
|
||
|
||
# 如果有搜索关键词,可以高亮显示匹配的部分(可选功能)
|
||
highlight_match = False
|
||
if search_query:
|
||
# 检查是否在标题中匹配
|
||
if search_query.lower() in (event.title or '').lower():
|
||
highlight_match = 'title'
|
||
# 检查是否在股票中匹配
|
||
elif any(search_query.lower() in str(stock).lower() for stock in related_stocks_list):
|
||
highlight_match = 'stocks'
|
||
# 检查是否在概念中匹配
|
||
elif search_query.lower() in str(related_concepts).lower():
|
||
highlight_match = 'concepts'
|
||
|
||
# 将转义的换行符转换为真正的换行符,保留 Markdown 格式
|
||
cleaned_former = unescape_markdown_text(event.former)
|
||
cleaned_forecast = unescape_markdown_text(event.forecast)
|
||
cleaned_fact = unescape_markdown_text(event.fact)
|
||
|
||
event_dict = {
|
||
'id': event.data_id,
|
||
'title': event.title,
|
||
'description': f"前值: {cleaned_former}, 预测: {cleaned_forecast}, 实际: {cleaned_fact}" if cleaned_former or cleaned_forecast or cleaned_fact else "",
|
||
'start_time': event.calendar_time.isoformat() if event.calendar_time else None,
|
||
'end_time': None, # future_events 表没有结束时间
|
||
'category': {
|
||
'event_type': event.type,
|
||
'importance': event.star,
|
||
'star_rating': star_rating,
|
||
'inferred_tag': event.inferred_tag # 添加inferred_tag到返回数据
|
||
},
|
||
'star_rating': star_rating,
|
||
'inferred_tag': event.inferred_tag, # 直接返回行业标签
|
||
'related_concepts': related_concepts,
|
||
'related_stocks': related_stocks_list,
|
||
'related_avg_chg': round(related_avg_chg, 2),
|
||
'related_max_chg': round(related_max_chg, 2),
|
||
'related_week_chg': round(related_week_chg, 2),
|
||
'former': cleaned_former,
|
||
'forecast': cleaned_forecast,
|
||
'fact': cleaned_fact
|
||
}
|
||
|
||
# 可选:添加搜索匹配标记
|
||
if search_query and highlight_match:
|
||
event_dict['search_match'] = highlight_match
|
||
|
||
events_data.append(event_dict)
|
||
|
||
return jsonify({
|
||
'code': 200,
|
||
'message': 'success',
|
||
'data': {
|
||
'events': events_data,
|
||
'total_count': total_count,
|
||
'page': page,
|
||
'per_page': per_page,
|
||
'total_pages': (total_count + per_page - 1) // per_page,
|
||
'search_query': search_query # 返回搜索关键词
|
||
}
|
||
})
|
||
|
||
except Exception as e:
|
||
return jsonify({
|
||
'code': 500,
|
||
'message': str(e),
|
||
'data': None
|
||
}), 500
|
||
|
||
|
||
# 11. 投资日历-数据接口
|
||
@app.route('/api/calendar/data', methods=['GET'])
|
||
def api_calendar_data():
|
||
"""投资日历数据接口"""
|
||
try:
|
||
start_date = request.args.get('start')
|
||
end_date = request.args.get('end')
|
||
data_type = request.args.get('type', 'all')
|
||
|
||
# 分页参数
|
||
page = int(request.args.get('page', 1))
|
||
page_size = int(request.args.get('page_size', 20)) # 默认每页20条
|
||
|
||
# 验证分页参数
|
||
if page < 1:
|
||
page = 1
|
||
if page_size < 1 or page_size > 100: # 限制每页最大100条
|
||
page_size = 20
|
||
|
||
query1 = RelatedData.query
|
||
|
||
if start_date:
|
||
query1 = query1.filter(RelatedData.created_at >= datetime.fromisoformat(start_date))
|
||
if end_date:
|
||
query1 = query1.filter(RelatedData.created_at <= datetime.fromisoformat(end_date))
|
||
if data_type != 'all':
|
||
query1 = query1.filter_by(data_type=data_type)
|
||
|
||
data_list1 = query1.order_by(RelatedData.created_at.desc()).all()
|
||
|
||
query2_sql = """
|
||
SELECT data_id as id, \
|
||
title, \
|
||
type as data_type, \
|
||
former, \
|
||
forecast, \
|
||
fact, \
|
||
star, \
|
||
calendar_time as created_at
|
||
FROM future_events
|
||
WHERE type = 'data' \
|
||
"""
|
||
|
||
# 添加时间筛选条件
|
||
params = {}
|
||
if start_date:
|
||
query2_sql += " AND calendar_time >= :start_date"
|
||
params['start_date'] = start_date
|
||
if end_date:
|
||
query2_sql += " AND calendar_time <= :end_date"
|
||
params['end_date'] = end_date
|
||
if data_type != 'all':
|
||
query2_sql += " AND type = :data_type"
|
||
params['data_type'] = data_type
|
||
|
||
query2_sql += " ORDER BY calendar_time DESC"
|
||
|
||
result2 = db.session.execute(text(query2_sql), params)
|
||
|
||
result_data = []
|
||
|
||
# 处理 RelatedData 的数据
|
||
for data in data_list1:
|
||
result_data.append({
|
||
'id': data.id,
|
||
'title': data.title,
|
||
'data_type': data.data_type,
|
||
'data_content': data.data_content,
|
||
'description': data.description,
|
||
'created_at': data.created_at.isoformat() if data.created_at else None,
|
||
'event_id': data.event_id,
|
||
'source': 'related_data', # 标识数据来源
|
||
'former': None,
|
||
'forecast': None,
|
||
'fact': None,
|
||
'star': None
|
||
})
|
||
|
||
# 处理 future_events 的数据
|
||
for row in result2:
|
||
result_data.append({
|
||
'id': row.id,
|
||
'title': row.title,
|
||
'data_type': row.data_type,
|
||
'data_content': None,
|
||
'description': None,
|
||
'created_at': row.created_at.isoformat() if row.created_at else None,
|
||
'event_id': None,
|
||
'source': 'future_events', # 标识数据来源
|
||
'former': row.former,
|
||
'forecast': row.forecast,
|
||
'fact': row.fact,
|
||
'star': row.star
|
||
})
|
||
|
||
# 按时间排序(最新的在前面)
|
||
result_data.sort(key=lambda x: x['created_at'] or '1900-01-01', reverse=True)
|
||
|
||
# 计算分页
|
||
total_count = len(result_data)
|
||
total_pages = (total_count + page_size - 1) // page_size # 向上取整
|
||
|
||
# 计算起始和结束索引
|
||
start_index = (page - 1) * page_size
|
||
end_index = start_index + page_size
|
||
|
||
# 获取当前页数据
|
||
current_page_data = result_data[start_index:end_index]
|
||
|
||
# 分别统计两个数据源的数量(用于原有逻辑)
|
||
related_data_count = len(data_list1)
|
||
future_events_count = len(list(result2))
|
||
|
||
return jsonify({
|
||
'code': 200,
|
||
'message': 'success',
|
||
'data': {
|
||
'data_list': current_page_data,
|
||
'pagination': {
|
||
'current_page': page,
|
||
'page_size': page_size,
|
||
'total_count': total_count,
|
||
'total_pages': total_pages,
|
||
'has_next': page < total_pages,
|
||
'has_prev': page > 1
|
||
},
|
||
# 保留原有字段,便于兼容
|
||
'total_count': total_count,
|
||
'related_data_count': related_data_count,
|
||
'future_events_count': future_events_count
|
||
}
|
||
})
|
||
except ValueError as ve:
|
||
# 处理分页参数格式错误
|
||
return jsonify({
|
||
'code': 400,
|
||
'message': f'分页参数格式错误: {str(ve)}',
|
||
'data': None
|
||
}), 400
|
||
except Exception as e:
|
||
return jsonify({
|
||
'code': 500,
|
||
'message': str(e),
|
||
'data': None
|
||
}), 500
|
||
|
||
|
||
# 12. 投资日历-详情接口
|
||
def extract_concepts_from_concepts_field(concepts_text):
|
||
"""从concepts字段中提取概念信息"""
|
||
if not concepts_text:
|
||
return []
|
||
|
||
try:
|
||
import json
|
||
import ast
|
||
|
||
# 解析concepts字段的JSON/字符串数据
|
||
if isinstance(concepts_text, str):
|
||
try:
|
||
# 先尝试JSON解析
|
||
concepts_data = json.loads(concepts_text)
|
||
except:
|
||
# 如果JSON解析失败,尝试ast.literal_eval解析
|
||
concepts_data = ast.literal_eval(concepts_text)
|
||
else:
|
||
concepts_data = concepts_text
|
||
|
||
extracted_concepts = []
|
||
for concept_info in concepts_data:
|
||
if isinstance(concept_info, list) and len(concept_info) >= 3:
|
||
concept_name = concept_info[0] # 概念名称
|
||
reason = concept_info[1] # 原因/描述
|
||
score = concept_info[2] # 分数
|
||
|
||
extracted_concepts.append({
|
||
'name': concept_name,
|
||
'reason': reason,
|
||
'score': score
|
||
})
|
||
|
||
return extracted_concepts
|
||
except Exception as e:
|
||
print(f"Error extracting concepts: {e}")
|
||
return []
|
||
|
||
|
||
@app.route('/api/calendar/detail/<int:item_id>', methods=['GET'])
|
||
def api_future_event_detail(item_id):
|
||
"""未来事件详情接口 - 连接 future_events 表 (修正数据解析) - 仅限 Pro/Max 会员"""
|
||
try:
|
||
# 从 future_events 表查询事件详情
|
||
query = """
|
||
SELECT data_id, \
|
||
calendar_time, \
|
||
type, \
|
||
star, \
|
||
title, \
|
||
former, \
|
||
forecast, \
|
||
fact, \
|
||
related_stocks, \
|
||
concepts
|
||
FROM future_events
|
||
WHERE data_id = :item_id \
|
||
"""
|
||
|
||
result = db.session.execute(text(query), {'item_id': item_id})
|
||
event = result.fetchone()
|
||
|
||
if not event:
|
||
return jsonify({
|
||
'code': 404,
|
||
'message': 'Event not found',
|
||
'data': None
|
||
}), 404
|
||
|
||
extracted_concepts = extract_concepts_from_concepts_field(event.concepts)
|
||
|
||
# 解析相关股票
|
||
related_stocks_list = []
|
||
sector_stats = {
|
||
'全部股票': 0,
|
||
'大周期': 0,
|
||
'大消费': 0,
|
||
'TMT板块': 0,
|
||
'大金融地产': 0,
|
||
'公共产业板块': 0,
|
||
'其他': 0
|
||
}
|
||
|
||
# 申万一级行业到主板块的映射
|
||
sector_map = {
|
||
# 大周期
|
||
'石油石化': '大周期', '煤炭': '大周期', '有色金属': '大周期',
|
||
'钢铁': '大周期', '基础化工': '大周期', '建筑材料': '大周期',
|
||
'机械设备': '大周期', '电力设备及新能源': '大周期', '国防军工': '大周期',
|
||
'电力设备': '大周期', '电网设备': '大周期', '风力发电': '大周期',
|
||
'太阳能发电': '大周期', '建筑装饰': '大周期', '建筑': '大周期',
|
||
'交通运输': '大周期', '采掘': '大周期', '公用事业': '大周期',
|
||
|
||
# 大消费
|
||
'汽车': '大消费', '家用电器': '大消费', '酒类': '大消费',
|
||
'食品饮料': '大消费', '医药生物': '大消费', '纺织服饰': '大消费',
|
||
'农林牧渔': '大消费', '商贸零售': '大消费', '轻工制造': '大消费',
|
||
'消费者服务': '大消费', '美容护理': '大消费', '社会服务': '大消费',
|
||
'纺织服装': '大消费', '商业贸易': '大消费', '休闲服务': '大消费',
|
||
|
||
# 大金融地产
|
||
'银行': '大金融地产', '证券': '大金融地产', '保险': '大金融地产',
|
||
'多元金融': '大金融地产', '综合金融': '大金融地产',
|
||
'房地产': '大金融地产', '非银金融': '大金融地产',
|
||
|
||
# TMT板块
|
||
'计算机': 'TMT板块', '电子': 'TMT板块', '传媒': 'TMT板块', '通信': 'TMT板块',
|
||
|
||
# 公共产业
|
||
'环保': '公共产业板块', '综合': '公共产业板块'
|
||
}
|
||
|
||
# 处理相关股票
|
||
related_avg_chg = 0
|
||
related_max_chg = 0
|
||
related_week_chg = 0
|
||
|
||
if event.related_stocks:
|
||
try:
|
||
import json
|
||
import ast
|
||
|
||
# **修正:正确解析related_stocks数据结构**
|
||
if isinstance(event.related_stocks, str):
|
||
try:
|
||
# 先尝试JSON解析
|
||
stock_data = json.loads(event.related_stocks)
|
||
except:
|
||
# 如果JSON解析失败,尝试ast.literal_eval解析
|
||
stock_data = ast.literal_eval(event.related_stocks)
|
||
else:
|
||
stock_data = event.related_stocks
|
||
|
||
print(f"Parsed stock_data: {stock_data}") # 调试输出
|
||
|
||
if stock_data:
|
||
daily_changes = []
|
||
week_changes = []
|
||
|
||
# **修正:处理正确的数据格式 [股票代码, 股票名称, 描述, 分数]**
|
||
for stock_info in stock_data:
|
||
if isinstance(stock_info, list) and len(stock_info) >= 2:
|
||
stock_code = stock_info[0] # 第一个元素是股票代码
|
||
stock_name = stock_info[1] # 第二个元素是股票名称
|
||
description = stock_info[2] if len(stock_info) > 2 else ''
|
||
score = stock_info[3] if len(stock_info) > 3 else 0
|
||
else:
|
||
continue # 跳过格式不正确的数据
|
||
|
||
if stock_code:
|
||
# 规范化股票代码,移除后缀
|
||
clean_code = stock_code.replace('.SZ', '').replace('.SH', '').replace('.BJ', '')
|
||
|
||
print(f"Processing stock: {clean_code} - {stock_name}") # 调试输出
|
||
|
||
# 使用模糊匹配LIKE查询申万一级行业F004V
|
||
sector_query = """
|
||
SELECT F004V as sw_primary_sector
|
||
FROM ea_sector
|
||
WHERE SECCODE LIKE :stock_code_pattern
|
||
AND F002V = '申银万国行业分类' LIMIT 1 \
|
||
"""
|
||
sector_result = db.session.execute(text(sector_query),
|
||
{'stock_code_pattern': f'{clean_code}%'})
|
||
sector_row = sector_result.fetchone()
|
||
|
||
# 根据申万一级行业(F004V)映射到主板块
|
||
sw_primary_sector = sector_row.sw_primary_sector if sector_row else None
|
||
primary_sector = sector_map.get(sw_primary_sector, '其他') if sw_primary_sector else '其他'
|
||
|
||
print(
|
||
f"Stock: {clean_code}, SW Primary: {sw_primary_sector}, Primary Sector: {primary_sector}")
|
||
|
||
# 通过SQL查询获取真实的日涨跌幅和周涨跌幅
|
||
trade_query = """
|
||
SELECT F007N as close_price, F010N as change_pct, TRADEDATE
|
||
FROM ea_trade
|
||
WHERE SECCODE LIKE :stock_code_pattern
|
||
ORDER BY TRADEDATE DESC LIMIT 7 \
|
||
"""
|
||
trade_result = db.session.execute(text(trade_query),
|
||
{'stock_code_pattern': f'{clean_code}%'})
|
||
trade_data = trade_result.fetchall()
|
||
|
||
daily_chg = 0
|
||
week_chg = 0
|
||
|
||
if trade_data:
|
||
# 日涨跌幅(当日)
|
||
daily_chg = float(trade_data[0].change_pct or 0)
|
||
|
||
# 周涨跌幅(5个交易日)
|
||
if len(trade_data) >= 5:
|
||
current_price = float(trade_data[0].close_price or 0)
|
||
week_ago_price = float(trade_data[4].close_price or 0)
|
||
if week_ago_price > 0:
|
||
week_chg = ((current_price - week_ago_price) / week_ago_price) * 100
|
||
|
||
print(
|
||
f"Trade data found: {len(trade_data) if trade_data else 0} records, daily_chg: {daily_chg}")
|
||
|
||
# 统计各分类数量
|
||
sector_stats['全部股票'] += 1
|
||
sector_stats[primary_sector] += 1
|
||
|
||
# 收集涨跌幅数据
|
||
daily_changes.append(daily_chg)
|
||
week_changes.append(week_chg)
|
||
|
||
related_stocks_list.append({
|
||
'code': stock_code, # 原始股票代码
|
||
'name': stock_name, # 股票名称
|
||
'description': description, # 关联描述
|
||
'score': score, # 关联分数
|
||
'sw_primary_sector': sw_primary_sector, # 申万一级行业(F004V)
|
||
'primary_sector': primary_sector, # 主板块分类
|
||
'daily_change': daily_chg, # 真实的日涨跌幅
|
||
'week_change': week_chg # 真实的周涨跌幅
|
||
})
|
||
|
||
# 计算平均收益率
|
||
if daily_changes:
|
||
related_avg_chg = sum(daily_changes) / len(daily_changes)
|
||
related_max_chg = max(daily_changes)
|
||
|
||
if week_changes:
|
||
related_week_chg = sum(week_changes) / len(week_changes)
|
||
|
||
except Exception as e:
|
||
print(f"Error processing related stocks: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
|
||
# 构建返回数据
|
||
detail_data = {
|
||
'id': event.data_id,
|
||
'title': event.title,
|
||
'type': event.type,
|
||
'star': event.star,
|
||
'calendar_time': event.calendar_time.isoformat() if event.calendar_time else None,
|
||
'former': event.former,
|
||
'forecast': event.forecast,
|
||
'fact': event.fact,
|
||
'concepts': event.concepts,
|
||
'extracted_concepts': extracted_concepts,
|
||
'related_stocks': related_stocks_list,
|
||
'sector_stats': sector_stats,
|
||
'related_avg_chg': round(related_avg_chg, 2),
|
||
'related_max_chg': round(related_max_chg, 2),
|
||
'related_week_chg': round(related_week_chg, 2)
|
||
}
|
||
|
||
return jsonify({
|
||
'code': 200,
|
||
'message': 'success',
|
||
'data': {
|
||
'type': 'future_event',
|
||
'detail': detail_data
|
||
}
|
||
})
|
||
|
||
except Exception as e:
|
||
return jsonify({
|
||
'code': 500,
|
||
'message': str(e),
|
||
'data': None
|
||
}), 500
|
||
|
||
|
||
# 13-15. 筛选弹窗接口(已有,优化格式)
|
||
@app.route('/api/filter/options', methods=['GET'])
|
||
def api_filter_options():
|
||
"""筛选选项接口"""
|
||
try:
|
||
# 获取排序选项
|
||
sort_options = [
|
||
{'key': 'new', 'name': '最新', 'desc': '按创建时间排序'},
|
||
{'key': 'hot', 'name': '热门', 'desc': '按热度分数排序'},
|
||
{'key': 'returns', 'name': '收益率', 'desc': '按收益率排序'},
|
||
{'key': 'importance', 'name': '重要性', 'desc': '按重要性等级排序'},
|
||
{'key': 'view_count', 'name': '浏览量', 'desc': '按浏览次数排序'}
|
||
]
|
||
|
||
# 获取行业筛选选项
|
||
industry_options = db.session.execute(text("""
|
||
SELECT DISTINCT f002v as classification_name, COUNT(*) as count
|
||
FROM ea_sector
|
||
WHERE f002v IS NOT NULL
|
||
GROUP BY f002v
|
||
ORDER BY f002v
|
||
""")).fetchall()
|
||
|
||
# 获取重要性选项
|
||
importance_options = [
|
||
{'key': 'S', 'name': 'S级', 'desc': '重大事件'},
|
||
{'key': 'A', 'name': 'A级', 'desc': '重要事件'},
|
||
{'key': 'B', 'name': 'B级', 'desc': '普通事件'},
|
||
{'key': 'C', 'name': 'C级', 'desc': '参考事件'}
|
||
]
|
||
|
||
return jsonify({
|
||
'code': 200,
|
||
'message': 'success',
|
||
'data': {
|
||
'sort_options': sort_options,
|
||
'industry_options': [{
|
||
'name': row.classification_name,
|
||
'count': row.count
|
||
} for row in industry_options],
|
||
'importance_options': importance_options
|
||
}
|
||
})
|
||
except Exception as e:
|
||
return jsonify({
|
||
'code': 500,
|
||
'message': str(e),
|
||
'data': None
|
||
}), 500
|
||
|
||
|
||
# 16-17. 会员权益接口
|
||
@app.route('/api/membership/status', methods=['GET'])
|
||
@token_required
|
||
def api_membership_status():
|
||
"""会员状态接口"""
|
||
try:
|
||
user = request.user
|
||
|
||
# TODO: 根据实际业务逻辑判断会员状态
|
||
# 这里假设用户表中有会员相关字段
|
||
is_member = getattr(user, 'is_member', False)
|
||
member_expire_date = getattr(user, 'member_expire_date', None)
|
||
|
||
return jsonify({
|
||
'code': 200,
|
||
'message': 'success',
|
||
'data': {
|
||
'user_id': user.id,
|
||
'is_member': is_member,
|
||
'member_expire_date': member_expire_date.isoformat() if member_expire_date else None,
|
||
'user_level': user.user_level,
|
||
'benefits': {
|
||
'unlimited_access': is_member,
|
||
'priority_support': is_member,
|
||
'advanced_analytics': is_member,
|
||
'custom_alerts': is_member
|
||
}
|
||
}
|
||
})
|
||
except Exception as e:
|
||
return jsonify({
|
||
'code': 500,
|
||
'message': str(e),
|
||
'data': None
|
||
}), 500
|
||
|
||
|
||
# 18-19. 个人中心接口
|
||
@app.route('/api/user/profile', methods=['GET'])
|
||
@token_required
|
||
def api_user_profile():
|
||
"""个人资料接口"""
|
||
try:
|
||
user = request.user
|
||
|
||
likes_count = PostLike.query.filter_by(user_id=user.id).count()
|
||
follows_count = EventFollow.query.filter_by(user_id=user.id).count()
|
||
comments_made = Comment.query.filter_by(user_id=user.id).count()
|
||
|
||
comments_received = db.session.query(Comment) \
|
||
.join(Post, Comment.post_id == Post.id) \
|
||
.filter(Post.user_id == user.id).count()
|
||
|
||
replies_received = Comment.query.filter(
|
||
Comment.parent_id.in_(
|
||
db.session.query(Comment.id).filter_by(user_id=user.id)
|
||
)
|
||
).count()
|
||
|
||
# 总评论数(发出的评论 + 收到的评论和回复)
|
||
total_comments = comments_made + comments_received + replies_received
|
||
profile_data = {
|
||
'basic_info': {
|
||
'user_id': user.id,
|
||
'username': user.username,
|
||
'email': user.email,
|
||
'phone': user.phone,
|
||
'nickname': user.nickname,
|
||
'avatar_url': get_full_avatar_url(user.avatar_url), # 修改这里
|
||
'bio': user.bio,
|
||
'gender': user.gender,
|
||
'birth_date': user.birth_date.isoformat() if user.birth_date else None,
|
||
'location': user.location
|
||
},
|
||
'account_status': {
|
||
'email_confirmed': user.email_confirmed,
|
||
'phone_confirmed': user.phone_confirmed,
|
||
'is_verified': user.is_verified,
|
||
'verify_time': user.verify_time.isoformat() if user.verify_time else None,
|
||
'created_at': user.created_at.isoformat() if user.created_at else None,
|
||
'last_seen': user.last_seen.isoformat() if user.last_seen else None
|
||
},
|
||
'statistics': {
|
||
'likes_count': likes_count, # 点赞数
|
||
'follows_count': follows_count, # 关注数
|
||
'total_comments': total_comments, # 总评论数
|
||
'comments_detail': {
|
||
'comments_made': comments_made, # 发出的评论
|
||
'comments_received': comments_received, # 收到的评论
|
||
'replies_received': replies_received # 收到的回复
|
||
}
|
||
},
|
||
'investment_preferences': {
|
||
'trading_experience': user.trading_experience,
|
||
'investment_style': user.investment_style,
|
||
'risk_preference': user.risk_preference,
|
||
'investment_amount': user.investment_amount,
|
||
'preferred_markets': user.preferred_markets
|
||
},
|
||
'community_stats': {
|
||
'user_level': user.user_level,
|
||
'reputation_score': user.reputation_score,
|
||
'contribution_point': user.contribution_point,
|
||
'post_count': user.post_count,
|
||
'comment_count': user.comment_count,
|
||
'follower_count': user.follower_count,
|
||
'following_count': user.following_count
|
||
},
|
||
'settings': {
|
||
'email_notifications': user.email_notifications,
|
||
'sms_notifications': user.sms_notifications,
|
||
'privacy_level': user.privacy_level,
|
||
'theme_preference': user.theme_preference
|
||
}
|
||
}
|
||
|
||
return jsonify({
|
||
'code': 200,
|
||
'message': 'success',
|
||
'data': profile_data
|
||
})
|
||
except Exception as e:
|
||
return jsonify({
|
||
'code': 500,
|
||
'message': str(e),
|
||
'data': None
|
||
}), 500
|
||
|
||
|
||
# 在文件开头添加缓存变量
|
||
_agreements_cache = {}
|
||
_cache_loaded = False
|
||
|
||
|
||
def load_agreements_from_docx():
|
||
"""从docx文件中加载协议内容,只读取一次"""
|
||
global _agreements_cache, _cache_loaded
|
||
|
||
if _cache_loaded:
|
||
return _agreements_cache
|
||
|
||
try:
|
||
# 定义文件路径和对应的协议类型
|
||
docx_files = {
|
||
'about_us': 'about_us.docx', # 关于我们
|
||
'service_terms': 'service_terms.docx', # 服务条款
|
||
'privacy_policy': 'privacy_policy.docx' # 隐私政策
|
||
}
|
||
|
||
# 定义协议标题
|
||
titles = {
|
||
'about_us': '关于我们',
|
||
'service_terms': '服务条款',
|
||
'privacy_policy': '隐私政策'
|
||
}
|
||
|
||
for agreement_type, filename in docx_files.items():
|
||
file_path = os.path.join(os.path.dirname(__file__), filename)
|
||
|
||
if os.path.exists(file_path):
|
||
try:
|
||
# 读取docx文件
|
||
doc = Document(file_path)
|
||
|
||
# 提取文本内容
|
||
content_paragraphs = []
|
||
for paragraph in doc.paragraphs:
|
||
if paragraph.text.strip(): # 跳过空段落
|
||
content_paragraphs.append(paragraph.text.strip())
|
||
|
||
# 合并所有段落
|
||
content = '\n\n'.join(content_paragraphs)
|
||
|
||
# 获取文件修改时间作为版本标识
|
||
file_stat = os.stat(file_path)
|
||
last_modified = file_stat.st_mtime
|
||
|
||
# 缓存内容
|
||
_agreements_cache[agreement_type] = {
|
||
'title': titles.get(agreement_type, agreement_type),
|
||
'content': content,
|
||
'last_updated': last_modified,
|
||
'version': '1.0',
|
||
'file_path': filename
|
||
}
|
||
|
||
print(f"Successfully loaded {agreement_type} from {filename}")
|
||
|
||
except Exception as e:
|
||
print(f"Error reading {filename}: {str(e)}")
|
||
# 如果读取失败,使用默认内容
|
||
_agreements_cache[agreement_type] = {
|
||
'title': titles.get(agreement_type, agreement_type),
|
||
'content': f"协议内容正在加载中,请稍后再试。(文件:{filename})",
|
||
'last_updated': None,
|
||
'version': '1.0',
|
||
'file_path': filename,
|
||
'error': str(e)
|
||
}
|
||
else:
|
||
print(f"File not found: {filename}")
|
||
# 如果文件不存在,使用默认内容
|
||
_agreements_cache[agreement_type] = {
|
||
'title': titles.get(agreement_type, agreement_type),
|
||
'content': f"协议文件未找到,请联系管理员。(文件:{filename})",
|
||
'last_updated': None,
|
||
'version': '1.0',
|
||
'file_path': filename,
|
||
'error': 'File not found'
|
||
}
|
||
|
||
_cache_loaded = True
|
||
print(f"Agreements cache loaded successfully. Total: {len(_agreements_cache)} agreements")
|
||
|
||
except Exception as e:
|
||
print(f"Error loading agreements: {str(e)}")
|
||
_cache_loaded = False
|
||
|
||
return _agreements_cache
|
||
|
||
|
||
@app.route('/api/agreements', methods=['GET'])
|
||
def api_agreements():
|
||
"""平台协议接口 - 从docx文件读取"""
|
||
try:
|
||
# 获取查询参数
|
||
agreement_type = request.args.get('type') # about_us, service_terms, privacy_policy
|
||
force_reload = request.args.get('reload', 'false').lower() == 'true' # 强制重新加载
|
||
|
||
# 如果需要强制重新加载,清除缓存
|
||
if force_reload:
|
||
global _cache_loaded
|
||
_cache_loaded = False
|
||
_agreements_cache.clear()
|
||
|
||
# 加载协议内容
|
||
agreements_data = load_agreements_from_docx()
|
||
|
||
if not agreements_data:
|
||
return jsonify({
|
||
'code': 500,
|
||
'message': 'Failed to load agreements',
|
||
'data': None
|
||
}), 500
|
||
|
||
# 如果指定了特定协议类型,只返回该协议
|
||
if agreement_type and agreement_type in agreements_data:
|
||
return jsonify({
|
||
'code': 200,
|
||
'message': 'success',
|
||
'data': {
|
||
'agreement_type': agreement_type,
|
||
**agreements_data[agreement_type]
|
||
}
|
||
})
|
||
|
||
# 返回所有协议
|
||
return jsonify({
|
||
'code': 200,
|
||
'message': 'success',
|
||
'data': {
|
||
'agreements': agreements_data,
|
||
'available_types': list(agreements_data.keys()),
|
||
'cache_loaded': _cache_loaded,
|
||
'total_agreements': len(agreements_data)
|
||
}
|
||
})
|
||
|
||
except Exception as e:
|
||
return jsonify({
|
||
'code': 500,
|
||
'message': str(e),
|
||
'data': None
|
||
}), 500
|
||
|
||
|
||
# 20. 个人中心-我的关注接口
|
||
@app.route('/api/user/activities', methods=['GET'])
|
||
@token_required
|
||
def api_user_activities():
|
||
"""用户活动接口(我的关注、评论、点赞)"""
|
||
try:
|
||
user = request.user
|
||
activity_type = request.args.get('type', 'follows') # follows, comments, likes, commented
|
||
page = request.args.get('page', 1, type=int)
|
||
per_page = min(50, request.args.get('per_page', 20, type=int))
|
||
|
||
if activity_type == 'follows':
|
||
# 我的关注列表
|
||
follows = EventFollow.query.filter_by(user_id=user.id) \
|
||
.order_by(EventFollow.created_at.desc()) \
|
||
.paginate(page=page, per_page=per_page, error_out=False)
|
||
|
||
activities = []
|
||
for follow in follows.items:
|
||
# 获取相关股票并添加单日涨幅
|
||
related_stocks_data = []
|
||
for stock in follow.event.related_stocks.limit(5):
|
||
# 处理股票代码,移除可能的后缀(如 .SZ 或 .SH)
|
||
base_stock_code = stock.stock_code.split('.')[0]
|
||
|
||
# 获取股票最新交易数据
|
||
trade_data = TradeData.query.filter(
|
||
TradeData.SECCODE.startswith(base_stock_code)
|
||
).order_by(TradeData.TRADEDATE.desc()).first()
|
||
|
||
# 计算单日涨幅
|
||
daily_change = None
|
||
if trade_data and trade_data.F010N is not None:
|
||
daily_change = float(trade_data.F010N)
|
||
|
||
related_stocks_data.append({
|
||
'stock_code': stock.stock_code,
|
||
'stock_name': stock.stock_name,
|
||
'correlation': stock.correlation,
|
||
'daily_change': daily_change, # 新增:单日涨幅
|
||
'daily_change_formatted': f"{daily_change:.2f}%" if daily_change is not None else "暂无数据"
|
||
# 格式化显示
|
||
})
|
||
|
||
activities.append({
|
||
'event_id': follow.event_id,
|
||
'event_title': follow.event.title,
|
||
'event_description': follow.event.description,
|
||
'follow_time': follow.created_at.isoformat() if follow.created_at else None,
|
||
'event_hot_score': follow.event.hot_score,
|
||
# 新增字段
|
||
'importance': follow.event.importance, # 重要性
|
||
'related_avg_chg': follow.event.related_avg_chg, # 平均涨幅
|
||
'related_max_chg': follow.event.related_max_chg, # 最大涨幅
|
||
'related_week_chg': follow.event.related_week_chg, # 周涨幅
|
||
'related_stocks': related_stocks_data, # 修改:包含单日涨幅的相关股票
|
||
'created_at': follow.event.created_at.isoformat() if follow.event.created_at else None, # 发布时间
|
||
'preview': follow.event.description[:200] if follow.event.description else None, # 预览(限制200字)
|
||
'comment_count': follow.event.post_count, # 评论数
|
||
'view_count': follow.event.view_count, # 评论数
|
||
'follower_count': follow.event.follower_count # 关注数
|
||
})
|
||
|
||
total = follows.total
|
||
pages = follows.pages
|
||
|
||
elif activity_type == 'likes':
|
||
# 我的点赞列表
|
||
likes = PostLike.query.filter_by(user_id=user.id) \
|
||
.order_by(PostLike.created_at.desc()) \
|
||
.paginate(page=page, per_page=per_page, error_out=False)
|
||
|
||
activities = [{
|
||
'like_id': like.id,
|
||
'post_id': like.post_id,
|
||
'post_content': like.post.content,
|
||
'like_time': like.created_at.isoformat() if like.created_at else None,
|
||
# 新增发布人信息
|
||
'author': {
|
||
'nickname': like.post.user.nickname or like.post.user.username,
|
||
'avatar_url': get_full_avatar_url(like.post.user.avatar_url), # 修改这里
|
||
}
|
||
} for like in likes.items]
|
||
|
||
total = likes.total
|
||
pages = likes.pages
|
||
|
||
elif activity_type == 'comments':
|
||
# 我的评论列表(增强版 - 添加重要性和事件内容)
|
||
comments = Comment.query.filter_by(user_id=user.id) \
|
||
.join(Post, Comment.post_id == Post.id) \
|
||
.join(Event, Post.event_id == Event.id) \
|
||
.order_by(Comment.created_at.desc()) \
|
||
.paginate(page=page, per_page=per_page, error_out=False)
|
||
|
||
activities = []
|
||
for comment in comments.items:
|
||
# 通过关联路径获取事件信息:comment.post_id -> post.id -> post.event_id -> event.id
|
||
post = comment.post
|
||
event = post.event if post else None
|
||
|
||
activity_data = {
|
||
'comment_id': comment.id,
|
||
'post_id': comment.post_id,
|
||
'content': comment.content, # 评论内容
|
||
'created_at': comment.created_at.isoformat() if comment.created_at else None,
|
||
'post_title': post.title if post and post.title else None,
|
||
'post_content': post.content if post else None,
|
||
|
||
# 新增:评论者信息(当前用户)
|
||
'commenter': {
|
||
'id': comment.user.id,
|
||
'username': comment.user.username,
|
||
'nickname': comment.user.nickname or comment.user.username,
|
||
'avatar_url': get_full_avatar_url(comment.user.avatar_url),
|
||
'user_level': comment.user.user_level,
|
||
'is_verified': comment.user.is_verified
|
||
},
|
||
|
||
# 新增字段:事件信息
|
||
'event': {
|
||
'id': event.id if event else None,
|
||
'title': event.title if event else None,
|
||
'description': event.description if event else None, # 事件内容
|
||
'importance': event.importance if event else None, # 重要性
|
||
'event_type': event.event_type if event else None,
|
||
'hot_score': event.hot_score if event else None,
|
||
'view_count': event.view_count if event else None,
|
||
'related_avg_chg': event.related_avg_chg if event else None,
|
||
'created_at': event.created_at.isoformat() if event and event.created_at else None
|
||
},
|
||
|
||
# 新增:帖子作者信息
|
||
'post_author': {
|
||
'id': post.user.id if post else None,
|
||
'username': post.user.username if post else None,
|
||
'nickname': post.user.nickname or post.user.username if post else None,
|
||
'avatar_url': get_full_avatar_url(post.user.avatar_url) if post else None,
|
||
}
|
||
}
|
||
activities.append(activity_data)
|
||
|
||
total = comments.total
|
||
pages = comments.pages
|
||
|
||
elif activity_type == 'commented':
|
||
# 评论了我的帖子
|
||
my_posts = Post.query.filter_by(user_id=user.id).subquery()
|
||
comments = Comment.query.join(my_posts, Comment.post_id == my_posts.c.id) \
|
||
.filter(Comment.user_id != user.id) \
|
||
.order_by(Comment.created_at.desc()) \
|
||
.paginate(page=page, per_page=per_page, error_out=False)
|
||
|
||
activities = [{
|
||
'comment_id': comment.id,
|
||
'comment_content': comment.content,
|
||
'comment_time': comment.created_at.isoformat() if comment.created_at else None,
|
||
'commenter_nickname': comment.user.nickname or comment.user.username,
|
||
'commenter_avatar': get_full_avatar_url(comment.user.avatar_url), # 修改这里
|
||
'post_content': comment.post.content,
|
||
'event_title': comment.post.event.title,
|
||
'event_id': comment.post.event_id
|
||
} for comment in comments.items]
|
||
|
||
total = comments.total
|
||
pages = comments.pages
|
||
|
||
return jsonify({
|
||
'code': 200,
|
||
'message': 'success',
|
||
'data': {
|
||
'activities': activities,
|
||
'total': total,
|
||
'pages': pages,
|
||
'current_page': page
|
||
}
|
||
})
|
||
|
||
except Exception as e:
|
||
print(f"Error in api_user_activities: {str(e)}")
|
||
return jsonify({
|
||
'code': 500,
|
||
'message': '服务器内部错误',
|
||
'data': None
|
||
}), 500
|
||
|
||
|
||
class UserFeedback(db.Model):
|
||
"""用户反馈模型"""
|
||
id = db.Column(db.Integer, primary_key=True)
|
||
user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False)
|
||
type = db.Column(db.String(50), nullable=False) # 反馈类型
|
||
content = db.Column(db.Text, nullable=False) # 反馈内容
|
||
contact_info = db.Column(db.String(100)) # 联系方式
|
||
status = db.Column(db.String(20), default='pending') # 状态:pending/processing/resolved/closed
|
||
admin_reply = db.Column(db.Text) # 管理员回复
|
||
created_at = db.Column(db.DateTime, default=beijing_now)
|
||
updated_at = db.Column(db.DateTime, default=beijing_now, onupdate=beijing_now)
|
||
|
||
# 关联关系
|
||
user = db.relationship('User', backref='feedbacks')
|
||
|
||
def __init__(self, user_id, type, content, contact_info=None):
|
||
self.user_id = user_id
|
||
self.type = type
|
||
self.content = content
|
||
self.contact_info = contact_info
|
||
|
||
def to_dict(self):
|
||
return {
|
||
'id': self.id,
|
||
'type': self.type,
|
||
'content': self.content,
|
||
'contact_info': self.contact_info,
|
||
'status': self.status,
|
||
'admin_reply': self.admin_reply,
|
||
'created_at': self.created_at.strftime('%Y-%m-%d %H:%M:%S'),
|
||
'updated_at': self.updated_at.strftime('%Y-%m-%d %H:%M:%S')
|
||
}
|
||
|
||
|
||
# 通用错误处理
|
||
@app.errorhandler(404)
|
||
def api_not_found(error):
|
||
if request.path.startswith('/api/'):
|
||
return jsonify({
|
||
'code': 404,
|
||
'message': '接口不存在',
|
||
'data': None
|
||
}), 404
|
||
return error
|
||
|
||
|
||
@app.errorhandler(405)
|
||
def api_method_not_allowed(error):
|
||
if request.path.startswith('/api/'):
|
||
return jsonify({
|
||
'code': 405,
|
||
'message': '请求方法不允许',
|
||
'data': None
|
||
}), 405
|
||
return error
|
||
|
||
|
||
# 应用启动时自动初始化(兼容 Gunicorn 和直接运行)
|
||
_sywg_cache_initialized = False
|
||
|
||
|
||
def ensure_sywg_cache_initialized():
|
||
"""确保申银万国行业分类缓存已初始化(懒加载,首次请求时触发)"""
|
||
global _sywg_cache_initialized
|
||
if not _sywg_cache_initialized:
|
||
init_sywg_industry_cache()
|
||
_sywg_cache_initialized = True
|
||
|
||
|
||
@app.before_request
|
||
def before_request_init():
|
||
"""首次请求时初始化缓存"""
|
||
global _sywg_cache_initialized
|
||
if not _sywg_cache_initialized:
|
||
ensure_sywg_cache_initialized()
|
||
|
||
|
||
# ===================== 应用启动配置 =====================
|
||
# 生产环境推荐使用 Gunicorn 启动(见下方命令)
|
||
#
|
||
# 【启动方式 1】使用 Gunicorn + gevent(推荐):
|
||
# USE_GEVENT=true gunicorn -w 4 -k gevent --worker-connections 1000 \
|
||
# -b 0.0.0.0:5002 --timeout 120 --graceful-timeout 30 \
|
||
# --certfile=/etc/letsencrypt/live/api.valuefrontier.cn/fullchain.pem \
|
||
# --keyfile=/etc/letsencrypt/live/api.valuefrontier.cn/privkey.pem \
|
||
# app_vx:app
|
||
#
|
||
# 【启动方式 2】使用 Gunicorn + 多进程(不使用 gevent):
|
||
# gunicorn -w 4 -b 0.0.0.0:5002 --timeout 120 --graceful-timeout 30 \
|
||
# --certfile=/etc/letsencrypt/live/api.valuefrontier.cn/fullchain.pem \
|
||
# --keyfile=/etc/letsencrypt/live/api.valuefrontier.cn/privkey.pem \
|
||
# app_vx:app
|
||
#
|
||
# 【启动方式 3】开发环境直接运行(仅限本地调试):
|
||
# python app_vx.py
|
||
#
|
||
# ===================== 应用启动配置结束 =====================
|
||
|
||
if __name__ == '__main__':
|
||
import argparse
|
||
|
||
parser = argparse.ArgumentParser(description='启动 Flask 应用')
|
||
parser.add_argument('--debug', action='store_true', help='启用调试模式(仅限开发环境)')
|
||
parser.add_argument('--port', type=int, default=5002, help='监听端口(默认 5002)')
|
||
parser.add_argument('--no-ssl', action='store_true', help='禁用 SSL(仅限开发环境)')
|
||
args = parser.parse_args()
|
||
|
||
# 直接运行时,立即初始化缓存
|
||
with app.app_context():
|
||
init_sywg_industry_cache()
|
||
_sywg_cache_initialized = True
|
||
|
||
# 警告:生产环境不应使用 debug=True
|
||
if args.debug:
|
||
logger.warning("⚠️ 调试模式已启用,仅限开发环境使用!")
|
||
|
||
# SSL 配置
|
||
ssl_context = None
|
||
if not args.no_ssl:
|
||
cert_file = '/etc/letsencrypt/live/api.valuefrontier.cn/fullchain.pem'
|
||
key_file = '/etc/letsencrypt/live/api.valuefrontier.cn/privkey.pem'
|
||
if os.path.exists(cert_file) and os.path.exists(key_file):
|
||
ssl_context = (cert_file, key_file)
|
||
else:
|
||
logger.warning("⚠️ SSL 证书文件不存在,将使用 HTTP 模式")
|
||
|
||
logger.info(f"🚀 启动 Flask 应用: port={args.port}, debug={args.debug}, ssl={'enabled' if ssl_context else 'disabled'}")
|
||
|
||
app.run(
|
||
host='0.0.0.0',
|
||
port=args.port,
|
||
debug=args.debug,
|
||
ssl_context=ssl_context,
|
||
threaded=True # 启用多线程处理请求
|
||
)
|