305 lines
9.2 KiB
Python
305 lines
9.2 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""行业深度研报查询脚本 —— 供 Claude Skill 使用。
|
||
|
||
连接 MySQL 数据库,查询行业报告的结构化数据和原始 Markdown 文本。
|
||
依赖: pymysql(项目已安装)。
|
||
|
||
用法:
|
||
python query_industry.py <command> [args...]
|
||
|
||
子命令:
|
||
lookup <keyword> 按关键词或代码查找行业
|
||
summary <code> 获取行业核心指标摘要
|
||
module <code> <num|all> 获取结构化 Module JSON
|
||
part <code> <nums|all> 获取 Part 原始 Markdown
|
||
rank <field> [--top N] 跨行业排名
|
||
list 列出所有行业
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
import sys
|
||
import os
|
||
|
||
# Windows 终端 GBK 编码问题
|
||
sys.stdout.reconfigure(encoding="utf-8")
|
||
sys.stderr.reconfigure(encoding="utf-8")
|
||
|
||
# 数据库配置
|
||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||
from db_config import DB_CONFIG
|
||
|
||
try:
|
||
import pymysql
|
||
except ImportError:
|
||
import subprocess
|
||
print("pymysql 未安装,正在自动安装...", file=sys.stderr)
|
||
subprocess.check_call([sys.executable, "-m", "pip", "install", "pymysql", "-q"])
|
||
import pymysql
|
||
|
||
|
||
def _get_conn():
|
||
"""获取数据库连接。"""
|
||
return pymysql.connect(**DB_CONFIG, cursorclass=pymysql.cursors.DictCursor)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 子命令实现
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def cmd_lookup(keyword: str):
|
||
"""按关键词查找行业,模糊匹配 name 和 code。"""
|
||
conn = _get_conn()
|
||
try:
|
||
with conn.cursor() as cur:
|
||
like = f"%{keyword}%"
|
||
cur.execute(
|
||
"""SELECT code, name, stars, stars_label, headline
|
||
FROM industry_reports
|
||
WHERE name LIKE %s OR code LIKE %s
|
||
ORDER BY stars DESC""",
|
||
(like, like),
|
||
)
|
||
rows = cur.fetchall()
|
||
finally:
|
||
conn.close()
|
||
|
||
if not rows:
|
||
print(f"未找到匹配「{keyword}」的行业")
|
||
return
|
||
|
||
print(f"找到 {len(rows)} 个行业:\n")
|
||
for r in rows:
|
||
stars_display = "★" * r["stars"] + "☆" * (5 - r["stars"])
|
||
print(f" {r['code']} {r['name']} {stars_display} ({r['stars_label']})")
|
||
if r["headline"]:
|
||
print(f" {r['headline']}")
|
||
print()
|
||
|
||
|
||
def cmd_summary(code: str):
|
||
"""获取行业核心指标摘要。"""
|
||
conn = _get_conn()
|
||
try:
|
||
with conn.cursor() as cur:
|
||
cur.execute(
|
||
"""SELECT code, name, stars, stars_label, cagr, gross_margin,
|
||
net_margin, cr5, cr10, hhi, concentration_trend,
|
||
headline, total_characters, report_version, generated_at
|
||
FROM industry_reports WHERE code = %s""",
|
||
(code,),
|
||
)
|
||
row = cur.fetchone()
|
||
finally:
|
||
conn.close()
|
||
|
||
if not row:
|
||
print(f"未找到行业: {code}")
|
||
return
|
||
|
||
print(json.dumps(row, ensure_ascii=False, indent=2, default=str))
|
||
|
||
|
||
def cmd_module(code: str, module_num: str):
|
||
"""获取结构化 Module JSON 数据。"""
|
||
conn = _get_conn()
|
||
try:
|
||
with conn.cursor() as cur:
|
||
if module_num == "all":
|
||
cur.execute(
|
||
"""SELECT rm.module_num, rm.module_json
|
||
FROM report_modules rm
|
||
JOIN industry_reports ir ON ir.id = rm.report_id
|
||
WHERE ir.code = %s
|
||
ORDER BY rm.module_num""",
|
||
(code,),
|
||
)
|
||
else:
|
||
cur.execute(
|
||
"""SELECT rm.module_num, rm.module_json
|
||
FROM report_modules rm
|
||
JOIN industry_reports ir ON ir.id = rm.report_id
|
||
WHERE ir.code = %s AND rm.module_num = %s""",
|
||
(code, int(module_num)),
|
||
)
|
||
rows = cur.fetchall()
|
||
finally:
|
||
conn.close()
|
||
|
||
if not rows:
|
||
print(f"未找到行业 {code} 的 Module 数据")
|
||
return
|
||
|
||
for r in rows:
|
||
mj = r["module_json"]
|
||
data = json.loads(mj) if isinstance(mj, str) else mj
|
||
print(f"=== Module {r['module_num']} ===")
|
||
print(json.dumps(data, ensure_ascii=False, indent=2))
|
||
print()
|
||
|
||
|
||
def cmd_part(code: str, part_nums: str):
|
||
"""获取 Part 原始 Markdown 文本。"""
|
||
conn = _get_conn()
|
||
try:
|
||
with conn.cursor() as cur:
|
||
if part_nums == "all":
|
||
cur.execute(
|
||
"""SELECT rrp.part_num, rrp.content, rrp.char_count
|
||
FROM report_raw_parts rrp
|
||
JOIN industry_reports ir ON ir.id = rrp.report_id
|
||
WHERE ir.code = %s
|
||
ORDER BY rrp.part_num""",
|
||
(code,),
|
||
)
|
||
else:
|
||
nums = [int(x.strip()) for x in part_nums.split(",")]
|
||
placeholders = ",".join(["%s"] * len(nums))
|
||
cur.execute(
|
||
f"""SELECT rrp.part_num, rrp.content, rrp.char_count
|
||
FROM report_raw_parts rrp
|
||
JOIN industry_reports ir ON ir.id = rrp.report_id
|
||
WHERE ir.code = %s AND rrp.part_num IN ({placeholders})
|
||
ORDER BY rrp.part_num""",
|
||
[code] + nums,
|
||
)
|
||
rows = cur.fetchall()
|
||
finally:
|
||
conn.close()
|
||
|
||
if not rows:
|
||
print(f"未找到行业 {code} 的 Part 数据")
|
||
return
|
||
|
||
for r in rows:
|
||
label = "meta (目录+执行摘要)" if r["part_num"] == 0 else f"Part {r['part_num']}"
|
||
print(f"{'=' * 60}")
|
||
print(f" {label} ({r['char_count']} 字)")
|
||
print(f"{'=' * 60}")
|
||
print(r["content"])
|
||
print()
|
||
|
||
|
||
def cmd_rank(field: str, top: int = 5):
|
||
"""跨行业排名,按指定字段降序。"""
|
||
field_map = {
|
||
"stars": "stars",
|
||
"cagr": "cagr_value",
|
||
"cagr_value": "cagr_value",
|
||
"gross_margin": "gross_margin_value",
|
||
"gross_margin_value": "gross_margin_value",
|
||
}
|
||
column = field_map.get(field)
|
||
if not column:
|
||
print(f"不支持的排序字段: {field}")
|
||
print(f"可用字段: {', '.join(field_map.keys())}")
|
||
return
|
||
|
||
conn = _get_conn()
|
||
try:
|
||
with conn.cursor() as cur:
|
||
cur.execute(
|
||
f"""SELECT code, name, stars, stars_label, cagr, gross_margin, headline
|
||
FROM industry_reports
|
||
WHERE {column} IS NOT NULL
|
||
ORDER BY {column} DESC
|
||
LIMIT %s""",
|
||
(top,),
|
||
)
|
||
rows = cur.fetchall()
|
||
finally:
|
||
conn.close()
|
||
|
||
if not rows:
|
||
print("暂无数据")
|
||
return
|
||
|
||
print(f"Top {top} — 按 {field} 排序:\n")
|
||
for i, r in enumerate(rows, 1):
|
||
print(f" {i}. {r['code']} {r['name']}")
|
||
print(f" 景气度={r['stars']}★ CAGR={r['cagr']} 毛利率={r['gross_margin']}")
|
||
if r["headline"]:
|
||
print(f" {r['headline']}")
|
||
print()
|
||
|
||
|
||
def cmd_list():
|
||
"""列出所有行业。"""
|
||
conn = _get_conn()
|
||
try:
|
||
with conn.cursor() as cur:
|
||
cur.execute(
|
||
"SELECT code, name, stars FROM industry_reports ORDER BY code"
|
||
)
|
||
rows = cur.fetchall()
|
||
finally:
|
||
conn.close()
|
||
|
||
if not rows:
|
||
print("暂无行业数据")
|
||
return
|
||
|
||
print(f"共 {len(rows)} 个行业:\n")
|
||
for r in rows:
|
||
print(f" {r['code']} {r['name']} {'★' * r['stars']}")
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 主入口
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def main():
|
||
if len(sys.argv) < 2:
|
||
print(__doc__)
|
||
sys.exit(0)
|
||
|
||
cmd = sys.argv[1]
|
||
|
||
if cmd == "lookup":
|
||
if len(sys.argv) < 3:
|
||
print("用法: query_industry.py lookup <keyword>")
|
||
sys.exit(1)
|
||
cmd_lookup(sys.argv[2])
|
||
|
||
elif cmd == "summary":
|
||
if len(sys.argv) < 3:
|
||
print("用法: query_industry.py summary <code>")
|
||
sys.exit(1)
|
||
cmd_summary(sys.argv[2])
|
||
|
||
elif cmd == "module":
|
||
if len(sys.argv) < 4:
|
||
print("用法: query_industry.py module <code> <1-5|all>")
|
||
sys.exit(1)
|
||
cmd_module(sys.argv[2], sys.argv[3])
|
||
|
||
elif cmd == "part":
|
||
if len(sys.argv) < 4:
|
||
print("用法: query_industry.py part <code> <0-8|all>")
|
||
sys.exit(1)
|
||
cmd_part(sys.argv[2], sys.argv[3])
|
||
|
||
elif cmd == "rank":
|
||
if len(sys.argv) < 3:
|
||
print("用法: query_industry.py rank <field> [--top N]")
|
||
sys.exit(1)
|
||
field = sys.argv[2]
|
||
top = 5
|
||
if "--top" in sys.argv:
|
||
idx = sys.argv.index("--top")
|
||
if idx + 1 < len(sys.argv):
|
||
top = int(sys.argv[idx + 1])
|
||
cmd_rank(field, top)
|
||
|
||
elif cmd == "list":
|
||
cmd_list()
|
||
|
||
else:
|
||
print(f"未知命令: {cmd}")
|
||
print(__doc__)
|
||
sys.exit(1)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|