Files
public-resources/industry-deep/scripts/query_industry.py

305 lines
9.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""行业深度研报查询脚本 —— 供 Claude Skill 使用。
连接 MySQL 数据库,查询行业报告的结构化数据和原始 Markdown 文本。
依赖: pymysql项目已安装
用法:
python query_industry.py <command> [args...]
子命令:
lookup <keyword> 按关键词或代码查找行业
summary <code> 获取行业核心指标摘要
module <code> <num|all> 获取结构化 Module JSON
part <code> <nums|all> 获取 Part 原始 Markdown
rank <field> [--top N] 跨行业排名
list 列出所有行业
"""
from __future__ import annotations
import json
import sys
import os
# Windows 终端 GBK 编码问题
sys.stdout.reconfigure(encoding="utf-8")
sys.stderr.reconfigure(encoding="utf-8")
# 数据库配置
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from db_config import DB_CONFIG
try:
import pymysql
except ImportError:
import subprocess
print("pymysql 未安装,正在自动安装...", file=sys.stderr)
subprocess.check_call([sys.executable, "-m", "pip", "install", "pymysql", "-q"])
import pymysql
def _get_conn():
"""获取数据库连接。"""
return pymysql.connect(**DB_CONFIG, cursorclass=pymysql.cursors.DictCursor)
# ---------------------------------------------------------------------------
# 子命令实现
# ---------------------------------------------------------------------------
def cmd_lookup(keyword: str):
"""按关键词查找行业,模糊匹配 name 和 code。"""
conn = _get_conn()
try:
with conn.cursor() as cur:
like = f"%{keyword}%"
cur.execute(
"""SELECT code, name, stars, stars_label, headline
FROM industry_reports
WHERE name LIKE %s OR code LIKE %s
ORDER BY stars DESC""",
(like, like),
)
rows = cur.fetchall()
finally:
conn.close()
if not rows:
print(f"未找到匹配「{keyword}」的行业")
return
print(f"找到 {len(rows)} 个行业:\n")
for r in rows:
stars_display = "" * r["stars"] + "" * (5 - r["stars"])
print(f" {r['code']} {r['name']} {stars_display} ({r['stars_label']})")
if r["headline"]:
print(f" {r['headline']}")
print()
def cmd_summary(code: str):
"""获取行业核心指标摘要。"""
conn = _get_conn()
try:
with conn.cursor() as cur:
cur.execute(
"""SELECT code, name, stars, stars_label, cagr, gross_margin,
net_margin, cr5, cr10, hhi, concentration_trend,
headline, total_characters, report_version, generated_at
FROM industry_reports WHERE code = %s""",
(code,),
)
row = cur.fetchone()
finally:
conn.close()
if not row:
print(f"未找到行业: {code}")
return
print(json.dumps(row, ensure_ascii=False, indent=2, default=str))
def cmd_module(code: str, module_num: str):
"""获取结构化 Module JSON 数据。"""
conn = _get_conn()
try:
with conn.cursor() as cur:
if module_num == "all":
cur.execute(
"""SELECT rm.module_num, rm.module_json
FROM report_modules rm
JOIN industry_reports ir ON ir.id = rm.report_id
WHERE ir.code = %s
ORDER BY rm.module_num""",
(code,),
)
else:
cur.execute(
"""SELECT rm.module_num, rm.module_json
FROM report_modules rm
JOIN industry_reports ir ON ir.id = rm.report_id
WHERE ir.code = %s AND rm.module_num = %s""",
(code, int(module_num)),
)
rows = cur.fetchall()
finally:
conn.close()
if not rows:
print(f"未找到行业 {code} 的 Module 数据")
return
for r in rows:
mj = r["module_json"]
data = json.loads(mj) if isinstance(mj, str) else mj
print(f"=== Module {r['module_num']} ===")
print(json.dumps(data, ensure_ascii=False, indent=2))
print()
def cmd_part(code: str, part_nums: str):
"""获取 Part 原始 Markdown 文本。"""
conn = _get_conn()
try:
with conn.cursor() as cur:
if part_nums == "all":
cur.execute(
"""SELECT rrp.part_num, rrp.content, rrp.char_count
FROM report_raw_parts rrp
JOIN industry_reports ir ON ir.id = rrp.report_id
WHERE ir.code = %s
ORDER BY rrp.part_num""",
(code,),
)
else:
nums = [int(x.strip()) for x in part_nums.split(",")]
placeholders = ",".join(["%s"] * len(nums))
cur.execute(
f"""SELECT rrp.part_num, rrp.content, rrp.char_count
FROM report_raw_parts rrp
JOIN industry_reports ir ON ir.id = rrp.report_id
WHERE ir.code = %s AND rrp.part_num IN ({placeholders})
ORDER BY rrp.part_num""",
[code] + nums,
)
rows = cur.fetchall()
finally:
conn.close()
if not rows:
print(f"未找到行业 {code} 的 Part 数据")
return
for r in rows:
label = "meta (目录+执行摘要)" if r["part_num"] == 0 else f"Part {r['part_num']}"
print(f"{'=' * 60}")
print(f" {label} ({r['char_count']} 字)")
print(f"{'=' * 60}")
print(r["content"])
print()
def cmd_rank(field: str, top: int = 5):
"""跨行业排名,按指定字段降序。"""
field_map = {
"stars": "stars",
"cagr": "cagr_value",
"cagr_value": "cagr_value",
"gross_margin": "gross_margin_value",
"gross_margin_value": "gross_margin_value",
}
column = field_map.get(field)
if not column:
print(f"不支持的排序字段: {field}")
print(f"可用字段: {', '.join(field_map.keys())}")
return
conn = _get_conn()
try:
with conn.cursor() as cur:
cur.execute(
f"""SELECT code, name, stars, stars_label, cagr, gross_margin, headline
FROM industry_reports
WHERE {column} IS NOT NULL
ORDER BY {column} DESC
LIMIT %s""",
(top,),
)
rows = cur.fetchall()
finally:
conn.close()
if not rows:
print("暂无数据")
return
print(f"Top {top} — 按 {field} 排序:\n")
for i, r in enumerate(rows, 1):
print(f" {i}. {r['code']} {r['name']}")
print(f" 景气度={r['stars']}★ CAGR={r['cagr']} 毛利率={r['gross_margin']}")
if r["headline"]:
print(f" {r['headline']}")
print()
def cmd_list():
"""列出所有行业。"""
conn = _get_conn()
try:
with conn.cursor() as cur:
cur.execute(
"SELECT code, name, stars FROM industry_reports ORDER BY code"
)
rows = cur.fetchall()
finally:
conn.close()
if not rows:
print("暂无行业数据")
return
print(f"{len(rows)} 个行业:\n")
for r in rows:
print(f" {r['code']} {r['name']} {'' * r['stars']}")
# ---------------------------------------------------------------------------
# 主入口
# ---------------------------------------------------------------------------
def main():
if len(sys.argv) < 2:
print(__doc__)
sys.exit(0)
cmd = sys.argv[1]
if cmd == "lookup":
if len(sys.argv) < 3:
print("用法: query_industry.py lookup <keyword>")
sys.exit(1)
cmd_lookup(sys.argv[2])
elif cmd == "summary":
if len(sys.argv) < 3:
print("用法: query_industry.py summary <code>")
sys.exit(1)
cmd_summary(sys.argv[2])
elif cmd == "module":
if len(sys.argv) < 4:
print("用法: query_industry.py module <code> <1-5|all>")
sys.exit(1)
cmd_module(sys.argv[2], sys.argv[3])
elif cmd == "part":
if len(sys.argv) < 4:
print("用法: query_industry.py part <code> <0-8|all>")
sys.exit(1)
cmd_part(sys.argv[2], sys.argv[3])
elif cmd == "rank":
if len(sys.argv) < 3:
print("用法: query_industry.py rank <field> [--top N]")
sys.exit(1)
field = sys.argv[2]
top = 5
if "--top" in sys.argv:
idx = sys.argv.index("--top")
if idx + 1 < len(sys.argv):
top = int(sys.argv[idx + 1])
cmd_rank(field, top)
elif cmd == "list":
cmd_list()
else:
print(f"未知命令: {cmd}")
print(__doc__)
sys.exit(1)
if __name__ == "__main__":
main()