diff --git a/app.py b/app.py index 73e2006e..b5b31b6b 100755 --- a/app.py +++ b/app.py @@ -6324,196 +6324,100 @@ def parse_json_field(field_value): # ==================== 行业API ==================== @app.route('/api/classifications', methods=['GET']) def get_classifications(): - """获取所有行业分类系统""" + """获取申银万国行业分类树形结构""" try: + # 查询申银万国行业分类的所有数据 sql = """ - SELECT DISTINCT f002v as classification_name + SELECT f003v as code, f004v as level1, f005v as level2, f006v as level3,f007v as level4 FROM ea_sector - WHERE f002v NOT IN ('指数成份股', '市场分类', '概念板块', '地区省市分类', '中上协行业分类') - ORDER BY classification_name \ + WHERE f002v = '申银万国行业分类' + AND f003v IS NOT NULL + AND f004v IS NOT NULL + ORDER BY f003v """ result = db.session.execute(text(sql)).all() - classifications = [{'name': row.classification_name} for row in result] + # 构建树形结构 + tree_dict = {} - return jsonify({ - 'success': True, - 'data': classifications - }) + for row in result: + code = row.code + level1 = row.level1 + level2 = row.level2 + level3 = row.level3 - except Exception as e: - return jsonify({ - 'success': False, - 'error': str(e) - }), 500 + # 跳过空数据 + if not level1: + continue - -@app.route('/api/levels', methods=['GET']) -def get_industry_levels(): - """获取行业层级数据""" - try: - classification = request.args.get('classification') - level = request.args.get('level', type=int) - level1_name = request.args.get('level1_name', '') - level2_name = request.args.get('level2_name', '') - level3_name = request.args.get('level3_name', '') - - if not classification or not level or level < 1 or level > 4: - return jsonify({ - 'success': False, - 'error': 'Invalid parameters' - }), 400 - - # 层级到字段的映射 - level_fields = { - 1: "f004v", - 2: "f005v", - 3: "f006v", - 4: "f007v" - } - - field_name = level_fields[level] - - # 构建查询 - if level == 1: - sql = f""" - SELECT DISTINCT {field_name} as name, - MIN(f003v) as code - FROM ea_sector - WHERE f002v = :classification - AND {field_name} IS NOT NULL - GROUP BY name - ORDER BY name - """ - params = {"classification": classification} - - elif level == 2: - sql = f""" - SELECT DISTINCT {field_name} as name, - MIN(f003v) as code - FROM ea_sector - WHERE f002v = :classification - AND f004v = :level1_name - AND {field_name} IS NOT NULL - GROUP BY name - ORDER BY name - """ - params = {"classification": classification, "level1_name": level1_name} - - elif level == 3: - sql = f""" - SELECT DISTINCT {field_name} as name, - MIN(f003v) as code - FROM ea_sector - WHERE f002v = :classification - AND f004v = :level1_name - AND f005v = :level2_name - AND {field_name} IS NOT NULL - GROUP BY name - ORDER BY name - """ - params = { - "classification": classification, - "level1_name": level1_name, - "level2_name": level2_name - } - - elif level == 4: - sql = f""" - SELECT DISTINCT f003v as code, - {field_name} as name - FROM ea_sector - WHERE f002v = :classification - AND f004v = :level1_name - AND f005v = :level2_name - AND f006v = :level3_name - AND {field_name} IS NOT NULL - ORDER BY name - """ - params = { - "classification": classification, - "level1_name": level1_name, - "level2_name": level2_name, - "level3_name": level3_name - } - - results = db.session.execute(text(sql), params).all() - - industries = [{"code": row.code, "name": row.name} for row in results if row.name] - - return jsonify({ - 'success': True, - 'data': industries - }) - - except Exception as e: - return jsonify({ - 'success': False, - 'error': str(e) - }), 500 - - -@app.route('/api/info', methods=['GET']) -def get_industry_info(): - """获取行业详细信息""" - try: - classification = request.args.get('classification') - code = request.args.get('code') - - if not classification or not code: - return jsonify({ - 'success': False, - 'error': 'Missing parameters' - }), 400 - - # 根据代码长度确定字段 - if len(code) >= 8: - field_name = "f007v" - elif len(code) >= 6: - field_name = "f006v" - elif len(code) >= 4: - field_name = "f005v" - else: - field_name = "f004v" - - sql = f""" - SELECT {field_name} as name, - f004v as level1_name, - f005v as level2_name, - f006v as level3_name, - f007v as level4_name - FROM ea_sector - WHERE f002v = :classification - AND f003v = :code - AND {field_name} IS NOT NULL - LIMIT 1 - """ - - result = db.session.execute(text(sql), { - "classification": classification, - "code": code - }).first() - - if not result: - return jsonify({ - 'success': False, - 'error': 'Industry not found' - }), 404 - - return jsonify({ - 'success': True, - 'data': { - 'name': result.name, - 'code': code, - 'classification': classification, - 'hierarchy': { - 'level1': result.level1_name, - 'level2': result.level2_name, - 'level3': result.level3_name, - 'level4': result.level4_name + # 第一层 + if level1 not in tree_dict: + # 获取第一层的code(取前3位或前缀) + level1_code = code[:3] if len(code) >= 3 else code + tree_dict[level1] = { + 'value': level1_code, + 'label': level1, + 'children_dict': {} } + + # 第二层 + if level2: + if level2 not in tree_dict[level1]['children_dict']: + # 获取第二层的code(取前6位) + level2_code = code[:6] if len(code) >= 6 else code + tree_dict[level1]['children_dict'][level2] = { + 'value': level2_code, + 'label': level2, + 'children_dict': {} + } + + # 第三层 + if level3: + if level3 not in tree_dict[level1]['children_dict'][level2]['children_dict']: + tree_dict[level1]['children_dict'][level2]['children_dict'][level3] = { + 'value': code, + 'label': level3 + } + + # 转换为最终格式 + result_list = [] + for level1_name, level1_data in tree_dict.items(): + level1_node = { + 'value': level1_data['value'], + 'label': level1_data['label'] } + + # 处理第二层 + if level1_data['children_dict']: + level1_children = [] + for level2_name, level2_data in level1_data['children_dict'].items(): + level2_node = { + 'value': level2_data['value'], + 'label': level2_data['label'] + } + + # 处理第三层 + if level2_data['children_dict']: + level2_children = [] + for level3_name, level3_data in level2_data['children_dict'].items(): + level2_children.append({ + 'value': level3_data['value'], + 'label': level3_data['label'] + }) + if level2_children: + level2_node['children'] = level2_children + + level1_children.append(level2_node) + + if level1_children: + level1_node['children'] = level1_children + + result_list.append(level1_node) + + return jsonify({ + 'success': True, + 'data': result_list }) except Exception as e: @@ -6523,6 +6427,29 @@ def get_industry_info(): }), 500 +@app.route('/api/stocklist', methods=['GET']) +def get_stock_list(): + """获取股票列表""" + try: + sql = """ + SELECT DISTINCT SECCODE as code, SECNAME as name + FROM ea_stocklist + ORDER BY SECCODE + """ + + result = db.session.execute(text(sql)).all() + + stocks = [{'code': row.code, 'name': row.name} for row in result] + + return jsonify(stocks) + + except Exception as e: + return jsonify({ + 'success': False, + 'error': str(e) + }), 500 + + @app.route('/api/events', methods=['GET'], strict_slashes=False) def api_get_events(): """ @@ -6544,10 +6471,8 @@ def api_get_events(): date_range = request.args.get('date_range') recent_days = request.args.get('recent_days', type=int) - # 行业筛选参数 - industry_classification = request.args.get('industry_classification') - industry_code = request.args.get('industry_code') - industry_level = request.args.get('industry_level', type=int) + # 行业筛选参数(只支持申银万国行业分类) + industry_code = request.args.get('industry_code') # 申万行业代码,如 "S370502" # 概念/标签筛选参数 tag = request.args.get('tag') @@ -6593,12 +6518,28 @@ def api_get_events(): query = query.filter_by(importance=importance) if creator_id: query = query.filter_by(creator_id=creator_id) - # 新增:行业代码过滤(MySQL JSON,对象数组模式) - if industry_classification and industry_code: - json_path = f'$[*]."{industry_classification}"' - query = query.filter( - text("JSON_CONTAINS(JSON_EXTRACT(related_industries, :json_path), :industry_code)") - ).params(json_path=json_path, industry_code=json.dumps(industry_code)) + # 新增:行业代码过滤(申银万国行业分类) + if industry_code: + # related_industries 格式: [{"申银万国行业分类": "S370502"}, ...] + # 支持多个行业代码,用逗号分隔 + json_path = '$[*]."申银万国行业分类"' + + # 如果包含逗号,说明是多个行业代码 + if ',' in industry_code: + codes = [code.strip() for code in industry_code.split(',') if code.strip()] + # 使用 OR 条件匹配任意一个行业代码 + conditions = [] + for code in codes: + conditions.append( + text("JSON_CONTAINS(JSON_EXTRACT(related_industries, :json_path), :code)") + .bindparams(json_path=json_path, code=json.dumps(code)) + ) + query = query.filter(db.or_(*conditions)) + else: + # 单个行业代码 + query = query.filter( + text("JSON_CONTAINS(JSON_EXTRACT(related_industries, :json_path), :industry_code)") + ).params(json_path=json_path, industry_code=json.dumps(industry_code)) # 新增:关键词/全文搜索过滤(MySQL JSON) if search_query: like_pattern = f"%{search_query}%" @@ -6719,8 +6660,6 @@ def api_get_events(): applied_filters['start_date'] = start_date if end_date: applied_filters['end_date'] = end_date - if industry_classification: - applied_filters['industry_classification'] = industry_classification if industry_code: applied_filters['industry_code'] = industry_code if tag: diff --git a/src/views/Community/components/EventFilters.js b/src/views/Community/components/EventFilters.js index 40bde007..7efc6788 100644 --- a/src/views/Community/components/EventFilters.js +++ b/src/views/Community/components/EventFilters.js @@ -1,6 +1,6 @@ // src/views/Community/components/EventFilters.js import React, { useState, useEffect } from 'react'; -import { Card, Row, Col, DatePicker, Button, Select, Form, Cascader } from 'antd'; +import { Card, Row, Col, DatePicker, Button, Select, Form, Input, Cascader } from 'antd'; import { FilterOutlined } from '@ant-design/icons'; import moment from 'moment'; import locale from 'antd/es/date-picker/locale/zh_CN'; @@ -12,29 +12,38 @@ const { Option } = Select; const EventFilters = ({ filters, onFilterChange, loading }) => { const [form] = Form.useForm(); - const [industryCascaderValue, setIndustryCascaderValue] = useState([]); - - // 使用全局行业数据 - const { industryData, loadIndustryData, loading: industryLoading } = useIndustry(); + const [industryOptions, setIndustryOptions] = useState([]); // 初始化表单值 useEffect(() => { const initialValues = { date_range: filters.date_range ? filters.date_range.split(' 至 ').map(d => moment(d)) : null, sort: filters.sort, - importance: filters.importance + importance: filters.importance, + industry_code: filters.industry_code ? filters.industry_code.split(',') : [] }; form.setFieldsValue(initialValues); }, [filters, form]); - // Cascader 获得焦点时加载数据 - const handleCascaderFocus = async () => { - if (!industryData || industryData.length === 0) { - logger.debug('EventFilters', 'Cascader 获得焦点,开始加载行业数据'); - await loadIndustryData(); + // 加载申银万国行业分类树形数据 + const loadIndustryClassifications = async () => { + try { + const response = await industryService.getClassifications(); + if (response.success && response.data) { + setIndustryOptions(response.data); + logger.debug('EventFilters', '申银万国行业分类加载成功', { + count: response.data?.length || 0 + }); + } + } catch (error) { + logger.error('EventFilters', 'loadIndustryClassifications', error); } }; + useEffect(() => { + loadIndustryClassifications(); + }, []); + const handleDateRangeChange = (dates) => { if (dates && dates.length === 2) { const dateRange = `${dates[0].format('YYYY-MM-DD')} 至 ${dates[1].format('YYYY-MM-DD')}`; @@ -52,29 +61,60 @@ const EventFilters = ({ filters, onFilterChange, loading }) => { onFilterChange('importance', value); }; - // Cascader 选择变化 - const handleIndustryCascaderChange = (value, selectedOptions) => { - setIndustryCascaderValue(value); + // 收集所有叶子节点的 value(递归) + const collectLeafValues = (node) => { + // 如果没有子节点,说明是叶子节点 + if (!node.children || node.children.length === 0) { + return [node.value]; + } - if (value && value.length > 0) { - // value[0] = 分类体系名称 - // value[1...n] = 行业代码(一级~四级) - const industryCode = value[value.length - 1]; // 最后一级的 code - const classification = value[0]; // 分类体系名称 + // 有子节点,递归收集所有子节点的叶子节点 + let leafValues = []; + node.children.forEach(child => { + leafValues = leafValues.concat(collectLeafValues(child)); + }); + return leafValues; + }; - onFilterChange('industry_classification', classification); - onFilterChange('industry_code', industryCode); + // 根据级联路径找到对应的节点 + const findNodeByPath = (options, path) => { + let current = options; + let node = null; - logger.debug('EventFilters', 'Cascader 选择变化', { - value, - classification, - industryCode, - path: selectedOptions.map(o => o.label).join(' > ') - }); - } else { - // 清空 - onFilterChange('industry_classification', ''); + for (let i = 0; i < path.length; i++) { + node = current.find(item => item.value === path[i]); + if (!node) return null; + if (i < path.length - 1) { + current = node.children || []; + } + } + return node; + }; + + // 行业级联选择变化 + const handleIndustryChange = (value, selectedOptions) => { + if (!value || value.length === 0) { onFilterChange('industry_code', ''); + return; + } + + // 获取选中的节点 + const selectedNode = findNodeByPath(industryOptions, value); + + if (!selectedNode) { + // 如果找不到节点,使用最后一个值 + onFilterChange('industry_code', value[value.length - 1]); + return; + } + + // 如果选中的节点有子节点,收集所有叶子节点的 value + // 这样可以匹配该级别下的所有事件 + if (selectedNode.children && selectedNode.children.length > 0) { + const leafValues = collectLeafValues(selectedNode); + onFilterChange('industry_code', leafValues.join(',')); + } else { + // 叶子节点,直接使用该 value + onFilterChange('industry_code', selectedNode.value); } }; @@ -123,24 +163,20 @@ const EventFilters = ({ filters, onFilterChange, loading }) => { {/* 行业分类级联选择器 - 替换原来的 5 个独立 Select */} - + labels.join(' > ')} showSearch={{ filter: (inputValue, path) => - path.some(option => option.label.toLowerCase().includes(inputValue.toLowerCase())) + path.some(option => option.label.toLowerCase().indexOf(inputValue.toLowerCase()) > -1) }} + disabled={loading} + allowClear style={{ width: '100%' }} + displayRender={(labels) => labels.join(' / ')} />