diff --git a/app.py b/app.py index c4ad24b2..d12ab182 100644 --- a/app.py +++ b/app.py @@ -8,6 +8,7 @@ import uuid from functools import wraps import qrcode from flask_mail import Mail, Message +from flask_socketio import SocketIO, emit, join_room, leave_room import pytz import requests from celery import Celery @@ -40,6 +41,7 @@ from tencentcloud.common.exception.tencent_cloud_sdk_exception import TencentClo from sqlalchemy import text, desc, and_ import pandas as pd from decimal import Decimal +from apscheduler.schedulers.background import BackgroundScheduler # 交易日数据缓存 trading_days = [] @@ -242,6 +244,18 @@ db = SQLAlchemy(app) # 初始化邮件服务 mail = Mail(app) +# 初始化 Flask-SocketIO(用于实时事件推送) +socketio = SocketIO( + app, + cors_allowed_origins=["http://localhost:3000", "http://127.0.0.1:3000", "http://localhost:5173", + "https://valuefrontier.cn", "http://valuefrontier.cn"], + async_mode='gevent', + logger=True, + engineio_logger=False, + ping_timeout=120, # 心跳超时时间(秒),客户端120秒内无响应才断开 + ping_interval=25 # 心跳检测间隔(秒),每25秒发送一次ping +) + @login_manager.user_loader def load_user(user_id): @@ -1835,6 +1849,15 @@ def send_verification_code(): if not credential or not code_type: return jsonify({'success': False, 'error': '缺少必要参数'}), 400 + # 清理格式字符(空格、横线、括号等) + if code_type == 'phone': + # 移除手机号中的空格、横线、括号、加号等格式字符 + credential = re.sub(r'[\s\-\(\)\+]', '', credential) + print(f"📱 清理后的手机号: {credential}") + elif code_type == 'email': + # 邮箱只移除空格 + credential = credential.strip() + # 生成验证码 verification_code = generate_verification_code() @@ -1883,7 +1906,7 @@ def send_verification_code(): @app.route('/api/auth/login-with-code', methods=['POST']) def login_with_verification_code(): - """使用验证码登录""" + """使用验证码登录/注册(自动注册)""" try: data = request.get_json() credential = data.get('credential') # 手机号或邮箱 @@ -1893,6 +1916,17 @@ def login_with_verification_code(): if not credential or not verification_code or not login_type: return jsonify({'success': False, 'error': '缺少必要参数'}), 400 + # 清理格式字符(空格、横线、括号等) + if login_type == 'phone': + # 移除手机号中的空格、横线、括号、加号等格式字符 + original_credential = credential + credential = re.sub(r'[\s\-\(\)\+]', '', credential) + if original_credential != credential: + print(f"📱 登录时清理手机号: {original_credential} -> {credential}") + elif login_type == 'email': + # 邮箱只移除前后空格 + credential = credential.strip() + # 检查验证码 session_key = f'verification_code_{login_type}_{credential}_login' stored_code_info = session.get(session_key) @@ -1918,13 +1952,86 @@ def login_with_verification_code(): # 验证码正确,查找用户 user = None + is_new_user = False + if login_type == 'phone': user = User.query.filter_by(phone=credential).first() + if not user: + # 自动注册新用户 + is_new_user = True + # 生成唯一用户名 + base_username = f"user_{credential}" + username = base_username + counter = 1 + while User.query.filter_by(username=username).first(): + username = f"{base_username}_{counter}" + counter += 1 + + # 创建新用户 + user = User(username=username, phone=credential) + user.phone_confirmed = True + user.email = f"{username}@valuefrontier.temp" # 临时邮箱 + db.session.add(user) + db.session.commit() + elif login_type == 'email': user = User.query.filter_by(email=credential).first() + if not user: + # 自动注册新用户 + is_new_user = True + # 从邮箱生成用户名 + email_prefix = credential.split('@')[0] + base_username = f"user_{email_prefix}" + username = base_username + counter = 1 + while User.query.filter_by(username=username).first(): + username = f"{base_username}_{counter}" + counter += 1 + # 如果用户不存在,自动创建新用户 if not user: - return jsonify({'success': False, 'error': '用户不存在'}), 404 + try: + # 生成用户名 + if login_type == 'phone': + # 使用手机号生成用户名 + base_username = f"用户{credential[-4:]}" + elif login_type == 'email': + # 使用邮箱前缀生成用户名 + base_username = credential.split('@')[0] + else: + base_username = "新用户" + + # 确保用户名唯一 + username = base_username + counter = 1 + while User.is_username_taken(username): + username = f"{base_username}_{counter}" + counter += 1 + + # 创建新用户 + user = User(username=username) + + # 设置手机号或邮箱 + if login_type == 'phone': + user.phone = credential + elif login_type == 'email': + user.email = credential + + # 设置默认密码(使用随机密码,用户后续可以修改) + user.set_password(uuid.uuid4().hex) + user.status = 'active' + user.nickname = username + + db.session.add(user) + db.session.commit() + + is_new_user = True + print(f"✅ 自动创建新用户: {username}, {login_type}: {credential}") + + except Exception as e: + print(f"❌ 创建用户失败: {e}") + db.session.rollback() + return jsonify({'success': False, 'error': '创建用户失败'}), 500 # 清除验证码 session.pop(session_key, None) @@ -1941,9 +2048,13 @@ def login_with_verification_code(): # 更新最后登录时间 user.update_last_seen() + # 根据是否为新用户返回不同的消息 + message = '注册成功,欢迎加入!' if is_new_user else '登录成功' + return jsonify({ 'success': True, - 'message': '登录成功', + 'message': message, + 'is_new_user': is_new_user, 'user': { 'id': user.id, 'username': user.username, @@ -1957,6 +2068,7 @@ def login_with_verification_code(): except Exception as e: print(f"验证码登录错误: {e}") + db.session.rollback() return jsonify({'success': False, 'error': '登录失败'}), 500 @@ -2009,8 +2121,8 @@ def register(): except Exception as e: db.session.rollback() - print(f"注册失败: {e}") - return jsonify({'success': False, 'error': '注册失败,请重试'}), 500 + print(f"验证码登录/注册错误: {e}") + return jsonify({'success': False, 'error': '登录失败'}), 500 def send_sms_code(phone, code, template_id): @@ -2123,8 +2235,10 @@ def register_with_phone(): data = request.get_json() phone = data.get('phone') code = data.get('code') + password = data.get('password') + username = data.get('username') - if not all([phone, code]): + if not all([phone, code, password, username]): return jsonify({'success': False, 'error': '所有字段都是必填的'}), 400 # 验证验证码 @@ -2135,12 +2249,14 @@ def register_with_phone(): if stored_code['code'] != code: return jsonify({'success': False, 'error': '验证码错误'}), 400 - if User.query.filter_by(phone=phone).first(): - return jsonify({'success': False, 'error': '手机号已存在'}), 400 + if User.query.filter_by(username=username).first(): + return jsonify({'success': False, 'error': '用户名已存在'}), 400 try: # 创建用户 - user = User(username='用户', phone=phone) + user = User(username=username, phone=phone) + user.email = f"{username}@valuefrontier.temp" + user.set_password(password) user.phone_confirmed = True db.session.add(user) @@ -2506,12 +2622,13 @@ def get_wechat_qrcode(): 'wechat_unionid': None } - return jsonify({'code':0, - 'data':{ - 'auth_url': wechat_auth_url, - 'session_id': state, - 'expires_in': 300 - }}), 200 + return jsonify({"code":0, + "data": + { + 'auth_url': wechat_auth_url, + 'session_id': state, + 'expires_in': 300 + }}), 200 @app.route('/api/account/wechat/qrcode', methods=['GET']) @@ -2609,8 +2726,19 @@ def wechat_callback(): state = request.args.get('state') error = request.args.get('error') - # 错误处理 - if error or not code or not state: + # 错误处理:用户拒绝授权 + if error: + if state in wechat_qr_sessions: + wechat_qr_sessions[state]['status'] = 'auth_denied' + wechat_qr_sessions[state]['error'] = '用户拒绝授权' + print(f"❌ 用户拒绝授权: state={state}") + return redirect('/auth/signin?error=wechat_auth_denied') + + # 参数验证 + if not code or not state: + if state in wechat_qr_sessions: + wechat_qr_sessions[state]['status'] = 'auth_failed' + wechat_qr_sessions[state]['error'] = '授权参数缺失' return redirect('/auth/signin?error=wechat_auth_failed') # 验证state @@ -2625,14 +2753,28 @@ def wechat_callback(): return redirect('/auth/signin?error=session_expired') try: - # 获取access_token + # 步骤1: 用户已扫码并授权(微信回调过来说明用户已完成扫码+授权) + session_data['status'] = 'scanned' + print(f"✅ 微信扫码回调: state={state}, code={code[:10]}...") + + # 步骤2: 获取access_token token_data = get_wechat_access_token(code) if not token_data: + session_data['status'] = 'auth_failed' + session_data['error'] = '获取访问令牌失败' + print(f"❌ 获取微信access_token失败: state={state}") return redirect('/auth/signin?error=token_failed') - # 获取用户信息 + # 步骤3: Token获取成功,标记为已授权 + session_data['status'] = 'authorized' + print(f"✅ 微信授权成功: openid={token_data['openid']}") + + # 步骤4: 获取用户信息 user_info = get_wechat_userinfo(token_data['access_token'], token_data['openid']) if not user_info: + session_data['status'] = 'auth_failed' + session_data['error'] = '获取用户信息失败' + print(f"❌ 获取微信用户信息失败: openid={token_data['openid']}") return redirect('/auth/signin?error=userinfo_failed') # 查找或创建用户 / 或处理绑定 @@ -2677,6 +2819,8 @@ def wechat_callback(): return redirect('/home?bind=failed') user = None + is_new_user = False + if unionid: user = User.query.filter_by(wechat_union_id=unionid).first() if not user: @@ -2707,6 +2851,9 @@ def wechat_callback(): db.session.add(user) db.session.commit() + is_new_user = True + print(f"✅ 微信扫码自动创建新用户: {username}, openid: {openid}") + # 更新最后登录时间 user.update_last_seen() @@ -2720,18 +2867,30 @@ def wechat_callback(): # Flask-Login 登录 login_user(user, remember=True) - # 清理微信session(仅登录/注册流程清理;绑定流程在上方已处理,不在此处清理) + # 更新微信session状态,供前端轮询检测 if state in wechat_qr_sessions: - # 仅当不是绑定流程,或没有模式信息时清理 - if not wechat_qr_sessions[state].get('mode'): - del wechat_qr_sessions[state] + session_item = wechat_qr_sessions[state] + # 仅处理登录/注册流程,不处理绑定流程 + if not session_item.get('mode'): + # 更新状态和用户信息 + session_item['status'] = 'register_ready' if is_new_user else 'login_ready' + session_item['user_info'] = {'user_id': user.id} + print(f"✅ 微信扫码状态已更新: {session_item['status']}, user_id: {user.id}") # 直接跳转到首页 return redirect('/home') except Exception as e: print(f"❌ 微信登录失败: {e}") + import traceback + traceback.print_exc() db.session.rollback() + + # 更新session状态为失败 + if state in wechat_qr_sessions: + wechat_qr_sessions[state]['status'] = 'auth_failed' + wechat_qr_sessions[state]['error'] = str(e) + return redirect('/auth/signin?error=login_failed') @@ -2802,61 +2961,6 @@ def login_with_wechat(): }), 500 -@app.route('/api/auth/register/wechat', methods=['POST']) -def register_with_wechat(): - """微信注册(保留用于特殊情况)""" - data = request.get_json() - session_id = data.get('session_id') - username = data.get('username') - password = data.get('password') - - if not all([session_id, username, password]): - return jsonify({'error': '所有字段都是必填的'}), 400 - - # 验证session - session = wechat_qr_sessions.get(session_id) - if not session: - return jsonify({'error': '微信验证失败或状态无效'}), 400 - - if User.query.filter_by(username=username).first(): - return jsonify({'error': '用户名已存在'}), 400 - - # 检查微信OpenID是否已被其他用户使用 - wechat_openid = session.get('wechat_openid') - wechat_unionid = session.get('wechat_unionid') - - if wechat_unionid and User.query.filter_by(wechat_union_id=wechat_unionid).first(): - return jsonify({'error': '该微信号已被其他用户绑定'}), 400 - if User.query.filter_by(wechat_open_id=wechat_openid).first(): - return jsonify({'error': '该微信号已被其他用户绑定'}), 400 - - # 创建用户 - try: - wechat_info = session['user_info'] - user = User(username=username) - user.set_password(password) - # 使用清理后的昵称 - user.nickname = user._sanitize_nickname(wechat_info.get('nickname', username)) - user.avatar_url = wechat_info.get('avatar_url') - user.wechat_open_id = wechat_openid - user.wechat_union_id = wechat_unionid - - db.session.add(user) - db.session.commit() - - # 清除session - del wechat_qr_sessions[session_id] - - return jsonify({ - 'message': '注册成功', - 'user': user.to_dict() - }), 201 - except Exception as e: - db.session.rollback() - print(f"WeChat register error: {e}") - return jsonify({'error': '注册失败,请重试'}), 500 - - @app.route('/api/account/wechat/unbind', methods=['POST']) def unbind_wechat_account(): """解绑当前登录用户的微信""" @@ -3691,6 +3795,7 @@ class RelatedStock(db.Model): updated_at = db.Column(db.DateTime, default=beijing_now, onupdate=beijing_now) correlation = db.Column(db.Float()) momentum = db.Column(db.String(1024)) # 动量 + retrieved_sources = db.Column(db.JSON) # 动量 class RelatedData(db.Model): @@ -4003,17 +4108,31 @@ def get_related_stocks(event_id): stocks_data = [] for stock in stocks: - stocks_data.append({ - 'id': stock.id, - 'stock_code': stock.stock_code, - 'stock_name': stock.stock_name, - 'sector': stock.sector, - 'relation_desc': stock.relation_desc, - 'correlation': stock.correlation, - 'momentum': stock.momentum, - 'created_at': stock.created_at.isoformat() if stock.created_at else None, - 'updated_at': stock.updated_at.isoformat() if stock.updated_at else None - }) + if stock.retrieved_sources is not None: + stocks_data.append({ + 'id': stock.id, + 'stock_code': stock.stock_code, + 'stock_name': stock.stock_name, + 'sector': stock.sector, + 'relation_desc': {"data":stock.retrieved_sources}, + 'retrieved_sources': stock.retrieved_sources, + 'correlation': stock.correlation, + 'momentum': stock.momentum, + 'created_at': stock.created_at.isoformat() if stock.created_at else None, + 'updated_at': stock.updated_at.isoformat() if stock.updated_at else None + }) + else: + stocks_data.append({ + 'id': stock.id, + 'stock_code': stock.stock_code, + 'stock_name': stock.stock_name, + 'sector': stock.sector, + 'relation_desc': stock.relation_desc, + 'correlation': stock.correlation, + 'momentum': stock.momentum, + 'created_at': stock.created_at.isoformat() if stock.created_at else None, + 'updated_at': stock.updated_at.isoformat() if stock.updated_at else None + }) return jsonify({ 'success': True, @@ -4544,7 +4663,7 @@ def get_clickhouse_client(): return Cclient( host='222.128.1.157', port=18000, - user='default', + user='default',a password='Zzl33818!', database='stock' ) @@ -6290,196 +6409,100 @@ def parse_json_field(field_value): # ==================== 行业API ==================== @app.route('/api/classifications', methods=['GET']) def get_classifications(): - """获取所有行业分类系统""" + """获取申银万国行业分类树形结构""" try: + # 查询申银万国行业分类的所有数据 sql = """ - SELECT DISTINCT f002v as classification_name + SELECT f003v as code, f004v as level1, f005v as level2, f006v as level3,f007v as level4 FROM ea_sector - WHERE f002v NOT IN ('指数成份股', '市场分类', '概念板块', '地区省市分类', '中上协行业分类') - ORDER BY classification_name \ + WHERE f002v = '申银万国行业分类' + AND f003v IS NOT NULL + AND f004v IS NOT NULL + ORDER BY f003v """ result = db.session.execute(text(sql)).all() - classifications = [{'name': row.classification_name} for row in result] + # 构建树形结构 + tree_dict = {} - return jsonify({ - 'success': True, - 'data': classifications - }) + for row in result: + code = row.code + level1 = row.level1 + level2 = row.level2 + level3 = row.level3 - except Exception as e: - return jsonify({ - 'success': False, - 'error': str(e) - }), 500 + # 跳过空数据 + if not level1: + continue - -@app.route('/api/levels', methods=['GET']) -def get_industry_levels(): - """获取行业层级数据""" - try: - classification = request.args.get('classification') - level = request.args.get('level', type=int) - level1_name = request.args.get('level1_name', '') - level2_name = request.args.get('level2_name', '') - level3_name = request.args.get('level3_name', '') - - if not classification or not level or level < 1 or level > 4: - return jsonify({ - 'success': False, - 'error': 'Invalid parameters' - }), 400 - - # 层级到字段的映射 - level_fields = { - 1: "f004v", - 2: "f005v", - 3: "f006v", - 4: "f007v" - } - - field_name = level_fields[level] - - # 构建查询 - if level == 1: - sql = f""" - SELECT DISTINCT {field_name} as name, - MIN(f003v) as code - FROM ea_sector - WHERE f002v = :classification - AND {field_name} IS NOT NULL - GROUP BY name - ORDER BY name - """ - params = {"classification": classification} - - elif level == 2: - sql = f""" - SELECT DISTINCT {field_name} as name, - MIN(f003v) as code - FROM ea_sector - WHERE f002v = :classification - AND f004v = :level1_name - AND {field_name} IS NOT NULL - GROUP BY name - ORDER BY name - """ - params = {"classification": classification, "level1_name": level1_name} - - elif level == 3: - sql = f""" - SELECT DISTINCT {field_name} as name, - MIN(f003v) as code - FROM ea_sector - WHERE f002v = :classification - AND f004v = :level1_name - AND f005v = :level2_name - AND {field_name} IS NOT NULL - GROUP BY name - ORDER BY name - """ - params = { - "classification": classification, - "level1_name": level1_name, - "level2_name": level2_name - } - - elif level == 4: - sql = f""" - SELECT DISTINCT f003v as code, - {field_name} as name - FROM ea_sector - WHERE f002v = :classification - AND f004v = :level1_name - AND f005v = :level2_name - AND f006v = :level3_name - AND {field_name} IS NOT NULL - ORDER BY name - """ - params = { - "classification": classification, - "level1_name": level1_name, - "level2_name": level2_name, - "level3_name": level3_name - } - - results = db.session.execute(text(sql), params).all() - - industries = [{"code": row.code, "name": row.name} for row in results if row.name] - - return jsonify({ - 'success': True, - 'data': industries - }) - - except Exception as e: - return jsonify({ - 'success': False, - 'error': str(e) - }), 500 - - -@app.route('/api/info', methods=['GET']) -def get_industry_info(): - """获取行业详细信息""" - try: - classification = request.args.get('classification') - code = request.args.get('code') - - if not classification or not code: - return jsonify({ - 'success': False, - 'error': 'Missing parameters' - }), 400 - - # 根据代码长度确定字段 - if len(code) >= 8: - field_name = "f007v" - elif len(code) >= 6: - field_name = "f006v" - elif len(code) >= 4: - field_name = "f005v" - else: - field_name = "f004v" - - sql = f""" - SELECT {field_name} as name, - f004v as level1_name, - f005v as level2_name, - f006v as level3_name, - f007v as level4_name - FROM ea_sector - WHERE f002v = :classification - AND f003v = :code - AND {field_name} IS NOT NULL - LIMIT 1 - """ - - result = db.session.execute(text(sql), { - "classification": classification, - "code": code - }).first() - - if not result: - return jsonify({ - 'success': False, - 'error': 'Industry not found' - }), 404 - - return jsonify({ - 'success': True, - 'data': { - 'name': result.name, - 'code': code, - 'classification': classification, - 'hierarchy': { - 'level1': result.level1_name, - 'level2': result.level2_name, - 'level3': result.level3_name, - 'level4': result.level4_name + # 第一层 + if level1 not in tree_dict: + # 获取第一层的code(取前3位或前缀) + level1_code = code[:3] if len(code) >= 3 else code + tree_dict[level1] = { + 'value': level1_code, + 'label': level1, + 'children_dict': {} } + + # 第二层 + if level2: + if level2 not in tree_dict[level1]['children_dict']: + # 获取第二层的code(取前6位) + level2_code = code[:6] if len(code) >= 6 else code + tree_dict[level1]['children_dict'][level2] = { + 'value': level2_code, + 'label': level2, + 'children_dict': {} + } + + # 第三层 + if level3: + if level3 not in tree_dict[level1]['children_dict'][level2]['children_dict']: + tree_dict[level1]['children_dict'][level2]['children_dict'][level3] = { + 'value': code, + 'label': level3 + } + + # 转换为最终格式 + result_list = [] + for level1_name, level1_data in tree_dict.items(): + level1_node = { + 'value': level1_data['value'], + 'label': level1_data['label'] } + + # 处理第二层 + if level1_data['children_dict']: + level1_children = [] + for level2_name, level2_data in level1_data['children_dict'].items(): + level2_node = { + 'value': level2_data['value'], + 'label': level2_data['label'] + } + + # 处理第三层 + if level2_data['children_dict']: + level2_children = [] + for level3_name, level3_data in level2_data['children_dict'].items(): + level2_children.append({ + 'value': level3_data['value'], + 'label': level3_data['label'] + }) + if level2_children: + level2_node['children'] = level2_children + + level1_children.append(level2_node) + + if level1_children: + level1_node['children'] = level1_children + + result_list.append(level1_node) + + return jsonify({ + 'success': True, + 'data': result_list }) except Exception as e: @@ -6489,6 +6512,29 @@ def get_industry_info(): }), 500 +@app.route('/api/stocklist', methods=['GET']) +def get_stock_list(): + """获取股票列表""" + try: + sql = """ + SELECT DISTINCT SECCODE as code, SECNAME as name + FROM ea_stocklist + ORDER BY SECCODE + """ + + result = db.session.execute(text(sql)).all() + + stocks = [{'code': row.code, 'name': row.name} for row in result] + + return jsonify(stocks) + + except Exception as e: + return jsonify({ + 'success': False, + 'error': str(e) + }), 500 + + @app.route('/api/events', methods=['GET'], strict_slashes=False) def api_get_events(): """ @@ -6510,10 +6556,8 @@ def api_get_events(): date_range = request.args.get('date_range') recent_days = request.args.get('recent_days', type=int) - # 行业筛选参数 - industry_classification = request.args.get('industry_classification') - industry_code = request.args.get('industry_code') - industry_level = request.args.get('industry_level', type=int) + # 行业筛选参数(只支持申银万国行业分类) + industry_code = request.args.get('industry_code') # 申万行业代码,如 "S370502" # 概念/标签筛选参数 tag = request.args.get('tag') @@ -6555,16 +6599,39 @@ def api_get_events(): query = query.filter_by(status=event_status) if event_type != 'all': query = query.filter_by(event_type=event_type) + # 支持多个重要性级别筛选,用逗号分隔(如 importance=S,A) if importance != 'all': - query = query.filter_by(importance=importance) + if ',' in importance: + # 多个重要性级别 + importance_list = [imp.strip() for imp in importance.split(',') if imp.strip()] + query = query.filter(Event.importance.in_(importance_list)) + else: + # 单个重要性级别 + query = query.filter_by(importance=importance) if creator_id: query = query.filter_by(creator_id=creator_id) - # 新增:行业代码过滤(MySQL JSON,对象数组模式) - if industry_classification and industry_code: - json_path = f'$[*]."{industry_classification}"' - query = query.filter( - text("JSON_CONTAINS(JSON_EXTRACT(related_industries, :json_path), :industry_code)") - ).params(json_path=json_path, industry_code=json.dumps(industry_code)) + # 新增:行业代码过滤(申银万国行业分类) + if industry_code: + # related_industries 格式: [{"申银万国行业分类": "S370502"}, ...] + # 支持多个行业代码,用逗号分隔 + json_path = '$[*]."申银万国行业分类"' + + # 如果包含逗号,说明是多个行业代码 + if ',' in industry_code: + codes = [code.strip() for code in industry_code.split(',') if code.strip()] + # 使用 OR 条件匹配任意一个行业代码 + conditions = [] + for code in codes: + conditions.append( + text("JSON_CONTAINS(JSON_EXTRACT(related_industries, :json_path), :code)") + .bindparams(json_path=json_path, code=json.dumps(code)) + ) + query = query.filter(db.or_(*conditions)) + else: + # 单个行业代码 + query = query.filter( + text("JSON_CONTAINS(JSON_EXTRACT(related_industries, :json_path), :industry_code)") + ).params(json_path=json_path, industry_code=json.dumps(industry_code)) # 新增:关键词/全文搜索过滤(MySQL JSON) if search_query: like_pattern = f"%{search_query}%" @@ -6685,8 +6752,6 @@ def api_get_events(): applied_filters['start_date'] = start_date if end_date: applied_filters['end_date'] = end_date - if industry_classification: - applied_filters['industry_classification'] = industry_classification if industry_code: applied_filters['industry_code'] = industry_code if tag: @@ -7436,6 +7501,333 @@ def add_event_comment(event_id): }), 500 +# ==================== WebSocket 事件处理器(实时事件推送) ==================== + +@socketio.on('connect') +def handle_connect(): + """客户端连接事件""" + print(f'\n[WebSocket DEBUG] ========== 客户端连接 ==========') + print(f'[WebSocket DEBUG] Socket ID: {request.sid}') + print(f'[WebSocket DEBUG] Remote Address: {request.remote_addr if hasattr(request, "remote_addr") else "N/A"}') + print(f'[WebSocket] 客户端已连接: {request.sid}') + + emit('connection_response', { + 'status': 'connected', + 'sid': request.sid, + 'message': '已连接到事件推送服务' + }) + print(f'[WebSocket DEBUG] ✓ 已发送 connection_response') + print(f'[WebSocket DEBUG] ========== 连接完成 ==========\n') + + +@socketio.on('subscribe_events') +def handle_subscribe(data): + """ + 客户端订阅事件推送 + data: { + 'event_type': 'all' | 'policy' | 'market' | 'tech' | ..., + 'importance': 'all' | 'S' | 'A' | 'B' | 'C', + 'filters': {...} # 可选的其他筛选条件 + } + """ + try: + print(f'\n[WebSocket DEBUG] ========== 收到订阅请求 ==========') + print(f'[WebSocket DEBUG] Socket ID: {request.sid}') + print(f'[WebSocket DEBUG] 订阅数据: {data}') + + event_type = data.get('event_type', 'all') + importance = data.get('importance', 'all') + + print(f'[WebSocket DEBUG] 事件类型: {event_type}') + print(f'[WebSocket DEBUG] 重要性: {importance}') + + # 加入对应的房间 + room_name = f"events_{event_type}" + print(f'[WebSocket DEBUG] 准备加入房间: {room_name}') + join_room(room_name) + print(f'[WebSocket DEBUG] ✓ 已加入房间: {room_name}') + + print(f'[WebSocket] 客户端 {request.sid} 订阅了房间: {room_name}') + + response_data = { + 'success': True, + 'room': room_name, + 'event_type': event_type, + 'importance': importance, + 'message': f'已订阅 {event_type} 类型的事件推送' + } + print(f'[WebSocket DEBUG] 准备发送 subscription_confirmed: {response_data}') + emit('subscription_confirmed', response_data) + print(f'[WebSocket DEBUG] ✓ 已发送 subscription_confirmed') + print(f'[WebSocket DEBUG] ========== 订阅完成 ==========\n') + + except Exception as e: + print(f'[WebSocket ERROR] 订阅失败: {e}') + import traceback + traceback.print_exc() + emit('subscription_error', { + 'success': False, + 'error': str(e) + }) + + +@socketio.on('unsubscribe_events') +def handle_unsubscribe(data): + """取消订阅事件推送""" + try: + print(f'\n[WebSocket DEBUG] ========== 收到取消订阅请求 ==========') + print(f'[WebSocket DEBUG] Socket ID: {request.sid}') + print(f'[WebSocket DEBUG] 数据: {data}') + + event_type = data.get('event_type', 'all') + room_name = f"events_{event_type}" + + print(f'[WebSocket DEBUG] 准备离开房间: {room_name}') + leave_room(room_name) + print(f'[WebSocket DEBUG] ✓ 已离开房间: {room_name}') + + print(f'[WebSocket] 客户端 {request.sid} 取消订阅房间: {room_name}') + + emit('unsubscription_confirmed', { + 'success': True, + 'room': room_name, + 'message': f'已取消订阅 {event_type} 类型的事件推送' + }) + print(f'[WebSocket DEBUG] ========== 取消订阅完成 ==========\n') + + except Exception as e: + print(f'[WebSocket ERROR] 取消订阅失败: {e}') + import traceback + traceback.print_exc() + emit('unsubscription_error', { + 'success': False, + 'error': str(e) + }) + + +@socketio.on('disconnect') +def handle_disconnect(): + """客户端断开连接事件""" + print(f'\n[WebSocket DEBUG] ========== 客户端断开 ==========') + print(f'[WebSocket DEBUG] Socket ID: {request.sid}') + print(f'[WebSocket] 客户端已断开: {request.sid}') + print(f'[WebSocket DEBUG] ========== 断开完成 ==========\n') + + +# ==================== WebSocket 辅助函数 ==================== + +def broadcast_new_event(event): + """ + 广播新事件到所有订阅的客户端 + 在创建新事件时调用此函数 + + Args: + event: Event 模型实例 + """ + try: + print(f'\n[WebSocket DEBUG] ========== 广播新事件 ==========') + print(f'[WebSocket DEBUG] 事件ID: {event.id}') + print(f'[WebSocket DEBUG] 事件标题: {event.title}') + print(f'[WebSocket DEBUG] 事件类型: {event.event_type}') + print(f'[WebSocket DEBUG] 重要性: {event.importance}') + + event_data = { + 'id': event.id, + 'title': event.title, + 'description': event.description, + 'event_type': event.event_type, + 'importance': event.importance, + 'status': event.status, + 'created_at': event.created_at.isoformat() if event.created_at else None, + 'hot_score': event.hot_score, + 'view_count': event.view_count, + 'related_avg_chg': event.related_avg_chg, + 'related_max_chg': event.related_max_chg, + 'keywords': event.keywords_list if hasattr(event, 'keywords_list') else event.keywords, + } + + print(f'[WebSocket DEBUG] 准备发送的数据: {event_data}') + + # 发送到所有订阅者(all 房间) + print(f'[WebSocket DEBUG] 正在发送到房间: events_all') + socketio.emit('new_event', event_data, room='events_all', namespace='/') + print(f'[WebSocket DEBUG] ✓ 已发送到 events_all') + + # 发送到特定类型订阅者 + if event.event_type: + room_name = f"events_{event.event_type}" + print(f'[WebSocket DEBUG] 正在发送到房间: {room_name}') + socketio.emit('new_event', event_data, room=room_name, namespace='/') + print(f'[WebSocket DEBUG] ✓ 已发送到 {room_name}') + print(f'[WebSocket] 已推送新事件到房间: events_all, {room_name}') + else: + print(f'[WebSocket] 已推送新事件到房间: events_all') + + print(f'[WebSocket DEBUG] ========== 广播完成 ==========\n') + + except Exception as e: + print(f'[WebSocket ERROR] 推送新事件失败: {e}') + import traceback + traceback.print_exc() + + +# ==================== WebSocket 轮询机制(检测新事件) ==================== + +# 内存变量:记录近24小时内已知的事件ID集合和最大ID +known_event_ids_in_24h = set() # 近24小时内已知的所有事件ID +last_max_event_id = 0 # 已知的最大事件ID + +def poll_new_events(): + """ + 定期轮询数据库,检查是否有新事件 + 每 30 秒执行一次 + + 新的设计思路(修复 created_at 不是入库时间的问题): + 1. 查询近24小时内的所有活跃事件(按 created_at,因为这是事件发生时间) + 2. 通过对比事件ID(自增ID)来判断是否为新插入的事件 + 3. 推送 ID > last_max_event_id 的事件 + 4. 更新已知事件ID集合和最大ID + """ + global known_event_ids_in_24h, last_max_event_id + + try: + with app.app_context(): + from datetime import datetime, timedelta + + current_time = datetime.now() + print(f'\n[轮询 DEBUG] ========== 开始轮询 ==========') + print(f'[轮询 DEBUG] 当前时间: {current_time.strftime("%Y-%m-%d %H:%M:%S")}') + print(f'[轮询 DEBUG] 已知事件ID数量: {len(known_event_ids_in_24h)}') + print(f'[轮询 DEBUG] 当前最大事件ID: {last_max_event_id}') + + # 查询近24小时内的所有活跃事件(按事件发生时间 created_at) + time_24h_ago = current_time - timedelta(hours=24) + print(f'[轮询 DEBUG] 查询时间范围: 近24小时({time_24h_ago.strftime("%Y-%m-%d %H:%M:%S")} ~ 现在)') + + # 查询所有近24小时内的活跃事件 + events_in_24h = Event.query.filter( + Event.created_at >= time_24h_ago, + Event.status == 'active' + ).order_by(Event.id.asc()).all() + + print(f'[轮询 DEBUG] 数据库查询结果: 找到 {len(events_in_24h)} 个近24小时内的事件') + + # 找出新插入的事件(ID > last_max_event_id) + new_events = [ + event for event in events_in_24h + if event.id > last_max_event_id + ] + + print(f'[轮询 DEBUG] 新事件数量(ID > {last_max_event_id}): {len(new_events)} 个') + + if new_events: + print(f'[轮询] 发现 {len(new_events)} 个新事件') + + for event in new_events: + print(f'[轮询 DEBUG] 新事件详情:') + print(f'[轮询 DEBUG] - ID: {event.id}') + print(f'[轮询 DEBUG] - 标题: {event.title}') + print(f'[轮询 DEBUG] - 事件发生时间(created_at): {event.created_at}') + print(f'[轮询 DEBUG] - 事件类型: {event.event_type}') + + # 推送新事件 + print(f'[轮询 DEBUG] 准备推送事件 ID={event.id}') + broadcast_new_event(event) + print(f'[轮询] ✓ 已推送事件 ID={event.id}, 标题={event.title}') + + # 更新已知事件ID集合(所有近24小时内的事件ID) + known_event_ids_in_24h = set(event.id for event in events_in_24h) + + # 更新最大事件ID + new_max_id = max(event.id for event in events_in_24h) + print(f'[轮询 DEBUG] 更新最大事件ID: {last_max_event_id} -> {new_max_id}') + last_max_event_id = new_max_id + + print(f'[轮询 DEBUG] 更新后已知事件ID数量: {len(known_event_ids_in_24h)}') + + else: + print(f'[轮询 DEBUG] 没有新事件需要推送') + + # 即使没有新事件,也要更新已知事件集合(清理超过24小时的) + if events_in_24h: + known_event_ids_in_24h = set(event.id for event in events_in_24h) + current_max_id = max(event.id for event in events_in_24h) + if current_max_id != last_max_event_id: + print(f'[轮询 DEBUG] 更新最大事件ID: {last_max_event_id} -> {current_max_id}') + last_max_event_id = current_max_id + + print(f'[轮询 DEBUG] ========== 轮询结束 ==========\n') + + except Exception as e: + print(f'[轮询 ERROR] 检查新事件时出错: {e}') + import traceback + traceback.print_exc() + + +def initialize_event_polling(): + """ + 初始化事件轮询机制 + 在应用启动时调用 + """ + global known_event_ids_in_24h, last_max_event_id + + try: + from datetime import datetime, timedelta + + with app.app_context(): + current_time = datetime.now() + time_24h_ago = current_time - timedelta(hours=24) + + print(f'\n[轮询] ========== 初始化事件轮询 ==========') + print(f'[轮询] 当前时间: {current_time.strftime("%Y-%m-%d %H:%M:%S")}') + + # 查询近24小时内的所有活跃事件 + events_in_24h = Event.query.filter( + Event.created_at >= time_24h_ago, + Event.status == 'active' + ).order_by(Event.id.asc()).all() + + # 初始化已知事件ID集合 + known_event_ids_in_24h = set(event.id for event in events_in_24h) + + # 初始化最大事件ID + if events_in_24h: + last_max_event_id = max(event.id for event in events_in_24h) + print(f'[轮询] 近24小时内共有 {len(events_in_24h)} 个活跃事件') + print(f'[轮询] 初始最大事件ID: {last_max_event_id}') + print(f'[轮询] 事件ID范围: {min(event.id for event in events_in_24h)} ~ {last_max_event_id}') + else: + last_max_event_id = 0 + print(f'[轮询] 近24小时内没有活跃事件') + print(f'[轮询] 初始最大事件ID: 0') + + # 统计数据库中的事件总数 + total_events = Event.query.filter_by(status='active').count() + print(f'[轮询] 数据库中共有 {total_events} 个活跃事件(所有时间)') + print(f'[轮询] 只会推送 ID > {last_max_event_id} 的新事件') + print(f'[轮询] ========== 初始化完成 ==========\n') + + # 创建后台调度器 + scheduler = BackgroundScheduler() + # 每 30 秒执行一次轮询 + scheduler.add_job( + func=poll_new_events, + trigger='interval', + seconds=30, + id='poll_new_events', + name='检查新事件并推送', + replace_existing=True + ) + scheduler.start() + print('[轮询] 调度器已启动,每 30 秒检查一次新事件') + + except Exception as e: + print(f'[轮询] 初始化失败: {e}') + + +# ==================== 结束 WebSocket 部分 ==================== + + @app.route('/api/posts//like', methods=['POST']) @login_required def like_post(post_id): @@ -7611,6 +8003,98 @@ def format_date(date_obj): return str(date_obj) +def remove_cycles_from_sankey_flows(flows_data): + """ + 移除Sankey图数据中的循环边,确保数据是DAG(有向无环图) + 使用拓扑排序算法检测循环,优先保留flow_ratio高的边 + + Args: + flows_data: list of flow objects with 'source', 'target', 'flow_metrics' keys + + Returns: + list of flows without cycles + """ + if not flows_data: + return flows_data + + # 按flow_ratio降序排序,优先保留重要的边 + sorted_flows = sorted( + flows_data, + key=lambda x: x.get('flow_metrics', {}).get('flow_ratio', 0) or 0, + reverse=True + ) + + # 构建图的邻接表和入度表 + def build_graph(flows): + graph = {} # node -> list of successors + in_degree = {} # node -> in-degree count + all_nodes = set() + + for flow in flows: + source = flow['source']['node_name'] + target = flow['target']['node_name'] + all_nodes.add(source) + all_nodes.add(target) + + if source not in graph: + graph[source] = [] + graph[source].append(target) + + if target not in in_degree: + in_degree[target] = 0 + in_degree[target] += 1 + + if source not in in_degree: + in_degree[source] = 0 + + return graph, in_degree, all_nodes + + # 使用Kahn算法检测是否有环 + def has_cycle(graph, in_degree, all_nodes): + # 找到所有入度为0的节点 + queue = [node for node in all_nodes if in_degree.get(node, 0) == 0] + visited_count = 0 + + while queue: + node = queue.pop(0) + visited_count += 1 + + # 访问所有邻居 + for neighbor in graph.get(node, []): + in_degree[neighbor] -= 1 + if in_degree[neighbor] == 0: + queue.append(neighbor) + + # 如果访问的节点数等于总节点数,说明没有环 + return visited_count < len(all_nodes) + + # 逐个添加边,如果添加后产生环则跳过 + result_flows = [] + + for flow in sorted_flows: + # 尝试添加这条边 + temp_flows = result_flows + [flow] + + # 检查是否产生环 + graph, in_degree, all_nodes = build_graph(temp_flows) + + # 复制in_degree用于检测(因为检测过程会修改它) + in_degree_copy = in_degree.copy() + + if not has_cycle(graph, in_degree_copy, all_nodes): + # 没有产生环,可以添加 + result_flows.append(flow) + else: + # 产生环,跳过这条边 + print(f"Skipping edge that creates cycle: {flow['source']['node_name']} -> {flow['target']['node_name']}") + + removed_count = len(flows_data) - len(result_flows) + if removed_count > 0: + print(f"Removed {removed_count} edges to eliminate cycles in Sankey diagram") + + return result_flows + + def get_report_type(date_str): """获取报告期类型""" if not date_str: @@ -10321,6 +10805,9 @@ def get_value_chain_analysis(company_code): } }) + # 移除循环边,确保Sankey图数据是DAG(有向无环图) + flows_data = remove_cycles_from_sankey_flows(flows_data) + # 统计各层级节点数量 level_stats = {} for level_key, nodes in nodes_by_level.items(): @@ -11574,4 +12061,8 @@ if __name__ == '__main__': except Exception as e: app.logger.error(f"数据库初始化失败: {e}") - app.run(host='0.0.0.0', port=5001, debug=False) \ No newline at end of file + # 初始化事件轮询机制(WebSocket 推送) + initialize_event_polling() + + # 使用 socketio.run 替代 app.run 以支持 WebSocket + socketio.run(app, host='0.0.0.0', port=5001, debug=False, allow_unsafe_werkzeug=True) \ No newline at end of file