Files
vf_react/app_vx_raw.py
2025-11-13 10:20:03 +08:00

8127 lines
296 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import hmac
import json
import base64
from hashlib import sha1
import csv
import logging
import re
from urllib.parse import urlencode, quote
import math
import os
import pytz
import requests
from flask_compress import Compress
from collections import defaultdict
import jieba
import jieba.analyse
from functools import lru_cache, wraps
import threading
from pathlib import Path
import pickle
import psutil
import time
import gc
from typing import Dict, Any, Optional, Tuple
import pandas as pd
from sqlalchemy import Column, Integer, String, Boolean, DateTime, create_engine, text, func, or_, case, event, desc, \
JSON, asc
from flask import Flask, session, has_request_context, render_template, request, jsonify, redirect, url_for, flash, \
session, render_template_string, current_app, send_from_directory
from flask_sqlalchemy import SQLAlchemy
from flask_login import LoginManager, UserMixin, login_user, logout_user, login_required, current_user
from flask_mail import Mail, Message
from itsdangerous import URLSafeTimedSerializer
import random
import string
from flask_migrate import Migrate
from flask_session import Session # type: ignore
from sqlalchemy.dialects.mysql.base import MySQLDialect
from sqlalchemy.dialects.postgresql import JSONB
from werkzeug.utils import secure_filename
from PIL import Image
from datetime import datetime, timedelta, time as dt_time
import hashlib
from werkzeug.security import generate_password_hash, check_password_hash
import json
from config import STOP_WORDS
from clickhouse_driver import Client as Cclient
import jwt
import uuid
import redis
from docx import Document
# 初始化 Flask-Migrate
engine = create_engine("mysql+pymysql://root:Zzl5588161!@111.198.58.126:33060/stock", echo=False, pool_size=20,
max_overflow=50)
engine_med = create_engine("mysql+pymysql://root:Zzl5588161!@111.198.58.126:33060/med", echo=False)
engine_2 = create_engine("mysql+pymysql://root:Zzl5588161!@111.198.58.126:33060/valuefrontier", echo=False)
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
app = Flask(__name__)
Compress(app)
UPLOAD_FOLDER = 'static/uploads/avatars'
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif'}
MAX_CONTENT_LENGTH = 16 * 1024 * 1024 # 16MB max file size
# Configure Flask-Compress
app.config['COMPRESS_ALGORITHM'] = ['gzip', 'br']
app.config['COMPRESS_MIMETYPES'] = [
'text/html',
'text/css',
'text/xml',
'application/json',
'application/javascript',
'application/x-javascript'
]
# Redis 初始化
redis_client = redis.StrictRedis(
host='43.143.189.195', # 改为你的 Redis 服务器地址
port=6379,
password='Zzl338180',
db=0,
decode_responses=True
)
app.config['SECRET_KEY'] = 'vf7891574233241'
app.config['SQLALCHEMY_DATABASE_URI'] = 'mysql+pymysql://root:Zzl5588161!@111.198.58.126:33060/stock?charset=utf8mb4'
app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False
app.config['JSON_AS_ASCII'] = False
app.config['JSONIFY_PRETTYPRINT_REGULAR'] = True
# 邮件配置
app.config['MAIL_SERVER'] = 'smtp.exmail.qq.com'
app.config['MAIL_PORT'] = 465
app.config['MAIL_USE_SSL'] = True
app.config['MAIL_USERNAME'] = 'admin@valuefrontier.cn'
app.config['MAIL_PASSWORD'] = 'QYncRu6WUdASvTg4'
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
app.config['MAX_CONTENT_LENGTH'] = MAX_CONTENT_LENGTH
# 短信验证模块
app.config['QINIU_ACCESS_KEY'] = "0MIwksc8RvcNten1iUdTbtwB6orPOfzRiqYTXVOU"
app.config['QINIU_SECRET_KEY'] = "QmjyOi27XoLtBccv5AAI6khIcJncLfr0ErjSMu_i"
app.config['QINIU_TEMPLATE_ID'] = "1901640687774875648"
app.config['QINIU_SIGNATURE_ID'] = "1900745528702943232"
app.config['WECHAT_APP_ID'] = 'wxa8d74c47041b5f87'
app.config['WECHAT_APP_SECRET'] = 'eedef95b11787fd7ca7f1acc6c9061bc'
app.config['BASE_URL'] = 'http://43.143.189.195:5002'
app.config['WECHAT_REDIRECT_URI'] = f"{app.config['BASE_URL']}/api/wechat/callback"
app.config['CACHE_TYPE'] = 'redis'
app.config['CACHE_REDIS_HOST'] = '43.143.189.195' # 使用实际的服务器IP
app.config['CACHE_REDIS_PORT'] = 6379
app.config['CACHE_REDIS_PASSWORD'] = 'Zzl338180'
app.config['CACHE_DEFAULT_TIMEOUT'] = 300
app.config['SESSION_TYPE'] = 'redis'
app.config['SESSION_REDIS'] = redis.Redis(host='43.143.189.195', port=6379, password='Zzl338180')
# Cache directory setup
CACHE_DIR = Path('cache')
CACHE_DIR.mkdir(exist_ok=True)
# Memory management constants
MAX_MEMORY_PERCENT = 75
MEMORY_CHECK_INTERVAL = 300
MAX_CACHE_ITEMS = 50
# 初始化扩展
db = SQLAlchemy(app)
mail = Mail(app)
login_manager = LoginManager(app)
login_manager.login_view = 'login'
serializer = URLSafeTimedSerializer(app.config['SECRET_KEY'])
migrate = Migrate(app, db)
DOMAIN = 'http://43.143.189.195:5002'
JWT_SECRET = 'Llmgreat123'
JWT_ALGORITHM = 'HS256'
JWT_EXPIRES_SECONDS = 3600 # 1小时有效期
Session(app)
def token_required(f):
@wraps(f)
def decorated(*args, **kwargs):
token = None
# 从请求头中提取 Authorization: Bearer <token>
if 'Authorization' in request.headers:
auth_header = request.headers['Authorization']
if auth_header.startswith('Bearer '):
token = auth_header.split(" ")[1]
if not token:
return jsonify({'code': 401, 'message': '未提供 Token'}), 401
try:
payload = jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALGORITHM])
user = User.query.get(payload['user_id'])
if not user:
return jsonify({'code': 401, 'message': '无效 Token'}), 401
# 动态绑定用户对象(可选)
request.user = user
except jwt.ExpiredSignatureError:
return jsonify({'code': 401, 'message': 'Token 已过期'}), 401
except jwt.InvalidTokenError:
return jsonify({'code': 401, 'message': '无效 Token'}), 401
return f(*args, **kwargs)
return decorated
def generate_qiniu_token(access_key, secret_key, method, path, host, content_type=None, body=None):
"""生成七牛云认证令牌"""
# 步骤1-2: 添加方法和路径
data = f"{method} {path}"
# 步骤3: 添加主机
data += f"\nHost: {host}"
# 步骤4: 添加内容类型(如果存在)
if content_type:
data += f"\nContent-Type: {content_type}"
# 步骤5: 添加回车
data += "\n\n"
# 步骤6: 添加请求体
if body and content_type and content_type != "application/octet-stream":
data += body
# 计算HMAC-SHA1签名并编码
sign = hmac.new(secret_key.encode(), data.encode(), sha1).digest()
encoded_sign = base64.urlsafe_b64encode(sign).decode('utf-8')
# 生成七牛令牌
return f"Qiniu {access_key}:{encoded_sign}"
def send_sms_verification_minimal(phone_number, redis_client):
"""
最小化版本 - 使用urllib避免requests库问题
"""
import urllib.request
import urllib.parse
import urllib.error
import json as json_lib
import hmac
from hashlib import sha1
import base64
import random
import string
import ssl
try:
# 生成6位数字验证码
verification_code = ''.join(random.choices(string.digits, k=6))
print(f"生成验证码: {verification_code}")
# 存储验证码到 Redis
redis_key = f"sms:login:{phone_number}" # 统一key格式
redis_client.setex(redis_key, 600, verification_code)
print(f"验证码已存储到Rediskey: {redis_key}")
# 硬编码配置
QINIU_ACCESS_KEY = "0MIwksc8RvcNten1iUdTbtwB6orPOfzRiqYTXVOU"
QINIU_SECRET_KEY = "QmjyOi27XoLtBccv5AAI6khIcJncLfr0ErjSMu_i"
QINIU_TEMPLATE_ID = "1901640687774875648"
QINIU_SIGNATURE_ID = "1900745528702943232"
# 准备请求数据
host = "sms.qiniuapi.com"
path = "/v1/message/single"
method = "POST"
content_type = "application/json"
request_data = {
"template_id": QINIU_TEMPLATE_ID,
"mobile": phone_number,
"parameters": {
"code": verification_code
},
"signature_id": QINIU_SIGNATURE_ID
}
body = json_lib.dumps(request_data)
print("请求数据准备完成")
# 生成签名
data = f"{method} {path}"
data += f"\nHost: {host}"
data += f"\nContent-Type: {content_type}"
data += "\n\n"
data += body
sign = hmac.new(QINIU_SECRET_KEY.encode(), data.encode(), sha1).digest()
encoded_sign = base64.urlsafe_b64encode(sign).decode('utf-8')
token = f"Qiniu {QINIU_ACCESS_KEY}:{encoded_sign}"
print("签名生成完成")
# 创建请求
url = f"https://{host}{path}"
headers = {
"Content-Type": content_type,
"Authorization": token,
"User-Agent": "ValueFrontier/1.0"
}
# 使用urllib发送请求
req = urllib.request.Request(
url,
data=body.encode('utf-8'),
headers=headers,
method=method
)
# 创建SSL上下文忽略证书验证
ctx = ssl.create_default_context()
ctx.check_hostname = False
ctx.verify_mode = ssl.CERT_NONE
print("开始发送请求")
# 发送请求
with urllib.request.urlopen(req, timeout=10, context=ctx) as response:
response_data = response.read().decode('utf-8')
result = json_lib.loads(response_data)
print(f"响应: {result}")
print(f"实际发送的验证码: {verification_code}") # 重要:记录实际验证码
# 检查结果
if 'message_id' in result:
print(f"发送成功: {phone_number}, message_id: {result.get('message_id')}")
return True, "验证码发送成功", verification_code # 返回验证码用于测试
else:
error_msg = result.get('message', '发送失败')
print(f"发送失败: {error_msg}")
return False, f"发送失败: {error_msg}", None
except urllib.error.HTTPError as e:
print(f"HTTP错误: {e.code} - {e.reason}")
return False, f"网络错误: {e.code}"
except urllib.error.URLError as e:
print(f"URL错误: {e.reason}")
return False, "网络连接失败"
except Exception as e:
print(f"其他异常: {type(e).__name__} - {str(e)}")
return False, "发送失败,请重试"
def verify_sms_code(phone_number, code):
"""验证短信验证码"""
stored_code = session.get('sms_verification_code')
stored_phone = session.get('sms_verification_phone')
expiration = session.get('sms_verification_expiration')
if not all([stored_code, stored_phone, expiration]):
return False, "请先获取验证码"
if stored_phone != phone_number:
return False, "手机号与验证码不匹配"
if beijing_now().timestamp() > expiration:
return False, "验证码已过期"
if code != stored_code:
return False, "验证码错误"
return True, "验证成功"
def beijing_now():
# 使用 pytz 处理时区
beijing_tz = pytz.timezone('Asia/Shanghai')
return datetime.now(beijing_tz)
class User(UserMixin, db.Model):
"""用户模型"""
id = db.Column(db.Integer, primary_key=True)
# 基础账号信息(注册时必填)
username = db.Column(db.String(80), unique=True, nullable=False) # 用户名
email = db.Column(db.String(120), unique=True, nullable=False) # 邮箱
password_hash = db.Column(db.String(128), nullable=False) # 密码哈希
email_confirmed = db.Column(db.Boolean, default=False) # 邮箱是否验证
wechat_union_id = db.Column(db.String(100), unique=True) # 微信 UnionID
wechat_open_id = db.Column(db.String(100)) # 微信 OpenID
# 账号状态
created_at = db.Column(db.DateTime, default=beijing_now) # 注册时间
last_seen = db.Column(db.DateTime, default=beijing_now) # 最后活跃时间
status = db.Column(db.String(20), default='active') # 账号状态 active/banned/deleted
# 个人资料(可选,后续在个人中心完善)
nickname = db.Column(db.String(30)) # 社区昵称
avatar_url = db.Column(db.String(200)) # 头像URL
banner_url = db.Column(db.String(200)) # 个人主页背景图
bio = db.Column(db.String(200)) # 个人简介
gender = db.Column(db.String(10)) # 性别
birth_date = db.Column(db.Date) # 生日
location = db.Column(db.String(100)) # 所在地
# 联系方式(可选)
phone = db.Column(db.String(20)) # 手机号
wechat_id = db.Column(db.String(80)) # 微信号
# 实名认证信息(可选)
real_name = db.Column(db.String(30)) # 真实姓名
id_number = db.Column(db.String(18)) # 身份证号(加密存储)
is_verified = db.Column(db.Boolean, default=False) # 是否实名认证
verify_time = db.Column(db.DateTime) # 实名认证时间
# 投资相关信息(可选)
trading_experience = db.Column(db.String(200)) # 炒股年限
investment_style = db.Column(db.String(50)) # 投资风格
risk_preference = db.Column(db.String(20)) # 风险偏好
investment_amount = db.Column(db.String(20)) # 投资规模
preferred_markets = db.Column(db.String(200), default='[]') # 偏好市场 JSON
# 社区信息(系统自动更新)
user_level = db.Column(db.Integer, default=1) # 用户等级
reputation_score = db.Column(db.Integer, default=0) # 信用积分
contribution_point = db.Column(db.Integer, default=0) # 贡献点数
post_count = db.Column(db.Integer, default=0) # 发帖数
comment_count = db.Column(db.Integer, default=0) # 评论数
follower_count = db.Column(db.Integer, default=0) # 粉丝数
following_count = db.Column(db.Integer, default=0) # 关注数
# 创作者信息(可选)
is_creator = db.Column(db.Boolean, default=False) # 是否创作者
creator_type = db.Column(db.String(20)) # 创作者类型
creator_tags = db.Column(db.String(200), default='[]') # 创作者标签 JSON
# 系统设置
email_notifications = db.Column(db.Boolean, default=True) # 邮件通知
sms_notifications = db.Column(db.Boolean, default=False) # 短信通知
wechat_notifications = db.Column(db.Boolean, default=False) # 微信通知
notification_preferences = db.Column(db.String(500), default='{}') # 通知偏好 JSON
privacy_level = db.Column(db.String(20), default='public') # 隐私级别
theme_preference = db.Column(db.String(20), default='light') # 主题偏好
blocked_keywords = db.Column(db.String(500), default='[]') # 屏蔽关键词 JSON
# 手机号验证
phone_confirmed = db.Column(db.Boolean, default=False) # 手机是否验证
phone_confirm_time = db.Column(db.DateTime) # 手机验证时间
def __init__(self, username, email=None, password=None, phone=None):
"""初始化用户,只需要基本信息"""
self.username = username
if email:
self.email = email
if phone:
self.phone = phone
if password:
self.set_password(password)
self.created_at = beijing_now()
self.last_seen = beijing_now()
def set_password(self, password):
"""设置密码"""
self.password_hash = generate_password_hash(password)
def check_password(self, password):
"""验证密码"""
return check_password_hash(self.password_hash, password)
def update_last_seen(self):
"""更新最后活跃时间"""
self.last_seen = beijing_now()
# JSON 字段的getter和setter
def get_preferred_markets(self):
"""获取偏好市场列表"""
if self.preferred_markets:
try:
return json.loads(self.preferred_markets)
except:
return []
return []
def get_blocked_keywords(self):
"""获取屏蔽关键词列表"""
if self.blocked_keywords:
try:
return json.loads(self.blocked_keywords)
except:
return []
return []
def get_notification_preferences(self):
"""获取通知偏好设置"""
if self.notification_preferences:
try:
return json.loads(self.notification_preferences)
except:
return {}
return {}
def get_creator_tags(self):
"""获取创作者标签"""
if self.creator_tags:
try:
return json.loads(self.creator_tags)
except:
return []
return []
def set_preferred_markets(self, markets):
"""设置偏好市场"""
self.preferred_markets = json.dumps(markets)
def set_blocked_keywords(self, keywords):
"""设置屏蔽关键词"""
self.blocked_keywords = json.dumps(keywords)
def set_notification_preferences(self, preferences):
"""设置通知偏好"""
self.notification_preferences = json.dumps(preferences)
def set_creator_tags(self, tags):
"""设置创作者标签"""
self.creator_tags = json.dumps(tags)
def to_dict(self):
"""返回用户的字典表示"""
return {
'id': self.id,
'username': self.username,
'email': self.email,
'nickname': self.nickname,
'avatar_url': get_full_avatar_url(self.avatar_url), # 修改这里
'bio': self.bio,
'is_verified': self.is_verified,
'user_level': self.user_level,
'reputation_score': self.reputation_score,
'is_creator': self.is_creator
}
def __repr__(self):
return f'<User {self.username}>'
@app.route('/send_sms_verification_route', methods=['POST'])
def send_sms_verification_route():
"""发送短信验证码API"""
try:
print("收到短信发送请求")
if not request.is_json:
print("请求不是JSON格式")
return jsonify({
'success': False,
'message': '无效的请求'
})
phone = request.json.get('phone')
print(f"手机号: {phone}")
if not phone:
return jsonify({
'success': False,
'message': '手机号不能为空'
})
# 验证手机号格式
if not phone.isdigit() or len(phone) != 11 or not phone.startswith('1'):
return jsonify({
'success': False,
'message': '无效的手机号码格式'
})
# 使用最小化函数
success, message = send_sms_verification_minimal(phone, redis_client)
return jsonify({
'success': success,
'message': message
})
except Exception as e:
print(f"路由异常: {type(e).__name__} - {str(e)}")
return jsonify({
'success': False,
'message': '系统错误'
})
@app.route('/check_phone', methods=['POST'])
def check_phone():
"""检查手机号是否已注册"""
phone = request.json.get('phone', '')
user = User.query.filter_by(phone=phone).first()
return jsonify({
'exists': user is not None,
'message': '该手机号已被注册' if user else 'ok'
})
@app.route('/register_with_phone', methods=['POST'])
def register_with_phone():
"""手机号注册"""
username = request.form.get('username')
phone = request.form.get('phone')
password = request.form.get('password')
verification_code = request.form.get('verification_code')
# 验证数据
if User.query.filter_by(username=username).first():
return jsonify({
'success': False,
'error_field': 'username',
'message': '用户名已存在'
})
if User.query.filter_by(phone=phone).first():
return jsonify({
'success': False,
'error_field': 'phone',
'message': '手机号已被注册'
})
# 验证短信验证码
success, message = verify_sms_code(phone, verification_code)
if not success:
return jsonify({
'success': False,
'error_field': 'verification_code',
'message': message
})
# 创建用户
try:
user = User(username=username, password=password)
user.phone = phone
user.phone_confirmed = True
user.phone_confirm_time = beijing_now()
db.session.add(user)
db.session.commit()
# 清除会话中的验证信息
session.pop('sms_verification_code', None)
session.pop('sms_verification_phone', None)
session.pop('sms_verification_expiration', None)
# 自动登录
login_user(user)
return jsonify({
'success': True,
'message': '注册成功!',
'redirect_url': url_for('index')
})
except Exception as e:
db.session.rollback()
app.logger.error(f"手机注册错误: {str(e)}")
return jsonify({
'success': False,
'message': '注册失败,请重试'
})
@app.route('/login_with_phone', methods=['POST'])
def login_with_phone():
"""手机号验证码登录"""
phone = request.form.get('phone')
verification_code = request.form.get('verification_code')
next_page = request.form.get('next')
# 验证短信验证码
success, message = verify_sms_code(phone, verification_code)
if not success:
return jsonify({
'success': False,
'message': message
})
# 根据手机号查找用户
user = User.query.filter_by(phone=phone).first()
if not user:
return jsonify({
'success': False,
'message': '该手机号未注册'
})
# 登录用户
try:
login_user(user)
# 清除会话中的验证信息
session.pop('sms_verification_code', None)
session.pop('sms_verification_phone', None)
session.pop('sms_verification_expiration', None)
# 更新用户最后活跃时间
user.update_last_seen()
db.session.commit()
# 重定向到下一页或首页
redirect_url = next_page if next_page and next_page.startswith('/') else url_for('index')
return jsonify({
'success': True,
'message': '登录成功!',
'redirect_url': redirect_url
})
except Exception as e:
app.logger.error(f"手机登录错误: {str(e)}")
return jsonify({
'success': False,
'message': '登录失败,请重试'
})
@app.route('/bind_phone', methods=['POST'])
@token_required
def bind_phone():
"""绑定手机号到现有账号"""
phone = request.form.get('phone')
verification_code = request.form.get('verification_code')
# 验证短信验证码
success, message = verify_sms_code(phone, verification_code)
if not success:
return jsonify({
'success': False,
'message': message
})
# 检查手机号是否已被其他账号使用
existing_user = User.query.filter_by(phone=phone).first()
if existing_user and existing_user.id != request.user.id:
return jsonify({
'success': False,
'message': '该手机号已被其他账号绑定'
})
try:
request.user.phone = phone
request.user.phone_confirmed = True
request.user.phone_confirm_time = beijing_now()
db.session.commit()
# 清除会话中的验证信息
session.pop('sms_verification_code', None)
session.pop('sms_verification_phone', None)
session.pop('sms_verification_expiration', None)
return jsonify({
'success': True,
'message': '手机号绑定成功'
})
except Exception as e:
db.session.rollback()
app.logger.error(f"手机绑定错误: {str(e)}")
return jsonify({
'success': False,
'message': '绑定失败,请重试'
})
# 路由
@app.route('/register', methods=['GET', 'POST'])
def register():
if request.method == 'POST':
username = request.form.get('username')
email = request.form.get('email')
password = request.form.get('password')
verification_code = request.form.get('verification_code')
# 验证数据
if User.query.filter_by(username=username).first():
return jsonify({
'success': False,
'error_field': 'username',
'message': '用户名已存在'
})
if User.query.filter_by(email=email).first():
return jsonify({
'success': False,
'error_field': 'email',
'message': '邮箱已被注册'
})
# 验证验证码
stored_code = session.get('verification_code')
stored_email = session.get('verification_email')
expiration = session.get('verification_expiration')
if not all([stored_code, stored_email, expiration]):
return jsonify({
'success': False,
'error_field': 'verification_code',
'message': '请先获取验证码'
})
if stored_email != email:
return jsonify({
'success': False,
'error_field': 'verification_code',
'message': '邮箱与验证码不匹配'
})
if beijing_now().timestamp() > expiration:
return jsonify({
'success': False,
'error_field': 'verification_code',
'message': '验证码已过期'
})
if verification_code != stored_code:
return jsonify({
'success': False,
'error_field': 'verification_code',
'message': '验证码错误'
})
# 创建用户
try:
user = User(username=username, email=email, password=password)
# 已验证邮箱,直接设置为已验证状态
user.email_confirmed = True
db.session.add(user)
db.session.commit()
# 清除会话中的验证信息
session.pop('verification_code', None)
session.pop('verification_email', None)
session.pop('verification_expiration', None)
return jsonify({
'success': True,
'message': '注册成功!',
'redirect_url': url_for('login')
})
except Exception as e:
db.session.rollback()
app.logger.error(f"Registration error: {str(e)}")
return jsonify({
'success': False,
'message': '注册失败,请重试'
})
return render_template('pages/sign-up/sign-up-basic.html')
@app.route('/check_username', methods=['POST'])
def check_username():
username = request.json.get('username', '')
user = User.query.filter_by(username=username).first()
return jsonify({
'exists': user is not None,
'message': '该用户名已被使用' if user else 'ok'
})
@app.route('/check_email', methods=['POST'])
def check_email():
email = request.json.get('email', '')
user = User.query.filter_by(email=email).first()
return jsonify({
'exists': user is not None,
'message': '该邮箱已被注册' if user else 'ok'
})
@app.route('/resend_verification', methods=['GET'])
def resend_verification():
email = session.get('verification_email')
if not email:
return jsonify({
'success': False,
'message': '请重新注册'
})
user = User.query.filter_by(email=email).first()
if not user:
return jsonify({
'success': False,
'message': '用户不存在'
})
if user.email_confirmed:
return jsonify({
'success': False,
'message': '邮箱已经验证过了'
})
# 发送新的验证码
send_verification_email(email)
return jsonify({
'success': True,
'message': '新的验证码已发送,请查收'
})
@app.route('/send_verification', methods=['POST'])
def send_verification():
if not request.is_json:
return jsonify({
'success': False,
'message': '无效的请求'
})
email = request.json.get('email')
if not email:
return jsonify({
'success': False,
'message': '邮箱不能为空'
})
# Check if email is already registered
user = User.query.filter_by(email=email).first()
if user and user.email_confirmed:
return jsonify({
'success': False,
'message': '该邮箱已注册且已验证'
})
# Send verification email
send_verification_email(email)
return jsonify({
'success': True,
'message': '验证码已发送'
})
# 修改现有的登录路由以支持手机号登录
# 修改现有的登录路由以支持手机号登录
@app.route('/login', methods=['GET', 'POST'])
def login():
if request.method == 'POST':
login_type = request.form.get('login_type', 'username')
next_page = request.form.get('next')
# 常规用户名/邮箱登录
if login_type in ['username', 'email']:
# 获取用户名/邮箱和密码
if login_type == 'username':
credential = request.form.get('username')
password = request.form.get('password')
else: # email
credential = request.form.get('email')
password = request.form.get('email_password')
# 查找用户
user = None
if '@' in credential and login_type == 'email':
user = User.query.filter_by(email=credential).first()
elif login_type == 'username':
# 先尝试用户名查找
user = User.query.filter_by(username=credential).first()
# 如果没找到,尝试手机号
if user is None and credential.isdigit() and len(credential) == 11:
user = User.query.filter_by(phone=credential).first()
# 如果还没找到,尝试邮箱
if user is None and '@' in credential:
user = User.query.filter_by(email=credential).first()
# 验证密码
if user and user.check_password(password):
# 检查账号状态
if user.status != 'active':
flash('您的账号已被禁用或删除')
return redirect(url_for('login', next=next_page))
# 优先检查邮箱验证,其次检查手机号验证
if not user.email_confirmed and not user.phone_confirmed:
flash('请先验证您的邮箱或手机号')
return redirect(url_for('login', next=next_page))
login_user(user)
# 更新用户最后活跃时间
user.update_last_seen()
db.session.commit()
# 确保 next_page 是安全的 URL
if next_page and next_page.startswith('/'):
return redirect(next_page)
return redirect(url_for('index'))
flash('账号或密码错误')
return redirect(url_for('login', next=next_page))
# 手机号验证码登录
elif login_type == 'phone':
return redirect(url_for('login_with_phone_page'))
# GET 请求,显示登录页面
return render_template('pages/sign-in/index.html')
# 手机号登录页面
@app.route('/login/phone', methods=['GET'])
def login_with_phone_page():
next_page = request.args.get('next')
return render_template('pages/sign-in/sign-in-phone.html', next=next_page)
def send_notification_email(recipient, subject, template, **kwargs):
"""
发送通知邮件
:param recipient: 收件人邮箱
:param subject: 邮件主题
:param template: 模板文件名
:param kwargs: 传递给模板的参数
"""
try:
# 读取邮件模板内容
msg = Message(
subject=subject,
sender=app.config['MAIL_USERNAME'],
recipients=[recipient]
)
# 渲染HTML邮件内容
if template == 'emails/notification_post_liked.html':
msg.html = render_template_string("""
<div style="max-width: 600px; margin: 0 auto; padding: 20px; font-family: Arial, sans-serif;">
<h2 style="color: #333;">你的帖子收到了新的点赞</h2>
<div style="background: #f7f7f7; padding: 20px; border-radius: 5px; margin: 20px 0;">
<p>{{ liker.username }} 点赞了你的帖子</p>
<p style="color: #666; font-size: 14px;">帖子内容: {{ post.content[:100] }}...</p>
</div>
<a href="{{ url_for('event_detail', event_id=post.event_id, _anchor='post-' + post.id|string, _external=True) }}"
style="display: inline-block; background: #4a90e2; color: white; padding: 10px 20px; text-decoration: none; border-radius: 5px;">
查看详情
</a>
<p style="color: #666; font-size: 13px; margin-top: 20px;">
如果你不想再收到此类通知,可以在个人设置中关闭邮件通知
</p>
</div>
""", **kwargs)
elif template == 'emails/notification_post_commented.html':
msg.html = render_template_string("""
<div style="max-width: 600px; margin: 0 auto; padding: 20px; font-family: Arial, sans-serif;">
<h2 style="color: #333;">你的帖子收到了新的评论</h2>
<div style="background: #f7f7f7; padding: 20px; border-radius: 5px; margin: 20px 0;">
<p>{{ commenter.username }} 评论了你的帖子:</p>
<p style="color: #666; font-size: 14px;">帖子内容: {{ post.content[:100] }}...</p>
</div>
<a href="{{ url_for('event_detail', event_id=post.event_id, _anchor='post-' + post.id|string, _external=True) }}"
style="display: inline-block; background: #4a90e2; color: white; padding: 10px 20px; text-decoration: none; border-radius: 5px;">
查看详情
</a>
<p style="color: #666; font-size: 13px; margin-top: 20px;">
如果你不想再收到此类通知,可以在个人设置中关闭邮件通知
</p>
</div>
""", **kwargs)
elif template == 'emails/notification_comment_replied.html':
msg.html = render_template_string("""
<div style="max-width: 600px; margin: 0 auto; padding: 20px; font-family: Arial, sans-serif;">
<h2 style="color: #333;">你的评论收到了新的回复</h2>
<div style="background: #f7f7f7; padding: 20px; border-radius: 5px; margin: 20px 0;">
<p>{{ replier.username }} 回复了你的评论:</p>
<p style="color: #666; font-size: 14px;">你的评论: {{ comment.content[:100] }}...</p>
</div>
<a href="{{ url_for('event_detail', event_id=comment.post.event_id, _anchor='comment-' + comment.id|string, _external=True) }}"
style="display: inline-block; background: #4a90e2; color: white; padding: 10px 20px; text-decoration: none; border-radius: 5px;">
查看详情
</a>
<p style="color: #666; font-size: 13px; margin-top: 20px;">
如果你不想再收到此类通知,可以在个人设置中关闭邮件通知
</p>
</div>
""", **kwargs)
# 使用异步任务发送邮件
send_async_email(msg)
return True
except Exception as e:
app.logger.error(f"Error sending notification email: {str(e)}")
return False
def send_async_email(msg):
"""异步发送邮件"""
try:
mail.send(msg)
except Exception as e:
app.logger.error(f"Error sending async email: {str(e)}")
@app.route('/logout')
@token_required
def logout():
logout_user()
return redirect(url_for('index'))
@app.context_processor
def inject_user():
if has_request_context() and hasattr(request, 'user'):
return dict(current_user=request.user)
return dict(current_user=None)
@login_manager.user_loader
def load_user(user_id):
return User.query.get(int(user_id))
class Notification(db.Model):
"""通知模型"""
id = db.Column(db.Integer, primary_key=True)
user_id = db.Column(db.Integer, db.ForeignKey('user.id'))
type = db.Column(db.String(50)) # 通知类型
content = db.Column(db.Text) # 通知内容
link = db.Column(db.String(200)) # 相关链接
is_read = db.Column(db.Boolean, default=False) # 是否已读
created_at = db.Column(db.DateTime, default=beijing_now)
def __init__(self, user_id, type, content, link=None):
self.user_id = user_id
self.type = type
self.content = content
self.link = link
class CreatorApplication(db.Model):
"""创作者申请模型"""
id = db.Column(db.Integer, primary_key=True)
user_id = db.Column(db.Integer, db.ForeignKey('user.id'))
status = db.Column(db.String(20), default='pending') # pending, approved, rejected
application_type = db.Column(db.String(20)) # analyst, strategist, trader, researcher
description = db.Column(db.Text) # 申请说明/个人介绍
# 资质材料
qualifications = db.Column(db.JSON) # 资格证书、工作经验等
sample_works = db.Column(db.JSON) # 作品案例链接
# 审核信息
reviewer_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=True) # 审核人
review_time = db.Column(db.DateTime, nullable=True) # 审核时间
review_notes = db.Column(db.Text, nullable=True) # 审核备注
# 时间戳
created_at = db.Column(db.DateTime, default=beijing_now)
updated_at = db.Column(db.DateTime, default=beijing_now, onupdate=beijing_now)
# 关系
user = db.relationship('User', foreign_keys=[user_id], backref='creator_applications')
reviewer = db.relationship('User', foreign_keys=[reviewer_id])
def __init__(self, user_id, status='pending', application_type=None, description=None):
self.user_id = user_id
self.status = status
self.application_type = application_type
self.description = description
def approve(self, reviewer_id, notes=None):
"""批准申请"""
self.status = 'approved'
self.reviewer_id = reviewer_id
self.review_time = beijing_now()
self.review_notes = notes
# 更新用户为创作者
user = User.query.get(self.user_id)
if user:
user.is_creator = True
user.creator_type = self.application_type
def reject(self, reviewer_id, notes=None):
"""拒绝申请"""
self.status = 'rejected'
self.reviewer_id = reviewer_id
self.review_time = beijing_now()
self.review_notes = notes
def __repr__(self):
return f'<CreatorApplication {self.user_id} - {self.status}>'
class Event(db.Model):
"""事件模型"""
id = db.Column(db.Integer, primary_key=True)
title = db.Column(db.String(200), nullable=False)
description = db.Column(db.Text)
# 事件类型与状态
event_type = db.Column(db.String(50))
status = db.Column(db.String(20), default='active')
# 时间相关
start_time = db.Column(db.DateTime, default=beijing_now)
end_time = db.Column(db.DateTime)
created_at = db.Column(db.DateTime, default=beijing_now)
updated_at = db.Column(db.DateTime, default=beijing_now)
# 热度与统计
hot_score = db.Column(db.Float, default=0)
view_count = db.Column(db.Integer, default=0)
trending_score = db.Column(db.Float, default=0)
post_count = db.Column(db.Integer, default=0)
follower_count = db.Column(db.Integer, default=0)
# 关联信息
related_industries = db.Column(db.JSON)
keywords = db.Column(db.JSON)
files = db.Column(db.JSON)
importance = db.Column(db.String(20))
related_avg_chg = db.Column(db.Float, default=0)
related_max_chg = db.Column(db.Float, default=0)
related_week_chg = db.Column(db.Float, default=0)
# 新增字段
invest_score = db.Column(db.Integer) # 超预期得分
expectation_surprise_score = db.Column(db.Integer)
# 创建者信息
creator_id = db.Column(db.Integer, db.ForeignKey('user.id'))
creator = db.relationship('User', backref='created_events')
# 关系
posts = db.relationship('Post', backref='event', lazy='dynamic')
followers = db.relationship('EventFollow', backref='event', lazy='dynamic')
related_stocks = db.relationship('RelatedStock', backref='event', lazy='dynamic')
historical_events = db.relationship('HistoricalEvent', backref='event', lazy='dynamic')
related_data = db.relationship('RelatedData', backref='event', lazy='dynamic')
related_concepts = db.relationship('RelatedConcepts', backref='event', lazy='dynamic')
ind_type = db.Column(db.String(255))
@property
def keywords_list(self):
"""返回解析后的关键词列表"""
if not self.keywords:
return []
if isinstance(self.keywords, list):
return self.keywords
try:
# 如果是字符串尝试解析JSON
if isinstance(self.keywords, str):
decoded = json.loads(self.keywords)
# 处理Unicode编码的情况
if isinstance(decoded, list):
return [
keyword.encode('utf-8').decode('unicode_escape')
if isinstance(keyword, str) and '\\u' in keyword
else keyword
for keyword in decoded
]
return []
# 如果已经是字典或其他格式,尝试转换为列表
return list(self.keywords)
except (json.JSONDecodeError, AttributeError, TypeError):
return []
def set_keywords(self, keywords):
"""设置关键词列表"""
if isinstance(keywords, list):
self.keywords = json.dumps(keywords, ensure_ascii=False)
elif isinstance(keywords, str):
try:
# 尝试解析JSON字符串
parsed = json.loads(keywords)
if isinstance(parsed, list):
self.keywords = json.dumps(parsed, ensure_ascii=False)
else:
self.keywords = json.dumps([keywords], ensure_ascii=False)
except json.JSONDecodeError:
# 如果不是有效的JSON将其作为单个关键词
self.keywords = json.dumps([keywords], ensure_ascii=False)
class RelatedStock(db.Model):
"""相关标的模型"""
id = db.Column(db.Integer, primary_key=True)
event_id = db.Column(db.Integer, db.ForeignKey('event.id'))
stock_code = db.Column(db.String(20)) # 股票代码
stock_name = db.Column(db.String(100)) # 股票名称
sector = db.Column(db.String(100)) # 关联类型
relation_desc = db.Column(db.String(1024)) # 关联原因描述
created_at = db.Column(db.DateTime, default=beijing_now)
updated_at = db.Column(db.DateTime, default=beijing_now, onupdate=beijing_now)
correlation = db.Column(db.Float())
momentum = db.Column(db.String(1024)) # 动量
class RelatedData(db.Model):
"""关联数据模型"""
id = db.Column(db.Integer, primary_key=True)
event_id = db.Column(db.Integer, db.ForeignKey('event.id'))
title = db.Column(db.String(200)) # 数据标题
data_type = db.Column(db.String(50)) # 数据类型
data_content = db.Column(db.JSON) # 数据内容(JSON格式)
description = db.Column(db.Text) # 数据描述
created_at = db.Column(db.DateTime, default=beijing_now)
class RelatedConcepts(db.Model):
"""关联数据模型"""
id = db.Column(db.Integer, primary_key=True)
event_id = db.Column(db.Integer, db.ForeignKey('event.id'))
concept_code = db.Column(db.String(20)) # 数据标题
concept = db.Column(db.String(100)) # 数据类型
reason = db.Column(db.Text) # 数据描述
image_paths = db.Column(db.JSON) # 数据内容(JSON格式)
created_at = db.Column(db.DateTime, default=beijing_now)
@property
def image_paths_list(self):
"""返回解析后的图片路径列表"""
if not self.image_paths:
return []
try:
# 如果是字符串先解析成JSON
if isinstance(self.image_paths, str):
paths = json.loads(self.image_paths)
else:
paths = self.image_paths
# 确保paths是列表
if not isinstance(paths, list):
paths = [paths]
# 从每个对象中提取path字段
return [item['path'] if isinstance(item, dict) and 'path' in item
else item for item in paths]
except Exception as e:
print(f"Error processing image paths: {e}")
return []
def get_first_image_path(self):
"""获取第一张图片的完整路径"""
paths = self.image_paths_list
if not paths:
return None
# 获取第一个路径
first_path = paths[0]
# 返回完整路径
return first_path
class EventHotHistory(db.Model):
"""事件热度历史记录"""
id = db.Column(db.Integer, primary_key=True)
event_id = db.Column(db.Integer, db.ForeignKey('event.id'))
score = db.Column(db.Float) # 总分
interaction_score = db.Column(db.Float) # 互动分数
follow_score = db.Column(db.Float) # 关注度分数
view_score = db.Column(db.Float) # 浏览量分数
recent_activity_score = db.Column(db.Float) # 最近活跃度分数
time_decay = db.Column(db.Float) # 时间衰减因子
created_at = db.Column(db.DateTime, default=beijing_now)
event = db.relationship('Event', backref='hot_history')
class Post(db.Model):
"""帖子模型"""
id = db.Column(db.Integer, primary_key=True)
event_id = db.Column(db.Integer, db.ForeignKey('event.id'), nullable=False)
user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False)
# 内容
title = db.Column(db.String(200)) # 标题(可选)
content = db.Column(db.Text, nullable=False) # 内容
content_type = db.Column(db.String(20), default='text') # 内容类型:text/rich_text/link
# 时间
created_at = db.Column(db.DateTime, default=beijing_now)
updated_at = db.Column(db.DateTime, default=beijing_now, onupdate=beijing_now)
# 统计
likes_count = db.Column(db.Integer, default=0)
comments_count = db.Column(db.Integer, default=0)
view_count = db.Column(db.Integer, default=0)
# 状态
status = db.Column(db.String(20), default='active') # active/hidden/deleted
is_top = db.Column(db.Boolean, default=False) # 是否置顶
# 关系
user = db.relationship('User', backref='posts')
likes = db.relationship('PostLike', backref='post', lazy='dynamic')
comments = db.relationship('Comment', backref='post', lazy='dynamic')
# 辅助模型
class EventFollow(db.Model):
"""事件关注"""
id = db.Column(db.Integer, primary_key=True)
user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False)
event_id = db.Column(db.Integer, db.ForeignKey('event.id'), nullable=False)
created_at = db.Column(db.DateTime, default=beijing_now)
user = db.relationship('User', backref='event_follows')
__table_args__ = (db.UniqueConstraint('user_id', 'event_id'),)
class PostLike(db.Model):
"""帖子点赞"""
id = db.Column(db.Integer, primary_key=True)
user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False)
post_id = db.Column(db.Integer, db.ForeignKey('post.id'), nullable=False)
created_at = db.Column(db.DateTime, default=beijing_now)
user = db.relationship('User', backref='post_likes')
__table_args__ = (db.UniqueConstraint('user_id', 'post_id'),)
class Comment(db.Model):
"""评论"""
id = db.Column(db.Integer, primary_key=True)
post_id = db.Column(db.Integer, db.ForeignKey('post.id'), nullable=False)
user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False)
content = db.Column(db.Text, nullable=False)
parent_id = db.Column(db.Integer, db.ForeignKey('comment.id')) # 父评论ID,用于回复
created_at = db.Column(db.DateTime, default=beijing_now)
status = db.Column(db.String(20), default='active')
user = db.relationship('User', backref='comments')
replies = db.relationship('Comment', backref=db.backref('parent', remote_side=[id]))
class StockBasicInfo(db.Model):
__tablename__ = 'ea_stocklist'
SECCODE = db.Column(db.String(10), primary_key=True)
SECNAME = db.Column(db.String(40))
ORGNAME = db.Column(db.String(100))
F001V = db.Column(db.String(100)) # Pinyin abbreviation
F003V = db.Column(db.String(50)) # Security category
F005V = db.Column(db.String(50)) # Trading market
F006D = db.Column(db.DateTime) # Listing date
F011V = db.Column(db.String(50)) # Listing status
class CompanyInfo(db.Model):
__tablename__ = 'ea_baseinfo'
SECCODE = db.Column(db.String(10), primary_key=True)
SECNAME = db.Column(db.String(40))
ORGNAME = db.Column(db.String(100))
F001V = db.Column(db.String(100)) # English name
F003V = db.Column(db.String(40)) # Legal representative
F015V = db.Column(db.String(500)) # Main business
F016V = db.Column(db.String(4000)) # Business scope
F017V = db.Column(db.String(2000)) # Company introduction
F030V = db.Column(db.String(60)) # CSRC industry first level
F032V = db.Column(db.String(60)) # CSRC industry second level
class TradeData(db.Model):
__tablename__ = 'ea_trade'
SECCODE = db.Column(db.String(10), primary_key=True)
SECNAME = db.Column(db.String(40))
TRADEDATE = db.Column(db.Date, primary_key=True)
F002N = db.Column(db.Numeric(18, 4)) # Previous close
F003N = db.Column(db.Numeric(18, 4)) # Open price
F004N = db.Column(db.Numeric(18, 4)) # Trading volume
F005N = db.Column(db.Numeric(18, 4)) # High price
F006N = db.Column(db.Numeric(18, 4)) # Low price
F007N = db.Column(db.Numeric(18, 4)) # Close price
F009N = db.Column(db.Numeric(18, 4)) # Change
F010N = db.Column(db.Numeric(18, 4)) # Change percentage
F011N = db.Column(db.Numeric(18, 4)) # Trading amount
class SectorInfo(db.Model):
__tablename__ = 'ea_sector'
SECCODE = db.Column(db.String(10), primary_key=True)
SECNAME = db.Column(db.String(40))
F001V = db.Column(db.String(50), primary_key=True) # Classification standard code
F002V = db.Column(db.String(50)) # Classification standard
F003V = db.Column(db.String(50)) # Sector code
F004V = db.Column(db.String(50)) # Sector level 1 name
F005V = db.Column(db.String(50)) # Sector level 2 name
F006V = db.Column(db.String(50)) # Sector level 3 name
F007V = db.Column(db.String(50)) # Sector level 4 name
# 发送验证邮件
def send_verification_email(email):
"""
发送验证邮件
:param email: 收件人邮箱
"""
# 生成6位随机验证码
verification_code = ''.join(random.choices('0123456789', k=6))
# 设置验证码有效期为10分钟
expiration = beijing_now().timestamp() + 600
# 存储到会话
session['verification_code'] = verification_code
session['verification_email'] = email
session['verification_expiration'] = expiration
# 邮件主题
subject = '价值前沿 - 邮箱验证'
# 邮件内容
html = f'''
<div style="max-width: 600px; margin: 0 auto; padding: 20px; font-family: Arial, sans-serif;">
<h2 style="color: #333;">邮箱验证</h2>
<p>感谢您注册价值前沿!请使用以下验证码完成邮箱验证:</p>
<div style="background: #f7f7f7; padding: 15px; text-align: center; font-size: 24px; font-weight: bold; letter-spacing: 5px; margin: 20px 0;">
{verification_code}
</div>
<p>验证码有效期为10分钟请尽快完成验证。</p>
<p>如果这不是您的操作,请忽略此邮件。</p>
<p style="color: #666; font-size: 13px; margin-top: 20px;">
此邮件由系统自动发送,请勿回复。
</p>
</div>
'''
# 创建邮件对象
msg = Message(
subject=subject,
sender=app.config['MAIL_USERNAME'],
recipients=[email],
html=html
)
# 使用异步任务发送邮件
send_async_email(msg)
# 保护需要登录的路由
@app.route('/')
def index():
return render_template('presentation.html')
@app.route('/profile')
@token_required
def profile():
# Check if the user has completed the wizard
# If any essential fields are empty, redirect to the wizard
if not request.user.nickname or not request.user.investment_style or request.user.preferred_markets == '[]':
return redirect(url_for('profile_wizard'))
# Get user's recent posts (limit to 5)
posts = Post.query.filter_by(user_id=request.user.id, status='active') \
.order_by(Post.created_at.desc()) \
.limit(5).all()
# Get user's followed events (limit to 5)
followed_events = EventFollow.query.filter_by(user_id=request.user.id) \
.order_by(EventFollow.created_at.desc()) \
.limit(5).all()
return render_template('pages/account/profile.html',
user=request.user,
posts=posts,
followed_events=followed_events)
@app.route('/profile/wizard')
@token_required
def profile_wizard():
return render_template('pages/account/profile_wizard.html', user=request.user)
@app.route('/profile/wizard/save', methods=['POST'])
@token_required
def save_profile_wizard():
try:
data = request.json
# Update user information
user = request.user
# Basic info (Step 1)
if 'nickname' in data:
user.nickname = data.get('nickname')
if 'gender' in data:
user.gender = data.get('gender')
if 'bio' in data:
user.bio = data.get('bio')
# Investment preference (Step 2)
if 'investment_style' in data:
user.investment_style = data.get('investment_style')
# Investment details (Step 3)
if 'trading_experience' in data:
user.trading_experience = int(data.get('trading_experience')) if data.get('trading_experience') else None
if 'investment_amount' in data:
user.investment_amount = data.get('investment_amount')
if 'risk_preference' in data:
user.risk_preference = data.get('risk_preference')
if 'preferred_markets' in data:
user.preferred_markets = json.dumps(data.get('preferred_markets', []))
# Save changes
db.session.commit()
return jsonify({'success': True, 'message': '个人资料已成功更新'})
except Exception as e:
db.session.rollback()
app.logger.error(f"Error saving profile wizard data: {str(e)}")
return jsonify({'success': False, 'message': f'保存失败: {str(e)}'})
@app.route('/settings', methods=['GET'])
@token_required
def settings():
return render_template('pages/account/settings.html', user=request.user)
def allowed_file(filename):
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
# 更新视图函数
@app.route('/settings/profile', methods=['POST'])
@token_required
def update_profile():
"""更新个人资料"""
try:
user = request.user
form = request.form
# 基本信息更新
user.nickname = form.get('nickname')
user.bio = form.get('bio')
user.gender = form.get('gender')
user.birth_date = datetime.strptime(form.get('birth_date'), '%Y-%m-%d') if form.get('birth_date') else None
user.phone = form.get('phone')
user.location = form.get('location')
user.wechat_id = form.get('wechat_id')
# 处理头像上传
if 'avatar' in request.files:
file = request.files['avatar']
if file and allowed_file(file.filename):
# 生成安全的文件名
filename = secure_filename(f"{user.id}_{int(datetime.now().timestamp())}_{file.filename}")
filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
# 确保上传目录存在
os.makedirs(os.path.dirname(filepath), exist_ok=True)
# 保存并处理图片
image = Image.open(file)
image.thumbnail((300, 300)) # 调整图片大小
image.save(filepath)
# 更新用户头像URL
user.avatar_url = f'{DOMAIN}/static/uploads/avatars/{filename}'
db.session.commit()
return jsonify({'success': True, 'message': '个人资料已更新'})
except Exception as e:
db.session.rollback()
app.logger.error(f"Error updating profile: {str(e)}")
return jsonify({'success': False, 'message': '更新失败,请重试'})
@app.route('/settings/verify_identity', methods=['POST'])
@token_required
def verify_identity():
"""处理实名认证"""
try:
user = request.user
# 如果已经实名认证过,返回错误
if user.is_verified:
return jsonify({'success': False, 'message': '您已完成实名认证'})
# 获取表单数据
real_name = request.form.get('real_name')
id_number = request.form.get('id_number')
# 简单的身份证号验证(实际项目中应该使用更复杂的验证方式)
if not (real_name and id_number and len(id_number) == 18):
return jsonify({'success': False, 'message': '请输入有效的身份信息'})
# 更新用户信息
user.real_name = real_name
# 存储加密后的身份证号(实际项目中应使用更安全的加密方式)
user.id_number = hashlib.sha256(id_number.encode()).hexdigest()
user.is_verified = True
user.verify_time = beijing_now()
db.session.commit()
return jsonify({'success': True, 'message': '实名认证成功'})
except Exception as e:
db.session.rollback()
app.logger.error(f"Error in identity verification: {str(e)}")
return jsonify({'success': False, 'message': '认证失败,请重试'})
@app.route('/settings/password', methods=['POST'])
@token_required
def update_password():
"""修改密码"""
try:
current_password = request.form.get('current_password')
new_password = request.form.get('new_password')
confirm_password = request.form.get('confirm_password')
# 验证当前密码
if not request.user.check_password(current_password):
return jsonify({'success': False, 'message': '当前密码错误'})
# 验证新密码
if new_password != confirm_password:
return jsonify({'success': False, 'message': '两次输入的密码不一致'})
if len(new_password) < 8:
return jsonify({'success': False, 'message': '密码长度至少为8位'})
# 更新密码
request.user.set_password(new_password)
db.session.commit()
return jsonify({'success': True, 'message': '密码已更新'})
except Exception as e:
db.session.rollback()
app.logger.error(f"Error updating password: {str(e)}")
return jsonify({'success': False, 'message': '密码更新失败,请重试'})
# 投资偏好设置
@app.route('/settings/investment_preferences', methods=['POST'])
@token_required
def update_investment_preferences():
"""更新投资偏好"""
try:
user = request.user
form = request.form
user.trading_experience = form.get('trading_experience')
user.investment_style = form.get('investment_style')
user.risk_preference = form.get('risk_preference')
user.investment_amount = form.get('investment_amount')
user.preferred_markets = json.dumps(request.form.getlist('preferred_markets'))
db.session.commit()
return jsonify({'success': True, 'message': '投资偏好已更新'})
except Exception as e:
db.session.rollback()
app.logger.error(f"Error updating investment preferences: {str(e)}")
return jsonify({'success': False, 'message': '更新失败,请重试'})
@app.route('/settings/email', methods=['POST'])
@token_required
def update_email():
"""修改邮箱"""
try:
new_email = request.form.get('new_email')
if User.query.filter_by(email=new_email).first():
flash('该邮箱已被使用', 'error')
return redirect(url_for('settings'))
# 发送验证邮件
send_verification_email(new_email, purpose='change_email')
session['new_email'] = new_email
flash('验证邮件已发送到新邮箱,请查收', 'success')
return redirect(url_for('verify_new_email'))
except Exception as e:
app.logger.error(f"Error updating email: {str(e)}")
flash('操作失败,请重试', 'error')
return redirect(url_for('settings'))
@app.route('/settings/notifications', methods=['POST'])
@token_required
def update_notifications():
"""更新通知设置"""
try:
user = request.user
# 更新通知方式
user.email_notifications = 'email_notifications' in request.form
user.sms_notifications = 'sms_notifications' in request.form
user.wechat_notifications = 'wechat_notifications' in request.form
# 更新通知类型偏好
notification_preferences = {
'notify_comments': 'notify_comments' in request.form,
'notify_likes': 'notify_likes' in request.form,
'notify_follows': 'notify_follows' in request.form,
'notify_system': 'notify_system' in request.form,
'notify_news': 'notify_news' in request.form
}
user.notification_preferences = json.dumps(notification_preferences)
db.session.commit()
return jsonify({'success': True, 'message': '通知设置已更新'})
except Exception as e:
db.session.rollback()
app.logger.error(f"Error updating notifications: {str(e)}")
return jsonify({'success': False, 'message': '设置更新失败,请重试'})
@app.route('/settings/privacy', methods=['POST'])
@token_required
def update_privacy():
"""更新隐私设置"""
try:
user = request.user
user.privacy_level = request.form.get('privacy_level', 'public')
user.blocked_keywords = json.dumps(
[kw.strip() for kw in request.form.get('blocked_keywords', '').split('\n') if kw.strip()]
)
db.session.commit()
return jsonify({'success': True, 'message': '隐私设置已更新'})
except Exception as e:
db.session.rollback()
app.logger.error(f"Error updating privacy settings: {str(e)}")
return jsonify({'success': False, 'message': '设置更新失败,请重试'})
@app.route('/settings/creator', methods=['POST'])
@token_required
def update_creator_settings():
"""更新创作者设置"""
try:
user = request.user
if not user.is_creator:
return jsonify({'success': False, 'message': '您还不是创作者'})
user.creator_type = request.form.get('creator_type')
user.creator_tags = json.dumps(request.form.getlist('creator_tags'))
db.session.commit()
return jsonify({'success': True, 'message': '创作者设置已更新'})
except Exception as e:
db.session.rollback()
app.logger.error(f"Error updating creator settings: {str(e)}")
return jsonify({'success': False, 'message': '设置更新失败,请重试'})
@app.route('/settings/apply_creator', methods=['POST'])
@token_required
def apply_creator():
"""申请成为创作者"""
try:
user = request.user
# 检查是否已经是创作者
if user.is_creator:
return jsonify({'success': False, 'message': '您已经是创作者了'})
# 检查是否有待审核的申请
existing_application = CreatorApplication.query.filter_by(
user_id=user.id,
status='pending'
).first()
if existing_application:
return jsonify({'success': False, 'message': '您已有一个正在审核的申请'})
# 基本条件检查
if not user.is_verified:
return jsonify({'success': False, 'message': '请先完成实名认证'})
if user.post_count < 5:
return jsonify({'success': False, 'message': '需要至少发布5篇文章才能申请'})
# 获取申请信息
application_type = request.form.get('creator_type')
description = request.form.get('description')
# 验证必填信息
if not all([application_type, description]):
return jsonify({'success': False, 'message': '请填写完整的申请信息'})
# 处理资质材料
qualifications = []
if 'qualifications' in request.files:
files = request.files.getlist('qualifications')
for file in files:
if file and allowed_file(file.filename):
filename = secure_filename(f"qual_{user.id}_{int(time.time())}_{file.filename}")
filepath = os.path.join(app.config['UPLOAD_FOLDER'], 'qualifications', filename)
os.makedirs(os.path.dirname(filepath), exist_ok=True)
file.save(filepath)
qualifications.append({
'name': file.filename,
'path': f'/static/uploads/qualifications/{filename}'
})
# 创建申请记录
application = CreatorApplication(
user_id=user.id,
status='pending',
application_type=application_type,
description=description
)
if qualifications:
application.qualifications = qualifications
# 处理作品案例链接
sample_works = request.form.get('sample_works')
if sample_works:
application.sample_works = json.loads(sample_works)
db.session.add(application)
db.session.commit()
# 发送通知给管理员
notify_admins_new_application(application.id)
# 发送确认邮件给申请者
send_application_confirmation_email(user.email, application_type)
return jsonify({
'success': True,
'message': '申请已提交我们会在3个工作日内审核并通知您结果'
})
except Exception as e:
db.session.rollback()
app.logger.error(f"Error in creator application: {str(e)}")
return jsonify({'success': False, 'message': '申请提交失败,请重试'})
def notify_admins_new_application(application_id):
"""通知管理员有新的创作者申请"""
admin_users = User.query.filter_by(is_admin=True).all()
for admin in admin_users:
# 发送站内通知
notification = Notification(
user_id=admin.id,
type='new_creator_application',
content=f'有新的创作者申请需要审核 (ID: {application_id})',
link=f'/admin/creator_applications/{application_id}'
)
db.session.add(notification)
# 如果管理员开启了邮件通知
if admin.email_notifications:
send_admin_notification_email(
admin.email,
'新的创作者申请待审核',
f'有新的创作者申请需要审核,请登录管理后台查看。\n申请ID: {application_id}'
)
db.session.commit()
def send_application_confirmation_email(email, application_type):
"""发送申请确认邮件"""
subject = '创作者申请确认'
html_content = render_template(
'emails/creator_application_confirmation.html',
application_type=application_type
)
msg = Message(
subject,
recipients=[email],
html=html_content
)
mail.send(msg)
def send_admin_notification_email(email, subject, content):
"""发送管理员通知邮件"""
try:
msg = Message(
subject,
sender=app.config['MAIL_USERNAME'],
recipients=[email]
)
msg.html = render_template(
'emails/admin_notification.html',
subject=subject,
content=content
)
mail.send(msg)
return True
except Exception as e:
app.logger.error(f"Error sending admin notification email: {str(e)}")
return False
@app.route('/api/stock/quotes', methods=['POST'])
def get_stock_quotes():
codes = request.json.get('codes', [])
event_time = datetime.fromisoformat(request.json.get('event_time'))
current_time = datetime.now()
client = get_clickhouse_client()
# Get stock names from MySQL
stock_names = {}
with engine.connect() as conn:
for code in codes:
codez = code.split('.')[0]
result = conn.execute(text(
"SELECT SECNAME FROM ea_stocklist WHERE SECCODE = :code"
), {"code": codez}).fetchone()
if result:
stock_names[code] = result[0]
def get_trading_day_and_times(event_datetime):
event_date = event_datetime.date()
event_time = event_datetime.time()
# Trading hours
market_open = dt_time(9, 30)
market_close = dt_time(15, 0)
with engine.connect() as conn:
# First check if the event date itself is a trading day
is_trading_day = conn.execute(text("""
SELECT 1 FROM trading_days
WHERE EXCHANGE_DATE = :date
"""), {"date": event_date}).fetchone() is not None
if is_trading_day:
# If it's a trading day, determine time period based on event time
if event_time < market_open:
# Before market opens - use full trading day
return event_date, market_open, market_close
elif event_time > market_close:
# After market closes - get next trading day
next_trading_day = conn.execute(text("""
SELECT EXCHANGE_DATE FROM trading_days
WHERE EXCHANGE_DATE > :date
ORDER BY EXCHANGE_DATE LIMIT 1
"""), {"date": event_date}).fetchone()
# Convert to date object if we found a next trading day
return (next_trading_day[0].date() if next_trading_day else None,
market_open, market_close)
else:
# During trading hours
return event_date, event_time, market_close
else:
# If not a trading day, get next trading day
next_trading_day = conn.execute(text("""
SELECT EXCHANGE_DATE FROM trading_days
WHERE EXCHANGE_DATE > :date
ORDER BY EXCHANGE_DATE LIMIT 1
"""), {"date": event_date}).fetchone()
# Convert to date object if we found a next trading day
return (next_trading_day[0].date() if next_trading_day else None,
market_open, market_close)
trading_day, start_time, end_time = get_trading_day_and_times(event_time)
if not trading_day:
return jsonify({code: {'name': name, 'price': None, 'change': None}
for code, name in stock_names.items()})
# For historical dates, ensure we're using actual data
start_datetime = datetime.combine(trading_day, start_time)
end_datetime = datetime.combine(trading_day, end_time)
# If the trading day is in the future relative to current time,
# return only names without data
if trading_day > current_time.date():
return jsonify({code: {'name': name, 'price': None, 'change': None}
for code, name in stock_names.items()})
results = {}
for code in codes:
# Get the first price and last price for the trading period
data = client.execute("""
WITH first_price AS (
SELECT close
FROM stock_minute
WHERE code = %(code)s
AND timestamp >= %(start)s
AND timestamp <= %(end)s
ORDER BY timestamp
LIMIT 1
),
last_price AS (
SELECT close
FROM stock_minute
WHERE code = %(code)s
AND timestamp >= %(start)s
AND timestamp <= %(end)s
ORDER BY timestamp DESC
LIMIT 1
)
SELECT
last_price.close as last_price,
(last_price.close - first_price.close) / first_price.close * 100 as change
FROM last_price
CROSS JOIN first_price
WHERE EXISTS (SELECT 1 FROM first_price) AND EXISTS (SELECT 1 FROM last_price)
""", {
'code': code,
'start': start_datetime,
'end': end_datetime
})
if data and data[0]:
results[code] = {
'price': data[0][0],
'change': data[0][1],
'name': stock_names.get(code, 'Unknown')
}
else:
results[code] = {
'price': None,
'change': None,
'name': stock_names.get(code, 'Unknown')
}
return jsonify(results)
def get_clickhouse_client():
return Cclient(
host='111.198.58.126',
port=18778,
user='default',
password='Zzl33818!',
database='stock'
)
@app.route('/api/stock/<stock_code>/kline')
def get_stock_kline(stock_code):
"""获取股票K线数据"""
chart_type = request.args.get('chart_type', 'daily') # 默认改为daily
event_time = request.args.get('event_time')
try:
event_datetime = datetime.fromisoformat(event_time) if event_time else datetime.now()
except ValueError:
return jsonify({'error': 'Invalid event_time format'}), 400
# 获取股票名称
try:
with engine.connect() as conn:
result = conn.execute(text(
"SELECT SECNAME FROM ea_stocklist WHERE SECCODE = :code"
), {"code": stock_code.split('.')[0]}).fetchone()
stock_name = result[0] if result else 'Unknown'
except Exception as e:
print(f"Error getting stock name: {e}")
stock_name = 'Unknown'
if chart_type == 'daily':
return get_daily_kline(stock_code, event_datetime, stock_name)
elif chart_type == 'minute':
return get_minute_kline(stock_code, event_datetime, stock_name)
else:
return jsonify({
'error': 'Invalid chart type',
'message': 'Supported types: daily, minute',
'code': stock_code,
'name': stock_name
}), 400
def get_daily_kline(stock_code, event_datetime, stock_name):
"""处理日K线数据"""
stock_code = stock_code.split('.')[0]
print(f"Debug: stock_code={stock_code}, event_datetime={event_datetime}, stock_name={stock_name}")
try:
with engine.connect() as conn:
# 获取事件日期前后的数据
kline_sql = """
WITH date_range AS (
SELECT TRADEDATE
FROM ea_trade
WHERE SECCODE = :stock_code
AND TRADEDATE BETWEEN DATE_SUB(:trade_date, INTERVAL 60 DAY)
AND :trade_date
GROUP BY TRADEDATE
ORDER BY TRADEDATE
)
SELECT t.TRADEDATE,
CAST(t.F003N AS FLOAT) as open,
CAST(t.F007N AS FLOAT) as close,
CAST(t.F005N AS FLOAT) as high,
CAST(t.F006N AS FLOAT) as low,
CAST(t.F004N AS FLOAT) as volume
FROM ea_trade t
JOIN date_range d ON t.TRADEDATE = d.TRADEDATE
WHERE t.SECCODE = :stock_code
ORDER BY t.TRADEDATE DESC
"""
result = conn.execute(text(kline_sql), {
"stock_code": stock_code,
"trade_date": event_datetime.date()
}).fetchall()
print(f"Debug: Query result count: {len(result)}")
if not result:
print("Debug: No data found, trying fallback query...")
# 如果没有数据,尝试获取最近的交易数据
fallback_sql = """
SELECT TRADEDATE,
CAST(F003N AS FLOAT) as open,
CAST(F007N AS FLOAT) as close,
CAST(F005N AS FLOAT) as high,
CAST(F006N AS FLOAT) as low,
CAST(F004N AS FLOAT) as volume
FROM ea_trade
WHERE SECCODE = :stock_code
AND TRADEDATE <= :trade_date
AND F003N IS NOT NULL
AND F007N IS NOT NULL
AND F005N IS NOT NULL
AND F006N IS NOT NULL
AND F004N IS NOT NULL
ORDER BY TRADEDATE DESC
LIMIT 100
"""
result = conn.execute(text(fallback_sql), {
"stock_code": stock_code,
"trade_date": event_datetime.date()
}).fetchall()
print(f"Debug: Fallback query result count: {len(result)}")
if not result:
return jsonify({
'error': 'No data available',
'code': stock_code,
'name': stock_name,
'data': [],
'trade_date': event_datetime.date().strftime('%Y-%m-%d'),
'type': 'daily'
})
kline_data = []
for row in result:
try:
kline_data.append({
'time': row.TRADEDATE.strftime('%Y-%m-%d'),
'open': float(row.open) if row.open else 0,
'high': float(row.high) if row.high else 0,
'low': float(row.low) if row.low else 0,
'close': float(row.close) if row.close else 0,
'volume': float(row.volume) if row.volume else 0
})
except (ValueError, TypeError) as e:
print(f"Debug: Error processing row: {e}")
continue
print(f"Debug: Final kline_data count: {len(kline_data)}")
return jsonify({
'code': stock_code,
'name': stock_name,
'data': kline_data,
'trade_date': event_datetime.date().strftime('%Y-%m-%d'),
'type': 'daily',
'is_history': True,
'data_count': len(kline_data)
})
except Exception as e:
print(f"Error in get_daily_kline: {e}")
return jsonify({
'error': f'Database error: {str(e)}',
'code': stock_code,
'name': stock_name,
'data': [],
'trade_date': event_datetime.date().strftime('%Y-%m-%d'),
'type': 'daily'
}), 500
def get_minute_kline(stock_code, event_datetime, stock_name):
"""处理分钟K线数据"""
client = get_clickhouse_client()
def get_trading_days():
trading_days = set()
with open('tdays.csv', 'r') as f:
reader = csv.DictReader(f)
for row in reader:
trading_days.add(datetime.strptime(row['DateTime'], '%Y/%m/%d').date())
return trading_days
trading_days = get_trading_days()
def find_next_trading_day(current_date):
"""找到下一个交易日"""
while current_date <= max(trading_days):
current_date += timedelta(days=1)
if current_date in trading_days:
return current_date
return None
def find_prev_trading_day(current_date):
"""找到前一个交易日"""
while current_date >= min(trading_days):
current_date -= timedelta(days=1)
if current_date in trading_days:
return current_date
return None
target_date = event_datetime.date()
is_after_market = event_datetime.time() > dt_time(15, 0)
# 核心逻辑改动:先判断当前日期是否是交易日,以及是否已收盘
if target_date in trading_days and is_after_market:
# 如果是交易日且已收盘,查找下一个交易日
next_trade_date = find_next_trading_day(target_date)
if next_trade_date:
target_date = next_trade_date
elif target_date not in trading_days:
# 如果不是交易日,先尝试找下一个交易日
next_trade_date = find_next_trading_day(target_date)
if next_trade_date:
target_date = next_trade_date
else:
# 如果找不到下一个交易日,找最近的历史交易日
target_date = find_prev_trading_day(target_date)
if not target_date:
return jsonify({
'error': 'No data available',
'code': stock_code,
'name': stock_name,
'data': [],
'trade_date': event_datetime.date().strftime('%Y-%m-%d'),
'type': 'minute'
})
# 获取目标日期的完整交易时段数据
data = client.execute("""
SELECT
timestamp,
open,
high,
low,
close,
volume,
amt
FROM stock_minute
WHERE code = %(code)s
AND timestamp BETWEEN %(start)s AND %(end)s
ORDER BY timestamp
""", {
'code': stock_code,
'start': datetime.combine(target_date, dt_time(9, 30)),
'end': datetime.combine(target_date, dt_time(15, 0))
})
kline_data = [{
'time': row[0].strftime('%H:%M'),
'open': float(row[1]),
'high': float(row[2]),
'low': float(row[3]),
'close': float(row[4]),
'volume': float(row[5]),
'amount': float(row[6])
} for row in data]
return jsonify({
'code': stock_code,
'name': stock_name,
'data': kline_data,
'trade_date': target_date.strftime('%Y-%m-%d'),
'type': 'minute',
'is_history': target_date < event_datetime.date()
})
@app.route('/api/related-stock/add', methods=['POST'])
@login_required
def add_related_stock():
data = request.json
event_id = data.get('event_id')
stock_code = data.get('stock_code')
relation_desc = data.get('relation_desc')
# 验证股票是否存在
client = get_clickhouse_client()
stock_exists = client.execute(
"SELECT 1 FROM stock_minute WHERE code = %(code)s LIMIT 1",
{'code': stock_code}
)
if not stock_exists:
return jsonify({
'success': False,
'message': '股票代码不存在'
})
try:
related_stock = RelatedStock(
event_id=event_id,
stock_code=stock_code,
relation_desc=relation_desc
)
db.session.add(related_stock)
db.session.commit()
return jsonify({'success': True})
except Exception as e:
db.session.rollback()
return jsonify({
'success': False,
'message': str(e)
})
@app.route('/api/related-stock/<int:stock_id>', methods=['DELETE'])
@login_required
def delete_related_stock(stock_id):
try:
stock = RelatedStock.query.get_or_404(stock_id)
db.session.delete(stock)
db.session.commit()
return jsonify({'success': True})
except Exception as e:
db.session.rollback()
return jsonify({
'success': False,
'message': str(e)
})
# 事件相关路由
@app.route('/event/create', methods=['GET', 'POST'])
@token_required
def create_event():
"""创建新事件"""
if request.method == 'POST':
try:
app.logger.info("Received event creation request")
# 获取表单数据
title = request.form.get('title')
description = request.form.get('description')
event_type = request.form.get('event_type')
if not all([title, description, event_type]):
return jsonify({
'success': False,
'message': '请填写所有必填字段'
}), 400
# 创建新事件
event = Event(
title=title,
description=description,
event_type=event_type,
creator_id=request.user.id,
status='active'
)
# 处理可选字段
if request.form.get('is_top'):
event.is_top = True
if request.form.getlist('keywords'):
keywords = request.form.getlist('keywords')
# 确保关键词是UTF-8编码
keywords = [keyword.encode('utf-8').decode('utf-8') for keyword in keywords]
event.keywords = json.dumps(keywords, ensure_ascii=False)
if request.form.getlist('related_stocks'):
event.related_stocks = json.dumps(request.form.getlist('related_stocks'))
if request.form.getlist('related_industries'):
event.related_industries = json.dumps(request.form.getlist('related_industries'))
db.session.add(event)
db.session.commit()
app.logger.info(f"Event created successfully with ID: {event.id}")
return jsonify({
'success': True,
'event_id': event.id,
'message': '话题创建成功!'
})
except Exception as e:
db.session.rollback()
app.logger.error(f"Error creating event: {str(e)}")
return jsonify({
'success': False,
'message': f'创建失败:{str(e)}'
}), 500
return render_template('projects/create_event.html')
@app.route('/file-upload', methods=['POST'])
@token_required
def upload_file():
"""处理文件上传"""
try:
if 'file' not in request.files:
return jsonify({'error': 'No file part'}), 400
file = request.files['file']
if file.filename == '':
return jsonify({'error': 'No selected file'}), 400
if file and allowed_file(file.filename):
filename = secure_filename(file.filename)
# 生成带时间戳的文件名以避免冲突
unique_filename = f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_{filename}"
filepath = os.path.join(app.config['UPLOAD_FOLDER'], 'events', unique_filename)
# 确保上传目录存在
os.makedirs(os.path.dirname(filepath), exist_ok=True)
file.save(filepath)
return jsonify({
'success': True,
'file': unique_filename,
'url': url_for('static', filename=f'uploads/events/{unique_filename}')
})
return jsonify({'error': 'File type not allowed'}), 400
except Exception as e:
app.logger.error(f"Error uploading file: {str(e)}")
return jsonify({'error': 'Upload failed'}), 500
@app.route('/data/concepts/<path:filename>')
def concept_images(filename):
return send_from_directory(os.path.join(app.root_path, 'data', 'concepts'), filename)
class HistoricalEvent(db.Model):
"""历史事件模型"""
id = db.Column(db.Integer, primary_key=True)
event_id = db.Column(db.Integer, db.ForeignKey('event.id'))
title = db.Column(db.String(200))
content = db.Column(db.Text)
event_date = db.Column(db.DateTime)
relevance = db.Column(db.Integer) # 相关性
importance = db.Column(db.Integer) # 重要程度
related_stock = db.Column(db.JSON) # 保留JSON字段
created_at = db.Column(db.DateTime, default=beijing_now)
# 新增关系
stocks = db.relationship('HistoricalEventStock', backref='historical_event', lazy='dynamic',
cascade='all, delete-orphan')
class HistoricalEventStock(db.Model):
"""历史事件相关股票模型"""
__tablename__ = 'historical_event_stocks'
id = db.Column(db.Integer, primary_key=True)
historical_event_id = db.Column(db.Integer, db.ForeignKey('historical_event.id'), nullable=False)
stock_code = db.Column(db.String(20), nullable=False)
stock_name = db.Column(db.String(50))
relation_desc = db.Column(db.Text)
correlation = db.Column(db.Float, default=0.5)
sector = db.Column(db.String(100))
created_at = db.Column(db.DateTime, default=beijing_now)
__table_args__ = (
db.UniqueConstraint('historical_event_id', 'stock_code', name='unique_event_stock'),
)
@app.route('/api/historical-event/<int:event_id>/stocks')
def get_historical_event_stocks(event_id):
"""获取历史事件的相关股票"""
historical_event = HistoricalEvent.query.get_or_404(event_id)
stocks = historical_event.stocks.all()
return jsonify({
'success': True,
'stocks': [{
'id': stock.id,
'stock_code': stock.stock_code,
'stock_name': stock.stock_name,
'relation_desc': stock.relation_desc,
'correlation': stock.correlation,
'sector': stock.sector
} for stock in stocks]
})
@app.route('/api/related-data/<int:data_id>')
def get_related_data_details(data_id):
"""获取关联数据详情"""
data = RelatedData.query.get_or_404(data_id)
return jsonify({
'id': data.id,
'title': data.title,
'data_type': data.data_type,
'data_content': data.data_content,
'description': data.description
})
@app.route('/event/follow/<int:event_id>', methods=['POST'])
@token_required
def follow_event(event_id):
"""关注/取消关注事件"""
event = Event.query.get_or_404(event_id)
follow = EventFollow.query.filter_by(
user_id=request.user.id,
event_id=event_id
).first()
try:
if follow:
db.session.delete(follow)
event.follower_count -= 1
message = '已取消关注'
else:
follow = EventFollow(user_id=request.user.id, event_id=event_id)
db.session.add(follow)
event.follower_count += 1
message = '已关注'
db.session.commit()
return jsonify({'success': True, 'message': message})
except Exception as e:
db.session.rollback()
return jsonify({'success': False, 'message': '操作失败,请重试'})
# 帖子相关路由
@app.route('/post/create/<int:event_id>', methods=['GET', 'POST'])
@token_required
def create_post(event_id):
"""创建新帖子"""
event = Event.query.get_or_404(event_id)
if request.method == 'POST':
try:
post = Post(
event_id=event_id,
user_id=request.user.id,
title=request.form.get('title'),
content=request.form['content'],
content_type=request.form.get('content_type', 'text')
)
db.session.add(post)
event.post_count += 1
db.session.commit()
# 检查是否是 API 请求(通过 Accept header 或 Content-Type 判断)
if request.headers.get('Accept') == 'application/json' or \
request.headers.get('Content-Type', '').startswith('application/json'):
return jsonify({
'success': True,
'message': '发布成功',
'data': {
'post_id': post.id,
'event_id': event_id,
'redirect_url': url_for('event_detail', event_id=event_id)
}
})
else:
# 传统表单提交,添加成功消息并重定向
flash('发布成功', 'success')
return redirect(url_for('event_detail', event_id=event_id))
except Exception as e:
db.session.rollback()
if request.headers.get('Accept') == 'application/json' or \
request.headers.get('Content-Type', '').startswith('application/json'):
return jsonify({
'success': False,
'message': '发布失败,请重试'
}), 400
else:
flash('发布失败,请重试', 'error')
app.logger.error(f"Error creating post: {str(e)}")
return render_template('projects/create_post.html', event=event)
# 点赞相关路由
@app.route('/post/like/<int:post_id>', methods=['POST'])
@token_required
def like_post(post_id):
"""点赞/取消点赞帖子"""
post = Post.query.get_or_404(post_id)
like = PostLike.query.filter_by(
user_id=request.user.id,
post_id=post_id
).first()
try:
if like:
# 取消点赞
db.session.delete(like)
post.likes_count -= 1
message = '已取消点赞'
else:
# 添加点赞
like = PostLike(user_id=request.user.id, post_id=post_id)
db.session.add(like)
post.likes_count += 1
message = '已点赞'
# 可以在这里添加点赞通知
if post.user_id != request.user.id:
notify_user_post_liked(post)
db.session.commit()
return jsonify({
'success': True,
'message': message,
'likes_count': post.likes_count
})
except Exception as e:
db.session.rollback()
return jsonify({'success': False, 'message': '操作失败,请重试'})
# 通知相关函数
def notify_user_post_liked(post):
"""当用户的帖子被点赞时发送通知"""
try:
notification = Notification(
user_id=post.user_id,
type='post_like',
content=f'{request.user.username} 点赞了你的帖子',
link=url_for('event_detail', event_id=post.event_id, _anchor=f'post-{post.id}'),
related_user_id=request.user.id,
related_post_id=post.id
)
db.session.add(notification)
# 如果用户开启了邮件通知
user = User.query.get(post.user_id)
if user.email_notifications:
send_notification_email(
recipient=user.email,
subject='你的帖子收到了新的点赞',
template='emails/notification_post_liked.html',
user=user,
post=post,
liker=request.user
)
except Exception as e:
app.logger.error(f"Error creating like notification: {str(e)}")
# 通知创建失败不应影响主要功能
pass
def notify_user_post_commented(post):
"""当用户的帖子收到评论时发送通知"""
try:
notification = Notification(
user_id=post.user_id,
type='post_comment',
content=f'{request.user.username} 评论了你的帖子',
link=url_for('event_detail', event_id=post.event_id, _anchor=f'post-{post.id}'),
related_user_id=request.user.id,
related_post_id=post.id
)
db.session.add(notification)
# 如果用户开启了邮件通知
user = User.query.get(post.user_id)
if user.email_notifications:
send_notification_email(
recipient=user.email,
subject='你的帖子收到了新的评论',
template='emails/notification_post_commented.html',
user=user,
post=post,
commenter=request.user
)
except Exception as e:
app.logger.error(f"Error creating comment notification: {str(e)}")
pass
def notify_user_comment_replied(parent_comment):
"""当用户的评论被回复时发送通知"""
try:
notification = Notification(
user_id=parent_comment.user_id,
type='comment_reply',
content=f'{request.user.username} 回复了你的评论',
link=url_for('event_detail',
event_id=parent_comment.post.event_id,
_anchor=f'comment-{parent_comment.id}'),
related_user_id=request.user.id,
related_post_id=parent_comment.post_id,
related_comment_id=parent_comment.id
)
db.session.add(notification)
# 如果用户开启了邮件通知
user = User.query.get(parent_comment.user_id)
if user.email_notifications:
send_notification_email(
recipient=user.email,
subject='你的评论收到了新的回复',
template='emails/notification_comment_replied.html',
user=user,
comment=parent_comment,
replier=request.user
)
except Exception as e:
app.logger.error(f"Error creating reply notification: {str(e)}")
pass
def cleanup_old_notifications():
"""清理30天前的已读通知"""
with app.app_context():
try:
thirty_days_ago = beijing_now() - timedelta(days=30)
Notification.query.filter(
Notification.created_at < thirty_days_ago,
Notification.is_read == True
).delete()
db.session.commit()
current_app.logger.info("Successfully cleaned up old notifications")
except Exception as e:
db.session.rollback()
current_app.logger.error(f"Error cleaning up notifications: {str(e)}")
def update_user_activity():
"""更新用户活跃度"""
with app.app_context():
try:
# 获取过去7天内的用户活动数据
seven_days_ago = beijing_now() - timedelta(days=7)
# 统计用户发帖、评论、点赞等活动
active_users = db.session.query(
User.id,
db.func.count(Post.id).label('post_count'),
db.func.count(Comment.id).label('comment_count'),
db.func.count(PostLike.id).label('like_count')
).outerjoin(Post, User.id == Post.user_id) \
.outerjoin(Comment, User.id == Comment.user_id) \
.outerjoin(PostLike, User.id == PostLike.user_id) \
.filter(
db.or_(
Post.created_at >= seven_days_ago,
Comment.created_at >= seven_days_ago,
PostLike.created_at >= seven_days_ago
)
).group_by(User.id).all()
# 更新用户活跃度分数
for user_id, post_count, comment_count, like_count in active_users:
activity_score = post_count * 2 + comment_count * 1 + like_count * 0.5
User.query.filter_by(id=user_id).update({
'activity_score': activity_score,
'last_active': beijing_now()
})
db.session.commit()
current_app.logger.info("Successfully updated user activity scores")
except Exception as e:
db.session.rollback()
current_app.logger.error(f"Error updating user activity: {str(e)}")
@app.route('/post/comment/<int:post_id>', methods=['POST'])
@token_required
def add_comment(post_id):
"""添加评论"""
post = Post.query.get_or_404(post_id)
try:
content = request.form.get('content')
parent_id = request.form.get('parent_id', type=int)
if not content:
return jsonify({'success': False, 'message': '评论内容不能为空'})
comment = Comment(
post_id=post_id,
user_id=request.user.id,
content=content,
parent_id=parent_id
)
db.session.add(comment)
post.comments_count += 1
# 如果是回复评论,可以添加通知
if parent_id:
parent_comment = Comment.query.get(parent_id)
if parent_comment and parent_comment.user_id != request.user.id:
notify_user_comment_replied(parent_comment)
# 如果是评论帖子,通知帖子作者
elif post.user_id != request.user.id:
notify_user_post_commented(post)
db.session.commit()
return jsonify({
'success': True,
'message': '评论成功',
'comment': {
'id': comment.id,
'content': comment.content,
'user_name': request.user.username,
'user_avatar': get_full_avatar_url(request.user.avatar_url), # 修改这里
'created_at': comment.created_at.strftime('%Y-%m-%d %H:%M:%S')
}
})
except Exception as e:
db.session.rollback()
return jsonify({'success': False, 'message': '评论失败,请重试'})
@app.route('/post/comments/<int:post_id>')
@token_required
def get_comments(post_id):
"""获取帖子评论列表"""
page = request.args.get('page', 1, type=int)
# 获取顶层评论
comments = Comment.query.filter_by(
post_id=post_id,
parent_id=None,
status='active'
).order_by(
Comment.created_at.desc()
).paginate(page=page, per_page=20)
# 同时获取每个顶层评论的部分回复
comments_data = []
for comment in comments.items:
replies = Comment.query.filter_by(
parent_id=comment.id,
status='active'
).order_by(
Comment.created_at.asc()
).limit(3).all()
comments_data.append({
'id': comment.id,
'content': comment.content,
'user': {
'id': comment.user.id,
'username': comment.user.username,
'avatar_url': get_full_avatar_url(comment.user.avatar_url), # 修改这里
},
'created_at': comment.created_at.strftime('%Y-%m-%d %H:%M:%S'),
'replies': [{
'id': reply.id,
'content': reply.content,
'user': {
'id': reply.user.id,
'username': reply.user.username,
'avatar_url': get_full_avatar_url(reply.user.avatar_url), # 修改这里
},
'created_at': reply.created_at.strftime('%Y-%m-%d %H:%M:%S')
} for reply in replies]
})
return jsonify({
'comments': comments_data,
'total': comments.total,
'pages': comments.pages,
'current_page': comments.page
})
beijing_tz = pytz.timezone('Asia/Shanghai')
def update_hot_scores():
"""
更新所有事件的热度分数
在Flask应用上下文中执行数据库操作
"""
with app.app_context():
try:
# 获取所有活跃事件
events = Event.query.filter_by(status='active').all()
current_time = beijing_now()
for event in events:
# 确保created_at有时区信息解决naive和aware datetime比较问题
created_at = beijing_tz.localize(
event.created_at) if event.created_at.tzinfo is None else event.created_at
# 使用处理后的created_at计算hours_passed
hours_passed = (current_time - created_at).total_seconds() / 3600
# 基础分数 - 帖子数和评论数
posts = Post.query.filter_by(event_id=event.id).all()
post_count = len(posts)
comment_count = sum(post.comments_count for post in posts)
# 获取24小时内的新增帖子数
recent_posts = Post.query.filter(
Post.event_id == event.id,
Post.created_at >= current_time - timedelta(hours=24)
).count()
# 获取点赞数
like_count = db.session.query(func.sum(Post.likes_count)).filter(
Post.event_id == event.id
).scalar() or 0
# 基础互动分数 = 帖子数 * 2 + 评论数 * 1 + 点赞数 * 0.5
interaction_score = (post_count * 2) + (comment_count * 1) + (like_count * 0.5)
# 关注度分数 = 关注人数 * 3
follow_score = event.follower_count * 3
# 浏览量分数 = log(浏览量)
if event.view_count > 0:
view_score = math.log(event.view_count) * 2
else:
view_score = 0
# 时间衰减因子 - 使用上面已经计算好的hours_passed
time_decay = math.exp(-hours_passed / 72) # 3天后衰减为原始分数的1/e
# 最近活跃度权重
recent_activity_weight = (recent_posts * 5) # 24小时内的新帖权重高
# 总分 = (互动分数 + 关注度分数 + 浏览量分数 + 最近活跃度) * 时间衰减
total_score = (interaction_score + follow_score + view_score + recent_activity_weight) * time_decay
# 更新热度分数
event.hot_score = round(total_score, 2)
# 分数的对数值作为事件的trending_score (用于趋势排序)
if total_score > 0:
event.trending_score = math.log(total_score) * time_decay
else:
event.trending_score = 0
# 记录热度历史
history = EventHotHistory(
event_id=event.id,
score=event.hot_score,
interaction_score=interaction_score,
follow_score=follow_score,
view_score=view_score,
recent_activity_score=recent_activity_weight,
time_decay=time_decay
)
db.session.add(history)
db.session.commit()
app.logger.info("Successfully updated event hot scores")
except Exception as e:
db.session.rollback()
app.logger.error(f"Error updating hot scores: {str(e)}")
raise
# 添加热度历史记录模型
def calculate_hot_score(event):
"""计算事件热度分数"""
current_time = beijing_now()
time_diff = (current_time - event.created_at).total_seconds() / 3600 # 转换为小时
# 基础分数 = 浏览量 * 0.1 + 帖子数 * 0.5 + 关注数 * 1
base_score = (
event.view_count * 0.1 +
event.post_count * 0.5 +
event.follower_count * 1
)
# 时间衰减因子72小时(3天)内的事件获得较高权重
time_factor = max(1 - (time_diff / 72), 0.1)
return base_score * time_factor
@app.route('/api/sector/hierarchy', methods=['GET'])
def api_sector_hierarchy():
"""行业层级关系接口:展示多个行业分类体系的层级结构"""
try:
# 定义需要返回的行业分类体系
classification_systems = [
'申银万国行业分类',
'巨潮行业分类',
'新财富行业分类',
'证监会行业分类2001',
'中上协行业分类'
]
result = [] # 改为数组
for classification in classification_systems:
# 查询特定分类标准的数据
sectors = SectorInfo.query.filter_by(F002V=classification).all()
if not sectors:
continue
# 构建该分类体系的层级结构
hierarchy = {}
for sector in sectors:
level1 = sector.F004V # 一级行业
level2 = sector.F005V # 二级行业
level3 = sector.F006V # 三级行业
level4 = sector.F007V # 四级行业
# 统计股票数量
stock_code = sector.SECCODE
# 初始化一级行业
if level1 not in hierarchy:
hierarchy[level1] = {
'level2_sectors': {},
'stocks': set(),
'stocks_count': 0
}
# 添加股票到一级行业
if stock_code:
hierarchy[level1]['stocks'].add(stock_code)
# 处理二级行业
if level2:
if level2 not in hierarchy[level1]['level2_sectors']:
hierarchy[level1]['level2_sectors'][level2] = {
'level3_sectors': {},
'stocks': set(),
'stocks_count': 0
}
# 添加股票到二级行业
if stock_code:
hierarchy[level1]['level2_sectors'][level2]['stocks'].add(stock_code)
# 处理三级行业
if level3:
if level3 not in hierarchy[level1]['level2_sectors'][level2]['level3_sectors']:
hierarchy[level1]['level2_sectors'][level2]['level3_sectors'][level3] = {
'level4_sectors': [],
'stocks': set(),
'stocks_count': 0
}
# 添加股票到三级行业
if stock_code:
hierarchy[level1]['level2_sectors'][level2]['level3_sectors'][level3]['stocks'].add(
stock_code)
# 处理四级行业
if level4 and level4 not in \
hierarchy[level1]['level2_sectors'][level2]['level3_sectors'][level3]['level4_sectors']:
hierarchy[level1]['level2_sectors'][level2]['level3_sectors'][level3][
'level4_sectors'].append(level4)
# 计算股票数量并清理set对象
formatted_hierarchy = []
for level1, level1_data in hierarchy.items():
level1_item = {
'level1_sector': level1,
'stocks_count': len(level1_data['stocks']),
'level2_sectors': []
}
for level2, level2_data in level1_data['level2_sectors'].items():
level2_item = {
'level2_sector': level2,
'stocks_count': len(level2_data['stocks']),
'level3_sectors': []
}
for level3, level3_data in level2_data['level3_sectors'].items():
level3_item = {
'level3_sector': level3,
'stocks_count': len(level3_data['stocks']),
'level4_sectors': level3_data['level4_sectors']
}
level2_item['level3_sectors'].append(level3_item)
# 按股票数量排序
level2_item['level3_sectors'].sort(key=lambda x: x['stocks_count'], reverse=True)
level1_item['level2_sectors'].append(level2_item)
# 按股票数量排序
level1_item['level2_sectors'].sort(key=lambda x: x['stocks_count'], reverse=True)
formatted_hierarchy.append(level1_item)
# 按股票数量排序
formatted_hierarchy.sort(key=lambda x: x['stocks_count'], reverse=True)
# 将该分类体系添加到结果数组中
result.append({
'classification_name': classification,
'total_level1_count': len(formatted_hierarchy),
'total_stocks_count': sum(item['stocks_count'] for item in formatted_hierarchy),
'hierarchy': formatted_hierarchy
})
# 按总股票数量排序
result.sort(key=lambda x: x['total_stocks_count'], reverse=True)
return jsonify({
"code": 200,
"message": "success",
"data": result
})
except Exception as e:
return jsonify({
"code": 500,
"message": str(e),
"data": None
}), 500
@app.route('/api/sector/hierarchy/simple', methods=['GET'])
def api_sector_hierarchy_simple():
"""简化版行业层级关系接口:只展示到三级分类"""
try:
# 查询所有申银万国行业分类数据
sectors = SectorInfo.query.filter_by(F002V='申银万国行业分类').all()
# 构建简化层级结构
hierarchy = {}
for sector in sectors:
sw_primary = sector.F004V # 申万一级行业
sw_secondary = sector.F005V # 申万二级行业
# 获取对应的主板块分类
primary_sector = get_primary_sector_by_sw_primary(sw_primary)
# 初始化主板块
if primary_sector not in hierarchy:
hierarchy[primary_sector] = {}
# 初始化申万一级行业
if sw_primary not in hierarchy[primary_sector]:
hierarchy[primary_sector][sw_primary] = set()
# 添加申万二级行业
if sw_secondary:
hierarchy[primary_sector][sw_primary].add(sw_secondary)
# 格式化输出结构
result = []
for primary_sector, sw_primaries in hierarchy.items():
primary_data = {
'primary_sector': primary_sector,
'sw_primary_sectors': []
}
for sw_primary, sw_secondaries in sw_primaries.items():
sw_primary_item = {
'sw_primary_sector': sw_primary,
'sw_secondary_sectors': sorted(list(sw_secondaries))
}
primary_data['sw_primary_sectors'].append(sw_primary_item)
result.append(primary_data)
# 按主板块名称排序
result.sort(key=lambda x: x['primary_sector'])
return jsonify({
"code": 200,
"message": "success",
"data": result
})
except Exception as e:
return jsonify({
"code": 500,
"message": str(e),
"data": None
}), 500
@app.route('/api/sector/mapping', methods=['GET'])
def api_sector_mapping():
"""行业映射关系接口展示primary_sector到申万一级行业的映射关系"""
try:
# 使用现有的映射关系
sector_map = {
# 大周期
'石油石化': '大周期', '煤炭': '大周期', '有色金属': '大周期',
'钢铁': '大周期', '基础化工': '大周期', '建筑材料': '大周期',
'机械设备': '大周期', '电力设备及新能源': '大周期', '国防军工': '大周期',
'电力设备': '大周期', '电网设备': '大周期', '风力发电': '大周期',
'太阳能发电': '大周期', '建筑装饰': '大周期', '建筑': '大周期',
'交通运输': '大周期', '采掘': '大周期', '公用事业': '大周期',
# 大消费
'汽车': '大消费', '家用电器': '大消费', '酒类': '大消费',
'食品饮料': '大消费', '医药生物': '大消费', '纺织服饰': '大消费',
'农林牧渔': '大消费', '商贸零售': '大消费', '轻工制造': '大消费',
'消费者服务': '大消费', '美容护理': '大消费', '社会服务': '大消费',
'纺织服装': '大消费', '商业贸易': '大消费', '休闲服务': '大消费',
# 大金融地产
'银行': '大金融地产', '证券': '大金融地产', '保险': '大金融地产',
'多元金融': '大金融地产', '综合金融': '大金融地产',
'房地产': '大金融地产', '非银金融': '大金融地产',
# TMT板块
'计算机': 'TMT板块', '电子': 'TMT板块', '传媒': 'TMT板块', '通信': 'TMT板块',
# 公共产业
'环保': '公共产业板块', '综合': '公共产业板块'
}
# 重组为 primary_sector -> [sw_primary_sectors]
result = {}
for sw_primary, primary_sector in sector_map.items():
if primary_sector not in result:
result[primary_sector] = []
result[primary_sector].append(sw_primary)
# 格式化输出
formatted_result = [
{
'primary_sector': primary,
'sw_primary_sectors': sorted(sw_primaries)
}
for primary, sw_primaries in result.items()
]
# 按主板块名称排序
formatted_result.sort(key=lambda x: x['primary_sector'])
return jsonify({
"code": 200,
"message": "success",
"data": formatted_result
})
except Exception as e:
return jsonify({
"code": 500,
"message": str(e),
"data": None
}), 500
@app.route('/api/sector/banner', methods=['GET'])
def api_sector_banner():
"""行业分类 banner 接口:返回一级分类和对应二级行业列表"""
try:
# 原始映射
sector_map = {
'石油石化': '大周期', '煤炭': '大周期', '有色金属': '大周期', '钢铁': '大周期', '基础化工': '大周期',
'建筑材料': '大周期', '机械设备': '大周期', '电力设备及新能源': '大周期', '国防军工': '大周期',
'电力设备': '大周期', '电网设备': '大周期', '风力发电': '大周期', '太阳能发电': '大周期',
'建筑装饰': '大周期',
'汽车': '大消费', '家用电器': '大消费', '酒类': '大消费', '食品饮料': '大消费', '医药生物': '大消费',
'纺织服饰': '大消费', '农林牧渔': '大消费', '商贸零售': '大消费', '轻工制造': '大消费',
'消费者服务': '大消费', '美容护理': '大消费', '社会服务': '大消费',
'银行': '大金融地产', '证券': '大金融地产', '保险': '大金融地产', '多元金融': '大金融地产',
'综合金融': '大金融地产', '房地产': '大金融地产', '非银金融': '大金融地产',
'计算机': 'TMT板块', '电子': 'TMT板块', '传媒': 'TMT板块', '通信': 'TMT板块',
'交通运输': '公共产业板块', '电力公用事业': '公共产业板块', '建筑': '公共产业板块',
'环保': '公共产业板块', '综合': '公共产业板块', '公用事业': '公共产业板块'
}
# 重组结构为 一级 → [二级...]
result_dict = {}
for sub_sector, primary_sector in sector_map.items():
result_dict.setdefault(primary_sector, []).append(sub_sector)
# 格式化成列表
result_list = [
{"primary_sector": primary, "sub_sectors": subs}
for primary, subs in result_dict.items()
]
return jsonify({
"code": 200,
"message": "success",
"data": result_list
})
except Exception as e:
return jsonify({
"code": 500,
"message": str(e),
"data": None
}), 500
def get_primary_sector_by_sw_primary(sw_primary):
"""根据申万一级行业获取主板块分类"""
sector_map = {
# 大周期
'石油石化': '大周期', '煤炭': '大周期', '有色金属': '大周期',
'钢铁': '大周期', '基础化工': '大周期', '建筑材料': '大周期',
'机械设备': '大周期', '电力设备及新能源': '大周期', '国防军工': '大周期',
'电力设备': '大周期', '电网设备': '大周期', '风力发电': '大周期',
'太阳能发电': '大周期', '建筑装饰': '大周期', '建筑': '大周期',
'交通运输': '大周期', '采掘': '大周期', '公用事业': '大周期',
# 大消费
'汽车': '大消费', '家用电器': '大消费', '酒类': '大消费',
'食品饮料': '大消费', '医药生物': '大消费', '纺织服饰': '大消费',
'农林牧渔': '大消费', '商贸零售': '大消费', '轻工制造': '大消费',
'消费者服务': '大消费', '美容护理': '大消费', '社会服务': '大消费',
'纺织服装': '大消费', '商业贸易': '大消费', '休闲服务': '大消费',
# 大金融地产
'银行': '大金融地产', '证券': '大金融地产', '保险': '大金融地产',
'多元金融': '大金融地产', '综合金融': '大金融地产',
'房地产': '大金融地产', '非银金融': '大金融地产',
# TMT板块 (重点:电子 → TMT板块)
'计算机': 'TMT板块', '电子': 'TMT板块', '传媒': 'TMT板块', '通信': 'TMT板块',
# 公共产业
'环保': '公共产业板块', '综合': '公共产业板块'
}
return sector_map.get(sw_primary, '其他')
@app.route('/api/stock/<stock_id>/primary')
@token_required
def get_stock_primary_sector(stock_id):
"""获取股票申万一级行业及对应主板块分类"""
try:
# 处理股票代码,移除后缀如 .SH/.SZ
base_stock_code = stock_id.split('.')[0]
# 查找该股票的行业分类信息(以申银万国行业为标准)
sector_info = SectorInfo.query.filter(
SectorInfo.SECCODE.ilike(f"{base_stock_code}%"),
SectorInfo.F002V == '申银万国行业分类'
).first()
if not sector_info:
return jsonify({
'code': 404,
'message': f'未找到股票 {stock_id} 的行业分类信息',
'data': None
}), 404
sw_primary_sector = sector_info.F004V # 申万一级行业
primary_sector = get_primary_sector_by_sw_primary(sw_primary_sector)
result = {
'stock_code': sector_info.SECCODE,
'stock_name': sector_info.SECNAME,
'sw_primary_sector': sw_primary_sector, # 申万一级行业
'primary_sector': primary_sector # 对应的主板块分类
}
return jsonify({
'code': 200,
'message': 'success',
'data': result
})
except Exception as e:
print(f"Error in get_stock_primary_sector: {str(e)}")
return jsonify({
'code': 500,
'message': str(e),
'data': None
}), 500
@app.route('/api/stock/<stock_id>/secondary')
@token_required
def get_stock_secondary_sector(stock_id):
"""获取股票申万二级行业信息"""
try:
base_stock_code = stock_id.split('.')[0]
sector_info = SectorInfo.query.filter(
SectorInfo.SECCODE.ilike(f"{base_stock_code}%"),
SectorInfo.F002V == '申银万国行业分类'
).first()
if not sector_info:
return jsonify({
'code': 404,
'message': f'未找到股票 {stock_id} 的行业分类信息',
'data': None
}), 404
sw_primary_sector = sector_info.F004V # 申万一级行业
sw_secondary_sector = sector_info.F005V # 申万二级行业
primary_sector = get_primary_sector_by_sw_primary(sw_primary_sector)
result = {
'stock_code': sector_info.SECCODE,
'stock_name': sector_info.SECNAME,
'sw_primary_sector': sw_primary_sector,
'sw_secondary_sector': sw_secondary_sector,
'primary_sector': primary_sector # 基于申万一级行业映射
}
return jsonify({
'code': 200,
'message': 'success',
'data': result
})
except Exception as e:
print(f"Error in get_stock_secondary_sector: {str(e)}")
return jsonify({
'code': 500,
'message': str(e),
'data': None
}), 500
@app.route('/api/stock/<stock_id>/third')
@token_required
def get_stock_third_sector(stock_id):
"""获取股票申万三级行业信息 - 对应F005V"""
try:
base_stock_code = stock_id.split('.')[0]
sector_info = SectorInfo.query.filter(
SectorInfo.SECCODE.ilike(f"{base_stock_code}%"),
SectorInfo.F002V == '申银万国行业分类'
).first()
if not sector_info:
return jsonify({
'code': 404,
'message': f'未找到股票 {stock_id} 的行业分类信息',
'data': None
}), 404
sw_primary_sector = sector_info.F004V # 申万一级行业
sw_secondary_sector = sector_info.F005V # 申万二级行业
sw_third_sector = sector_info.F005V # 申万三级行业 (根据你的说明对应F005V)
primary_sector = get_primary_sector_by_sw_primary(sw_primary_sector)
result = {
'stock_code': sector_info.SECCODE,
'stock_name': sector_info.SECNAME,
'sw_primary_sector': sw_primary_sector,
'sw_secondary_sector': sw_secondary_sector,
'sw_third_sector': sw_third_sector, # 对应F005V
'primary_sector': primary_sector
}
return jsonify({
'code': 200,
'message': 'success',
'data': result
})
except Exception as e:
print(f"Error in get_stock_third_sector: {str(e)}")
return jsonify({
'code': 500,
'message': str(e),
'data': None
}), 500
@app.route('/api/stock/<stock_id>/fourth')
@token_required
def get_stock_fourth_sector(stock_id):
"""获取股票申万四级行业信息 - 对应F006V"""
try:
base_stock_code = stock_id.split('.')[0]
sector_info = SectorInfo.query.filter(
SectorInfo.SECCODE.ilike(f"{base_stock_code}%"),
SectorInfo.F002V == '申银万国行业分类'
).first()
if not sector_info:
return jsonify({
'code': 404,
'message': f'未找到股票 {stock_id} 的行业分类信息',
'data': None
}), 404
sw_primary_sector = sector_info.F004V # 申万一级行业
sw_secondary_sector = sector_info.F005V # 申万二级行业
sw_third_sector = sector_info.F005V # 申万三级行业 (F005V)
sw_fourth_sector = sector_info.F006V # 申万四级行业 (F006V)
primary_sector = get_primary_sector_by_sw_primary(sw_primary_sector)
result = {
'stock_code': sector_info.SECCODE,
'stock_name': sector_info.SECNAME,
'sw_primary_sector': sw_primary_sector,
'sw_secondary_sector': sw_secondary_sector,
'sw_third_sector': sw_third_sector,
'sw_fourth_sector': sw_fourth_sector, # 对应F006V
'primary_sector': primary_sector
}
return jsonify({
'code': 200,
'message': 'success',
'data': result
})
except Exception as e:
print(f"Error in get_stock_fourth_sector: {str(e)}")
return jsonify({
'code': 500,
'message': str(e),
'data': None
}), 500
@app.route('/api/stock/<stock_id>/fifth')
@token_required
def get_stock_fifth_sector(stock_id):
"""获取股票申万五级行业信息 - 对应F007V"""
try:
base_stock_code = stock_id.split('.')[0]
sector_info = SectorInfo.query.filter(
SectorInfo.SECCODE.ilike(f"{base_stock_code}%"),
SectorInfo.F002V == '申银万国行业分类'
).first()
if not sector_info:
return jsonify({
'code': 404,
'message': f'未找到股票 {stock_id} 的行业分类信息',
'data': None
}), 404
sw_primary_sector = sector_info.F004V # 申万一级行业
sw_secondary_sector = sector_info.F005V # 申万二级行业
sw_third_sector = sector_info.F005V # 申万三级行业 (F005V)
sw_fourth_sector = sector_info.F006V # 申万四级行业 (F006V)
sw_fifth_sector = sector_info.F007V # 申万五级行业 (F007V) - 新增字段
primary_sector = get_primary_sector_by_sw_primary(sw_primary_sector)
result = {
'stock_code': sector_info.SECCODE,
'stock_name': sector_info.SECNAME,
'sw_primary_sector': sw_primary_sector,
'sw_secondary_sector': sw_secondary_sector,
'sw_third_sector': sw_third_sector,
'sw_fourth_sector': sw_fourth_sector,
'sw_fifth_sector': sw_fifth_sector, # 对应F007V
'primary_sector': primary_sector
}
return jsonify({
'code': 200,
'message': 'success',
'data': result
})
except Exception as e:
print(f"Error in get_stock_fifth_sector: {str(e)}")
return jsonify({
'code': 500,
'message': str(e),
'data': None
}), 500
def get_calendar_events():
"""获取日历事件数据"""
date_str = request.args.get('date')
if not date_str:
return jsonify({'error': 'Date parameter is required'}), 400
try:
# 解析日期
target_date = datetime.strptime(date_str, '%Y-%m-%d')
end_date = target_date + timedelta(days=1)
# 更新查询以包含related_stocks和concepts
query = """
WITH RankedEvents AS (
SELECT
data_id,
calendar_time,
type,
star,
title,
former,
forecast,
fact,
related_stocks, -- 添加相关股票
concepts, -- 添加相关概念
primary_sectors,
inferred_tag
ROW_NUMBER() OVER (PARTITION BY title ORDER BY star DESC) as rn
FROM future_events
WHERE calendar_time BETWEEN :start_date AND :end_date
)
SELECT DISTINCT
data_id,
calendar_time,
type,
star,
title,
former,
forecast,
fact,
related_stocks,
concepts,
primary_sectors,
inferred_tag
FROM RankedEvents
WHERE rn = 1
ORDER BY calendar_time
"""
result = db.session.execute(text(query), {
'start_date': target_date,
'end_date': end_date
})
# 更新返回的事件数据格式化
events = []
for event in result:
events.append({
'calendar_time': event.calendar_time.isoformat() if event.calendar_time else None,
'type': event.type,
'star': event.star,
'title': event.title,
'former': event.former,
'forecast': event.forecast,
'fact': event.fact,
'related_stocks': event.related_stocks,
'concepts': event.concepts,
'primary_sectors': event.primary_sectors,
'inferred_tag': event.inferred_tag
})
# 提交事务
db.session.commit()
return jsonify(events)
except Exception as e:
return jsonify({'error': str(e)}), 500
# New API endpoint for hot news
@app.route('/api/hot-news')
def hot_news():
"""Get the 4 hottest news items from the last 3 days based on average price increase"""
def format_events(events):
"""Format event list to JSON-compatible data structure"""
try:
return [{
'id': event.id,
'title': event.title,
'description': event.description,
'created_at': event.updated_at.strftime('%Y-%m-%d %H:%M:%S'),
'importance': event.importance,
'creator': {
'username': event.creator.username if event.creator else 'Anonymous',
'avatar_url': get_full_avatar_url(event.creator.avatar_url) if event.creator else None
} if event.creator else None,
'related_avg_chg': event.related_avg_chg,
'related_max_chg': event.related_max_chg,
'related_week_chg': event.related_week_chg,
'post_count': event.post_count,
'follower_count': event.follower_count,
'view_count': event.view_count
} for event in events]
except Exception as e:
logger.error(f"Error formatting events: {str(e)}", exc_info=True)
return []
try:
# Calculate date 3 days ago
three_days_ago = datetime.now() - timedelta(days=3)
# Query events from last 3 days, sorted by average price increase
hot_events = Event.query.filter(
Event.status == 'active',
Event.created_at >= three_days_ago,
Event.related_avg_chg != None, # Ensure price data exists
Event.related_avg_chg > 0 # Only positive changes
).order_by(Event.related_avg_chg.desc()).limit(4).all()
# If not enough events with positive price increases, get additional popular events
if len(hot_events) < 4:
additional_events = Event.query.filter(
Event.status == 'active',
Event.created_at >= three_days_ago,
~Event.id.in_([event.id for event in hot_events])
).order_by(Event.hot_score.desc()).limit(4 - len(hot_events)).all()
hot_events.extend(additional_events)
# Format response data
events_data = format_events(hot_events)
return jsonify(events_data)
except Exception as e:
logger.error(f"Error getting hot news: {str(e)}", exc_info=True)
return jsonify({'error': str(e)}), 500
@app.route('/api/event/<int:event_id>/related-stocks')
def get_event_related_stocks(event_id):
# sector_map二级行业 → 一级行业
"""获取事件相关股票列表"""
try:
event = Event.query.get_or_404(event_id)
related_stocks = event.related_stocks.order_by(RelatedStock.correlation.desc()).all()
sector_map = {
# 大周期
'石油石化': '大周期',
'煤炭': '大周期',
'有色金属': '大周期',
'钢铁': '大周期',
'基础化工': '大周期',
'建筑材料': '大周期',
'机械设备': '大周期',
'电力设备及新能源': '大周期',
'国防军工': '大周期',
'电力设备': '大周期',
'电网设备': '大周期',
'风力发电': '大周期',
'太阳能发电': '大周期',
'建筑装饰': '大周期',
# 大消费
'汽车': '大消费',
'家用电器': '大消费',
'酒类': '大消费',
'食品饮料': '大消费',
'医药生物': '大消费',
'纺织服饰': '大消费',
'农林牧渔': '大消费',
'商贸零售': '大消费',
'轻工制造': '大消费',
'消费者服务': '大消费',
'美容护理': '大消费',
'社会服务': '大消费',
# 大金融地产
'银行': '大金融地产',
'证券': '大金融地产',
'保险': '大金融地产',
'多元金融': '大金融地产',
'综合金融': '大金融地产',
'房地产': '大金融地产',
'非银金融': '大金融地产',
# TMT
'计算机': 'TMT板块',
'电子': 'TMT板块',
'传媒': 'TMT板块',
'通信': 'TMT板块',
# 公共产业
'交通运输': '公共产业板块',
'电力公用事业': '公共产业板块',
'建筑': '公共产业板块',
'环保': '公共产业板块',
'综合': '公共产业板块',
'公用事业': '公共产业板块',
}
stocks_data = []
for stock in related_stocks:
# 处理股票代码,移除可能的后缀
base_stock_code = stock.stock_code.split('.')[0]
# 查询申万行业分类
sector_info = SectorInfo.query.filter(
SectorInfo.SECCODE.ilike(f"{base_stock_code}%"), # 使用ilike进行不区分大小写的匹配
SectorInfo.F002V == '申银万国行业分类'
).first()
# 获取申万一级行业名称
sw_sector = sector_info.F004V if sector_info else None
# 确定primary_sector
primary_sector = sector_map.get(sw_sector, '未知') if sw_sector else sector_map.get(stock.sector, '未知')
stocks_data.append({
'stock_code': stock.stock_code,
'stock_name': stock.stock_name,
'sector': stock.sector,
'sw_sector': sw_sector, # 添加申万行业分类信息
'primary_sector': primary_sector,
'relation_desc': stock.relation_desc,
'correlation': stock.correlation,
'momentum': stock.momentum
})
return jsonify({
'code': 200,
'message': 'success',
'data': stocks_data
})
except Exception as e:
print(f"Error in get_event_related_stocks: {str(e)}")
return jsonify({
'code': 500,
'message': str(e),
'data': None
}), 500
def calculate_change_distribution(stocks_data):
"""
计算涨跌幅分布统计
Args:
stocks_data: 包含股票代码和涨跌幅的数据列表
Returns:
dict: 涨跌幅分布统计
"""
distribution = {
'limit_down': 0, # 跌停
'down_over_5': 0, # 跌5%以上
'down_5_to_1': 0, # 跌5%到1%
'down_within_1': 0, # 跌1%以内
'flat': 0, # 平盘±0%
'up_within_1': 0, # 涨1%以内
'up_1_to_5': 0, # 涨1-5%
'up_over_5': 0, # 涨5%以上不涨停
'limit_up': 0 # 涨停
}
for stock in stocks_data:
change = stock.get('daily_change', 0)
stock_code = stock.get('stock_code', '')
# 判断涨跌停限制
limit_rate = get_limit_rate(stock_code)
# 判断涨停/跌停 (允许0.01%的误差)
if change <= -limit_rate + 0.01:
distribution['limit_down'] += 1
elif change >= limit_rate - 0.01:
distribution['limit_up'] += 1
elif change > 5:
distribution['up_over_5'] += 1
elif change > 1:
distribution['up_1_to_5'] += 1
elif change > 0.1:
distribution['up_within_1'] += 1
elif change >= -0.1:
distribution['flat'] += 1
elif change > -1:
distribution['down_within_1'] += 1
elif change > -5:
distribution['down_5_to_1'] += 1
else:
distribution['down_over_5'] += 1
return distribution
def get_limit_rate(stock_code):
"""
根据股票代码获取涨跌停限制比例
Args:
stock_code: 股票代码
Returns:
float: 涨跌停限制比例
"""
if not stock_code:
return 10.0
# 去掉市场后缀
clean_code = stock_code.replace('.SH', '').replace('.SZ', '').replace('.BJ', '')
# ST股票 (5%涨跌停)
if 'ST' in stock_code.upper():
return 5.0
# 科创板 (688开头, 20%涨跌停)
if clean_code.startswith('688'):
return 20.0
# 创业板注册制 (30开头, 20%涨跌停)
if clean_code.startswith('30'):
return 20.0
# 北交所 (43、83、87开头, 30%涨跌停)
if clean_code.startswith(('43', '83', '87')):
return 30.0
# 主板、中小板默认 (10%涨跌停)
return 10.0
@app.route('/api/events', methods=['GET'])
def api_get_events():
"""
获取事件列表API - 优化版本(保持完全兼容)
优化策略:
1. 使用ind_type字段简化内部逻辑
2. 批量获取股票行情,包括周涨跌计算
3. 保持原有返回数据结构不变
"""
try:
# ==================== 参数解析 ====================
# 分页参数
page = max(1, request.args.get('page', 1, type=int))
per_page = min(100, max(1, request.args.get('per_page', 10, type=int)))
# 基础筛选参数
event_type = request.args.get('type', 'all')
event_status = request.args.get('status', 'active')
importance = request.args.get('importance', 'all')
# 日期筛选参数
start_date = request.args.get('start_date')
end_date = request.args.get('end_date')
date_range = request.args.get('date_range')
recent_days = request.args.get('recent_days', type=int)
# 行业筛选参数(兼容新旧版本)
ind_type = request.args.get('ind_type', 'all')
stock_sector = request.args.get('stock_sector', 'all')
secondary_sector = request.args.get('secondary_sector', 'all')
# 如果使用旧参数映射到ind_type
if ind_type == 'all' and stock_sector != 'all':
ind_type = stock_sector
# 标签筛选参数
tag = request.args.get('tag')
tags = request.args.get('tags')
keywords = request.args.get('keywords')
# 搜索参数
search_query = request.args.get('q')
search_type = request.args.get('search_type', 'topic')
search_fields = request.args.get('search_fields', 'title,description').split(',')
# 排序参数
sort_by = request.args.get('sort', 'new')
return_type = request.args.get('return_type', 'avg')
order = request.args.get('order', 'desc')
# 收益率筛选参数
min_avg_return = request.args.get('min_avg_return', type=float)
max_avg_return = request.args.get('max_avg_return', type=float)
min_max_return = request.args.get('min_max_return', type=float)
max_max_return = request.args.get('max_max_return', type=float)
min_week_return = request.args.get('min_week_return', type=float)
max_week_return = request.args.get('max_week_return', type=float)
# 其他筛选参数
min_hot_score = request.args.get('min_hot_score', type=float)
max_hot_score = request.args.get('max_hot_score', type=float)
min_view_count = request.args.get('min_view_count', type=int)
creator_id = request.args.get('creator_id', type=int)
# 返回格式参数
include_creator = request.args.get('include_creator', 'true').lower() == 'true'
include_stats = request.args.get('include_stats', 'true').lower() == 'true'
include_related_data = request.args.get('include_related_data', 'false').lower() == 'true'
# ==================== 构建查询 ====================
query = Event.query
# 状态筛选
if event_status != 'all':
query = query.filter_by(status=event_status)
# 事件类型筛选
if event_type != 'all':
query = query.filter_by(event_type=event_type)
# 重要性筛选
if importance != 'all':
query = query.filter_by(importance=importance)
# 行业类型筛选使用ind_type字段
if ind_type != 'all':
query = query.filter_by(ind_type=ind_type)
# 创建者筛选
if creator_id:
query = query.filter_by(creator_id=creator_id)
# ==================== 日期筛选 ====================
if recent_days:
cutoff_date = datetime.now() - timedelta(days=recent_days)
query = query.filter(Event.created_at >= cutoff_date)
else:
# 处理日期范围字符串
if date_range and '' in date_range:
try:
start_date_str, end_date_str = date_range.split('')
start_date = start_date_str.strip()
end_date = end_date_str.strip()
except ValueError:
pass
# 开始日期
if start_date:
try:
if len(start_date) == 10:
start_datetime = datetime.strptime(start_date, '%Y-%m-%d')
else:
start_datetime = datetime.strptime(start_date, '%Y-%m-%d %H:%M:%S')
query = query.filter(Event.created_at >= start_datetime)
except ValueError:
pass
# 结束日期
if end_date:
try:
if len(end_date) == 10:
end_datetime = datetime.strptime(end_date, '%Y-%m-%d')
end_datetime = end_datetime.replace(hour=23, minute=59, second=59)
else:
end_datetime = datetime.strptime(end_date, '%Y-%m-%d %H:%M:%S')
query = query.filter(Event.created_at <= end_datetime)
except ValueError:
pass
# ==================== 概念/标签筛选 ====================
# 单个标签筛选
if tag:
if isinstance(db.engine.dialect, MySQLDialect):
query = query.filter(text("JSON_CONTAINS(keywords, :tag, '$')"))
query = query.params(tag=json.dumps(tag))
else:
query = query.filter(Event.keywords.cast(JSONB).contains([tag]))
# 多个标签筛选 (AND逻辑)
if tags:
tag_list = [t.strip() for t in tags.split(',') if t.strip()]
for single_tag in tag_list:
if isinstance(db.engine.dialect, MySQLDialect):
query = query.filter(text("JSON_CONTAINS(keywords, :tag, '$')"))
query = query.params(tag=json.dumps(single_tag))
else:
query = query.filter(Event.keywords.cast(JSONB).contains([single_tag]))
# 关键词筛选 (OR逻辑)
if keywords:
keyword_list = [k.strip() for k in keywords.split(',') if k.strip()]
keyword_filters = []
for keyword in keyword_list:
if isinstance(db.engine.dialect, MySQLDialect):
keyword_filters.append(text("JSON_CONTAINS(keywords, :keyword, '$')"))
else:
keyword_filters.append(Event.keywords.cast(JSONB).contains([keyword]))
if keyword_filters:
query = query.filter(or_(*keyword_filters))
# ==================== 搜索功能 ====================
if search_query:
search_terms = search_query.strip().split()
if search_type == 'stock':
# 股票搜索
query = query.join(RelatedStock).filter(
or_(
RelatedStock.stock_code.ilike(f'%{search_query}%'),
RelatedStock.stock_name.ilike(f'%{search_query}%')
)
).distinct()
elif search_type == 'all':
# 全局搜索
search_filters = []
# 文本字段搜索
for term in search_terms:
term_filters = []
if 'title' in search_fields:
term_filters.append(Event.title.ilike(f'%{term}%'))
if 'description' in search_fields:
term_filters.append(Event.description.ilike(f'%{term}%'))
if 'keywords' in search_fields:
if isinstance(db.engine.dialect, MySQLDialect):
term_filters.append(text("JSON_CONTAINS(keywords, :term, '$')"))
else:
term_filters.append(Event.keywords.cast(JSONB).contains([term]))
if term_filters:
search_filters.append(or_(*term_filters))
# 股票搜索
stock_subquery = db.session.query(RelatedStock.event_id).filter(
or_(
RelatedStock.stock_code.ilike(f'%{search_query}%'),
RelatedStock.stock_name.ilike(f'%{search_query}%')
)
).subquery()
search_filters.append(Event.id.in_(stock_subquery))
if search_filters:
query = query.filter(or_(*search_filters))
else:
# 话题搜索 (默认)
for term in search_terms:
term_filters = []
if 'title' in search_fields:
term_filters.append(Event.title.ilike(f'%{term}%'))
if 'description' in search_fields:
term_filters.append(Event.description.ilike(f'%{term}%'))
if 'keywords' in search_fields:
if isinstance(db.engine.dialect, MySQLDialect):
term_filters.append(text("JSON_CONTAINS(keywords, :term, '$')"))
else:
term_filters.append(Event.keywords.cast(JSONB).contains([term]))
if term_filters:
query = query.filter(or_(*term_filters))
# ==================== 收益率筛选 ====================
if min_avg_return is not None:
query = query.filter(Event.related_avg_chg >= min_avg_return)
if max_avg_return is not None:
query = query.filter(Event.related_avg_chg <= max_avg_return)
if min_max_return is not None:
query = query.filter(Event.related_max_chg >= min_max_return)
if max_max_return is not None:
query = query.filter(Event.related_max_chg <= max_max_return)
if min_week_return is not None:
query = query.filter(Event.related_week_chg >= min_week_return)
if max_week_return is not None:
query = query.filter(Event.related_week_chg <= max_week_return)
# ==================== 其他数值筛选 ====================
if min_hot_score is not None:
query = query.filter(Event.hot_score >= min_hot_score)
if max_hot_score is not None:
query = query.filter(Event.hot_score <= max_hot_score)
if min_view_count is not None:
query = query.filter(Event.view_count >= min_view_count)
# ==================== 排序逻辑 ====================
order_func = desc if order.lower() == 'desc' else asc
if sort_by == 'hot':
query = query.order_by(order_func(Event.hot_score))
elif sort_by == 'new':
query = query.order_by(order_func(Event.created_at))
elif sort_by == 'returns':
if return_type == 'avg':
query = query.order_by(order_func(Event.related_avg_chg))
elif return_type == 'max':
query = query.order_by(order_func(Event.related_max_chg))
elif return_type == 'week':
query = query.order_by(order_func(Event.related_week_chg))
elif sort_by == 'importance':
importance_order = case(
(Event.importance == 'S', 1),
(Event.importance == 'A', 2),
(Event.importance == 'B', 3),
(Event.importance == 'C', 4),
else_=5
)
if order.lower() == 'desc':
query = query.order_by(importance_order)
else:
query = query.order_by(desc(importance_order))
elif sort_by == 'view_count':
query = query.order_by(order_func(Event.view_count))
elif sort_by == 'follow' and hasattr(request, 'user') and request.user.is_authenticated:
# 关注的事件排序
query = query.join(EventFollow).filter(
EventFollow.user_id == request.user.id
).order_by(order_func(Event.created_at))
# ==================== 分页查询 ====================
paginated = query.paginate(page=page, per_page=per_page, error_out=False)
# ==================== 批量获取股票行情数据(优化版) ====================
# 1. 收集当前页所有事件的ID
event_ids = [event.id for event in paginated.items]
# 2. 获取所有相关股票
all_related_stocks = {}
if event_ids:
related_stocks = RelatedStock.query.filter(
RelatedStock.event_id.in_(event_ids)
).all()
# 按事件ID分组
for stock in related_stocks:
if stock.event_id not in all_related_stocks:
all_related_stocks[stock.event_id] = []
all_related_stocks[stock.event_id].append(stock)
# 3. 收集所有股票代码
all_stock_codes = []
stock_code_mapping = {} # 清理后的代码 -> 原始代码的映射
for stocks in all_related_stocks.values():
for stock in stocks:
clean_code = stock.stock_code.replace('.SH', '').replace('.SZ', '').replace('.BJ', '')
all_stock_codes.append(clean_code)
stock_code_mapping[clean_code] = stock.stock_code
# 去重
all_stock_codes = list(set(all_stock_codes))
# 4. 批量查询最近7个交易日的数据用于计算日涨跌和周涨跌
stock_price_data = {}
if all_stock_codes:
# 构建SQL查询 - 获取最近7个交易日的数据
codes_str = "'" + "', '".join(all_stock_codes) + "'"
# 获取最近7个交易日的数据
recent_trades_sql = f"""
SELECT
SECCODE,
SECNAME,
F007N as close_price,
F010N as daily_change,
TRADEDATE,
ROW_NUMBER() OVER (PARTITION BY SECCODE ORDER BY TRADEDATE DESC) as rn
FROM ea_trade
WHERE SECCODE IN ({codes_str})
AND F007N IS NOT NULL
AND TRADEDATE >= DATE_SUB(CURDATE(), INTERVAL 10 DAY)
ORDER BY SECCODE, TRADEDATE DESC
"""
result = db.session.execute(text(recent_trades_sql))
# 整理数据
for row in result.fetchall():
sec_code = row[0]
if sec_code not in stock_price_data:
stock_price_data[sec_code] = {
'stock_name': row[1],
'prices': []
}
stock_price_data[sec_code]['prices'].append({
'close_price': float(row[2]) if row[2] else 0,
'daily_change': float(row[3]) if row[3] else 0,
'trade_date': row[4],
'rank': row[5]
})
# 5. 计算日涨跌和周涨跌
stock_changes = {}
for sec_code, data in stock_price_data.items():
prices = data['prices']
# 最新日涨跌第1条记录
daily_change = 0
if prices and prices[0]['rank'] == 1:
daily_change = prices[0]['daily_change']
# 计算周涨跌(最新价 vs 5个交易日前的价格
week_change = 0
if len(prices) >= 2:
latest_price = prices[0]['close_price']
# 找到第5个交易日的数据如果有
week_ago_price = None
for price_data in prices:
if price_data['rank'] >= 5:
week_ago_price = price_data['close_price']
break
# 如果没有第5天的数据使用最早的数据
if week_ago_price is None and len(prices) > 1:
week_ago_price = prices[-1]['close_price']
if week_ago_price and week_ago_price > 0:
week_change = (latest_price - week_ago_price) / week_ago_price * 100
stock_changes[sec_code] = {
'stock_name': data['stock_name'],
'daily_change': daily_change,
'week_change': week_change
}
# ==================== 获取整体统计信息 ====================
# 获取所有筛选条件下的事件和股票(用于统计)
all_filtered_events = query.limit(500).all()
all_event_ids = [e.id for e in all_filtered_events]
overall_distribution = {
'limit_down': 0,
'down_over_5': 0,
'down_5_to_1': 0,
'down_within_1': 0,
'flat': 0,
'up_within_1': 0,
'up_1_to_5': 0,
'up_over_5': 0,
'limit_up': 0
}
if all_event_ids:
# 获取所有相关股票
all_stocks_for_stats = RelatedStock.query.filter(
RelatedStock.event_id.in_(all_event_ids)
).all()
# 统计涨跌分布
for stock in all_stocks_for_stats:
clean_code = stock.stock_code.replace('.SH', '').replace('.SZ', '').replace('.BJ', '')
if clean_code in stock_changes:
daily_change = stock_changes[clean_code]['daily_change']
# 计算涨跌停限制
limit_rate = get_limit_rate(stock.stock_code)
# 分类统计
if daily_change <= -limit_rate + 0.01:
overall_distribution['limit_down'] += 1
elif daily_change >= limit_rate - 0.01:
overall_distribution['limit_up'] += 1
elif daily_change > 5:
overall_distribution['up_over_5'] += 1
elif daily_change > 1:
overall_distribution['up_1_to_5'] += 1
elif daily_change > 0.1:
overall_distribution['up_within_1'] += 1
elif daily_change >= -0.1:
overall_distribution['flat'] += 1
elif daily_change > -1:
overall_distribution['down_within_1'] += 1
elif daily_change > -5:
overall_distribution['down_5_to_1'] += 1
else:
overall_distribution['down_over_5'] += 1
# ==================== 构建响应数据 ====================
events_data = []
for event in paginated.items:
event_stocks = all_related_stocks.get(event.id, [])
stocks_data = []
total_daily_change = 0
max_daily_change = -999
total_week_change = 0
max_week_change = -999
valid_stocks_count = 0
# 处理每个股票的数据
for stock in event_stocks:
clean_code = stock.stock_code.replace('.SH', '').replace('.SZ', '').replace('.BJ', '')
stock_info = stock_changes.get(clean_code, {})
daily_change = stock_info.get('daily_change', 0)
week_change = stock_info.get('week_change', 0)
if stock_info:
total_daily_change += daily_change
max_daily_change = max(max_daily_change, daily_change)
total_week_change += week_change
max_week_change = max(max_week_change, week_change)
valid_stocks_count += 1
stocks_data.append({
"stock_code": stock.stock_code,
"stock_name": stock.stock_name,
"sector": stock.sector,
"week_change": round(week_change, 2),
"daily_change": round(daily_change, 2)
})
avg_daily_change = total_daily_change / valid_stocks_count if valid_stocks_count > 0 else 0
avg_week_change = total_week_change / valid_stocks_count if valid_stocks_count > 0 else 0
if max_daily_change == -999:
max_daily_change = 0
if max_week_change == -999:
max_week_change = 0
# 构建事件数据
event_dict = {
'id': event.id,
'title': event.title,
'description': event.description,
'event_type': event.event_type,
'importance': event.importance,
'status': event.status,
'created_at': event.created_at.isoformat() if event.created_at else None,
'updated_at': event.updated_at.isoformat() if event.updated_at else None,
'start_time': event.start_time.isoformat() if event.start_time else None,
'end_time': event.end_time.isoformat() if event.end_time else None,
'related_stocks': stocks_data,
'stocks_stats': {
'stocks_count': len(event_stocks),
'valid_stocks_count': valid_stocks_count,
# 周涨跌统计
'avg_week_change': round(avg_week_change, 2),
'max_week_change': round(max_week_change, 2),
# 日涨跌统计
'avg_daily_change': round(avg_daily_change, 2),
'max_daily_change': round(max_daily_change, 2)
}
}
# 统计信息
if include_stats:
event_dict.update({
'hot_score': event.hot_score,
'view_count': event.view_count,
'post_count': event.post_count,
'follower_count': event.follower_count,
'related_avg_chg': event.related_avg_chg,
'related_max_chg': event.related_max_chg,
'related_week_chg': event.related_week_chg,
'invest_score': event.invest_score,
'trending_score': event.trending_score,
})
# 创建者信息
if include_creator:
event_dict['creator'] = {
'id': event.creator.id if event.creator else None,
'username': event.creator.username if event.creator else 'Anonymous',
'avatar_url': get_full_avatar_url(event.creator.avatar_url) if event.creator else None,
'is_creator': event.creator.is_creator if event.creator else False,
'creator_type': event.creator.creator_type if event.creator else None
}
# 关联数据
event_dict['keywords'] = event.keywords if isinstance(event.keywords, list) else []
event_dict['related_industries'] = event.related_industries
# 包含统计信息
if include_stats:
event_dict['stats'] = {
'related_stocks_count': len(event_stocks),
'historical_events_count': 0, # 需要额外查询
'related_data_count': 0, # 需要额外查询
'related_concepts_count': 0 # 需要额外查询
}
# 包含关联数据
if include_related_data:
event_dict['related_stocks'] = [{
'id': stock.id,
'stock_code': stock.stock_code,
'stock_name': stock.stock_name,
'sector': stock.sector,
'correlation': float(stock.correlation) if stock.correlation else 0
} for stock in event_stocks[:5]] # 限制返回5个
events_data.append(event_dict)
# ==================== 构建筛选信息 ====================
applied_filters = {}
# 记录已应用的筛选条件
if event_type != 'all':
applied_filters['type'] = event_type
if importance != 'all':
applied_filters['importance'] = importance
if start_date:
applied_filters['start_date'] = start_date
if end_date:
applied_filters['end_date'] = end_date
if tag:
applied_filters['tag'] = tag
if tags:
applied_filters['tags'] = tags
if search_query:
applied_filters['search_query'] = search_query
applied_filters['search_type'] = search_type
# ==================== 返回结果(保持完全兼容) ====================
return jsonify({
'success': True,
'data': {
'events': events_data,
'pagination': {
'page': paginated.page,
'per_page': paginated.per_page,
'total': paginated.total,
'pages': paginated.pages,
'has_prev': paginated.has_prev,
'has_next': paginated.has_next
},
'filters': {
'applied_filters': {
'type': event_type,
'status': event_status,
'importance': importance,
'stock_sector': stock_sector, # 保持兼容
'secondary_sector': secondary_sector, # 保持兼容
'sort': sort_by,
'order': order
}
},
# 整体股票涨跌幅分布统计
'overall_stats': {
'total_stocks': len(all_stocks_for_stats) if 'all_stocks_for_stats' in locals() else 0,
'change_distribution': overall_distribution,
'change_distribution_percentages': {
k: v for k, v in overall_distribution.items()
}
}
}
})
except Exception as e:
app.logger.error(f"获取事件列表出错: {str(e)}", exc_info=True)
return jsonify({
'success': False,
'error': str(e),
'error_type': type(e).__name__
}), 500
def get_filter_counts(base_query):
"""
获取各个筛选条件的计数信息
可用于前端显示筛选选项的可用数量
"""
try:
counts = {}
# 重要性计数
importance_counts = db.session.query(
Event.importance,
func.count(Event.id).label('count')
).filter(
Event.id.in_(base_query.with_entities(Event.id).subquery())
).group_by(Event.importance).all()
counts['importance'] = {item.importance or 'unknown': item.count for item in importance_counts}
# 事件类型计数
type_counts = db.session.query(
Event.event_type,
func.count(Event.id).label('count')
).filter(
Event.id.in_(base_query.with_entities(Event.id).subquery())
).group_by(Event.event_type).all()
counts['event_type'] = {item.event_type or 'unknown': item.count for item in type_counts}
return counts
except Exception:
return {}
@app.route('/api/events/filters', methods=['GET'])
def api_get_event_filters():
"""
获取事件筛选选项的统计信息
用于前端动态生成筛选器选项
"""
try:
# 基础查询 (只包含激活状态的事件)
base_query = Event.query.filter_by(status='active')
filter_counts = get_filter_counts(base_query)
return jsonify({
'success': True,
'data': {
'filter_counts': filter_counts,
'available_sorts': [
{'key': 'new', 'name': '最新', 'desc': '按创建时间排序'},
{'key': 'hot', 'name': '热门', 'desc': '按热度分数排序'},
{'key': 'returns', 'name': '收益率', 'desc': '按收益率排序'},
{'key': 'importance', 'name': '重要性', 'desc': '按重要性等级排序'},
{'key': 'view_count', 'name': '浏览量', 'desc': '按浏览次数排序'},
],
'available_return_types': [
{'key': 'avg', 'name': '平均收益率'},
{'key': 'max', 'name': '最大收益率'},
{'key': 'week', 'name': '周收益率'},
],
'available_importance_levels': [
{'key': 'S', 'name': 'S级', 'desc': '重大事件'},
{'key': 'A', 'name': 'A级', 'desc': '重要事件'},
{'key': 'B', 'name': 'B级', 'desc': '普通事件'},
{'key': 'C', 'name': 'C级', 'desc': '参考事件'},
]
}
})
except Exception as e:
app.logger.error(f"获取筛选选项出错: {str(e)}")
return jsonify({
'success': False,
'error': str(e)
}), 500
@app.route('/api/industry-classifications', methods=['GET'])
def api_get_industry_classifications():
"""
获取行业分类体系列表
返回:
{
"success": true,
"data": [
{"classification_name": "申万行业分类标准"},
{"classification_name": "证监会行业分类标准"}
]
}
"""
try:
# 获取所有行业分类体系
sql = """
SELECT DISTINCT f002v as classification_name
FROM ea_sector
WHERE f002v IS NOT NULL
ORDER BY f002v
"""
results = db.session.execute(text(sql)).fetchall()
classifications = [
{'classification_name': row.classification_name}
for row in results
]
return jsonify({
'success': True,
'data': classifications
})
except Exception as e:
app.logger.error(f"获取行业分类体系出错: {str(e)}")
return jsonify({
'success': False,
'error': str(e)
}), 500
@app.route('/api/industry-level-codes')
def industry_level_codes_api():
"""
API端点: 根据行业分类系统和层级获取行业代码和名称
参数:
classification: 行业分类系统名称
level: 行业层级 (1-4)
level1_name: 一级行业名称 (仅对level>1有效)
level2_name: 二级行业名称 (仅对level>2有效)
level3_name: 三级行业名称 (仅对level>3有效)
"""
classification = request.args.get('classification')
level = request.args.get('level', type=int)
level1_name = request.args.get('level1_name', '')
level2_name = request.args.get('level2_name', '')
level3_name = request.args.get('level3_name', '')
# 验证参数
if not classification or not level or level < 1 or level > 4:
return jsonify([])
try:
# 层级到字段的映射
level_fields = {
1: "f004v", # 一级行业
2: "f005v", # 二级行业
3: "f006v", # 三级行业
4: "f007v", # 四级行业
}
field_name = level_fields[level]
# 根据层级选择不同的查询
if level == 1:
# 一级行业查询
sql = f"""
SELECT DISTINCT {field_name} as name,
MIN(f003v) as code
FROM ea_sector
WHERE f002v = :classification
AND {field_name} IS NOT NULL
GROUP BY name
ORDER BY name
"""
params = {"classification": classification}
elif level == 2:
# 二级行业查询
sql = f"""
SELECT DISTINCT {field_name} as name,
MIN(f003v) as code
FROM ea_sector
WHERE f002v = :classification
AND f004v = :level1_name
AND {field_name} IS NOT NULL
GROUP BY name
ORDER BY name
"""
params = {"classification": classification, "level1_name": level1_name}
elif level == 3:
# 三级行业查询
sql = f"""
SELECT DISTINCT {field_name} as name,
MIN(f003v) as code
FROM ea_sector
WHERE f002v = :classification
AND f004v = :level1_name
AND f005v = :level2_name
AND {field_name} IS NOT NULL
GROUP BY name
ORDER BY name
"""
params = {
"classification": classification,
"level1_name": level1_name,
"level2_name": level2_name
}
elif level == 4:
# 四级行业查询
sql = f"""
SELECT DISTINCT f003v as code,
{field_name} as name
FROM ea_sector
WHERE f002v = :classification
AND f004v = :level1_name
AND f005v = :level2_name
AND f006v = :level3_name
AND {field_name} IS NOT NULL
ORDER BY name
"""
params = {
"classification": classification,
"level1_name": level1_name,
"level2_name": level2_name,
"level3_name": level3_name
}
# 执行查询
results = db.session.execute(text(sql), params).all()
# 转换为JSON响应
return jsonify([{"code": row.code, "name": row.name} for row in results if row.name])
except Exception as e:
app.logger.error(f"获取行业代码出错: {str(e)}")
return jsonify({"error": str(e)}), 500
@app.route('/trending')
def trending_events():
"""获取趋势事件(用于首页的趋势板块)"""
# 获取24小时内的热门事件
recent_events = Event.query.filter(
Event.created_at >= beijing_now() - timedelta(days=1),
Event.status == 'active'
).order_by(
Event.hot_score.desc()
).limit(10).all()
events_data = [{
'id': event.id,
'title': event.title,
'description': event.description[:100] + '...' if len(event.description) > 100 else event.description,
'hot_score': event.hot_score,
'post_count': event.post_count,
'follower_count': event.follower_count
} for event in recent_events]
return jsonify({'events': events_data})
def get_event_class(count):
"""根据事件数量返回对应的样式类"""
if count >= 10:
return 'bg-gradient-danger'
elif count >= 7:
return 'bg-gradient-warning'
elif count >= 4:
return 'bg-gradient-info'
else:
return 'bg-gradient-success'
@app.route('/api/calendar-event-counts')
def get_calendar_event_counts():
"""获取整月的事件数量统计仅统计type为event的事件"""
try:
# 获取当前月份的开始和结束日期
today = datetime.now()
start_date = today.replace(day=1)
if today.month == 12:
end_date = today.replace(year=today.year + 1, month=1, day=1)
else:
end_date = today.replace(month=today.month + 1, day=1)
# 修改查询以仅统计type为event的事件数量
query = """
SELECT DATE(calendar_time) as date, COUNT(*) as count
FROM future_events
WHERE calendar_time BETWEEN :start_date AND :end_date
AND type = 'event'
GROUP BY DATE(calendar_time)
"""
result = db.session.execute(text(query), {
'start_date': start_date,
'end_date': end_date
})
# 格式化结果为日历事件格式
events = [{
'title': f'{day.count} 个事件',
'start': day.date.isoformat() if day.date else None,
'className': get_event_class(day.count)
} for day in result]
return jsonify(events)
except Exception as e:
return jsonify({'error': str(e)}), 500
def get_full_avatar_url(avatar_url):
"""
统一处理头像URL确保返回完整的可访问URL
Args:
avatar_url: 头像URL字符串
Returns:
完整的头像URL如果没有头像则返回默认头像URL
"""
if not avatar_url:
# 返回默认头像
return f"{DOMAIN}/static/assets/img/default-avatar.png"
# 如果已经是完整URLhttp或https开头直接返回
if avatar_url.startswith(('http://', 'https://')):
return avatar_url
# 如果是相对路径,拼接域名
if avatar_url.startswith('/'):
return f"{DOMAIN}{avatar_url}"
else:
return f"{DOMAIN}/{avatar_url}"
# 修改User模型的to_dict方法
def to_dict(self):
"""转换为字典格式方便API返回"""
return {
'id': self.id,
'username': self.username,
'email': self.email,
'nickname': self.nickname,
'avatar_url': get_full_avatar_url(self.avatar_url), # 使用统一处理函数
'bio': self.bio,
'location': self.location,
'is_verified': self.is_verified,
'user_level': self.user_level,
'reputation_score': self.reputation_score,
'post_count': self.post_count,
'follower_count': self.follower_count,
'following_count': self.following_count,
'created_at': self.created_at.isoformat() if self.created_at else None,
'last_seen': self.last_seen.isoformat() if self.last_seen else None
}
# ==================== 标准化API接口 ====================
# 1. 首页接口
@app.route('/api/home', methods=['GET'])
def api_home():
try:
seven_days_ago = datetime.now() - timedelta(days=7)
hot_events = Event.query.filter(
Event.status == 'active',
Event.created_at >= seven_days_ago
).order_by(Event.hot_score.desc()).limit(10).all()
events_data = []
for event in hot_events:
related_stocks = RelatedStock.query.filter_by(event_id=event.id).all()
# 计算相关性统计数据
correlations = [float(stock.correlation or 0) for stock in related_stocks]
avg_correlation = sum(correlations) / len(correlations) if correlations else 0
max_correlation = max(correlations) if correlations else 0
stocks_data = []
total_week_change = 0
max_week_change = 0
total_daily_change = 0 # 新增:日涨跌幅总和
max_daily_change = 0 # 新增:最大日涨跌幅
valid_stocks_count = 0
for stock in related_stocks:
stock_code = stock.stock_code.split('.')[0]
# 获取最新交易日数据
latest_trade = db.session.execute(text("""
SELECT * FROM ea_trade
WHERE SECCODE = :stock_code
ORDER BY TRADEDATE DESC
LIMIT 1
"""), {"stock_code": stock_code}).first()
week_change = 0
daily_change = 0 # 新增:日涨跌幅
if latest_trade and latest_trade.F007N:
latest_price = float(latest_trade.F007N or 0)
latest_date = latest_trade.TRADEDATE
daily_change = float(latest_trade.F010N or 0) # F010N是日涨跌幅字段
# 更新日涨跌幅统计
total_daily_change += daily_change
max_daily_change = max(max_daily_change, daily_change)
# 获取最近5条交易记录
week_ago_trades = db.session.execute(text("""
SELECT * FROM ea_trade
WHERE SECCODE = :stock_code
AND TRADEDATE < :latest_date
ORDER BY TRADEDATE DESC
LIMIT 5
"""), {
"stock_code": stock_code,
"latest_date": latest_date
}).fetchall()
if week_ago_trades and week_ago_trades[-1].F007N:
week_ago_price = float(week_ago_trades[-1].F007N or 0)
if week_ago_price > 0:
week_change = (latest_price - week_ago_price) / week_ago_price * 100
total_week_change += week_change
max_week_change = max(max_week_change, week_change)
valid_stocks_count += 1
stocks_data.append({
"stock_code": stock.stock_code,
"stock_name": stock.stock_name,
"correlation": float(stock.correlation or 0),
"sector": stock.sector,
"week_change": round(week_change, 2),
"daily_change": round(daily_change, 2), # 新增:个股日涨跌幅
"latest_trade_date": latest_trade.TRADEDATE.strftime("%Y-%m-%d") if latest_trade else None
})
# 计算平均值
avg_week_change = total_week_change / valid_stocks_count if valid_stocks_count > 0 else 0
avg_daily_change = total_daily_change / valid_stocks_count if valid_stocks_count > 0 else 0
events_data.append({
"id": event.id,
"title": event.title,
"description": event.description,
'event_type': event.event_type,
'importance': event.importance, # 添加重要性
'status': event.status,
"created_at": event.created_at.strftime("%Y-%m-%d %H:%M:%S"),
'updated_at': event.updated_at.isoformat() if event.updated_at else None,
'start_time': event.start_time.isoformat() if event.start_time else None,
'end_time': event.end_time.isoformat() if event.end_time else None,
'view_count': event.view_count, # 添加浏览量
'post_count': event.post_count, # 添加帖子数
'follower_count': event.follower_count, # 添加关注者数
"related_stocks": stocks_data,
"stocks_stats": {
"avg_correlation": round(avg_correlation, 2),
"max_correlation": round(max_correlation, 2),
"stocks_count": len(related_stocks),
"valid_stocks_count": valid_stocks_count,
# 周涨跌统计
"avg_week_change": round(avg_week_change, 2),
"max_week_change": round(max_week_change, 2),
# 日涨跌统计
"avg_daily_change": round(avg_daily_change, 2),
"max_daily_change": round(max_daily_change, 2)
}
})
return jsonify({
"code": 200,
"message": "success",
"data": {
"events": events_data
}
})
except Exception as e:
print(f"Error in api_home: {str(e)}")
return jsonify({
"code": 500,
"message": str(e),
"data": None
}), 500
def generate_jwt_token(user_id):
payload = {
'user_id': user_id,
'exp': datetime.utcnow() + timedelta(seconds=JWT_EXPIRES_SECONDS)
}
return jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM)
def save_sms_code_to_redis(phone, code, purpose='login', expire=600):
key = f"sms:{purpose}:{phone}"
redis_client.setex(key, expire, code)
def verify_sms_code(phone, code, purpose='login'):
key = f"sms:{purpose}:{phone}"
stored_code = redis_client.get(key)
return stored_code == code
@app.route('/api/auth/send-sms', methods=['POST'])
def api_send_sms_code():
data = request.get_json()
phone = data.get('phone')
purpose = data.get('purpose', 'login')
if not phone or not re.match(r'^1[3-9]\d{9}$', phone):
return jsonify({'code': 400, 'message': '请输入正确手机号'}), 400
# 限制发送频率60秒内不能重复发送
freq_key = f"sms:freq:{phone}"
if redis_client.exists(freq_key):
return jsonify({'code': 429, 'message': '请勿频繁请求验证码'}), 429
# 检查用途:登录必须是已注册用户
user = User.query.filter_by(phone=phone).first()
if purpose == 'login' and not user:
return jsonify({'code': 404, 'message': '该手机号尚未注册'}), 404
# 发送短信并获取验证码
success, msg, verification_code = send_sms_verification_minimal(phone, redis_client)
if not success:
return jsonify({'code': 500, 'message': msg}), 500
# 保存验证码到 Redis这个验证码会被登录接口使用
save_sms_code_to_redis(phone, verification_code, purpose)
redis_client.setex(freq_key, 60, '1') # 60秒限制
# 开发环境下返回验证码用于测试
response_data = {'code': 200, 'message': '验证码已发送'}
if app.debug: # 只在调试模式下返回验证码
response_data['debug_code'] = verification_code
return jsonify(response_data)
@app.route('/api/auth/login/phone', methods=['POST'])
def login_by_phone():
"""手机验证码登录"""
data = request.get_json()
phone = data.get('phone')
code = data.get('code')
if not phone or not code:
return jsonify({'code': 400, 'message': '手机号和验证码不能为空'}), 400
# 验证验证码
if not verify_sms_code(phone, code):
return jsonify({'code': 400, 'message': '验证码错误或已过期'}), 400
# 查找用户
user = User.query.filter_by(phone=phone).first()
if not user:
return jsonify({'code': 404, 'message': '该手机号尚未注册'}), 404
# 更新用户最后登录时间
user.update_last_seen()
db.session.commit()
# 生成JWT token
token = generate_jwt_token(user.id)
return jsonify({
'code': 200,
'message': '登录成功',
'data': {
'token': token,
'user': user.to_dict()
}
})
@app.route('/api/auth/login/wechat', methods=['POST'])
def api_login_wechat():
"""微信登录接口"""
try:
data = request.get_json()
code = data.get('code') # 微信授权码
union_id = data.get('unionId') # 微信 UnionID
if not code and not union_id:
return jsonify({
'code': 400,
'message': '缺少必要的参数',
'data': None
}), 400
# 1. 通过code获取access_token和openid
# TODO: 需要您提供以下信息:
# - WECHAT_APP_ID: 微信开放平台应用ID
# - WECHAT_APP_SECRET: 微信开放平台应用密钥
wx_api_url = f"https://api.weixin.qq.com/sns/oauth2/access_token?appid={WECHAT_APP_ID}&secret={WECHAT_APP_SECRET}&code={code}&grant_type=authorization_code"
# 2. 获取用户信息
user = None
if union_id:
# 通过 union_id 查找用户
user = User.query.filter_by(wechat_union_id=union_id).first()
if not user:
# 创建新用户
username = f"wx_user_{int(time.time())}" # 生成临时用户名
user = User(
username=username,
wechat_union_id=union_id,
status='active'
)
db.session.add(user)
db.session.commit()
# 生成JWT token
token = generate_jwt_token(user.id)
return jsonify({
'code': 200,
'message': 'success',
'data': {
'token': token,
'user': {
'id': user.id,
'username': user.username,
'nickname': user.nickname,
'avatar_url': get_full_avatar_url(user.avatar_url), # 修改这里
'is_new_user': user.created_at > (datetime.now() - timedelta(minutes=1))
}
}
})
except Exception as e:
return jsonify({
'code': 500,
'message': str(e),
'data': None
}), 500
# 获取微信登录二维码
# 获取微信登录二维码
@app.route('/api/wechat/qrcode', methods=['GET'])
def get_wechat_qrcode():
"""获取微信登录二维码"""
try:
# 生成唯一的state参数
state = str(uuid.uuid4())
# 检查必要的配置
if not app.config.get('WECHAT_APP_ID'):
return jsonify({
'code': 500,
'message': '微信配置未设置',
'data': None
}), 500
# 将state保存到session或Redis中
# 如果有Redis缓存系统
# cache.set(f"wx_state_{state}", True, timeout=300)
# 如果没有缓存系统可以使用session
from flask import session
session[f"wx_state_{state}"] = True
# 构建微信授权URL
base_url = app.config.get('BASE_URL', 'http://43.143.189.195:5002')
redirect_uri = f"{base_url}/api/wechat/callback"
qrcode_url = f"https://open.weixin.qq.com/connect/qrconnect?appid={app.config['WECHAT_APP_ID']}&redirect_uri={redirect_uri}&response_type=code&scope=snsapi_login&state={state}#wechat_redirect"
return jsonify({
'code': 200,
'message': 'success',
'data': {
'qr_url': qrcode_url, # 修改为 qr_url 匹配前端
'state': state
}
})
except Exception as e:
return jsonify({
'code': 500,
'message': f'生成二维码失败: {str(e)}',
'data': None
}), 500
# 检查登录状态
# 生成微信登录二维码 - 新版本
@app.route('/api/wechat/generate-login-qr', methods=['POST'])
def generate_wechat_login_qr():
"""生成微信登录二维码"""
try:
# 生成唯一标识
state = str(uuid.uuid4())
# 保存到session
session[f"wx_state_{state}"] = {
'created_at': datetime.now(),
'status': 'pending'
}
# 构建登录URL - 这是用户扫码后访问的页面
base_url = app.config.get('BASE_URL', 'http://43.143.189.195:5002')
login_url = f"{base_url}/wechat-auth?state={state}"
return jsonify({
'code': 200,
'message': 'success',
'data': {
'login_url': login_url,
'state': state
}
})
except Exception as e:
return jsonify({
'code': 500,
'message': str(e),
'data': None
}), 500
# 微信授权页面 - 用户扫码后首先访问这里
@app.route('/wechat-auth')
def wechat_auth_page():
"""微信授权页面 - 用户扫码后访问"""
state = request.args.get('state')
if not state or not session.get(f"wx_state_{state}"):
return render_template_string("""
<html>
<head><title>登录失败</title><meta charset="utf-8"></head>
<body style="text-align: center; padding: 50px;">
<h2>登录链接无效</h2>
<p>请重新扫描二维码</p>
</body>
</html>
""")
# 构建微信授权URL - 使用网页授权接口
app_id = app.config['WECHAT_APP_ID']
base_url = app.config.get('BASE_URL', 'http://43.143.189.195:5002')
redirect_uri = quote(f"{base_url}/api/wechat/callback")
# 使用网页授权接口snsapi_userinfo
auth_url = f"https://open.weixin.qq.com/connect/oauth2/authorize?appid={app_id}&redirect_uri={redirect_uri}&response_type=code&scope=snsapi_userinfo&state={state}#wechat_redirect"
# 直接重定向到微信授权
return redirect(auth_url)
# 微信回调处理 - 完整版本
@app.route('/api/wechat/callback')
def wechat_callback():
"""微信登录回调处理"""
try:
code = request.args.get('code')
state = request.args.get('state')
app.logger.info(f"微信回调: code={code}, state={state}")
if not code or not state:
return render_template_string("""
<html>
<head><title>登录失败</title><meta charset="utf-8"></head>
<body style="text-align: center; padding: 50px;">
<h2>登录失败</h2>
<p>参数错误,请重新尝试</p>
<button onclick="window.close()">关闭</button>
</body>
</html>
""")
# 验证state
state_info = session.get(f"wx_state_{state}")
if not state_info:
return render_template_string("""
<html>
<head><title>登录失败</title><meta charset="utf-8"></head>
<body style="text-align: center; padding: 50px;">
<h2>登录失败</h2>
<p>无效的登录状态</p>
<button onclick="window.close()">关闭</button>
</body>
</html>
""")
# 1. 获取access_token
token_url = "https://api.weixin.qq.com/sns/oauth2/access_token"
token_params = {
'appid': app.config['WECHAT_APP_ID'],
'secret': app.config['WECHAT_APP_SECRET'],
'code': code,
'grant_type': 'authorization_code'
}
app.logger.info(f"请求微信token: {token_params}")
response = requests.get(token_url, params=token_params, timeout=10)
token_data = response.json()
app.logger.info(f"微信token响应: {token_data}")
if 'errcode' in token_data:
error_msg = token_data.get('errmsg', '未知错误')
app.logger.error(f"获取微信token失败: {error_msg}")
return render_template_string(f"""
<html>
<head><title>登录失败</title><meta charset="utf-8"></head>
<body style="text-align: center; padding: 50px;">
<h2>登录失败</h2>
<p>获取授权失败: {error_msg}</p>
<button onclick="window.close()">关闭</button>
</body>
</html>
""")
access_token = token_data['access_token']
openid = token_data['openid']
# 2. 获取用户信息
user_info_url = "https://api.weixin.qq.com/sns/userinfo"
user_params = {
'access_token': access_token,
'openid': openid,
'lang': 'zh_CN'
}
user_response = requests.get(user_info_url, params=user_params, timeout=10)
user_info = user_response.json()
app.logger.info(f"微信用户信息: {user_info}")
if 'errcode' in user_info:
error_msg = user_info.get('errmsg', '未知错误')
app.logger.error(f"获取微信用户信息失败: {error_msg}")
return render_template_string(f"""
<html>
<head><title>登录失败</title><meta charset="utf-8"></head>
<body style="text-align: center; padding: 50px;">
<h2>登录失败</h2>
<p>获取用户信息失败: {error_msg}</p>
<button onclick="window.close()">关闭</button>
</body>
</html>
""")
# 3. 处理用户登录逻辑
unionid = user_info.get('unionid', openid) # 如果没有unionid就用openid
user = User.query.filter_by(wechat_union_id=unionid).first()
if not user:
# 创建新用户
username = f"wx_user_{int(time.time())}"
email = f"wx_{openid}@temp.com"
user = User(
username=username,
email=email,
password_hash=generate_password_hash(''), # 空密码
wechat_union_id=unionid,
wechat_open_id=openid,
nickname=user_info.get('nickname', username),
avatar_url=get_full_avatar_url(user_info.get('headimgurl')),
gender='male' if user_info.get('sex') == 1 else 'female',
status='active',
email_confirmed=True # 微信用户默认已验证
)
db.session.add(user)
db.session.commit()
app.logger.info(f"创建新用户: {user.username}")
# 4. 生成JWT token
token = generate_jwt_token(user.id)
# 5. 更新用户最后活跃时间
user.update_last_seen()
db.session.commit()
# 6. 保存登录状态到session
login_info = {
'status': 'completed',
'token': token,
'user': {
'id': user.id,
'username': user.username,
'nickname': user.nickname,
'avatar_url': get_full_avatar_url(user.avatar_url), # 修改这里
'is_new_user': user.created_at > (datetime.now() - timedelta(minutes=1))
}
}
session[f"wx_state_{state}"] = login_info
app.logger.info(f"微信登录成功: {user.username}")
# 7. 返回成功页面包含token传递逻辑
return render_template_string("""
<html>
<head>
<title>登录成功</title>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1">
</head>
<body style="text-align: center; padding: 50px; font-family: Arial, sans-serif;">
<div style="max-width: 400px; margin: 0 auto;">
<div style="color: #07c160; font-size: 48px; margin-bottom: 20px;">✓</div>
<h2 style="color: #333;">登录成功</h2>
<p style="color: #666;">欢迎 {{ user.nickname }}</p>
<p style="color: #999; font-size: 14px;">正在跳转到首页...</p>
</div>
<script>
// 向父窗口发送登录成功消息
if (window.opener) {
window.opener.postMessage({
type: 'wechat_login_success',
token: '{{ token }}',
user: {{ user_json | safe }}
}, '*');
window.close();
} else {
// 如果不是在弹窗中,直接跳转
localStorage.setItem('token', '{{ token }}');
localStorage.setItem('user', '{{ user_json | safe }}');
setTimeout(function() {
window.location.href = '/';
}, 2000);
}
</script>
</body>
</html>
""", user=user, token=token, user_json=json.dumps(login_info['user']))
except Exception as e:
app.logger.error(f"微信回调错误: {str(e)}")
return render_template_string(f"""
<html>
<head><title>登录失败</title><meta charset="utf-8"></head>
<body style="text-align: center; padding: 50px;">
<h2>登录失败</h2>
<p>系统错误: {str(e)}</p>
<button onclick="window.close()">关闭</button>
</body>
</html>
"""), 500
# 检查登录状态接口 - 修改版
@app.route('/api/wechat/check-login', methods=['POST'])
def check_wechat_login():
"""检查微信登录状态"""
try:
data = request.get_json()
state = data.get('state')
if not state:
return jsonify({'code': 400, 'message': '参数错误'}), 400
state_info = session.get(f"wx_state_{state}")
if not state_info:
return jsonify({'code': 400, 'message': '无效状态'}), 400
if state_info.get('status') == 'completed':
# 清除session
session.pop(f"wx_state_{state}", None)
return jsonify({
'code': 200,
'message': 'success',
'data': {
'status': 'authorized',
'token': state_info['token'],
'user': state_info['user']
}
})
else:
return jsonify({
'code': 202,
'message': '等待扫码',
'data': {'status': 'pending'}
})
except Exception as e:
return jsonify({'code': 500, 'message': str(e)}), 500
@app.route('/api/auth/login/email', methods=['POST'])
def api_login_email():
"""邮箱登录接口"""
try:
data = request.get_json()
email = data.get('email')
password = data.get('password')
if not email or not password:
return jsonify({
'code': 400,
'message': '邮箱和密码不能为空',
'data': None
}), 400
user = User.query.filter_by(email=email).first()
if not user or not user.check_password(password):
return jsonify({
'code': 400,
'message': '邮箱或密码错误',
'data': None
}), 400
token = generate_jwt_token(user.id)
login_user(user)
user.update_last_seen()
db.session.commit()
return jsonify({
'code': 200,
'message': '登录成功',
'data': {
'token': token,
'user_id': user.id,
'username': user.username,
'email': user.email,
'is_verified': user.is_verified,
'user_level': user.user_level
}
})
except Exception as e:
return jsonify({
'code': 500,
'message': str(e),
'data': None
}), 500
@app.route('/api/all-industry-data')
def get_all_industry_data():
"""获取所有行业分类数据"""
try:
query = """
SELECT DISTINCT
f002v as classification_name,
f003v as code,
f004v as level1,
f005v as level2,
f006v as level3,
f007v as level4
FROM ea_sector
WHERE f002v NOT IN ('指数成份股', '市场分类', '概念板块', '地区省市分类', '中上协行业分类')
ORDER BY f003v
"""
with engine.connect() as conn:
result = conn.execute(text(query))
data = [dict(row) for row in result]
return jsonify({
"code": 200,
"message": "success",
"data": data
})
except Exception as e:
app.logger.error(f"获取行业数据出错: {str(e)}")
return jsonify({
"code": 500,
"message": str(e),
"data": None
}), 500
# 5. 事件详情-相关标的接口
@app.route('/api/event/<int:event_id>/related-stocks-detail', methods=['GET'])
def api_event_related_stocks(event_id):
"""事件相关标的详情接口"""
try:
event = Event.query.get_or_404(event_id)
related_stocks = event.related_stocks.order_by(RelatedStock.correlation.desc()).all()
# 获取ClickHouse客户端用于分时数据查询
client = get_clickhouse_client()
def get_minute_chart_data(stock_code):
"""获取股票分时图数据"""
try:
# 获取当前日期或最新交易日的分时数据
from datetime import datetime, timedelta, time as dt_time
today = datetime.now().date()
# 获取最新交易日的分时数据
data = client.execute("""
SELECT
timestamp,
open,
high,
low,
close,
volume,
amt
FROM stock_minute
WHERE code = %(code)s
AND timestamp >= %(start)s
AND timestamp <= %(end)s
ORDER BY timestamp
""", {
'code': stock_code,
'start': datetime.combine(today, dt_time(9, 30)),
'end': datetime.combine(today, dt_time(15, 0))
})
# 如果今天没有数据,获取最近的交易日数据
if not data:
# 获取最近的交易日数据
recent_data = client.execute("""
SELECT
timestamp,
open,
high,
low,
close,
volume,
amt
FROM stock_minute
WHERE code = %(code)s
AND timestamp >= (
SELECT MAX(timestamp) - INTERVAL 1 DAY
FROM stock_minute
WHERE code = %(code)s
)
ORDER BY timestamp
""", {
'code': stock_code
})
data = recent_data
# 格式化数据
minute_data = []
for row in data:
minute_data.append({
'time': row[0].strftime('%H:%M'),
'open': float(row[1]) if row[1] else None,
'high': float(row[2]) if row[2] else None,
'low': float(row[3]) if row[3] else None,
'close': float(row[4]) if row[4] else None,
'volume': float(row[5]) if row[5] else None,
'amount': float(row[6]) if row[6] else None
})
return minute_data
except Exception as e:
print(f"Error fetching minute data for {stock_code}: {e}")
return []
stocks_data = []
for stock in related_stocks:
# 获取股票基本信息 - 也使用灵活匹配
stock_info = StockBasicInfo.query.filter_by(SECCODE=stock.stock_code).first()
if not stock_info:
base_code = stock.stock_code.split('.')[0]
stock_info = StockBasicInfo.query.filter_by(SECCODE=base_code).first()
# 获取最新交易数据 - 使用灵活匹配
latest_trade = None
search_codes = [stock.stock_code, stock.stock_code.split('.')[0]]
for code in search_codes:
latest_trade = TradeData.query.filter_by(SECCODE=code) \
.order_by(TradeData.TRADEDATE.desc()).first()
if latest_trade:
break
# 获取前一交易日数据
prev_trade = None
if latest_trade:
prev_trade = TradeData.query.filter_by(SECCODE=latest_trade.SECCODE) \
.filter(TradeData.TRADEDATE < latest_trade.TRADEDATE) \
.order_by(TradeData.TRADEDATE.desc()).first()
# 计算涨跌幅
change_pct = None
change_amount = None
if latest_trade and prev_trade:
if prev_trade.F007N and prev_trade.F007N != 0:
change_amount = float(latest_trade.F007N) - float(prev_trade.F007N)
change_pct = (change_amount / float(prev_trade.F007N)) * 100
elif latest_trade and latest_trade.F010N:
change_pct = float(latest_trade.F010N)
change_amount = float(latest_trade.F009N) if latest_trade.F009N else None
# 获取分时图数据
minute_chart_data = get_minute_chart_data(stock.stock_code)
stock_data = {
'id': stock.id,
'stock_code': stock.stock_code,
'stock_name': stock.stock_name,
'sector': stock.sector,
'relation_desc': stock.relation_desc,
'correlation': stock.correlation,
'momentum': stock.momentum,
'listing_date': stock_info.F006D.isoformat() if stock_info and stock_info.F006D else None,
'market': stock_info.F005V if stock_info else None,
# 交易数据
'trade_data': {
'latest_price': float(latest_trade.F007N) if latest_trade and latest_trade.F007N else None,
'open_price': float(latest_trade.F003N) if latest_trade and latest_trade.F003N else None,
'high_price': float(latest_trade.F005N) if latest_trade and latest_trade.F005N else None,
'low_price': float(latest_trade.F006N) if latest_trade and latest_trade.F006N else None,
'prev_close': float(latest_trade.F002N) if latest_trade and latest_trade.F002N else None,
'change_amount': change_amount,
'change_pct': round(change_pct, 2) if change_pct is not None else None,
'volume': float(latest_trade.F004N) if latest_trade and latest_trade.F004N else None,
'amount': float(latest_trade.F011N) if latest_trade and latest_trade.F011N else None,
'trade_date': latest_trade.TRADEDATE.isoformat() if latest_trade else None,
} if latest_trade else None,
# 分时图数据
'minute_chart_data': minute_chart_data,
# 图表URL
'charts': {
'minute_chart_url': f"/api/stock/{stock.stock_code}/minute-chart",
'daily_chart_url': f"/api/stock/{stock.stock_code}/kline",
}
}
stocks_data.append(stock_data)
return jsonify({
'code': 200,
'message': 'success',
'data': {
'event_id': event_id,
'event_title': event.title,
'event_desc': event.description,
'event_type': event.event_type,
'event_importance': event.importance,
'event_status': event.status,
'event_created_at': event.created_at.strftime("%Y-%m-%d %H:%M:%S"),
'event_start_time': event.start_time.isoformat() if event.start_time else None,
'event_end_time': event.end_time.isoformat() if event.end_time else None,
'keywords': event.keywords,
'view_count': event.view_count,
'post_count': event.post_count,
'follower_count': event.follower_count,
'related_stocks': stocks_data,
'total_count': len(stocks_data)
}
})
except Exception as e:
return jsonify({
'code': 500,
'message': str(e),
'data': None
}), 500
@app.route('/api/stock/<stock_code>/minute-chart', methods=['GET'])
def get_minute_chart_data(stock_code):
client = get_clickhouse_client()
"""获取股票分时图数据"""
try:
# 获取当前日期或最新交易日的分时数据
from datetime import datetime, timedelta, time as dt_time
today = datetime.now().date()
# 获取最新交易日的分时数据
data = client.execute("""
SELECT
timestamp,
open,
high,
low,
close,
volume,
amt
FROM stock_minute
WHERE code = %(code)s
AND timestamp >= %(start)s
AND timestamp <= %(end)s
ORDER BY timestamp
""", {
'code': stock_code,
'start': datetime.combine(today, dt_time(9, 30)),
'end': datetime.combine(today, dt_time(15, 0))
})
# 如果今天没有数据,获取最近的交易日数据
if not data:
# 获取最近的交易日数据
recent_data = client.execute("""
SELECT
timestamp,
open,
high,
low,
close,
volume,
amt
FROM stock_minute
WHERE code = %(code)s
AND timestamp >= (
SELECT MAX(timestamp) - INTERVAL 1 DAY
FROM stock_minute
WHERE code = %(code)s
)
ORDER BY timestamp
""", {
'code': stock_code
})
data = recent_data
# 格式化数据
minute_data = []
for row in data:
minute_data.append({
'time': row[0].strftime('%H:%M'),
'open': float(row[1]) if row[1] else None,
'high': float(row[2]) if row[2] else None,
'low': float(row[3]) if row[3] else None,
'close': float(row[4]) if row[4] else None,
'volume': float(row[5]) if row[5] else None,
'amount': float(row[6]) if row[6] else None
})
return minute_data
except Exception as e:
print(f"Error getting minute chart data: {e}")
return []
# 6. 事件详情-个股详情接口(增强版)
@app.route('/api/event/<int:event_id>/stock/<stock_code>/detail', methods=['GET'])
def api_stock_detail(event_id, stock_code):
"""个股详情接口"""
try:
# 验证事件是否存在
event = Event.query.get_or_404(event_id)
# 获取查询参数
include_minute_data = request.args.get('include_minute_data', 'true').lower() == 'true'
# 获取股票基本信息
basic_info = None
base_code = stock_code.split('.')[0] # 去掉后缀
# 按优先级查找股票信息
basic_info = StockBasicInfo.query.filter_by(SECCODE=stock_code).first()
if not basic_info:
basic_info = StockBasicInfo.query.filter(
StockBasicInfo.SECCODE.ilike(f"{stock_code}%")
).first()
if not basic_info:
basic_info = StockBasicInfo.query.filter(
StockBasicInfo.SECCODE.ilike(f"{base_code}%")
).first()
company_info = CompanyInfo.query.filter_by(SECCODE=stock_code).first()
if not company_info:
company_info = CompanyInfo.query.filter_by(SECCODE=base_code).first()
if not basic_info:
return jsonify({
'code': 404,
'stock_code': stock_code,
'message': '股票不存在',
'data': None
}), 404
# 获取最新交易数据
latest_trade = TradeData.query.filter_by(SECCODE=stock_code) \
.order_by(TradeData.TRADEDATE.desc()).first()
if not latest_trade:
latest_trade = TradeData.query.filter_by(SECCODE=base_code) \
.order_by(TradeData.TRADEDATE.desc()).first()
# 获取分时数据
minute_chart_data = []
if include_minute_data:
minute_chart_data = get_minute_chart_data(stock_code)
# 获取该事件的相关描述
related_stock = RelatedStock.query.filter_by(
event_id=event_id
).filter(
db.or_(
RelatedStock.stock_code == stock_code,
RelatedStock.stock_code == base_code,
RelatedStock.stock_code.like(f"{base_code}.%")
)
).first()
related_desc = None
if related_stock:
related_desc = {
'event_id': related_stock.event_id,
'relation_desc': related_stock.relation_desc,
'sector': related_stock.sector,
'correlation': float(related_stock.correlation) if related_stock.correlation else None,
'momentum': related_stock.momentum
}
response_data = {
'code': 200,
'message': 'success',
'data': {
'event_info': {
'event_id': event.id,
'event_title': event.title,
'event_description': event.description
},
'basic_info': {
'stock_code': basic_info.SECCODE,
'stock_name': basic_info.SECNAME,
'org_name': basic_info.ORGNAME,
'pinyin': basic_info.F001V,
'category': basic_info.F003V,
'market': basic_info.F005V,
'listing_date': basic_info.F006D.isoformat() if basic_info.F006D else None,
'status': basic_info.F011V
},
'company_info': {
'english_name': company_info.F001V if company_info else None,
'legal_representative': company_info.F003V if company_info else None,
'main_business': company_info.F015V if company_info else None,
'business_scope': company_info.F016V if company_info else None,
'company_intro': company_info.F017V if company_info else None,
'csrc_industry_l1': company_info.F030V if company_info else None,
'csrc_industry_l2': company_info.F032V if company_info else None
},
'latest_trade': {
'trade_date': latest_trade.TRADEDATE.isoformat() if latest_trade else None,
'close_price': float(latest_trade.F007N) if latest_trade and latest_trade.F007N else None,
'change': float(latest_trade.F009N) if latest_trade and latest_trade.F009N else None,
'change_pct': float(latest_trade.F010N) if latest_trade and latest_trade.F010N else None,
'volume': float(latest_trade.F004N) if latest_trade and latest_trade.F004N else None,
'amount': float(latest_trade.F011N) if latest_trade and latest_trade.F011N else None
} if latest_trade else None,
'minute_chart_data': minute_chart_data,
'related_desc': related_desc
}
}
response = jsonify(response_data)
response.headers['Content-Type'] = 'application/json; charset=utf-8'
return response
except Exception as e:
return jsonify({
'code': 500,
'message': str(e),
'data': None
}), 500
def get_stock_minute_chart_data(stock_code):
"""获取股票分时图数据"""
try:
client = get_clickhouse_client()
# 获取当前日期(使用最新的交易日)
from datetime import datetime, timedelta, time as dt_time
import csv
def get_trading_days():
trading_days = set()
with open('tdays.csv', 'r') as f:
reader = csv.DictReader(f)
for row in reader:
trading_days.add(datetime.strptime(row['DateTime'], '%Y/%m/%d').date())
return trading_days
trading_days = get_trading_days()
def find_latest_trading_day(current_date):
"""找到最新的交易日"""
while current_date >= min(trading_days):
if current_date in trading_days:
return current_date
current_date -= timedelta(days=1)
return None
target_date = find_latest_trading_day(datetime.now().date())
if not target_date:
return []
# 获取分时数据
data = client.execute("""
SELECT
timestamp,
open,
high,
low,
close,
volume,
amt
FROM stock_minute
WHERE code = %(code)s
AND timestamp BETWEEN %(start)s AND %(end)s
ORDER BY timestamp
""", {
'code': stock_code,
'start': datetime.combine(target_date, dt_time(9, 30)),
'end': datetime.combine(target_date, dt_time(15, 0))
})
minute_data = []
for row in data:
minute_data.append({
'time': row[0].strftime('%H:%M'),
'open': float(row[1]),
'high': float(row[2]),
'low': float(row[3]),
'close': float(row[4]),
'volume': float(row[5]),
'amount': float(row[6])
})
return minute_data
except Exception as e:
print(f"Error getting minute chart data: {e}")
return []
# 7. 事件详情-相关概念接口
@app.route('/api/event/<int:event_id>/related-concepts', methods=['GET'])
def api_event_related_concepts(event_id):
"""事件相关概念接口"""
try:
event = Event.query.get_or_404(event_id)
related_concepts = event.related_concepts.all()
base_url = request.host_url
concepts_data = []
for concept in related_concepts:
image_paths = concept.image_paths_list
image_urls = [base_url + 'data/concepts/' + p for p in image_paths]
concepts_data.append({
'id': concept.id,
'concept_code': concept.concept_code,
'concept': concept.concept,
'reason': concept.reason,
'image_paths': image_paths,
'image_urls': image_urls,
'first_image': image_urls[0] if image_urls else None
})
return jsonify({
'code': 200,
'message': 'success',
'data': {
'event_id': event_id,
'event_title': event.title,
'related_concepts': concepts_data,
'total_count': len(concepts_data)
}
})
except Exception as e:
return jsonify({
'code': 500,
'message': str(e),
'data': None
}), 500
# 8. 事件详情-历史事件接口
@app.route('/api/event/<int:event_id>/historical-events', methods=['GET'])
def api_event_historical_events(event_id):
"""事件历史事件接口"""
try:
event = Event.query.get_or_404(event_id)
historical_events = event.historical_events.order_by(
HistoricalEvent.importance.desc(),
HistoricalEvent.event_date.desc()
).all()
events_data = []
for hist_event in historical_events:
# 获取相关股票信息
related_stocks = []
valid_changes = [] # 用于计算涨跌幅
for stock in hist_event.stocks.all():
base_stock_code = stock.stock_code.split('.')[0]
# 获取股票当日交易数据
trade_data = TradeData.query.filter(
TradeData.SECCODE.startswith(base_stock_code)
).order_by(TradeData.TRADEDATE.desc()).first()
if trade_data and trade_data.F010N is not None:
daily_change = float(trade_data.F010N)
valid_changes.append(daily_change)
else:
daily_change = None
stock_data = {
'stock_code': stock.stock_code,
'stock_name': stock.stock_name,
'relation_desc': stock.relation_desc,
'correlation': stock.correlation,
'sector': stock.sector,
'daily_change': daily_change,
'has_trade_data': True if trade_data else False
}
related_stocks.append(stock_data)
# 计算相关股票的平均涨幅和最大涨幅
avg_change = None
max_change = None
if valid_changes:
avg_change = sum(valid_changes) / len(valid_changes)
max_change = max(valid_changes)
events_data.append({
'id': hist_event.id,
'title': hist_event.title,
'content': hist_event.content,
'event_date': hist_event.event_date.isoformat() if hist_event.event_date else None,
'relevance': hist_event.relevance,
'importance': hist_event.importance,
'related_stocks': related_stocks,
# 使用计算得到的涨幅数据
'related_avg_chg': round(avg_change, 2) if avg_change is not None else None,
'related_max_chg': round(max_change, 2) if max_change is not None else None
})
# 计算当前事件的相关股票涨幅数据
current_valid_changes = []
for stock in event.related_stocks:
base_stock_code = stock.stock_code.split('.')[0]
trade_data = TradeData.query.filter(
TradeData.SECCODE.startswith(base_stock_code)
).order_by(TradeData.TRADEDATE.desc()).first()
if trade_data and trade_data.F010N is not None:
current_valid_changes.append(float(trade_data.F010N))
current_avg_change = None
current_max_change = None
if current_valid_changes:
current_avg_change = sum(current_valid_changes) / len(current_valid_changes)
current_max_change = max(current_valid_changes)
return jsonify({
'code': 200,
'message': 'success',
'data': {
'event_id': event_id,
'event_title': event.title,
'invest_score': event.invest_score,
'related_avg_chg': round(current_avg_change, 2) if current_avg_change is not None else None,
'related_max_chg': round(current_max_change, 2) if current_max_change is not None else None,
'historical_events': events_data,
'total_count': len(events_data)
}
})
except Exception as e:
print(f"Error in api_event_historical_events: {str(e)}")
return jsonify({
'code': 500,
'message': str(e),
'data': None
}), 500
@app.route('/api/event/<int:event_id>/comments', methods=['GET'])
@token_required
def get_event_comments(event_id):
"""获取事件的所有评论和帖子(嵌套格式)
Query参数:
- page: 页码(默认1)
- per_page: 每页评论数(默认20)
- sort: 排序方式(time_desc/time_asc/hot, 默认time_desc)
- include_posts: 是否包含帖子信息(默认true)
- reply_limit: 每个评论显示的回复数量限制(默认3)
返回:
{
"success": true,
"data": {
"event": {
"id": 事件ID,
"title": "事件标题",
"description": "事件描述"
},
"posts": [帖子信息],
"total": 总评论数,
"current_page": 当前页码,
"total_pages": 总页数,
"comments": [
{
"comment_id": 评论ID,
"content": "评论内容",
"created_at": "评论时间",
"post_id": 所属帖子ID,
"post_title": "帖子标题",
"user": {
"user_id": 用户ID,
"nickname": "用户昵称",
"avatar_url": "头像URL"
},
"reply_count": 总回复数量,
"has_more_replies": 是否有更多回复,
"list": [ # 回复列表
{
"comment_id": 回复ID,
"content": "回复内容",
"created_at": "回复时间",
"user": {
"user_id": 用户ID,
"nickname": "用户昵称",
"avatar_url": "头像URL"
},
"reply_to": { # 被回复的用户信息
"user_id": 用户ID,
"nickname": "用户昵称"
}
}
]
}
]
}
}
"""
try:
# 获取查询参数
page = request.args.get('page', 1, type=int)
per_page = request.args.get('per_page', 20, type=int)
sort = request.args.get('sort', 'time_desc')
include_posts = request.args.get('include_posts', 'true').lower() == 'true'
reply_limit = request.args.get('reply_limit', 3, type=int) # 每个评论显示的回复数限制
# 参数验证
if page < 1:
page = 1
if per_page < 1 or per_page > 100:
per_page = 20
if reply_limit < 0 or reply_limit > 50: # 限制回复数量
reply_limit = 3
# 获取事件信息
event = Event.query.get_or_404(event_id)
# 获取事件下的所有帖子
posts_query = Post.query.filter_by(event_id=event_id, status='active') \
.order_by(Post.is_top.desc(), Post.created_at.desc())
posts = posts_query.all()
# 格式化帖子数据
posts_data = []
if include_posts:
for post in posts:
posts_data.append({
'post_id': post.id,
'title': post.title,
'content': post.content,
'content_type': post.content_type,
'created_at': post.created_at.strftime('%Y-%m-%d %H:%M:%S'),
'updated_at': post.updated_at.strftime('%Y-%m-%d %H:%M:%S') if post.updated_at else None,
'likes_count': post.likes_count,
'comments_count': post.comments_count,
'view_count': post.view_count,
'is_top': post.is_top,
'user': {
'user_id': post.user.id,
'username': post.user.username,
'nickname': post.user.nickname or post.user.username,
'avatar_url': get_full_avatar_url(post.user.avatar_url),
'user_level': post.user.user_level,
'is_verified': post.user.is_verified,
'is_creator': post.user.is_creator
}
})
# 获取帖子ID列表用于查询评论
post_ids = [p.id for p in posts]
if not post_ids:
return jsonify({
'success': True,
'data': {
'event': {
'id': event.id,
'title': event.title,
'description': event.description,
'event_type': event.event_type,
'importance': event.importance,
'status': event.status
},
'posts': posts_data,
'total': 0,
'current_page': page,
'total_pages': 0,
'comments': []
}
})
# 构建基础查询 - 只查询主评论
base_query = Comment.query.filter(
Comment.post_id.in_(post_ids),
Comment.parent_id == None, # 只查询主评论
Comment.status == 'active'
)
# 排序处理
if sort == 'time_asc':
base_query = base_query.order_by(Comment.created_at.asc())
elif sort == 'hot':
# 这里可以根据你的业务逻辑添加热度排序
base_query = base_query.order_by(Comment.created_at.desc())
else: # 默认按时间倒序
base_query = base_query.order_by(Comment.created_at.desc())
# 执行分页查询
pagination = base_query.paginate(page=page, per_page=per_page, error_out=False)
# 格式化评论数据(嵌套格式)
comments_data = []
for comment in pagination.items:
# 获取评论的总回复数量
reply_count = Comment.query.filter_by(
parent_id=comment.id,
status='active'
).count()
# 获取指定数量的回复
replies_query = Comment.query.filter_by(
parent_id=comment.id,
status='active'
).order_by(Comment.created_at.asc()) # 回复按时间正序排列
if reply_limit > 0:
replies = replies_query.limit(reply_limit).all()
else:
replies = []
# 格式化回复数据 - 作为list字段
replies_list = []
for reply in replies:
# 获取被回复的用户信息这里是回复主评论所以reply_to就是主评论的用户
reply_to_user = {
'user_id': comment.user.id,
'nickname': comment.user.nickname or comment.user.username
}
replies_list.append({
'comment_id': reply.id,
'content': reply.content,
'created_at': reply.created_at.strftime('%Y-%m-%d %H:%M:%S'),
'user': {
'user_id': reply.user.id,
'username': reply.user.username,
'nickname': reply.user.nickname or reply.user.username,
'avatar_url': get_full_avatar_url(reply.user.avatar_url),
'user_level': reply.user.user_level,
'is_verified': reply.user.is_verified
},
'reply_to': reply_to_user
})
# 获取评论所属的帖子信息
post = comment.post
# 构建嵌套格式的评论数据
comments_data.append({
'comment_id': comment.id,
'content': comment.content,
'created_at': comment.created_at.strftime('%Y-%m-%d %H:%M:%S'),
'post_id': comment.post_id,
'post_title': post.title if post else None,
'post_content_preview': post.content[:100] + '...' if post and len(
post.content) > 100 else post.content if post else None,
'user': {
'user_id': comment.user.id,
'username': comment.user.username,
'nickname': comment.user.nickname or comment.user.username,
'avatar_url': get_full_avatar_url(comment.user.avatar_url),
'user_level': comment.user.user_level,
'is_verified': comment.user.is_verified
},
'reply_count': reply_count,
'has_more_replies': reply_count > len(replies_list),
'list': replies_list # 嵌套的回复列表
})
return jsonify({
'success': True,
'data': {
'event': {
'id': event.id,
'title': event.title,
'description': event.description,
'event_type': event.event_type,
'importance': event.importance,
'status': event.status,
'created_at': event.created_at.strftime('%Y-%m-%d %H:%M:%S'),
'hot_score': event.hot_score,
'view_count': event.view_count,
'post_count': event.post_count,
'follower_count': event.follower_count
},
'posts': posts_data,
'posts_count': len(posts_data),
'total': pagination.total,
'current_page': pagination.page,
'total_pages': pagination.pages,
'comments': comments_data
}
})
except Exception as e:
return jsonify({
'success': False,
'message': '获取评论列表失败',
'error': str(e)
}), 500
@app.route('/api/comment/<int:comment_id>/replies', methods=['GET'])
@token_required
def get_comment_replies(comment_id):
"""获取某条评论的所有回复
Query参数:
- page: 页码(默认1)
- per_page: 每页回复数(默认20)
- sort: 排序方式(time_desc/time_asc, 默认time_desc)
返回格式:
{
"code": 200,
"message": "success",
"data": {
"comment": { # 原评论信息
"id": 评论ID,
"content": "评论内容",
"created_at": "评论时间",
"user": {
"id": 用户ID,
"nickname": "用户昵称",
"avatar_url": "头像URL"
}
},
"replies": { # 回复信息
"total": 总回复数,
"current_page": 当前页码,
"total_pages": 总页数,
"items": [
{
"id": 回复ID,
"content": "回复内容",
"created_at": "回复时间",
"user": {
"id": 用户ID,
"nickname": "用户昵称",
"avatar_url": "头像URL"
},
"reply_to": { # 被回复的用户信息
"id": 用户ID,
"nickname": "用户昵称"
}
}
]
}
}
}
"""
try:
# 获取查询参数
page = request.args.get('page', 1, type=int)
per_page = request.args.get('per_page', 20, type=int)
sort = request.args.get('sort', 'time_desc')
# 参数验证
if page < 1:
page = 1
if per_page < 1 or per_page > 100:
per_page = 20
# 获取原评论信息
comment = Comment.query.get_or_404(comment_id)
if comment.status != 'active':
return jsonify({
'code': 404,
'message': '评论不存在或已被删除',
'data': None
}), 404
# 构建原评论数据
comment_data = {
'comment_id': comment.id,
'content': comment.content,
'created_at': comment.created_at.strftime('%Y-%m-%d %H:%M:%S'),
'user': {
'user_id': comment.user.id,
'nickname': comment.user.nickname or comment.user.username,
'avatar_url': get_full_avatar_url(comment.user.avatar_url), # 修改这里
}
}
# 构建回复查询
replies_query = Comment.query.filter_by(
parent_id=comment_id,
status='active'
)
# 排序处理
if sort == 'time_asc':
replies_query = replies_query.order_by(Comment.created_at.asc())
else: # 默认按时间倒序
replies_query = replies_query.order_by(Comment.created_at.desc())
# 执行分页查询
pagination = replies_query.paginate(page=page, per_page=per_page, error_out=False)
# 格式化回复数据
replies_data = []
for reply in pagination.items:
# 获取被回复的用户信息
reply_to_user = None
if reply.parent_id:
parent_comment = Comment.query.get(reply.parent_id)
if parent_comment:
reply_to_user = {
'id': parent_comment.user.id,
'nickname': parent_comment.user.nickname or parent_comment.user.username
}
replies_data.append({
'reply_id': reply.id,
'content': reply.content,
'created_at': reply.created_at.strftime('%Y-%m-%d %H:%M:%S'),
'user': {
'user_id': reply.user.id,
'nickname': reply.user.nickname or reply.user.username,
'avatar_url': get_full_avatar_url(reply.user.avatar_url), # 修改这里
},
'reply_to': reply_to_user
})
return jsonify({
'code': 200,
'message': 'success',
'data': {
'comment': comment_data,
'replies': {
'total': pagination.total,
'current_page': pagination.page,
'total_pages': pagination.pages,
'items': replies_data
}
}
})
except Exception as e:
return jsonify({
'code': 500,
'message': str(e),
'data': None
}), 500
# 9. 事件详情-关联数据接口
@app.route('/api/event/<int:event_id>/related-data-list', methods=['GET'])
def api_event_related_data(event_id):
"""事件关联数据接口"""
try:
event = Event.query.get_or_404(event_id)
related_data = event.related_data.all()
data_list = [{
'id': data.id,
'title': data.title,
'data_type': data.data_type,
'data_content': data.data_content,
'description': data.description,
'created_at': data.created_at.isoformat() if data.created_at else None
} for data in related_data]
return jsonify({
'code': 200,
'message': 'success',
'data': {
'event_id': event_id,
'event_title': event.title,
'related_data': data_list,
'total_count': len(data_list)
}
})
except Exception as e:
return jsonify({
'code': 500,
'message': str(e),
'data': None
}), 500
# 10. 投资日历-事件接口(增强版)
@app.route('/api/calendar/events', methods=['GET'])
def api_calendar_events():
"""投资日历事件接口 - 连接 future_events 表 (修正版)"""
try:
start_date = request.args.get('start')
end_date = request.args.get('end')
importance = request.args.get('importance', 'all')
category = request.args.get('category', 'all')
page = int(request.args.get('page', 1))
per_page = int(request.args.get('per_page', 10))
offset = (page - 1) * per_page
# 构建基础查询 - 使用 future_events 表
query = """
SELECT
data_id,
calendar_time,
type,
star,
title,
former,
forecast,
fact,
related_stocks,
concepts
FROM future_events
WHERE 1=1
"""
params = {}
if start_date:
query += " AND calendar_time >= :start_date"
params['start_date'] = datetime.fromisoformat(start_date)
if end_date:
query += " AND calendar_time <= :end_date"
params['end_date'] = datetime.fromisoformat(end_date)
if importance != 'all':
query += " AND star = :importance"
params['importance'] = importance
if category != 'all':
query += " AND type = :category"
params['category'] = category
query += " ORDER BY calendar_time LIMIT :limit OFFSET :offset"
params['limit'] = per_page
params['offset'] = offset
result = db.session.execute(text(query), params)
events = result.fetchall()
# 总数统计(不包含分页)
count_query = """
SELECT COUNT(*) as count FROM future_events WHERE 1=1
"""
count_params = params.copy()
count_params.pop('limit', None)
count_params.pop('offset', None)
if start_date:
count_query += " AND calendar_time >= :start_date"
if end_date:
count_query += " AND calendar_time <= :end_date"
if importance != 'all':
count_query += " AND star = :importance"
if category != 'all':
count_query += " AND type = :category"
total_count_result = db.session.execute(text(count_query), count_params).fetchone()
total_count = total_count_result.count if total_count_result else 0
# 申万一级行业到主板块的映射
sector_map = {
# 大周期
'石油石化': '大周期', '煤炭': '大周期', '有色金属': '大周期',
'钢铁': '大周期', '基础化工': '大周期', '建筑材料': '大周期',
'机械设备': '大周期', '电力设备及新能源': '大周期', '国防军工': '大周期',
'电力设备': '大周期', '电网设备': '大周期', '风力发电': '大周期',
'太阳能发电': '大周期', '建筑装饰': '大周期', '建筑': '大周期',
'交通运输': '大周期', '采掘': '大周期', '公用事业': '大周期',
# 大消费
'汽车': '大消费', '家用电器': '大消费', '酒类': '大消费',
'食品饮料': '大消费', '医药生物': '大消费', '纺织服饰': '大消费',
'农林牧渔': '大消费', '商贸零售': '大消费', '轻工制造': '大消费',
'消费者服务': '大消费', '美容护理': '大消费', '社会服务': '大消费',
'纺织服装': '大消费', '商业贸易': '大消费', '休闲服务': '大消费',
# 大金融地产
'银行': '大金融地产', '证券': '大金融地产', '保险': '大金融地产',
'多元金融': '大金融地产', '综合金融': '大金融地产',
'房地产': '大金融地产', '非银金融': '大金融地产',
# TMT板块
'计算机': 'TMT板块', '电子': 'TMT板块', '传媒': 'TMT板块', '通信': 'TMT板块',
# 公共产业
'环保': '公共产业板块', '综合': '公共产业板块'
}
events_data = []
for event in events:
# 解析相关股票 - 使用与detail接口相同的逻辑
related_stocks_list = []
sector_stats = {
'全部股票': 0,
'大周期': 0,
'大消费': 0,
'TMT板块': 0,
'大金融地产': 0,
'公共产业板块': 0,
'其他': 0
}
related_avg_chg = 0
related_max_chg = 0
related_week_chg = 0
# **修正:处理相关股票数据**
if event.related_stocks:
try:
import json
import ast
# 使用与detail接口相同的解析逻辑
if isinstance(event.related_stocks, str):
try:
stock_data = json.loads(event.related_stocks)
except:
stock_data = ast.literal_eval(event.related_stocks)
else:
stock_data = event.related_stocks
if stock_data:
daily_changes = []
week_changes = []
# 处理正确的数据格式 [股票代码, 股票名称, 描述, 分数]
for stock_info in stock_data:
if isinstance(stock_info, list) and len(stock_info) >= 2:
stock_code = stock_info[0] # 股票代码
stock_name = stock_info[1] # 股票名称
description = stock_info[2] if len(stock_info) > 2 else ''
score = stock_info[3] if len(stock_info) > 3 else 0
else:
continue
if stock_code:
# 规范化股票代码,移除后缀
clean_code = stock_code.replace('.SZ', '').replace('.SH', '').replace('.BJ', '')
# 使用模糊匹配LIKE查询申万一级行业F004V
sector_query = """
SELECT F004V as sw_primary_sector
FROM ea_sector
WHERE SECCODE LIKE :stock_code_pattern
AND F002V = '申银万国行业分类'
LIMIT 1
"""
sector_result = db.session.execute(text(sector_query),
{'stock_code_pattern': f'{clean_code}%'})
sector_row = sector_result.fetchone()
# 根据申万一级行业F004V映射到主板块
sw_primary_sector = sector_row.sw_primary_sector if sector_row else None
primary_sector = sector_map.get(sw_primary_sector,
'其他') if sw_primary_sector else '其他'
# 使用模糊匹配查询真实的交易数据
trade_query = """
SELECT F007N as close_price, F010N as change_pct, TRADEDATE
FROM ea_trade
WHERE SECCODE LIKE :stock_code_pattern
ORDER BY TRADEDATE DESC
LIMIT 7
"""
trade_result = db.session.execute(text(trade_query),
{'stock_code_pattern': f'{clean_code}%'})
trade_data = trade_result.fetchall()
daily_chg = 0
week_chg = 0
if trade_data:
# 日涨跌幅(当日)
daily_chg = float(trade_data[0].change_pct or 0)
# 周涨跌幅5个交易日
if len(trade_data) >= 5:
current_price = float(trade_data[0].close_price or 0)
week_ago_price = float(trade_data[4].close_price or 0)
if week_ago_price > 0:
week_chg = ((current_price - week_ago_price) / week_ago_price) * 100
# **修正安全地更新sector统计**
sector_stats['全部股票'] += 1
if primary_sector in sector_stats:
sector_stats[primary_sector] += 1
else:
sector_stats['其他'] += 1
# 收集涨跌幅数据
daily_changes.append(daily_chg)
week_changes.append(week_chg)
related_stocks_list.append({
'code': stock_code,
'name': stock_name,
'description': description,
'score': score,
'sw_primary_sector': sw_primary_sector,
'primary_sector': primary_sector,
'daily_chg': daily_chg,
'week_chg': week_chg
})
# **修正:确保计算平均收益率**
if daily_changes:
related_avg_chg = round(sum(daily_changes) / len(daily_changes), 4)
related_max_chg = round(max(daily_changes), 4)
if week_changes:
related_week_chg = round(sum(week_changes) / len(week_changes), 4)
except Exception as e:
print(f"Error processing related stocks for event {event.data_id}: {e}")
# **新增:解析相关概念 - 使用与detail接口相同的逻辑**
related_concepts = extract_concepts_from_concepts_field(event.concepts)
# 获取评星等级
star_rating = event.star
events_data.append({
'id': event.data_id,
'title': event.title,
'description': f"前值: {event.former}, 预测: {event.forecast}, 实际: {event.fact}" if event.former or event.forecast or event.fact else "",
'start_time': event.calendar_time.isoformat() if event.calendar_time else None,
'end_time': None, # future_events 表没有结束时间
'category': {
'event_type': event.type,
'importance': event.star,
'star_rating': star_rating
},
'star_rating': star_rating,
'related_concepts': related_concepts, # **修正:使用正确解析的概念**
'related_stocks': related_stocks_list, # **修正:使用正确解析的股票**
'related_avg_chg': round(related_avg_chg, 2), # **修正:真实的涨跌幅**
'related_max_chg': round(related_max_chg, 2), # **修正:真实的涨跌幅**
'related_week_chg': round(related_week_chg, 2), # **修正:真实的涨跌幅**
'sector_stats': sector_stats, # **修正:正确的行业统计**
'former': event.former,
'forecast': event.forecast,
'fact': event.fact
})
return jsonify({
'code': 200,
'message': 'success',
'data': {
'events': events_data,
'total_count': total_count,
'page': page,
'per_page': per_page,
'total_pages': (total_count + per_page - 1) // per_page
}
})
except Exception as e:
return jsonify({
'code': 500,
'message': str(e),
'data': None
}), 500
# 11. 投资日历-数据接口
@app.route('/api/calendar/data', methods=['GET'])
def api_calendar_data():
"""投资日历数据接口"""
try:
start_date = request.args.get('start')
end_date = request.args.get('end')
data_type = request.args.get('type', 'all')
# 分页参数
page = int(request.args.get('page', 1))
page_size = int(request.args.get('page_size', 20)) # 默认每页20条
# 验证分页参数
if page < 1:
page = 1
if page_size < 1 or page_size > 100: # 限制每页最大100条
page_size = 20
query1 = RelatedData.query
if start_date:
query1 = query1.filter(RelatedData.created_at >= datetime.fromisoformat(start_date))
if end_date:
query1 = query1.filter(RelatedData.created_at <= datetime.fromisoformat(end_date))
if data_type != 'all':
query1 = query1.filter_by(data_type=data_type)
data_list1 = query1.order_by(RelatedData.created_at.desc()).all()
query2_sql = """
SELECT
data_id as id,
title,
type as data_type,
former,
forecast,
fact,
star,
calendar_time as created_at
FROM future_events
WHERE type = 'data'
"""
# 添加时间筛选条件
params = {}
if start_date:
query2_sql += " AND calendar_time >= :start_date"
params['start_date'] = start_date
if end_date:
query2_sql += " AND calendar_time <= :end_date"
params['end_date'] = end_date
if data_type != 'all':
query2_sql += " AND type = :data_type"
params['data_type'] = data_type
query2_sql += " ORDER BY calendar_time DESC"
result2 = db.session.execute(text(query2_sql), params)
result_data = []
# 处理 RelatedData 的数据
for data in data_list1:
result_data.append({
'id': data.id,
'title': data.title,
'data_type': data.data_type,
'data_content': data.data_content,
'description': data.description,
'created_at': data.created_at.isoformat() if data.created_at else None,
'event_id': data.event_id,
'source': 'related_data', # 标识数据来源
'former': None,
'forecast': None,
'fact': None,
'star': None
})
# 处理 future_events 的数据
for row in result2:
result_data.append({
'id': row.id,
'title': row.title,
'data_type': row.data_type,
'data_content': None,
'description': None,
'created_at': row.created_at.isoformat() if row.created_at else None,
'event_id': None,
'source': 'future_events', # 标识数据来源
'former': row.former,
'forecast': row.forecast,
'fact': row.fact,
'star': row.star
})
# 按时间排序(最新的在前面)
result_data.sort(key=lambda x: x['created_at'] or '1900-01-01', reverse=True)
# 计算分页
total_count = len(result_data)
total_pages = (total_count + page_size - 1) // page_size # 向上取整
# 计算起始和结束索引
start_index = (page - 1) * page_size
end_index = start_index + page_size
# 获取当前页数据
current_page_data = result_data[start_index:end_index]
# 分别统计两个数据源的数量(用于原有逻辑)
related_data_count = len(data_list1)
future_events_count = len(list(result2))
return jsonify({
'code': 200,
'message': 'success',
'data': {
'data_list': current_page_data,
'pagination': {
'current_page': page,
'page_size': page_size,
'total_count': total_count,
'total_pages': total_pages,
'has_next': page < total_pages,
'has_prev': page > 1
},
# 保留原有字段,便于兼容
'total_count': total_count,
'related_data_count': related_data_count,
'future_events_count': future_events_count
}
})
except ValueError as ve:
# 处理分页参数格式错误
return jsonify({
'code': 400,
'message': f'分页参数格式错误: {str(ve)}',
'data': None
}), 400
except Exception as e:
return jsonify({
'code': 500,
'message': str(e),
'data': None
}), 500
# 12. 投资日历-详情接口
def extract_concepts_from_concepts_field(concepts_text):
"""从concepts字段中提取概念信息"""
if not concepts_text:
return []
try:
import json
import ast
# 解析concepts字段的JSON/字符串数据
if isinstance(concepts_text, str):
try:
# 先尝试JSON解析
concepts_data = json.loads(concepts_text)
except:
# 如果JSON解析失败尝试ast.literal_eval解析
concepts_data = ast.literal_eval(concepts_text)
else:
concepts_data = concepts_text
extracted_concepts = []
for concept_info in concepts_data:
if isinstance(concept_info, list) and len(concept_info) >= 3:
concept_name = concept_info[0] # 概念名称
reason = concept_info[1] # 原因/描述
score = concept_info[2] # 分数
extracted_concepts.append({
'name': concept_name,
'reason': reason,
'score': score
})
return extracted_concepts
except Exception as e:
print(f"Error extracting concepts: {e}")
return []
@app.route('/api/calendar/detail/<int:item_id>', methods=['GET'])
def api_future_event_detail(item_id):
"""未来事件详情接口 - 连接 future_events 表 (修正数据解析)"""
try:
# 从 future_events 表查询事件详情
query = """
SELECT
data_id,
calendar_time,
type,
star,
title,
former,
forecast,
fact,
related_stocks,
concepts
FROM future_events
WHERE data_id = :item_id
"""
result = db.session.execute(text(query), {'item_id': item_id})
event = result.fetchone()
if not event:
return jsonify({
'code': 404,
'message': 'Event not found',
'data': None
}), 404
extracted_concepts = extract_concepts_from_concepts_field(event.concepts)
# 解析相关股票
related_stocks_list = []
sector_stats = {
'全部股票': 0,
'大周期': 0,
'大消费': 0,
'TMT板块': 0,
'大金融地产': 0,
'公共产业板块': 0,
'其他': 0
}
# 申万一级行业到主板块的映射
sector_map = {
# 大周期
'石油石化': '大周期', '煤炭': '大周期', '有色金属': '大周期',
'钢铁': '大周期', '基础化工': '大周期', '建筑材料': '大周期',
'机械设备': '大周期', '电力设备及新能源': '大周期', '国防军工': '大周期',
'电力设备': '大周期', '电网设备': '大周期', '风力发电': '大周期',
'太阳能发电': '大周期', '建筑装饰': '大周期', '建筑': '大周期',
'交通运输': '大周期', '采掘': '大周期', '公用事业': '大周期',
# 大消费
'汽车': '大消费', '家用电器': '大消费', '酒类': '大消费',
'食品饮料': '大消费', '医药生物': '大消费', '纺织服饰': '大消费',
'农林牧渔': '大消费', '商贸零售': '大消费', '轻工制造': '大消费',
'消费者服务': '大消费', '美容护理': '大消费', '社会服务': '大消费',
'纺织服装': '大消费', '商业贸易': '大消费', '休闲服务': '大消费',
# 大金融地产
'银行': '大金融地产', '证券': '大金融地产', '保险': '大金融地产',
'多元金融': '大金融地产', '综合金融': '大金融地产',
'房地产': '大金融地产', '非银金融': '大金融地产',
# TMT板块
'计算机': 'TMT板块', '电子': 'TMT板块', '传媒': 'TMT板块', '通信': 'TMT板块',
# 公共产业
'环保': '公共产业板块', '综合': '公共产业板块'
}
# 处理相关股票
related_avg_chg = 0
related_max_chg = 0
related_week_chg = 0
if event.related_stocks:
try:
import json
import ast
# **修正正确解析related_stocks数据结构**
if isinstance(event.related_stocks, str):
try:
# 先尝试JSON解析
stock_data = json.loads(event.related_stocks)
except:
# 如果JSON解析失败尝试ast.literal_eval解析
stock_data = ast.literal_eval(event.related_stocks)
else:
stock_data = event.related_stocks
print(f"Parsed stock_data: {stock_data}") # 调试输出
if stock_data:
daily_changes = []
week_changes = []
# **修正:处理正确的数据格式 [股票代码, 股票名称, 描述, 分数]**
for stock_info in stock_data:
if isinstance(stock_info, list) and len(stock_info) >= 2:
stock_code = stock_info[0] # 第一个元素是股票代码
stock_name = stock_info[1] # 第二个元素是股票名称
description = stock_info[2] if len(stock_info) > 2 else ''
score = stock_info[3] if len(stock_info) > 3 else 0
else:
continue # 跳过格式不正确的数据
if stock_code:
# 规范化股票代码,移除后缀
clean_code = stock_code.replace('.SZ', '').replace('.SH', '').replace('.BJ', '')
print(f"Processing stock: {clean_code} - {stock_name}") # 调试输出
# 使用模糊匹配LIKE查询申万一级行业F004V
sector_query = """
SELECT F004V as sw_primary_sector
FROM ea_sector
WHERE SECCODE LIKE :stock_code_pattern
AND F002V = '申银万国行业分类'
LIMIT 1
"""
sector_result = db.session.execute(text(sector_query),
{'stock_code_pattern': f'{clean_code}%'})
sector_row = sector_result.fetchone()
# 根据申万一级行业F004V映射到主板块
sw_primary_sector = sector_row.sw_primary_sector if sector_row else None
primary_sector = sector_map.get(sw_primary_sector, '其他') if sw_primary_sector else '其他'
print(
f"Stock: {clean_code}, SW Primary: {sw_primary_sector}, Primary Sector: {primary_sector}")
# 通过SQL查询获取真实的日涨跌幅和周涨跌幅
trade_query = """
SELECT F007N as close_price, F010N as change_pct, TRADEDATE
FROM ea_trade
WHERE SECCODE LIKE :stock_code_pattern
ORDER BY TRADEDATE DESC
LIMIT 7
"""
trade_result = db.session.execute(text(trade_query),
{'stock_code_pattern': f'{clean_code}%'})
trade_data = trade_result.fetchall()
daily_chg = 0
week_chg = 0
if trade_data:
# 日涨跌幅(当日)
daily_chg = float(trade_data[0].change_pct or 0)
# 周涨跌幅5个交易日
if len(trade_data) >= 5:
current_price = float(trade_data[0].close_price or 0)
week_ago_price = float(trade_data[4].close_price or 0)
if week_ago_price > 0:
week_chg = ((current_price - week_ago_price) / week_ago_price) * 100
print(
f"Trade data found: {len(trade_data) if trade_data else 0} records, daily_chg: {daily_chg}")
# 统计各分类数量
sector_stats['全部股票'] += 1
sector_stats[primary_sector] += 1
# 收集涨跌幅数据
daily_changes.append(daily_chg)
week_changes.append(week_chg)
related_stocks_list.append({
'code': stock_code, # 原始股票代码
'name': stock_name, # 股票名称
'description': description, # 关联描述
'score': score, # 关联分数
'sw_primary_sector': sw_primary_sector, # 申万一级行业F004V
'primary_sector': primary_sector, # 主板块分类
'daily_change': daily_chg, # 真实的日涨跌幅
'week_change': week_chg # 真实的周涨跌幅
})
# 计算平均收益率
if daily_changes:
related_avg_chg = sum(daily_changes) / len(daily_changes)
related_max_chg = max(daily_changes)
if week_changes:
related_week_chg = sum(week_changes) / len(week_changes)
except Exception as e:
print(f"Error processing related stocks: {e}")
import traceback
traceback.print_exc()
# 构建返回数据
detail_data = {
'id': event.data_id,
'title': event.title,
'type': event.type,
'star': event.star,
'calendar_time': event.calendar_time.isoformat() if event.calendar_time else None,
'former': event.former,
'forecast': event.forecast,
'fact': event.fact,
'concepts': event.concepts,
'extracted_concepts': extracted_concepts,
'related_stocks': related_stocks_list,
'sector_stats': sector_stats,
'related_avg_chg': round(related_avg_chg, 2),
'related_max_chg': round(related_max_chg, 2),
'related_week_chg': round(related_week_chg, 2)
}
return jsonify({
'code': 200,
'message': 'success',
'data': {
'type': 'future_event',
'detail': detail_data
}
})
except Exception as e:
return jsonify({
'code': 500,
'message': str(e),
'data': None
}), 500
# 13-15. 筛选弹窗接口(已有,优化格式)
@app.route('/api/filter/options', methods=['GET'])
def api_filter_options():
"""筛选选项接口"""
try:
# 获取排序选项
sort_options = [
{'key': 'new', 'name': '最新', 'desc': '按创建时间排序'},
{'key': 'hot', 'name': '热门', 'desc': '按热度分数排序'},
{'key': 'returns', 'name': '收益率', 'desc': '按收益率排序'},
{'key': 'importance', 'name': '重要性', 'desc': '按重要性等级排序'},
{'key': 'view_count', 'name': '浏览量', 'desc': '按浏览次数排序'}
]
# 获取行业筛选选项
industry_options = db.session.execute(text("""
SELECT DISTINCT f002v as classification_name, COUNT(*) as count
FROM ea_sector
WHERE f002v IS NOT NULL
GROUP BY f002v
ORDER BY f002v
""")).fetchall()
# 获取重要性选项
importance_options = [
{'key': 'S', 'name': 'S级', 'desc': '重大事件'},
{'key': 'A', 'name': 'A级', 'desc': '重要事件'},
{'key': 'B', 'name': 'B级', 'desc': '普通事件'},
{'key': 'C', 'name': 'C级', 'desc': '参考事件'}
]
return jsonify({
'code': 200,
'message': 'success',
'data': {
'sort_options': sort_options,
'industry_options': [{
'name': row.classification_name,
'count': row.count
} for row in industry_options],
'importance_options': importance_options
}
})
except Exception as e:
return jsonify({
'code': 500,
'message': str(e),
'data': None
}), 500
# 16-17. 会员权益接口
@app.route('/api/membership/status', methods=['GET'])
@token_required
def api_membership_status():
"""会员状态接口"""
try:
user = request.user
# TODO: 根据实际业务逻辑判断会员状态
# 这里假设用户表中有会员相关字段
is_member = getattr(user, 'is_member', False)
member_expire_date = getattr(user, 'member_expire_date', None)
return jsonify({
'code': 200,
'message': 'success',
'data': {
'user_id': user.id,
'is_member': is_member,
'member_expire_date': member_expire_date.isoformat() if member_expire_date else None,
'user_level': user.user_level,
'benefits': {
'unlimited_access': is_member,
'priority_support': is_member,
'advanced_analytics': is_member,
'custom_alerts': is_member
}
}
})
except Exception as e:
return jsonify({
'code': 500,
'message': str(e),
'data': None
}), 500
# 18-19. 个人中心接口
@app.route('/api/user/profile', methods=['GET'])
@token_required
def api_user_profile():
"""个人资料接口"""
try:
user = request.user
likes_count = PostLike.query.filter_by(user_id=user.id).count()
follows_count = EventFollow.query.filter_by(user_id=user.id).count()
comments_made = Comment.query.filter_by(user_id=user.id).count()
comments_received = db.session.query(Comment) \
.join(Post, Comment.post_id == Post.id) \
.filter(Post.user_id == user.id).count()
replies_received = Comment.query.filter(
Comment.parent_id.in_(
db.session.query(Comment.id).filter_by(user_id=user.id)
)
).count()
# 总评论数(发出的评论 + 收到的评论和回复)
total_comments = comments_made + comments_received + replies_received
profile_data = {
'basic_info': {
'user_id': user.id,
'username': user.username,
'email': user.email,
'phone': user.phone,
'nickname': user.nickname,
'avatar_url': get_full_avatar_url(user.avatar_url), # 修改这里
'bio': user.bio,
'gender': user.gender,
'birth_date': user.birth_date.isoformat() if user.birth_date else None,
'location': user.location
},
'account_status': {
'email_confirmed': user.email_confirmed,
'phone_confirmed': user.phone_confirmed,
'is_verified': user.is_verified,
'verify_time': user.verify_time.isoformat() if user.verify_time else None,
'created_at': user.created_at.isoformat() if user.created_at else None,
'last_seen': user.last_seen.isoformat() if user.last_seen else None
},
'statistics': {
'likes_count': likes_count, # 点赞数
'follows_count': follows_count, # 关注数
'total_comments': total_comments, # 总评论数
'comments_detail': {
'comments_made': comments_made, # 发出的评论
'comments_received': comments_received, # 收到的评论
'replies_received': replies_received # 收到的回复
}
},
'investment_preferences': {
'trading_experience': user.trading_experience,
'investment_style': user.investment_style,
'risk_preference': user.risk_preference,
'investment_amount': user.investment_amount,
'preferred_markets': user.preferred_markets
},
'community_stats': {
'user_level': user.user_level,
'reputation_score': user.reputation_score,
'contribution_point': user.contribution_point,
'post_count': user.post_count,
'comment_count': user.comment_count,
'follower_count': user.follower_count,
'following_count': user.following_count
},
'settings': {
'email_notifications': user.email_notifications,
'sms_notifications': user.sms_notifications,
'privacy_level': user.privacy_level,
'theme_preference': user.theme_preference
}
}
return jsonify({
'code': 200,
'message': 'success',
'data': profile_data
})
except Exception as e:
return jsonify({
'code': 500,
'message': str(e),
'data': None
}), 500
# 在文件开头添加缓存变量
_agreements_cache = {}
_cache_loaded = False
def load_agreements_from_docx():
"""从docx文件中加载协议内容只读取一次"""
global _agreements_cache, _cache_loaded
if _cache_loaded:
return _agreements_cache
try:
# 定义文件路径和对应的协议类型
docx_files = {
'about_us': 'about_us.docx', # 关于我们
'service_terms': 'service_terms.docx', # 服务条款
'privacy_policy': 'privacy_policy.docx' # 隐私政策
}
# 定义协议标题
titles = {
'about_us': '关于我们',
'service_terms': '服务条款',
'privacy_policy': '隐私政策'
}
for agreement_type, filename in docx_files.items():
file_path = os.path.join(os.path.dirname(__file__), filename)
if os.path.exists(file_path):
try:
# 读取docx文件
doc = Document(file_path)
# 提取文本内容
content_paragraphs = []
for paragraph in doc.paragraphs:
if paragraph.text.strip(): # 跳过空段落
content_paragraphs.append(paragraph.text.strip())
# 合并所有段落
content = '\n\n'.join(content_paragraphs)
# 获取文件修改时间作为版本标识
file_stat = os.stat(file_path)
last_modified = file_stat.st_mtime
# 缓存内容
_agreements_cache[agreement_type] = {
'title': titles.get(agreement_type, agreement_type),
'content': content,
'last_updated': last_modified,
'version': '1.0',
'file_path': filename
}
print(f"Successfully loaded {agreement_type} from {filename}")
except Exception as e:
print(f"Error reading {filename}: {str(e)}")
# 如果读取失败,使用默认内容
_agreements_cache[agreement_type] = {
'title': titles.get(agreement_type, agreement_type),
'content': f"协议内容正在加载中,请稍后再试。(文件:{filename})",
'last_updated': None,
'version': '1.0',
'file_path': filename,
'error': str(e)
}
else:
print(f"File not found: {filename}")
# 如果文件不存在,使用默认内容
_agreements_cache[agreement_type] = {
'title': titles.get(agreement_type, agreement_type),
'content': f"协议文件未找到,请联系管理员。(文件:{filename})",
'last_updated': None,
'version': '1.0',
'file_path': filename,
'error': 'File not found'
}
_cache_loaded = True
print(f"Agreements cache loaded successfully. Total: {len(_agreements_cache)} agreements")
except Exception as e:
print(f"Error loading agreements: {str(e)}")
_cache_loaded = False
return _agreements_cache
@app.route('/api/agreements', methods=['GET'])
def api_agreements():
"""平台协议接口 - 从docx文件读取"""
try:
# 获取查询参数
agreement_type = request.args.get('type') # about_us, service_terms, privacy_policy
force_reload = request.args.get('reload', 'false').lower() == 'true' # 强制重新加载
# 如果需要强制重新加载,清除缓存
if force_reload:
global _cache_loaded
_cache_loaded = False
_agreements_cache.clear()
# 加载协议内容
agreements_data = load_agreements_from_docx()
if not agreements_data:
return jsonify({
'code': 500,
'message': 'Failed to load agreements',
'data': None
}), 500
# 如果指定了特定协议类型,只返回该协议
if agreement_type and agreement_type in agreements_data:
return jsonify({
'code': 200,
'message': 'success',
'data': {
'agreement_type': agreement_type,
**agreements_data[agreement_type]
}
})
# 返回所有协议
return jsonify({
'code': 200,
'message': 'success',
'data': {
'agreements': agreements_data,
'available_types': list(agreements_data.keys()),
'cache_loaded': _cache_loaded,
'total_agreements': len(agreements_data)
}
})
except Exception as e:
return jsonify({
'code': 500,
'message': str(e),
'data': None
}), 500
# 可选:添加一个管理接口来手动重新加载协议
@app.route('/api/agreements/reload', methods=['POST'])
def api_reload_agreements():
"""重新加载协议内容(管理员接口)"""
try:
# 清除缓存
global _cache_loaded
_cache_loaded = False
_agreements_cache.clear()
# 重新加载
agreements_data = load_agreements_from_docx()
return jsonify({
'code': 200,
'message': 'Agreements reloaded successfully',
'data': {
'total_agreements': len(agreements_data),
'agreements': list(agreements_data.keys())
}
})
except Exception as e:
return jsonify({
'code': 500,
'message': str(e),
'data': None
}), 500
# 20. 个人中心-我的关注接口
@app.route('/api/user/activities', methods=['GET'])
@token_required
def api_user_activities():
"""用户活动接口(我的关注、评论、点赞)"""
try:
user = request.user
activity_type = request.args.get('type', 'follows') # follows, comments, likes, commented
page = request.args.get('page', 1, type=int)
per_page = min(50, request.args.get('per_page', 20, type=int))
if activity_type == 'follows':
# 我的关注列表
follows = EventFollow.query.filter_by(user_id=user.id) \
.order_by(EventFollow.created_at.desc()) \
.paginate(page=page, per_page=per_page, error_out=False)
activities = []
for follow in follows.items:
# 获取相关股票并添加单日涨幅
related_stocks_data = []
for stock in follow.event.related_stocks.limit(5):
# 处理股票代码,移除可能的后缀(如 .SZ 或 .SH
base_stock_code = stock.stock_code.split('.')[0]
# 获取股票最新交易数据
trade_data = TradeData.query.filter(
TradeData.SECCODE.startswith(base_stock_code)
).order_by(TradeData.TRADEDATE.desc()).first()
# 计算单日涨幅
daily_change = None
if trade_data and trade_data.F010N is not None:
daily_change = float(trade_data.F010N)
related_stocks_data.append({
'stock_code': stock.stock_code,
'stock_name': stock.stock_name,
'correlation': stock.correlation,
'daily_change': daily_change, # 新增:单日涨幅
'daily_change_formatted': f"{daily_change:.2f}%" if daily_change is not None else "暂无数据"
# 格式化显示
})
activities.append({
'event_id': follow.event_id,
'event_title': follow.event.title,
'event_description': follow.event.description,
'follow_time': follow.created_at.isoformat() if follow.created_at else None,
'event_hot_score': follow.event.hot_score,
# 新增字段
'importance': follow.event.importance, # 重要性
'related_avg_chg': follow.event.related_avg_chg, # 平均涨幅
'related_max_chg': follow.event.related_max_chg, # 最大涨幅
'related_week_chg': follow.event.related_week_chg, # 周涨幅
'related_stocks': related_stocks_data, # 修改:包含单日涨幅的相关股票
'created_at': follow.event.created_at.isoformat() if follow.event.created_at else None, # 发布时间
'preview': follow.event.description[:200] if follow.event.description else None, # 预览限制200字
'comment_count': follow.event.post_count, # 评论数
'view_count': follow.event.view_count, # 评论数
'follower_count': follow.event.follower_count # 关注数
})
total = follows.total
pages = follows.pages
elif activity_type == 'likes':
# 我的点赞列表
likes = PostLike.query.filter_by(user_id=user.id) \
.order_by(PostLike.created_at.desc()) \
.paginate(page=page, per_page=per_page, error_out=False)
activities = [{
'like_id': like.id,
'post_id': like.post_id,
'post_content': like.post.content,
'like_time': like.created_at.isoformat() if like.created_at else None,
# 新增发布人信息
'author': {
'nickname': like.post.user.nickname or like.post.user.username,
'avatar_url': get_full_avatar_url(like.post.user.avatar_url), # 修改这里
}
} for like in likes.items]
total = likes.total
pages = likes.pages
elif activity_type == 'comments':
# 我的评论列表(增强版 - 添加重要性和事件内容)
comments = Comment.query.filter_by(user_id=user.id) \
.join(Post, Comment.post_id == Post.id) \
.join(Event, Post.event_id == Event.id) \
.order_by(Comment.created_at.desc()) \
.paginate(page=page, per_page=per_page, error_out=False)
activities = []
for comment in comments.items:
# 通过关联路径获取事件信息comment.post_id -> post.id -> post.event_id -> event.id
post = comment.post
event = post.event if post else None
activity_data = {
'comment_id': comment.id,
'post_id': comment.post_id,
'content': comment.content, # 评论内容
'created_at': comment.created_at.isoformat() if comment.created_at else None,
'post_title': post.title if post and post.title else None,
'post_content': post.content if post else None,
# 新增:评论者信息(当前用户)
'commenter': {
'id': comment.user.id,
'username': comment.user.username,
'nickname': comment.user.nickname or comment.user.username,
'avatar_url': get_full_avatar_url(comment.user.avatar_url),
'user_level': comment.user.user_level,
'is_verified': comment.user.is_verified
},
# 新增字段:事件信息
'event': {
'id': event.id if event else None,
'title': event.title if event else None,
'description': event.description if event else None, # 事件内容
'importance': event.importance if event else None, # 重要性
'event_type': event.event_type if event else None,
'hot_score': event.hot_score if event else None,
'view_count': event.view_count if event else None,
'related_avg_chg': event.related_avg_chg if event else None,
'created_at': event.created_at.isoformat() if event and event.created_at else None
},
# 新增:帖子作者信息
'post_author': {
'id': post.user.id if post else None,
'username': post.user.username if post else None,
'nickname': post.user.nickname or post.user.username if post else None,
'avatar_url': get_full_avatar_url(post.user.avatar_url) if post else None,
}
}
activities.append(activity_data)
total = comments.total
pages = comments.pages
elif activity_type == 'commented':
# 评论了我的帖子
my_posts = Post.query.filter_by(user_id=user.id).subquery()
comments = Comment.query.join(my_posts, Comment.post_id == my_posts.c.id) \
.filter(Comment.user_id != user.id) \
.order_by(Comment.created_at.desc()) \
.paginate(page=page, per_page=per_page, error_out=False)
activities = [{
'comment_id': comment.id,
'comment_content': comment.content,
'comment_time': comment.created_at.isoformat() if comment.created_at else None,
'commenter_nickname': comment.user.nickname or comment.user.username,
'commenter_avatar': get_full_avatar_url(comment.user.avatar_url), # 修改这里
'post_content': comment.post.content,
'event_title': comment.post.event.title,
'event_id': comment.post.event_id
} for comment in comments.items]
total = comments.total
pages = comments.pages
return jsonify({
'code': 200,
'message': 'success',
'data': {
'activities': activities,
'total': total,
'pages': pages,
'current_page': page
}
})
except Exception as e:
print(f"Error in api_user_activities: {str(e)}")
return jsonify({
'code': 500,
'message': '服务器内部错误',
'data': None
}), 500
class UserFeedback(db.Model):
"""用户反馈模型"""
id = db.Column(db.Integer, primary_key=True)
user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False)
type = db.Column(db.String(50), nullable=False) # 反馈类型
content = db.Column(db.Text, nullable=False) # 反馈内容
contact_info = db.Column(db.String(100)) # 联系方式
status = db.Column(db.String(20), default='pending') # 状态pending/processing/resolved/closed
admin_reply = db.Column(db.Text) # 管理员回复
created_at = db.Column(db.DateTime, default=beijing_now)
updated_at = db.Column(db.DateTime, default=beijing_now, onupdate=beijing_now)
# 关联关系
user = db.relationship('User', backref='feedbacks')
def __init__(self, user_id, type, content, contact_info=None):
self.user_id = user_id
self.type = type
self.content = content
self.contact_info = contact_info
def to_dict(self):
return {
'id': self.id,
'type': self.type,
'content': self.content,
'contact_info': self.contact_info,
'status': self.status,
'admin_reply': self.admin_reply,
'created_at': self.created_at.strftime('%Y-%m-%d %H:%M:%S'),
'updated_at': self.updated_at.strftime('%Y-%m-%d %H:%M:%S')
}
# 21. 个人中心-意见反馈接口
@app.route('/api/user/feedback', methods=['POST'])
@token_required
def api_user_feedback():
"""意见反馈接口"""
try:
data = request.get_json()
feedback_type = data.get('type', 'other') # bug, suggestion, complaint, other
content = data.get('content')
contact_info = data.get('contact_info', '')
if not content:
return jsonify({
'code': 400,
'message': '反馈类型和内容不能为空',
'data': None
}), 400
# 验证反馈类型是否有效
valid_types = ['bug', 'feature', 'suggestion', 'other'] # 可以根据需求修改
if feedback_type not in valid_types:
return jsonify({
'code': 400,
'message': '无效的反馈类型',
'data': None
}), 400
# 创建反馈记录
feedback = UserFeedback(
user_id=request.user.id,
type=feedback_type,
content=content,
contact_info=contact_info
)
# 保存到数据库
db.session.add(feedback)
db.session.commit()
# 可以在这里添加通知管理员的逻辑
notify_admin_new_feedback(feedback)
# 可以保存到数据库或发送邮件通知管理员
return jsonify({
'code': 200,
'message': '反馈提交成功,我们会尽快处理',
'data': feedback.to_dict()
})
except Exception as e:
return jsonify({
'code': 500,
'message': str(e),
'data': None
}), 500
def notify_admin_new_feedback(feedback):
"""通知管理员新的反馈"""
try:
# 获取管理员邮箱列表
admin_emails = ['admin@example.com'] # 替换为实际的管理员邮箱列表
# 发送通知邮件
subject = f'新的用户反馈 - {feedback.type}'
body = f"""
收到新的用户反馈:
用户ID: {feedback.user_id}
反馈类型: {feedback.type}
反馈内容: {feedback.content}
联系方式: {feedback.contact_info or '未提供'}
提交时间: {feedback.created_at}
"""
for admin_email in admin_emails:
send_notification_email(admin_email, subject, 'emails/admin_notification.html',
feedback=feedback)
except Exception as e:
app.logger.error(f"发送管理员通知失败: {str(e)}")
# 通用错误处理
@app.errorhandler(404)
def api_not_found(error):
if request.path.startswith('/api/'):
return jsonify({
'code': 404,
'message': '接口不存在',
'data': None
}), 404
return error
@app.errorhandler(405)
def api_method_not_allowed(error):
if request.path.startswith('/api/'):
return jsonify({
'code': 405,
'message': '请求方法不允许',
'data': None
}), 405
return error
if __name__ == '__main__':
app.run(
host='0.0.0.0',
port=5002,
debug=True
)