# -*- coding: utf-8 -*- """行业深度研报查询脚本 —— 供 Claude Skill 使用。 连接 MySQL 数据库,查询行业报告的结构化数据和原始 Markdown 文本。 依赖: pymysql(项目已安装)。 用法: python query_industry.py [args...] 子命令: lookup 按关键词或代码查找行业 summary 获取行业核心指标摘要 module 获取结构化 Module JSON part 获取 Part 原始 Markdown rank [--top N] 跨行业排名 list 列出所有行业 """ from __future__ import annotations import json import sys import os # Windows 终端 GBK 编码问题 sys.stdout.reconfigure(encoding="utf-8") sys.stderr.reconfigure(encoding="utf-8") # 数据库配置 sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) from db_config import DB_CONFIG try: import pymysql except ImportError: import subprocess print("pymysql 未安装,正在自动安装...", file=sys.stderr) subprocess.check_call([sys.executable, "-m", "pip", "install", "pymysql", "-q"]) import pymysql def _get_conn(): """获取数据库连接。""" return pymysql.connect(**DB_CONFIG, cursorclass=pymysql.cursors.DictCursor) # --------------------------------------------------------------------------- # 子命令实现 # --------------------------------------------------------------------------- def cmd_lookup(keyword: str): """按关键词查找行业,模糊匹配 name 和 code。""" conn = _get_conn() try: with conn.cursor() as cur: like = f"%{keyword}%" cur.execute( """SELECT code, name, stars, stars_label, headline FROM industry_reports WHERE name LIKE %s OR code LIKE %s ORDER BY stars DESC""", (like, like), ) rows = cur.fetchall() finally: conn.close() if not rows: print(f"未找到匹配「{keyword}」的行业") return print(f"找到 {len(rows)} 个行业:\n") for r in rows: stars_display = "★" * r["stars"] + "☆" * (5 - r["stars"]) print(f" {r['code']} {r['name']} {stars_display} ({r['stars_label']})") if r["headline"]: print(f" {r['headline']}") print() def cmd_summary(code: str): """获取行业核心指标摘要。""" conn = _get_conn() try: with conn.cursor() as cur: cur.execute( """SELECT code, name, stars, stars_label, cagr, gross_margin, net_margin, cr5, cr10, hhi, concentration_trend, headline, total_characters, report_version, generated_at FROM industry_reports WHERE code = %s""", (code,), ) row = cur.fetchone() finally: conn.close() if not row: print(f"未找到行业: {code}") return print(json.dumps(row, ensure_ascii=False, indent=2, default=str)) def cmd_module(code: str, module_num: str): """获取结构化 Module JSON 数据。""" conn = _get_conn() try: with conn.cursor() as cur: if module_num == "all": cur.execute( """SELECT rm.module_num, rm.module_json FROM report_modules rm JOIN industry_reports ir ON ir.id = rm.report_id WHERE ir.code = %s ORDER BY rm.module_num""", (code,), ) else: cur.execute( """SELECT rm.module_num, rm.module_json FROM report_modules rm JOIN industry_reports ir ON ir.id = rm.report_id WHERE ir.code = %s AND rm.module_num = %s""", (code, int(module_num)), ) rows = cur.fetchall() finally: conn.close() if not rows: print(f"未找到行业 {code} 的 Module 数据") return for r in rows: mj = r["module_json"] data = json.loads(mj) if isinstance(mj, str) else mj print(f"=== Module {r['module_num']} ===") print(json.dumps(data, ensure_ascii=False, indent=2)) print() def cmd_part(code: str, part_nums: str): """获取 Part 原始 Markdown 文本。""" conn = _get_conn() try: with conn.cursor() as cur: if part_nums == "all": cur.execute( """SELECT rrp.part_num, rrp.content, rrp.char_count FROM report_raw_parts rrp JOIN industry_reports ir ON ir.id = rrp.report_id WHERE ir.code = %s ORDER BY rrp.part_num""", (code,), ) else: nums = [int(x.strip()) for x in part_nums.split(",")] placeholders = ",".join(["%s"] * len(nums)) cur.execute( f"""SELECT rrp.part_num, rrp.content, rrp.char_count FROM report_raw_parts rrp JOIN industry_reports ir ON ir.id = rrp.report_id WHERE ir.code = %s AND rrp.part_num IN ({placeholders}) ORDER BY rrp.part_num""", [code] + nums, ) rows = cur.fetchall() finally: conn.close() if not rows: print(f"未找到行业 {code} 的 Part 数据") return for r in rows: label = "meta (目录+执行摘要)" if r["part_num"] == 0 else f"Part {r['part_num']}" print(f"{'=' * 60}") print(f" {label} ({r['char_count']} 字)") print(f"{'=' * 60}") print(r["content"]) print() def cmd_rank(field: str, top: int = 5): """跨行业排名,按指定字段降序。""" field_map = { "stars": "stars", "cagr": "cagr_value", "cagr_value": "cagr_value", "gross_margin": "gross_margin_value", "gross_margin_value": "gross_margin_value", } column = field_map.get(field) if not column: print(f"不支持的排序字段: {field}") print(f"可用字段: {', '.join(field_map.keys())}") return conn = _get_conn() try: with conn.cursor() as cur: cur.execute( f"""SELECT code, name, stars, stars_label, cagr, gross_margin, headline FROM industry_reports WHERE {column} IS NOT NULL ORDER BY {column} DESC LIMIT %s""", (top,), ) rows = cur.fetchall() finally: conn.close() if not rows: print("暂无数据") return print(f"Top {top} — 按 {field} 排序:\n") for i, r in enumerate(rows, 1): print(f" {i}. {r['code']} {r['name']}") print(f" 景气度={r['stars']}★ CAGR={r['cagr']} 毛利率={r['gross_margin']}") if r["headline"]: print(f" {r['headline']}") print() def cmd_list(): """列出所有行业。""" conn = _get_conn() try: with conn.cursor() as cur: cur.execute( "SELECT code, name, stars FROM industry_reports ORDER BY code" ) rows = cur.fetchall() finally: conn.close() if not rows: print("暂无行业数据") return print(f"共 {len(rows)} 个行业:\n") for r in rows: print(f" {r['code']} {r['name']} {'★' * r['stars']}") # --------------------------------------------------------------------------- # 主入口 # --------------------------------------------------------------------------- def main(): if len(sys.argv) < 2: print(__doc__) sys.exit(0) cmd = sys.argv[1] if cmd == "lookup": if len(sys.argv) < 3: print("用法: query_industry.py lookup ") sys.exit(1) cmd_lookup(sys.argv[2]) elif cmd == "summary": if len(sys.argv) < 3: print("用法: query_industry.py summary ") sys.exit(1) cmd_summary(sys.argv[2]) elif cmd == "module": if len(sys.argv) < 4: print("用法: query_industry.py module <1-5|all>") sys.exit(1) cmd_module(sys.argv[2], sys.argv[3]) elif cmd == "part": if len(sys.argv) < 4: print("用法: query_industry.py part <0-8|all>") sys.exit(1) cmd_part(sys.argv[2], sys.argv[3]) elif cmd == "rank": if len(sys.argv) < 3: print("用法: query_industry.py rank [--top N]") sys.exit(1) field = sys.argv[2] top = 5 if "--top" in sys.argv: idx = sys.argv.index("--top") if idx + 1 < len(sys.argv): top = int(sys.argv[idx + 1]) cmd_rank(field, top) elif cmd == "list": cmd_list() else: print(f"未知命令: {cmd}") print(__doc__) sys.exit(1) if __name__ == "__main__": main()